chore(releases): remove deadcode

Signed-off-by: Aaron Pham (mbp16) <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham (mbp16)
2024-05-27 12:37:50 -04:00
parent da42c269c9
commit f4f7f16e81
38 changed files with 97 additions and 3291 deletions

View File

@@ -24,7 +24,7 @@ from openllm_core._typing_compat import (
get_literal_args,
TypedDict,
)
from openllm_cli import termui
from . import _termui as termui
if sys.version_info >= (3, 11):
import tomllib

View File

@@ -0,0 +1,82 @@
from __future__ import annotations
import enum, functools, logging, os, typing as t
import click, inflection, orjson
from openllm_core._typing_compat import DictStrAny, TypedDict
from openllm_core.utils import get_debug_mode
logger = logging.getLogger('openllm')
class Level(enum.IntEnum):
NOTSET = logging.DEBUG
DEBUG = logging.DEBUG
INFO = logging.INFO
WARNING = logging.WARNING
ERROR = logging.ERROR
CRITICAL = logging.CRITICAL
@property
def color(self) -> str | None:
return {
Level.NOTSET: None,
Level.DEBUG: 'cyan',
Level.INFO: 'green',
Level.WARNING: 'yellow',
Level.ERROR: 'red',
Level.CRITICAL: 'red',
}[self]
@classmethod
def from_logging_level(cls, level: int) -> Level:
return {
logging.DEBUG: Level.DEBUG,
logging.INFO: Level.INFO,
logging.WARNING: Level.WARNING,
logging.ERROR: Level.ERROR,
logging.CRITICAL: Level.CRITICAL,
}[level]
class JsonLog(TypedDict):
log_level: Level
content: str
def log(content: str, level: Level = Level.INFO, fg: str | None = None) -> None:
if get_debug_mode():
echo(content, fg=fg)
else:
echo(orjson.dumps(JsonLog(log_level=level, content=content)).decode(), fg=fg, json=True)
warning = functools.partial(log, level=Level.WARNING)
error = functools.partial(log, level=Level.ERROR)
critical = functools.partial(log, level=Level.CRITICAL)
debug = functools.partial(log, level=Level.DEBUG)
info = functools.partial(log, level=Level.INFO)
notset = functools.partial(log, level=Level.NOTSET)
def echo(text: t.Any, fg: str | None = None, *, _with_style: bool = True, json: bool = False, **attrs: t.Any) -> None:
if json:
text = orjson.loads(text)
if 'content' in text and 'log_level' in text:
content = text['content']
fg = Level.from_logging_level(text['log_level']).color
else:
content = orjson.dumps(text).decode()
fg = Level.INFO.color if not get_debug_mode() else Level.DEBUG.color
else:
content = t.cast(str, text)
attrs['fg'] = fg
(click.echo if not _with_style else click.secho)(content, **attrs)
COLUMNS: int = int(os.environ.get('COLUMNS', str(120)))
CONTEXT_SETTINGS: DictStrAny = {
'help_option_names': ['-h', '--help'],
'max_content_width': COLUMNS,
'token_normalize_func': inflection.underscore,
}
__all__ = ['COLUMNS', 'CONTEXT_SETTINGS', 'Level', 'critical', 'debug', 'echo', 'error', 'info', 'log', 'warning']

View File

@@ -27,14 +27,11 @@ __lazy = utils.LazyModule( # NOTE: update this to sys.modules[__name__] once my
'exceptions': [],
'client': ['HTTPClient', 'AsyncHTTPClient'],
'bundle': [],
'testing': [],
'utils': ['api'],
'entrypoints': ['mount_entrypoints'],
'serialisation': ['ggml', 'transformers', 'vllm'],
'_llm': ['LLM'],
'_deprecated': ['Runner'],
'_runners': ['runner'],
'_quantisation': ['infer_quantisation_config'],
'_strategies': ['CascadingResourceStrategy', 'get_resource'],
},
extra_objects={'COMPILED': COMPILED},
@@ -44,7 +41,7 @@ __all__, __dir__ = __lazy.__all__, __lazy.__dir__
_BREAKING_INTERNAL = ['_service', '_service_vars']
_NEW_IMPL = ['LLM', *_BREAKING_INTERNAL]
if (_BENTOML_VERSION := utils.pkg.pkg_version_info('bentoml')) > (1, 2):
if utils.pkg.pkg_version_info('bentoml') > (1, 2):
import _openllm_tiny as _tiny
else:
_tiny = None
@@ -58,7 +55,7 @@ def __getattr__(name: str) -> _t.Any:
f'"{name}" is an internal implementation and considered breaking with older OpenLLM. Please migrate your code if you depend on this.'
)
_warnings.warn(
f'"{name}" is considered deprecated implementation and will be removed in the future. Make sure to upgrade to OpenLLM 0.5.x',
f'"{name}" is considered deprecated implementation and could be breaking. See https://github.com/bentoml/OpenLLM for more information on upgrading instruction.',
DeprecationWarning,
stacklevel=3,
)

View File

@@ -26,7 +26,6 @@ from . import (
exceptions as exceptions,
serialisation as serialisation,
utils as utils,
entrypoints as entrypoints,
)
from .serialisation import ggml as ggml, transformers as transformers, vllm as vllm
from ._deprecated import Runner as Runner

View File

@@ -1,16 +0,0 @@
import importlib
from openllm_core.utils import LazyModule
_import_structure = {'openai': [], 'hf': []}
def mount_entrypoints(svc, llm):
for module_name in _import_structure:
svc = importlib.import_module(f'.{module_name}', __name__).mount_to_svc(svc, llm)
return svc
__lazy = LazyModule(
__name__, globals()['__file__'], _import_structure, extra_objects={'mount_entrypoints': mount_entrypoints}
)
__all__, __dir__, __getattr__ = __lazy.__all__, __lazy.__dir__, __lazy.__getattr__

View File

@@ -1,17 +0,0 @@
"""Entrypoint for all third-party apps.
Currently support OpenAI compatible API.
Each module should implement the following API:
- `mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Service: ...`
"""
from typing import Any
from _bentoml_sdk import Service
from openllm_core._typing_compat import M, T
from . import hf as hf, openai as openai
from .._llm import LLM
def mount_entrypoints(svc: Service[Any], llm: LLM[M, T]) -> Service: ...

View File

@@ -1,641 +0,0 @@
from __future__ import annotations
import functools
import inspect
import types
import typing as t
import attr
from starlette.routing import Host, Mount, Route
from starlette.schemas import EndpointInfo, SchemaGenerator
from openllm_core.utils import first_not_none
if t.TYPE_CHECKING:
import pydantic
OPENAPI_VERSION, API_VERSION = '3.0.2', '1.0'
# NOTE: OpenAI schema
LIST_MODELS_SCHEMA = """\
---
consumes:
- application/json
description: >
List and describe the various models available in the API.
You can refer to the available supported models with `openllm models` for more
information.
operationId: openai__list_models
produces:
- application/json
summary: Describes a model offering that can be used with the API.
tags:
- OpenAI
x-bentoml-name: list_models
responses:
200:
description: The Model object
content:
application/json:
example:
object: 'list'
data:
- id: __model_id__
object: model
created: 1686935002
owned_by: 'na'
schema:
$ref: '#/components/schemas/ModelList'
"""
CHAT_COMPLETIONS_SCHEMA = """\
---
consumes:
- application/json
description: >-
Given a list of messages comprising a conversation, the model will return a
response.
operationId: openai__chat_completions
produces:
- application/json
tags:
- OpenAI
x-bentoml-name: create_chat_completions
summary: Creates a model response for the given chat conversation.
requestBody:
required: true
content:
application/json:
examples:
one-shot:
summary: One-shot input example
value:
messages: __chat_messages__
model: __model_id__
max_tokens: 256
temperature: 0.7
top_p: 0.43
n: 1
stream: false
chat_template: __chat_template__
add_generation_prompt: __add_generation_prompt__
echo: false
streaming:
summary: Streaming input example
value:
messages:
- role: system
content: You are a helpful assistant.
- role: user
content: Hello, I'm looking for a chatbot that can help me with my work.
model: __model_id__
max_tokens: 256
temperature: 0.7
top_p: 0.43
n: 1
stream: true
stop:
- "<|endoftext|>"
chat_template: __chat_template__
add_generation_prompt: __add_generation_prompt__
echo: false
schema:
$ref: '#/components/schemas/ChatCompletionRequest'
responses:
200:
description: OK
content:
application/json:
schema:
$ref: '#/components/schemas/ChatCompletionResponse'
examples:
streaming:
summary: Streaming output example
value: >
{"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-3.5-turbo-0613","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}
one-shot:
summary: One-shot output example
value: >
{"id": "chatcmpl-123", "object": "chat.completion", "created": 1677652288, "model": "gpt-3.5-turbo-0613", "choices": [{"index": 0, "message": {"role": "assistant", "content": "Hello there, how may I assist you today?"}, "finish_reason": "stop"}], "usage": {"prompt_tokens": 9, "completion_tokens": 12, "total_tokens": 21}}
404:
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
examples:
wrong-model:
summary: Wrong model
value: >
{
"error": {
"message": "Model 'meta-llama--Llama-2-13b-chat-hf' does not exists. Try 'GET /v1/models' to see available models.\\nTip: If you are migrating from OpenAI, make sure to update your 'model' parameters in the request.",
"type": "invalid_request_error",
"object": "error",
"param": null,
"code": 404
}
}
description: NotFound
500:
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
examples:
invalid-parameters:
summary: Invalid parameters
value: >
{
"error": {
"message": "`top_p` has to be a float > 0 and < 1, but is 4.0",
"type": "invalid_request_error",
"object": "error",
"param": null,
"code": 500
}
}
description: Internal Server Error
400:
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
examples:
invalid-json:
summary: Invalid JSON sent
value: >
{
"error": {
"message": "Invalid JSON input received (Check server log).",
"type": "invalid_request_error",
"object": "error",
"param": null,
"code": 400
}
}
invalid-prompt:
summary: Invalid prompt
value: >
{
"error": {
"message": "Please provide a prompt.",
"type": "invalid_request_error",
"object": "error",
"param": null,
"code": 400
}
}
description: Bad Request
"""
COMPLETIONS_SCHEMA = """\
---
consumes:
- application/json
description: >-
Given a prompt, the model will return one or more predicted completions, and can also return the probabilities of alternative tokens at each position. We recommend most users use our Chat completions API.
operationId: openai__completions
produces:
- application/json
tags:
- OpenAI
x-bentoml-name: create_completions
summary: Creates a completion for the provided prompt and parameters.
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/CompletionRequest'
examples:
one-shot:
summary: One-shot input example
value:
prompt: This is a test
model: __model_id__
max_tokens: 256
temperature: 0.7
logprobs: null
top_p: 0.43
n: 1
stream: false
streaming:
summary: Streaming input example
value:
prompt: This is a test
model: __model_id__
max_tokens: 256
temperature: 0.7
top_p: 0.43
logprobs: null
n: 1
stream: true
stop:
- "<|endoftext|>"
responses:
200:
description: OK
content:
application/json:
schema:
$ref: '#/components/schemas/CompletionResponse'
examples:
one-shot:
summary: One-shot output example
value:
id: cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7
object: text_completion
created: 1589478378
model: VAR_model_id
choices:
- text: This is indeed a test
index: 0
logprobs: null
finish_reason: length
usage:
prompt_tokens: 5
completion_tokens: 7
total_tokens: 12
streaming:
summary: Streaming output example
value:
id: cmpl-7iA7iJjj8V2zOkCGvWF2hAkDWBQZe
object: text_completion
created: 1690759702
choices:
- text: This
index: 0
logprobs: null
finish_reason: null
model: gpt-3.5-turbo-instruct
404:
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
examples:
wrong-model:
summary: Wrong model
value: >
{
"error": {
"message": "Model 'meta-llama--Llama-2-13b-chat-hf' does not exists. Try 'GET /v1/models' to see available models.\\nTip: If you are migrating from OpenAI, make sure to update your 'model' parameters in the request.",
"type": "invalid_request_error",
"object": "error",
"param": null,
"code": 404
}
}
description: NotFound
500:
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
examples:
invalid-parameters:
summary: Invalid parameters
value: >
{
"error": {
"message": "`top_p` has to be a float > 0 and < 1, but is 4.0",
"type": "invalid_request_error",
"object": "error",
"param": null,
"code": 500
}
}
description: Internal Server Error
400:
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
examples:
invalid-json:
summary: Invalid JSON sent
value: >
{
"error": {
"message": "Invalid JSON input received (Check server log).",
"type": "invalid_request_error",
"object": "error",
"param": null,
"code": 400
}
}
invalid-prompt:
summary: Invalid prompt
value: >
{
"error": {
"message": "Please provide a prompt.",
"type": "invalid_request_error",
"object": "error",
"param": null,
"code": 400
}
}
description: Bad Request
"""
HF_AGENT_SCHEMA = """\
---
consumes:
- application/json
description: Generate instruction for given HF Agent chain for all OpenLLM supported models.
operationId: hf__agent
summary: Generate instruction for given HF Agent.
tags:
- HF
x-bentoml-name: hf_agent
produces:
- application/json
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/AgentRequest'
example:
inputs: "Is the following `text` positive or negative?"
parameters:
text: "This is a positive text."
stop: []
required: true
responses:
200:
description: Successfull generated instruction.
content:
application/json:
example:
- generated_text: "This is a generated instruction."
schema:
$ref: '#/components/schemas/AgentResponse'
400:
content:
application/json:
schema:
$ref: '#/components/schemas/HFErrorResponse'
description: Bad Request
500:
content:
application/json:
schema:
$ref: '#/components/schemas/HFErrorResponse'
description: Not Found
"""
HF_ADAPTERS_SCHEMA = """\
---
consumes:
- application/json
description: Return current list of adapters for given LLM.
operationId: hf__adapters_map
produces:
- application/json
summary: Describes a model offering that can be used with the API.
tags:
- HF
x-bentoml-name: hf_adapters
responses:
200:
description: Return list of LoRA adapters.
content:
application/json:
example:
aarnphm/opt-6-7b-quotes:
adapter_name: default
adapter_type: LORA
aarnphm/opt-6-7b-dolly:
adapter_name: dolly
adapter_type: LORA
500:
content:
application/json:
schema:
$ref: '#/components/schemas/HFErrorResponse'
description: Not Found
"""
COHERE_GENERATE_SCHEMA = """\
---
consumes:
- application/json
description: >-
Given a prompt, the model will return one or more predicted completions, and
can also return the probabilities of alternative tokens at each position.
operationId: cohere__generate
produces:
- application/json
tags:
- Cohere
x-bentoml-name: cohere_generate
summary: Creates a completion for the provided prompt and parameters.
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/CohereGenerateRequest'
examples:
one-shot:
summary: One-shot input example
value:
prompt: This is a test
max_tokens: 256
temperature: 0.7
p: 0.43
k: 12
num_generations: 2
stream: false
streaming:
summary: Streaming input example
value:
prompt: This is a test
max_tokens: 256
temperature: 0.7
p: 0.43
k: 12
num_generations: 2
stream: true
stop_sequences:
- "<|endoftext|>"
"""
COHERE_CHAT_SCHEMA = """\
---
consumes:
- application/json
description: >-
Given a list of messages comprising a conversation, the model will return a response.
operationId: cohere__chat
produces:
- application/json
tags:
- Cohere
x-bentoml-name: cohere_chat
summary: Creates a model response for the given chat conversation.
"""
_SCHEMAS = {k[:-7].lower(): v for k, v in locals().items() if k.endswith('_SCHEMA')}
def apply_schema(func, **attrs):
for k, v in attrs.items():
func.__doc__ = func.__doc__.replace(k, v)
return func
def add_schema_definitions(func):
append_str = _SCHEMAS.get(func.__name__.lower(), '')
if not append_str:
return func
if func.__doc__ is None:
func.__doc__ = ''
func.__doc__ = func.__doc__.strip() + '\n\n' + append_str.strip()
return func
class OpenLLMSchemaGenerator(SchemaGenerator):
def get_endpoints(self, routes):
endpoints_info = []
for route in routes:
if isinstance(route, (Mount, Host)):
routes = route.routes or []
path = self._remove_converter(route.path) if isinstance(route, Mount) else ''
sub_endpoints = [
EndpointInfo(path=f'{path}{sub_endpoint.path}', http_method=sub_endpoint.http_method, func=sub_endpoint.func)
for sub_endpoint in self.get_endpoints(routes)
]
endpoints_info.extend(sub_endpoints)
elif not isinstance(route, Route) or not route.include_in_schema:
continue
elif (
inspect.isfunction(route.endpoint)
or inspect.ismethod(route.endpoint)
or isinstance(route.endpoint, functools.partial)
):
endpoint = route.endpoint.func if isinstance(route.endpoint, functools.partial) else route.endpoint
path = self._remove_converter(route.path)
for method in route.methods or ['GET']:
if method == 'HEAD':
continue
endpoints_info.append(EndpointInfo(path, method.lower(), endpoint))
else:
path = self._remove_converter(route.path)
for method in ['get', 'post', 'put', 'patch', 'delete', 'options']:
if not hasattr(route.endpoint, method):
continue
func = getattr(route.endpoint, method)
endpoints_info.append(EndpointInfo(path, method.lower(), func))
return endpoints_info
def get_schema(self, routes, mount_path=None):
schema = dict(self.base_schema)
schema.setdefault('paths', {})
endpoints_info = self.get_endpoints(routes)
if mount_path:
mount_path = f'/{mount_path}' if not mount_path.startswith('/') else mount_path
for endpoint in endpoints_info:
parsed = self.parse_docstring(endpoint.func)
if not parsed:
continue
path = endpoint.path if mount_path is None else mount_path + endpoint.path
if path not in schema['paths']:
schema['paths'][path] = {}
schema['paths'][path][endpoint.http_method] = parsed
return schema
def get_generator(title, components=None, tags=None, inject=True):
base_schema = {'info': {'title': title, 'version': API_VERSION}, 'version': OPENAPI_VERSION}
if components and inject:
base_schema['components'] = {'schemas': {c.__name__: component_schema_generator(c) for c in components}}
if tags is not None and tags and inject:
base_schema['tags'] = tags
return OpenLLMSchemaGenerator(base_schema)
def component_schema_generator(attr_cls: pydantic.BaseModel, description=None):
schema = {'type': 'object', 'required': [], 'properties': {}, 'title': attr_cls.__name__}
schema['description'] = first_not_none(
getattr(attr_cls, '__doc__', None), description, default=f'Generated components for {attr_cls.__name__}'
)
for name, field in attr_cls.model_fields.items():
attr_type = field.annotation
origin_type = t.get_origin(attr_type)
args_type = t.get_args(attr_type)
# Map Python types to OpenAPI schema types
if isinstance(attr_type, str):
schema_type = 'string'
elif isinstance(attr_type, int):
schema_type = 'integer'
elif isinstance(attr_type, float):
schema_type = 'number'
elif isinstance(attr_type, bool):
schema_type = 'boolean'
elif origin_type is list or origin_type is tuple:
schema_type = 'array'
elif origin_type is dict:
schema_type = 'object'
# Assuming string keys for simplicity, and handling Any type for values
prop_schema = {'type': 'object', 'additionalProperties': True if args_type[1] is t.Any else {'type': 'string'}}
elif attr_type == t.Optional[str]:
schema_type = 'string'
elif origin_type is t.Union and t.Any in args_type:
schema_type = 'object'
prop_schema = {'type': 'object', 'additionalProperties': True}
else:
schema_type = 'string'
if 'prop_schema' not in locals():
prop_schema = {'type': schema_type}
if field.default is not attr.NOTHING and not isinstance(field.default, attr.Factory):
prop_schema['default'] = field.default
if field.default is attr.NOTHING and not isinstance(attr_type, type(t.Optional)):
schema['required'].append(name)
schema['properties'][name] = prop_schema
locals().pop('prop_schema', None)
return schema
_SimpleSchema = types.new_class(
'_SimpleSchema',
(object,),
{},
lambda ns: ns.update({'__init__': lambda self, it: setattr(self, 'it', it), 'asdict': lambda self: self.it}),
)
def append_schemas(svc, generated_schema, tags_order='prepend', inject=True):
# HACK: Dirty hack to append schemas to existing service. We def need to support mounting Starlette app OpenAPI spec.
from bentoml._internal.service.openapi.specification import OpenAPISpecification
if not inject:
return svc
svc_schema = svc.openapi_spec
if isinstance(svc_schema, (OpenAPISpecification, _SimpleSchema)):
svc_schema = svc_schema.asdict()
if 'tags' in generated_schema:
if tags_order == 'prepend':
svc_schema['tags'] = generated_schema['tags'] + svc_schema['tags']
elif tags_order == 'append':
svc_schema['tags'].extend(generated_schema['tags'])
else:
raise ValueError(f'Invalid tags_order: {tags_order}')
if 'components' in generated_schema:
svc_schema['components']['schemas'].update(generated_schema['components']['schemas'])
svc_schema['paths'].update(generated_schema['paths'])
# HACK: mk this attribute until we have a better way to add starlette schemas.
from bentoml._internal.service import openapi
def _generate_spec(svc, openapi_version=OPENAPI_VERSION):
return _SimpleSchema(svc_schema)
def asdict(self):
return svc_schema
openapi.generate_spec = _generate_spec
OpenAPISpecification.asdict = asdict
return svc

View File

@@ -1,29 +0,0 @@
from typing import Any, Callable, Dict, List, Literal, Optional, Type
from attr import AttrsInstance
from starlette.routing import BaseRoute
from starlette.schemas import EndpointInfo
from bentoml import Service
from openllm_core._typing_compat import ParamSpec
P = ParamSpec('P')
class OpenLLMSchemaGenerator:
base_schema: Dict[str, Any]
def get_endpoints(self, routes: list[BaseRoute]) -> list[EndpointInfo]: ...
def get_schema(self, routes: list[BaseRoute], mount_path: Optional[str] = ...) -> Dict[str, Any]: ...
def parse_docstring(self, func_or_method: Callable[P, Any]) -> Dict[str, Any]: ...
def apply_schema(func: Callable[P, Any], **attrs: Any) -> Callable[P, Any]: ...
def add_schema_definitions(func: Callable[P, Any]) -> Callable[P, Any]: ...
def append_schemas(
svc: Service, generated_schema: Dict[str, Any], tags_order: Literal['prepend', 'append'] = ..., inject: bool = ...
) -> Service: ...
def component_schema_generator(attr_cls: Type[AttrsInstance], description: Optional[str] = ...) -> Dict[str, Any]: ...
def get_generator(
title: str,
components: Optional[List[Type[AttrsInstance]]] = ...,
tags: Optional[List[Dict[str, Any]]] = ...,
inject: bool = ...,
) -> OpenLLMSchemaGenerator: ...

View File

@@ -1,63 +0,0 @@
import functools, logging
from http import HTTPStatus
import orjson
from starlette.applications import Starlette
from starlette.responses import JSONResponse
from starlette.routing import Route
from openllm_core.utils import converter
from ._openapi import add_schema_definitions, append_schemas, get_generator
from ..protocol.hf import AgentRequest, AgentResponse, HFErrorResponse
schemas = get_generator(
'hf',
components=[AgentRequest, AgentResponse, HFErrorResponse],
tags=[
{
'name': 'HF',
'description': 'HF integration, including Agent and others schema endpoints.',
'externalDocs': 'https://huggingface.co/docs/transformers/main_classes/agent',
}
],
)
logger = logging.getLogger(__name__)
def mount_to_svc(svc, llm):
app = Starlette(
debug=True,
routes=[
Route('/agent', endpoint=functools.partial(hf_agent, llm=llm), name='hf_agent', methods=['POST']),
Route('/schema', endpoint=lambda req: schemas.OpenAPIResponse(req), include_in_schema=False),
],
)
mount_path = '/hf'
svc.mount_asgi_app(app, path=mount_path)
return append_schemas(svc, schemas.get_schema(routes=app.routes, mount_path=mount_path), tags_order='append')
def error_response(status_code, message):
return JSONResponse(
converter.unstructure(HFErrorResponse(message=message, error_code=status_code.value)),
status_code=status_code.value,
)
@add_schema_definitions
async def hf_agent(req, llm):
json_str = await req.body()
try:
request = converter.structure(orjson.loads(json_str), AgentRequest)
except orjson.JSONDecodeError as err:
logger.debug('Sent body: %s', json_str)
logger.error('Invalid JSON input received: %s', err)
return error_response(HTTPStatus.BAD_REQUEST, 'Invalid JSON input received (Check server log).')
stop = request.parameters.pop('stop', [])
try:
result = await llm.generate(request.inputs, stop=stop, **request.parameters)
return JSONResponse(
converter.unstructure([AgentResponse(generated_text=result.outputs[0].text)]), status_code=HTTPStatus.OK.value
)
except Exception as err:
logger.error('Error while generating: %s', err)
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, 'Error while generating (Check server log).')

View File

@@ -1,14 +0,0 @@
from http import HTTPStatus
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from bentoml import Service
from openllm_core._typing_compat import M, T
from .._llm import LLM
def mount_to_svc(svc: Service, llm: LLM[M, T]) -> Service: ...
def error_response(status_code: HTTPStatus, message: str) -> JSONResponse: ...
async def hf_agent(req: Request, llm: LLM[M, T]) -> Response: ...
def hf_adapters(req: Request, llm: LLM[M, T]) -> Response: ...

View File

@@ -1,448 +0,0 @@
import functools
import logging
import time
import traceback
from http import HTTPStatus
import orjson
from starlette.applications import Starlette
from starlette.responses import JSONResponse, StreamingResponse
from starlette.routing import Route
from openllm_core.utils import converter, gen_random_uuid
from ._openapi import add_schema_definitions, append_schemas, apply_schema, get_generator
from openllm_core.protocol.openai import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
ChatMessage,
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
Delta,
ErrorResponse,
LogProbs,
ModelCard,
ModelList,
UsageInfo,
)
schemas = get_generator(
'openai',
components=[
ErrorResponse,
ModelList,
ChatCompletionResponse,
ChatCompletionRequest,
ChatCompletionStreamResponse,
CompletionRequest,
CompletionResponse,
CompletionStreamResponse,
],
tags=[
{
'name': 'OpenAI',
'description': 'OpenAI Compatible API support',
'externalDocs': 'https://platform.openai.com/docs/api-reference/completions/object',
}
],
)
logger = logging.getLogger(__name__)
def jsonify_attr(obj):
return orjson.dumps(converter.unstructure(obj)).decode()
def error_response(status_code, message):
return JSONResponse(
{
'error': converter.unstructure(
ErrorResponse(message=message, type='invalid_request_error', code=str(status_code.value))
)
},
status_code=status_code.value,
)
async def check_model(request, model): # noqa
if request.model == model:
return None
return error_response(
HTTPStatus.NOT_FOUND,
f"Model '{request.model}' does not exists. Try 'GET /v1/models' to see available models.\nTip: If you are migrating from OpenAI, make sure to update your 'model' parameters in the request.",
)
def create_logprobs(token_ids, top_logprobs, num_output_top_logprobs=None, initial_text_offset=0, *, llm):
# Create OpenAI-style logprobs.
logprobs = LogProbs()
last_token_len = 0
if num_output_top_logprobs:
logprobs.top_logprobs = []
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
token_logprob = None
if step_top_logprobs is not None:
token_logprob = step_top_logprobs[token_id]
token = llm.tokenizer.convert_ids_to_tokens(token_id)
logprobs.tokens.append(token)
logprobs.token_logprobs.append(token_logprob)
if len(logprobs.text_offset) == 0:
logprobs.text_offset.append(initial_text_offset)
else:
logprobs.text_offset.append(logprobs.text_offset[-1] + last_token_len)
last_token_len = len(token)
if num_output_top_logprobs:
logprobs.top_logprobs.append(
{llm.tokenizer.convert_ids_to_tokens(i): p for i, p in step_top_logprobs.items()}
if step_top_logprobs
else None
)
return logprobs
def mount_to_svc(svc, llm):
list_models.__doc__ = list_models.__doc__.replace('__model_id__', llm.llm_type)
completions.__doc__ = completions.__doc__.replace('__model_id__', llm.llm_type)
chat_completions.__doc__ = chat_completions.__doc__.replace('__model_id__', llm.llm_type)
app = Starlette(
debug=True,
routes=[
Route(
'/models', functools.partial(apply_schema(list_models, __model_id__=llm.llm_type), llm=llm), methods=['GET']
),
Route(
'/completions',
functools.partial(apply_schema(completions, __model_id__=llm.llm_type), llm=llm),
methods=['POST'],
),
Route(
'/chat/completions',
functools.partial(
apply_schema(
chat_completions,
__model_id__=llm.llm_type,
__chat_template__=orjson.dumps(llm.config.chat_template).decode(),
__chat_messages__=orjson.dumps(llm.config.chat_messages).decode(),
__add_generation_prompt__=str(True) if llm.config.chat_messages is not None else str(False),
),
llm=llm,
),
methods=['POST'],
),
Route('/schema', endpoint=lambda req: schemas.OpenAPIResponse(req), include_in_schema=False),
],
)
svc.mount_asgi_app(app, path='/v1')
return append_schemas(svc, schemas.get_schema(routes=app.routes, mount_path='/v1'))
# GET /v1/models
@add_schema_definitions
def list_models(_, llm):
return JSONResponse(
converter.unstructure(ModelList(data=[ModelCard(id=llm.llm_type)])), status_code=HTTPStatus.OK.value
)
# POST /v1/chat/completions
@add_schema_definitions
async def chat_completions(req, llm):
# TODO: Check for length based on model context_length
json_str = await req.body()
try:
request = converter.structure(orjson.loads(json_str), ChatCompletionRequest)
except orjson.JSONDecodeError as err:
logger.debug('Sent body: %s', json_str)
logger.error('Invalid JSON input received: %s', err)
return error_response(HTTPStatus.BAD_REQUEST, 'Invalid JSON input received (Check server log).')
logger.debug('Received chat completion request: %s', request)
err_check = await check_model(request, llm.llm_type)
if err_check is not None:
return err_check
if request.logit_bias is not None and len(request.logit_bias) > 0:
return error_response(HTTPStatus.BAD_REQUEST, "'logit_bias' is not yet supported.")
model_name, request_id = request.model, gen_random_uuid('chatcmpl')
created_time = int(time.monotonic())
prompt = llm.tokenizer.apply_chat_template(
request.messages,
tokenize=False,
chat_template=request.chat_template if request.chat_template != 'None' else None,
add_generation_prompt=request.add_generation_prompt,
)
logger.debug('Prompt: %r', prompt)
config = llm.config.compatible_options(request)
def get_role() -> str:
return (
request.messages[-1]['role'] if not request.add_generation_prompt else 'assistant'
) # TODO: Support custom role here.
try:
result_generator = llm.generate_iterator(prompt, request_id=request_id, **config)
except Exception as err:
traceback.print_exc()
logger.error('Error generating completion: %s', err)
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
def create_stream_response_json(index, text, finish_reason=None, usage=None):
response = ChatCompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[
ChatCompletionResponseStreamChoice(index=index, delta=Delta(content=text), finish_reason=finish_reason)
],
)
if usage is not None:
response.usage = usage
return jsonify_attr(response)
async def completion_stream_generator():
# first chunk with role
role = get_role()
for i in range(config['n']):
yield f'data: {jsonify_attr(ChatCompletionStreamResponse(id=request_id, created=created_time, choices=[ChatCompletionResponseStreamChoice(index=i, delta=Delta(role=role), finish_reason=None)], model=model_name))}\n\n'
if request.echo:
last_message, last_content = request.messages[-1], ''
if last_message.get('content') and last_message.get('role') == role:
last_content = last_message['content']
if last_content:
for i in range(config['n']):
yield f'data: {jsonify_attr(ChatCompletionStreamResponse(id=request_id, created=created_time, choices=[ChatCompletionResponseStreamChoice(index=i, delta=Delta(content=last_content), finish_reason=None)], model=model_name))}\n\n'
previous_num_tokens = [0] * config['n']
finish_reason_sent = [False] * config['n']
async for res in result_generator:
for output in res.outputs:
if finish_reason_sent[output.index]:
continue
yield f'data: {create_stream_response_json(output.index, output.text)}\n\n'
previous_num_tokens[output.index] += len(output.token_ids)
if output.finish_reason is not None:
prompt_tokens = len(res.prompt_token_ids)
usage = UsageInfo(prompt_tokens, previous_num_tokens[i], prompt_tokens + previous_num_tokens[i])
yield f'data: {create_stream_response_json(output.index, "", output.finish_reason, usage)}\n\n'
finish_reason_sent[output.index] = True
yield 'data: [DONE]\n\n'
try:
# Streaming case
if request.stream:
return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
# Non-streaming case
final_result, texts, token_ids = None, [[]] * config['n'], [[]] * config['n']
async for res in result_generator:
if await req.is_disconnected():
return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
for output in res.outputs:
texts[output.index].append(output.text)
token_ids[output.index].extend(output.token_ids)
final_result = res
if final_result is None:
return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
final_result = final_result.model_copy(
update=dict(
outputs=[
output.model_copy(update=dict(text=''.join(texts[output.index]), token_ids=token_ids[output.index]))
for output in final_result.outputs
]
)
)
role = get_role()
choices = [
ChatCompletionResponseChoice(
index=output.index, message=ChatMessage(role=role, content=output.text), finish_reason=output.finish_reason
)
for output in final_result.outputs
]
if request.echo:
last_message, last_content = request.messages[-1], ''
if last_message.get('content') and last_message.get('role') == role:
last_content = last_message['content']
for choice in choices:
full_message = last_content + choice.message.content
choice.message.content = full_message
num_prompt_tokens = len(final_result.prompt_token_ids)
num_generated_tokens = sum(len(output.token_ids) for output in final_result.outputs)
usage = UsageInfo(num_prompt_tokens, num_generated_tokens, num_prompt_tokens + num_generated_tokens)
response = ChatCompletionResponse(
id=request_id, created=created_time, model=model_name, usage=usage, choices=choices
)
return JSONResponse(converter.unstructure(response), status_code=HTTPStatus.OK.value)
except Exception as err:
traceback.print_exc()
logger.error('Error generating completion: %s', err)
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
# POST /v1/completions
@add_schema_definitions
async def completions(req, llm):
# TODO: Check for length based on model context_length
json_str = await req.body()
try:
request = converter.structure(orjson.loads(json_str), CompletionRequest)
except orjson.JSONDecodeError as err:
logger.debug('Sent body: %s', json_str)
logger.error('Invalid JSON input received: %s', err)
return error_response(HTTPStatus.BAD_REQUEST, 'Invalid JSON input received (Check server log).')
logger.debug('Received legacy completion request: %s', request)
err_check = await check_model(request, llm.llm_type)
if err_check is not None:
return err_check
# OpenAI API supports echoing the prompt when max_tokens is 0.
echo_without_generation = request.echo and request.max_tokens == 0
if echo_without_generation:
request.max_tokens = 1 # XXX: Hack to make sure we get the prompt back.
if request.suffix is not None:
return error_response(HTTPStatus.BAD_REQUEST, "'suffix' is not yet supported.")
if request.logit_bias is not None and len(request.logit_bias) > 0:
return error_response(HTTPStatus.BAD_REQUEST, "'logit_bias' is not yet supported.")
if not request.prompt:
return error_response(HTTPStatus.BAD_REQUEST, 'Please provide a prompt.')
prompt = request.prompt
# TODO: Support multiple prompts
model_name, request_id = request.model, gen_random_uuid('cmpl')
created_time = int(time.monotonic())
config = llm.config.compatible_options(request)
try:
result_generator = llm.generate_iterator(prompt, request_id=request_id, **config)
except Exception as err:
traceback.print_exc()
logger.error('Error generating completion: %s', err)
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
# best_of != n then we don't stream
# TODO: support use_beam_search
stream = request.stream and (config['best_of'] is None or config['n'] == config['best_of'])
def create_stream_response_json(index, text, logprobs=None, finish_reason=None, usage=None):
response = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[CompletionResponseStreamChoice(index=index, text=text, logprobs=logprobs, finish_reason=finish_reason)],
)
if usage:
response.usage = usage
return jsonify_attr(response)
async def completion_stream_generator():
previous_num_tokens = [0] * config['n']
previous_texts = [''] * config['n']
previous_echo = [False] * config['n']
async for res in result_generator:
for output in res.outputs:
i = output.index
delta_text = output.text
token_ids = output.token_ids
logprobs = None
top_logprobs = None
if request.logprobs is not None:
top_logprobs = output.logprobs[previous_num_tokens[i] :]
if request.echo and not previous_echo[i]:
if not echo_without_generation:
delta_text = res.prompt + delta_text
token_ids = res.prompt_token_ids + token_ids
if top_logprobs:
top_logprobs = res.prompt_logprobs + top_logprobs
else:
delta_text = res.prompt
token_ids = res.prompt_token_ids
if top_logprobs:
top_logprobs = res.prompt_logprobs
previous_echo[i] = True
if request.logprobs is not None:
logprobs = create_logprobs(
output.token_ids,
output.logprobs[previous_num_tokens[i] :],
request.logprobs,
len(previous_texts[i]),
llm=llm,
)
previous_num_tokens[i] += len(output.token_ids)
previous_texts[i] += output.text
yield f'data: {create_stream_response_json(index=i, text=output.text, logprobs=logprobs, finish_reason=output.finish_reason)}\n\n'
if output.finish_reason is not None:
logprobs = LogProbs() if request.logprobs is not None else None
prompt_tokens = len(res.prompt_token_ids)
usage = UsageInfo(prompt_tokens, previous_num_tokens[i], prompt_tokens + previous_num_tokens[i])
yield f'data: {create_stream_response_json(i, "", logprobs, output.finish_reason, usage)}\n\n'
yield 'data: [DONE]\n\n'
try:
# Streaming case
if stream:
return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
# Non-streaming case
final_result, texts, token_ids = None, [[]] * config['n'], [[]] * config['n']
async for res in result_generator:
if await req.is_disconnected():
return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
for output in res.outputs:
texts[output.index].append(output.text)
token_ids[output.index].extend(output.token_ids)
final_result = res
if final_result is None:
return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
final_result = final_result.model_copy(
update=dict(
outputs=[
output.model_copy(update=dict(text=''.join(texts[output.index]), token_ids=token_ids[output.index]))
for output in final_result.outputs
]
)
)
choices = []
prompt_token_ids = final_result.prompt_token_ids
prompt_logprobs = final_result.prompt_logprobs
prompt_text = final_result.prompt
for output in final_result.outputs:
logprobs = None
if request.logprobs is not None:
if not echo_without_generation:
token_ids, top_logprobs = output.token_ids, output.logprobs
if request.echo:
token_ids, top_logprobs = prompt_token_ids + token_ids, prompt_logprobs + top_logprobs
else:
token_ids, top_logprobs = prompt_token_ids, prompt_logprobs
logprobs = create_logprobs(token_ids, top_logprobs, request.logprobs, llm=llm)
if not echo_without_generation:
output_text = output.text
if request.echo:
output_text = prompt_text + output_text
else:
output_text = prompt_text
choice_data = CompletionResponseChoice(
index=output.index, text=output_text, logprobs=logprobs, finish_reason=output.finish_reason
)
choices.append(choice_data)
num_prompt_tokens = len(final_result.prompt_token_ids)
num_generated_tokens = sum(len(output.token_ids) for output in final_result.outputs)
usage = UsageInfo(num_prompt_tokens, num_generated_tokens, num_prompt_tokens + num_generated_tokens)
response = CompletionResponse(id=request_id, created=created_time, model=model_name, usage=usage, choices=choices)
return JSONResponse(converter.unstructure(response), status_code=HTTPStatus.OK.value)
except Exception as err:
traceback.print_exc()
logger.error('Error generating completion: %s', err)
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')

View File

@@ -1,30 +0,0 @@
from http import HTTPStatus
from typing import Dict, List, Optional, Union
from attr import AttrsInstance
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from bentoml import Service
from openllm_core._typing_compat import M, T
from .._llm import LLM
from ..protocol.openai import ChatCompletionRequest, CompletionRequest, LogProbs
def mount_to_svc(svc: Service, llm: LLM[M, T]) -> Service: ...
def jsonify_attr(obj: AttrsInstance) -> str: ...
def error_response(status_code: HTTPStatus, message: str) -> JSONResponse: ...
async def check_model(
request: Union[CompletionRequest, ChatCompletionRequest], model: str
) -> Optional[JSONResponse]: ...
def create_logprobs(
token_ids: List[int],
top_logprobs: List[Dict[int, float]], #
num_output_top_logprobs: Optional[int] = ...,
initial_text_offset: int = ...,
*,
llm: LLM[M, T],
) -> LogProbs: ...
def list_models(req: Request, llm: LLM[M, T]) -> Response: ...
async def chat_completions(req: Request, llm: LLM[M, T]) -> Response: ...
async def completions(req: Request, llm: LLM[M, T]) -> Response: ...

View File

@@ -1,366 +0,0 @@
from __future__ import annotations
import functools, logging, os, typing as t
import bentoml, openllm, click, inflection, click_option_group as cog
from bentoml_cli.utils import BentoMLCommandGroup
from click import shell_completion as sc
from openllm_core._configuration import LLMConfig
from openllm_core._typing_compat import (
Concatenate,
DictStrAny,
LiteralBackend,
LiteralSerialisation,
ParamSpec,
AnyCallable,
get_literal_args,
)
from openllm_core.utils import DEBUG, compose, dantic, resolve_user_filepath
logger = logging.getLogger(__name__)
P = ParamSpec('P')
LiteralOutput = t.Literal['json', 'pretty', 'porcelain']
_AnyCallable = t.Callable[..., t.Any]
FC = t.TypeVar('FC', bound=t.Union[_AnyCallable, click.Command])
def bento_complete_envvar(ctx: click.Context, param: click.Parameter, incomplete: str) -> list[sc.CompletionItem]:
return [
sc.CompletionItem(str(it.tag), help='Bento')
for it in bentoml.list()
if str(it.tag).startswith(incomplete) and all(k in it.info.labels for k in {'start_name', 'bundler'})
]
def model_complete_envvar(ctx: click.Context, param: click.Parameter, incomplete: str) -> list[sc.CompletionItem]:
return [
sc.CompletionItem(inflection.dasherize(it), help='Model')
for it in openllm.CONFIG_MAPPING
if it.startswith(incomplete)
]
def parse_config_options(
config: LLMConfig,
server_timeout: int,
workers_per_resource: float,
device: t.Tuple[str, ...] | None,
cors: bool,
environ: DictStrAny,
) -> DictStrAny:
# TODO: Support amd.com/gpu on k8s
_bentoml_config_options_env = environ.pop('BENTOML_CONFIG_OPTIONS', '')
_bentoml_config_options_opts = [
'tracing.sample_rate=1.0',
'api_server.max_runner_connections=25',
f'runners."llm-{config["start_name"]}-runner".batching.max_batch_size=128',
f'api_server.traffic.timeout={server_timeout}',
f'runners."llm-{config["start_name"]}-runner".traffic.timeout={config["timeout"]}',
f'runners."llm-{config["start_name"]}-runner".workers_per_resource={workers_per_resource}',
]
if device:
if len(device) > 1:
_bentoml_config_options_opts.extend([
f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"[{idx}]={dev}'
for idx, dev in enumerate(device)
])
else:
_bentoml_config_options_opts.append(
f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"=[{device[0]}]'
)
if cors:
_bentoml_config_options_opts.extend([
'api_server.http.cors.enabled=true',
'api_server.http.cors.access_control_allow_origins="*"',
])
_bentoml_config_options_opts.extend([
f'api_server.http.cors.access_control_allow_methods[{idx}]="{it}"'
for idx, it in enumerate(['GET', 'OPTIONS', 'POST', 'HEAD', 'PUT'])
])
_bentoml_config_options_env += ' ' if _bentoml_config_options_env else '' + ' '.join(_bentoml_config_options_opts)
environ['BENTOML_CONFIG_OPTIONS'] = _bentoml_config_options_env
if DEBUG:
logger.debug('Setting BENTOML_CONFIG_OPTIONS=%s', _bentoml_config_options_env)
return environ
_adapter_mapping_key = 'adapter_map'
def _id_callback(ctx: click.Context, _: click.Parameter, value: t.Tuple[str, ...] | None) -> None:
if not value:
return None
if _adapter_mapping_key not in ctx.params:
ctx.params[_adapter_mapping_key] = {}
for v in value:
adapter_id, *adapter_name = v.rsplit(':', maxsplit=1)
# try to resolve the full path if users pass in relative,
# currently only support one level of resolve path with current directory
try:
adapter_id = resolve_user_filepath(adapter_id, os.getcwd())
except FileNotFoundError:
pass
name = adapter_name[0] if len(adapter_name) > 0 else 'default'
ctx.params[_adapter_mapping_key][adapter_id] = name
return None
def optimization_decorator(fn: FC, *, factory=click, _eager=True) -> FC | list[AnyCallable]:
shared = [
dtype_option(factory=factory),
revision_option(factory=factory), #
backend_option(factory=factory),
quantize_option(factory=factory), #
serialisation_option(factory=factory),
]
if not _eager:
return shared
return compose(*shared)(fn)
# NOTE: A list of bentoml option that is not needed for parsing.
# NOTE: User shouldn't set '--working-dir', as OpenLLM will setup this.
# NOTE: production is also deprecated
_IGNORED_OPTIONS = {'working_dir', 'production', 'protocol_version'}
def parse_serve_args() -> t.Callable[[t.Callable[..., LLMConfig]], t.Callable[[FC], FC]]:
from bentoml_cli.cli import cli
group = cog.optgroup.group(
'Start a HTTP server options', help='Related to serving the model [synonymous to `bentoml serve-http`]'
)
def decorator(f: t.Callable[Concatenate[int, t.Optional[str], P], LLMConfig]) -> t.Callable[[FC], FC]:
serve_command = cli.commands['serve']
# The first variable is the argument bento
# The last five is from BentoMLCommandGroup.NUMBER_OF_COMMON_PARAMS
serve_options = [
p
for p in serve_command.params[1 : -BentoMLCommandGroup.NUMBER_OF_COMMON_PARAMS]
if p.name not in _IGNORED_OPTIONS
]
for options in reversed(serve_options):
attrs = options.to_info_dict()
# we don't need param_type_name, since it should all be options
attrs.pop('param_type_name')
# name is not a valid args
attrs.pop('name')
# type can be determine from default value
attrs.pop('type')
param_decls = (*attrs.pop('opts'), *attrs.pop('secondary_opts'))
f = cog.optgroup.option(*param_decls, **attrs)(f)
return group(f)
return decorator
def start_decorator(fn: FC) -> FC:
composed = compose(
cog.optgroup.group(
'LLM Options',
help="""The following options are related to running LLM Server as well as optimization options.
OpenLLM supports running model k-bit quantization (8-bit, 4-bit), GPTQ quantization, PagedAttention via vLLM.
The following are either in our roadmap or currently being worked on:
- DeepSpeed Inference: [link](https://www.deepspeed.ai/inference/)
- GGML: Fast inference on [bare metal](https://github.com/ggerganov/ggml)
""",
),
*optimization_decorator(fn, factory=cog.optgroup, _eager=False),
cog.optgroup.option(
'--device',
type=dantic.CUDA,
multiple=True,
envvar='CUDA_VISIBLE_DEVICES',
callback=parse_device_callback,
help='Assign GPU devices (if available)',
show_envvar=True,
),
adapter_id_option(factory=cog.optgroup),
)
return composed(fn)
def parse_device_callback(
_: click.Context, param: click.Parameter, value: tuple[tuple[str], ...] | None
) -> t.Tuple[str, ...] | None:
if value is None:
return value
el: t.Tuple[str, ...] = tuple(i for k in value for i in k)
# NOTE: --device all is a special case
if len(el) == 1 and el[0] == 'all':
return tuple(map(str, openllm.utils.available_devices()))
return el
def _click_factory_type(*param_decls: t.Any, **attrs: t.Any) -> t.Callable[[FC | None], FC]:
"""General ``@click`` decorator with some sauce.
This decorator extends the default ``@click.option`` plus a factory option and factory attr to
provide type-safe click.option or click.argument wrapper for all compatible factory.
"""
factory = attrs.pop('factory', click)
factory_attr = attrs.pop('attr', 'option')
if factory_attr != 'argument':
attrs.setdefault('help', 'General option for OpenLLM CLI.')
def decorator(f: FC | None) -> FC:
callback = getattr(factory, factory_attr, None)
if callback is None:
raise ValueError(f'Factory {factory} has no attribute {factory_attr}.')
return t.cast(FC, callback(*param_decls, **attrs)(f) if f is not None else callback(*param_decls, **attrs))
return decorator
cli_option = functools.partial(_click_factory_type, attr='option')
cli_argument = functools.partial(_click_factory_type, attr='argument')
def adapter_id_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option(
'--adapter-id',
default=None,
help='Optional name or path for given LoRA adapter',
multiple=True,
callback=_id_callback,
metavar='[PATH | [remote/][adapter_name:]adapter_id][, ...]',
**attrs,
)(f)
def cors_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option(
'--cors/--no-cors',
show_default=True,
default=False,
envvar='OPENLLM_CORS',
show_envvar=True,
help='Enable CORS for the server.',
**attrs,
)(f)
def machine_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option('--machine', is_flag=True, default=False, hidden=True, **attrs)(f)
def dtype_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option(
'--dtype',
type=str,
envvar='TORCH_DTYPE',
default='auto',
help="Optional dtype for casting tensors for running inference ['float16', 'float32', 'bfloat16', 'int8', 'int16']",
**attrs,
)(f)
def model_id_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option(
'--model-id',
type=click.STRING,
default=None,
envvar='OPENLLM_MODEL_ID',
show_envvar=True,
help='Optional model_id name or path for (fine-tune) weight.',
**attrs,
)(f)
def revision_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option(
'--revision',
'--model-version',
'model_version',
type=click.STRING,
default=None,
help='Optional model revision to save for this model. It will be inferred automatically from model-id.',
**attrs,
)(f)
def backend_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option(
'--backend',
type=click.Choice(get_literal_args(LiteralBackend)),
default=None,
envvar='OPENLLM_BACKEND',
show_envvar=True,
help='Runtime to use for both serialisation/inference engine.',
**attrs,
)(f)
def model_name_argument(f: _AnyCallable | None = None, required: bool = True, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_argument(
'model_name',
type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING]),
required=required,
**attrs,
)(f)
def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option(
'--quantise',
'--quantize',
'quantize',
type=str,
default=None,
envvar='OPENLLM_QUANTIZE',
show_envvar=True,
help="""Dynamic quantization for running this LLM.
The following quantization strategies are supported:
- ``int8``: ``LLM.int8`` for [8-bit](https://arxiv.org/abs/2208.07339) quantization.
- ``int4``: ``SpQR`` for [4-bit](https://arxiv.org/abs/2306.03078) quantization.
- ``gptq``: ``GPTQ`` [quantization](https://arxiv.org/abs/2210.17323)
- ``awq``: ``AWQ`` [AWQ: Activation-aware Weight Quantization](https://arxiv.org/abs/2306.00978)
- ``squeezellm``: ``SqueezeLLM`` [SqueezeLLM: Dense-and-Sparse Quantization](https://arxiv.org/abs/2306.07629)
> [!NOTE] that the model can also be served with quantized weights.
"""
+ (
"""
> [!NOTE] that this will set the mode for serving within deployment."""
if build
else ''
),
**attrs,
)(f)
def serialisation_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
return cli_option(
'--serialisation',
'--serialization',
'serialisation',
type=click.Choice(get_literal_args(LiteralSerialisation)),
default=None,
show_default=True,
show_envvar=True,
envvar='OPENLLM_SERIALIZATION',
help="""Serialisation format for save/load LLM.
Currently the following strategies are supported:
- ``safetensors``: This will use safetensors format, which is synonymous to ``safe_serialization=True``.
> [!NOTE] Safetensors might not work for every cases, and you can always fallback to ``legacy`` if needed.
- ``legacy``: This will use PyTorch serialisation format, often as ``.bin`` files. This should be used if the model doesn't yet support safetensors.
""",
**attrs,
)(f)

View File

@@ -1,276 +0,0 @@
from __future__ import annotations
import itertools, logging, os, re, subprocess, sys, typing as t, bentoml, openllm_core, orjson
from simple_di import Provide, inject
from bentoml._internal.configuration.containers import BentoMLContainer
from openllm_core._typing_compat import LiteralSerialisation
from openllm_core.exceptions import OpenLLMException
from openllm_core.utils import WARNING_ENV_VAR, codegen, first_not_none, get_disable_warnings, is_vllm_available
if t.TYPE_CHECKING:
from bentoml._internal.bento import BentoStore
from openllm_core._configuration import LLMConfig
from openllm_core._typing_compat import LiteralBackend, LiteralQuantise, LiteralString
logger = logging.getLogger(__name__)
def _start(
model_id: str,
timeout: int = 30,
workers_per_resource: t.Literal['conserved', 'round_robin'] | float | None = None,
device: tuple[str, ...] | t.Literal['all'] | None = None,
quantize: LiteralQuantise | None = None,
adapter_map: dict[LiteralString, str | None] | None = None,
backend: LiteralBackend | None = None,
additional_args: list[str] | None = None,
cors: bool = False,
__test__: bool = False,
**_: t.Any,
) -> LLMConfig | subprocess.Popen[bytes]:
"""Python API to start a LLM server. These provides one-to-one mapping to CLI arguments.
For all additional arguments, pass it as string to ``additional_args``. For example, if you want to
pass ``--port 5001``, you can pass ``additional_args=["--port", "5001"]``
> [!NOTE] This will create a blocking process, so if you use this API, you can create a running sub thread
> to start the server instead of blocking the main thread.
``openllm.start`` will invoke ``click.Command`` under the hood, so it behaves exactly the same as the CLI interaction.
Args:
model_id: The model id to start this LLMServer
timeout: The server timeout
workers_per_resource: Number of workers per resource assigned.
See [resource scheduling](https://docs.bentoml.org/en/latest/guides/scheduling.html#resource-scheduling-strategy)
for more information. By default, this is set to 1.
> [!NOTE] ``--workers-per-resource`` will also accept the following strategies:
> - ``round_robin``: Similar behaviour when setting ``--workers-per-resource 1``. This is useful for smaller models.
> - ``conserved``: This will determine the number of available GPU resources, and only assign
> one worker for the LLMRunner. For example, if ther are 4 GPUs available, then ``conserved`` is
> equivalent to ``--workers-per-resource 0.25``.
device: Assign GPU devices (if available) to this LLM. By default, this is set to ``None``. It also accepts 'all'
argument to assign all available GPUs to this LLM.
quantize: Quantize the model weights. This is only applicable for PyTorch models.
Possible quantisation strategies:
- int8: Quantize the model with 8bit (bitsandbytes required)
- int4: Quantize the model with 4bit (bitsandbytes required)
- gptq: Quantize the model with GPTQ (auto-gptq required)
cors: Whether to enable CORS for this LLM. By default, this is set to ``False``.
adapter_map: The adapter mapping of LoRA to use for this LLM. It accepts a dictionary of ``{adapter_id: adapter_name}``.
backend: The backend to use for this LLM. By default, this is set to ``pt``.
additional_args: Additional arguments to pass to ``openllm start``.
"""
from .entrypoint import start_command
os.environ['BACKEND'] = openllm_core.utils.first_not_none(backend, default='vllm' if is_vllm_available() else 'pt')
args: list[str] = [model_id]
if timeout:
args.extend(['--server-timeout', str(timeout)])
if workers_per_resource:
args.extend([
'--workers-per-resource',
str(workers_per_resource) if not isinstance(workers_per_resource, str) else workers_per_resource,
])
if device and not os.environ.get('CUDA_VISIBLE_DEVICES'):
args.extend(['--device', ','.join(device)])
if quantize:
args.extend(['--quantize', str(quantize)])
if cors:
args.append('--cors')
if adapter_map:
args.extend(
list(
itertools.chain.from_iterable([['--adapter-id', f"{k}{':' + v if v else ''}"] for k, v in adapter_map.items()])
)
)
if additional_args:
args.extend(additional_args)
if __test__:
args.append('--return-process')
cmd = start_command
return cmd.main(args=args, standalone_mode=False)
@inject
def _build(
model_id: str,
model_version: str | None = None,
bento_version: str | None = None,
quantize: LiteralQuantise | None = None,
adapter_map: dict[str, str | None] | None = None,
build_ctx: str | None = None,
enable_features: tuple[str, ...] | None = None,
dockerfile_template: str | None = None,
overwrite: bool = False,
push: bool = False,
force_push: bool = False,
containerize: bool = False,
serialisation: LiteralSerialisation | None = None,
additional_args: list[str] | None = None,
bento_store: BentoStore = Provide[BentoMLContainer.bento_store],
) -> bentoml.Bento:
"""Package a LLM into a BentoLLM.
The LLM will be built into a BentoService with the following structure:
if ``quantize`` is passed, it will instruct the model to be quantized dynamically during serving time.
``openllm.build`` will invoke ``click.Command`` under the hood, so it behaves exactly the same as ``openllm build`` CLI.
Args:
model_id: The model id to build this BentoLLM
model_version: Optional model version for this given LLM
bento_version: Optional bento veresion for this given BentoLLM
quantize: Quantize the model weights. This is only applicable for PyTorch models.
Possible quantisation strategies:
- int8: Quantize the model with 8bit (bitsandbytes required)
- int4: Quantize the model with 4bit (bitsandbytes required)
- gptq: Quantize the model with GPTQ (auto-gptq required)
adapter_map: The adapter mapping of LoRA to use for this LLM. It accepts a dictionary of ``{adapter_id: adapter_name}``.
build_ctx: The build context to use for building BentoLLM. By default, it sets to current directory.
enable_features: Additional OpenLLM features to be included with this BentoLLM.
dockerfile_template: The dockerfile template to use for building BentoLLM. See https://docs.bentoml.com/en/latest/guides/containerization.html#dockerfile-template.
overwrite: Whether to overwrite the existing BentoLLM. By default, this is set to ``False``.
push: Whether to push the result bento to BentoCloud. Make sure to login with 'bentoml cloud login' first.
containerize: Whether to containerize the Bento after building. '--containerize' is the shortcut of 'openllm build && bentoml containerize'.
Note that 'containerize' and 'push' are mutually exclusive
container_registry: Container registry to choose the base OpenLLM container image to build from. Default to ECR.
serialisation: Serialisation for saving models. Default to 'safetensors', which is equivalent to `safe_serialization=True`
additional_args: Additional arguments to pass to ``openllm build``.
bento_store: Optional BentoStore for saving this BentoLLM. Default to the default BentoML local store.
Returns:
``bentoml.Bento | str``: BentoLLM instance. This can be used to serve the LLM or can be pushed to BentoCloud.
"""
from openllm.serialisation.transformers.weights import has_safetensors_weights
args: list[str] = [
sys.executable,
'-m',
'openllm',
'build',
model_id,
'--machine',
'--quiet',
'--serialisation',
first_not_none(
serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'
),
]
if quantize:
args.extend(['--quantize', quantize])
if containerize and push:
raise OpenLLMException("'containerize' and 'push' are currently mutually exclusive.")
if push:
args.extend(['--push'])
if containerize:
args.extend(['--containerize'])
if build_ctx:
args.extend(['--build-ctx', build_ctx])
if enable_features:
args.extend([f'--enable-features={f}' for f in enable_features])
if overwrite:
args.append('--overwrite')
if adapter_map:
args.extend([f"--adapter-id={k}{':' + v if v is not None else ''}" for k, v in adapter_map.items()])
if model_version:
args.extend(['--model-version', model_version])
if bento_version:
args.extend(['--bento-version', bento_version])
if dockerfile_template:
args.extend(['--dockerfile-template', dockerfile_template])
if additional_args:
args.extend(additional_args)
if force_push:
args.append('--force-push')
current_disable_warning = get_disable_warnings()
os.environ[WARNING_ENV_VAR] = str(True)
try:
output = subprocess.check_output(args, env=os.environ.copy(), cwd=build_ctx or os.getcwd())
except subprocess.CalledProcessError as e:
logger.error("Exception caught while building Bento for '%s'", model_id, exc_info=e)
if e.stderr:
raise OpenLLMException(e.stderr.decode('utf-8')) from None
raise OpenLLMException(str(e)) from None
matched = re.match(r'__object__:(\{.*\})$', output.decode('utf-8').strip())
if matched is None:
raise ValueError(
f"Failed to find tag from output: {output.decode('utf-8').strip()}\nNote: Output from 'openllm build' might not be correct. Please open an issue on GitHub."
)
os.environ[WARNING_ENV_VAR] = str(current_disable_warning)
try:
result = orjson.loads(matched.group(1))
except orjson.JSONDecodeError as e:
raise ValueError(
f"Failed to decode JSON from output: {output.decode('utf-8').strip()}\nNote: Output from 'openllm build' might not be correct. Please open an issue on GitHub."
) from e
return bentoml.get(result['tag'], _bento_store=bento_store)
def _import_model(
model_id: str,
model_version: str | None = None,
backend: LiteralBackend | None = None,
quantize: LiteralQuantise | None = None,
serialisation: LiteralSerialisation | None = None,
additional_args: t.Sequence[str] | None = None,
) -> dict[str, t.Any]:
"""Import a LLM into local store.
> [!NOTE]
> If ``quantize`` is passed, the model weights will be saved as quantized weights. You should
> only use this option if you want the weight to be quantized by default. Note that OpenLLM also
> support on-demand quantisation during initial startup.
``openllm.import_model`` will invoke ``click.Command`` under the hood, so it behaves exactly the same as the CLI ``openllm import``.
> [!NOTE]
> ``openllm.start`` will automatically invoke ``openllm.import_model`` under the hood.
Args:
model_id: required model id for this given LLM
model_version: Optional model version for this given LLM
backend: The backend to use for this LLM. By default, this is set to ``pt``.
quantize: Quantize the model weights. This is only applicable for PyTorch models.
Possible quantisation strategies:
- int8: Quantize the model with 8bit (bitsandbytes required)
- int4: Quantize the model with 4bit (bitsandbytes required)
- gptq: Quantize the model with GPTQ (auto-gptq required)
serialisation: Type of model format to save to local store. If set to 'safetensors', then OpenLLM will save model using safetensors. Default behaviour is similar to ``safe_serialization=False``.
additional_args: Additional arguments to pass to ``openllm import``.
Returns:
``bentoml.Model``:BentoModel of the given LLM. This can be used to serve the LLM or can be pushed to BentoCloud.
"""
from .entrypoint import import_command
args = [model_id, '--quiet']
if backend is not None:
args.extend(['--backend', backend])
if model_version is not None:
args.extend(['--model-version', str(model_version)])
if quantize is not None:
args.extend(['--quantize', quantize])
if serialisation is not None:
args.extend(['--serialisation', serialisation])
if additional_args is not None:
args.extend(additional_args)
return import_command.main(args=args, standalone_mode=False)
def _list_models() -> dict[str, t.Any]:
"""List all available models within the local store."""
from .entrypoint import models_command
return models_command.main(args=['--quiet'], standalone_mode=False)
start, build, import_model, list_models = (
codegen.gen_sdk(_start),
codegen.gen_sdk(_build),
codegen.gen_sdk(_import_model),
codegen.gen_sdk(_list_models),
)
__all__ = ['build', 'import_model', 'list_models', 'start']

View File

@@ -0,0 +1,5 @@
# mainly for backward compatible entrypoint.
from _openllm_tiny._entrypoint import cli as cli
if __name__ == '__main__':
cli()

View File

@@ -1,16 +0,0 @@
"""OpenLLM CLI Extension.
The following directory contains all possible extensions for OpenLLM CLI
For adding new extension, just simply name that ext to `<name_ext>.py` and define
a ``click.command()`` with the following format:
```python
import click
@click.command(<name_ext>)
...
def cli(...): # <- this is important here, it should always name CLI in order for the extension resolver to know how to import this extensions.
```
NOTE: Make sure to keep this file blank such that it won't mess with the import order.
"""

View File

@@ -1,43 +0,0 @@
from __future__ import annotations
import shutil
import subprocess
import typing as t
import click
import psutil
from simple_di import Provide, inject
import bentoml
from bentoml._internal.configuration.containers import BentoMLContainer
from openllm_cli import termui
from openllm_cli._factory import bento_complete_envvar, machine_option
if t.TYPE_CHECKING:
from bentoml._internal.bento import BentoStore
@click.command('dive_bentos', context_settings=termui.CONTEXT_SETTINGS)
@click.argument('bento', type=str, shell_complete=bento_complete_envvar)
@machine_option
@click.pass_context
@inject
def cli(
ctx: click.Context, bento: str, machine: bool, _bento_store: BentoStore = Provide[BentoMLContainer.bento_store]
) -> str | None:
"""Dive into a BentoLLM. This is synonymous to cd $(b get <bento>:<tag> -o path)."""
try:
bentomodel = _bento_store.get(bento)
except bentoml.exceptions.NotFound:
ctx.fail(f'Bento {bento} not found. Make sure to call `openllm build` first.')
if 'bundler' not in bentomodel.info.labels or bentomodel.info.labels['bundler'] != 'openllm.bundle':
ctx.fail(
f"Bento is either too old or not built with OpenLLM. Make sure to use ``openllm build {bentomodel.info.labels['start_name']}`` for correctness."
)
if machine:
return bentomodel.path
# copy and paste this into a new shell
if psutil.WINDOWS:
subprocess.check_call([shutil.which('dir') or 'dir'], cwd=bentomodel.path)
else:
subprocess.check_call([shutil.which('ls') or 'ls', '-Rrthla'], cwd=bentomodel.path)
ctx.exit(0)

View File

@@ -1,53 +0,0 @@
from __future__ import annotations
import typing as t
import click
from simple_di import Provide, inject
import bentoml
from bentoml._internal.bento.bento import BentoInfo
from bentoml._internal.bento.build_config import DockerOptions
from bentoml._internal.configuration.containers import BentoMLContainer
from bentoml._internal.container.generate import generate_containerfile
from openllm_cli import termui
from openllm_cli._factory import bento_complete_envvar
from openllm_core.utils import converter
if t.TYPE_CHECKING:
from bentoml._internal.bento import BentoStore
@click.command(
'get_containerfile', context_settings=termui.CONTEXT_SETTINGS, help='Return Containerfile of any given Bento.'
)
@click.argument('bento', type=str, shell_complete=bento_complete_envvar)
@click.pass_context
@inject
def cli(ctx: click.Context, bento: str, _bento_store: BentoStore = Provide[BentoMLContainer.bento_store]) -> str:
try:
bentomodel = _bento_store.get(bento)
except bentoml.exceptions.NotFound:
ctx.fail(f'Bento {bento} not found. Make sure to call `openllm build` first.')
# The logic below are similar to bentoml._internal.container.construct_containerfile
with open(bentomodel.path_of('bento.yaml'), 'r') as f:
options = BentoInfo.from_yaml_file(f)
# NOTE: dockerfile_template is already included in the
# Dockerfile inside bento, and it is not relevant to
# construct_containerfile. Hence it is safe to set it to None here.
# See https://github.com/bentoml/BentoML/issues/3399.
docker_attrs = converter.unstructure(options.docker)
# NOTE: if users specify a dockerfile_template, we will
# save it to /env/docker/Dockerfile.template. This is necessary
# for the reconstruction of the Dockerfile.
if 'dockerfile_template' in docker_attrs and docker_attrs['dockerfile_template'] is not None:
docker_attrs['dockerfile_template'] = 'env/docker/Dockerfile.template'
doc = generate_containerfile(
docker=DockerOptions(**docker_attrs),
build_ctx=bentomodel.path,
conda=options.conda,
bento_fs=bentomodel._fs,
enable_buildkit=True,
add_header=True,
)
termui.echo(doc, fg='white')
return bentomodel.path

View File

@@ -1,179 +0,0 @@
from __future__ import annotations
import logging
import string
import typing as t
import attr
import click
import inflection
import orjson
from bentoml_cli.utils import opt_callback
import openllm
from openllm_cli import termui
from openllm_cli._factory import model_complete_envvar
logger = logging.getLogger(__name__)
class PromptFormatter(string.Formatter):
def vformat(self, format_string: str, args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> t.Any:
if len(args) > 0:
raise ValueError('Positional arguments are not supported')
return super().vformat(format_string, args, kwargs)
def check_unused_args(
self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]
) -> None:
extras = set(kwargs).difference(used_args)
if extras:
raise KeyError(f'Extra params passed: {extras}')
def extract_template_variables(self, template: str) -> t.Sequence[str]:
return [field[1] for field in self.parse(template) if field[1] is not None]
default_formatter = PromptFormatter()
# equivocal setattr to save one lookup per assignment
_object_setattr = object.__setattr__
@attr.define(slots=True)
class PromptTemplate:
template: str
_input_variables: t.Sequence[str] = attr.field(init=False)
def __attrs_post_init__(self) -> None:
self._input_variables = default_formatter.extract_template_variables(self.template)
def with_options(self, **attrs: t.Any) -> PromptTemplate:
prompt_variables = {key: '{' + key + '}' if key not in attrs else attrs[key] for key in self._input_variables}
o = attr.evolve(self, template=self.template.format(**prompt_variables))
_object_setattr(o, '_input_variables', default_formatter.extract_template_variables(o.template))
return o
def format(self, **attrs: t.Any) -> str:
prompt_variables = {k: v for k, v in attrs.items() if k in self._input_variables}
try:
return self.template.format(**prompt_variables)
except KeyError as e:
raise RuntimeError(
f"Missing variable '{e.args[0]}' (required: {self._input_variables}) in the prompt template."
) from None
@click.command('get_prompt', context_settings=termui.CONTEXT_SETTINGS)
@click.argument('model_id', shell_complete=model_complete_envvar)
@click.argument('prompt', type=click.STRING)
@click.option('--prompt-template-file', type=click.File(), default=None)
@click.option('--chat-template-file', type=click.File(), default=None)
@click.option('--system-message', type=str, default=None)
@click.option(
'--add-generation-prompt/--no-add-generation-prompt',
default=False,
help='See https://huggingface.co/docs/transformers/main/chat_templating#what-template-should-i-use. This only applicable if model-id is a HF model_id',
)
@click.option(
'--opt',
help="Define additional prompt variables. (format: ``--opt system_message='You are a useful assistant'``)",
required=False,
multiple=True,
callback=opt_callback,
metavar='ARG=VALUE[,ARG=VALUE]',
)
@click.pass_context
def cli(
ctx: click.Context,
/,
model_id: str,
prompt: str,
prompt_template_file: t.IO[t.Any] | None,
chat_template_file: t.IO[t.Any] | None,
system_message: str | None,
add_generation_prompt: bool,
_memoized: dict[str, t.Any],
**_: t.Any,
) -> str | None:
"""Helpers for generating prompts.
\b
It accepts remote HF model_ids as well as model name passed to `openllm start`.
If you pass in a HF model_id, then it will use the tokenizer to generate the prompt.
```bash
openllm get-prompt WizardLM/WizardCoder-15B-V1.0 "Hello there"
```
If you need change the prompt template, you can create the template file that contains the jina2 template through `--chat-template-file`
See https://huggingface.co/docs/transformers/main/chat_templating#templates-for-chat-models for more details.
\b
```bash
openllm get-prompt WizardLM/WizardCoder-15B-V1.0 "Hello there" --chat-template-file template.jinja2
```
\b
If you pass a model name, then it will use OpenLLM configuration to generate the prompt.
Note that this is mainly for utilities, as OpenLLM won't use these prompts to format for you.
\b
```bash
openllm get-prompt mistral "Hello there"
"""
_memoized = {k: v[0] for k, v in _memoized.items() if v}
if prompt_template_file and chat_template_file:
ctx.fail('prompt-template-file and chat-template-file are mutually exclusive.')
acceptable = set(openllm.CONFIG_MAPPING_NAMES.keys()) | set(
inflection.dasherize(name) for name in openllm.CONFIG_MAPPING_NAMES.keys()
)
if model_id in acceptable:
logger.warning(
'Using a default prompt from OpenLLM. Note that this prompt might not work for your intended usage.\n'
)
config = openllm.AutoConfig.for_model(model_id)
template = prompt_template_file.read() if prompt_template_file is not None else config.template
system_message = system_message or config.system_message
try:
formatted = (
PromptTemplate(template).with_options(system_message=system_message).format(instruction=prompt, **_memoized)
)
except RuntimeError as err:
logger.debug('Exception caught while formatting prompt: %s', err)
ctx.fail(str(err))
else:
import transformers
trust_remote_code = openllm.utils.check_bool_env('TRUST_REMOTE_CODE', False)
config = transformers.AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, trust_remote_code=trust_remote_code)
if chat_template_file is not None:
chat_template_file = chat_template_file.read()
if system_message is None:
logger.warning('system-message is not provided, using default infer from the model architecture.\n')
for architecture in config.architectures:
if architecture in openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE():
system_message = (
openllm.AutoConfig.infer_class_from_name(
openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE()[architecture]
)
.model_construct_env()
.system_message
)
break
else:
ctx.fail(
f'Failed to infer system message from model architecture: {config.architectures}. Please pass in --system-message'
)
messages = [{'role': 'system', 'content': system_message}, {'role': 'user', 'content': prompt}]
formatted = tokenizer.apply_chat_template(
messages, chat_template=chat_template_file, add_generation_prompt=add_generation_prompt, tokenize=False
)
termui.echo(orjson.dumps({'prompt': formatted}, option=orjson.OPT_INDENT_2).decode(), fg='white')
ctx.exit(0)

View File

@@ -1,34 +0,0 @@
from __future__ import annotations
import click
import inflection
import orjson
import bentoml
import openllm
from bentoml._internal.utils import human_readable_size
from openllm_cli import termui
@click.command('list_bentos', context_settings=termui.CONTEXT_SETTINGS)
@click.pass_context
def cli(ctx: click.Context) -> None:
"""List available bentos built by OpenLLM."""
mapping = {
k: [
{
'tag': str(b.tag),
'size': human_readable_size(openllm.utils.calc_dir_size(b.path)),
'models': [
{'tag': str(m.tag), 'size': human_readable_size(openllm.utils.calc_dir_size(m.path))}
for m in (bentoml.models.get(_.tag) for _ in b.info.models)
],
}
for b in tuple(i for i in bentoml.list() if all(k in i.info.labels for k in {'start_name', 'bundler'}))
if b.info.labels['start_name'] == k
]
for k in tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys())
}
mapping = {k: v for k, v in mapping.items() if v}
termui.echo(orjson.dumps(mapping, option=orjson.OPT_INDENT_2).decode(), fg='white')
ctx.exit(0)

View File

@@ -1,49 +0,0 @@
from __future__ import annotations
import typing as t
import click
import inflection
import orjson
import bentoml
import openllm
from bentoml._internal.utils import human_readable_size
from openllm_cli import termui
from openllm_cli._factory import model_complete_envvar, model_name_argument
if t.TYPE_CHECKING:
from openllm_core._typing_compat import DictStrAny
@click.command('list_models', context_settings=termui.CONTEXT_SETTINGS)
@model_name_argument(required=False, shell_complete=model_complete_envvar)
def cli(model_name: str | None) -> DictStrAny:
"""List available models in local store to be used with OpenLLM."""
models = tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys())
ids_in_local_store = {
k: [
i
for i in bentoml.models.list()
if 'framework' in i.info.labels
and i.info.labels['framework'] == 'openllm'
and 'model_name' in i.info.labels
and i.info.labels['model_name'] == k
]
for k in models
}
if model_name is not None:
ids_in_local_store = {
k: [
i
for i in v
if 'model_name' in i.info.labels and i.info.labels['model_name'] == inflection.dasherize(model_name)
]
for k, v in ids_in_local_store.items()
}
ids_in_local_store = {k: v for k, v in ids_in_local_store.items() if v}
local_models = {
k: [{'tag': str(i.tag), 'size': human_readable_size(openllm.utils.calc_dir_size(i.path))} for i in val]
for k, val in ids_in_local_store.items()
}
termui.echo(orjson.dumps(local_models, option=orjson.OPT_INDENT_2).decode(), fg='white')
return local_models

View File

@@ -1,108 +0,0 @@
from __future__ import annotations
import importlib.machinery
import logging
import os
import pkgutil
import subprocess
import sys
import tempfile
import typing as t
import click
import jupytext
import nbformat
import yaml
from openllm_cli import playground, termui
from openllm_core.utils import is_jupyter_available, is_jupytext_available, is_notebook_available
if t.TYPE_CHECKING:
from openllm_core._typing_compat import DictStrAny
logger = logging.getLogger(__name__)
def load_notebook_metadata() -> DictStrAny:
with open(os.path.join(os.path.dirname(playground.__file__), '_meta.yml'), 'r') as f:
content = yaml.safe_load(f)
if not all('description' in k for k in content.values()):
raise ValueError("Invalid metadata file. All entries must have a 'description' key.")
return content
@click.command('playground', context_settings=termui.CONTEXT_SETTINGS)
@click.argument('output-dir', default=None, required=False)
@click.option(
'--port',
envvar='JUPYTER_PORT',
show_envvar=True,
show_default=True,
default=8888,
help='Default port for Jupyter server',
)
@click.pass_context
def cli(ctx: click.Context, output_dir: str | None, port: int) -> None:
"""OpenLLM Playground.
A collections of notebooks to explore the capabilities of OpenLLM.
This includes notebooks for fine-tuning, inference, and more.
All of the script available in the playground can also be run directly as a Python script:
For example:
\b
```bash
python -m openllm.playground.falcon_tuned --help
```
\b
> [!NOTE]
> This command requires Jupyter to be installed. Install it with 'pip install "openllm[playground]"'
"""
if not is_jupyter_available() or not is_jupytext_available() or not is_notebook_available():
raise RuntimeError(
"Playground requires 'jupyter', 'jupytext', and 'notebook'. Install it with 'pip install \"openllm[playground]\"'"
)
metadata = load_notebook_metadata()
_temp_dir = False
if output_dir is None:
_temp_dir = True
output_dir = tempfile.mkdtemp(prefix='openllm-playground-')
else:
os.makedirs(os.path.abspath(os.path.expandvars(os.path.expanduser(output_dir))), exist_ok=True)
termui.echo('The playground notebooks will be saved to: ' + os.path.abspath(output_dir), fg='blue')
for module in pkgutil.iter_modules(playground.__path__):
if module.ispkg or os.path.exists(os.path.join(output_dir, module.name + '.ipynb')):
logger.debug(
'Skipping: %s (%s)', module.name, 'File already exists' if not module.ispkg else f'{module.name} is a module'
)
continue
if not isinstance(module.module_finder, importlib.machinery.FileFinder):
continue
termui.echo('Generating notebook for: ' + module.name, fg='magenta')
markdown_cell = nbformat.v4.new_markdown_cell(metadata[module.name]['description'])
f = jupytext.read(os.path.join(module.module_finder.path, module.name + '.py'))
f.cells.insert(0, markdown_cell)
jupytext.write(f, os.path.join(output_dir, module.name + '.ipynb'), fmt='notebook')
try:
subprocess.check_output([
sys.executable,
'-m',
'jupyter',
'notebook',
'--notebook-dir',
output_dir,
'--port',
str(port),
'--no-browser',
'--debug',
])
except subprocess.CalledProcessError as e:
termui.echo(e.output, fg='red')
raise click.ClickException(f'Failed to start a jupyter server:\n{e}') from None
except KeyboardInterrupt:
termui.echo('\nShutting down Jupyter server...', fg='yellow')
if _temp_dir:
termui.echo('Note: You can access the generated notebooks in: ' + output_dir, fg='blue')
ctx.exit(0)

View File

@@ -1,14 +0,0 @@
This folder represents a playground and source of truth for a features
development around OpenLLM
```bash
openllm playground
```
## Usage for developing this module
Write a python script that can be used by `jupytext` to convert it to a jupyter
notebook via `openllm playground`
Make sure to add the new file and its documentation into
[`_meta.yml`](./_meta.yml).

View File

@@ -1,36 +0,0 @@
features:
description: |
## General introduction to OpenLLM.
This script will demo a few features from OpenLLM:
- Usage of Auto class abstraction and run prediction with `generate`
- Ability to send per-requests parameters
- Runner integration with BentoML
opt_tuned:
description: |
## Fine tuning OPT
This script demonstrate how one can easily fine tune OPT
with [LoRa](https://arxiv.org/abs/2106.09685) and in int8 with bitsandbytes.
It is based on one of the Peft examples fine tuning script.
It requires at least one GPU to be available, so make sure to have it.
falcon_tuned:
description: |
## Fine tuning Falcon
This script demonstrate how one can fine tune Falcon using [QLoRa](https://arxiv.org/pdf/2305.14314.pdf),
[trl](https://github.com/lvwerra/trl).
It is trained using OpenAssistant's Guanaco [dataset](https://huggingface.co/datasets/timdettmers/openassistant-guanaco)
It requires at least one GPU to be available, so make sure to have it.
llama2_qlora:
description: |
## Fine tuning LlaMA 2
This script demonstrate how one can fine tune Falcon using LoRA with [trl](https://github.com/lvwerra/trl)
It is trained using the [Dolly datasets](https://huggingface.co/datasets/databricks/databricks-dolly-15k)
It requires at least one GPU to be available.

View File

@@ -1,95 +0,0 @@
from __future__ import annotations
import dataclasses
import logging
import os
import sys
import typing as t
import torch
import transformers
import openllm
# Make sure to have at least one GPU to run this script
openllm.utils.configure_logging()
logger = logging.getLogger(__name__)
# On notebook, make sure to install the following
# ! pip install -U openllm[fine-tune] @ git+https://github.com/bentoml/OpenLLM.git
from datasets import load_dataset
from trl import SFTTrainer
DEFAULT_MODEL_ID = 'ybelkada/falcon-7b-sharded-bf16'
DATASET_NAME = 'timdettmers/openassistant-guanaco'
@dataclasses.dataclass
class TrainingArguments:
per_device_train_batch_size: int = dataclasses.field(default=4)
gradient_accumulation_steps: int = dataclasses.field(default=4)
optim: str = dataclasses.field(default='paged_adamw_32bit')
save_steps: int = dataclasses.field(default=10)
warmup_steps: int = dataclasses.field(default=10)
max_steps: int = dataclasses.field(default=500)
logging_steps: int = dataclasses.field(default=10)
learning_rate: float = dataclasses.field(default=2e-4)
max_grad_norm: float = dataclasses.field(default=0.3)
warmup_ratio: float = dataclasses.field(default=0.03)
fp16: bool = dataclasses.field(default=True)
group_by_length: bool = dataclasses.field(default=True)
lr_scheduler_type: str = dataclasses.field(default='constant')
output_dir: str = dataclasses.field(default=os.path.join(os.getcwd(), 'outputs', 'falcon'))
@dataclasses.dataclass
class ModelArguments:
model_id: str = dataclasses.field(default=DEFAULT_MODEL_ID)
max_sequence_length: int = dataclasses.field(default=512)
parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith('.json'):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, training_args = t.cast(t.Tuple[ModelArguments, TrainingArguments], parser.parse_args_into_dataclasses())
llm = openllm.LLM(
model_args.model_id, quantize='int4', bnb_4bit_quant_type='nf4', bnb_4bit_compute_dtype=torch.float16
)
model, tokenizer = llm.prepare(
'lora',
lora_alpha=16,
lora_dropout=0.1,
r=16,
bias='none',
target_modules=['query_key_value', 'dense', 'dense_h_to_4h', 'dense_4h_to_h'],
)
model.config.use_cache = False
tokenizer.pad_token = tokenizer.eos_token
dataset = load_dataset(DATASET_NAME, split='train')
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
dataset_text_field='text',
max_seq_length=model_args.max_sequence_length,
tokenizer=tokenizer,
args=dataclasses.replace(
transformers.TrainingArguments(training_args.output_dir), **dataclasses.asdict(training_args)
),
)
# upcast layernorm in float32 for more stable training
for name, module in trainer.model.named_modules():
if 'norm' in name:
module = module.to(torch.float32)
trainer.train()
trainer.model.save_pretrained(os.path.join(training_args.output_dir, 'lora'))

View File

@@ -1,44 +0,0 @@
from __future__ import annotations
import argparse
import asyncio
import logging
import typing as t
import openllm
openllm.utils.configure_logging()
logger = logging.getLogger(__name__)
MAX_NEW_TOKENS = 384
Q = 'Answer the following question, step by step:\n{q}\nA:'
question = 'What is the meaning of life?'
async def main() -> int:
parser = argparse.ArgumentParser()
parser.add_argument('question', default=question)
if openllm.utils.in_notebook():
args = parser.parse_args(args=[question])
else:
args = parser.parse_args()
llm = openllm.LLM[t.Any, t.Any]('facebook/opt-2.7b')
prompt = Q.format(q=args.question)
logger.info('-' * 50, "Running with 'generate()'", '-' * 50)
res = await llm.generate(prompt)
logger.info('=' * 10, 'Response:', res)
logger.info('-' * 50, "Running with 'generate()' with per-requests argument", '-' * 50)
res = await llm.generate(prompt, max_new_tokens=MAX_NEW_TOKENS)
logger.info('=' * 10, 'Response:', res)
return 0
def _mp_fn(index: t.Any): # type: ignore
# For xla_spawn (TPUs)
asyncio.run(main())

View File

@@ -1,236 +0,0 @@
from __future__ import annotations
import dataclasses
import logging
import os
import sys
import typing as t
import torch
import transformers
import openllm
if t.TYPE_CHECKING:
import peft
# Make sure to have at least one GPU to run this script
openllm.utils.configure_logging()
logger = logging.getLogger(__name__)
# On notebook, make sure to install the following
# ! pip install -U openllm[fine-tune] @ git+https://github.com/bentoml/OpenLLM.git
from functools import partial
from itertools import chain
from random import randint, randrange
import bitsandbytes as bnb
from datasets import load_dataset
# COPIED FROM https://github.com/artidoro/qlora/blob/main/qlora.py
def find_all_linear_names(model):
lora_module_names = set()
for name, module in model.named_modules():
if isinstance(module, bnb.nn.Linear4bit):
names = name.split('.')
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
if 'lm_head' in lora_module_names: # needed for 16-bit
lora_module_names.remove('lm_head')
return list(lora_module_names)
# Change this to the local converted path if you don't have access to the meta-llama model
DEFAULT_MODEL_ID = 'meta-llama/Llama-2-7b-hf'
# change this to 'main' if you want to use the latest llama
DEFAULT_MODEL_VERSION = '335a02887eb6684d487240bbc28b5699298c3135'
DATASET_NAME = 'databricks/databricks-dolly-15k'
def format_dolly(sample):
instruction = f"### Instruction\n{sample['instruction']}"
context = f"### Context\n{sample['context']}" if len(sample['context']) > 0 else None
response = f"### Answer\n{sample['response']}"
# join all the parts together
prompt = '\n\n'.join([i for i in [instruction, context, response] if i is not None])
return prompt
# template dataset to add prompt to each sample
def template_dataset(sample, tokenizer):
sample['text'] = f'{format_dolly(sample)}{tokenizer.eos_token}'
return sample
# empty list to save remainder from batches to use in next batch
remainder = {'input_ids': [], 'attention_mask': [], 'token_type_ids': []}
def chunk(sample, chunk_length=2048):
# define global remainder variable to save remainder from batches to use in next batch
global remainder
# Concatenate all texts and add remainder from previous batch
concatenated_examples = {k: list(chain(*sample[k])) for k in sample.keys()}
concatenated_examples = {k: remainder[k] + concatenated_examples[k] for k in concatenated_examples.keys()}
# get total number of tokens for batch
batch_total_length = len(concatenated_examples[next(iter(sample.keys()))])
# get max number of chunks for batch
if batch_total_length >= chunk_length:
batch_chunk_length = (batch_total_length // chunk_length) * chunk_length
# Split by chunks of max_len.
result = {
k: [t[i : i + chunk_length] for i in range(0, batch_chunk_length, chunk_length)]
for k, t in concatenated_examples.items()
}
# add remainder to global variable for next batch
remainder = {k: concatenated_examples[k][batch_chunk_length:] for k in concatenated_examples.keys()}
# prepare labels
result['labels'] = result['input_ids'].copy()
return result
def prepare_datasets(tokenizer, dataset_name=DATASET_NAME):
# Load dataset from the hub
dataset = load_dataset(dataset_name, split='train')
print(f'dataset size: {len(dataset)}')
print(dataset[randrange(len(dataset))])
# apply prompt template per sample
dataset = dataset.map(partial(template_dataset, tokenizer=tokenizer), remove_columns=list(dataset.features))
# print random sample
print('Sample from dolly-v2 ds:', dataset[randint(0, len(dataset))]['text'])
# tokenize and chunk dataset
lm_dataset = dataset.map(
lambda sample: tokenizer(sample['text']), batched=True, remove_columns=list(dataset.features)
).map(partial(chunk, chunk_length=2048), batched=True)
# Print total number of samples
print(f'Total number of samples: {len(lm_dataset)}')
return lm_dataset
def prepare_for_int4_training(
model_id: str, model_version: str | None = None, gradient_checkpointing: bool = True, bf16: bool = True
) -> tuple[peft.PeftModel, transformers.LlamaTokenizerFast]:
from peft.tuners.lora import LoraLayer
llm = openllm.LLM(
model_id,
revision=model_version,
quantize='int4',
bnb_4bit_compute_dtype=torch.bfloat16,
use_cache=not gradient_checkpointing,
device_map='auto',
)
print('Model summary:', llm.model)
# get lora target modules
modules = find_all_linear_names(llm.model)
print(f'Found {len(modules)} modules to quantize: {modules}')
model, tokenizer = llm.prepare('lora', use_gradient_checkpointing=gradient_checkpointing, target_modules=modules)
# pre-process the model by upcasting the layer norms in float 32 for
for name, module in model.named_modules():
if isinstance(module, LoraLayer):
if bf16:
module = module.to(torch.bfloat16)
if 'norm' in name:
module = module.to(torch.float32)
if 'lm_head' in name or 'embed_tokens' in name:
if hasattr(module, 'weight'):
if bf16 and module.weight.dtype == torch.float32:
module = module.to(torch.bfloat16)
return model, tokenizer
@dataclasses.dataclass
class TrainingArguments:
per_device_train_batch_size: int = dataclasses.field(default=1)
gradient_checkpointing: bool = dataclasses.field(default=True)
bf16: bool = dataclasses.field(default=torch.cuda.get_device_capability()[0] == 8)
learning_rate: float = dataclasses.field(default=5e-5)
num_train_epochs: int = dataclasses.field(default=3)
logging_steps: int = dataclasses.field(default=1)
report_to: str = dataclasses.field(default='none')
output_dir: str = dataclasses.field(default=os.path.join(os.getcwd(), 'outputs', 'llama'))
save_strategy: str = dataclasses.field(default='no')
@dataclasses.dataclass
class ModelArguments:
model_id: str = dataclasses.field(default=DEFAULT_MODEL_ID)
model_version: str = dataclasses.field(default=DEFAULT_MODEL_VERSION)
seed: int = dataclasses.field(default=42)
merge_weights: bool = dataclasses.field(default=False)
if openllm.utils.in_notebook():
model_args, training_rags = ModelArguments(), TrainingArguments()
else:
parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith('.json'):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, training_args = t.cast(
t.Tuple[ModelArguments, TrainingArguments], parser.parse_args_into_dataclasses()
)
# import the model first hand
openllm.import_model(model_id=model_args.model_id, model_version=model_args.model_version)
def train_loop(model_args: ModelArguments, training_args: TrainingArguments):
import peft
transformers.set_seed(model_args.seed)
model, tokenizer = prepare_for_int4_training(
model_args.model_id, gradient_checkpointing=training_args.gradient_checkpointing, bf16=training_args.bf16
)
datasets = prepare_datasets(tokenizer)
trainer = transformers.Trainer(
model=model,
args=dataclasses.replace(
transformers.TrainingArguments(training_args.output_dir), **dataclasses.asdict(training_args)
),
train_dataset=datasets,
data_collator=transformers.default_data_collator,
)
trainer.train()
if model_args.merge_weights:
# note that this will requires larger GPU as we will load the whole model into memory
# merge adapter weights with base model and save
# save int4 model
trainer.model.save_pretrained(training_args.output_dir, safe_serialization=False)
# gc mem
del model, trainer
torch.cuda.empty_cache()
model = peft.AutoPeftModelForCausalLM.from_pretrained(
training_args.output_dir, low_cpu_mem_usage=True, torch_dtype=torch.float16
)
# merge lora with base weights and save
model = model.merge_and_unload()
model.save_pretrained(
os.path.join(os.getcwd(), 'outputs', 'merged_llama_lora'), safe_serialization=True, max_shard_size='2GB'
)
else:
trainer.model.save_pretrained(os.path.join(training_args.output_dir, 'lora'))
train_loop(model_args, training_args)

View File

@@ -1,81 +0,0 @@
from __future__ import annotations
import dataclasses
import logging
import os
import sys
import typing as t
import transformers
import openllm
# Make sure to have at least one GPU to run this script
openllm.utils.configure_logging()
logger = logging.getLogger(__name__)
# On notebook, make sure to install the following
# ! pip install -U openllm[fine-tune] @ git+https://github.com/bentoml/OpenLLM.git
from datasets import load_dataset
if t.TYPE_CHECKING:
from peft import PeftModel
DEFAULT_MODEL_ID = 'facebook/opt-6.7b'
def load_trainer(
model: PeftModel, tokenizer: transformers.GPT2TokenizerFast, dataset_dict: t.Any, training_args: TrainingArguments
):
return transformers.Trainer(
model=model,
train_dataset=dataset_dict['train'],
args=dataclasses.replace(
transformers.TrainingArguments(training_args.output_dir), **dataclasses.asdict(training_args)
),
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
@dataclasses.dataclass
class TrainingArguments:
per_device_train_batch_size: int = dataclasses.field(default=4)
gradient_accumulation_steps: int = dataclasses.field(default=4)
warmup_steps: int = dataclasses.field(default=10)
max_steps: int = dataclasses.field(default=50)
learning_rate: float = dataclasses.field(default=3e-4)
fp16: bool = dataclasses.field(default=True)
logging_steps: int = dataclasses.field(default=1)
output_dir: str = dataclasses.field(default=os.path.join(os.getcwd(), 'outputs', 'opt'))
@dataclasses.dataclass
class ModelArguments:
model_id: str = dataclasses.field(default=DEFAULT_MODEL_ID)
parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith('.json'):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, training_args = t.cast(t.Tuple[ModelArguments, TrainingArguments], parser.parse_args_into_dataclasses())
llm = openllm.LLM(model_args.model_id, quantize='int8')
model, tokenizer = llm.prepare(
'lora', r=16, lora_alpha=32, target_modules=['q_proj', 'v_proj'], lora_dropout=0.05, bias='none'
)
# ft on english_quotes
data = load_dataset('Abirate/english_quotes')
data = data.map(lambda samples: tokenizer(samples['quote']), batched=True)
trainer = load_trainer(model, tokenizer, data, training_args)
model.config.use_cache = False # silence just for warning, reenable for inference later
trainer.train()
trainer.model.save_pretrained(os.path.join(training_args.output_dir, 'lora'))

View File

@@ -1,90 +1,9 @@
from __future__ import annotations
import enum
import functools
import logging
import os
import typing as t
import click
import inflection
import orjson
from openllm_core._typing_compat import DictStrAny, TypedDict
from openllm_core.utils import get_debug_mode
logger = logging.getLogger('openllm')
from _openllm_tiny import _termui
class Level(enum.IntEnum):
NOTSET = logging.DEBUG
DEBUG = logging.DEBUG
INFO = logging.INFO
WARNING = logging.WARNING
ERROR = logging.ERROR
CRITICAL = logging.CRITICAL
@property
def color(self) -> str | None:
return {
Level.NOTSET: None,
Level.DEBUG: 'cyan',
Level.INFO: 'green',
Level.WARNING: 'yellow',
Level.ERROR: 'red',
Level.CRITICAL: 'red',
}[self]
@classmethod
def from_logging_level(cls, level: int) -> Level:
return {
logging.DEBUG: Level.DEBUG,
logging.INFO: Level.INFO,
logging.WARNING: Level.WARNING,
logging.ERROR: Level.ERROR,
logging.CRITICAL: Level.CRITICAL,
}[level]
def __dir__():
return dir(_termui)
class JsonLog(TypedDict):
log_level: Level
content: str
def log(content: str, level: Level = Level.INFO, fg: str | None = None) -> None:
if get_debug_mode():
echo(content, fg=fg)
else:
echo(orjson.dumps(JsonLog(log_level=level, content=content)).decode(), fg=fg, json=True)
warning = functools.partial(log, level=Level.WARNING)
error = functools.partial(log, level=Level.ERROR)
critical = functools.partial(log, level=Level.CRITICAL)
debug = functools.partial(log, level=Level.DEBUG)
info = functools.partial(log, level=Level.INFO)
notset = functools.partial(log, level=Level.NOTSET)
def echo(text: t.Any, fg: str | None = None, *, _with_style: bool = True, json: bool = False, **attrs: t.Any) -> None:
if json:
text = orjson.loads(text)
if 'content' in text and 'log_level' in text:
content = text['content']
fg = Level.from_logging_level(text['log_level']).color
else:
content = orjson.dumps(text).decode()
fg = Level.INFO.color if not get_debug_mode() else Level.DEBUG.color
else:
content = t.cast(str, text)
attrs['fg'] = fg
(click.echo if not _with_style else click.secho)(content, **attrs)
COLUMNS: int = int(os.environ.get('COLUMNS', str(120)))
CONTEXT_SETTINGS: DictStrAny = {
'help_option_names': ['-h', '--help'],
'max_content_width': COLUMNS,
'token_normalize_func': inflection.underscore,
}
__all__ = ['COLUMNS', 'CONTEXT_SETTINGS', 'Level', 'critical', 'debug', 'echo', 'error', 'info', 'log', 'warning']
def __getattr__(name):
return getattr(_termui, name)