mirror of
https://github.com/bentoml/OpenLLM.git
synced 2025-12-23 23:57:46 -05:00
infra: using ruff formatter (#594)
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -7,21 +7,8 @@ default_language_version:
|
||||
python: python3.9 # NOTE: sync with .python-version-default
|
||||
exclude: '.*\.(css|js|svg)$'
|
||||
repos:
|
||||
- repo: https://github.com/google/yapf
|
||||
rev: v0.40.2
|
||||
hooks:
|
||||
- id: yapf
|
||||
alias: f
|
||||
verbose: true
|
||||
exclude: |
|
||||
(?x)^(
|
||||
examples/.*|
|
||||
bench.py |
|
||||
cz.py |
|
||||
openllm-client/src/openllm_client/pb.*
|
||||
)$
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: 'v0.1.4'
|
||||
rev: 'v0.1.5'
|
||||
hooks:
|
||||
- id: ruff
|
||||
alias: r
|
||||
@@ -30,7 +17,6 @@ repos:
|
||||
- id: ruff-format
|
||||
alias: rf
|
||||
verbose: true
|
||||
types: [pyi]
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: mypy
|
||||
@@ -45,6 +31,10 @@ repos:
|
||||
- id: editorconfig-checker
|
||||
verbose: true
|
||||
alias: ec
|
||||
exclude: |
|
||||
(?x)^(
|
||||
openllm-python/src/openllm/cli/entrypoint.py
|
||||
)$
|
||||
- repo: meta
|
||||
hooks:
|
||||
- id: check-hooks-apply
|
||||
|
||||
16
.style.yapf
16
.style.yapf
@@ -1,16 +0,0 @@
|
||||
[style]
|
||||
BASED_ON_STYLE = google
|
||||
INDENT_WIDTH = 2
|
||||
JOIN_MULTIPLE_LINES = true
|
||||
COLUMN_LIMIT = 192
|
||||
USE_TABS = false
|
||||
BLANK_LINES_AROUND_TOP_LEVEL_DEFINITION = 1
|
||||
BLANK_LINES_BETWEEN_TOP_LEVEL_IMPORTS_AND_VARIABLES = 1
|
||||
DISABLE_ENDING_COMMA_HEURISTIC = true
|
||||
BLANK_LINE_BEFORE_CLASS_DOCSTRING = false
|
||||
BLANK_LINE_BEFORE_MODULE_DOCSTRING = false
|
||||
BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = false
|
||||
ALIGN_CLOSING_BRACKET_WITH_VISUAL_INDENT = true
|
||||
ALLOW_MULTILINE_DICTIONARY_KEYS = false
|
||||
ALLOW_SPLIT_BEFORE_DICT_VALUE = false
|
||||
COALESCE_BRACKETS = true
|
||||
@@ -1,4 +0,0 @@
|
||||
openllm-python/src/openllm/playground/
|
||||
openllm-python/src/openllm/models/__init__.py
|
||||
openllm-client/src/openllm_client/pb/**
|
||||
examples/
|
||||
13
bench.py
13
bench.py
@@ -8,6 +8,7 @@ import aiohttp
|
||||
|
||||
import openllm
|
||||
|
||||
|
||||
async def send_request(url, it, prompt, session, model, **attrs):
|
||||
headers = {'accept': 'application/json', 'Content-Type': 'application/json'}
|
||||
config = openllm.AutoConfig.for_model(model).model_construct_env(**attrs).model_dump()
|
||||
@@ -223,12 +224,20 @@ async def main(args: argparse.Namespace) -> int:
|
||||
'Write a letter to your future self, offering reflections on personal growth, achievements, and aspirations, as well as words of encouragement and guidance for your future journey.',
|
||||
]
|
||||
async with aiohttp.ClientSession() as session:
|
||||
await asyncio.gather(*[send_request(url, it, prompt, session, args.model, max_new_tokens=2048) for it, prompt in enumerate(prompts)])
|
||||
await asyncio.gather(
|
||||
*[send_request(url, it, prompt, session, args.model, max_new_tokens=2048) for it, prompt in enumerate(prompts)]
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--generate', default=False, action='store_true', help='Whether to test with stream endpoint.')
|
||||
parser.add_argument('--model', default='llama', choices=openllm.CONFIG_MAPPING_NAMES.keys(), action='store', help='Whether to test with stream endpoint.')
|
||||
parser.add_argument(
|
||||
'--model',
|
||||
default='llama',
|
||||
choices=openllm.CONFIG_MAPPING_NAMES.keys(),
|
||||
action='store',
|
||||
help='Whether to test with stream endpoint.',
|
||||
)
|
||||
raise SystemExit(asyncio.run(main(parser.parse_args())))
|
||||
|
||||
16
cz.py
16
cz.py
@@ -7,6 +7,7 @@ import tokenize
|
||||
|
||||
from tabulate import tabulate
|
||||
|
||||
|
||||
TOKEN_WHITELIST = [token.OP, token.NAME, token.NUMBER, token.STRING]
|
||||
|
||||
|
||||
@@ -21,12 +22,23 @@ def run_cz(dir: str, package: str):
|
||||
with tokenize.open(filepath) as file_:
|
||||
tokens = [t for t in tokenize.generate_tokens(file_.readline) if t.type in TOKEN_WHITELIST]
|
||||
token_count, line_count = len(tokens), len(set([t.start[0] for t in tokens]))
|
||||
table.append([filepath.replace(os.path.join(dir, 'src'), ''), line_count, token_count / line_count if line_count != 0 else 0])
|
||||
table.append(
|
||||
[
|
||||
filepath.replace(os.path.join(dir, 'src'), ''),
|
||||
line_count,
|
||||
token_count / line_count if line_count != 0 else 0,
|
||||
]
|
||||
)
|
||||
print(f'\n{"=" * 80}\n')
|
||||
print(tabulate([headers, *sorted(table, key=lambda x: -x[1])], headers='firstrow', floatfmt='.1f') + '\n')
|
||||
print(
|
||||
tabulate(
|
||||
[(dir_name, sum([x[1] for x in group])) for dir_name, group in itertools.groupby(sorted([(x[0].rsplit('/', 1)[0], x[1]) for x in table]), key=lambda x: x[0])],
|
||||
[
|
||||
(dir_name, sum([x[1] for x in group]))
|
||||
for dir_name, group in itertools.groupby(
|
||||
sorted([(x[0].rsplit('/', 1)[0], x[1]) for x in table]), key=lambda x: x[0]
|
||||
)
|
||||
],
|
||||
headers=['Directory', 'LOC'],
|
||||
floatfmt='.1f',
|
||||
)
|
||||
|
||||
@@ -46,10 +46,15 @@ chain = LLMChain(llm=llm, prompt=prompt)
|
||||
svc = bentoml.Service('fb-ads-copy', runners=[llm.runner])
|
||||
|
||||
SAMPLE_INPUT = Query(
|
||||
industry='SAAS', product_name='BentoML', keywords=['open source', 'developer tool', 'AI application platform', 'serverless', 'cost-efficient'], llm_config=llm.runner.config.model_dump()
|
||||
industry='SAAS',
|
||||
product_name='BentoML',
|
||||
keywords=['open source', 'developer tool', 'AI application platform', 'serverless', 'cost-efficient'],
|
||||
llm_config=llm.runner.config.model_dump(),
|
||||
)
|
||||
|
||||
|
||||
@svc.api(input=JSON.from_sample(sample=SAMPLE_INPUT), output=Text())
|
||||
def generate(query: Query):
|
||||
return chain.run({'industry': query.industry, 'product_name': query.product_name, 'keywords': ', '.join(query.keywords)})
|
||||
return chain.run(
|
||||
{'industry': query.industry, 'product_name': query.product_name, 'keywords': ', '.join(query.keywords)}
|
||||
)
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
# NOTE: Make sure to install openai>1
|
||||
import os, openai
|
||||
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam, ChatCompletionAssistantMessageParam
|
||||
from openai.types.chat import (
|
||||
ChatCompletionSystemMessageParam,
|
||||
ChatCompletionUserMessageParam,
|
||||
ChatCompletionAssistantMessageParam,
|
||||
)
|
||||
|
||||
client = openai.OpenAI(base_url=os.getenv('OPENLLM_ENDPOINT', 'http://localhost:3000') + '/v1', api_key='na')
|
||||
|
||||
|
||||
@@ -8,7 +8,9 @@ model = models.data[0].id
|
||||
|
||||
# Completion API
|
||||
stream = str(os.getenv('STREAM', False)).upper() in ['TRUE', '1', 'YES', 'Y', 'ON']
|
||||
completions = client.completions.create(prompt='Write me a tag line for an ice cream shop.', model=model, max_tokens=64, stream=stream)
|
||||
completions = client.completions.create(
|
||||
prompt='Write me a tag line for an ice cream shop.', model=model, max_tokens=64, stream=stream
|
||||
)
|
||||
|
||||
print(f'Completion result (stream={stream}):')
|
||||
if stream:
|
||||
|
||||
@@ -17,13 +17,21 @@ class HTTPClient:
|
||||
address: str
|
||||
client_args: Dict[str, Any]
|
||||
@staticmethod
|
||||
def wait_until_server_ready(addr: str, timeout: float = ..., verify: bool = ..., check_interval: int = ..., **client_args: Any) -> None: ...
|
||||
def wait_until_server_ready(
|
||||
addr: str, timeout: float = ..., verify: bool = ..., check_interval: int = ..., **client_args: Any
|
||||
) -> None: ...
|
||||
@overload
|
||||
def __init__(self, address: str, timeout: int = ..., verify: bool = ..., api_version: str = ..., **client_args: Any) -> None: ...
|
||||
def __init__(
|
||||
self, address: str, timeout: int = ..., verify: bool = ..., api_version: str = ..., **client_args: Any
|
||||
) -> None: ...
|
||||
@overload
|
||||
def __init__(self, address: str = ..., timeout: int = ..., verify: bool = ..., api_version: str = ..., **client_args: Any) -> None: ...
|
||||
def __init__(
|
||||
self, address: str = ..., timeout: int = ..., verify: bool = ..., api_version: str = ..., **client_args: Any
|
||||
) -> None: ...
|
||||
@overload
|
||||
def __init__(self, address: None = ..., timeout: int = ..., verify: bool = ..., api_version: str = ..., **client_args: Any) -> None: ...
|
||||
def __init__(
|
||||
self, address: None = ..., timeout: int = ..., verify: bool = ..., api_version: str = ..., **client_args: Any
|
||||
) -> None: ...
|
||||
@property
|
||||
def is_ready(self) -> bool: ...
|
||||
def health(self) -> None: ...
|
||||
@@ -54,13 +62,21 @@ class AsyncHTTPClient:
|
||||
address: str
|
||||
client_args: Dict[str, Any]
|
||||
@staticmethod
|
||||
async def wait_until_server_ready(addr: str, timeout: float = ..., verify: bool = ..., check_interval: int = ..., **client_args: Any) -> None: ...
|
||||
async def wait_until_server_ready(
|
||||
addr: str, timeout: float = ..., verify: bool = ..., check_interval: int = ..., **client_args: Any
|
||||
) -> None: ...
|
||||
@overload
|
||||
def __init__(self, address: str, timeout: int = ..., verify: bool = ..., api_version: str = ..., **client_args: Any) -> None: ...
|
||||
def __init__(
|
||||
self, address: str, timeout: int = ..., verify: bool = ..., api_version: str = ..., **client_args: Any
|
||||
) -> None: ...
|
||||
@overload
|
||||
def __init__(self, address: str = ..., timeout: int = ..., verify: bool = ..., api_version: str = ..., **client_args: Any) -> None: ...
|
||||
def __init__(
|
||||
self, address: str = ..., timeout: int = ..., verify: bool = ..., api_version: str = ..., **client_args: Any
|
||||
) -> None: ...
|
||||
@overload
|
||||
def __init__(self, address: None = ..., timeout: int = ..., verify: bool = ..., api_version: str = ..., **client_args: Any) -> None: ...
|
||||
def __init__(
|
||||
self, address: None = ..., timeout: int = ..., verify: bool = ..., api_version: str = ..., **client_args: Any
|
||||
) -> None: ...
|
||||
@property
|
||||
def is_ready(self) -> bool: ...
|
||||
async def health(self) -> None: ...
|
||||
|
||||
@@ -17,21 +17,29 @@ from ._schemas import Request
|
||||
from ._schemas import Response
|
||||
from ._schemas import StreamingResponse
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _address_validator(_, attr, value):
|
||||
if not isinstance(value, str): raise TypeError(f'{attr.name} must be a string')
|
||||
if not urlparse(value).netloc: raise ValueError(f'{attr.name} must be a valid URL')
|
||||
if not isinstance(value, str):
|
||||
raise TypeError(f'{attr.name} must be a string')
|
||||
if not urlparse(value).netloc:
|
||||
raise ValueError(f'{attr.name} must be a valid URL')
|
||||
|
||||
|
||||
def _address_converter(addr: str):
|
||||
return addr if '://' in addr else 'http://' + addr
|
||||
|
||||
|
||||
class ServerState(enum.Enum):
|
||||
CLOSED = 1 # CLOSED: The server is not yet ready or `wait_until_server_ready` has not been called/failed.
|
||||
READY = 2 # READY: The server is ready and `wait_until_server_ready` has been called.
|
||||
|
||||
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
|
||||
@attr.define(init=False)
|
||||
class HTTPClient:
|
||||
address: str = attr.field(validator=_address_validator, converter=_address_converter)
|
||||
@@ -47,7 +55,9 @@ class HTTPClient:
|
||||
__config: dict[str, t.Any] | None = None
|
||||
|
||||
def __repr__(self):
|
||||
return f'<HTTPClient timeout={self._timeout} api_version={self._api_version} verify={self._verify} state={self._state}>'
|
||||
return (
|
||||
f'<HTTPClient timeout={self._timeout} api_version={self._api_version} verify={self._verify} state={self._state}>'
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def wait_until_server_ready(addr, timeout=30, verify=False, check_interval=1, **client_args):
|
||||
@@ -58,8 +68,10 @@ class HTTPClient:
|
||||
try:
|
||||
with httpx.Client(base_url=addr, verify=verify, **client_args) as sess:
|
||||
status = sess.get('/readyz').status_code
|
||||
if status == 200: break
|
||||
else: time.sleep(check_interval)
|
||||
if status == 200:
|
||||
break
|
||||
else:
|
||||
time.sleep(check_interval)
|
||||
except (httpx.ConnectError, urllib.error.URLError, ConnectionError):
|
||||
logger.debug('Server is not ready yet, retrying in %d seconds...', check_interval)
|
||||
time.sleep(check_interval)
|
||||
@@ -75,12 +87,21 @@ class HTTPClient:
|
||||
def __init__(self, address=None, timeout=30, verify=False, api_version='v1', **client_args):
|
||||
if address is None:
|
||||
env = os.getenv('OPENLLM_ENDPOINT')
|
||||
if env is None: raise ValueError('address must be provided')
|
||||
if env is None:
|
||||
raise ValueError('address must be provided')
|
||||
address = env
|
||||
self.__attrs_init__(address, client_args, httpx.Client(base_url=address, timeout=timeout, verify=verify, **client_args), timeout, api_version, verify)
|
||||
self.__attrs_init__(
|
||||
address,
|
||||
client_args,
|
||||
httpx.Client(base_url=address, timeout=timeout, verify=verify, **client_args),
|
||||
timeout,
|
||||
api_version,
|
||||
verify,
|
||||
)
|
||||
|
||||
def _metadata(self) -> dict[str, t.Any]:
|
||||
if self.__metadata is None: self.__metadata = self._inner.post(self._build_endpoint('metadata')).json()
|
||||
if self.__metadata is None:
|
||||
self.__metadata = self._inner.post(self._build_endpoint('metadata')).json()
|
||||
return self.__metadata
|
||||
|
||||
def _config(self) -> dict[str, t.Any]:
|
||||
@@ -111,41 +132,62 @@ class HTTPClient:
|
||||
logger.error('Server is not healthy (Scroll up for traceback)\n%s', e)
|
||||
_object_setattr(self, '_state', ServerState.CLOSED)
|
||||
|
||||
def generate(self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs) -> Response:
|
||||
def generate(
|
||||
self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs
|
||||
) -> Response:
|
||||
if not self.is_ready:
|
||||
self.health()
|
||||
if not self.is_ready: raise RuntimeError('Server is not ready. Check server logs for more information.')
|
||||
if timeout is None: timeout = self._timeout
|
||||
if verify is None: verify = self._verify
|
||||
if not self.is_ready:
|
||||
raise RuntimeError('Server is not ready. Check server logs for more information.')
|
||||
if timeout is None:
|
||||
timeout = self._timeout
|
||||
if verify is None:
|
||||
verify = self._verify
|
||||
_meta, _config = self._metadata(), self._config()
|
||||
if llm_config is not None: llm_config = {**_config, **llm_config, **attrs}
|
||||
else: llm_config = {**_config, **attrs}
|
||||
if _meta['prompt_template'] is not None: prompt = _meta['prompt_template'].format(system_message=_meta['system_message'], instruction=prompt)
|
||||
if llm_config is not None:
|
||||
llm_config = {**_config, **llm_config, **attrs}
|
||||
else:
|
||||
llm_config = {**_config, **attrs}
|
||||
if _meta['prompt_template'] is not None:
|
||||
prompt = _meta['prompt_template'].format(system_message=_meta['system_message'], instruction=prompt)
|
||||
|
||||
req = Request(prompt=prompt, llm_config=llm_config, stop=stop, adapter_name=adapter_name)
|
||||
with httpx.Client(base_url=self.address, timeout=timeout, verify=verify, **self.client_args) as client:
|
||||
r = client.post(self._build_endpoint('generate'), json=req.model_dump_json(), **self.client_args)
|
||||
if r.status_code != 200: raise ValueError("Failed to get generation from '/v1/generate'. Check server logs for more details.")
|
||||
if r.status_code != 200:
|
||||
raise ValueError("Failed to get generation from '/v1/generate'. Check server logs for more details.")
|
||||
return Response.model_construct(r.json())
|
||||
|
||||
def generate_stream(self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs) -> t.Iterator[StreamingResponse]:
|
||||
def generate_stream(
|
||||
self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs
|
||||
) -> t.Iterator[StreamingResponse]:
|
||||
if not self.is_ready:
|
||||
self.health()
|
||||
if not self.is_ready: raise RuntimeError('Server is not ready. Check server logs for more information.')
|
||||
if timeout is None: timeout = self._timeout
|
||||
if verify is None: verify = self._verify
|
||||
if not self.is_ready:
|
||||
raise RuntimeError('Server is not ready. Check server logs for more information.')
|
||||
if timeout is None:
|
||||
timeout = self._timeout
|
||||
if verify is None:
|
||||
verify = self._verify
|
||||
_meta, _config = self._metadata(), self._config()
|
||||
if llm_config is not None: llm_config = {**_config, **llm_config, **attrs}
|
||||
else: llm_config = {**_config, **attrs}
|
||||
if _meta['prompt_template'] is not None: prompt = _meta['prompt_template'].format(system_message=_meta['system_message'], instruction=prompt)
|
||||
if llm_config is not None:
|
||||
llm_config = {**_config, **llm_config, **attrs}
|
||||
else:
|
||||
llm_config = {**_config, **attrs}
|
||||
if _meta['prompt_template'] is not None:
|
||||
prompt = _meta['prompt_template'].format(system_message=_meta['system_message'], instruction=prompt)
|
||||
|
||||
req = Request(prompt=prompt, llm_config=llm_config, stop=stop, adapter_name=adapter_name)
|
||||
with httpx.Client(base_url=self.address, timeout=timeout, verify=verify, **self.client_args) as client:
|
||||
with client.stream('POST', self._build_endpoint('generate_stream'), json=req.model_dump_json(), **self.client_args) as r:
|
||||
with client.stream(
|
||||
'POST', self._build_endpoint('generate_stream'), json=req.model_dump_json(), **self.client_args
|
||||
) as r:
|
||||
for payload in r.iter_bytes():
|
||||
if payload == b'data: [DONE]\n\n': break
|
||||
if payload == b'data: [DONE]\n\n':
|
||||
break
|
||||
# Skip line
|
||||
if payload == b'\n': continue
|
||||
if payload == b'\n':
|
||||
continue
|
||||
if payload.startswith(b'data: '):
|
||||
try:
|
||||
proc = payload.decode('utf-8').lstrip('data: ').rstrip('\n')
|
||||
@@ -154,6 +196,7 @@ class HTTPClient:
|
||||
except Exception:
|
||||
pass # FIXME: Handle this
|
||||
|
||||
|
||||
@attr.define(init=False)
|
||||
class AsyncHTTPClient:
|
||||
address: str = attr.field(validator=_address_validator, converter=_address_converter)
|
||||
@@ -180,8 +223,10 @@ class AsyncHTTPClient:
|
||||
try:
|
||||
async with httpx.AsyncClient(base_url=addr, verify=verify, **client_args) as sess:
|
||||
status = (await sess.get('/readyz')).status_code
|
||||
if status == 200: break
|
||||
else: await asyncio.sleep(check_interval)
|
||||
if status == 200:
|
||||
break
|
||||
else:
|
||||
await asyncio.sleep(check_interval)
|
||||
except (httpx.ConnectError, urllib.error.URLError, ConnectionError):
|
||||
logger.debug('Server is not ready yet, retrying in %d seconds...', check_interval)
|
||||
await asyncio.sleep(check_interval)
|
||||
@@ -197,12 +242,21 @@ class AsyncHTTPClient:
|
||||
def __init__(self, address=None, timeout=30, verify=False, api_version='v1', **client_args):
|
||||
if address is None:
|
||||
env = os.getenv('OPENLLM_ENDPOINT')
|
||||
if env is None: raise ValueError('address must be provided')
|
||||
if env is None:
|
||||
raise ValueError('address must be provided')
|
||||
address = env
|
||||
self.__attrs_init__(address, client_args, httpx.AsyncClient(base_url=address, timeout=timeout, verify=verify, **client_args), timeout, api_version, verify)
|
||||
self.__attrs_init__(
|
||||
address,
|
||||
client_args,
|
||||
httpx.AsyncClient(base_url=address, timeout=timeout, verify=verify, **client_args),
|
||||
timeout,
|
||||
api_version,
|
||||
verify,
|
||||
)
|
||||
|
||||
async def _metadata(self) -> dict[str, t.Any]:
|
||||
if self.__metadata is None: self.__metadata = (await self._inner.post(self._build_endpoint('metadata'))).json()
|
||||
if self.__metadata is None:
|
||||
self.__metadata = (await self._inner.post(self._build_endpoint('metadata'))).json()
|
||||
return self.__metadata
|
||||
|
||||
async def _config(self) -> dict[str, t.Any]:
|
||||
@@ -230,41 +284,62 @@ class AsyncHTTPClient:
|
||||
logger.error('Server is not healthy (Scroll up for traceback)\n%s', e)
|
||||
_object_setattr(self, '_state', ServerState.CLOSED)
|
||||
|
||||
async def generate(self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs) -> Response:
|
||||
async def generate(
|
||||
self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs
|
||||
) -> Response:
|
||||
if not self.is_ready:
|
||||
await self.health()
|
||||
if not self.is_ready: raise RuntimeError('Server is not ready. Check server logs for more information.')
|
||||
if timeout is None: timeout = self._timeout
|
||||
if verify is None: verify = self._verify
|
||||
if not self.is_ready:
|
||||
raise RuntimeError('Server is not ready. Check server logs for more information.')
|
||||
if timeout is None:
|
||||
timeout = self._timeout
|
||||
if verify is None:
|
||||
verify = self._verify
|
||||
_meta, _config = await self._metadata(), await self._config()
|
||||
if llm_config is not None: llm_config = {**_config, **llm_config, **attrs}
|
||||
else: llm_config = {**_config, **attrs}
|
||||
if _meta['prompt_template'] is not None: prompt = _meta['prompt_template'].format(system_message=_meta['system_message'], instruction=prompt)
|
||||
if llm_config is not None:
|
||||
llm_config = {**_config, **llm_config, **attrs}
|
||||
else:
|
||||
llm_config = {**_config, **attrs}
|
||||
if _meta['prompt_template'] is not None:
|
||||
prompt = _meta['prompt_template'].format(system_message=_meta['system_message'], instruction=prompt)
|
||||
|
||||
req = Request(prompt=prompt, llm_config=llm_config, stop=stop, adapter_name=adapter_name)
|
||||
async with httpx.AsyncClient(base_url=self.address, timeout=timeout, verify=verify, **self.client_args) as client:
|
||||
r = await client.post(self._build_endpoint('generate'), json=req.model_dump_json(), **self.client_args)
|
||||
if r.status_code != 200: raise ValueError("Failed to get generation from '/v1/generate'. Check server logs for more details.")
|
||||
if r.status_code != 200:
|
||||
raise ValueError("Failed to get generation from '/v1/generate'. Check server logs for more details.")
|
||||
return Response.model_construct(r.json())
|
||||
|
||||
async def generate_stream(self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs) -> t.AsyncGenerator[StreamingResponse, t.Any]:
|
||||
async def generate_stream(
|
||||
self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs
|
||||
) -> t.AsyncGenerator[StreamingResponse, t.Any]:
|
||||
if not self.is_ready:
|
||||
await self.health()
|
||||
if not self.is_ready: raise RuntimeError('Server is not ready. Check server logs for more information.')
|
||||
if timeout is None: timeout = self._timeout
|
||||
if verify is None: verify = self._verify
|
||||
if not self.is_ready:
|
||||
raise RuntimeError('Server is not ready. Check server logs for more information.')
|
||||
if timeout is None:
|
||||
timeout = self._timeout
|
||||
if verify is None:
|
||||
verify = self._verify
|
||||
_meta, _config = await self._metadata(), await self._config()
|
||||
if llm_config is not None: llm_config = {**_config, **llm_config, **attrs}
|
||||
else: llm_config = {**_config, **attrs}
|
||||
if _meta['prompt_template'] is not None: prompt = _meta['prompt_template'].format(system_message=_meta['system_message'], instruction=prompt)
|
||||
if llm_config is not None:
|
||||
llm_config = {**_config, **llm_config, **attrs}
|
||||
else:
|
||||
llm_config = {**_config, **attrs}
|
||||
if _meta['prompt_template'] is not None:
|
||||
prompt = _meta['prompt_template'].format(system_message=_meta['system_message'], instruction=prompt)
|
||||
|
||||
req = Request(prompt=prompt, llm_config=llm_config, stop=stop, adapter_name=adapter_name)
|
||||
async with httpx.AsyncClient(base_url=self.address, timeout=timeout, verify=verify, **self.client_args) as client:
|
||||
async with client.stream('POST', self._build_endpoint('generate_stream'), json=req.model_dump_json(), **self.client_args) as r:
|
||||
async with client.stream(
|
||||
'POST', self._build_endpoint('generate_stream'), json=req.model_dump_json(), **self.client_args
|
||||
) as r:
|
||||
async for payload in r.aiter_bytes():
|
||||
if payload == b'data: [DONE]\n\n': break
|
||||
if payload == b'data: [DONE]\n\n':
|
||||
break
|
||||
# Skip line
|
||||
if payload == b'\n': continue
|
||||
if payload == b'\n':
|
||||
continue
|
||||
if payload.startswith(b'data: '):
|
||||
try:
|
||||
proc = payload.decode('utf-8').lstrip('data: ').rstrip('\n')
|
||||
|
||||
@@ -4,6 +4,7 @@ import typing as t
|
||||
import attr
|
||||
import cattr
|
||||
|
||||
|
||||
# XXX: sync with openllm-core/src/openllm_core/_schemas.py
|
||||
@attr.define
|
||||
class Request:
|
||||
@@ -19,10 +20,12 @@ class Request:
|
||||
def model_construct(cls, data: t.Dict[str, t.Any]) -> Request:
|
||||
return cattr.structure(data, cls)
|
||||
|
||||
|
||||
SampleLogprobs = t.List[t.Dict[int, float]]
|
||||
PromptLogprobs = t.List[t.Optional[t.Dict[int, float]]]
|
||||
FinishReason = t.Literal['length', 'stop']
|
||||
|
||||
|
||||
@attr.define
|
||||
class CompletionChunk:
|
||||
index: int
|
||||
@@ -32,6 +35,7 @@ class CompletionChunk:
|
||||
logprobs: t.Optional[SampleLogprobs] = None
|
||||
finish_reason: t.Optional[FinishReason] = None
|
||||
|
||||
|
||||
@attr.define
|
||||
class Response:
|
||||
prompt: str
|
||||
@@ -48,6 +52,7 @@ class Response:
|
||||
def model_construct(cls, data: t.Dict[str, t.Any]) -> Response:
|
||||
return cattr.structure(data, cls)
|
||||
|
||||
|
||||
@attr.define
|
||||
class StreamingResponse:
|
||||
request_id: str
|
||||
@@ -57,7 +62,12 @@ class StreamingResponse:
|
||||
|
||||
@classmethod
|
||||
def from_response_chunk(cls, response: Response) -> StreamingResponse:
|
||||
return cls(request_id=response.request_id, index=response.outputs[0].index, text=response.outputs[0].text, token_ids=response.outputs[0].token_ids[0])
|
||||
return cls(
|
||||
request_id=response.request_id,
|
||||
index=response.outputs[0].index,
|
||||
text=response.outputs[0].text,
|
||||
token_ids=response.outputs[0].token_ids[0],
|
||||
)
|
||||
|
||||
def model_dump_json(self) -> t.Dict[str, t.Any]:
|
||||
return cattr.unstructure(self)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -6,8 +6,10 @@ from enum import auto
|
||||
|
||||
import attr
|
||||
|
||||
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
|
||||
class SeparatorStyle(IntEnum):
|
||||
# Generic separator styles for chat models
|
||||
ADD_COLON_SINGLE = auto()
|
||||
@@ -24,6 +26,7 @@ class SeparatorStyle(IntEnum):
|
||||
MPT = auto()
|
||||
STARCODER = auto()
|
||||
|
||||
|
||||
@attr.define
|
||||
class Conversation:
|
||||
# The name of this template
|
||||
@@ -129,8 +132,10 @@ class Conversation:
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
ret += role + ':\n' + message + seps[i % 2]
|
||||
if i % 2 == 1: ret += '\n\n'
|
||||
else: ret += role + ':\n'
|
||||
if i % 2 == 1:
|
||||
ret += '\n\n'
|
||||
else:
|
||||
ret += role + ':\n'
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.MPT:
|
||||
ret = f'<|im_start|>system\n{system_prompt}<|im_end|>{self.sep}' if system_prompt else ''
|
||||
@@ -156,7 +161,7 @@ class Conversation:
|
||||
|
||||
def to_openai_messages(self) -> t.List[t.Dict[str, str]]:
|
||||
ret = [{'role': 'system', 'content': self.system_message}]
|
||||
for i, (_, msg) in enumerate(self.messages[self.offset:]):
|
||||
for i, (_, msg) in enumerate(self.messages[self.offset :]):
|
||||
if i % 2 == 0:
|
||||
ret.append({'role': 'user', 'content': msg})
|
||||
elif msg is not None:
|
||||
@@ -164,14 +169,16 @@ class Conversation:
|
||||
return ret
|
||||
|
||||
def copy(self) -> Conversation:
|
||||
return Conversation(name=self.name,
|
||||
system_template=self.system_template,
|
||||
system_message=self.system_message,
|
||||
roles=self.roles,
|
||||
messages=self.messages,
|
||||
offset=self.offset,
|
||||
sep_style=self.sep_style,
|
||||
sep=self.sep,
|
||||
sep2=self.sep2,
|
||||
stop_str=self.stop_str,
|
||||
stop_token_ids=self.stop_token_ids)
|
||||
return Conversation(
|
||||
name=self.name,
|
||||
system_template=self.system_template,
|
||||
system_message=self.system_message,
|
||||
roles=self.roles,
|
||||
messages=self.messages,
|
||||
offset=self.offset,
|
||||
sep_style=self.sep_style,
|
||||
sep=self.sep,
|
||||
sep2=self.sep2,
|
||||
stop_str=self.stop_str,
|
||||
stop_token_ids=self.stop_token_ids,
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Schema definition for OpenLLM. This schema is used throughout openllm core components library."""
|
||||
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
@@ -11,9 +12,11 @@ from .config import AutoConfig
|
||||
from .utils import converter
|
||||
from .utils import gen_random_uuid
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import vllm
|
||||
|
||||
|
||||
@attr.frozen(slots=True)
|
||||
class MetadataOutput:
|
||||
model_id: str
|
||||
@@ -30,6 +33,7 @@ class MetadataOutput:
|
||||
def model_dump_json(self) -> str:
|
||||
return orjson.dumps(self.model_dump(), option=orjson.OPT_INDENT_2).decode('utf-8')
|
||||
|
||||
|
||||
@attr.define(slots=True, frozen=True)
|
||||
class GenerationInput:
|
||||
prompt: str
|
||||
@@ -42,7 +46,12 @@ class GenerationInput:
|
||||
return cls.from_llm_config(AutoConfig.for_model(model_name, **attrs))
|
||||
|
||||
def model_dump(self) -> dict[str, t.Any]:
|
||||
return {'prompt': self.prompt, 'stop': self.stop, 'llm_config': self.llm_config.model_dump(flatten=True), 'adapter_name': self.adapter_name}
|
||||
return {
|
||||
'prompt': self.prompt,
|
||||
'stop': self.stop,
|
||||
'llm_config': self.llm_config.model_dump(flatten=True),
|
||||
'adapter_name': self.adapter_name,
|
||||
}
|
||||
|
||||
def model_dump_json(self) -> str:
|
||||
return orjson.dumps(self.model_dump(), option=orjson.OPT_INDENT_2).decode('utf-8')
|
||||
@@ -53,19 +62,20 @@ class GenerationInput:
|
||||
self.__attrs_init__(prompt=prompt, llm_config=llm_config, stop=stop, adapter_name=adapter_name) # type: ignore
|
||||
|
||||
def _llm_config_converter(data: dict[str, t.Any] | LLMConfig) -> LLMConfig:
|
||||
if isinstance(data, LLMConfig): return data
|
||||
if isinstance(data, LLMConfig):
|
||||
return data
|
||||
return llm_config.__class__(**data)
|
||||
|
||||
klass: type[GenerationInput] = attr.make_class(inflection.camelize(llm_config['model_name']) + 'GenerationInput', {
|
||||
'__init__': init,
|
||||
'llm_config': attr.field(default=llm_config, converter=_llm_config_converter)
|
||||
},
|
||||
bases=(cls,),
|
||||
slots=True,
|
||||
weakref_slot=True,
|
||||
frozen=True,
|
||||
repr=True,
|
||||
collect_by_mro=True)
|
||||
klass: type[GenerationInput] = attr.make_class(
|
||||
inflection.camelize(llm_config['model_name']) + 'GenerationInput',
|
||||
{'__init__': init, 'llm_config': attr.field(default=llm_config, converter=_llm_config_converter)},
|
||||
bases=(cls,),
|
||||
slots=True,
|
||||
weakref_slot=True,
|
||||
frozen=True,
|
||||
repr=True,
|
||||
collect_by_mro=True,
|
||||
)
|
||||
|
||||
def examples(_: type[GenerationInput]) -> dict[str, t.Any]:
|
||||
return klass(prompt='What is the meaning of life?', llm_config=llm_config, stop=['\n']).model_dump()
|
||||
@@ -78,6 +88,7 @@ class GenerationInput:
|
||||
pass
|
||||
return klass
|
||||
|
||||
|
||||
# NOTE: parameters from vllm.RequestOutput and vllm.CompletionOutput since vllm is not available on CPU.
|
||||
# OpenLLM will adapt CPU outputs to similar architecture with vLLM outputs for consistency
|
||||
|
||||
@@ -85,6 +96,7 @@ SampleLogprobs = t.List[t.Dict[int, float]]
|
||||
PromptLogprobs = t.List[t.Optional[t.Dict[int, float]]]
|
||||
FinishReason = t.Literal['length', 'stop']
|
||||
|
||||
|
||||
@attr.define
|
||||
class CompletionChunk:
|
||||
index: int
|
||||
@@ -100,6 +112,7 @@ class CompletionChunk:
|
||||
def model_dump_json(self)->str:return orjson.dumps(self.model_dump(),option=orjson.OPT_NON_STR_KEYS).decode('utf-8')
|
||||
# yapf: enable
|
||||
|
||||
|
||||
@attr.define
|
||||
class GenerationOutput:
|
||||
prompt: str
|
||||
@@ -111,31 +124,38 @@ class GenerationOutput:
|
||||
|
||||
@classmethod
|
||||
def examples(cls) -> dict[str, t.Any]:
|
||||
return cls(prompt='What is the meaning of life?',
|
||||
finished=True,
|
||||
outputs=[
|
||||
CompletionChunk(index=0,
|
||||
text='\nLife is the process by which organisms, such as bacteria and cells, reproduce themselves and continue to exist.',
|
||||
token_ids=[50118, 12116, 16, 5, 609, 30, 61, 28340, 6, 215, 25, 9436, 8, 4590, 6, 33942, 1235, 8, 535],
|
||||
cumulative_logprob=0.0,
|
||||
logprobs=None,
|
||||
finish_reason='length')
|
||||
],
|
||||
prompt_token_ids=[2, 2264, 16, 5, 3099, 9, 301, 116],
|
||||
prompt_logprobs=None,
|
||||
request_id=gen_random_uuid()).model_dump()
|
||||
return cls(
|
||||
prompt='What is the meaning of life?',
|
||||
finished=True,
|
||||
outputs=[
|
||||
CompletionChunk(
|
||||
index=0,
|
||||
text='\nLife is the process by which organisms, such as bacteria and cells, reproduce themselves and continue to exist.',
|
||||
token_ids=[50118, 12116, 16, 5, 609, 30, 61, 28340, 6, 215, 25, 9436, 8, 4590, 6, 33942, 1235, 8, 535],
|
||||
cumulative_logprob=0.0,
|
||||
logprobs=None,
|
||||
finish_reason='length',
|
||||
)
|
||||
],
|
||||
prompt_token_ids=[2, 2264, 16, 5, 3099, 9, 301, 116],
|
||||
prompt_logprobs=None,
|
||||
request_id=gen_random_uuid(),
|
||||
).model_dump()
|
||||
|
||||
@staticmethod
|
||||
def _preprocess_sse_message(data: str) -> str:
|
||||
proc = [line[6:] for line in data.strip().split('\n') if line.startswith('data: ')]
|
||||
if not proc: return data
|
||||
if len(proc) > 1: raise ValueError('Multiple data found in SSE message.')
|
||||
if not proc:
|
||||
return data
|
||||
if len(proc) > 1:
|
||||
raise ValueError('Multiple data found in SSE message.')
|
||||
return proc[0]
|
||||
|
||||
@classmethod
|
||||
def from_runner(cls, data: str) -> GenerationOutput:
|
||||
data = cls._preprocess_sse_message(data)
|
||||
if not data: raise ValueError('No data found from messages.')
|
||||
if not data:
|
||||
raise ValueError('No data found from messages.')
|
||||
try:
|
||||
return converter.structure(orjson.loads(data), cls)
|
||||
except orjson.JSONDecodeError as e:
|
||||
@@ -143,15 +163,24 @@ class GenerationOutput:
|
||||
|
||||
@classmethod
|
||||
def from_vllm(cls, request_output: vllm.RequestOutput) -> GenerationOutput:
|
||||
return cls(prompt=request_output.prompt,
|
||||
finished=request_output.finished,
|
||||
request_id=request_output.request_id,
|
||||
outputs=[
|
||||
CompletionChunk(index=it.index, text=it.text, token_ids=it.token_ids, cumulative_logprob=it.cumulative_logprob, logprobs=it.logprobs, finish_reason=it.finish_reason)
|
||||
for it in request_output.outputs
|
||||
],
|
||||
prompt_token_ids=request_output.prompt_token_ids,
|
||||
prompt_logprobs=request_output.prompt_logprobs)
|
||||
return cls(
|
||||
prompt=request_output.prompt,
|
||||
finished=request_output.finished,
|
||||
request_id=request_output.request_id,
|
||||
outputs=[
|
||||
CompletionChunk(
|
||||
index=it.index,
|
||||
text=it.text,
|
||||
token_ids=it.token_ids,
|
||||
cumulative_logprob=it.cumulative_logprob,
|
||||
logprobs=it.logprobs,
|
||||
finish_reason=it.finish_reason,
|
||||
)
|
||||
for it in request_output.outputs
|
||||
],
|
||||
prompt_token_ids=request_output.prompt_token_ids,
|
||||
prompt_logprobs=request_output.prompt_logprobs,
|
||||
)
|
||||
|
||||
def with_options(self, **options: t.Any) -> GenerationOutput:
|
||||
return attr.evolve(self, **options)
|
||||
|
||||
@@ -4,6 +4,7 @@ import typing as t
|
||||
|
||||
import attr
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from peft.peft_model import PeftModel
|
||||
from transformers import PreTrainedModel
|
||||
@@ -16,9 +17,11 @@ if t.TYPE_CHECKING:
|
||||
M = t.TypeVar('M', bound='t.Union[PreTrainedModel, PeftModel]')
|
||||
T = t.TypeVar('T', bound='t.Union[PreTrainedTokenizerFast, PreTrainedTokenizer, PreTrainedTokenizerBase]')
|
||||
|
||||
|
||||
def get_literal_args(typ: t.Any) -> tuple[str, ...]:
|
||||
return getattr(typ, '__args__', tuple())
|
||||
|
||||
|
||||
AnyCallable = t.Callable[..., t.Any]
|
||||
DictStrAny = t.Dict[str, t.Any]
|
||||
ListAny = t.List[t.Any]
|
||||
@@ -29,7 +32,9 @@ At = t.TypeVar('At', bound=attr.AttrsInstance)
|
||||
LiteralSerialisation = t.Literal['safetensors', 'legacy']
|
||||
LiteralQuantise = t.Literal['int8', 'int4', 'gptq', 'awq', 'squeezellm']
|
||||
LiteralBackend = t.Literal['pt', 'vllm', 'ggml', 'mlc']
|
||||
AdapterType = t.Literal['lora', 'adalora', 'adaption_prompt', 'prefix_tuning', 'p_tuning', 'prompt_tuning', 'ia3', 'loha', 'lokr']
|
||||
AdapterType = t.Literal[
|
||||
'lora', 'adalora', 'adaption_prompt', 'prefix_tuning', 'p_tuning', 'prompt_tuning', 'ia3', 'loha', 'lokr'
|
||||
]
|
||||
|
||||
# TODO: support quay
|
||||
LiteralContainerRegistry = t.Literal['docker', 'gh', 'ecr']
|
||||
@@ -65,13 +70,16 @@ else:
|
||||
from typing_extensions import TypeAlias as TypeAlias
|
||||
from typing_extensions import TypeGuard as TypeGuard
|
||||
|
||||
|
||||
class AdapterTuple(TupleAny):
|
||||
adapter_id: str
|
||||
name: str
|
||||
config: DictStrAny
|
||||
|
||||
|
||||
AdapterMap = t.Dict[AdapterType, t.Tuple[AdapterTuple, ...]]
|
||||
|
||||
|
||||
class RefTuple(TupleAny):
|
||||
git_hash: str
|
||||
version: VersionInfo
|
||||
|
||||
@@ -14,6 +14,7 @@ from openllm_core.utils import ReprMixin
|
||||
from openllm_core.utils import is_bentoml_available
|
||||
from openllm_core.utils.import_utils import is_transformers_available
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import types
|
||||
|
||||
@@ -27,6 +28,7 @@ if t.TYPE_CHECKING:
|
||||
from openllm_core._typing_compat import LiteralString
|
||||
from openllm_core._typing_compat import M
|
||||
from openllm_core._typing_compat import T
|
||||
|
||||
ConfigKeysView = _odict_keys[str, type[openllm_core.LLMConfig]]
|
||||
ConfigValuesView = _odict_values[str, type[openllm_core.LLMConfig]]
|
||||
ConfigItemsView = _odict_items[str, type[openllm_core.LLMConfig]]
|
||||
@@ -35,9 +37,23 @@ else:
|
||||
OrderedDictType = OrderedDict
|
||||
|
||||
# NOTE: This is the entrypoint when adding new model config
|
||||
CONFIG_MAPPING_NAMES = OrderedDict([('chatglm', 'ChatGLMConfig'), ('dolly_v2', 'DollyV2Config'), ('falcon', 'FalconConfig'), ('flan_t5', 'FlanT5Config'), ('gpt_neox', 'GPTNeoXConfig'),
|
||||
('llama', 'LlamaConfig'), ('mpt', 'MPTConfig'), ('opt', 'OPTConfig'), ('stablelm', 'StableLMConfig'), ('starcoder', 'StarCoderConfig'),
|
||||
('mistral', 'MistralConfig'), ('baichuan', 'BaichuanConfig')])
|
||||
CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
('chatglm', 'ChatGLMConfig'),
|
||||
('dolly_v2', 'DollyV2Config'),
|
||||
('falcon', 'FalconConfig'),
|
||||
('flan_t5', 'FlanT5Config'),
|
||||
('gpt_neox', 'GPTNeoXConfig'),
|
||||
('llama', 'LlamaConfig'),
|
||||
('mpt', 'MPTConfig'),
|
||||
('opt', 'OPTConfig'),
|
||||
('stablelm', 'StableLMConfig'),
|
||||
('starcoder', 'StarCoderConfig'),
|
||||
('mistral', 'MistralConfig'),
|
||||
('baichuan', 'BaichuanConfig'),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class _LazyConfigMapping(OrderedDictType, ReprMixin):
|
||||
def __init__(self, mapping: OrderedDict[LiteralString, LiteralString]):
|
||||
@@ -46,13 +62,17 @@ class _LazyConfigMapping(OrderedDictType, ReprMixin):
|
||||
self._modules: dict[str, types.ModuleType] = {}
|
||||
|
||||
def __getitem__(self, key: str) -> t.Any:
|
||||
if key in self._extra_content: return self._extra_content[key]
|
||||
if key in self._extra_content:
|
||||
return self._extra_content[key]
|
||||
if key not in self._mapping:
|
||||
if inflection.underscore(key) in self._mapping: return self.__getitem__(inflection.underscore(key))
|
||||
if inflection.underscore(key) in self._mapping:
|
||||
return self.__getitem__(inflection.underscore(key))
|
||||
raise KeyError(key)
|
||||
value, module_name = self._mapping[key], inflection.underscore(key)
|
||||
if module_name not in self._modules: self._modules[module_name] = importlib.import_module(f'.configuration_{module_name}', 'openllm_core.config')
|
||||
if hasattr(self._modules[module_name], value): return getattr(self._modules[module_name], value)
|
||||
if module_name not in self._modules:
|
||||
self._modules[module_name] = importlib.import_module(f'.configuration_{module_name}', 'openllm_core.config')
|
||||
if hasattr(self._modules[module_name], value):
|
||||
return getattr(self._modules[module_name], value)
|
||||
# Some of the mappings have entries model_type -> config of another model type. In that case we try to grab the object at the top level.
|
||||
return getattr(importlib.import_module('openllm'), value)
|
||||
|
||||
@@ -82,35 +102,57 @@ class _LazyConfigMapping(OrderedDictType, ReprMixin):
|
||||
return item in self._mapping or item in self._extra_content
|
||||
|
||||
def register(self, key: str, value: t.Any) -> None:
|
||||
if key in self._mapping.keys(): raise ValueError(f"'{key}' is already used by a OpenLLM config, pick another name.")
|
||||
if key in self._mapping.keys():
|
||||
raise ValueError(f"'{key}' is already used by a OpenLLM config, pick another name.")
|
||||
self._extra_content[key] = value
|
||||
|
||||
|
||||
CONFIG_MAPPING: dict[LiteralString, type[openllm_core.LLMConfig]] = _LazyConfigMapping(CONFIG_MAPPING_NAMES)
|
||||
# The below handle special alias when we call underscore to the name directly without processing camelcase first.
|
||||
CONFIG_NAME_ALIASES: dict[str, str] = {'chat_glm': 'chatglm', 'stable_lm': 'stablelm', 'star_coder': 'starcoder', 'gpt_neo_x': 'gpt_neox'}
|
||||
CONFIG_NAME_ALIASES: dict[str, str] = {
|
||||
'chat_glm': 'chatglm',
|
||||
'stable_lm': 'stablelm',
|
||||
'star_coder': 'starcoder',
|
||||
'gpt_neo_x': 'gpt_neox',
|
||||
}
|
||||
CONFIG_FILE_NAME = 'config.json'
|
||||
|
||||
|
||||
class AutoConfig:
|
||||
def __init__(self, *_: t.Any, **__: t.Any):
|
||||
raise EnvironmentError('Cannot instantiate AutoConfig directly. Please use `AutoConfig.for_model(model_name)` instead.')
|
||||
raise EnvironmentError(
|
||||
'Cannot instantiate AutoConfig directly. Please use `AutoConfig.for_model(model_name)` instead.'
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def for_model(cls, model_name: str, **attrs: t.Any) -> openllm_core.LLMConfig:
|
||||
model_name = inflection.underscore(model_name)
|
||||
if model_name in CONFIG_MAPPING: return CONFIG_MAPPING[model_name].model_construct_env(**attrs)
|
||||
raise ValueError(f"Unrecognized configuration class for {model_name}. Model name should be one of {', '.join(CONFIG_MAPPING.keys())}.")
|
||||
if model_name in CONFIG_MAPPING:
|
||||
return CONFIG_MAPPING[model_name].model_construct_env(**attrs)
|
||||
raise ValueError(
|
||||
f"Unrecognized configuration class for {model_name}. Model name should be one of {', '.join(CONFIG_MAPPING.keys())}."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def infer_class_from_name(cls, name: str) -> type[openllm_core.LLMConfig]:
|
||||
model_name = inflection.underscore(name)
|
||||
if model_name in CONFIG_NAME_ALIASES: model_name = CONFIG_NAME_ALIASES[model_name]
|
||||
if model_name in CONFIG_MAPPING: return CONFIG_MAPPING[model_name]
|
||||
raise ValueError(f"Unrecognized configuration class for {model_name}. Model name should be one of {', '.join(CONFIG_MAPPING.keys())}.")
|
||||
if model_name in CONFIG_NAME_ALIASES:
|
||||
model_name = CONFIG_NAME_ALIASES[model_name]
|
||||
if model_name in CONFIG_MAPPING:
|
||||
return CONFIG_MAPPING[model_name]
|
||||
raise ValueError(
|
||||
f"Unrecognized configuration class for {model_name}. Model name should be one of {', '.join(CONFIG_MAPPING.keys())}."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def infer_class_from_llm(cls, llm: openllm.LLM[M, T]) -> type[openllm_core.LLMConfig]:
|
||||
if not is_bentoml_available(): raise MissingDependencyError("'infer_class_from_llm' requires 'bentoml' to be available. Make sure to install it with 'pip install bentoml'")
|
||||
CONFIG_MAPPING_NAMES_TO_ARCHITECTURE: dict[str, str] = {v.__config__['architecture']: k for k, v in CONFIG_MAPPING.items()}
|
||||
if not is_bentoml_available():
|
||||
raise MissingDependencyError(
|
||||
"'infer_class_from_llm' requires 'bentoml' to be available. Make sure to install it with 'pip install bentoml'"
|
||||
)
|
||||
CONFIG_MAPPING_NAMES_TO_ARCHITECTURE: dict[str, str] = {
|
||||
v.__config__['architecture']: k for k, v in CONFIG_MAPPING.items()
|
||||
}
|
||||
if llm._local:
|
||||
config_file = os.path.join(llm.model_id, CONFIG_FILE_NAME)
|
||||
else:
|
||||
@@ -118,16 +160,25 @@ class AutoConfig:
|
||||
config_file = llm.bentomodel.path_of(CONFIG_FILE_NAME)
|
||||
except OpenLLMException:
|
||||
if not is_transformers_available():
|
||||
raise MissingDependencyError("'infer_class_from_llm' requires 'transformers' to be available. Make sure to install it with 'pip install transformers'")
|
||||
raise MissingDependencyError(
|
||||
"'infer_class_from_llm' requires 'transformers' to be available. Make sure to install it with 'pip install transformers'"
|
||||
)
|
||||
from transformers.utils import cached_file
|
||||
|
||||
try:
|
||||
config_file = cached_file(llm.model_id, CONFIG_FILE_NAME)
|
||||
except Exception as err:
|
||||
raise ValueError("Failed to determine architecture from 'config.json'. If this is a gated model, make sure to pass in HUGGING_FACE_HUB_TOKEN") from err
|
||||
if not os.path.exists(config_file): raise ValueError(f"Failed to find 'config.json' (config_json_path={config_file})")
|
||||
raise ValueError(
|
||||
"Failed to determine architecture from 'config.json'. If this is a gated model, make sure to pass in HUGGING_FACE_HUB_TOKEN"
|
||||
) from err
|
||||
if not os.path.exists(config_file):
|
||||
raise ValueError(f"Failed to find 'config.json' (config_json_path={config_file})")
|
||||
with open(config_file, 'r', encoding='utf-8') as f:
|
||||
loaded_config = orjson.loads(f.read())
|
||||
if 'architectures' in loaded_config:
|
||||
for architecture in loaded_config['architectures']:
|
||||
if architecture in CONFIG_MAPPING_NAMES_TO_ARCHITECTURE: return cls.infer_class_from_name(CONFIG_MAPPING_NAMES_TO_ARCHITECTURE[architecture])
|
||||
raise ValueError(f"Failed to determine config class for '{llm.model_id}'. Make sure {llm.model_id} is saved with openllm.")
|
||||
if architecture in CONFIG_MAPPING_NAMES_TO_ARCHITECTURE:
|
||||
return cls.infer_class_from_name(CONFIG_MAPPING_NAMES_TO_ARCHITECTURE[architecture])
|
||||
raise ValueError(
|
||||
f"Failed to determine config class for '{llm.model_id}'. Make sure {llm.model_id} is saved with openllm."
|
||||
)
|
||||
|
||||
@@ -6,7 +6,8 @@ import openllm_core
|
||||
from openllm_core._conversation import SeparatorStyle
|
||||
from openllm_core.prompts import PromptTemplate
|
||||
|
||||
START_BAICHUAN_COMMAND_DOCSTRING = '''\
|
||||
|
||||
START_BAICHUAN_COMMAND_DOCSTRING = """\
|
||||
Run a LLMServer for Baichuan model.
|
||||
|
||||
\b
|
||||
@@ -24,10 +25,11 @@ or provide `--model-id` flag when running ``openllm start baichuan``:
|
||||
|
||||
\b
|
||||
$ openllm start baichuan --model-id='fireballoon/baichuan-vicuna-chinese-7b'
|
||||
'''
|
||||
"""
|
||||
DEFAULT_SYSTEM_MESSAGE = ''
|
||||
DEFAULT_PROMPT_TEMPLATE = PromptTemplate('{instruction}')
|
||||
|
||||
|
||||
class BaichuanConfig(openllm_core.LLMConfig):
|
||||
"""Baichuan-7B is an open-source, large-scale pre-trained language model developed by Baichuan Intelligent Technology.
|
||||
|
||||
@@ -38,23 +40,28 @@ class BaichuanConfig(openllm_core.LLMConfig):
|
||||
and English benchmarks (C-Eval, MMLU, etc).
|
||||
Refer to [Baichuan-7B's GitHub page](https://github.com/baichuan-inc/Baichuan-7B) for more information.
|
||||
"""
|
||||
|
||||
__config__ = {
|
||||
'name_type': 'lowercase',
|
||||
'trust_remote_code': True,
|
||||
'timeout': 3600000,
|
||||
'url': 'https://github.com/baichuan-inc/Baichuan-7B',
|
||||
'requirements': ['cpm-kernels', 'sentencepiece'],
|
||||
'architecture': 'BaiChuanForCausalLM',
|
||||
# NOTE: See the following
|
||||
# https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/19ef51ba5bad8935b03acd20ff04a269210983bc/modeling_baichuan.py#L555
|
||||
# https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/main/generation_config.json
|
||||
# https://github.com/baichuan-inc/Baichuan-13B/issues/25
|
||||
'conversation': dict(roles=('<reserved_102>', '<reserved_103>'), sep_style=SeparatorStyle.NO_COLON_SINGLE, sep=''),
|
||||
'default_id': 'baichuan-inc/baichuan-7b',
|
||||
'model_ids': [
|
||||
'baichuan-inc/baichuan-7b', 'baichuan-inc/baichuan-13b-base', 'baichuan-inc/baichuan-13b-chat', 'fireballoon/baichuan-vicuna-chinese-7b', 'fireballoon/baichuan-vicuna-7b',
|
||||
'hiyouga/baichuan-7b-sft'
|
||||
]
|
||||
'name_type': 'lowercase',
|
||||
'trust_remote_code': True,
|
||||
'timeout': 3600000,
|
||||
'url': 'https://github.com/baichuan-inc/Baichuan-7B',
|
||||
'requirements': ['cpm-kernels', 'sentencepiece'],
|
||||
'architecture': 'BaiChuanForCausalLM',
|
||||
# NOTE: See the following
|
||||
# https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/19ef51ba5bad8935b03acd20ff04a269210983bc/modeling_baichuan.py#L555
|
||||
# https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/main/generation_config.json
|
||||
# https://github.com/baichuan-inc/Baichuan-13B/issues/25
|
||||
'conversation': dict(roles=('<reserved_102>', '<reserved_103>'), sep_style=SeparatorStyle.NO_COLON_SINGLE, sep=''),
|
||||
'default_id': 'baichuan-inc/baichuan-7b',
|
||||
'model_ids': [
|
||||
'baichuan-inc/baichuan-7b',
|
||||
'baichuan-inc/baichuan-13b-base',
|
||||
'baichuan-inc/baichuan-13b-chat',
|
||||
'fireballoon/baichuan-vicuna-chinese-7b',
|
||||
'fireballoon/baichuan-vicuna-7b',
|
||||
'hiyouga/baichuan-7b-sft',
|
||||
],
|
||||
}
|
||||
|
||||
class GenerationConfig:
|
||||
@@ -70,18 +77,26 @@ class BaichuanConfig(openllm_core.LLMConfig):
|
||||
def default_system_message(self) -> str:
|
||||
return DEFAULT_SYSTEM_MESSAGE
|
||||
|
||||
def sanitize_parameters(self,
|
||||
prompt: str,
|
||||
prompt_template: PromptTemplate | str | None = None,
|
||||
system_message: str | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
top_p: float | None = None,
|
||||
temperature: float | None = None,
|
||||
**attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt_template: PromptTemplate | str | None = None,
|
||||
system_message: str | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
top_p: float | None = None,
|
||||
temperature: float | None = None,
|
||||
**attrs: t.Any,
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
system_message = DEFAULT_SYSTEM_MESSAGE if system_message is None else system_message
|
||||
if prompt_template is None: prompt_template = DEFAULT_PROMPT_TEMPLATE
|
||||
elif isinstance(prompt_template, str): prompt_template = PromptTemplate(template=prompt_template)
|
||||
return prompt_template.with_options(system_message=system_message).format(instruction=prompt), {'max_new_tokens': max_new_tokens, 'top_p': top_p, 'temperature': temperature, **attrs}, {}
|
||||
if prompt_template is None:
|
||||
prompt_template = DEFAULT_PROMPT_TEMPLATE
|
||||
elif isinstance(prompt_template, str):
|
||||
prompt_template = PromptTemplate(template=prompt_template)
|
||||
return (
|
||||
prompt_template.with_options(system_message=system_message).format(instruction=prompt),
|
||||
{'max_new_tokens': max_new_tokens, 'top_p': top_p, 'temperature': temperature, **attrs},
|
||||
{},
|
||||
)
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str:
|
||||
return generation_result[0]
|
||||
|
||||
@@ -6,10 +6,11 @@ import openllm_core
|
||||
from openllm_core._conversation import SeparatorStyle
|
||||
from openllm_core.utils import dantic
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm_core.prompts import PromptTemplate
|
||||
|
||||
START_CHATGLM_COMMAND_DOCSTRING = '''\
|
||||
START_CHATGLM_COMMAND_DOCSTRING = """\
|
||||
Run a LLMServer for ChatGLM model.
|
||||
|
||||
\b
|
||||
@@ -27,8 +28,9 @@ or provide `--model-id` flag when running ``openllm start chatglm``:
|
||||
|
||||
\b
|
||||
$ openllm start chatglm --model-id='thudm/chatglm-6b-int8'
|
||||
'''
|
||||
DEFAULT_PROMPT_TEMPLATE = '''{instruction}'''
|
||||
"""
|
||||
DEFAULT_PROMPT_TEMPLATE = """{instruction}"""
|
||||
|
||||
|
||||
class ChatGLMConfig(openllm_core.LLMConfig):
|
||||
"""ChatGLM is an open bilingual language model based on [General Language Model (GLM)](https://github.com/THUDM/GLM) framework.
|
||||
@@ -44,19 +46,29 @@ class ChatGLMConfig(openllm_core.LLMConfig):
|
||||
|
||||
Refer to [ChatGLM's GitHub page](https://github.com/THUDM/ChatGLM-6B) for more information.
|
||||
"""
|
||||
|
||||
__config__ = {
|
||||
'name_type': 'lowercase',
|
||||
'trust_remote_code': True,
|
||||
'timeout': 3600000,
|
||||
'backend': ('pt',),
|
||||
'url': 'https://github.com/THUDM/ChatGLM-6B',
|
||||
'conversation': dict(roles=('问', '答'), sep_style=SeparatorStyle.CHATGLM, sep='\n'),
|
||||
'requirements': ['cpm-kernels', 'sentencepiece'],
|
||||
'architecture': 'ChatGLMModel',
|
||||
'default_id': 'thudm/chatglm-6b',
|
||||
'model_ids': ['thudm/chatglm-6b', 'thudm/chatglm-6b-int8', 'thudm/chatglm-6b-int4', 'thudm/chatglm2-6b', 'thudm/chatglm2-6b-int4']
|
||||
'name_type': 'lowercase',
|
||||
'trust_remote_code': True,
|
||||
'timeout': 3600000,
|
||||
'backend': ('pt',),
|
||||
'url': 'https://github.com/THUDM/ChatGLM-6B',
|
||||
'conversation': dict(roles=('问', '答'), sep_style=SeparatorStyle.CHATGLM, sep='\n'),
|
||||
'requirements': ['cpm-kernels', 'sentencepiece'],
|
||||
'architecture': 'ChatGLMModel',
|
||||
'default_id': 'thudm/chatglm-6b',
|
||||
'model_ids': [
|
||||
'thudm/chatglm-6b',
|
||||
'thudm/chatglm-6b-int8',
|
||||
'thudm/chatglm-6b-int4',
|
||||
'thudm/chatglm2-6b',
|
||||
'thudm/chatglm2-6b-int4',
|
||||
],
|
||||
}
|
||||
retain_history: bool = dantic.Field(False, description='Whether to retain history given to the model. If set to True, then the model will retain given history.')
|
||||
retain_history: bool = dantic.Field(
|
||||
False,
|
||||
description='Whether to retain history given to the model. If set to True, then the model will retain given history.',
|
||||
)
|
||||
use_half_precision: bool = dantic.Field(True, description='Whether to use half precision for model.')
|
||||
|
||||
class GenerationConfig:
|
||||
@@ -65,17 +77,19 @@ class ChatGLMConfig(openllm_core.LLMConfig):
|
||||
top_p: float = 0.7
|
||||
temperature: float = 0.95
|
||||
|
||||
def sanitize_parameters(self,
|
||||
prompt: str,
|
||||
prompt_template: PromptTemplate | str | None = None,
|
||||
system_message: str | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
num_beams: int | None = None,
|
||||
top_p: float | None = None,
|
||||
temperature: float | None = None,
|
||||
chat_history: list[tuple[str, str]] | None = None,
|
||||
use_default_prompt_template: bool = False,
|
||||
**attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt_template: PromptTemplate | str | None = None,
|
||||
system_message: str | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
num_beams: int | None = None,
|
||||
top_p: float | None = None,
|
||||
temperature: float | None = None,
|
||||
chat_history: list[tuple[str, str]] | None = None,
|
||||
use_default_prompt_template: bool = False,
|
||||
**attrs: t.Any,
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
prompt_text = ''
|
||||
if use_default_prompt_template and chat_history is not None:
|
||||
for i, (old_query, response) in enumerate(chat_history):
|
||||
@@ -84,11 +98,23 @@ class ChatGLMConfig(openllm_core.LLMConfig):
|
||||
else:
|
||||
prompt_text = prompt
|
||||
postprocess_generate_kwargs = {'chat_history': chat_history if chat_history is not None else None}
|
||||
return prompt_text, {'max_new_tokens': max_new_tokens, 'num_beams': num_beams, 'top_p': top_p, 'temperature': temperature, **attrs}, postprocess_generate_kwargs
|
||||
return (
|
||||
prompt_text,
|
||||
{'max_new_tokens': max_new_tokens, 'num_beams': num_beams, 'top_p': top_p, 'temperature': temperature, **attrs},
|
||||
postprocess_generate_kwargs,
|
||||
)
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: tuple[str, list[tuple[str, str]]], *, chat_history: list[tuple[str, str]] | None = None, **attrs: t.Any) -> str:
|
||||
def postprocess_generate(
|
||||
self,
|
||||
prompt: str,
|
||||
generation_result: tuple[str, list[tuple[str, str]]],
|
||||
*,
|
||||
chat_history: list[tuple[str, str]] | None = None,
|
||||
**attrs: t.Any,
|
||||
) -> str:
|
||||
generated, history = generation_result
|
||||
if self.config.retain_history:
|
||||
if chat_history is None: raise ValueError("'retain_history' is True while there is no history provided.")
|
||||
if chat_history is None:
|
||||
raise ValueError("'retain_history' is True while there is no history provided.")
|
||||
chat_history.extend(history)
|
||||
return generated
|
||||
|
||||
@@ -8,10 +8,11 @@ from openllm_core.prompts import PromptTemplate
|
||||
from openllm_core.prompts import process_prompt
|
||||
from openllm_core.utils import dantic
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import transformers
|
||||
|
||||
START_DOLLY_V2_COMMAND_DOCSTRING = '''\
|
||||
START_DOLLY_V2_COMMAND_DOCSTRING = """\
|
||||
Run a LLMServer for dolly-v2 model.
|
||||
|
||||
\b
|
||||
@@ -29,22 +30,25 @@ or provide `--model-id` flag when running ``openllm start dolly-v2``:
|
||||
|
||||
\b
|
||||
$ openllm start dolly-v2 --model-id databricks/dolly-v2-7b
|
||||
'''
|
||||
"""
|
||||
INSTRUCTION_KEY = '### Instruction:'
|
||||
RESPONSE_KEY = '### Response:'
|
||||
END_KEY = '### End'
|
||||
INTRO_BLURB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.'
|
||||
INTRO_BLURB = (
|
||||
'Below is an instruction that describes a task. Write a response that appropriately completes the request.'
|
||||
)
|
||||
# NOTE: This is the prompt that is used for generating responses using an already
|
||||
# trained model. It ends with the response key, where the job of the model is to provide
|
||||
# the completion that follows it (i.e. the response itself).
|
||||
DEFAULT_PROMPT_TEMPLATE = '''{intro}
|
||||
DEFAULT_PROMPT_TEMPLATE = """{intro}
|
||||
{instruction_key}
|
||||
{instruction}
|
||||
{response_key}
|
||||
'''.format(intro=INTRO_BLURB, instruction_key=INSTRUCTION_KEY, instruction='{instruction}', response_key=RESPONSE_KEY)
|
||||
""".format(intro=INTRO_BLURB, instruction_key=INSTRUCTION_KEY, instruction='{instruction}', response_key=RESPONSE_KEY)
|
||||
|
||||
|
||||
def get_special_token_id(tokenizer: transformers.PreTrainedTokenizer, key: str) -> int:
|
||||
'''Gets the token ID for a given string that has been added to the tokenizer as a special token.
|
||||
"""Gets the token ID for a given string that has been added to the tokenizer as a special token.
|
||||
|
||||
When training, we configure the tokenizer so that the sequences like "### Instruction:" and "### End" are
|
||||
treated specially and converted to a single, new token. This retrieves the token ID each of these keys map to.
|
||||
@@ -58,11 +62,13 @@ def get_special_token_id(tokenizer: transformers.PreTrainedTokenizer, key: str)
|
||||
|
||||
Returns:
|
||||
int: the token ID for the given key.
|
||||
'''
|
||||
"""
|
||||
token_ids = tokenizer.encode(key)
|
||||
if len(token_ids) > 1: raise ValueError(f"Expected only a single token for '{key}' but found {token_ids}")
|
||||
if len(token_ids) > 1:
|
||||
raise ValueError(f"Expected only a single token for '{key}' but found {token_ids}")
|
||||
return token_ids[0]
|
||||
|
||||
|
||||
class DollyV2Config(openllm_core.LLMConfig):
|
||||
"""Databricks` Dolly is an instruction-following large language model trained on the Databricks machine learning platform that is licensed for commercial use.
|
||||
|
||||
@@ -75,17 +81,20 @@ class DollyV2Config(openllm_core.LLMConfig):
|
||||
|
||||
Refer to [Databricks's Dolly page](https://github.com/databrickslabs/dolly) for more information.
|
||||
"""
|
||||
|
||||
__config__ = {
|
||||
'timeout': 3600000,
|
||||
'url': 'https://github.com/databrickslabs/dolly',
|
||||
'architecture': 'GPTNeoXForCausalLM',
|
||||
'default_id': 'databricks/dolly-v2-3b',
|
||||
'conversation': dict(system_message='Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n',
|
||||
roles=('### Instruction', '### Response'),
|
||||
sep_style=SeparatorStyle.DOLLY,
|
||||
sep='\n\n',
|
||||
sep2='### End'),
|
||||
'model_ids': ['databricks/dolly-v2-3b', 'databricks/dolly-v2-7b', 'databricks/dolly-v2-12b']
|
||||
'timeout': 3600000,
|
||||
'url': 'https://github.com/databrickslabs/dolly',
|
||||
'architecture': 'GPTNeoXForCausalLM',
|
||||
'default_id': 'databricks/dolly-v2-3b',
|
||||
'conversation': dict(
|
||||
system_message='Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n',
|
||||
roles=('### Instruction', '### Response'),
|
||||
sep_style=SeparatorStyle.DOLLY,
|
||||
sep='\n\n',
|
||||
sep2='### End',
|
||||
),
|
||||
'model_ids': ['databricks/dolly-v2-3b', 'databricks/dolly-v2-7b', 'databricks/dolly-v2-12b'],
|
||||
}
|
||||
return_full_text: bool = dantic.Field(False, description='Whether to return the full prompt to the users.')
|
||||
|
||||
@@ -96,23 +105,25 @@ class DollyV2Config(openllm_core.LLMConfig):
|
||||
max_new_tokens: int = 256
|
||||
eos_token_id: int = 50277 # NOTE: from get_special_token_id(self.tokenizer, END_KEY)
|
||||
|
||||
def sanitize_parameters(self,
|
||||
prompt: str,
|
||||
prompt_template: PromptTemplate | str | None = None,
|
||||
system_message: str | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
top_k: int | None = None,
|
||||
top_p: float | None = None,
|
||||
use_default_prompt_template: bool = True,
|
||||
**attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {
|
||||
'max_new_tokens': max_new_tokens,
|
||||
'top_k': top_k,
|
||||
'top_p': top_p,
|
||||
'temperature': temperature,
|
||||
**attrs
|
||||
}, {}
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt_template: PromptTemplate | str | None = None,
|
||||
system_message: str | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
top_k: int | None = None,
|
||||
top_p: float | None = None,
|
||||
use_default_prompt_template: bool = True,
|
||||
**attrs: t.Any,
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return (
|
||||
process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs),
|
||||
{'max_new_tokens': max_new_tokens, 'top_k': top_k, 'top_p': top_p, 'temperature': temperature, **attrs},
|
||||
{},
|
||||
)
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: list[dict[t.Literal['generated_text'], str]], **_: t.Any) -> str:
|
||||
def postprocess_generate(
|
||||
self, prompt: str, generation_result: list[dict[t.Literal['generated_text'], str]], **_: t.Any
|
||||
) -> str:
|
||||
return generation_result[0]['generated_text']
|
||||
|
||||
@@ -7,7 +7,8 @@ from openllm_core._conversation import SeparatorStyle
|
||||
from openllm_core.prompts import PromptTemplate
|
||||
from openllm_core.prompts import process_prompt
|
||||
|
||||
START_FALCON_COMMAND_DOCSTRING = '''\
|
||||
|
||||
START_FALCON_COMMAND_DOCSTRING = """\
|
||||
Run a LLMServer for FalconLM model.
|
||||
|
||||
\b
|
||||
@@ -27,11 +28,12 @@ or provide `--model-id` flag when running ``openllm start falcon``:
|
||||
|
||||
\b
|
||||
$ openllm start falcon --model-id tiiuae/falcon-7b-instruct
|
||||
'''
|
||||
DEFAULT_PROMPT_TEMPLATE = '''{context}
|
||||
"""
|
||||
DEFAULT_PROMPT_TEMPLATE = """{context}
|
||||
{user_name}: {instruction}
|
||||
{agent}:
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
class FalconConfig(openllm_core.LLMConfig):
|
||||
"""Falcon-7B is a 7B parameters causal decoder-only model built by TII and trained on 1,500B tokens of [RefinedWeb](https://huggingface.co/datasets/tiiuae/falcon-refinedweb) enhanced with curated corpora.
|
||||
@@ -40,25 +42,30 @@ class FalconConfig(openllm_core.LLMConfig):
|
||||
|
||||
Refer to [Falcon's HuggingFace page](https://huggingface.co/tiiuae/falcon-7b) for more information.
|
||||
"""
|
||||
|
||||
__config__ = {
|
||||
'name_type': 'lowercase',
|
||||
'trust_remote_code': False,
|
||||
'timeout': int(36e6),
|
||||
'url': 'https://falconllm.tii.ae/',
|
||||
'requirements': ['einops', 'xformers'],
|
||||
'architecture': 'FalconForCausalLM',
|
||||
# NOTE: See https://huggingface.co/tiiuae/falcon-7b-instruct/discussions/1
|
||||
'conversation': dict(roles=('User', 'Assistant'), messages=[], sep_style=SeparatorStyle.ADD_COLON_SINGLE, sep='\n'), # No space after colon
|
||||
'default_id': 'tiiuae/falcon-7b',
|
||||
'model_ids': ['tiiuae/falcon-7b', 'tiiuae/falcon-40b', 'tiiuae/falcon-7b-instruct', 'tiiuae/falcon-40b-instruct'],
|
||||
'fine_tune_strategies': ({
|
||||
'adapter_type': 'lora',
|
||||
'r': 64,
|
||||
'lora_alpha': 16,
|
||||
'lora_dropout': 0.1,
|
||||
'bias': 'none',
|
||||
'target_modules': ['query_key_value', 'dense', 'dense_h_to_4h', 'dense_4h_to_h']
|
||||
},)
|
||||
'name_type': 'lowercase',
|
||||
'trust_remote_code': False,
|
||||
'timeout': int(36e6),
|
||||
'url': 'https://falconllm.tii.ae/',
|
||||
'requirements': ['einops', 'xformers'],
|
||||
'architecture': 'FalconForCausalLM',
|
||||
# NOTE: See https://huggingface.co/tiiuae/falcon-7b-instruct/discussions/1
|
||||
'conversation': dict(
|
||||
roles=('User', 'Assistant'), messages=[], sep_style=SeparatorStyle.ADD_COLON_SINGLE, sep='\n'
|
||||
), # No space after colon
|
||||
'default_id': 'tiiuae/falcon-7b',
|
||||
'model_ids': ['tiiuae/falcon-7b', 'tiiuae/falcon-40b', 'tiiuae/falcon-7b-instruct', 'tiiuae/falcon-40b-instruct'],
|
||||
'fine_tune_strategies': (
|
||||
{
|
||||
'adapter_type': 'lora',
|
||||
'r': 64,
|
||||
'lora_alpha': 16,
|
||||
'lora_dropout': 0.1,
|
||||
'bias': 'none',
|
||||
'target_modules': ['query_key_value', 'dense', 'dense_h_to_4h', 'dense_4h_to_h'],
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
class GenerationConfig:
|
||||
@@ -68,23 +75,29 @@ class FalconConfig(openllm_core.LLMConfig):
|
||||
num_beams: int = 4
|
||||
early_stopping: bool = True
|
||||
|
||||
def sanitize_parameters(self,
|
||||
prompt: str,
|
||||
prompt_template: PromptTemplate | str | None = None,
|
||||
system_message: str | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
top_k: int | None = None,
|
||||
num_return_sequences: int | None = None,
|
||||
eos_token_id: int | None = None,
|
||||
use_default_prompt_template: bool = False,
|
||||
**attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt_template: PromptTemplate | str | None = None,
|
||||
system_message: str | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
top_k: int | None = None,
|
||||
num_return_sequences: int | None = None,
|
||||
eos_token_id: int | None = None,
|
||||
use_default_prompt_template: bool = False,
|
||||
**attrs: t.Any,
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return (
|
||||
process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs),
|
||||
{
|
||||
'max_new_tokens': max_new_tokens,
|
||||
'top_k': top_k,
|
||||
'num_return_sequences': num_return_sequences,
|
||||
'eos_token_id': eos_token_id,
|
||||
**attrs
|
||||
}, {}
|
||||
**attrs,
|
||||
},
|
||||
{},
|
||||
)
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str:
|
||||
return generation_result[0]
|
||||
|
||||
@@ -7,7 +7,8 @@ from openllm_core._conversation import SeparatorStyle
|
||||
from openllm_core.prompts import PromptTemplate
|
||||
from openllm_core.prompts import process_prompt
|
||||
|
||||
START_FLAN_T5_COMMAND_DOCSTRING = '''\
|
||||
|
||||
START_FLAN_T5_COMMAND_DOCSTRING = """\
|
||||
Run a LLMServer for FLAN-T5 model.
|
||||
|
||||
\b
|
||||
@@ -25,8 +26,9 @@ or provide `--model-id` flag when running ``openllm start flan-t5``:
|
||||
|
||||
\b
|
||||
$ openllm start flan-t5 --model-id google/flan-t5-xxl
|
||||
'''
|
||||
DEFAULT_PROMPT_TEMPLATE = '''Answer the following question:\nQuestion: {instruction}\nAnswer:'''
|
||||
"""
|
||||
DEFAULT_PROMPT_TEMPLATE = """Answer the following question:\nQuestion: {instruction}\nAnswer:"""
|
||||
|
||||
|
||||
class FlanT5Config(openllm_core.LLMConfig):
|
||||
"""FLAN-T5 was released in the paper [Scaling Instruction-Finetuned Language Models](https://arxiv.org/pdf/2210.11416.pdf).
|
||||
@@ -35,15 +37,24 @@ class FlanT5Config(openllm_core.LLMConfig):
|
||||
|
||||
Refer to [FLAN-T5's page](https://huggingface.co/docs/transformers/model_doc/flan-t5) for more information.
|
||||
"""
|
||||
|
||||
__config__ = {
|
||||
'url': 'https://huggingface.co/docs/transformers/model_doc/flan-t5',
|
||||
'architecture': 'T5ForConditionalGeneration',
|
||||
'model_type': 'seq2seq_lm',
|
||||
'backend': ('pt',),
|
||||
# NOTE: See https://www.philschmid.de/fine-tune-flan-t5. No specific template found, but seems to have the same dialogue style
|
||||
'conversation': dict(system_message='', roles=('User', 'Assistant'), sep_style=SeparatorStyle.ADD_COLON_SINGLE, sep='\n'),
|
||||
'default_id': 'google/flan-t5-large',
|
||||
'model_ids': ['google/flan-t5-small', 'google/flan-t5-base', 'google/flan-t5-large', 'google/flan-t5-xl', 'google/flan-t5-xxl']
|
||||
'url': 'https://huggingface.co/docs/transformers/model_doc/flan-t5',
|
||||
'architecture': 'T5ForConditionalGeneration',
|
||||
'model_type': 'seq2seq_lm',
|
||||
'backend': ('pt',),
|
||||
# NOTE: See https://www.philschmid.de/fine-tune-flan-t5. No specific template found, but seems to have the same dialogue style
|
||||
'conversation': dict(
|
||||
system_message='', roles=('User', 'Assistant'), sep_style=SeparatorStyle.ADD_COLON_SINGLE, sep='\n'
|
||||
),
|
||||
'default_id': 'google/flan-t5-large',
|
||||
'model_ids': [
|
||||
'google/flan-t5-small',
|
||||
'google/flan-t5-base',
|
||||
'google/flan-t5-large',
|
||||
'google/flan-t5-xl',
|
||||
'google/flan-t5-xxl',
|
||||
],
|
||||
}
|
||||
|
||||
class GenerationConfig:
|
||||
@@ -53,24 +64,30 @@ class FlanT5Config(openllm_core.LLMConfig):
|
||||
top_p: float = 0.4
|
||||
repetition_penalty = 1.0
|
||||
|
||||
def sanitize_parameters(self,
|
||||
prompt: str,
|
||||
prompt_template: PromptTemplate | str | None = None,
|
||||
system_message: str | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
top_k: int | None = None,
|
||||
top_p: float | None = None,
|
||||
repetition_penalty: float | None = None,
|
||||
use_default_prompt_template: bool = True,
|
||||
**attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt_template: PromptTemplate | str | None = None,
|
||||
system_message: str | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
top_k: int | None = None,
|
||||
top_p: float | None = None,
|
||||
repetition_penalty: float | None = None,
|
||||
use_default_prompt_template: bool = True,
|
||||
**attrs: t.Any,
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return (
|
||||
process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs),
|
||||
{
|
||||
'max_new_tokens': max_new_tokens,
|
||||
'temperature': temperature,
|
||||
'top_k': top_k,
|
||||
'top_p': top_p,
|
||||
'repetition_penalty': repetition_penalty
|
||||
}, {}
|
||||
'repetition_penalty': repetition_penalty,
|
||||
},
|
||||
{},
|
||||
)
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str:
|
||||
return generation_result[0]
|
||||
|
||||
@@ -8,7 +8,8 @@ from openllm_core.prompts import PromptTemplate
|
||||
from openllm_core.prompts import process_prompt
|
||||
from openllm_core.utils import dantic
|
||||
|
||||
START_GPT_NEOX_COMMAND_DOCSTRING = '''\
|
||||
|
||||
START_GPT_NEOX_COMMAND_DOCSTRING = """\
|
||||
Run a LLMServer for GPTNeoX model.
|
||||
|
||||
\b
|
||||
@@ -26,8 +27,9 @@ or provide `--model-id` flag when running ``openllm start gpt-neox``:
|
||||
|
||||
\b
|
||||
$ openllm start gpt-neox --model-id 'stabilityai/stablelm-tuned-alpha-3b'
|
||||
'''
|
||||
DEFAULT_PROMPT_TEMPLATE = '''{instruction}'''
|
||||
"""
|
||||
DEFAULT_PROMPT_TEMPLATE = """{instruction}"""
|
||||
|
||||
|
||||
class GPTNeoXConfig(openllm_core.LLMConfig):
|
||||
"""GPTNeoX is an autoregressive language model trained on the Pile, whose weights will be made freely and openly available to the public through a permissive license.
|
||||
@@ -44,15 +46,18 @@ class GPTNeoXConfig(openllm_core.LLMConfig):
|
||||
Refer to [GPTNeoX's model card](https://huggingface.co/docs/transformers/model_doc/gpt_neox)
|
||||
for more information.
|
||||
"""
|
||||
|
||||
__config__ = {
|
||||
'model_name': 'gpt_neox',
|
||||
'start_name': 'gpt-neox',
|
||||
'architecture': 'GPTNeoXForCausalLM',
|
||||
# NOTE: See https://huggingface.co/togethercomputer/GPT-NeoXT-Chat-Base-20B
|
||||
'conversation': dict(system_message='', roles=('<human>', '<bot>'), sep_style=SeparatorStyle.ADD_COLON_SPACE_SINGLE, sep='\n'),
|
||||
'url': 'https://github.com/EleutherAI/gpt-neox',
|
||||
'default_id': 'eleutherai/gpt-neox-20b',
|
||||
'model_ids': ['eleutherai/gpt-neox-20b']
|
||||
'model_name': 'gpt_neox',
|
||||
'start_name': 'gpt-neox',
|
||||
'architecture': 'GPTNeoXForCausalLM',
|
||||
# NOTE: See https://huggingface.co/togethercomputer/GPT-NeoXT-Chat-Base-20B
|
||||
'conversation': dict(
|
||||
system_message='', roles=('<human>', '<bot>'), sep_style=SeparatorStyle.ADD_COLON_SPACE_SINGLE, sep='\n'
|
||||
),
|
||||
'url': 'https://github.com/EleutherAI/gpt-neox',
|
||||
'default_id': 'eleutherai/gpt-neox-20b',
|
||||
'model_ids': ['eleutherai/gpt-neox-20b'],
|
||||
}
|
||||
use_half_precision: bool = dantic.Field(True, description='Whether to use half precision for model.')
|
||||
|
||||
@@ -60,15 +65,21 @@ class GPTNeoXConfig(openllm_core.LLMConfig):
|
||||
temperature: float = 0.9
|
||||
max_new_tokens: int = 100
|
||||
|
||||
def sanitize_parameters(self,
|
||||
prompt: str,
|
||||
prompt_template: PromptTemplate | str | None = None,
|
||||
system_message: str | None = None,
|
||||
temperature: float | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
use_default_prompt_template: bool = True,
|
||||
**attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {'max_new_tokens': max_new_tokens, 'temperature': temperature}, {}
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt_template: PromptTemplate | str | None = None,
|
||||
system_message: str | None = None,
|
||||
temperature: float | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
use_default_prompt_template: bool = True,
|
||||
**attrs: t.Any,
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return (
|
||||
process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs),
|
||||
{'max_new_tokens': max_new_tokens, 'temperature': temperature},
|
||||
{},
|
||||
)
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: list[str], **_: t.Any) -> str:
|
||||
return generation_result[0]
|
||||
|
||||
@@ -6,7 +6,8 @@ import openllm_core
|
||||
from openllm_core._conversation import SeparatorStyle
|
||||
from openllm_core.prompts import PromptTemplate
|
||||
|
||||
START_LLAMA_COMMAND_DOCSTRING = '''\
|
||||
|
||||
START_LLAMA_COMMAND_DOCSTRING = """\
|
||||
Run a LLMServer for Llama model.
|
||||
|
||||
\b
|
||||
@@ -37,26 +38,34 @@ OpenLLM also supports running Llama-2 and its fine-tune and variants. To import
|
||||
|
||||
\b
|
||||
$ CONVERTER=hf-llama2 openllm import llama /path/to/llama-2
|
||||
'''
|
||||
DEFAULT_SYSTEM_MESSAGE = '''
|
||||
"""
|
||||
DEFAULT_SYSTEM_MESSAGE = """
|
||||
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
||||
|
||||
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
|
||||
'''
|
||||
"""
|
||||
SINST_KEY, EINST_KEY, SYS_KEY, EOS_TOKEN, BOS_TOKEN = '[INST]', '[/INST]', '<<SYS>>', '</s>', '<s>'
|
||||
# TODO: support history and v1 prompt implementation
|
||||
_v1_prompt, _v2_prompt = '''{instruction}''', '''{start_key} {sys_key}\n{system_message}\n{sys_key}\n\n{instruction}\n{end_key}\n'''.format(start_key=SINST_KEY,
|
||||
sys_key=SYS_KEY,
|
||||
system_message='{system_message}',
|
||||
instruction='{instruction}',
|
||||
end_key=EINST_KEY)
|
||||
_v1_prompt, _v2_prompt = (
|
||||
"""{instruction}""",
|
||||
"""{start_key} {sys_key}\n{system_message}\n{sys_key}\n\n{instruction}\n{end_key}\n""".format(
|
||||
start_key=SINST_KEY,
|
||||
sys_key=SYS_KEY,
|
||||
system_message='{system_message}',
|
||||
instruction='{instruction}',
|
||||
end_key=EINST_KEY,
|
||||
),
|
||||
)
|
||||
PROMPT_MAPPING = {'v1': _v1_prompt, 'v2': _v2_prompt}
|
||||
|
||||
|
||||
def _get_prompt(model_type: t.Literal['v1', 'v2']) -> PromptTemplate:
|
||||
return PromptTemplate(PROMPT_MAPPING[model_type])
|
||||
|
||||
|
||||
DEFAULT_PROMPT_TEMPLATE = _get_prompt
|
||||
|
||||
|
||||
class LlamaConfig(openllm_core.LLMConfig):
|
||||
"""LLaMA model was proposed in [LLaMA: Open and Efficient Foundation Language Models](https://arxiv.org/abs/2302.13971) by Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timothée Lacroix, Baptiste Rozière, Naman Goyal, Eric Hambro, Faisal Azhar, Aurelien Rodriguez, Armand Joulin, Edouard Grave, Guillaume Lample.
|
||||
|
||||
@@ -69,27 +78,39 @@ class LlamaConfig(openllm_core.LLMConfig):
|
||||
Refer to [Llama's model card](https://huggingface.co/docs/transformers/main/model_doc/llama)
|
||||
for more information.
|
||||
"""
|
||||
|
||||
__config__ = {
|
||||
'name_type': 'lowercase',
|
||||
'url': 'https://github.com/facebookresearch/llama',
|
||||
'architecture': 'LlamaForCausalLM',
|
||||
'requirements': ['fairscale', 'sentencepiece', 'scipy'],
|
||||
'default_id': 'NousResearch/llama-2-7b-hf',
|
||||
'serialisation': 'safetensors',
|
||||
# NOTE: see https://huggingface.co/blog/codellama#conversational-instructions
|
||||
'conversation': dict(system_template='<s>[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n', roles=('[INST]', '[/INST]'), sep_style=SeparatorStyle.LLAMA, sep=' ', sep2=' </s><s>'),
|
||||
'model_ids': [
|
||||
'meta-llama/Llama-2-70b-chat-hf', 'meta-llama/Llama-2-13b-chat-hf', 'meta-llama/Llama-2-7b-chat-hf', 'meta-llama/Llama-2-70b-hf', 'meta-llama/Llama-2-13b-hf',
|
||||
'meta-llama/Llama-2-7b-hf', 'NousResearch/llama-2-70b-chat-hf', 'NousResearch/llama-2-13b-chat-hf', 'NousResearch/llama-2-7b-chat-hf', 'NousResearch/llama-2-70b-hf',
|
||||
'NousResearch/llama-2-13b-hf', 'NousResearch/llama-2-7b-hf',
|
||||
],
|
||||
'fine_tune_strategies': ({
|
||||
'adapter_type': 'lora',
|
||||
'r': 64,
|
||||
'lora_alpha': 16,
|
||||
'lora_dropout': 0.1,
|
||||
'bias': 'none'
|
||||
},)
|
||||
'name_type': 'lowercase',
|
||||
'url': 'https://github.com/facebookresearch/llama',
|
||||
'architecture': 'LlamaForCausalLM',
|
||||
'requirements': ['fairscale', 'sentencepiece', 'scipy'],
|
||||
'default_id': 'NousResearch/llama-2-7b-hf',
|
||||
'serialisation': 'safetensors',
|
||||
# NOTE: see https://huggingface.co/blog/codellama#conversational-instructions
|
||||
'conversation': dict(
|
||||
system_template='<s>[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n',
|
||||
roles=('[INST]', '[/INST]'),
|
||||
sep_style=SeparatorStyle.LLAMA,
|
||||
sep=' ',
|
||||
sep2=' </s><s>',
|
||||
),
|
||||
'model_ids': [
|
||||
'meta-llama/Llama-2-70b-chat-hf',
|
||||
'meta-llama/Llama-2-13b-chat-hf',
|
||||
'meta-llama/Llama-2-7b-chat-hf',
|
||||
'meta-llama/Llama-2-70b-hf',
|
||||
'meta-llama/Llama-2-13b-hf',
|
||||
'meta-llama/Llama-2-7b-hf',
|
||||
'NousResearch/llama-2-70b-chat-hf',
|
||||
'NousResearch/llama-2-13b-chat-hf',
|
||||
'NousResearch/llama-2-7b-chat-hf',
|
||||
'NousResearch/llama-2-70b-hf',
|
||||
'NousResearch/llama-2-13b-hf',
|
||||
'NousResearch/llama-2-7b-hf',
|
||||
],
|
||||
'fine_tune_strategies': (
|
||||
{'adapter_type': 'lora', 'r': 64, 'lora_alpha': 16, 'lora_dropout': 0.1, 'bias': 'none'},
|
||||
),
|
||||
}
|
||||
|
||||
class GenerationConfig:
|
||||
@@ -110,24 +131,27 @@ class LlamaConfig(openllm_core.LLMConfig):
|
||||
def default_system_message(self) -> str:
|
||||
return DEFAULT_SYSTEM_MESSAGE
|
||||
|
||||
def sanitize_parameters(self,
|
||||
prompt: str,
|
||||
prompt_template: PromptTemplate | str | None = None,
|
||||
system_message: str | None = None,
|
||||
top_k: int | None = None,
|
||||
top_p: float | None = None,
|
||||
temperature: float | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
**attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt_template: PromptTemplate | str | None = None,
|
||||
system_message: str | None = None,
|
||||
top_k: int | None = None,
|
||||
top_p: float | None = None,
|
||||
temperature: float | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
**attrs: t.Any,
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
system_message = DEFAULT_SYSTEM_MESSAGE if system_message is None else system_message
|
||||
if prompt_template is None: prompt_template = DEFAULT_PROMPT_TEMPLATE('v2')
|
||||
elif isinstance(prompt_template, str): prompt_template = PromptTemplate(template=prompt_template)
|
||||
return prompt_template.with_options(system_message=system_message).format(instruction=prompt), {
|
||||
'max_new_tokens': max_new_tokens,
|
||||
'temperature': temperature,
|
||||
'top_p': top_p,
|
||||
'top_k': top_k
|
||||
}, {}
|
||||
if prompt_template is None:
|
||||
prompt_template = DEFAULT_PROMPT_TEMPLATE('v2')
|
||||
elif isinstance(prompt_template, str):
|
||||
prompt_template = PromptTemplate(template=prompt_template)
|
||||
return (
|
||||
prompt_template.with_options(system_message=system_message).format(instruction=prompt),
|
||||
{'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_p': top_p, 'top_k': top_k},
|
||||
{},
|
||||
)
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: list[str], **_: t.Any) -> str:
|
||||
return generation_result[0]
|
||||
|
||||
@@ -2,7 +2,8 @@ from __future__ import annotations
|
||||
|
||||
import openllm_core
|
||||
|
||||
START_MISTRAL_COMMAND_DOCSTRING = '''\
|
||||
|
||||
START_MISTRAL_COMMAND_DOCSTRING = """\
|
||||
Run a LLMServer for Mistral model.
|
||||
|
||||
\b
|
||||
@@ -23,8 +24,9 @@ or provide `--model-id` flag when running ``openllm start mistral``:
|
||||
|
||||
\b
|
||||
$ openllm start mistral --model-id HuggingFaceH4/zephyr-7b-alpha
|
||||
'''
|
||||
DEFAULT_PROMPT_TEMPLATE = '''{instruction}'''
|
||||
"""
|
||||
DEFAULT_PROMPT_TEMPLATE = """{instruction}"""
|
||||
|
||||
|
||||
class MistralConfig(openllm_core.LLMConfig):
|
||||
"""Mistral's [paper](https://arxiv.org/abs/2310.06825) and first released by [MistralAI](https://mistral.ai/news/announcing-mistral-7b/).
|
||||
@@ -32,13 +34,20 @@ class MistralConfig(openllm_core.LLMConfig):
|
||||
Mistral-7B-v0.1 is Mistral AI\'s first Large Language Model (LLM).
|
||||
Refer to [Mistral's HuggingFace page](https://huggingface.co/docs/transformers/v4.35.0/en/model_doc/mistral#overview) for more information.
|
||||
"""
|
||||
|
||||
__config__ = {
|
||||
'name_type': 'lowercase',
|
||||
'url': 'https://huggingface.co/docs/transformers/v4.35.0/en/model_doc/mistral#overview',
|
||||
'default_id': 'mistralai/Mistral-7B-Instruct-v0.1',
|
||||
'architecture': 'MistralForCausalLM',
|
||||
'add_generation_prompt': True,
|
||||
'model_ids': ['mistralai/Mistral-7B-v0.1', 'mistralai/Mistral-7B-Instruct-v0.1', 'amazon/MistralLite', 'HuggingFaceH4/zephyr-7b-beta', 'HuggingFaceH4/zephyr-7b-alpha'],
|
||||
'name_type': 'lowercase',
|
||||
'url': 'https://huggingface.co/docs/transformers/v4.35.0/en/model_doc/mistral#overview',
|
||||
'default_id': 'mistralai/Mistral-7B-Instruct-v0.1',
|
||||
'architecture': 'MistralForCausalLM',
|
||||
'add_generation_prompt': True,
|
||||
'model_ids': [
|
||||
'mistralai/Mistral-7B-v0.1',
|
||||
'mistralai/Mistral-7B-Instruct-v0.1',
|
||||
'amazon/MistralLite',
|
||||
'HuggingFaceH4/zephyr-7b-beta',
|
||||
'HuggingFaceH4/zephyr-7b-alpha',
|
||||
],
|
||||
}
|
||||
|
||||
class GenerationConfig:
|
||||
|
||||
@@ -8,9 +8,10 @@ from openllm_core.prompts import PromptTemplate
|
||||
from openllm_core.prompts import process_prompt
|
||||
from openllm_core.utils import dantic
|
||||
|
||||
|
||||
MPTPromptType = t.Literal['default', 'instruct', 'chat', 'storywriter']
|
||||
|
||||
START_MPT_COMMAND_DOCSTRING = '''\
|
||||
START_MPT_COMMAND_DOCSTRING = """\
|
||||
Run a LLMServer for MPT model.
|
||||
|
||||
\b
|
||||
@@ -35,24 +36,38 @@ or provide `--model-id` flag when running ``openllm start mpt``:
|
||||
|
||||
\b
|
||||
$ openllm start mpt --model-id mosaicml/mpt-30b
|
||||
'''
|
||||
"""
|
||||
INSTRUCTION_KEY, RESPONSE_KEY, END_KEY = '### Instruction:', '### Response:', '### End'
|
||||
INTRO_BLURB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.'
|
||||
INTRO_BLURB = (
|
||||
'Below is an instruction that describes a task. Write a response that appropriately completes the request.'
|
||||
)
|
||||
# NOTE: This is the prompt that is used for generating responses using an already
|
||||
# trained model. It ends with the response key, where the job of the model is to provide
|
||||
# the completion that follows it (i.e. the response itself).
|
||||
_chat_prompt, _default_prompt, _instruct_prompt = '''{instruction}''', '''{instruction}''', '''{intro}
|
||||
_chat_prompt, _default_prompt, _instruct_prompt = (
|
||||
"""{instruction}""",
|
||||
"""{instruction}""",
|
||||
"""{intro}
|
||||
{instruction_key}
|
||||
{instruction}
|
||||
{response_key}
|
||||
'''.format(intro=INTRO_BLURB, instruction_key=INSTRUCTION_KEY, instruction='{instruction}', response_key=RESPONSE_KEY)
|
||||
PROMPT_MAPPING = {'default': _default_prompt, 'instruct': _instruct_prompt, 'storywriter': _default_prompt, 'chat': _chat_prompt}
|
||||
""".format(intro=INTRO_BLURB, instruction_key=INSTRUCTION_KEY, instruction='{instruction}', response_key=RESPONSE_KEY),
|
||||
)
|
||||
PROMPT_MAPPING = {
|
||||
'default': _default_prompt,
|
||||
'instruct': _instruct_prompt,
|
||||
'storywriter': _default_prompt,
|
||||
'chat': _chat_prompt,
|
||||
}
|
||||
|
||||
|
||||
def _get_prompt(model_type: str) -> str:
|
||||
return PROMPT_MAPPING[model_type]
|
||||
|
||||
|
||||
DEFAULT_PROMPT_TEMPLATE = _get_prompt
|
||||
|
||||
|
||||
class MPTConfig(openllm_core.LLMConfig):
|
||||
"""MPT is a decoder-style transformer pretrained from scratch on English text and code.
|
||||
|
||||
@@ -62,25 +77,34 @@ class MPTConfig(openllm_core.LLMConfig):
|
||||
on HuggingFace. Refers [HuggingFace's MosaicML page](https://huggingface.co/mosaicml)
|
||||
for more details on specific models.
|
||||
"""
|
||||
|
||||
__config__ = {
|
||||
'name_type': 'lowercase',
|
||||
'trust_remote_code': True,
|
||||
'url': 'https://huggingface.co/mosaicml',
|
||||
'timeout': int(36e6),
|
||||
'requirements': ['triton', 'einops'],
|
||||
'architecture': 'MPTForCausalLM',
|
||||
# NOTE: See https://huggingface.co/TheBloke/mpt-30B-chat-GGML/discussions/4
|
||||
'conversation': dict(roles=('user', 'assistant'), messages=[], sep_style=SeparatorStyle.MPT, sep='\n'),
|
||||
'default_id': 'mosaicml/mpt-7b-instruct',
|
||||
'model_ids': [
|
||||
'mosaicml/mpt-7b', 'mosaicml/mpt-7b-instruct', 'mosaicml/mpt-7b-chat', 'mosaicml/mpt-7b-storywriter', 'mosaicml/mpt-30b', 'mosaicml/mpt-30b-instruct', 'mosaicml/mpt-30b-chat'
|
||||
]
|
||||
'name_type': 'lowercase',
|
||||
'trust_remote_code': True,
|
||||
'url': 'https://huggingface.co/mosaicml',
|
||||
'timeout': int(36e6),
|
||||
'requirements': ['triton', 'einops'],
|
||||
'architecture': 'MPTForCausalLM',
|
||||
# NOTE: See https://huggingface.co/TheBloke/mpt-30B-chat-GGML/discussions/4
|
||||
'conversation': dict(roles=('user', 'assistant'), messages=[], sep_style=SeparatorStyle.MPT, sep='\n'),
|
||||
'default_id': 'mosaicml/mpt-7b-instruct',
|
||||
'model_ids': [
|
||||
'mosaicml/mpt-7b',
|
||||
'mosaicml/mpt-7b-instruct',
|
||||
'mosaicml/mpt-7b-chat',
|
||||
'mosaicml/mpt-7b-storywriter',
|
||||
'mosaicml/mpt-30b',
|
||||
'mosaicml/mpt-30b-instruct',
|
||||
'mosaicml/mpt-30b-chat',
|
||||
],
|
||||
}
|
||||
prompt_type: MPTPromptType = dantic.Field('"default"', description='Given prompt type for running MPT. Default will be inferred from model name if pretrained.')
|
||||
prompt_type: MPTPromptType = dantic.Field(
|
||||
'"default"',
|
||||
description='Given prompt type for running MPT. Default will be inferred from model name if pretrained.',
|
||||
)
|
||||
max_sequence_length: int = dantic.Field(
|
||||
2048,
|
||||
description=
|
||||
'Max sequence length to run MPT with. Note that MPT is trained ith sequence length of 2048, but with [ALiBi](https://arxiv.org/abs/2108.12409) it can set up to 4096 (for 7b models) and 16384 (for 30b models)'
|
||||
2048,
|
||||
description='Max sequence length to run MPT with. Note that MPT is trained ith sequence length of 2048, but with [ALiBi](https://arxiv.org/abs/2108.12409) it can set up to 4096 (for 7b models) and 16384 (for 30b models)',
|
||||
)
|
||||
|
||||
class GenerationConfig:
|
||||
@@ -88,26 +112,35 @@ class MPTConfig(openllm_core.LLMConfig):
|
||||
temperature: float = 0
|
||||
top_p: float = 0.8
|
||||
|
||||
def sanitize_parameters(self,
|
||||
prompt: str,
|
||||
prompt_template: PromptTemplate | str | None = None,
|
||||
system_message: str | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
prompt_type: MPTPromptType | None = None,
|
||||
use_default_prompt_template: bool = True,
|
||||
**attrs: t.Any,
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt_template: PromptTemplate | str | None = None,
|
||||
system_message: str | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
prompt_type: MPTPromptType | None = None,
|
||||
use_default_prompt_template: bool = True,
|
||||
**attrs: t.Any,
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
_template = None
|
||||
if use_default_prompt_template:
|
||||
if prompt_type is None:
|
||||
if 'instruct' in self.model_id: prompt_type = 'instruct'
|
||||
elif 'storywriter' in self.model_id: prompt_type = 'storywriter'
|
||||
elif 'chat' in self.model_id: prompt_type = 'chat'
|
||||
else: prompt_type = 'default'
|
||||
if 'instruct' in self.model_id:
|
||||
prompt_type = 'instruct'
|
||||
elif 'storywriter' in self.model_id:
|
||||
prompt_type = 'storywriter'
|
||||
elif 'chat' in self.model_id:
|
||||
prompt_type = 'chat'
|
||||
else:
|
||||
prompt_type = 'default'
|
||||
_template = DEFAULT_PROMPT_TEMPLATE(prompt_type)
|
||||
return process_prompt(prompt, _template, use_default_prompt_template), {'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_p': top_p}, {}
|
||||
return (
|
||||
process_prompt(prompt, _template, use_default_prompt_template),
|
||||
{'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_p': top_p},
|
||||
{},
|
||||
)
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **attrs: t.Any) -> str:
|
||||
return generation_result[0]
|
||||
|
||||
@@ -6,10 +6,11 @@ import openllm_core
|
||||
from openllm_core._conversation import SeparatorStyle
|
||||
from openllm_core.prompts import process_prompt
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm_core.prompts.prompt_template import PromptTemplate
|
||||
|
||||
START_OPT_COMMAND_DOCSTRING = '''\
|
||||
START_OPT_COMMAND_DOCSTRING = """\
|
||||
Run a LLMServer for OPT model.
|
||||
|
||||
\b
|
||||
@@ -30,8 +31,9 @@ or provide `--model-id` flag when running ``openllm start opt``:
|
||||
|
||||
\b
|
||||
$ openllm start opt --model-id facebook/opt-6.7b
|
||||
'''
|
||||
DEFAULT_PROMPT_TEMPLATE = '''{instruction}'''
|
||||
"""
|
||||
DEFAULT_PROMPT_TEMPLATE = """{instruction}"""
|
||||
|
||||
|
||||
class OPTConfig(openllm_core.LLMConfig):
|
||||
"""OPT was first introduced in [Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) and first released in [metaseq's repository](https://github.com/facebookresearch/metaseq) on May 3rd 2022 by Meta AI.
|
||||
@@ -43,22 +45,34 @@ class OPTConfig(openllm_core.LLMConfig):
|
||||
|
||||
Refer to [OPT's HuggingFace page](https://huggingface.co/docs/transformers/model_doc/opt) for more information.
|
||||
"""
|
||||
|
||||
__config__ = {
|
||||
'name_type': 'lowercase',
|
||||
'trust_remote_code': False,
|
||||
'url': 'https://huggingface.co/docs/transformers/model_doc/opt',
|
||||
'default_id': 'facebook/opt-1.3b',
|
||||
'architecture': 'OPTForCausalLM',
|
||||
'conversation': dict(roles=('User', 'Assistant'), messages=[], sep_style=SeparatorStyle.ADD_COLON_SINGLE, sep='\n'),
|
||||
'model_ids': ['facebook/opt-125m', 'facebook/opt-350m', 'facebook/opt-1.3b', 'facebook/opt-2.7b', 'facebook/opt-6.7b', 'facebook/opt-66b'],
|
||||
'fine_tune_strategies': ({
|
||||
'adapter_type': 'lora',
|
||||
'r': 16,
|
||||
'lora_alpha': 32,
|
||||
'target_modules': ['q_proj', 'v_proj'],
|
||||
'lora_dropout': 0.05,
|
||||
'bias': 'none'
|
||||
},)
|
||||
'name_type': 'lowercase',
|
||||
'trust_remote_code': False,
|
||||
'url': 'https://huggingface.co/docs/transformers/model_doc/opt',
|
||||
'default_id': 'facebook/opt-1.3b',
|
||||
'architecture': 'OPTForCausalLM',
|
||||
'conversation': dict(
|
||||
roles=('User', 'Assistant'), messages=[], sep_style=SeparatorStyle.ADD_COLON_SINGLE, sep='\n'
|
||||
),
|
||||
'model_ids': [
|
||||
'facebook/opt-125m',
|
||||
'facebook/opt-350m',
|
||||
'facebook/opt-1.3b',
|
||||
'facebook/opt-2.7b',
|
||||
'facebook/opt-6.7b',
|
||||
'facebook/opt-66b',
|
||||
],
|
||||
'fine_tune_strategies': (
|
||||
{
|
||||
'adapter_type': 'lora',
|
||||
'r': 16,
|
||||
'lora_alpha': 32,
|
||||
'target_modules': ['q_proj', 'v_proj'],
|
||||
'lora_dropout': 0.05,
|
||||
'bias': 'none',
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
class GenerationConfig:
|
||||
@@ -67,23 +81,30 @@ class OPTConfig(openllm_core.LLMConfig):
|
||||
max_new_tokens: int = 256
|
||||
num_return_sequences: int = 1
|
||||
|
||||
def sanitize_parameters(self,
|
||||
prompt: str,
|
||||
prompt_template: PromptTemplate | str | None = None,
|
||||
system_message: str | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
top_k: int | None = None,
|
||||
num_return_sequences: int | None = None,
|
||||
use_default_prompt_template: bool = False,
|
||||
**attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt_template: PromptTemplate | str | None = None,
|
||||
system_message: str | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
top_k: int | None = None,
|
||||
num_return_sequences: int | None = None,
|
||||
use_default_prompt_template: bool = False,
|
||||
**attrs: t.Any,
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return (
|
||||
process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs),
|
||||
{
|
||||
'max_new_tokens': max_new_tokens,
|
||||
'temperature': temperature,
|
||||
'top_k': top_k,
|
||||
'num_return_sequences': num_return_sequences
|
||||
}, {}
|
||||
'num_return_sequences': num_return_sequences,
|
||||
},
|
||||
{},
|
||||
)
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **attrs: t.Any) -> str:
|
||||
if len(generation_result) == 1: return generation_result[0]
|
||||
if len(generation_result) == 1:
|
||||
return generation_result[0]
|
||||
return '\n'.join(generation_result)
|
||||
|
||||
@@ -7,7 +7,8 @@ from openllm_core._conversation import SeparatorStyle
|
||||
from openllm_core.prompts import PromptTemplate
|
||||
from openllm_core.prompts import process_prompt
|
||||
|
||||
START_STABLELM_COMMAND_DOCSTRING = '''\
|
||||
|
||||
START_STABLELM_COMMAND_DOCSTRING = """\
|
||||
Run a LLMServer for StableLM model.
|
||||
|
||||
\b
|
||||
@@ -25,14 +26,15 @@ or provide `--model-id` flag when running ``openllm start stablelm``:
|
||||
|
||||
\b
|
||||
$ openllm start stablelm --model-id 'stabilityai/stablelm-tuned-alpha-3b'
|
||||
'''
|
||||
SYSTEM_PROMPT = '''<|SYSTEM|># StableLM Tuned (Alpha version)
|
||||
"""
|
||||
SYSTEM_PROMPT = """<|SYSTEM|># StableLM Tuned (Alpha version)
|
||||
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
|
||||
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
|
||||
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
|
||||
- StableLM will refuse to participate in anything that could harm a human.
|
||||
'''
|
||||
DEFAULT_PROMPT_TEMPLATE = '''{system_prompt}<|USER|>{instruction}<|ASSISTANT|>'''
|
||||
"""
|
||||
DEFAULT_PROMPT_TEMPLATE = """{system_prompt}<|USER|>{instruction}<|ASSISTANT|>"""
|
||||
|
||||
|
||||
class StableLMConfig(openllm_core.LLMConfig):
|
||||
"""StableLM-Base-Alpha is a suite of 3B and 7B parameter decoder-only language models.
|
||||
@@ -48,23 +50,31 @@ class StableLMConfig(openllm_core.LLMConfig):
|
||||
and [StableLM-base's model card](https://huggingface.co/stabilityai/stablelm-base-alpha-7b)
|
||||
for more information.
|
||||
"""
|
||||
|
||||
__config__ = {
|
||||
'name_type': 'lowercase',
|
||||
'url': 'https://github.com/Stability-AI/StableLM',
|
||||
'conversation': dict(system_template='<|SYSTEM|>{system_message}',
|
||||
system_message='''# StableLM Tuned (Alpha version)
|
||||
'name_type': 'lowercase',
|
||||
'url': 'https://github.com/Stability-AI/StableLM',
|
||||
'conversation': dict(
|
||||
system_template='<|SYSTEM|>{system_message}',
|
||||
system_message="""# StableLM Tuned (Alpha version)
|
||||
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
|
||||
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
|
||||
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
|
||||
- StableLM will refuse to participate in anything that could harm a human.
|
||||
''',
|
||||
roles=('<|USER|>', '<|ASSISTANT|>'),
|
||||
sep_style=SeparatorStyle.NO_COLON_SINGLE,
|
||||
sep='',
|
||||
stop_token_ids=[50278, 50279, 50277, 1, 0]),
|
||||
'architecture': 'GPTNeoXForCausalLM',
|
||||
'default_id': 'stabilityai/stablelm-tuned-alpha-3b',
|
||||
'model_ids': ['stabilityai/stablelm-tuned-alpha-3b', 'stabilityai/stablelm-tuned-alpha-7b', 'stabilityai/stablelm-base-alpha-3b', 'stabilityai/stablelm-base-alpha-7b']
|
||||
""",
|
||||
roles=('<|USER|>', '<|ASSISTANT|>'),
|
||||
sep_style=SeparatorStyle.NO_COLON_SINGLE,
|
||||
sep='',
|
||||
stop_token_ids=[50278, 50279, 50277, 1, 0],
|
||||
),
|
||||
'architecture': 'GPTNeoXForCausalLM',
|
||||
'default_id': 'stabilityai/stablelm-tuned-alpha-3b',
|
||||
'model_ids': [
|
||||
'stabilityai/stablelm-tuned-alpha-3b',
|
||||
'stabilityai/stablelm-tuned-alpha-7b',
|
||||
'stabilityai/stablelm-base-alpha-3b',
|
||||
'stabilityai/stablelm-base-alpha-7b',
|
||||
],
|
||||
}
|
||||
|
||||
class GenerationConfig:
|
||||
@@ -73,22 +83,30 @@ class StableLMConfig(openllm_core.LLMConfig):
|
||||
top_k: int = 0
|
||||
top_p: float = 0.9
|
||||
|
||||
def sanitize_parameters(self,
|
||||
prompt: str,
|
||||
prompt_template: PromptTemplate | str | None = None,
|
||||
system_message: str | None = None,
|
||||
temperature: float | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
top_k: int | None = None,
|
||||
top_p: float | None = None,
|
||||
use_default_prompt_template: bool = False,
|
||||
**attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt_template: PromptTemplate | str | None = None,
|
||||
system_message: str | None = None,
|
||||
temperature: float | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
top_k: int | None = None,
|
||||
top_p: float | None = None,
|
||||
use_default_prompt_template: bool = False,
|
||||
**attrs: t.Any,
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
if 'tuned' in self._model_id and use_default_prompt_template:
|
||||
system_prompt = attrs.pop('system_prompt', SYSTEM_PROMPT)
|
||||
prompt_text = process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, system_prompt=system_prompt, **attrs)
|
||||
prompt_text = process_prompt(
|
||||
prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, system_prompt=system_prompt, **attrs
|
||||
)
|
||||
else:
|
||||
prompt_text = prompt
|
||||
return prompt_text, {'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_k': top_k, 'top_p': top_p}, {}
|
||||
return (
|
||||
prompt_text,
|
||||
{'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_k': top_k, 'top_p': top_p},
|
||||
{},
|
||||
)
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: list[str], **_: t.Any) -> str:
|
||||
return generation_result[0]
|
||||
|
||||
@@ -5,10 +5,11 @@ import openllm_core
|
||||
|
||||
from openllm_core._conversation import SeparatorStyle
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm_core.prompts import PromptTemplate
|
||||
|
||||
START_STARCODER_COMMAND_DOCSTRING = '''\
|
||||
START_STARCODER_COMMAND_DOCSTRING = """\
|
||||
Run a LLMServer for StarCoder model.
|
||||
|
||||
\b
|
||||
@@ -26,9 +27,17 @@ or provide `--model-id` flag when running ``openllm start starcoder``:
|
||||
|
||||
\b
|
||||
$ openllm start starcoder --model-id 'bigcode/starcoder'
|
||||
'''
|
||||
DEFAULT_PROMPT_TEMPLATE = '''{instruction}'''
|
||||
FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD, EOD, FIM_INDICATOR = '<fim-prefix>', '<fim-middle>', '<fim-suffix>', '<fim-pad>', '<|endoftext|>', '<FILL_HERE>'
|
||||
"""
|
||||
DEFAULT_PROMPT_TEMPLATE = """{instruction}"""
|
||||
FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD, EOD, FIM_INDICATOR = (
|
||||
'<fim-prefix>',
|
||||
'<fim-middle>',
|
||||
'<fim-suffix>',
|
||||
'<fim-pad>',
|
||||
'<|endoftext|>',
|
||||
'<FILL_HERE>',
|
||||
)
|
||||
|
||||
|
||||
class StarCoderConfig(openllm_core.LLMConfig):
|
||||
"""The StarCoder models are 15.5B parameter models trained on 80+ programming languages from [The Stack (v1.2)](https://huggingface.co/datasets/bigcode/the-stack), with opt-out requests excluded.
|
||||
@@ -39,14 +48,17 @@ class StarCoderConfig(openllm_core.LLMConfig):
|
||||
|
||||
Refer to [StarCoder's model card](https://huggingface.co/bigcode/starcoder) for more information.
|
||||
"""
|
||||
|
||||
__config__ = {
|
||||
'name_type': 'lowercase',
|
||||
'url': 'https://github.com/bigcode-project/starcoder',
|
||||
'architecture': 'GPTBigCodeForCausalLM',
|
||||
'conversation': dict(system_message='', roles=('<|user|>', '<|assistant|>'), sep_style=SeparatorStyle.STARCODER, sep='\n'),
|
||||
'requirements': ['bitsandbytes'],
|
||||
'default_id': 'bigcode/starcoder',
|
||||
'model_ids': ['bigcode/starcoder', 'bigcode/starcoderbase']
|
||||
'name_type': 'lowercase',
|
||||
'url': 'https://github.com/bigcode-project/starcoder',
|
||||
'architecture': 'GPTBigCodeForCausalLM',
|
||||
'conversation': dict(
|
||||
system_message='', roles=('<|user|>', '<|assistant|>'), sep_style=SeparatorStyle.STARCODER, sep='\n'
|
||||
),
|
||||
'requirements': ['bitsandbytes'],
|
||||
'default_id': 'bigcode/starcoder',
|
||||
'model_ids': ['bigcode/starcoder', 'bigcode/starcoderbase'],
|
||||
}
|
||||
|
||||
class GenerationConfig:
|
||||
@@ -58,15 +70,17 @@ class StarCoderConfig(openllm_core.LLMConfig):
|
||||
pad_token_id: int = 49152
|
||||
repetition_penalty: float = 1.2
|
||||
|
||||
def sanitize_parameters(self,
|
||||
prompt: str,
|
||||
prompt_template: PromptTemplate | str | None = None,
|
||||
system_message: str | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
repetition_penalty: float | None = None,
|
||||
**attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt_template: PromptTemplate | str | None = None,
|
||||
system_message: str | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
repetition_penalty: float | None = None,
|
||||
**attrs: t.Any,
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
fim_mode, prefix, suffix = FIM_INDICATOR in prompt, None, None
|
||||
if fim_mode:
|
||||
try:
|
||||
@@ -77,7 +91,18 @@ class StarCoderConfig(openllm_core.LLMConfig):
|
||||
else:
|
||||
prompt_text = prompt
|
||||
# XXX: This value for pad_token_id is currently a hack, need more investigate why the default starcoder doesn't include the same value as santacoder EOD
|
||||
return prompt_text, {'temperature': temperature, 'top_p': top_p, 'max_new_tokens': max_new_tokens, 'repetition_penalty': repetition_penalty, 'pad_token_id': 49152, **attrs}, {}
|
||||
return (
|
||||
prompt_text,
|
||||
{
|
||||
'temperature': temperature,
|
||||
'top_p': top_p,
|
||||
'max_new_tokens': max_new_tokens,
|
||||
'repetition_penalty': repetition_penalty,
|
||||
'pad_token_id': 49152,
|
||||
**attrs,
|
||||
},
|
||||
{},
|
||||
)
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str:
|
||||
return generation_result[0]
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
"""Base exceptions for OpenLLM. This extends BentoML exceptions."""
|
||||
|
||||
from __future__ import annotations
|
||||
from http import HTTPStatus
|
||||
|
||||
|
||||
class OpenLLMException(Exception):
|
||||
"""Base class for all OpenLLM exceptions. This shares similar interface with BentoMLException."""
|
||||
|
||||
@@ -11,23 +13,30 @@ class OpenLLMException(Exception):
|
||||
self.message = message
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class GpuNotAvailableError(OpenLLMException):
|
||||
"""Raised when there is no GPU available in given system."""
|
||||
|
||||
|
||||
class ValidationError(OpenLLMException):
|
||||
"""Raised when a validation fails."""
|
||||
|
||||
|
||||
class ForbiddenAttributeError(OpenLLMException):
|
||||
"""Raised when using an _internal field."""
|
||||
|
||||
|
||||
class MissingAnnotationAttributeError(OpenLLMException):
|
||||
"""Raised when a field under openllm.LLMConfig is missing annotations."""
|
||||
|
||||
|
||||
class MissingDependencyError(BaseException):
|
||||
"""Raised when a dependency is missing."""
|
||||
|
||||
|
||||
class Error(BaseException):
|
||||
"""To be used instead of naked raise."""
|
||||
|
||||
|
||||
class FineTuneStrategyNotSupportedError(OpenLLMException):
|
||||
"""Raised when a fine-tune strategy is not supported for given LLM."""
|
||||
|
||||
@@ -5,9 +5,11 @@ import attr
|
||||
|
||||
from .utils import default_formatter
|
||||
|
||||
|
||||
# equivocal setattr to save one lookup per assignment
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
|
||||
@attr.define(slots=True)
|
||||
class PromptTemplate:
|
||||
template: str
|
||||
@@ -30,20 +32,31 @@ class PromptTemplate:
|
||||
try:
|
||||
return self.template.format(**prompt_variables)
|
||||
except KeyError as e:
|
||||
raise RuntimeError(f"Missing variable '{e.args[0]}' (required: {self._input_variables}) in the prompt template.") from None
|
||||
raise RuntimeError(
|
||||
f"Missing variable '{e.args[0]}' (required: {self._input_variables}) in the prompt template."
|
||||
) from None
|
||||
|
||||
|
||||
# TODO: remove process_prompt after refactor config for all models
|
||||
def process_prompt(prompt: str, template: PromptTemplate | str | None = None, use_prompt_template: bool = True, **attrs: t.Any) -> str:
|
||||
def process_prompt(
|
||||
prompt: str, template: PromptTemplate | str | None = None, use_prompt_template: bool = True, **attrs: t.Any
|
||||
) -> str:
|
||||
# Currently, all default prompt will always have `instruction` key.
|
||||
if not use_prompt_template: return prompt
|
||||
elif template is None: raise ValueError("'template' can't be None while 'use_prompt_template=False'")
|
||||
if isinstance(template, PromptTemplate): template = template.to_string()
|
||||
if not use_prompt_template:
|
||||
return prompt
|
||||
elif template is None:
|
||||
raise ValueError("'template' can't be None while 'use_prompt_template=False'")
|
||||
if isinstance(template, PromptTemplate):
|
||||
template = template.to_string()
|
||||
template_variables = default_formatter.extract_template_variables(template)
|
||||
prompt_variables = {k: v for k, v in attrs.items() if k in template_variables}
|
||||
if 'instruction' in prompt_variables:
|
||||
raise RuntimeError("'instruction' should be passed as the first argument instead of kwargs when 'use_prompt_template=True'")
|
||||
raise RuntimeError(
|
||||
"'instruction' should be passed as the first argument instead of kwargs when 'use_prompt_template=True'"
|
||||
)
|
||||
try:
|
||||
return template.format(instruction=prompt, **prompt_variables)
|
||||
except KeyError as e:
|
||||
raise RuntimeError(
|
||||
f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. Use 'use_prompt_template=False' to disable the default prompt template.") from None
|
||||
f"Missing variable '{e.args[0]}' (required: {template_variables}) in the prompt template. Use 'use_prompt_template=False' to disable the default prompt template."
|
||||
) from None
|
||||
|
||||
@@ -2,17 +2,24 @@ from __future__ import annotations
|
||||
import string
|
||||
import typing as t
|
||||
|
||||
|
||||
class PromptFormatter(string.Formatter):
|
||||
"""This PromptFormatter is largely based on langchain's implementation."""
|
||||
|
||||
def vformat(self, format_string: str, args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> t.Any:
|
||||
if len(args) > 0: raise ValueError('Positional arguments are not supported')
|
||||
if len(args) > 0:
|
||||
raise ValueError('Positional arguments are not supported')
|
||||
return super().vformat(format_string, args, kwargs)
|
||||
|
||||
def check_unused_args(self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> None:
|
||||
def check_unused_args(
|
||||
self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]
|
||||
) -> None:
|
||||
extras = set(kwargs).difference(used_args)
|
||||
if extras: raise KeyError(f'Extra params passed: {extras}')
|
||||
if extras:
|
||||
raise KeyError(f'Extra params passed: {extras}')
|
||||
|
||||
def extract_template_variables(self, template: str) -> t.Sequence[str]:
|
||||
return [field[1] for field in self.parse(template) if field[1] is not None]
|
||||
|
||||
|
||||
default_formatter = PromptFormatter()
|
||||
|
||||
@@ -21,6 +21,7 @@ from .lazy import LazyModule as LazyModule
|
||||
from .lazy import VersionInfo as VersionInfo
|
||||
from .._typing_compat import overload
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from bentoml._internal.models.model import ModelContext
|
||||
from bentoml._internal.types import PathType
|
||||
@@ -38,13 +39,15 @@ try:
|
||||
except ImportError:
|
||||
# python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on)
|
||||
_TypingGenericAlias = () # type: ignore
|
||||
if sys.version_info < (3, 10): _WithArgsTypes = (_TypingGenericAlias,)
|
||||
if sys.version_info < (3, 10):
|
||||
_WithArgsTypes = (_TypingGenericAlias,)
|
||||
else:
|
||||
# _GenericAlias is the actual GenericAlias implementation
|
||||
_WithArgsTypes: t.Any = (t._GenericAlias, types.GenericAlias, types.UnionType) # type: ignore
|
||||
|
||||
DEV_DEBUG_VAR = 'OPENLLMDEVDEBUG'
|
||||
|
||||
|
||||
def resolve_user_filepath(filepath: str, ctx: str | None) -> str:
|
||||
# Return if filepath exist after expanduser
|
||||
|
||||
@@ -60,11 +63,18 @@ def resolve_user_filepath(filepath: str, ctx: str | None) -> str:
|
||||
|
||||
raise FileNotFoundError(f'file {filepath} not found')
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def reserve_free_port(host: str = 'localhost', port: int | None = None, prefix: str | None = None, max_retry: int = 50, enable_so_reuseport: bool = False,) -> t.Iterator[int]:
|
||||
def reserve_free_port(
|
||||
host: str = 'localhost',
|
||||
port: int | None = None,
|
||||
prefix: str | None = None,
|
||||
max_retry: int = 50,
|
||||
enable_so_reuseport: bool = False,
|
||||
) -> t.Iterator[int]:
|
||||
"""
|
||||
detect free port and reserve until exit the context
|
||||
"""
|
||||
detect free port and reserve until exit the context
|
||||
"""
|
||||
import psutil
|
||||
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
@@ -78,8 +88,8 @@ def reserve_free_port(host: str = 'localhost', port: int | None = None, prefix:
|
||||
if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 0:
|
||||
raise RuntimeError('Failed to set SO_REUSEPORT.') from None
|
||||
if prefix is not None:
|
||||
prefix_num = int(prefix) * 10**(5 - len(prefix))
|
||||
suffix_range = min(65535 - prefix_num, 10**(5 - len(prefix)))
|
||||
prefix_num = int(prefix) * 10 ** (5 - len(prefix))
|
||||
suffix_range = min(65535 - prefix_num, 10 ** (5 - len(prefix)))
|
||||
for _ in range(max_retry):
|
||||
suffix = random.randint(0, suffix_range)
|
||||
port = int(f'{prefix_num + suffix}')
|
||||
@@ -99,22 +109,28 @@ def reserve_free_port(host: str = 'localhost', port: int | None = None, prefix:
|
||||
finally:
|
||||
sock.close()
|
||||
|
||||
|
||||
def calc_dir_size(path: PathType) -> int:
|
||||
return sum(f.stat().st_size for f in Path(path).glob('**/*') if f.is_file())
|
||||
|
||||
|
||||
def set_debug_mode(enabled: bool, level: int = 1) -> None:
|
||||
# monkeypatch bentoml._internal.configuration.set_debug_mode to remove unused logs
|
||||
if enabled: os.environ[DEV_DEBUG_VAR] = str(level)
|
||||
if enabled:
|
||||
os.environ[DEV_DEBUG_VAR] = str(level)
|
||||
os.environ[DEBUG_ENV_VAR] = str(enabled)
|
||||
os.environ[_GRPC_DEBUG_ENV_VAR] = 'DEBUG' if enabled else 'ERROR'
|
||||
|
||||
|
||||
def lenient_issubclass(cls: t.Any, class_or_tuple: type[t.Any] | tuple[type[t.Any], ...] | None) -> bool:
|
||||
try:
|
||||
return isinstance(cls, type) and issubclass(cls, class_or_tuple) # type: ignore[arg-type]
|
||||
except TypeError:
|
||||
if isinstance(cls, _WithArgsTypes): return False
|
||||
if isinstance(cls, _WithArgsTypes):
|
||||
return False
|
||||
raise
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def generate_hash_from_file(f: str, algorithm: t.Literal['md5', 'sha1'] = 'sha1') -> str:
|
||||
"""Generate a hash from given file's modification time.
|
||||
@@ -128,44 +144,61 @@ def generate_hash_from_file(f: str, algorithm: t.Literal['md5', 'sha1'] = 'sha1'
|
||||
"""
|
||||
return getattr(hashlib, algorithm)(str(os.path.getmtime(resolve_filepath(f))).encode()).hexdigest()
|
||||
|
||||
|
||||
def check_bool_env(env: str, default: bool = True) -> bool:
|
||||
v = os.environ.get(env, str(default)).upper()
|
||||
if v.isdigit(): return bool(int(v)) # special check for digits
|
||||
if v.isdigit():
|
||||
return bool(int(v)) # special check for digits
|
||||
return v in ENV_VARS_TRUE_VALUES
|
||||
|
||||
|
||||
# equivocal setattr to save one lookup per assignment
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
|
||||
def field_env_key(key: str, suffix: str | None = None) -> str:
|
||||
return '_'.join(filter(None, map(str.upper, ['OPENLLM', suffix.strip('_') if suffix else '', key])))
|
||||
|
||||
|
||||
# Special debug flag controled via OPENLLMDEVDEBUG
|
||||
DEBUG: bool = sys.flags.dev_mode or (not sys.flags.ignore_environment and check_bool_env(DEV_DEBUG_VAR, default=False))
|
||||
# Whether to show the codenge for debug purposes
|
||||
SHOW_CODEGEN: bool = DEBUG and (os.environ.get(DEV_DEBUG_VAR, str(0)).isdigit() and int(os.environ.get(DEV_DEBUG_VAR, str(0))) > 3)
|
||||
SHOW_CODEGEN: bool = DEBUG and (
|
||||
os.environ.get(DEV_DEBUG_VAR, str(0)).isdigit() and int(os.environ.get(DEV_DEBUG_VAR, str(0))) > 3
|
||||
)
|
||||
# MYPY is like t.TYPE_CHECKING, but reserved for Mypy plugins
|
||||
MYPY = False
|
||||
|
||||
|
||||
def get_debug_mode() -> bool:
|
||||
if not DEBUG and DEBUG_ENV_VAR in os.environ: return check_bool_env(DEBUG_ENV_VAR, False)
|
||||
if not DEBUG and DEBUG_ENV_VAR in os.environ:
|
||||
return check_bool_env(DEBUG_ENV_VAR, False)
|
||||
return DEBUG
|
||||
|
||||
|
||||
def get_quiet_mode() -> bool:
|
||||
if QUIET_ENV_VAR in os.environ: return check_bool_env(QUIET_ENV_VAR, False)
|
||||
if DEBUG: return False
|
||||
if QUIET_ENV_VAR in os.environ:
|
||||
return check_bool_env(QUIET_ENV_VAR, False)
|
||||
if DEBUG:
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
def set_quiet_mode(enabled: bool) -> None:
|
||||
# do not log setting quiet mode
|
||||
os.environ[QUIET_ENV_VAR] = str(enabled)
|
||||
os.environ[_GRPC_DEBUG_ENV_VAR] = 'NONE'
|
||||
|
||||
|
||||
class ExceptionFilter(logging.Filter):
|
||||
def __init__(self, exclude_exceptions: list[type[Exception]] | None = None, **kwargs: t.Any):
|
||||
if exclude_exceptions is None: exclude_exceptions = []
|
||||
if exclude_exceptions is None:
|
||||
exclude_exceptions = []
|
||||
try:
|
||||
from circus.exc import ConflictError
|
||||
if ConflictError not in exclude_exceptions: exclude_exceptions.append(ConflictError)
|
||||
|
||||
if ConflictError not in exclude_exceptions:
|
||||
exclude_exceptions.append(ConflictError)
|
||||
except ImportError:
|
||||
pass
|
||||
super(ExceptionFilter, self).__init__(**kwargs)
|
||||
@@ -176,55 +209,43 @@ class ExceptionFilter(logging.Filter):
|
||||
etype, _, _ = record.exc_info
|
||||
if etype is not None:
|
||||
for exc in self.EXCLUDE_EXCEPTIONS:
|
||||
if issubclass(etype, exc): return False
|
||||
if issubclass(etype, exc):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class InfoFilter(logging.Filter):
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
return logging.INFO <= record.levelno < logging.WARNING
|
||||
|
||||
|
||||
def gen_random_uuid(prefix: str | None = None) -> str:
|
||||
return '-'.join([prefix or 'openllm', str(uuid.uuid4().hex)])
|
||||
|
||||
|
||||
_LOGGING_CONFIG: dict[str, t.Any] = {
|
||||
'version': 1,
|
||||
'disable_existing_loggers': True,
|
||||
'filters': {
|
||||
'excfilter': {
|
||||
'()': 'openllm_core.utils.ExceptionFilter'
|
||||
},
|
||||
'infofilter': {
|
||||
'()': 'openllm_core.utils.InfoFilter'
|
||||
}
|
||||
},
|
||||
'handlers': {
|
||||
'bentomlhandler': {
|
||||
'class': 'logging.StreamHandler',
|
||||
'filters': ['excfilter', 'infofilter'],
|
||||
'stream': 'ext://sys.stdout'
|
||||
},
|
||||
'defaulthandler': {
|
||||
'class': 'logging.StreamHandler',
|
||||
'level': logging.WARNING
|
||||
}
|
||||
},
|
||||
'loggers': {
|
||||
'bentoml': {
|
||||
'handlers': ['bentomlhandler', 'defaulthandler'],
|
||||
'level': logging.INFO,
|
||||
'propagate': False
|
||||
},
|
||||
'openllm': {
|
||||
'handlers': ['bentomlhandler', 'defaulthandler'],
|
||||
'level': logging.INFO,
|
||||
'propagate': False
|
||||
}
|
||||
},
|
||||
'root': {
|
||||
'level': logging.WARNING
|
||||
'version': 1,
|
||||
'disable_existing_loggers': True,
|
||||
'filters': {
|
||||
'excfilter': {'()': 'openllm_core.utils.ExceptionFilter'},
|
||||
'infofilter': {'()': 'openllm_core.utils.InfoFilter'},
|
||||
},
|
||||
'handlers': {
|
||||
'bentomlhandler': {
|
||||
'class': 'logging.StreamHandler',
|
||||
'filters': ['excfilter', 'infofilter'],
|
||||
'stream': 'ext://sys.stdout',
|
||||
},
|
||||
'defaulthandler': {'class': 'logging.StreamHandler', 'level': logging.WARNING},
|
||||
},
|
||||
'loggers': {
|
||||
'bentoml': {'handlers': ['bentomlhandler', 'defaulthandler'], 'level': logging.INFO, 'propagate': False},
|
||||
'openllm': {'handlers': ['bentomlhandler', 'defaulthandler'], 'level': logging.INFO, 'propagate': False},
|
||||
},
|
||||
'root': {'level': logging.WARNING},
|
||||
}
|
||||
|
||||
|
||||
def configure_logging() -> None:
|
||||
"""Configure logging for OpenLLM.
|
||||
|
||||
@@ -246,16 +267,21 @@ def configure_logging() -> None:
|
||||
|
||||
logging.config.dictConfig(_LOGGING_CONFIG)
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def in_notebook() -> bool:
|
||||
try:
|
||||
from IPython.core.getipython import get_ipython
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from IPython.core.interactiveshell import InteractiveShell
|
||||
return 'IPKernelApp' in t.cast('dict[str, t.Any]', t.cast(t.Callable[[], 'InteractiveShell'], get_ipython)().config)
|
||||
return 'IPKernelApp' in t.cast(
|
||||
'dict[str, t.Any]', t.cast(t.Callable[[], 'InteractiveShell'], get_ipython)().config
|
||||
)
|
||||
except (ImportError, AttributeError):
|
||||
return False
|
||||
|
||||
|
||||
class suppress(contextlib.suppress, contextlib.ContextDecorator):
|
||||
"""A version of contextlib.suppress with decorator support.
|
||||
|
||||
@@ -265,6 +291,7 @@ class suppress(contextlib.suppress, contextlib.ContextDecorator):
|
||||
>>> key_error()
|
||||
"""
|
||||
|
||||
|
||||
def compose(*funcs: AnyCallable) -> AnyCallable:
|
||||
"""Compose any number of unary functions into a single unary function.
|
||||
|
||||
@@ -281,11 +308,13 @@ def compose(*funcs: AnyCallable) -> AnyCallable:
|
||||
>>> [f(3*x, x+1) for x in range(1,10)]
|
||||
[1.5, 2.0, 2.25, 2.4, 2.5, 2.571, 2.625, 2.667, 2.7]
|
||||
"""
|
||||
|
||||
def compose_two(f1: AnyCallable, f2: AnyCallable) -> AnyCallable:
|
||||
return lambda *args, **kwargs: f1(f2(*args, **kwargs))
|
||||
|
||||
return functools.reduce(compose_two, funcs)
|
||||
|
||||
|
||||
def apply(transform: AnyCallable) -> t.Callable[[AnyCallable], AnyCallable]:
|
||||
"""Decorate a function with a transform function that is invoked on results returned from the decorated function.
|
||||
|
||||
@@ -304,9 +333,11 @@ def apply(transform: AnyCallable) -> t.Callable[[AnyCallable], AnyCallable]:
|
||||
"""
|
||||
return lambda func: functools.wraps(func)(compose(transform, func))
|
||||
|
||||
|
||||
T = t.TypeVar('T')
|
||||
K = t.TypeVar('K')
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@overload
|
||||
def first_not_none(*args: T | None, default: T) -> T: ...
|
||||
@@ -315,6 +346,7 @@ def first_not_none(*args: T | None) -> T | None: ...
|
||||
def first_not_none(*args: T | None, default: None | T = None) -> T | None: return next((arg for arg in args if arg is not None), default)
|
||||
# yapf: enable
|
||||
|
||||
|
||||
def resolve_filepath(path: str, ctx: str | None = None) -> str:
|
||||
"""Resolve a file path to an absolute path, expand user and environment variables."""
|
||||
try:
|
||||
@@ -322,26 +354,34 @@ def resolve_filepath(path: str, ctx: str | None = None) -> str:
|
||||
except FileNotFoundError:
|
||||
return path
|
||||
|
||||
|
||||
def validate_is_path(maybe_path: str) -> bool:
|
||||
return os.path.exists(os.path.dirname(resolve_filepath(maybe_path)))
|
||||
|
||||
|
||||
def generate_context(framework_name: str) -> ModelContext:
|
||||
import openllm_core
|
||||
|
||||
from bentoml._internal.models.model import ModelContext
|
||||
|
||||
framework_versions = {'transformers': pkg.get_pkg_version('transformers')}
|
||||
if openllm_core.utils.is_torch_available(): framework_versions['torch'] = pkg.get_pkg_version('torch')
|
||||
if openllm_core.utils.is_torch_available():
|
||||
framework_versions['torch'] = pkg.get_pkg_version('torch')
|
||||
return ModelContext(framework_name=framework_name, framework_versions=framework_versions)
|
||||
|
||||
|
||||
_TOKENIZER_PREFIX = '_tokenizer_'
|
||||
|
||||
|
||||
def flatten_attrs(**attrs: t.Any) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
|
||||
"""Normalize the given attrs to a model and tokenizer kwargs accordingly."""
|
||||
tokenizer_attrs = {k[len(_TOKENIZER_PREFIX):]: v for k, v in attrs.items() if k.startswith(_TOKENIZER_PREFIX)}
|
||||
tokenizer_attrs = {k[len(_TOKENIZER_PREFIX) :]: v for k, v in attrs.items() if k.startswith(_TOKENIZER_PREFIX)}
|
||||
for k in tuple(attrs.keys()):
|
||||
if k.startswith(_TOKENIZER_PREFIX): del attrs[k]
|
||||
if k.startswith(_TOKENIZER_PREFIX):
|
||||
del attrs[k]
|
||||
return attrs, tokenizer_attrs
|
||||
|
||||
|
||||
# NOTE: The set marks contains a set of modules name
|
||||
# that are available above and are whitelisted
|
||||
# to be included in the extra_objects map.
|
||||
@@ -349,20 +389,36 @@ _whitelist_modules = {'pkg'}
|
||||
|
||||
# XXX: define all classes, functions import above this line
|
||||
# since _extras will be the locals() import from this file.
|
||||
_extras: dict[str, t.Any] = {k: v for k, v in locals().items() if k in _whitelist_modules or (not isinstance(v, types.ModuleType) and not k.startswith('_'))}
|
||||
_extras: dict[str, t.Any] = {
|
||||
k: v
|
||||
for k, v in locals().items()
|
||||
if k in _whitelist_modules or (not isinstance(v, types.ModuleType) and not k.startswith('_'))
|
||||
}
|
||||
_extras['__openllm_migration__'] = {'bentoml_cattr': 'converter'}
|
||||
_import_structure: dict[str, list[str]] = {
|
||||
'analytics': [],
|
||||
'codegen': [],
|
||||
'dantic': [],
|
||||
'lazy': [],
|
||||
'pkg': [],
|
||||
'representation': ['ReprMixin'],
|
||||
'serde': ['converter'],
|
||||
'import_utils': [
|
||||
'OPTIONAL_DEPENDENCIES', 'is_vllm_available', 'is_torch_available', 'is_bitsandbytes_available', 'is_peft_available', 'is_jupyter_available', 'is_jupytext_available',
|
||||
'is_notebook_available', 'is_autogptq_available', 'is_grpc_available', 'is_transformers_available', 'is_optimum_supports_gptq', 'is_autoawq_available', 'is_bentoml_available'
|
||||
]
|
||||
'analytics': [],
|
||||
'codegen': [],
|
||||
'dantic': [],
|
||||
'lazy': [],
|
||||
'pkg': [],
|
||||
'representation': ['ReprMixin'],
|
||||
'serde': ['converter'],
|
||||
'import_utils': [
|
||||
'OPTIONAL_DEPENDENCIES',
|
||||
'is_vllm_available',
|
||||
'is_torch_available',
|
||||
'is_bitsandbytes_available',
|
||||
'is_peft_available',
|
||||
'is_jupyter_available',
|
||||
'is_jupytext_available',
|
||||
'is_notebook_available',
|
||||
'is_autogptq_available',
|
||||
'is_grpc_available',
|
||||
'is_transformers_available',
|
||||
'is_optimum_supports_gptq',
|
||||
'is_autoawq_available',
|
||||
'is_bentoml_available',
|
||||
],
|
||||
}
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
|
||||
@@ -13,6 +13,7 @@ import openllm_core
|
||||
|
||||
from openllm_core._typing_compat import ParamSpec
|
||||
|
||||
|
||||
P = ParamSpec('P')
|
||||
T = t.TypeVar('T')
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -20,14 +21,17 @@ logger = logging.getLogger(__name__)
|
||||
# This variable is a proxy that will control BENTOML_DO_NOT_TRACK
|
||||
OPENLLM_DO_NOT_TRACK = 'OPENLLM_DO_NOT_TRACK'
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def do_not_track() -> bool:
|
||||
return openllm_core.utils.check_bool_env(OPENLLM_DO_NOT_TRACK)
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def _usage_event_debugging() -> bool:
|
||||
return os.environ.get('__BENTOML_DEBUG_USAGE', str(False)).lower() == 'true'
|
||||
|
||||
|
||||
def silent(func: t.Callable[P, T]) -> t.Callable[P, T]:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> t.Any:
|
||||
@@ -35,21 +39,29 @@ def silent(func: t.Callable[P, T]) -> t.Callable[P, T]:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as err:
|
||||
if _usage_event_debugging():
|
||||
if openllm_core.utils.get_debug_mode(): logger.error('Tracking Error: %s', err, stack_info=True, stacklevel=3)
|
||||
else: logger.info('Tracking Error: %s', err)
|
||||
else: logger.debug('Tracking Error: %s', err)
|
||||
if openllm_core.utils.get_debug_mode():
|
||||
logger.error('Tracking Error: %s', err, stack_info=True, stacklevel=3)
|
||||
else:
|
||||
logger.info('Tracking Error: %s', err)
|
||||
else:
|
||||
logger.debug('Tracking Error: %s', err)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@silent
|
||||
def track(event_properties: attr.AttrsInstance) -> None:
|
||||
from bentoml._internal.utils import analytics
|
||||
if do_not_track(): return
|
||||
|
||||
if do_not_track():
|
||||
return
|
||||
analytics.track(t.cast('analytics.schemas.EventMeta', event_properties))
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_bentoml_tracking() -> t.Generator[None, None, None]:
|
||||
from bentoml._internal.utils import analytics
|
||||
|
||||
original_value = os.environ.pop(analytics.BENTOML_DO_NOT_TRACK, str(False))
|
||||
try:
|
||||
os.environ[analytics.BENTOML_DO_NOT_TRACK] = str(do_not_track())
|
||||
@@ -57,6 +69,7 @@ def set_bentoml_tracking() -> t.Generator[None, None, None]:
|
||||
finally:
|
||||
os.environ[analytics.BENTOML_DO_NOT_TRACK] = original_value
|
||||
|
||||
|
||||
class EventMeta:
|
||||
@property
|
||||
def event_name(self) -> str:
|
||||
@@ -64,14 +77,17 @@ class EventMeta:
|
||||
event_name = re.sub(r'(?<!^)(?=[A-Z])', '_', self.__class__.__name__).lower()
|
||||
# remove "_event" suffix
|
||||
suffix_to_remove = '_event'
|
||||
if event_name.endswith(suffix_to_remove): event_name = event_name[:-len(suffix_to_remove)]
|
||||
if event_name.endswith(suffix_to_remove):
|
||||
event_name = event_name[: -len(suffix_to_remove)]
|
||||
return event_name
|
||||
|
||||
|
||||
@attr.define
|
||||
class ModelSaveEvent(EventMeta):
|
||||
module: str
|
||||
model_size_in_kb: float
|
||||
|
||||
|
||||
@attr.define
|
||||
class OpenllmCliEvent(EventMeta):
|
||||
cmd_group: str
|
||||
@@ -82,6 +98,7 @@ class OpenllmCliEvent(EventMeta):
|
||||
error_type: str = attr.field(default=None)
|
||||
return_code: int = attr.field(default=None)
|
||||
|
||||
|
||||
@attr.define
|
||||
class StartInitEvent(EventMeta):
|
||||
model_name: str
|
||||
@@ -91,6 +108,8 @@ class StartInitEvent(EventMeta):
|
||||
def handler(llm_config: openllm_core.LLMConfig) -> StartInitEvent:
|
||||
return StartInitEvent(model_name=llm_config['model_name'], llm_config=llm_config.model_dump())
|
||||
|
||||
|
||||
def track_start_init(llm_config: openllm_core.LLMConfig) -> None:
|
||||
if do_not_track(): return
|
||||
if do_not_track():
|
||||
return
|
||||
track(StartInitEvent.handler(llm_config))
|
||||
|
||||
@@ -10,6 +10,7 @@ from operator import itemgetter
|
||||
|
||||
import orjson
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import openllm_core
|
||||
|
||||
@@ -17,6 +18,7 @@ if t.TYPE_CHECKING:
|
||||
from openllm_core._typing_compat import DictStrAny
|
||||
from openllm_core._typing_compat import ListStr
|
||||
from openllm_core._typing_compat import LiteralString
|
||||
|
||||
PartialAny = functools.partial[t.Any]
|
||||
|
||||
_T = t.TypeVar('_T', bound=t.Callable[..., t.Any])
|
||||
@@ -25,24 +27,32 @@ logger = logging.getLogger(__name__)
|
||||
# sentinel object for unequivocal object() getattr
|
||||
_sentinel = object()
|
||||
|
||||
|
||||
def has_own_attribute(cls: type[t.Any], attrib_name: t.Any) -> bool:
|
||||
"""Check whether *cls* defines *attrib_name* (and doesn't just inherit it)."""
|
||||
attr = getattr(cls, attrib_name, _sentinel)
|
||||
if attr is _sentinel: return False
|
||||
if attr is _sentinel:
|
||||
return False
|
||||
for base_cls in cls.__mro__[1:]:
|
||||
a = getattr(base_cls, attrib_name, None)
|
||||
if attr is a: return False
|
||||
if attr is a:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_annotations(cls: type[t.Any]) -> DictStrAny:
|
||||
if has_own_attribute(cls, '__annotations__'): return cls.__annotations__
|
||||
if has_own_attribute(cls, '__annotations__'):
|
||||
return cls.__annotations__
|
||||
return t.cast('DictStrAny', {})
|
||||
|
||||
|
||||
def is_class_var(annot: str | t.Any) -> bool:
|
||||
annot = str(annot)
|
||||
# Annotation can be quoted.
|
||||
if annot.startswith(("'", '"')) and annot.endswith(("'", '"')): annot = annot[1:-1]
|
||||
return annot.startswith(('typing.ClassVar', 't.ClassVar', 'ClassVar', 'typing_extensions.ClassVar',))
|
||||
if annot.startswith(("'", '"')) and annot.endswith(("'", '"')):
|
||||
annot = annot[1:-1]
|
||||
return annot.startswith(('typing.ClassVar', 't.ClassVar', 'ClassVar', 'typing_extensions.ClassVar'))
|
||||
|
||||
|
||||
def add_method_dunders(cls: type[t.Any], method_or_cls: _T, _overwrite_doc: str | None = None) -> _T:
|
||||
try:
|
||||
@@ -59,9 +69,11 @@ def add_method_dunders(cls: type[t.Any], method_or_cls: _T, _overwrite_doc: str
|
||||
pass
|
||||
return method_or_cls
|
||||
|
||||
|
||||
def _compile_and_eval(script: str, globs: DictStrAny, locs: t.Any = None, filename: str = '') -> None:
|
||||
eval(compile(script, filename, 'exec'), globs, locs)
|
||||
|
||||
|
||||
def _make_method(name: str, script: str, filename: str, globs: DictStrAny) -> AnyCallable:
|
||||
locs: DictStrAny = {}
|
||||
# In order of debuggers like PDB being able to step through the code, we add a fake linecache entry.
|
||||
@@ -70,13 +82,15 @@ def _make_method(name: str, script: str, filename: str, globs: DictStrAny) -> An
|
||||
while True:
|
||||
linecache_tuple = (len(script), None, script.splitlines(True), filename)
|
||||
old_val = linecache.cache.setdefault(filename, linecache_tuple)
|
||||
if old_val == linecache_tuple: break
|
||||
if old_val == linecache_tuple:
|
||||
break
|
||||
else:
|
||||
filename = f'{base_filename[:-1]}-{count}>'
|
||||
count += 1
|
||||
_compile_and_eval(script, globs, locs, filename)
|
||||
return locs[name]
|
||||
|
||||
|
||||
def make_attr_tuple_class(cls_name: str, attr_names: t.Sequence[str]) -> type[t.Any]:
|
||||
"""Create a tuple subclass to hold class attributes.
|
||||
|
||||
@@ -89,39 +103,53 @@ def make_attr_tuple_class(cls_name: str, attr_names: t.Sequence[str]) -> type[t.
|
||||
from . import SHOW_CODEGEN
|
||||
|
||||
attr_class_name = f'{cls_name}Attributes'
|
||||
attr_class_template = [f'class {attr_class_name}(tuple):', ' __slots__ = ()',]
|
||||
attr_class_template = [f'class {attr_class_name}(tuple):', ' __slots__ = ()']
|
||||
if attr_names:
|
||||
for i, attr_name in enumerate(attr_names):
|
||||
attr_class_template.append(f' {attr_name} = _attrs_property(_attrs_itemgetter({i}))')
|
||||
else:
|
||||
attr_class_template.append(' pass')
|
||||
globs: DictStrAny = {'_attrs_itemgetter': itemgetter, '_attrs_property': property}
|
||||
if SHOW_CODEGEN: print(f'Generated class for {attr_class_name}:\n\n', '\n'.join(attr_class_template))
|
||||
if SHOW_CODEGEN:
|
||||
print(f'Generated class for {attr_class_name}:\n\n', '\n'.join(attr_class_template))
|
||||
_compile_and_eval('\n'.join(attr_class_template), globs)
|
||||
return globs[attr_class_name]
|
||||
|
||||
|
||||
def generate_unique_filename(cls: type[t.Any], func_name: str) -> str:
|
||||
return f"<{cls.__name__} generated {func_name} {cls.__module__}.{getattr(cls, '__qualname__', cls.__name__)}>"
|
||||
|
||||
def generate_function(typ: type[t.Any],
|
||||
func_name: str,
|
||||
lines: list[str] | None,
|
||||
args: tuple[str, ...] | None,
|
||||
globs: dict[str, t.Any],
|
||||
annotations: dict[str, t.Any] | None = None) -> AnyCallable:
|
||||
|
||||
def generate_function(
|
||||
typ: type[t.Any],
|
||||
func_name: str,
|
||||
lines: list[str] | None,
|
||||
args: tuple[str, ...] | None,
|
||||
globs: dict[str, t.Any],
|
||||
annotations: dict[str, t.Any] | None = None,
|
||||
) -> AnyCallable:
|
||||
from openllm_core.utils import SHOW_CODEGEN
|
||||
script = 'def %s(%s):\n %s\n' % (func_name, ', '.join(args) if args is not None else '', '\n '.join(lines) if lines else 'pass')
|
||||
|
||||
script = 'def %s(%s):\n %s\n' % (
|
||||
func_name,
|
||||
', '.join(args) if args is not None else '',
|
||||
'\n '.join(lines) if lines else 'pass',
|
||||
)
|
||||
meth = _make_method(func_name, script, generate_unique_filename(typ, func_name), globs)
|
||||
if annotations: meth.__annotations__ = annotations
|
||||
if SHOW_CODEGEN: print(f'Generated script for {typ}:\n\n', script)
|
||||
if annotations:
|
||||
meth.__annotations__ = annotations
|
||||
if SHOW_CODEGEN:
|
||||
print(f'Generated script for {typ}:\n\n', script)
|
||||
return meth
|
||||
|
||||
def make_env_transformer(cls: type[openllm_core.LLMConfig],
|
||||
model_name: str,
|
||||
suffix: LiteralString | None = None,
|
||||
default_callback: t.Callable[[str, t.Any], t.Any] | None = None,
|
||||
globs: DictStrAny | None = None,
|
||||
) -> AnyCallable:
|
||||
|
||||
def make_env_transformer(
|
||||
cls: type[openllm_core.LLMConfig],
|
||||
model_name: str,
|
||||
suffix: LiteralString | None = None,
|
||||
default_callback: t.Callable[[str, t.Any], t.Any] | None = None,
|
||||
globs: DictStrAny | None = None,
|
||||
) -> AnyCallable:
|
||||
from openllm_core.utils import dantic
|
||||
from openllm_core.utils import field_env_key
|
||||
|
||||
@@ -130,17 +158,35 @@ def make_env_transformer(cls: type[openllm_core.LLMConfig],
|
||||
|
||||
default_callback = identity if default_callback is None else default_callback
|
||||
globs = {} if globs is None else globs
|
||||
globs.update({'__populate_env': dantic.env_converter, '__default_callback': default_callback, '__field_env': field_env_key, '__suffix': suffix or '', '__model_name': model_name,})
|
||||
globs.update(
|
||||
{
|
||||
'__populate_env': dantic.env_converter,
|
||||
'__default_callback': default_callback,
|
||||
'__field_env': field_env_key,
|
||||
'__suffix': suffix or '',
|
||||
'__model_name': model_name,
|
||||
}
|
||||
)
|
||||
lines: ListStr = [
|
||||
'__env=lambda field_name:__field_env(field_name,__suffix)',
|
||||
"return [f.evolve(default=__populate_env(__default_callback(f.name,f.default),__env(f.name)),metadata={'env':f.metadata.get('env',__env(f.name)),'description':f.metadata.get('description', '(not provided)')}) for f in fields]"
|
||||
'__env=lambda field_name:__field_env(field_name,__suffix)',
|
||||
"return [f.evolve(default=__populate_env(__default_callback(f.name,f.default),__env(f.name)),metadata={'env':f.metadata.get('env',__env(f.name)),'description':f.metadata.get('description', '(not provided)')}) for f in fields]",
|
||||
]
|
||||
fields_ann = 'list[attr.Attribute[t.Any]]'
|
||||
return generate_function(cls, '__auto_env', lines, args=('_', 'fields'), globs=globs, annotations={'_': 'type[LLMConfig]', 'fields': fields_ann, 'return': fields_ann})
|
||||
return generate_function(
|
||||
cls,
|
||||
'__auto_env',
|
||||
lines,
|
||||
args=('_', 'fields'),
|
||||
globs=globs,
|
||||
annotations={'_': 'type[LLMConfig]', 'fields': fields_ann, 'return': fields_ann},
|
||||
)
|
||||
|
||||
|
||||
def gen_sdk(func: _T, name: str | None = None, **attrs: t.Any) -> _T:
|
||||
from .representation import ReprMixin
|
||||
if name is None: name = func.__name__.strip('_')
|
||||
|
||||
if name is None:
|
||||
name = func.__name__.strip('_')
|
||||
_signatures = inspect.signature(func).parameters
|
||||
|
||||
def _repr(self: ReprMixin) -> str:
|
||||
@@ -149,20 +195,29 @@ def gen_sdk(func: _T, name: str | None = None, **attrs: t.Any) -> _T:
|
||||
def _repr_args(self: ReprMixin) -> t.Iterator[t.Tuple[str, t.Any]]:
|
||||
return ((k, _signatures[k].annotation) for k in self.__repr_keys__)
|
||||
|
||||
if func.__doc__ is None: doc = f'Generated SDK for {func.__name__}'
|
||||
else: doc = func.__doc__
|
||||
if func.__doc__ is None:
|
||||
doc = f'Generated SDK for {func.__name__}'
|
||||
else:
|
||||
doc = func.__doc__
|
||||
return t.cast(
|
||||
_T,
|
||||
functools.update_wrapper(
|
||||
types.new_class(name, (t.cast('PartialAny', functools.partial), ReprMixin),
|
||||
exec_body=lambda ns: ns.update({
|
||||
'__repr_keys__': property(lambda _: [i for i in _signatures.keys() if not i.startswith('_')]),
|
||||
'__repr_args__': _repr_args,
|
||||
'__repr__': _repr,
|
||||
'__doc__': inspect.cleandoc(doc),
|
||||
'__module__': 'openllm'
|
||||
}),
|
||||
)(func, **attrs), func,
|
||||
))
|
||||
_T,
|
||||
functools.update_wrapper(
|
||||
types.new_class(
|
||||
name,
|
||||
(t.cast('PartialAny', functools.partial), ReprMixin),
|
||||
exec_body=lambda ns: ns.update(
|
||||
{
|
||||
'__repr_keys__': property(lambda _: [i for i in _signatures.keys() if not i.startswith('_')]),
|
||||
'__repr_args__': _repr_args,
|
||||
'__repr__': _repr,
|
||||
'__doc__': inspect.cleandoc(doc),
|
||||
'__module__': 'openllm',
|
||||
}
|
||||
),
|
||||
)(func, **attrs),
|
||||
func,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
__all__ = ['gen_sdk', 'make_attr_tuple_class', 'make_env_transformer', 'generate_unique_filename', 'generate_function']
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""An interface provides the best of pydantic and attrs."""
|
||||
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
import importlib
|
||||
@@ -18,6 +19,7 @@ from click import ParamType
|
||||
from click import shell_completion as sc
|
||||
from click import types as click_types
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from attr import _ValidatorType
|
||||
|
||||
@@ -25,14 +27,38 @@ AnyCallable = t.Callable[..., t.Any]
|
||||
FC = t.TypeVar('FC', bound=t.Union[AnyCallable, click.Command])
|
||||
|
||||
__all__ = [
|
||||
'FC', 'attrs_to_options', 'Field', 'parse_type', 'is_typing', 'is_literal', 'ModuleType', 'EnumChoice', 'LiteralChoice', 'allows_multiple', 'is_mapping', 'is_container',
|
||||
'parse_container_args', 'parse_single_arg', 'CUDA', 'JsonType', 'BytesType'
|
||||
'FC',
|
||||
'attrs_to_options',
|
||||
'Field',
|
||||
'parse_type',
|
||||
'is_typing',
|
||||
'is_literal',
|
||||
'ModuleType',
|
||||
'EnumChoice',
|
||||
'LiteralChoice',
|
||||
'allows_multiple',
|
||||
'is_mapping',
|
||||
'is_container',
|
||||
'parse_container_args',
|
||||
'parse_single_arg',
|
||||
'CUDA',
|
||||
'JsonType',
|
||||
'BytesType',
|
||||
]
|
||||
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return sorted(__all__)
|
||||
|
||||
def attrs_to_options(name: str, field: attr.Attribute[t.Any], model_name: str, typ: t.Any = None, suffix_generation: bool = False, suffix_sampling: bool = False) -> t.Callable[[FC], FC]:
|
||||
|
||||
def attrs_to_options(
|
||||
name: str,
|
||||
field: attr.Attribute[t.Any],
|
||||
model_name: str,
|
||||
typ: t.Any = None,
|
||||
suffix_generation: bool = False,
|
||||
suffix_sampling: bool = False,
|
||||
) -> t.Callable[[FC], FC]:
|
||||
# TODO: support parsing nested attrs class and Union
|
||||
envvar = field.metadata['env']
|
||||
dasherized = inflection.dasherize(name)
|
||||
@@ -40,25 +66,32 @@ def attrs_to_options(name: str, field: attr.Attribute[t.Any], model_name: str, t
|
||||
|
||||
if typ in (None, attr.NOTHING):
|
||||
typ = field.type
|
||||
if typ is None: raise RuntimeError(f'Failed to parse type for {name}')
|
||||
if typ is None:
|
||||
raise RuntimeError(f'Failed to parse type for {name}')
|
||||
|
||||
full_option_name = f'--{dasherized}'
|
||||
if field.type is bool: full_option_name += f'/--no-{dasherized}'
|
||||
if suffix_generation: identifier = f'{model_name}_generation_{underscored}'
|
||||
elif suffix_sampling: identifier = f'{model_name}_sampling_{underscored}'
|
||||
else: identifier = f'{model_name}_{underscored}'
|
||||
if field.type is bool:
|
||||
full_option_name += f'/--no-{dasherized}'
|
||||
if suffix_generation:
|
||||
identifier = f'{model_name}_generation_{underscored}'
|
||||
elif suffix_sampling:
|
||||
identifier = f'{model_name}_sampling_{underscored}'
|
||||
else:
|
||||
identifier = f'{model_name}_{underscored}'
|
||||
|
||||
return cog.optgroup.option(
|
||||
identifier,
|
||||
full_option_name,
|
||||
type=parse_type(typ),
|
||||
required=field.default is attr.NOTHING,
|
||||
default=field.default if field.default not in (attr.NOTHING, None) else None,
|
||||
show_default=True,
|
||||
multiple=allows_multiple(typ) if typ else False,
|
||||
help=field.metadata.get('description', '(No description provided)'),
|
||||
show_envvar=True,
|
||||
envvar=envvar,
|
||||
)
|
||||
|
||||
return cog.optgroup.option(identifier,
|
||||
full_option_name,
|
||||
type=parse_type(typ),
|
||||
required=field.default is attr.NOTHING,
|
||||
default=field.default if field.default not in (attr.NOTHING, None) else None,
|
||||
show_default=True,
|
||||
multiple=allows_multiple(typ) if typ else False,
|
||||
help=field.metadata.get('description', '(No description provided)'),
|
||||
show_envvar=True,
|
||||
envvar=envvar,
|
||||
)
|
||||
|
||||
def env_converter(value: t.Any, env: str | None = None) -> t.Any:
|
||||
if env is not None:
|
||||
@@ -70,16 +103,19 @@ def env_converter(value: t.Any, env: str | None = None) -> t.Any:
|
||||
raise RuntimeError(f"Failed to parse ({value!r}) from '{env}': {err}") from None
|
||||
return value
|
||||
|
||||
def Field(default: t.Any = None,
|
||||
*,
|
||||
ge: int | float | None = None,
|
||||
le: int | float | None = None,
|
||||
validator: _ValidatorType[t.Any] | None = None,
|
||||
description: str | None = None,
|
||||
env: str | None = None,
|
||||
auto_default: bool = False,
|
||||
use_default_converter: bool = True,
|
||||
**attrs: t.Any) -> t.Any:
|
||||
|
||||
def Field(
|
||||
default: t.Any = None,
|
||||
*,
|
||||
ge: int | float | None = None,
|
||||
le: int | float | None = None,
|
||||
validator: _ValidatorType[t.Any] | None = None,
|
||||
description: str | None = None,
|
||||
env: str | None = None,
|
||||
auto_default: bool = False,
|
||||
use_default_converter: bool = True,
|
||||
**attrs: t.Any,
|
||||
) -> t.Any:
|
||||
"""A decorator that extends attr.field with additional arguments, which provides the same interface as pydantic's Field.
|
||||
|
||||
By default, if both validator and ge are provided, then then ge will be
|
||||
@@ -103,27 +139,39 @@ def Field(default: t.Any = None,
|
||||
**attrs: The rest of the arguments are passed to attr.field
|
||||
"""
|
||||
metadata = attrs.pop('metadata', {})
|
||||
if description is None: description = '(No description provided)'
|
||||
if description is None:
|
||||
description = '(No description provided)'
|
||||
metadata['description'] = description
|
||||
if env is not None: metadata['env'] = env
|
||||
if env is not None:
|
||||
metadata['env'] = env
|
||||
piped: list[_ValidatorType[t.Any]] = []
|
||||
|
||||
converter = attrs.pop('converter', None)
|
||||
if use_default_converter: converter = functools.partial(env_converter, env=env)
|
||||
if use_default_converter:
|
||||
converter = functools.partial(env_converter, env=env)
|
||||
|
||||
if ge is not None: piped.append(attr.validators.ge(ge))
|
||||
if le is not None: piped.append(attr.validators.le(le))
|
||||
if validator is not None: piped.append(validator)
|
||||
if ge is not None:
|
||||
piped.append(attr.validators.ge(ge))
|
||||
if le is not None:
|
||||
piped.append(attr.validators.le(le))
|
||||
if validator is not None:
|
||||
piped.append(validator)
|
||||
|
||||
if len(piped) == 0: _validator = None
|
||||
elif len(piped) == 1: _validator = piped[0]
|
||||
else: _validator = attr.validators.and_(*piped)
|
||||
if len(piped) == 0:
|
||||
_validator = None
|
||||
elif len(piped) == 1:
|
||||
_validator = piped[0]
|
||||
else:
|
||||
_validator = attr.validators.and_(*piped)
|
||||
|
||||
factory = attrs.pop('factory', None)
|
||||
if factory is not None and default is not None: raise RuntimeError("'factory' and 'default' are mutually exclusive.")
|
||||
if factory is not None and default is not None:
|
||||
raise RuntimeError("'factory' and 'default' are mutually exclusive.")
|
||||
# NOTE: the behaviour of this is we will respect factory over the default
|
||||
if factory is not None: attrs['factory'] = factory
|
||||
else: attrs['default'] = default
|
||||
if factory is not None:
|
||||
attrs['factory'] = factory
|
||||
else:
|
||||
attrs['default'] = default
|
||||
|
||||
kw_only = attrs.pop('kw_only', False)
|
||||
if auto_default and kw_only:
|
||||
@@ -131,6 +179,7 @@ def Field(default: t.Any = None,
|
||||
|
||||
return attr.field(metadata=metadata, validator=_validator, converter=converter, **attrs)
|
||||
|
||||
|
||||
def parse_type(field_type: t.Any) -> ParamType | tuple[ParamType, ...]:
|
||||
"""Transforms the pydantic field's type into a click-compatible type.
|
||||
|
||||
@@ -151,17 +200,22 @@ def parse_type(field_type: t.Any) -> ParamType | tuple[ParamType, ...]:
|
||||
if is_literal(field_type):
|
||||
return LiteralChoice(value=field_type, case_sensitive=True)
|
||||
# modules, classes, functions
|
||||
if is_typing(field_type): return ModuleType()
|
||||
if is_typing(field_type):
|
||||
return ModuleType()
|
||||
# entire dictionaries:
|
||||
# using a Dict, convert in advance
|
||||
if is_mapping(field_type): return JsonType()
|
||||
if is_mapping(field_type):
|
||||
return JsonType()
|
||||
# list, List[p], Tuple[p], Set[p] and so on
|
||||
if is_container(field_type): return parse_container_args(field_type)
|
||||
if is_container(field_type):
|
||||
return parse_container_args(field_type)
|
||||
# bytes are not natively supported by click
|
||||
if lenient_issubclass(field_type, bytes): return BytesType()
|
||||
if lenient_issubclass(field_type, bytes):
|
||||
return BytesType()
|
||||
# return the current type: it should be a primitive
|
||||
return field_type
|
||||
|
||||
|
||||
def is_typing(field_type: type) -> bool:
|
||||
"""Checks whether the current type is a module-like type.
|
||||
|
||||
@@ -172,10 +226,13 @@ def is_typing(field_type: type) -> bool:
|
||||
bool: true if the type is itself a type
|
||||
"""
|
||||
raw = t.get_origin(field_type)
|
||||
if raw is None: return False
|
||||
if raw is type or raw is t.Type: return True
|
||||
if raw is None:
|
||||
return False
|
||||
if raw is type or raw is t.Type:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_literal(field_type: type) -> bool:
|
||||
"""Checks whether the given field type is a Literal type or not.
|
||||
|
||||
@@ -191,6 +248,7 @@ def is_literal(field_type: type) -> bool:
|
||||
origin = t.get_origin(field_type)
|
||||
return origin is not None and origin is t.Literal
|
||||
|
||||
|
||||
class ModuleType(ParamType):
|
||||
name = 'module'
|
||||
|
||||
@@ -198,7 +256,8 @@ class ModuleType(ParamType):
|
||||
module_name, class_name = value.rsplit('.', maxsplit=1)
|
||||
if not all(s.isidentifier() for s in module_name.split('.')):
|
||||
raise ValueError(f"'{value}' is not a valid module name")
|
||||
if not class_name.isidentifier(): raise ValueError(f"Variable '{class_name}' is not a valid identifier")
|
||||
if not class_name.isidentifier():
|
||||
raise ValueError(f"Variable '{class_name}' is not a valid identifier")
|
||||
|
||||
module = importlib.import_module(module_name)
|
||||
if class_name:
|
||||
@@ -209,11 +268,13 @@ class ModuleType(ParamType):
|
||||
|
||||
def convert(self, value: str | t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
|
||||
try:
|
||||
if isinstance(value, str): return self._import_object(value)
|
||||
if isinstance(value, str):
|
||||
return self._import_object(value)
|
||||
return value
|
||||
except Exception as exc:
|
||||
self.fail(f"'{value}' is not a valid object ({type(exc)}: {exc!s})", param, ctx)
|
||||
|
||||
|
||||
class EnumChoice(click.Choice):
|
||||
name = 'enum'
|
||||
|
||||
@@ -237,6 +298,7 @@ class EnumChoice(click.Choice):
|
||||
result = self.internal_type[result]
|
||||
return result
|
||||
|
||||
|
||||
class LiteralChoice(EnumChoice):
|
||||
name = 'literal'
|
||||
|
||||
@@ -251,6 +313,7 @@ class LiteralChoice(EnumChoice):
|
||||
super(EnumChoice, self).__init__(list(_mapping), case_sensitive)
|
||||
self.internal_type = item_type
|
||||
|
||||
|
||||
def allows_multiple(field_type: type[t.Any]) -> bool:
|
||||
"""Checks whether the current type allows for multiple arguments to be provided as input or not.
|
||||
|
||||
@@ -276,6 +339,7 @@ def allows_multiple(field_type: type[t.Any]) -> bool:
|
||||
return not isinstance(args, tuple)
|
||||
return False
|
||||
|
||||
|
||||
def is_mapping(field_type: type) -> bool:
|
||||
"""Checks whether this field represents a dictionary or JSON object.
|
||||
|
||||
@@ -287,12 +351,16 @@ def is_mapping(field_type: type) -> bool:
|
||||
"""
|
||||
# Early out for standard containers.
|
||||
from . import lenient_issubclass
|
||||
if lenient_issubclass(field_type, t.Mapping): return True
|
||||
|
||||
if lenient_issubclass(field_type, t.Mapping):
|
||||
return True
|
||||
# for everything else or when the typing is more complex, check its origin
|
||||
origin = t.get_origin(field_type)
|
||||
if origin is None: return False
|
||||
if origin is None:
|
||||
return False
|
||||
return lenient_issubclass(origin, t.Mapping)
|
||||
|
||||
|
||||
def is_container(field_type: type) -> bool:
|
||||
"""Checks whether the current type is a container type ('contains' other types), like lists and tuples.
|
||||
|
||||
@@ -303,15 +371,20 @@ def is_container(field_type: type) -> bool:
|
||||
bool: true if a container, false otherwise
|
||||
"""
|
||||
# do not consider strings or byte arrays as containers
|
||||
if field_type in (str, bytes): return False
|
||||
if field_type in (str, bytes):
|
||||
return False
|
||||
# Early out for standard containers: list, tuple, range
|
||||
from . import lenient_issubclass
|
||||
if lenient_issubclass(field_type, t.Container): return True
|
||||
|
||||
if lenient_issubclass(field_type, t.Container):
|
||||
return True
|
||||
origin = t.get_origin(field_type)
|
||||
# Early out for non-typing objects
|
||||
if origin is None: return False
|
||||
if origin is None:
|
||||
return False
|
||||
return lenient_issubclass(origin, t.Container)
|
||||
|
||||
|
||||
def parse_container_args(field_type: type[t.Any]) -> ParamType | tuple[ParamType, ...]:
|
||||
"""Parses the arguments inside a container type (lists, tuples and so on).
|
||||
|
||||
@@ -335,6 +408,7 @@ def parse_container_args(field_type: type[t.Any]) -> ParamType | tuple[ParamType
|
||||
# Then deal with fixed-length containers: Tuple[str, int, int]
|
||||
return tuple(parse_single_arg(arg) for arg in args)
|
||||
|
||||
|
||||
def parse_single_arg(arg: type) -> ParamType:
|
||||
"""Returns the click-compatible type for container origin types.
|
||||
|
||||
@@ -349,35 +423,44 @@ def parse_single_arg(arg: type) -> ParamType:
|
||||
ParamType: click-compatible type
|
||||
"""
|
||||
from . import lenient_issubclass
|
||||
|
||||
# When we don't know the type, we choose 'str'
|
||||
if arg is t.Any: return click_types.convert_type(str)
|
||||
if arg is t.Any:
|
||||
return click_types.convert_type(str)
|
||||
# For containers and nested models, we use JSON
|
||||
if is_container(arg): return JsonType()
|
||||
if lenient_issubclass(arg, bytes): return BytesType()
|
||||
if is_container(arg):
|
||||
return JsonType()
|
||||
if lenient_issubclass(arg, bytes):
|
||||
return BytesType()
|
||||
return click_types.convert_type(arg)
|
||||
|
||||
|
||||
class BytesType(ParamType):
|
||||
name = 'bytes'
|
||||
|
||||
def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
|
||||
if isinstance(value, bytes): return value
|
||||
if isinstance(value, bytes):
|
||||
return value
|
||||
try:
|
||||
return str.encode(value)
|
||||
except Exception as exc:
|
||||
self.fail(f"'{value}' is not a valid string ({exc!s})", param, ctx)
|
||||
|
||||
|
||||
CYGWIN = sys.platform.startswith('cygwin')
|
||||
WIN = sys.platform.startswith('win')
|
||||
if sys.platform.startswith('win') and WIN:
|
||||
|
||||
def _get_argv_encoding() -> str:
|
||||
import locale
|
||||
|
||||
return locale.getpreferredencoding()
|
||||
else:
|
||||
|
||||
def _get_argv_encoding() -> str:
|
||||
return getattr(sys.stdin, 'encoding', None) or sys.getfilesystemencoding()
|
||||
|
||||
|
||||
class CudaValueType(ParamType):
|
||||
name = 'cuda'
|
||||
envvar_list_splitter = ','
|
||||
@@ -386,7 +469,7 @@ class CudaValueType(ParamType):
|
||||
def split_envvar_value(self, rv: str) -> t.Sequence[str]:
|
||||
var = tuple(i for i in rv.split(self.envvar_list_splitter))
|
||||
if '-1' in var:
|
||||
return var[:var.index('-1')]
|
||||
return var[: var.index('-1')]
|
||||
return var
|
||||
|
||||
def shell_complete(self, ctx: click.Context, param: click.Parameter, incomplete: str) -> list[sc.CompletionItem]:
|
||||
@@ -400,6 +483,7 @@ class CudaValueType(ParamType):
|
||||
incomplete: Value being completed. May be empty.
|
||||
"""
|
||||
from openllm.utils import available_devices
|
||||
|
||||
mapping = incomplete.split(self.envvar_list_splitter) if incomplete else available_devices()
|
||||
return [sc.CompletionItem(str(i), help=f'CUDA device index {i}') for i in mapping]
|
||||
|
||||
@@ -423,8 +507,10 @@ class CudaValueType(ParamType):
|
||||
def __repr__(self) -> str:
|
||||
return 'STRING'
|
||||
|
||||
|
||||
CUDA = CudaValueType()
|
||||
|
||||
|
||||
class JsonType(ParamType):
|
||||
name = 'json'
|
||||
|
||||
@@ -438,7 +524,8 @@ class JsonType(ParamType):
|
||||
self.should_load = should_load
|
||||
|
||||
def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
|
||||
if isinstance(value, dict) or not self.should_load: return value
|
||||
if isinstance(value, dict) or not self.should_load:
|
||||
return value
|
||||
try:
|
||||
return orjson.loads(value)
|
||||
except orjson.JSONDecodeError as exc:
|
||||
|
||||
@@ -6,17 +6,32 @@ import logging
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from collections import OrderedDict
|
||||
|
||||
BackendOrderedDict = OrderedDict[str, t.Tuple[t.Callable[[], bool], str]]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
OPTIONAL_DEPENDENCIES = {'opt', 'flan-t5', 'vllm', 'fine-tune', 'ggml', 'agents', 'openai', 'playground', 'gptq', 'grpc', 'awq'}
|
||||
OPTIONAL_DEPENDENCIES = {
|
||||
'opt',
|
||||
'flan-t5',
|
||||
'vllm',
|
||||
'fine-tune',
|
||||
'ggml',
|
||||
'agents',
|
||||
'openai',
|
||||
'playground',
|
||||
'gptq',
|
||||
'grpc',
|
||||
'awq',
|
||||
}
|
||||
ENV_VARS_TRUE_VALUES = {'1', 'ON', 'YES', 'TRUE'}
|
||||
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({'AUTO'})
|
||||
USE_TORCH = os.environ.get('USE_TORCH', 'AUTO').upper()
|
||||
USE_VLLM = os.environ.get('USE_VLLM', 'AUTO').upper()
|
||||
|
||||
|
||||
def _is_package_available(package: str) -> bool:
|
||||
_package_available = importlib.util.find_spec(package) is not None
|
||||
if _package_available:
|
||||
@@ -26,6 +41,7 @@ def _is_package_available(package: str) -> bool:
|
||||
_package_available = False
|
||||
return _package_available
|
||||
|
||||
|
||||
_torch_available = importlib.util.find_spec('torch') is not None
|
||||
_vllm_available = importlib.util.find_spec('vllm') is not None
|
||||
_transformers_available = _is_package_available('transformers')
|
||||
@@ -39,37 +55,49 @@ _notebook_available = _is_package_available('notebook')
|
||||
_autogptq_available = _is_package_available('auto_gptq')
|
||||
_autoawq_available = importlib.util.find_spec('awq') is not None
|
||||
|
||||
|
||||
def is_bentoml_available() -> bool:
|
||||
return _bentoml_available
|
||||
|
||||
|
||||
def is_transformers_available() -> bool:
|
||||
return _transformers_available
|
||||
|
||||
|
||||
def is_grpc_available() -> bool:
|
||||
return _grpc_available
|
||||
|
||||
|
||||
def is_optimum_supports_gptq() -> bool:
|
||||
from . import pkg
|
||||
|
||||
return pkg.pkg_version_info('optimum')[:2] >= (0, 12)
|
||||
|
||||
|
||||
def is_jupyter_available() -> bool:
|
||||
return _jupyter_available
|
||||
|
||||
|
||||
def is_jupytext_available() -> bool:
|
||||
return _jupytext_available
|
||||
|
||||
|
||||
def is_notebook_available() -> bool:
|
||||
return _notebook_available
|
||||
|
||||
|
||||
def is_peft_available() -> bool:
|
||||
return _peft_available
|
||||
|
||||
|
||||
def is_bitsandbytes_available() -> bool:
|
||||
return _bitsandbytes_available
|
||||
|
||||
|
||||
def is_autogptq_available() -> bool:
|
||||
return _autogptq_available
|
||||
|
||||
|
||||
def is_torch_available() -> bool:
|
||||
global _torch_available
|
||||
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and _torch_available:
|
||||
@@ -79,6 +107,7 @@ def is_torch_available() -> bool:
|
||||
_torch_available = False
|
||||
return _torch_available
|
||||
|
||||
|
||||
def is_autoawq_available() -> bool:
|
||||
global _autoawq_available
|
||||
try:
|
||||
@@ -87,6 +116,7 @@ def is_autoawq_available() -> bool:
|
||||
_autoawq_available = False
|
||||
return _autoawq_available
|
||||
|
||||
|
||||
def is_vllm_available() -> bool:
|
||||
global _vllm_available
|
||||
if USE_VLLM in ENV_VARS_TRUE_AND_AUTO_VALUES and _vllm_available:
|
||||
|
||||
@@ -17,29 +17,28 @@ import attr
|
||||
|
||||
import openllm_core
|
||||
|
||||
|
||||
__all__ = ['VersionInfo', 'LazyModule', 'LazyLoader']
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class LazyLoader(types.ModuleType):
|
||||
'''
|
||||
LazyLoader module borrowed from Tensorflow
|
||||
https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/util/lazy_loader.py
|
||||
with a addition of "module caching". This will throw an
|
||||
exception if module cannot be imported.
|
||||
|
||||
Lazily import a module, mainly to avoid pulling in large dependencies.
|
||||
`contrib`, and `ffmpeg` are examples of modules that are large and not always
|
||||
needed, and this allows them to only be loaded when they are used.
|
||||
'''
|
||||
def __init__(self,
|
||||
local_name: str,
|
||||
parent_module_globals: dict[str, t.Any],
|
||||
name: str,
|
||||
warning: str | None = None,
|
||||
exc_msg: str | None = None,
|
||||
exc: type[BaseException] = openllm_core.exceptions.MissingDependencyError,
|
||||
):
|
||||
class LazyLoader(types.ModuleType):
|
||||
"""
|
||||
LazyLoader module borrowed from Tensorflow https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/util/lazy_loader.py with a addition of "module caching". This will throw an exception if module cannot be imported.
|
||||
|
||||
Lazily import a module, mainly to avoid pulling in large dependencies. `contrib`, and `ffmpeg` are examples of modules that are large and not always needed, and this allows them to only be loaded when they are used.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
local_name: str,
|
||||
parent_module_globals: dict[str, t.Any],
|
||||
name: str,
|
||||
warning: str | None = None,
|
||||
exc_msg: str | None = None,
|
||||
exc: type[BaseException] = openllm_core.exceptions.MissingDependencyError,
|
||||
):
|
||||
self._local_name = local_name
|
||||
self._parent_module_globals = parent_module_globals
|
||||
self._warning = warning
|
||||
@@ -82,6 +81,7 @@ class LazyLoader(types.ModuleType):
|
||||
self._module = self._load()
|
||||
return dir(self._module)
|
||||
|
||||
|
||||
# vendorred from attrs
|
||||
@functools.total_ordering
|
||||
@attr.attrs(eq=False, order=False, slots=True, frozen=True, repr=False)
|
||||
@@ -94,14 +94,19 @@ class VersionInfo:
|
||||
@classmethod
|
||||
def from_version_string(cls, s: str) -> VersionInfo:
|
||||
v = s.split('.')
|
||||
if len(v) == 3: v.append('final')
|
||||
if len(v) == 3:
|
||||
v.append('final')
|
||||
return cls(major=int(v[0]), minor=int(v[1]), micro=int(v[2]), releaselevel=v[3])
|
||||
|
||||
def _ensure_tuple(self, other: VersionInfo) -> tuple[tuple[int, int, int, str], tuple[int, int, int, str]]:
|
||||
cmp = attr.astuple(other) if self.__class__ is other.__class__ else other
|
||||
if not isinstance(cmp, tuple): raise NotImplementedError
|
||||
if not (1 <= len(cmp) <= 4): raise NotImplementedError
|
||||
return t.cast(t.Tuple[int, int, int, str], attr.astuple(self)[:len(cmp)]), t.cast(t.Tuple[int, int, int, str], cmp)
|
||||
if not isinstance(cmp, tuple):
|
||||
raise NotImplementedError
|
||||
if not (1 <= len(cmp) <= 4):
|
||||
raise NotImplementedError
|
||||
return t.cast(t.Tuple[int, int, int, str], attr.astuple(self)[: len(cmp)]), t.cast(
|
||||
t.Tuple[int, int, int, str], cmp
|
||||
)
|
||||
|
||||
def __eq__(self, other: t.Any) -> bool:
|
||||
try:
|
||||
@@ -121,17 +126,21 @@ class VersionInfo:
|
||||
def __repr__(self) -> str:
|
||||
return '{0}.{1}.{2}'.format(*attr.astuple(self)[:3])
|
||||
|
||||
|
||||
_sentinel, _reserved_namespace = object(), {'__openllm_migration__'}
|
||||
|
||||
|
||||
class LazyModule(types.ModuleType):
|
||||
# Very heavily inspired by optuna.integration._IntegrationModule: https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
|
||||
def __init__(self,
|
||||
name: str,
|
||||
module_file: str,
|
||||
import_structure: dict[str, list[str]],
|
||||
module_spec: importlib.machinery.ModuleSpec | None = None,
|
||||
doc: str | None = None,
|
||||
extra_objects: dict[str, t.Any] | None = None):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
module_file: str,
|
||||
import_structure: dict[str, list[str]],
|
||||
module_spec: importlib.machinery.ModuleSpec | None = None,
|
||||
doc: str | None = None,
|
||||
extra_objects: dict[str, t.Any] | None = None,
|
||||
):
|
||||
"""Lazily load this module as an object.
|
||||
|
||||
It does instantiate a __all__ and __dir__ for IDE support
|
||||
@@ -175,31 +184,39 @@ class LazyModule(types.ModuleType):
|
||||
It also contains a special case for all of the metadata information, such as __version__ and __version_info__.
|
||||
"""
|
||||
if name in _reserved_namespace:
|
||||
raise openllm_core.exceptions.ForbiddenAttributeError(f"'{name}' is a reserved namespace for {self._name} and should not be access nor modified.")
|
||||
raise openllm_core.exceptions.ForbiddenAttributeError(
|
||||
f"'{name}' is a reserved namespace for {self._name} and should not be access nor modified."
|
||||
)
|
||||
dunder_to_metadata = {
|
||||
'__title__': 'Name',
|
||||
'__copyright__': '',
|
||||
'__version__': 'version',
|
||||
'__version_info__': 'version',
|
||||
'__description__': 'summary',
|
||||
'__uri__': '',
|
||||
'__url__': '',
|
||||
'__author__': '',
|
||||
'__email__': '',
|
||||
'__license__': 'license',
|
||||
'__homepage__': ''
|
||||
'__title__': 'Name',
|
||||
'__copyright__': '',
|
||||
'__version__': 'version',
|
||||
'__version_info__': 'version',
|
||||
'__description__': 'summary',
|
||||
'__uri__': '',
|
||||
'__url__': '',
|
||||
'__author__': '',
|
||||
'__email__': '',
|
||||
'__license__': 'license',
|
||||
'__homepage__': '',
|
||||
}
|
||||
if name in dunder_to_metadata:
|
||||
if name not in {'__version_info__', '__copyright__', '__version__'}:
|
||||
warnings.warn(f"Accessing '{self._name}.{name}' is deprecated. Please consider using 'importlib.metadata' directly to query for openllm packaging metadata.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2)
|
||||
warnings.warn(
|
||||
f"Accessing '{self._name}.{name}' is deprecated. Please consider using 'importlib.metadata' directly to query for openllm packaging metadata.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
meta = importlib.metadata.metadata('openllm')
|
||||
project_url = dict(url.split(', ') for url in t.cast(t.List[str], meta.get_all('Project-URL')))
|
||||
if name == '__license__': return 'Apache-2.0'
|
||||
elif name == '__copyright__': return f"Copyright (c) 2023-{time.strftime('%Y')}, Aaron Pham et al."
|
||||
elif name in ('__uri__', '__url__'): return project_url['GitHub']
|
||||
elif name == '__homepage__': return project_url['Homepage']
|
||||
if name == '__license__':
|
||||
return 'Apache-2.0'
|
||||
elif name == '__copyright__':
|
||||
return f"Copyright (c) 2023-{time.strftime('%Y')}, Aaron Pham et al."
|
||||
elif name in ('__uri__', '__url__'):
|
||||
return project_url['GitHub']
|
||||
elif name == '__homepage__':
|
||||
return project_url['Homepage']
|
||||
elif name == '__version_info__':
|
||||
return VersionInfo.from_version_string(meta['version']) # similar to how attrs handle __version_info__
|
||||
elif name == '__author__':
|
||||
@@ -210,10 +227,16 @@ class LazyModule(types.ModuleType):
|
||||
if '__openllm_migration__' in self._objects:
|
||||
cur_value = self._objects['__openllm_migration__'].get(name, _sentinel)
|
||||
if cur_value is not _sentinel:
|
||||
warnings.warn(f"'{name}' is deprecated and will be removed in future version. Make sure to use '{cur_value}' instead", DeprecationWarning, stacklevel=3)
|
||||
warnings.warn(
|
||||
f"'{name}' is deprecated and will be removed in future version. Make sure to use '{cur_value}' instead",
|
||||
DeprecationWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
return getattr(self, cur_value)
|
||||
if name in self._objects: return self._objects.__getitem__(name)
|
||||
if name in self._modules: value = self._get_module(name)
|
||||
if name in self._objects:
|
||||
return self._objects.__getitem__(name)
|
||||
if name in self._modules:
|
||||
value = self._get_module(name)
|
||||
elif name in self._class_to_module.keys():
|
||||
value = getattr(self._get_module(self._class_to_module.__getitem__(name)), name)
|
||||
else:
|
||||
@@ -225,7 +248,9 @@ class LazyModule(types.ModuleType):
|
||||
try:
|
||||
return importlib.import_module('.' + module_name, self.__name__)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f'Failed to import {self.__name__}.{module_name} because of the following error (look up to see its traceback):\n{e}') from e
|
||||
raise RuntimeError(
|
||||
f'Failed to import {self.__name__}.{module_name} because of the following error (look up to see its traceback):\n{e}'
|
||||
) from e
|
||||
|
||||
# make sure this module is picklable
|
||||
def __reduce__(self) -> tuple[type[LazyModule], tuple[str, str | None, dict[str, list[str]]]]:
|
||||
|
||||
@@ -10,6 +10,7 @@ from deepmerge import Merger
|
||||
from . import dantic
|
||||
from ..exceptions import ForbiddenAttributeError
|
||||
|
||||
|
||||
config_merger = Merger([(dict, 'merge')], ['override'], ['override'])
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
@@ -18,12 +19,15 @@ if t.TYPE_CHECKING:
|
||||
from .._configuration import LLMConfig
|
||||
from .._typing_compat import AdapterType
|
||||
|
||||
|
||||
# case insensitive, but rename to conform with type
|
||||
class _PeftEnumMeta(enum.EnumMeta):
|
||||
def __getitem__(self, __key: str | t.Any, /) -> t.Any:
|
||||
if isinstance(__key, str): __key = inflection.underscore(__key).upper()
|
||||
if isinstance(__key, str):
|
||||
__key = inflection.underscore(__key).upper()
|
||||
return self._member_map_[__key]
|
||||
|
||||
|
||||
# vendorred from peft.utils.config.PeftType since we don't have hard dependency on peft
|
||||
# see https://github.com/huggingface/peft/blob/main/src/peft/utils/config.py
|
||||
class PeftType(str, enum.Enum, metaclass=_PeftEnumMeta):
|
||||
@@ -42,7 +46,8 @@ class PeftType(str, enum.Enum, metaclass=_PeftEnumMeta):
|
||||
def _missing_(cls, value: object) -> enum.Enum | None:
|
||||
if isinstance(value, str):
|
||||
normalized = inflection.underscore(value).upper()
|
||||
if normalized in cls._member_map_: return cls._member_map_[normalized]
|
||||
if normalized in cls._member_map_:
|
||||
return cls._member_map_[normalized]
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@@ -53,29 +58,43 @@ class PeftType(str, enum.Enum, metaclass=_PeftEnumMeta):
|
||||
def get(__key: str | t.Any, /) -> PeftType:
|
||||
return PeftType[__key] # type-safe getitem.
|
||||
|
||||
|
||||
PEFT_TASK_TYPE_TARGET_MAPPING = {'causal_lm': 'CAUSAL_LM', 'seq2seq_lm': 'SEQ_2_SEQ_LM'}
|
||||
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
|
||||
def _adapter_converter(value: AdapterType | str | PeftType | None) -> PeftType:
|
||||
if value is None: raise ValueError("'AdapterType' cannot be None.")
|
||||
if isinstance(value, PeftType): return value
|
||||
if value not in PeftType.supported(): raise ValueError(f"Given '{value}' is not a supported adapter type.")
|
||||
if value is None:
|
||||
raise ValueError("'AdapterType' cannot be None.")
|
||||
if isinstance(value, PeftType):
|
||||
return value
|
||||
if value not in PeftType.supported():
|
||||
raise ValueError(f"Given '{value}' is not a supported adapter type.")
|
||||
return PeftType.get(value)
|
||||
|
||||
|
||||
@attr.define(slots=True, init=True)
|
||||
class FineTuneConfig:
|
||||
adapter_type: PeftType = dantic.Field('lora',
|
||||
description=f"The type of adapter to use for fine-tuning. Available supported methods: {PeftType.supported()}, default to 'lora'",
|
||||
use_default_converter=False,
|
||||
converter=_adapter_converter)
|
||||
adapter_config: t.Dict[str, t.Any] = dantic.Field(None,
|
||||
description='The configuration for the adapter. The content of the dict depends on the adapter type.',
|
||||
validator=attr.validators.optional(attr.validators.instance_of(dict)),
|
||||
converter=attr.converters.default_if_none(factory=dict),
|
||||
use_default_converter=False)
|
||||
inference_mode: bool = dantic.Field(False, description='Whether to use this Adapter for inference', use_default_converter=False)
|
||||
llm_config_class: type[LLMConfig] = dantic.Field(None, description='The reference class to openllm.LLMConfig', use_default_converter=False)
|
||||
adapter_type: PeftType = dantic.Field(
|
||||
'lora',
|
||||
description=f"The type of adapter to use for fine-tuning. Available supported methods: {PeftType.supported()}, default to 'lora'",
|
||||
use_default_converter=False,
|
||||
converter=_adapter_converter,
|
||||
)
|
||||
adapter_config: t.Dict[str, t.Any] = dantic.Field(
|
||||
None,
|
||||
description='The configuration for the adapter. The content of the dict depends on the adapter type.',
|
||||
validator=attr.validators.optional(attr.validators.instance_of(dict)),
|
||||
converter=attr.converters.default_if_none(factory=dict),
|
||||
use_default_converter=False,
|
||||
)
|
||||
inference_mode: bool = dantic.Field(
|
||||
False, description='Whether to use this Adapter for inference', use_default_converter=False
|
||||
)
|
||||
llm_config_class: type[LLMConfig] = dantic.Field(
|
||||
None, description='The reference class to openllm.LLMConfig', use_default_converter=False
|
||||
)
|
||||
|
||||
def build(self) -> PeftConfig:
|
||||
try:
|
||||
@@ -85,13 +104,20 @@ class FineTuneConfig:
|
||||
raise ImportError('PEFT is not installed. Please install it via `pip install "openllm[fine-tune]"`.')
|
||||
adapter_config = self.adapter_config.copy()
|
||||
# no need for peft_type
|
||||
if 'peft_type' in adapter_config: adapter_config.pop('peft_type')
|
||||
if 'peft_type' in adapter_config:
|
||||
adapter_config.pop('peft_type')
|
||||
for k in {'enable_lora', 'merge_weights'}: # these keys are from older PEFT and no longer valid.
|
||||
if k in adapter_config: adapter_config.pop(k)
|
||||
if k in adapter_config:
|
||||
adapter_config.pop(k)
|
||||
# respect user set task_type if it is passed, otherwise use one managed by OpenLLM
|
||||
inference_mode = adapter_config.pop('inference_mode', self.inference_mode)
|
||||
task_type = adapter_config.pop('task_type', TaskType[self.llm_config_class.peft_task_type()])
|
||||
adapter_config = {'peft_type': self.adapter_type.value, 'task_type': task_type, 'inference_mode': inference_mode, **adapter_config}
|
||||
adapter_config = {
|
||||
'peft_type': self.adapter_type.value,
|
||||
'task_type': task_type,
|
||||
'inference_mode': inference_mode,
|
||||
**adapter_config,
|
||||
}
|
||||
return get_peft_config(adapter_config)
|
||||
|
||||
def train(self) -> FineTuneConfig:
|
||||
@@ -103,6 +129,15 @@ class FineTuneConfig:
|
||||
return self
|
||||
|
||||
def with_config(self, **attrs: t.Any) -> FineTuneConfig:
|
||||
adapter_type, inference_mode = attrs.pop('adapter_type', self.adapter_type), attrs.get('inference_mode', self.inference_mode)
|
||||
if 'llm_config_class' in attrs: raise ForbiddenAttributeError("'llm_config_class' should not be passed when using 'with_config'.")
|
||||
return attr.evolve(self, adapter_type=adapter_type, inference_mode=inference_mode, adapter_config=config_merger.merge(self.adapter_config, attrs))
|
||||
adapter_type, inference_mode = (
|
||||
attrs.pop('adapter_type', self.adapter_type),
|
||||
attrs.get('inference_mode', self.inference_mode),
|
||||
)
|
||||
if 'llm_config_class' in attrs:
|
||||
raise ForbiddenAttributeError("'llm_config_class' should not be passed when using 'with_config'.")
|
||||
return attr.evolve(
|
||||
self,
|
||||
adapter_type=adapter_type,
|
||||
inference_mode=inference_mode,
|
||||
adapter_config=config_merger.merge(self.adapter_config, attrs),
|
||||
)
|
||||
|
||||
@@ -9,18 +9,23 @@ from typing import cast
|
||||
|
||||
from packaging.version import Version
|
||||
|
||||
|
||||
__all__ = ['PackageNotFoundError', 'pkg_version_info', 'get_pkg_version', 'source_locations', 'find_spec']
|
||||
|
||||
get_pkg_version = importlib.metadata.version
|
||||
find_spec = importlib.util.find_spec
|
||||
|
||||
|
||||
def pkg_version_info(pkg_name: str | ModuleType) -> tuple[int, int, int]:
|
||||
if isinstance(pkg_name, ModuleType): pkg_name = pkg_name.__name__
|
||||
if isinstance(pkg_name, ModuleType):
|
||||
pkg_name = pkg_name.__name__
|
||||
pkg_version = Version(get_pkg_version(pkg_name))
|
||||
return pkg_version.major, pkg_version.minor, pkg_version.micro
|
||||
|
||||
|
||||
def source_locations(pkg: str) -> str | None:
|
||||
module = find_spec(pkg)
|
||||
if module is None: return None
|
||||
if module is None:
|
||||
return None
|
||||
(module_path,) = module.submodule_search_locations # type: ignore[misc]
|
||||
return cast(str, module_path)
|
||||
|
||||
@@ -8,11 +8,13 @@ import orjson
|
||||
|
||||
from openllm_core import utils
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm_core._typing_compat import TypeAlias
|
||||
|
||||
ReprArgs: TypeAlias = t.Generator[t.Tuple[t.Optional[str], t.Any], None, None]
|
||||
|
||||
|
||||
class ReprMixin:
|
||||
@property
|
||||
@abstractmethod
|
||||
|
||||
@@ -9,14 +9,30 @@ from cattr import Converter
|
||||
from cattr.gen import make_dict_structure_fn
|
||||
from cattr.gen import make_dict_unstructure_fn
|
||||
|
||||
|
||||
converter = Converter(omit_if_default=True)
|
||||
|
||||
def datetime_structure_hook(dt_like: str | datetime | t.Any, _: t.Any) -> datetime:
|
||||
if isinstance(dt_like, str): return datetime.fromisoformat(dt_like)
|
||||
elif isinstance(dt_like, datetime): return dt_like
|
||||
else: raise Exception(f"Unable to parse datetime from '{dt_like}'")
|
||||
|
||||
converter.register_structure_hook_factory(attr.has, lambda cls: make_dict_structure_fn(cls, converter, _cattrs_forbid_extra_keys=getattr(cls, '__forbid_extra_keys__', False)))
|
||||
converter.register_unstructure_hook_factory(attr.has, lambda cls: make_dict_unstructure_fn(cls, converter, _cattrs_omit_if_default=getattr(cls, '__omit_if_default__', False)))
|
||||
def datetime_structure_hook(dt_like: str | datetime | t.Any, _: t.Any) -> datetime:
|
||||
if isinstance(dt_like, str):
|
||||
return datetime.fromisoformat(dt_like)
|
||||
elif isinstance(dt_like, datetime):
|
||||
return dt_like
|
||||
else:
|
||||
raise Exception(f"Unable to parse datetime from '{dt_like}'")
|
||||
|
||||
|
||||
converter.register_structure_hook_factory(
|
||||
attr.has,
|
||||
lambda cls: make_dict_structure_fn(
|
||||
cls, converter, _cattrs_forbid_extra_keys=getattr(cls, '__forbid_extra_keys__', False)
|
||||
),
|
||||
)
|
||||
converter.register_unstructure_hook_factory(
|
||||
attr.has,
|
||||
lambda cls: make_dict_unstructure_fn(
|
||||
cls, converter, _cattrs_omit_if_default=getattr(cls, '__omit_if_default__', False)
|
||||
),
|
||||
)
|
||||
converter.register_structure_hook(datetime, datetime_structure_hook)
|
||||
converter.register_unstructure_hook(datetime, lambda dt: dt.isoformat())
|
||||
|
||||
@@ -8,6 +8,7 @@ deploy, and monitor any LLMs with ease.
|
||||
* Online Serving with HTTP, gRPC, SSE(coming soon) or custom API
|
||||
* Native integration with BentoML and LangChain for custom LLM apps
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import logging as _logging
|
||||
import os as _os
|
||||
@@ -51,29 +52,41 @@ else:
|
||||
# configuration for bitsandbytes before import
|
||||
_os.environ['BITSANDBYTES_NOWELCOME'] = _os.environ.get('BITSANDBYTES_NOWELCOME', '1')
|
||||
# NOTE: The following warnings from bitsandbytes, and probably not that important for users to see when DEBUG is False
|
||||
_warnings.filterwarnings('ignore', message='MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization')
|
||||
_warnings.filterwarnings('ignore', message='MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization')
|
||||
_warnings.filterwarnings(
|
||||
'ignore', message='MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization'
|
||||
)
|
||||
_warnings.filterwarnings(
|
||||
'ignore', message='MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization'
|
||||
)
|
||||
_warnings.filterwarnings('ignore', message='The installed version of bitsandbytes was compiled without GPU support.')
|
||||
# NOTE: ignore the following warning from ghapi as it is not important for users
|
||||
_warnings.filterwarnings('ignore', message='Neither GITHUB_TOKEN nor GITHUB_JWT_TOKEN found: running as unauthenticated')
|
||||
_warnings.filterwarnings(
|
||||
'ignore', message='Neither GITHUB_TOKEN nor GITHUB_JWT_TOKEN found: running as unauthenticated'
|
||||
)
|
||||
|
||||
_import_structure: dict[str, list[str]] = {
|
||||
'exceptions': [],
|
||||
'client': [],
|
||||
'bundle': [],
|
||||
'playground': [],
|
||||
'testing': [],
|
||||
'prompts': ['PromptTemplate'],
|
||||
'protocol': [],
|
||||
'utils': [],
|
||||
'_deprecated': ['Runner'],
|
||||
'_strategies': ['CascadingResourceStrategy', 'get_resource'],
|
||||
'entrypoints': ['mount_entrypoints'],
|
||||
'serialisation': ['ggml', 'transformers'],
|
||||
'cli._sdk': ['start', 'start_grpc', 'build', 'import_model', 'list_models'],
|
||||
'_quantisation': ['infer_quantisation_config'],
|
||||
'_llm': ['LLM', 'LLMRunner', 'LLMRunnable'],
|
||||
'_generation': ['StopSequenceCriteria', 'StopOnTokens', 'LogitsProcessorList', 'StoppingCriteriaList', 'prepare_logits_processor'],
|
||||
'exceptions': [],
|
||||
'client': [],
|
||||
'bundle': [],
|
||||
'playground': [],
|
||||
'testing': [],
|
||||
'prompts': ['PromptTemplate'],
|
||||
'protocol': [],
|
||||
'utils': [],
|
||||
'_deprecated': ['Runner'],
|
||||
'_strategies': ['CascadingResourceStrategy', 'get_resource'],
|
||||
'entrypoints': ['mount_entrypoints'],
|
||||
'serialisation': ['ggml', 'transformers'],
|
||||
'cli._sdk': ['start', 'start_grpc', 'build', 'import_model', 'list_models'],
|
||||
'_quantisation': ['infer_quantisation_config'],
|
||||
'_llm': ['LLM', 'LLMRunner', 'LLMRunnable'],
|
||||
'_generation': [
|
||||
'StopSequenceCriteria',
|
||||
'StopOnTokens',
|
||||
'LogitsProcessorList',
|
||||
'StoppingCriteriaList',
|
||||
'prepare_logits_processor',
|
||||
],
|
||||
}
|
||||
COMPILED = _Path(__file__).suffix in ('.pyd', '.so')
|
||||
|
||||
@@ -109,7 +122,9 @@ if _t.TYPE_CHECKING:
|
||||
from .serialisation import transformers as transformers
|
||||
|
||||
# NOTE: update this to sys.modules[__name__] once mypy_extensions can recognize __spec__
|
||||
__lazy = openllm_core.utils.LazyModule(__name__, globals()['__file__'], _import_structure, extra_objects={'COMPILED': COMPILED})
|
||||
__lazy = openllm_core.utils.LazyModule(
|
||||
__name__, globals()['__file__'], _import_structure, extra_objects={'COMPILED': COMPILED}
|
||||
)
|
||||
__all__ = __lazy.__all__
|
||||
__dir__ = __lazy.__dir__
|
||||
__getattr__ = __lazy.__getattr__
|
||||
|
||||
@@ -6,8 +6,11 @@ Usage:
|
||||
To start any OpenLLM model:
|
||||
openllm start <model_name> --options ...
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from openllm.cli.entrypoint import cli
|
||||
|
||||
cli()
|
||||
|
||||
@@ -9,27 +9,33 @@ from openllm_core._typing_compat import LiteralBackend
|
||||
from openllm_core.utils import first_not_none
|
||||
from openllm_core.utils import is_vllm_available
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm_core import LLMConfig
|
||||
from openllm_core._typing_compat import ParamSpec
|
||||
|
||||
from ._llm import LLMRunner
|
||||
|
||||
P = ParamSpec('P')
|
||||
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
|
||||
def _mark_deprecated(fn: t.Callable[P, t.Any]) -> t.Callable[P, t.Any]:
|
||||
_object_setattr(fn, '__deprecated__', True)
|
||||
return fn
|
||||
|
||||
|
||||
@_mark_deprecated
|
||||
def Runner(model_name: str,
|
||||
ensure_available: bool = True,
|
||||
init_local: bool = False,
|
||||
backend: LiteralBackend | None = None,
|
||||
llm_config: LLMConfig | None = None,
|
||||
**attrs: t.Any) -> LLMRunner[t.Any, t.Any]:
|
||||
'''Create a Runner for given LLM. For a list of currently supported LLM, check out 'openllm models'.
|
||||
def Runner(
|
||||
model_name: str,
|
||||
ensure_available: bool = True,
|
||||
init_local: bool = False,
|
||||
backend: LiteralBackend | None = None,
|
||||
llm_config: LLMConfig | None = None,
|
||||
**attrs: t.Any,
|
||||
) -> LLMRunner[t.Any, t.Any]:
|
||||
"""Create a Runner for given LLM. For a list of currently supported LLM, check out 'openllm models'.
|
||||
|
||||
> [!WARNING]
|
||||
> This method is now deprecated and in favor of 'openllm.LLM.runner'
|
||||
@@ -54,11 +60,13 @@ def Runner(model_name: str,
|
||||
llm_config: Optional ``openllm.LLMConfig`` to initialise this ``openllm.LLMRunner``.
|
||||
init_local: If True, it will initialize the model locally. This is useful if you want to run the model locally. (Symmetrical to bentoml.Runner.init_local())
|
||||
**attrs: The rest of kwargs will then be passed to the LLM. Refer to the LLM documentation for the kwargs behaviour
|
||||
'''
|
||||
"""
|
||||
from ._llm import LLM
|
||||
if llm_config is None: llm_config = openllm.AutoConfig.for_model(model_name)
|
||||
|
||||
if llm_config is None:
|
||||
llm_config = openllm.AutoConfig.for_model(model_name)
|
||||
model_id = attrs.get('model_id', default=os.getenv('OPENLLM_MODEL_ID', llm_config['default_id']))
|
||||
_RUNNER_MSG = f'''\
|
||||
_RUNNER_MSG = f"""\
|
||||
Using 'openllm.Runner' is now deprecated. Make sure to switch to the following syntax:
|
||||
|
||||
```python
|
||||
@@ -70,24 +78,31 @@ def Runner(model_name: str,
|
||||
async def chat(input: str) -> str:
|
||||
async for it in llm.generate_iterator(input): print(it)
|
||||
```
|
||||
'''
|
||||
"""
|
||||
warnings.warn(_RUNNER_MSG, DeprecationWarning, stacklevel=2)
|
||||
attrs.update({
|
||||
attrs.update(
|
||||
{
|
||||
'model_id': model_id,
|
||||
'quantize': os.getenv('OPENLLM_QUANTIZE', attrs.get('quantize', None)),
|
||||
'serialisation': first_not_none(attrs.get('serialisation'), os.environ.get('OPENLLM_SERIALIZATION'), default=llm_config['serialisation']),
|
||||
'serialisation': first_not_none(
|
||||
attrs.get('serialisation'), os.environ.get('OPENLLM_SERIALIZATION'), default=llm_config['serialisation']
|
||||
),
|
||||
'system_message': first_not_none(os.environ.get('OPENLLM_SYSTEM_MESSAGE'), attrs.get('system_message'), None),
|
||||
'prompt_template': first_not_none(os.environ.get('OPENLLM_PROMPT_TEMPLATE'), attrs.get('prompt_template'), None),
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
backend = t.cast(LiteralBackend, first_not_none(backend, default='vllm' if is_vllm_available() else 'pt'))
|
||||
llm = LLM[t.Any, t.Any](backend=backend, llm_config=llm_config, **attrs)
|
||||
if init_local: llm.runner.init_local(quiet=True)
|
||||
if init_local:
|
||||
llm.runner.init_local(quiet=True)
|
||||
return llm.runner
|
||||
|
||||
|
||||
_DEPRECATED = {k: v for k, v in locals().items() if getattr(v, '__deprecated__', False)}
|
||||
|
||||
__all__ = list(_DEPRECATED)
|
||||
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return sorted(_DEPRECATED.keys())
|
||||
|
||||
@@ -4,6 +4,7 @@ import typing as t
|
||||
|
||||
import transformers
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
@@ -13,18 +14,30 @@ if t.TYPE_CHECKING:
|
||||
LogitsProcessorList = transformers.LogitsProcessorList
|
||||
StoppingCriteriaList = transformers.StoppingCriteriaList
|
||||
|
||||
|
||||
class StopSequenceCriteria(transformers.StoppingCriteria):
|
||||
def __init__(self, stop_sequences: str | list[str], tokenizer: transformers.PreTrainedTokenizer | transformers.PreTrainedTokenizerBase | transformers.PreTrainedTokenizerFast):
|
||||
if isinstance(stop_sequences, str): stop_sequences = [stop_sequences]
|
||||
def __init__(
|
||||
self,
|
||||
stop_sequences: str | list[str],
|
||||
tokenizer: transformers.PreTrainedTokenizer
|
||||
| transformers.PreTrainedTokenizerBase
|
||||
| transformers.PreTrainedTokenizerFast,
|
||||
):
|
||||
if isinstance(stop_sequences, str):
|
||||
stop_sequences = [stop_sequences]
|
||||
self.stop_sequences, self.tokenizer = stop_sequences, tokenizer
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: t.Any, **_: t.Any) -> bool:
|
||||
return any(self.tokenizer.decode(input_ids.tolist()[0]).endswith(stop_sequence) for stop_sequence in self.stop_sequences)
|
||||
return any(
|
||||
self.tokenizer.decode(input_ids.tolist()[0]).endswith(stop_sequence) for stop_sequence in self.stop_sequences
|
||||
)
|
||||
|
||||
|
||||
class StopOnTokens(transformers.StoppingCriteria):
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **_: t.Any) -> bool:
|
||||
return input_ids[0][-1] in {50278, 50279, 50277, 1, 0}
|
||||
|
||||
|
||||
def prepare_logits_processor(config: openllm.LLMConfig) -> transformers.LogitsProcessorList:
|
||||
generation_config = config.generation_config
|
||||
logits_processor = transformers.LogitsProcessorList()
|
||||
@@ -34,24 +47,31 @@ def prepare_logits_processor(config: openllm.LLMConfig) -> transformers.LogitsPr
|
||||
logits_processor.append(transformers.RepetitionPenaltyLogitsProcessor(generation_config['repetition_penalty']))
|
||||
if 1e-8 <= generation_config['top_p']:
|
||||
logits_processor.append(transformers.TopPLogitsWarper(generation_config['top_p']))
|
||||
if generation_config['top_k'] > 0: logits_processor.append(transformers.TopKLogitsWarper(generation_config['top_k']))
|
||||
if generation_config['top_k'] > 0:
|
||||
logits_processor.append(transformers.TopKLogitsWarper(generation_config['top_k']))
|
||||
return logits_processor
|
||||
|
||||
|
||||
# NOTE: The ordering here is important. Some models have two of these and we have a preference for which value gets used.
|
||||
SEQLEN_KEYS = ['max_sequence_length', 'seq_length', 'max_position_embeddings', 'max_seq_len', 'model_max_length']
|
||||
|
||||
|
||||
def get_context_length(config: transformers.PretrainedConfig) -> int:
|
||||
rope_scaling = getattr(config, 'rope_scaling', None)
|
||||
rope_scaling_factor = config.rope_scaling['factor'] if rope_scaling else 1.0
|
||||
for key in SEQLEN_KEYS:
|
||||
if getattr(config, key, None) is not None: return int(rope_scaling_factor * getattr(config, key))
|
||||
if getattr(config, key, None) is not None:
|
||||
return int(rope_scaling_factor * getattr(config, key))
|
||||
return 2048
|
||||
|
||||
|
||||
def is_sentence_complete(output: str) -> bool:
|
||||
return output.endswith(('.', '?', '!', '...', '。', '?', '!', '…', '"', "'", '”'))
|
||||
|
||||
|
||||
def is_partial_stop(output: str, stop_str: str) -> bool:
|
||||
"""Check whether the output contains a partial stop str."""
|
||||
for i in range(0, min(len(output), len(stop_str))):
|
||||
if stop_str.startswith(output[-i:]): return True
|
||||
if stop_str.startswith(output[-i:]):
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -54,6 +54,7 @@ from .exceptions import ForbiddenAttributeError
|
||||
from .exceptions import OpenLLMException
|
||||
from .serialisation.constants import PEFT_CONFIG_NAME
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import peft
|
||||
import torch
|
||||
@@ -77,11 +78,14 @@ P = ParamSpec('P')
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def normalise_model_name(name: str) -> str:
|
||||
if validate_is_path(name): return os.path.basename(resolve_filepath(name))
|
||||
if validate_is_path(name):
|
||||
return os.path.basename(resolve_filepath(name))
|
||||
name = name.replace('/', '--')
|
||||
return inflection.dasherize(name)
|
||||
|
||||
|
||||
def resolve_peft_config_type(adapter_map: dict[str, str]) -> AdapterMap:
|
||||
"""Resolve the type of the PeftConfig given the adapter_map.
|
||||
|
||||
@@ -93,7 +97,8 @@ def resolve_peft_config_type(adapter_map: dict[str, str]) -> AdapterMap:
|
||||
resolved: AdapterMap = {}
|
||||
_has_set_default = False
|
||||
for path_or_adapter_id, name in adapter_map.items():
|
||||
if name is None: raise ValueError('Adapter name must be specified.')
|
||||
if name is None:
|
||||
raise ValueError('Adapter name must be specified.')
|
||||
if os.path.isfile(os.path.join(path_or_adapter_id, PEFT_CONFIG_NAME)):
|
||||
config_file = os.path.join(path_or_adapter_id, PEFT_CONFIG_NAME)
|
||||
else:
|
||||
@@ -105,13 +110,16 @@ def resolve_peft_config_type(adapter_map: dict[str, str]) -> AdapterMap:
|
||||
resolved_config = orjson.loads(file.read())
|
||||
# all peft_type should be available in PEFT_CONFIG_NAME
|
||||
_peft_type: AdapterType = resolved_config['peft_type'].lower()
|
||||
if _peft_type not in resolved: resolved[_peft_type] = ()
|
||||
if _peft_type not in resolved:
|
||||
resolved[_peft_type] = ()
|
||||
resolved[_peft_type] += (_AdapterTuple((path_or_adapter_id, name, resolved_config)),)
|
||||
return resolved
|
||||
|
||||
|
||||
_reserved_namespace = {'model', 'tokenizer', 'runner', 'import_kwargs'}
|
||||
_AdapterTuple: type[AdapterTuple] = codegen.make_attr_tuple_class('AdapterTuple', ['adapter_id', 'name', 'config'])
|
||||
|
||||
|
||||
@attr.define(slots=True, repr=False, init=False)
|
||||
class LLM(t.Generic[M, T]):
|
||||
_model_id: str
|
||||
@@ -140,30 +148,44 @@ class LLM(t.Generic[M, T]):
|
||||
device: 'torch.device | None' = None
|
||||
|
||||
def __attrs_post_init__(self) -> None:
|
||||
if self.__llm_backend__ == 'pt': self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
if self.__llm_backend__ == 'pt':
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
def __init__(self,
|
||||
model_id: str,
|
||||
model_version: str | None = None,
|
||||
model_tag: str | bentoml.Tag | None = None,
|
||||
prompt_template: PromptTemplate | str | None = None,
|
||||
system_message: str | None = None,
|
||||
llm_config: LLMConfig | None = None,
|
||||
backend: LiteralBackend | None = None,
|
||||
*args: t.Any,
|
||||
quantize: LiteralQuantise | None = None,
|
||||
quantization_config: transformers.BitsAndBytesConfig | transformers.GPTQConfig | transformers.AwqConfig | None = None,
|
||||
adapter_map: dict[str, str] | None = None,
|
||||
serialisation: LiteralSerialisation = 'safetensors',
|
||||
trust_remote_code: bool = False,
|
||||
**attrs: t.Any):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
model_version: str | None = None,
|
||||
model_tag: str | bentoml.Tag | None = None,
|
||||
prompt_template: PromptTemplate | str | None = None,
|
||||
system_message: str | None = None,
|
||||
llm_config: LLMConfig | None = None,
|
||||
backend: LiteralBackend | None = None,
|
||||
*args: t.Any,
|
||||
quantize: LiteralQuantise | None = None,
|
||||
quantization_config: transformers.BitsAndBytesConfig
|
||||
| transformers.GPTQConfig
|
||||
| transformers.AwqConfig
|
||||
| None = None,
|
||||
adapter_map: dict[str, str] | None = None,
|
||||
serialisation: LiteralSerialisation = 'safetensors',
|
||||
trust_remote_code: bool = False,
|
||||
**attrs: t.Any,
|
||||
):
|
||||
# low_cpu_mem_usage is only available for model this is helpful on system with low memory to avoid OOM
|
||||
low_cpu_mem_usage = attrs.pop('low_cpu_mem_usage', True)
|
||||
_local = False
|
||||
if validate_is_path(model_id): model_id, _local = resolve_filepath(model_id), True
|
||||
backend = t.cast(LiteralBackend, first_not_none(backend, os.getenv('OPENLLM_BACKEND'), default='vllm' if openllm.utils.is_vllm_available() else 'pt'))
|
||||
if validate_is_path(model_id):
|
||||
model_id, _local = resolve_filepath(model_id), True
|
||||
backend = t.cast(
|
||||
LiteralBackend,
|
||||
first_not_none(
|
||||
backend, os.getenv('OPENLLM_BACKEND'), default='vllm' if openllm.utils.is_vllm_available() else 'pt'
|
||||
),
|
||||
)
|
||||
|
||||
quantize = first_not_none(quantize, t.cast(t.Optional[LiteralQuantise], os.getenv('OPENLLM_QUANTIZE')), default=None)
|
||||
quantize = first_not_none(
|
||||
quantize, t.cast(t.Optional[LiteralQuantise], os.getenv('OPENLLM_QUANTIZE')), default=None
|
||||
)
|
||||
# elif quantization_config is None and quantize is not None:
|
||||
# quantization_config, attrs = infer_quantisation_config(self, quantize, **attrs)
|
||||
attrs.update({'low_cpu_mem_usage': low_cpu_mem_usage})
|
||||
@@ -171,28 +193,35 @@ class LLM(t.Generic[M, T]):
|
||||
# parsing tokenizer and model kwargs, as the hierarchy is param pass > default
|
||||
model_attrs, tokenizer_attrs = flatten_attrs(**attrs)
|
||||
|
||||
if adapter_map is not None and not is_peft_available(): raise RuntimeError("LoRA adapter requires 'peft' to be installed. Make sure to do 'pip install \"openllm[fine-tune]\"'")
|
||||
if isinstance(prompt_template, str): prompt_template = PromptTemplate(prompt_template)
|
||||
if adapter_map is not None and not is_peft_available():
|
||||
raise RuntimeError(
|
||||
"LoRA adapter requires 'peft' to be installed. Make sure to do 'pip install \"openllm[fine-tune]\"'"
|
||||
)
|
||||
if isinstance(prompt_template, str):
|
||||
prompt_template = PromptTemplate(prompt_template)
|
||||
if model_tag is None:
|
||||
model_tag, model_version = self._make_tag_components(model_id, model_version, backend=backend)
|
||||
if model_version: model_tag = f'{model_tag}:{model_version}'
|
||||
if model_version:
|
||||
model_tag = f'{model_tag}:{model_version}'
|
||||
|
||||
self.__attrs_init__(model_id=model_id,
|
||||
revision=model_version,
|
||||
tag=bentoml.Tag.from_taglike(t.cast(t.Union[str, bentoml.Tag], model_tag)),
|
||||
quantization_config=quantization_config,
|
||||
quantise=quantize,
|
||||
model_decls=args,
|
||||
model_attrs=dict(**self.import_kwargs[0], **model_attrs),
|
||||
tokenizer_attrs=dict(**self.import_kwargs[-1], **tokenizer_attrs),
|
||||
adapter_map=resolve_peft_config_type(adapter_map) if adapter_map is not None else None,
|
||||
serialisation=serialisation,
|
||||
local=_local,
|
||||
prompt_template=prompt_template,
|
||||
system_message=system_message,
|
||||
llm_backend__=backend,
|
||||
llm_config__=llm_config,
|
||||
llm_trust_remote_code__=trust_remote_code)
|
||||
self.__attrs_init__(
|
||||
model_id=model_id,
|
||||
revision=model_version,
|
||||
tag=bentoml.Tag.from_taglike(t.cast(t.Union[str, bentoml.Tag], model_tag)),
|
||||
quantization_config=quantization_config,
|
||||
quantise=quantize,
|
||||
model_decls=args,
|
||||
model_attrs=dict(**self.import_kwargs[0], **model_attrs),
|
||||
tokenizer_attrs=dict(**self.import_kwargs[-1], **tokenizer_attrs),
|
||||
adapter_map=resolve_peft_config_type(adapter_map) if adapter_map is not None else None,
|
||||
serialisation=serialisation,
|
||||
local=_local,
|
||||
prompt_template=prompt_template,
|
||||
system_message=system_message,
|
||||
llm_backend__=backend,
|
||||
llm_config__=llm_config,
|
||||
llm_trust_remote_code__=trust_remote_code,
|
||||
)
|
||||
|
||||
try:
|
||||
model = bentoml.models.get(self.tag)
|
||||
@@ -202,13 +231,24 @@ class LLM(t.Generic[M, T]):
|
||||
self._tag = model.tag
|
||||
|
||||
@apply(lambda val: tuple(str.lower(i) if i else i for i in val))
|
||||
def _make_tag_components(self, model_id: str, model_version: str | None, backend: LiteralBackend) -> tuple[str, str | None]:
|
||||
def _make_tag_components(
|
||||
self, model_id: str, model_version: str | None, backend: LiteralBackend
|
||||
) -> tuple[str, str | None]:
|
||||
"""Return a valid tag name (<backend>-<repo>--<model_id>) and its tag version."""
|
||||
model_id, *maybe_revision = model_id.rsplit(':')
|
||||
if len(maybe_revision) > 0:
|
||||
if model_version is not None: logger.warning("revision is specified within 'model_id' (%s), and 'model_version=%s' will be ignored.", maybe_revision[0], model_version)
|
||||
if model_version is not None:
|
||||
logger.warning(
|
||||
"revision is specified within 'model_id' (%s), and 'model_version=%s' will be ignored.",
|
||||
maybe_revision[0],
|
||||
model_version,
|
||||
)
|
||||
model_version = maybe_revision[0]
|
||||
if validate_is_path(model_id): model_id, model_version = resolve_filepath(model_id), first_not_none(model_version, default=generate_hash_from_file(model_id))
|
||||
if validate_is_path(model_id):
|
||||
model_id, model_version = (
|
||||
resolve_filepath(model_id),
|
||||
first_not_none(model_version, default=generate_hash_from_file(model_id)),
|
||||
)
|
||||
return f'{backend}-{normalise_model_name(model_id)}', model_version
|
||||
|
||||
# yapf: disable
|
||||
@@ -257,28 +297,44 @@ class LLM(t.Generic[M, T]):
|
||||
try:
|
||||
import peft as _ # noqa: F401
|
||||
except ImportError as err:
|
||||
raise MissingDependencyError("Failed to import 'peft'. Make sure to do 'pip install \"openllm[fine-tune]\"'") from err
|
||||
if not self.has_adapters: raise AttributeError('Adapter map is not available.')
|
||||
raise MissingDependencyError(
|
||||
"Failed to import 'peft'. Make sure to do 'pip install \"openllm[fine-tune]\"'"
|
||||
) from err
|
||||
if not self.has_adapters:
|
||||
raise AttributeError('Adapter map is not available.')
|
||||
assert self._adapter_map is not None
|
||||
if self.__llm_adapter_map__ is None:
|
||||
_map: ResolvedAdapterMap = {k: {} for k in self._adapter_map}
|
||||
for adapter_type, adapter_tuple in self._adapter_map.items():
|
||||
base = first_not_none(self.config['fine_tune_strategies'].get(adapter_type), default=self.config.make_fine_tune_config(adapter_type))
|
||||
base = first_not_none(
|
||||
self.config['fine_tune_strategies'].get(adapter_type),
|
||||
default=self.config.make_fine_tune_config(adapter_type),
|
||||
)
|
||||
for adapter in adapter_tuple:
|
||||
_map[adapter_type][adapter.name] = (base.with_config(**adapter.config).build(), adapter.adapter_id)
|
||||
self.__llm_adapter_map__ = _map
|
||||
return self.__llm_adapter_map__
|
||||
|
||||
def prepare_for_training(self,
|
||||
adapter_type: AdapterType = 'lora',
|
||||
use_gradient_checking: bool = True,
|
||||
**attrs: t.Any) -> tuple[peft.PeftModel | peft.PeftModelForCausalLM | peft.PeftModelForSeq2SeqLM, T]:
|
||||
def prepare_for_training(
|
||||
self, adapter_type: AdapterType = 'lora', use_gradient_checking: bool = True, **attrs: t.Any
|
||||
) -> tuple[peft.PeftModel | peft.PeftModelForCausalLM | peft.PeftModelForSeq2SeqLM, T]:
|
||||
from peft import get_peft_model
|
||||
from peft import prepare_model_for_kbit_training
|
||||
peft_config = self.config['fine_tune_strategies'].get(adapter_type, self.config.make_fine_tune_config(adapter_type)).train().with_config(**attrs).build()
|
||||
if self.has_adapters: raise ValueError('Adapter should not be specified when fine-tuning.')
|
||||
model = get_peft_model(prepare_model_for_kbit_training(self.model, use_gradient_checkpointing=use_gradient_checking), peft_config) # type: ignore[no-untyped-call]
|
||||
if DEBUG: model.print_trainable_parameters() # type: ignore[no-untyped-call]
|
||||
|
||||
peft_config = (
|
||||
self.config['fine_tune_strategies']
|
||||
.get(adapter_type, self.config.make_fine_tune_config(adapter_type))
|
||||
.train()
|
||||
.with_config(**attrs)
|
||||
.build()
|
||||
)
|
||||
if self.has_adapters:
|
||||
raise ValueError('Adapter should not be specified when fine-tuning.')
|
||||
model = get_peft_model(
|
||||
prepare_model_for_kbit_training(self.model, use_gradient_checkpointing=use_gradient_checking), peft_config
|
||||
) # type: ignore[no-untyped-call]
|
||||
if DEBUG:
|
||||
model.print_trainable_parameters() # type: ignore[no-untyped-call]
|
||||
return model, self.tokenizer
|
||||
|
||||
@property
|
||||
@@ -288,13 +344,22 @@ class LLM(t.Generic[M, T]):
|
||||
# If OOM, then it is probably you don't have enough VRAM to run this model.
|
||||
if self.__llm_backend__ == 'pt':
|
||||
if is_torch_available():
|
||||
loaded_in_kbit = getattr(model, 'is_loaded_in_8bit', False) or getattr(model, 'is_loaded_in_4bit', False) or getattr(model, 'is_quantized', False)
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1 and not loaded_in_kbit and not isinstance(model, transformers.Pipeline):
|
||||
loaded_in_kbit = (
|
||||
getattr(model, 'is_loaded_in_8bit', False)
|
||||
or getattr(model, 'is_loaded_in_4bit', False)
|
||||
or getattr(model, 'is_quantized', False)
|
||||
)
|
||||
if (
|
||||
torch.cuda.is_available()
|
||||
and torch.cuda.device_count() == 1
|
||||
and not loaded_in_kbit
|
||||
and not isinstance(model, transformers.Pipeline)
|
||||
):
|
||||
try:
|
||||
model = model.to('cuda')
|
||||
except Exception as err:
|
||||
raise OpenLLMException(
|
||||
f'Failed to load {self} into GPU: {err}\nTip: If you run into OOM issue, maybe try different offload strategy. See https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/quantization#offload-between-cpu-and-gpu for more information.'
|
||||
f'Failed to load {self} into GPU: {err}\nTip: If you run into OOM issue, maybe try different offload strategy. See https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/quantization#offload-between-cpu-and-gpu for more information.'
|
||||
) from err
|
||||
if self.has_adapters:
|
||||
logger.debug('Applying the following adapters: %s', self.adapter_map)
|
||||
@@ -307,83 +372,117 @@ class LLM(t.Generic[M, T]):
|
||||
@property
|
||||
def tokenizer(self) -> T:
|
||||
# NOTE: the signature of load_tokenizer here is the wrapper under _wrapped_load_tokenizer
|
||||
if self.__llm_tokenizer__ is None: self.__llm_tokenizer__ = openllm.serialisation.load_tokenizer(self, **self.llm_parameters[-1])
|
||||
if self.__llm_tokenizer__ is None:
|
||||
self.__llm_tokenizer__ = openllm.serialisation.load_tokenizer(self, **self.llm_parameters[-1])
|
||||
return self.__llm_tokenizer__
|
||||
|
||||
@property
|
||||
def runner(self) -> LLMRunner[M, T]:
|
||||
if self.__llm_runner__ is None: self.__llm_runner__ = _RunnerFactory(self)
|
||||
if self.__llm_runner__ is None:
|
||||
self.__llm_runner__ = _RunnerFactory(self)
|
||||
return self.__llm_runner__
|
||||
|
||||
async def generate(self,
|
||||
prompt: str | None,
|
||||
prompt_token_ids: list[int] | None = None,
|
||||
stop: str | t.Iterable[str] | None = None,
|
||||
stop_token_ids: list[int] | None = None,
|
||||
request_id: str | None = None,
|
||||
adapter_name: str | None = None,
|
||||
**attrs: t.Any) -> GenerationOutput:
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str | None,
|
||||
prompt_token_ids: list[int] | None = None,
|
||||
stop: str | t.Iterable[str] | None = None,
|
||||
stop_token_ids: list[int] | None = None,
|
||||
request_id: str | None = None,
|
||||
adapter_name: str | None = None,
|
||||
**attrs: t.Any,
|
||||
) -> GenerationOutput:
|
||||
config = self.config.model_construct_env(**attrs)
|
||||
texts: list[list[str]] = [[]] * config['n']
|
||||
token_ids: list[list[int]] = [[]] * config['n']
|
||||
final_result: GenerationOutput | None = None
|
||||
async for result in self.generate_iterator(prompt, prompt_token_ids, stop, stop_token_ids, request_id, adapter_name, **config.model_dump(flatten=True)):
|
||||
async for result in self.generate_iterator(
|
||||
prompt, prompt_token_ids, stop, stop_token_ids, request_id, adapter_name, **config.model_dump(flatten=True)
|
||||
):
|
||||
for output in result.outputs:
|
||||
texts[output.index].append(output.text)
|
||||
token_ids[output.index].extend(output.token_ids)
|
||||
final_result = result
|
||||
if final_result is None: raise RuntimeError('No result is returned.')
|
||||
return final_result.with_options(prompt=prompt, outputs=[output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index]) for output in final_result.outputs])
|
||||
if final_result is None:
|
||||
raise RuntimeError('No result is returned.')
|
||||
return final_result.with_options(
|
||||
prompt=prompt,
|
||||
outputs=[
|
||||
output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index])
|
||||
for output in final_result.outputs
|
||||
],
|
||||
)
|
||||
|
||||
async def generate_iterator(self,
|
||||
prompt: str | None,
|
||||
prompt_token_ids: list[int] | None = None,
|
||||
stop: str | t.Iterable[str] | None = None,
|
||||
stop_token_ids: list[int] | None = None,
|
||||
request_id: str | None = None,
|
||||
adapter_name: str | None = None,
|
||||
**attrs: t.Any) -> t.AsyncGenerator[GenerationOutput, None]:
|
||||
async def generate_iterator(
|
||||
self,
|
||||
prompt: str | None,
|
||||
prompt_token_ids: list[int] | None = None,
|
||||
stop: str | t.Iterable[str] | None = None,
|
||||
stop_token_ids: list[int] | None = None,
|
||||
request_id: str | None = None,
|
||||
adapter_name: str | None = None,
|
||||
**attrs: t.Any,
|
||||
) -> t.AsyncGenerator[GenerationOutput, None]:
|
||||
if isinstance(self.runner._runner_handle, DummyRunnerHandle):
|
||||
if os.getenv('BENTO_PATH') is not None: raise RuntimeError('Runner client failed to set up correctly.')
|
||||
else: self.runner.init_local(quiet=True)
|
||||
if os.getenv('BENTO_PATH') is not None:
|
||||
raise RuntimeError('Runner client failed to set up correctly.')
|
||||
else:
|
||||
self.runner.init_local(quiet=True)
|
||||
|
||||
config = self.config.model_construct_env(**attrs)
|
||||
|
||||
if stop_token_ids is None: stop_token_ids = []
|
||||
if self.tokenizer.eos_token_id not in stop_token_ids: stop_token_ids.append(self.tokenizer.eos_token_id)
|
||||
if stop is None: stop = set()
|
||||
elif isinstance(stop, str): stop = {stop}
|
||||
else: stop = set(stop)
|
||||
if stop_token_ids is None:
|
||||
stop_token_ids = []
|
||||
if self.tokenizer.eos_token_id not in stop_token_ids:
|
||||
stop_token_ids.append(self.tokenizer.eos_token_id)
|
||||
if stop is None:
|
||||
stop = set()
|
||||
elif isinstance(stop, str):
|
||||
stop = {stop}
|
||||
else:
|
||||
stop = set(stop)
|
||||
for tid in stop_token_ids:
|
||||
if tid: stop.add(self.tokenizer.decode(tid))
|
||||
if tid:
|
||||
stop.add(self.tokenizer.decode(tid))
|
||||
|
||||
if prompt_token_ids is None:
|
||||
if prompt is None: raise ValueError('Either prompt or prompt_token_ids must be specified.')
|
||||
if prompt is None:
|
||||
raise ValueError('Either prompt or prompt_token_ids must be specified.')
|
||||
prompt_token_ids = self.tokenizer.encode(prompt)
|
||||
|
||||
if request_id is None: request_id = openllm_core.utils.gen_random_uuid()
|
||||
if request_id is None:
|
||||
request_id = openllm_core.utils.gen_random_uuid()
|
||||
previous_texts, previous_num_tokens = [''] * config['n'], [0] * config['n']
|
||||
async for out in self.runner.generate_iterator.async_stream(prompt_token_ids, request_id, stop, adapter_name, **config.model_dump(flatten=True)):
|
||||
async for out in self.runner.generate_iterator.async_stream(
|
||||
prompt_token_ids, request_id, stop, adapter_name, **config.model_dump(flatten=True)
|
||||
):
|
||||
generated = GenerationOutput.from_runner(out).with_options(prompt=prompt)
|
||||
delta_outputs = t.cast(t.List[CompletionChunk], [None] * len(generated.outputs))
|
||||
if generated.finished: break
|
||||
if generated.finished:
|
||||
break
|
||||
for output in generated.outputs:
|
||||
i = output.index
|
||||
delta_tokens, delta_text = output.token_ids[previous_num_tokens[i]:], output.text[len(previous_texts[i]):]
|
||||
delta_tokens, delta_text = output.token_ids[previous_num_tokens[i] :], output.text[len(previous_texts[i]) :]
|
||||
previous_texts[i], previous_num_tokens[i] = output.text, len(output.token_ids)
|
||||
delta_outputs[i] = output.with_options(text=delta_text, token_ids=delta_tokens)
|
||||
yield generated.with_options(outputs=delta_outputs)
|
||||
|
||||
def _RunnerFactory(self: openllm.LLM[M, T],
|
||||
/,
|
||||
models: list[bentoml.Model] | None = None,
|
||||
max_batch_size: int | None = None,
|
||||
max_latency_ms: int | None = None,
|
||||
scheduling_strategy: type[bentoml.Strategy] = CascadingResourceStrategy,
|
||||
*,
|
||||
backend: LiteralBackend | None = None) -> LLMRunner[M, T]:
|
||||
|
||||
def _RunnerFactory(
|
||||
self: openllm.LLM[M, T],
|
||||
/,
|
||||
models: list[bentoml.Model] | None = None,
|
||||
max_batch_size: int | None = None,
|
||||
max_latency_ms: int | None = None,
|
||||
scheduling_strategy: type[bentoml.Strategy] = CascadingResourceStrategy,
|
||||
*,
|
||||
backend: LiteralBackend | None = None,
|
||||
) -> LLMRunner[M, T]:
|
||||
from ._runners import runnable
|
||||
backend = t.cast(LiteralBackend, first_not_none(backend, os.environ.get('OPENLLM_BACKEND'), default=self.__llm_backend__))
|
||||
|
||||
backend = t.cast(
|
||||
LiteralBackend, first_not_none(backend, os.environ.get('OPENLLM_BACKEND'), default=self.__llm_backend__)
|
||||
)
|
||||
|
||||
models = models if models is not None else []
|
||||
try:
|
||||
@@ -391,12 +490,18 @@ def _RunnerFactory(self: openllm.LLM[M, T],
|
||||
except bentoml.exceptions.NotFound as err:
|
||||
raise RuntimeError(f'Failed to locate {self.bentomodel}:{err}') from err
|
||||
|
||||
if self._prompt_template: prompt_template = self._prompt_template.to_string()
|
||||
elif hasattr(self.config, 'default_prompt_template'): prompt_template = self.config.default_prompt_template
|
||||
else: prompt_template = None
|
||||
if self._system_message: system_message = self._system_message
|
||||
elif hasattr(self.config, 'default_system_message'): system_message = self.config.default_system_message
|
||||
else: system_message = None
|
||||
if self._prompt_template:
|
||||
prompt_template = self._prompt_template.to_string()
|
||||
elif hasattr(self.config, 'default_prompt_template'):
|
||||
prompt_template = self.config.default_prompt_template
|
||||
else:
|
||||
prompt_template = None
|
||||
if self._system_message:
|
||||
system_message = self._system_message
|
||||
elif hasattr(self.config, 'default_system_message'):
|
||||
system_message = self.config.default_system_message
|
||||
else:
|
||||
system_message = None
|
||||
|
||||
# yapf: disable
|
||||
def _wrapped_repr_keys(_: LLMRunner[M, T]) -> set[str]: return {'config', 'llm_type', 'runner_methods', 'backend', 'llm_tag'}
|
||||
@@ -408,31 +513,39 @@ def _RunnerFactory(self: openllm.LLM[M, T],
|
||||
yield 'llm_tag', self.tag
|
||||
# yapf: enable
|
||||
|
||||
return types.new_class(self.__class__.__name__ + 'Runner', (bentoml.Runner,),
|
||||
exec_body=lambda ns: ns.update({
|
||||
'llm_type': self.llm_type,
|
||||
'identifying_params': self.identifying_params,
|
||||
'llm_tag': self.tag,
|
||||
'llm': self,
|
||||
'config': self.config,
|
||||
'backend': backend,
|
||||
'__module__': self.__module__,
|
||||
'__doc__': getattr(openllm_core.config, f'START_{self.config["model_name"].upper()}_COMMAND_DOCSTRING'),
|
||||
'__repr__': ReprMixin.__repr__,
|
||||
'__repr_keys__': property(_wrapped_repr_keys),
|
||||
'__repr_args__': _wrapped_repr_args,
|
||||
'has_adapters': self.has_adapters,
|
||||
'prompt_template': prompt_template,
|
||||
'system_message': system_message,
|
||||
}))(runnable(backend),
|
||||
name=self.runner_name,
|
||||
embedded=False,
|
||||
models=models,
|
||||
max_batch_size=max_batch_size,
|
||||
max_latency_ms=max_latency_ms,
|
||||
scheduling_strategy=scheduling_strategy,
|
||||
runnable_init_params=dict(llm=self),
|
||||
method_configs=converter.unstructure({'generate_iterator': ModelSignature(batchable=False)}))
|
||||
return types.new_class(
|
||||
self.__class__.__name__ + 'Runner',
|
||||
(bentoml.Runner,),
|
||||
exec_body=lambda ns: ns.update(
|
||||
{
|
||||
'llm_type': self.llm_type,
|
||||
'identifying_params': self.identifying_params,
|
||||
'llm_tag': self.tag,
|
||||
'llm': self,
|
||||
'config': self.config,
|
||||
'backend': backend,
|
||||
'__module__': self.__module__,
|
||||
'__doc__': getattr(openllm_core.config, f'START_{self.config["model_name"].upper()}_COMMAND_DOCSTRING'),
|
||||
'__repr__': ReprMixin.__repr__,
|
||||
'__repr_keys__': property(_wrapped_repr_keys),
|
||||
'__repr_args__': _wrapped_repr_args,
|
||||
'has_adapters': self.has_adapters,
|
||||
'prompt_template': prompt_template,
|
||||
'system_message': system_message,
|
||||
}
|
||||
),
|
||||
)(
|
||||
runnable(backend),
|
||||
name=self.runner_name,
|
||||
embedded=False,
|
||||
models=models,
|
||||
max_batch_size=max_batch_size,
|
||||
max_latency_ms=max_latency_ms,
|
||||
scheduling_strategy=scheduling_strategy,
|
||||
runnable_init_params=dict(llm=self),
|
||||
method_configs=converter.unstructure({'generate_iterator': ModelSignature(batchable=False)}),
|
||||
)
|
||||
|
||||
|
||||
@t.final
|
||||
class LLMRunnable(bentoml.Runnable, t.Generic[M, T]):
|
||||
@@ -440,6 +553,7 @@ class LLMRunnable(bentoml.Runnable, t.Generic[M, T]):
|
||||
SUPPORTS_CPU_MULTI_THREADING = True
|
||||
generate_iterator: RunnableMethod[LLMRunnable[M, T], [list[int], str, str | t.Iterable[str] | None, str | None], str]
|
||||
|
||||
|
||||
@t.final
|
||||
class LLMRunner(t.Protocol[M, T]):
|
||||
__doc__: str
|
||||
@@ -461,22 +575,23 @@ class LLMRunner(t.Protocol[M, T]):
|
||||
runnable_init_params: dict[str, t.Any]
|
||||
_runner_handle: RunnerHandle
|
||||
|
||||
def __init__(self,
|
||||
runnable_class: type[LLMRunnable[M, T]],
|
||||
*,
|
||||
runnable_init_params: dict[str, t.Any] | None = ...,
|
||||
name: str | None = ...,
|
||||
scheduling_strategy: type[Strategy] = ...,
|
||||
models: list[bentoml.Model] | None = ...,
|
||||
max_batch_size: int | None = ...,
|
||||
max_latency_ms: int | None = ...,
|
||||
method_configs: dict[str, dict[str, int]] | None = ...,
|
||||
embedded: bool = False) -> None:
|
||||
...
|
||||
def __init__(
|
||||
self,
|
||||
runnable_class: type[LLMRunnable[M, T]],
|
||||
*,
|
||||
runnable_init_params: dict[str, t.Any] | None = ...,
|
||||
name: str | None = ...,
|
||||
scheduling_strategy: type[Strategy] = ...,
|
||||
models: list[bentoml.Model] | None = ...,
|
||||
max_batch_size: int | None = ...,
|
||||
max_latency_ms: int | None = ...,
|
||||
method_configs: dict[str, dict[str, int]] | None = ...,
|
||||
embedded: bool = False,
|
||||
) -> None: ...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def __repr_keys__(self) -> set[str]:
|
||||
...
|
||||
def __repr_keys__(self) -> set[str]: ...
|
||||
|
||||
|
||||
__all__ = ['LLMRunner', 'LLMRunnable', 'LLM']
|
||||
|
||||
@@ -14,6 +14,7 @@ from openllm_core.utils import is_autogptq_available
|
||||
from openllm_core.utils import is_bitsandbytes_available
|
||||
from openllm_core.utils import is_optimum_supports_gptq
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm_core._typing_compat import DictStrAny
|
||||
|
||||
@@ -21,20 +22,28 @@ if t.TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@overload
|
||||
def infer_quantisation_config(self: LLM[t.Any, t.Any], quantise: t.Literal['int8', 'int4'], **attrs: t.Any) -> tuple[transformers.BitsAndBytesConfig, DictStrAny]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def infer_quantisation_config(self: LLM[t.Any, t.Any], quantise: t.Literal['gptq'], **attrs: t.Any) -> tuple[transformers.GPTQConfig, DictStrAny]:
|
||||
...
|
||||
def infer_quantisation_config(
|
||||
self: LLM[t.Any, t.Any], quantise: t.Literal['int8', 'int4'], **attrs: t.Any
|
||||
) -> tuple[transformers.BitsAndBytesConfig, DictStrAny]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def infer_quantisation_config(self: LLM[t.Any, t.Any], quantise: t.Literal['awq'], **attrs: t.Any) -> tuple[transformers.AwqConfig, DictStrAny]:
|
||||
...
|
||||
def infer_quantisation_config(
|
||||
self: LLM[t.Any, t.Any], quantise: t.Literal['gptq'], **attrs: t.Any
|
||||
) -> tuple[transformers.GPTQConfig, DictStrAny]: ...
|
||||
|
||||
def infer_quantisation_config(self: LLM[t.Any, t.Any], quantise: LiteralQuantise,
|
||||
**attrs: t.Any) -> tuple[transformers.BitsAndBytesConfig | transformers.GPTQConfig | transformers.AwqConfig, DictStrAny]:
|
||||
|
||||
@overload
|
||||
def infer_quantisation_config(
|
||||
self: LLM[t.Any, t.Any], quantise: t.Literal['awq'], **attrs: t.Any
|
||||
) -> tuple[transformers.AwqConfig, DictStrAny]: ...
|
||||
|
||||
|
||||
def infer_quantisation_config(
|
||||
self: LLM[t.Any, t.Any], quantise: LiteralQuantise, **attrs: t.Any
|
||||
) -> tuple[transformers.BitsAndBytesConfig | transformers.GPTQConfig | transformers.AwqConfig, DictStrAny]:
|
||||
# 8 bit configuration
|
||||
int8_threshold = attrs.pop('llm_int8_threshhold', 6.0)
|
||||
int8_enable_fp32_cpu_offload = attrs.pop('llm_int8_enable_fp32_cpu_offload', False)
|
||||
@@ -64,34 +73,39 @@ def infer_quantisation_config(self: LLM[t.Any, t.Any], quantise: LiteralQuantise
|
||||
gptq_pad_token_id = attrs.pop('pad_token_id', None)
|
||||
disable_exllama = attrs.pop('disable_exllama', False) # backward compatibility
|
||||
gptq_use_exllama = attrs.pop('use_exllama', True)
|
||||
if disable_exllama: gptq_use_exllama = False
|
||||
return transformers.GPTQConfig(bits=bits,
|
||||
tokenizer=gptq_tokenizer,
|
||||
dataset=gptq_dataset,
|
||||
group_size=group_size,
|
||||
damp_percent=gptq_damp_percent,
|
||||
desc_act=gptq_desc_act,
|
||||
sym=gptq_sym,
|
||||
true_sequential=gptq_true_sequential,
|
||||
use_cuda_fp16=gptq_use_cuda_fp16,
|
||||
model_seqlen=gptq_model_seqlen,
|
||||
block_name_to_quantize=gptq_block_name_to_quantize,
|
||||
module_name_preceding_first_block=gptq_module_name_preceding_first_block,
|
||||
batch_size=gptq_batch_size,
|
||||
pad_token_id=gptq_pad_token_id,
|
||||
use_exllama=gptq_use_exllama,
|
||||
exllama_config={'version': 1}) # XXX: See how to migrate to v2
|
||||
if disable_exllama:
|
||||
gptq_use_exllama = False
|
||||
return transformers.GPTQConfig(
|
||||
bits=bits,
|
||||
tokenizer=gptq_tokenizer,
|
||||
dataset=gptq_dataset,
|
||||
group_size=group_size,
|
||||
damp_percent=gptq_damp_percent,
|
||||
desc_act=gptq_desc_act,
|
||||
sym=gptq_sym,
|
||||
true_sequential=gptq_true_sequential,
|
||||
use_cuda_fp16=gptq_use_cuda_fp16,
|
||||
model_seqlen=gptq_model_seqlen,
|
||||
block_name_to_quantize=gptq_block_name_to_quantize,
|
||||
module_name_preceding_first_block=gptq_module_name_preceding_first_block,
|
||||
batch_size=gptq_batch_size,
|
||||
pad_token_id=gptq_pad_token_id,
|
||||
use_exllama=gptq_use_exllama,
|
||||
exllama_config={'version': 1},
|
||||
) # XXX: See how to migrate to v2
|
||||
|
||||
def create_int8_config(int8_skip_modules: list[str] | None) -> transformers.BitsAndBytesConfig:
|
||||
# if int8_skip_modules is None: int8_skip_modules = []
|
||||
# if 'lm_head' not in int8_skip_modules and self.config_class.__openllm_model_type__ == 'causal_lm':
|
||||
# logger.debug("Skipping 'lm_head' for quantization for %s", self.__name__)
|
||||
# int8_skip_modules.append('lm_head')
|
||||
return transformers.BitsAndBytesConfig(load_in_8bit=True,
|
||||
llm_int8_enable_fp32_cpu_offload=int8_enable_fp32_cpu_offload,
|
||||
llm_int8_threshhold=int8_threshold,
|
||||
llm_int8_skip_modules=int8_skip_modules,
|
||||
llm_int8_has_fp16_weight=int8_has_fp16_weight)
|
||||
return transformers.BitsAndBytesConfig(
|
||||
load_in_8bit=True,
|
||||
llm_int8_enable_fp32_cpu_offload=int8_enable_fp32_cpu_offload,
|
||||
llm_int8_threshhold=int8_threshold,
|
||||
llm_int8_skip_modules=int8_skip_modules,
|
||||
llm_int8_has_fp16_weight=int8_has_fp16_weight,
|
||||
)
|
||||
|
||||
# 4 bit configuration
|
||||
int4_compute_dtype = attrs.pop('bnb_4bit_compute_dtype', torch.bfloat16)
|
||||
@@ -100,22 +114,30 @@ def infer_quantisation_config(self: LLM[t.Any, t.Any], quantise: LiteralQuantise
|
||||
|
||||
# NOTE: Quantization setup quantize is a openllm.LLM feature, where we can quantize the model with bitsandbytes or quantization aware training.
|
||||
if not is_bitsandbytes_available():
|
||||
raise RuntimeError("Quantization requires bitsandbytes to be installed. Make sure to install OpenLLM with 'pip install \"openllm[fine-tune]\"'")
|
||||
if quantise == 'int8': quantisation_config = create_int8_config(int8_skip_modules)
|
||||
raise RuntimeError(
|
||||
'Quantization requires bitsandbytes to be installed. Make sure to install OpenLLM with \'pip install "openllm[fine-tune]"\''
|
||||
)
|
||||
if quantise == 'int8':
|
||||
quantisation_config = create_int8_config(int8_skip_modules)
|
||||
elif quantise == 'int4':
|
||||
quantisation_config = transformers.BitsAndBytesConfig(load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=int4_compute_dtype,
|
||||
bnb_4bit_quant_type=int4_quant_type,
|
||||
bnb_4bit_use_double_quant=int4_use_double_quant)
|
||||
quantisation_config = transformers.BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=int4_compute_dtype,
|
||||
bnb_4bit_quant_type=int4_quant_type,
|
||||
bnb_4bit_use_double_quant=int4_use_double_quant,
|
||||
)
|
||||
elif quantise == 'gptq':
|
||||
if not is_autogptq_available() or not is_optimum_supports_gptq():
|
||||
raise MissingDependencyError(
|
||||
"'quantize=\"gptq\"' requires 'auto-gptq' and 'optimum>=0.12' to be installed (missing or failed to import). Make sure to do 'pip install \"openllm[gptq]\"'")
|
||||
"'quantize=\"gptq\"' requires 'auto-gptq' and 'optimum>=0.12' to be installed (missing or failed to import). Make sure to do 'pip install \"openllm[gptq]\"'"
|
||||
)
|
||||
else:
|
||||
quantisation_config = create_gptq_config()
|
||||
elif quantise == 'awq':
|
||||
if not is_autoawq_available():
|
||||
raise MissingDependencyError("quantize='awq' requires 'auto-awq' to be installed (missing or failed to import). Make sure to do 'pip install \"openllm[awq]\"'.")
|
||||
raise MissingDependencyError(
|
||||
"quantize='awq' requires 'auto-awq' to be installed (missing or failed to import). Make sure to do 'pip install \"openllm[awq]\"'."
|
||||
)
|
||||
else:
|
||||
quantisation_config = create_awq_config()
|
||||
else:
|
||||
|
||||
@@ -19,6 +19,7 @@ from openllm_core.utils import first_not_none
|
||||
from openllm_core.utils import get_debug_mode
|
||||
from openllm_core.utils import is_vllm_available
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import vllm
|
||||
|
||||
@@ -30,10 +31,15 @@ _DEFAULT_TOKENIZER = 'hf-internal-testing/llama-tokenizer'
|
||||
|
||||
__all__ = ['runnable']
|
||||
|
||||
|
||||
def runnable(backend: LiteralBackend | None = None) -> type[bentoml.Runnable]:
|
||||
backend = t.cast(LiteralBackend, first_not_none(backend, os.getenv('OPENLLM_BACKEND'), default='vllm' if is_vllm_available() else 'pt'))
|
||||
backend = t.cast(
|
||||
LiteralBackend,
|
||||
first_not_none(backend, os.getenv('OPENLLM_BACKEND'), default='vllm' if is_vllm_available() else 'pt'),
|
||||
)
|
||||
return vLLMRunnable if backend == 'vllm' else PyTorchRunnable
|
||||
|
||||
|
||||
class vLLMRunnable(bentoml.Runnable):
|
||||
SUPPORTED_RESOURCES = ('nvidia.com/gpu', 'amd.com/gpu', 'cpu')
|
||||
SUPPORTS_CPU_MULTI_THREADING = True
|
||||
@@ -41,47 +47,62 @@ class vLLMRunnable(bentoml.Runnable):
|
||||
def __init__(self, llm: openllm.LLM[M, T]) -> None:
|
||||
self.config = llm.config
|
||||
num_gpus, dev = 1, openllm.utils.device_count()
|
||||
if dev >= 2: num_gpus = min(dev // 2 * 2, dev)
|
||||
if dev >= 2:
|
||||
num_gpus = min(dev // 2 * 2, dev)
|
||||
quantization = None
|
||||
if llm._quantise and llm._quantise in {'awq', 'squeezellm'}: quantization = llm._quantise
|
||||
if llm._quantise and llm._quantise in {'awq', 'squeezellm'}:
|
||||
quantization = llm._quantise
|
||||
try:
|
||||
self.model = vllm.AsyncLLMEngine.from_engine_args(
|
||||
vllm.AsyncEngineArgs(model=llm.bentomodel.path,
|
||||
tokenizer=llm.bentomodel.path,
|
||||
trust_remote_code=llm.trust_remote_code,
|
||||
tokenizer_mode='auto',
|
||||
tensor_parallel_size=num_gpus,
|
||||
dtype='auto',
|
||||
quantization=quantization,
|
||||
disable_log_requests=not get_debug_mode(),
|
||||
worker_use_ray=False,
|
||||
engine_use_ray=False))
|
||||
vllm.AsyncEngineArgs(
|
||||
model=llm.bentomodel.path,
|
||||
tokenizer=llm.bentomodel.path,
|
||||
trust_remote_code=llm.trust_remote_code,
|
||||
tokenizer_mode='auto',
|
||||
tensor_parallel_size=num_gpus,
|
||||
dtype='auto',
|
||||
quantization=quantization,
|
||||
disable_log_requests=not get_debug_mode(),
|
||||
worker_use_ray=False,
|
||||
engine_use_ray=False,
|
||||
)
|
||||
)
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
raise OpenLLMException(f'Failed to initialise vLLMEngine due to the following error:\n{err}') from err
|
||||
|
||||
@bentoml.Runnable.method(batchable=False)
|
||||
async def generate_iterator(self,
|
||||
prompt_token_ids: list[int],
|
||||
request_id: str,
|
||||
stop: str | t.Iterable[str] | None = None,
|
||||
adapter_name: str | None = None,
|
||||
**attrs: t.Any) -> t.AsyncGenerator[str, None]:
|
||||
if adapter_name is not None: raise NotImplementedError('Adapter is not supported with vLLM.')
|
||||
async def generate_iterator(
|
||||
self,
|
||||
prompt_token_ids: list[int],
|
||||
request_id: str,
|
||||
stop: str | t.Iterable[str] | None = None,
|
||||
adapter_name: str | None = None,
|
||||
**attrs: t.Any,
|
||||
) -> t.AsyncGenerator[str, None]:
|
||||
if adapter_name is not None:
|
||||
raise NotImplementedError('Adapter is not supported with vLLM.')
|
||||
stop_: set[str] = set()
|
||||
if isinstance(stop, str) and stop != '': stop_.add(stop)
|
||||
elif isinstance(stop, t.Iterable): stop_.update(stop)
|
||||
if isinstance(stop, str) and stop != '':
|
||||
stop_.add(stop)
|
||||
elif isinstance(stop, t.Iterable):
|
||||
stop_.update(stop)
|
||||
|
||||
temperature = attrs.pop('temperature', self.config['temperature'])
|
||||
top_p = attrs.pop('top_p', self.config['top_p'])
|
||||
if temperature <= 1e-5: top_p = 1.0
|
||||
sampling_params = self.config.model_construct_env(stop=list(stop_), temperature=temperature, top_p=top_p, **attrs).to_sampling_config()
|
||||
if temperature <= 1e-5:
|
||||
top_p = 1.0
|
||||
sampling_params = self.config.model_construct_env(
|
||||
stop=list(stop_), temperature=temperature, top_p=top_p, **attrs
|
||||
).to_sampling_config()
|
||||
|
||||
async for request_output in self.model.generate(None, sampling_params, request_id, prompt_token_ids):
|
||||
# XXX: Need to write a hook for serialisation None correctly
|
||||
if request_output.prompt_logprobs is not None: request_output.prompt_logprobs = [it if it else {} for it in request_output.prompt_logprobs]
|
||||
if request_output.prompt_logprobs is not None:
|
||||
request_output.prompt_logprobs = [it if it else {} for it in request_output.prompt_logprobs]
|
||||
yield GenerationOutput.from_vllm(request_output).model_dump_json()
|
||||
|
||||
|
||||
class PyTorchRunnable(bentoml.Runnable):
|
||||
SUPPORTED_RESOURCES = ('nvidia.com/gpu', 'amd.com/gpu', 'cpu')
|
||||
SUPPORTS_CPU_MULTI_THREADING = True
|
||||
@@ -93,23 +114,30 @@ class PyTorchRunnable(bentoml.Runnable):
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
@bentoml.Runnable.method(batchable=False)
|
||||
async def generate_iterator(self,
|
||||
prompt_token_ids: list[int],
|
||||
request_id: str,
|
||||
stop: str | t.Iterable[str] | None = None,
|
||||
adapter_name: str | None = None,
|
||||
**attrs: t.Any) -> t.AsyncGenerator[str, None]:
|
||||
if adapter_name is not None: self.model.set_adapter(adapter_name)
|
||||
async def generate_iterator(
|
||||
self,
|
||||
prompt_token_ids: list[int],
|
||||
request_id: str,
|
||||
stop: str | t.Iterable[str] | None = None,
|
||||
adapter_name: str | None = None,
|
||||
**attrs: t.Any,
|
||||
) -> t.AsyncGenerator[str, None]:
|
||||
if adapter_name is not None:
|
||||
self.model.set_adapter(adapter_name)
|
||||
async for generation_output in self.forward(prompt_token_ids, request_id, stop=stop, **attrs):
|
||||
yield generation_output.model_dump_json()
|
||||
|
||||
async def forward(self, prompt_token_ids: list[int], request_id: str, stop: str | t.Iterable[str] | None = None, **attrs: t.Any) -> t.AsyncGenerator[GenerationOutput, None]:
|
||||
async def forward(
|
||||
self, prompt_token_ids: list[int], request_id: str, stop: str | t.Iterable[str] | None = None, **attrs: t.Any
|
||||
) -> t.AsyncGenerator[GenerationOutput, None]:
|
||||
from ._generation import is_partial_stop
|
||||
from ._generation import prepare_logits_processor
|
||||
|
||||
stop_: set[str] = set()
|
||||
if isinstance(stop, str) and stop != '': stop_.add(stop)
|
||||
elif isinstance(stop, t.Iterable): stop_.update(stop)
|
||||
if isinstance(stop, str) and stop != '':
|
||||
stop_.add(stop)
|
||||
elif isinstance(stop, t.Iterable):
|
||||
stop_.update(stop)
|
||||
config = self.config.model_construct_env(**attrs)
|
||||
|
||||
with torch.inference_mode():
|
||||
@@ -129,7 +157,9 @@ class PyTorchRunnable(bentoml.Runnable):
|
||||
if i == 0: # prefill
|
||||
out = self.model(torch.as_tensor([prompt_token_ids], device=self.device), use_cache=True)
|
||||
else: # decoding
|
||||
out = self.model(torch.as_tensor([[token]], device=self.device), use_cache=True, past_key_values=past_key_values)
|
||||
out = self.model(
|
||||
torch.as_tensor([[token]], device=self.device), use_cache=True, past_key_values=past_key_values
|
||||
)
|
||||
logits = out.logits
|
||||
past_key_values = out.past_key_values
|
||||
|
||||
@@ -143,7 +173,8 @@ class PyTorchRunnable(bentoml.Runnable):
|
||||
last_token_logits = logits[0, -1, :]
|
||||
|
||||
# Switch to CPU by avoiding some bugs in mps backend.
|
||||
if self.device.type == 'mps': last_token_logits = last_token_logits.float().to('cpu')
|
||||
if self.device.type == 'mps':
|
||||
last_token_logits = last_token_logits.float().to('cpu')
|
||||
|
||||
if config['temperature'] < 1e-5 or config['top_p'] < 1e-8: # greedy
|
||||
_, indices = torch.topk(last_token_logits, 2)
|
||||
@@ -160,7 +191,12 @@ class PyTorchRunnable(bentoml.Runnable):
|
||||
|
||||
tmp_output_ids, rfind_start = output_token_ids[input_len:], 0
|
||||
# XXX: Move this to API server
|
||||
text = self.tokenizer.decode(tmp_output_ids, skip_special_tokens=True, spaces_between_special_tokens=False, clean_up_tokenization_spaces=True)
|
||||
text = self.tokenizer.decode(
|
||||
tmp_output_ids,
|
||||
skip_special_tokens=True,
|
||||
spaces_between_special_tokens=False,
|
||||
clean_up_tokenization_spaces=True,
|
||||
)
|
||||
partially_stopped = False
|
||||
if stop_:
|
||||
for it in stop_:
|
||||
@@ -170,21 +206,41 @@ class PyTorchRunnable(bentoml.Runnable):
|
||||
break
|
||||
else:
|
||||
partially_stopped = is_partial_stop(text, it)
|
||||
if partially_stopped: break
|
||||
if partially_stopped:
|
||||
break
|
||||
if not partially_stopped:
|
||||
yield GenerationOutput(prompt='',
|
||||
finished=False,
|
||||
outputs=[CompletionChunk(index=0, text=text, token_ids=output_token_ids[input_len:], cumulative_logprob=0.0, finish_reason=None)],
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
request_id=request_id)
|
||||
if stopped: break
|
||||
else: finish_reason = 'length'
|
||||
if stopped: finish_reason = 'stop'
|
||||
yield GenerationOutput(prompt='',
|
||||
finished=True,
|
||||
outputs=[CompletionChunk(index=0, text=text, token_ids=output_token_ids[input_len:], cumulative_logprob=0.0, finish_reason=finish_reason)],
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
request_id=request_id)
|
||||
yield GenerationOutput(
|
||||
prompt='',
|
||||
finished=False,
|
||||
outputs=[
|
||||
CompletionChunk(
|
||||
index=0, text=text, token_ids=output_token_ids[input_len:], cumulative_logprob=0.0, finish_reason=None
|
||||
)
|
||||
],
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
request_id=request_id,
|
||||
)
|
||||
if stopped:
|
||||
break
|
||||
else:
|
||||
finish_reason = 'length'
|
||||
if stopped:
|
||||
finish_reason = 'stop'
|
||||
yield GenerationOutput(
|
||||
prompt='',
|
||||
finished=True,
|
||||
outputs=[
|
||||
CompletionChunk(
|
||||
index=0,
|
||||
text=text,
|
||||
token_ids=output_token_ids[input_len:],
|
||||
cumulative_logprob=0.0,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
# Clean
|
||||
del past_key_values, out
|
||||
|
||||
@@ -13,40 +13,60 @@ import openllm
|
||||
from bentoml.io import JSON
|
||||
from bentoml.io import Text
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
llm = openllm.LLM[t.Any, t.Any](svars.model_id,
|
||||
model_tag=svars.model_tag,
|
||||
prompt_template=openllm.utils.first_not_none(os.getenv('OPENLLM_PROMPT_TEMPLATE'), None),
|
||||
system_message=openllm.utils.first_not_none(os.getenv('OPENLLM_SYSTEM_MESSAGE'), None),
|
||||
serialisation=openllm.utils.first_not_none(os.getenv('OPENLLM_SERIALIZATION'), 'safetensors'),
|
||||
adapter_map=orjson.loads(svars.adapter_map),
|
||||
trust_remote_code=openllm.utils.check_bool_env('TRUST_REMOTE_CODE', default=False))
|
||||
llm = openllm.LLM[t.Any, t.Any](
|
||||
svars.model_id,
|
||||
model_tag=svars.model_tag,
|
||||
prompt_template=openllm.utils.first_not_none(os.getenv('OPENLLM_PROMPT_TEMPLATE'), None),
|
||||
system_message=openllm.utils.first_not_none(os.getenv('OPENLLM_SYSTEM_MESSAGE'), None),
|
||||
serialisation=openllm.utils.first_not_none(os.getenv('OPENLLM_SERIALIZATION'), 'safetensors'),
|
||||
adapter_map=orjson.loads(svars.adapter_map),
|
||||
trust_remote_code=openllm.utils.check_bool_env('TRUST_REMOTE_CODE', default=False),
|
||||
)
|
||||
llm_config = llm.config
|
||||
svc = bentoml.Service(name=f"llm-{llm_config['start_name']}-service", runners=[llm.runner])
|
||||
|
||||
llm_model_class = openllm.GenerationInput.from_llm_config(llm_config)
|
||||
|
||||
@svc.api(route='/v1/generate', input=JSON.from_sample(llm_model_class.examples()), output=JSON.from_sample(openllm.GenerationOutput.examples()))
|
||||
|
||||
@svc.api(
|
||||
route='/v1/generate',
|
||||
input=JSON.from_sample(llm_model_class.examples()),
|
||||
output=JSON.from_sample(openllm.GenerationOutput.examples()),
|
||||
)
|
||||
async def generate_v1(input_dict: dict[str, t.Any]) -> openllm.GenerationOutput:
|
||||
return await llm.generate(**llm_model_class(**input_dict).model_dump())
|
||||
|
||||
@svc.api(route='/v1/generate_stream', input=JSON.from_sample(llm_model_class.examples()), output=Text(content_type='text/event-stream'))
|
||||
|
||||
@svc.api(
|
||||
route='/v1/generate_stream',
|
||||
input=JSON.from_sample(llm_model_class.examples()),
|
||||
output=Text(content_type='text/event-stream'),
|
||||
)
|
||||
async def generate_stream_v1(input_dict: dict[str, t.Any]) -> t.AsyncGenerator[str, None]:
|
||||
async for it in llm.generate_iterator(**llm_model_class(**input_dict).model_dump()):
|
||||
yield f'data: {it.model_dump_json()}\n\n'
|
||||
yield 'data: [DONE]\n\n'
|
||||
|
||||
_Metadata = openllm.MetadataOutput(timeout=llm_config['timeout'],
|
||||
model_name=llm_config['model_name'],
|
||||
backend=llm.__llm_backend__,
|
||||
model_id=llm.model_id,
|
||||
configuration=llm_config.model_dump_json(flatten=True).decode(),
|
||||
prompt_template=llm.runner.prompt_template,
|
||||
system_message=llm.runner.system_message)
|
||||
|
||||
_Metadata = openllm.MetadataOutput(
|
||||
timeout=llm_config['timeout'],
|
||||
model_name=llm_config['model_name'],
|
||||
backend=llm.__llm_backend__,
|
||||
model_id=llm.model_id,
|
||||
configuration=llm_config.model_dump_json(flatten=True).decode(),
|
||||
prompt_template=llm.runner.prompt_template,
|
||||
system_message=llm.runner.system_message,
|
||||
)
|
||||
|
||||
|
||||
@svc.api(route='/v1/metadata', input=Text(), output=JSON.from_sample(_Metadata.model_dump()))
|
||||
def metadata_v1(_: str) -> openllm.MetadataOutput:
|
||||
return _Metadata
|
||||
|
||||
openllm.mount_entrypoints(svc, llm) # HACK: This must always be the last line in this file, as we will do some MK for OpenAPI schema.
|
||||
|
||||
openllm.mount_entrypoints(
|
||||
svc, llm
|
||||
) # HACK: This must always be the last line in this file, as we will do some MK for OpenAPI schema.
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
model_id = '{__model_id__}' # openllm: model id
|
||||
model_tag = '{__model_tag__}' # openllm: model tag
|
||||
adapter_map = '''{__model_adapter_map__}''' # openllm: model adapter map
|
||||
adapter_map = """{__model_adapter_map__}""" # openllm: model adapter map
|
||||
|
||||
@@ -20,60 +20,74 @@ from openllm_core._typing_compat import overload
|
||||
from openllm_core.utils import DEBUG
|
||||
from openllm_core.utils import ReprMixin
|
||||
|
||||
|
||||
class DynResource(t.Protocol):
|
||||
resource_id: t.ClassVar[str]
|
||||
|
||||
@classmethod
|
||||
def from_system(cls) -> t.Sequence[t.Any]:
|
||||
...
|
||||
def from_system(cls) -> t.Sequence[t.Any]: ...
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _strtoul(s: str) -> int:
|
||||
"""Return -1 or positive integer sequence string starts with,."""
|
||||
if not s: return -1
|
||||
if not s:
|
||||
return -1
|
||||
idx = 0
|
||||
for idx, c in enumerate(s):
|
||||
if not (c.isdigit() or (idx == 0 and c in '+-')): break
|
||||
if idx + 1 == len(s): idx += 1 # noqa: PLW2901
|
||||
if not (c.isdigit() or (idx == 0 and c in '+-')):
|
||||
break
|
||||
if idx + 1 == len(s):
|
||||
idx += 1 # noqa: PLW2901
|
||||
# NOTE: idx will be set via enumerate
|
||||
return int(s[:idx]) if idx > 0 else -1
|
||||
|
||||
|
||||
def _parse_list_with_prefix(lst: str, prefix: str) -> list[str]:
|
||||
rcs: list[str] = []
|
||||
for elem in lst.split(','):
|
||||
# Repeated id results in empty set
|
||||
if elem in rcs: return []
|
||||
if elem in rcs:
|
||||
return []
|
||||
# Anything other but prefix is ignored
|
||||
if not elem.startswith(prefix): break
|
||||
if not elem.startswith(prefix):
|
||||
break
|
||||
rcs.append(elem)
|
||||
return rcs
|
||||
|
||||
|
||||
_STACK_LEVEL = 3
|
||||
|
||||
|
||||
@overload # variant: default callback
|
||||
def _parse_visible_devices() -> list[str] | None:
|
||||
...
|
||||
def _parse_visible_devices() -> list[str] | None: ...
|
||||
|
||||
|
||||
@overload # variant: specify None, and respect_env
|
||||
def _parse_visible_devices(default_var: None, *, respect_env: t.Literal[True]) -> list[str] | None:
|
||||
...
|
||||
def _parse_visible_devices(default_var: None, *, respect_env: t.Literal[True]) -> list[str] | None: ...
|
||||
|
||||
|
||||
@overload # variant: default var is something other than None
|
||||
def _parse_visible_devices(default_var: str = ..., *, respect_env: t.Literal[False]) -> list[str]:
|
||||
...
|
||||
def _parse_visible_devices(default_var: str = ..., *, respect_env: t.Literal[False]) -> list[str]: ...
|
||||
|
||||
|
||||
def _parse_visible_devices(default_var: str | None = None, respect_env: bool = True) -> list[str] | None:
|
||||
"""CUDA_VISIBLE_DEVICES aware with default var for parsing spec."""
|
||||
if respect_env:
|
||||
spec = os.environ.get('CUDA_VISIBLE_DEVICES', default_var)
|
||||
if not spec: return None
|
||||
if not spec:
|
||||
return None
|
||||
else:
|
||||
if default_var is None: raise ValueError('spec is required to be not None when parsing spec.')
|
||||
if default_var is None:
|
||||
raise ValueError('spec is required to be not None when parsing spec.')
|
||||
spec = default_var
|
||||
|
||||
if spec.startswith('GPU-'): return _parse_list_with_prefix(spec, 'GPU-')
|
||||
if spec.startswith('MIG-'): return _parse_list_with_prefix(spec, 'MIG-')
|
||||
if spec.startswith('GPU-'):
|
||||
return _parse_list_with_prefix(spec, 'GPU-')
|
||||
if spec.startswith('MIG-'):
|
||||
return _parse_list_with_prefix(spec, 'MIG-')
|
||||
# XXX: We need to somehow handle cases such as '100m'
|
||||
# CUDA_VISIBLE_DEVICES uses something like strtoul
|
||||
# which makes `1gpu2,2ampere` is equivalent to `1,2`
|
||||
@@ -81,18 +95,22 @@ def _parse_visible_devices(default_var: str | None = None, respect_env: bool = T
|
||||
for el in spec.split(','):
|
||||
x = _strtoul(el.strip())
|
||||
# Repeated ordinal results in empty set
|
||||
if x in rc: return []
|
||||
if x in rc:
|
||||
return []
|
||||
# Negative value aborts the sequence
|
||||
if x < 0: break
|
||||
if x < 0:
|
||||
break
|
||||
rc.append(x)
|
||||
return [str(i) for i in rc]
|
||||
|
||||
|
||||
def _from_system(cls: type[DynResource]) -> list[str]:
|
||||
visible_devices = _parse_visible_devices()
|
||||
if visible_devices is None:
|
||||
if cls.resource_id == 'amd.com/gpu':
|
||||
if not psutil.LINUX:
|
||||
if DEBUG: logger.debug('AMD GPUs is currently only supported on Linux.')
|
||||
if DEBUG:
|
||||
logger.debug('AMD GPUs is currently only supported on Linux.')
|
||||
return []
|
||||
# ROCm does not currently have the rocm_smi wheel.
|
||||
# So we need to use the ctypes bindings directly.
|
||||
@@ -108,7 +126,8 @@ def _from_system(cls: type[DynResource]) -> list[str]:
|
||||
|
||||
device_count = c_uint32(0)
|
||||
ret = rocmsmi.rsmi_num_monitor_devices(byref(device_count))
|
||||
if ret == rsmi_status_t.RSMI_STATUS_SUCCESS: return [str(i) for i in range(device_count.value)]
|
||||
if ret == rsmi_status_t.RSMI_STATUS_SUCCESS:
|
||||
return [str(i) for i in range(device_count.value)]
|
||||
return []
|
||||
# In this case the binary is not found, returning empty list
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
@@ -118,6 +137,7 @@ def _from_system(cls: type[DynResource]) -> list[str]:
|
||||
else:
|
||||
try:
|
||||
from cuda import cuda
|
||||
|
||||
cuda.cuInit(0)
|
||||
_, dev = cuda.cuDeviceGetCount()
|
||||
return [str(i) for i in range(dev)]
|
||||
@@ -125,31 +145,39 @@ def _from_system(cls: type[DynResource]) -> list[str]:
|
||||
return []
|
||||
return visible_devices
|
||||
|
||||
@overload
|
||||
def _from_spec(cls: type[DynResource], spec: int) -> list[str]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def _from_spec(cls: type[DynResource], spec: list[int | str]) -> list[str]:
|
||||
...
|
||||
def _from_spec(cls: type[DynResource], spec: int) -> list[str]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def _from_spec(cls: type[DynResource], spec: str) -> list[str]:
|
||||
...
|
||||
def _from_spec(cls: type[DynResource], spec: list[int | str]) -> list[str]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def _from_spec(cls: type[DynResource], spec: str) -> list[str]: ...
|
||||
|
||||
|
||||
def _from_spec(cls: type[DynResource], spec: t.Any) -> list[str]:
|
||||
if isinstance(spec, int):
|
||||
if spec in (-1, 0): return []
|
||||
if spec < -1: raise ValueError('Spec cannot be < -1.')
|
||||
if spec in (-1, 0):
|
||||
return []
|
||||
if spec < -1:
|
||||
raise ValueError('Spec cannot be < -1.')
|
||||
return [str(i) for i in range(spec)]
|
||||
elif isinstance(spec, str):
|
||||
if not spec: return []
|
||||
if spec.isdigit(): spec = ','.join([str(i) for i in range(_strtoul(spec))])
|
||||
if not spec:
|
||||
return []
|
||||
if spec.isdigit():
|
||||
spec = ','.join([str(i) for i in range(_strtoul(spec))])
|
||||
return _parse_visible_devices(spec, respect_env=False)
|
||||
elif isinstance(spec, list):
|
||||
return [str(x) for x in spec]
|
||||
else:
|
||||
raise TypeError(f"'{cls.__name__}.from_spec' only supports parsing spec of type int, str, or list, got '{type(spec)}' instead.")
|
||||
raise TypeError(
|
||||
f"'{cls.__name__}.from_spec' only supports parsing spec of type int, str, or list, got '{type(spec)}' instead."
|
||||
)
|
||||
|
||||
|
||||
def _raw_device_uuid_nvml() -> list[str] | None:
|
||||
from ctypes import CDLL
|
||||
@@ -190,10 +218,14 @@ def _raw_device_uuid_nvml() -> list[str] | None:
|
||||
del nvml_h
|
||||
return uuids
|
||||
|
||||
|
||||
def _validate(cls: type[DynResource], val: list[t.Any]) -> None:
|
||||
if cls.resource_id == 'amd.com/gpu':
|
||||
raise RuntimeError("AMD GPU validation is not yet supported. Make sure to call 'get_resource(..., validate=False)'")
|
||||
if not all(isinstance(i, str) for i in val): raise ValueError('Input list should be all string type.')
|
||||
raise RuntimeError(
|
||||
"AMD GPU validation is not yet supported. Make sure to call 'get_resource(..., validate=False)'"
|
||||
)
|
||||
if not all(isinstance(i, str) for i in val):
|
||||
raise ValueError('Input list should be all string type.')
|
||||
|
||||
try:
|
||||
from cuda import cuda
|
||||
@@ -205,25 +237,36 @@ def _validate(cls: type[DynResource], val: list[t.Any]) -> None:
|
||||
for el in val:
|
||||
if el.startswith('GPU-') or el.startswith('MIG-'):
|
||||
uuids = _raw_device_uuid_nvml()
|
||||
if uuids is None: raise ValueError('Failed to parse available GPUs UUID')
|
||||
if el not in uuids: raise ValueError(f'Given UUID {el} is not found with available UUID (available: {uuids})')
|
||||
if uuids is None:
|
||||
raise ValueError('Failed to parse available GPUs UUID')
|
||||
if el not in uuids:
|
||||
raise ValueError(f'Given UUID {el} is not found with available UUID (available: {uuids})')
|
||||
elif el.isdigit():
|
||||
err, _ = cuda.cuDeviceGet(int(el))
|
||||
if err != cuda.CUresult.CUDA_SUCCESS: raise ValueError(f'Failed to get device {el}')
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise ValueError(f'Failed to get device {el}')
|
||||
except (ImportError, RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
def _make_resource_class(name: str, resource_kind: str, docstring: str) -> type[DynResource]:
|
||||
return types.new_class(
|
||||
name, (bentoml.Resource[t.List[str]], ReprMixin), {'resource_id': resource_kind}, lambda ns: ns.update({
|
||||
'resource_id': resource_kind,
|
||||
'from_spec': classmethod(_from_spec),
|
||||
'from_system': classmethod(_from_system),
|
||||
'validate': classmethod(_validate),
|
||||
'__repr_keys__': property(lambda _: {'resource_id'}),
|
||||
'__doc__': inspect.cleandoc(docstring),
|
||||
'__module__': 'openllm._strategies'
|
||||
}))
|
||||
name,
|
||||
(bentoml.Resource[t.List[str]], ReprMixin),
|
||||
{'resource_id': resource_kind},
|
||||
lambda ns: ns.update(
|
||||
{
|
||||
'resource_id': resource_kind,
|
||||
'from_spec': classmethod(_from_spec),
|
||||
'from_system': classmethod(_from_system),
|
||||
'validate': classmethod(_validate),
|
||||
'__repr_keys__': property(lambda _: {'resource_id'}),
|
||||
'__doc__': inspect.cleandoc(docstring),
|
||||
'__module__': 'openllm._strategies',
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# NOTE: we need to hint these t.Literal since mypy is to dumb to infer this as literal 🤦
|
||||
_TPU_RESOURCE: t.Literal['cloud-tpus.google.com/v2'] = 'cloud-tpus.google.com/v2'
|
||||
@@ -232,15 +275,22 @@ _NVIDIA_GPU_RESOURCE: t.Literal['nvidia.com/gpu'] = 'nvidia.com/gpu'
|
||||
_CPU_RESOURCE: t.Literal['cpu'] = 'cpu'
|
||||
|
||||
NvidiaGpuResource = _make_resource_class(
|
||||
'NvidiaGpuResource', _NVIDIA_GPU_RESOURCE, '''NVIDIA GPU resource.
|
||||
'NvidiaGpuResource',
|
||||
_NVIDIA_GPU_RESOURCE,
|
||||
"""NVIDIA GPU resource.
|
||||
|
||||
This is a modified version of internal's BentoML's NvidiaGpuResource
|
||||
where it respects and parse CUDA_VISIBLE_DEVICES correctly.''')
|
||||
where it respects and parse CUDA_VISIBLE_DEVICES correctly.""",
|
||||
)
|
||||
AmdGpuResource = _make_resource_class(
|
||||
'AmdGpuResource', _AMD_GPU_RESOURCE, '''AMD GPU resource.
|
||||
'AmdGpuResource',
|
||||
_AMD_GPU_RESOURCE,
|
||||
"""AMD GPU resource.
|
||||
|
||||
Since ROCm will respect CUDA_VISIBLE_DEVICES, the behaviour of from_spec, from_system are similar to
|
||||
``NvidiaGpuResource``. Currently ``validate`` is not yet supported.''')
|
||||
``NvidiaGpuResource``. Currently ``validate`` is not yet supported.""",
|
||||
)
|
||||
|
||||
|
||||
class CascadingResourceStrategy(bentoml.Strategy, ReprMixin):
|
||||
"""This is extends the default BentoML strategy where we check for NVIDIA GPU resource -> AMD GPU resource -> CPU resource.
|
||||
@@ -251,21 +301,27 @@ class CascadingResourceStrategy(bentoml.Strategy, ReprMixin):
|
||||
|
||||
TODO: Support CloudTPUResource
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_worker_count(cls, runnable_class: type[bentoml.Runnable], resource_request: dict[str, t.Any] | None, workers_per_resource: float) -> int:
|
||||
def get_worker_count(
|
||||
cls, runnable_class: type[bentoml.Runnable], resource_request: dict[str, t.Any] | None, workers_per_resource: float
|
||||
) -> int:
|
||||
"""Return the number of workers to be used for the given runnable class.
|
||||
|
||||
Note that for all available GPU, the number of workers will always be 1.
|
||||
"""
|
||||
if resource_request is None: resource_request = system_resources()
|
||||
if resource_request is None:
|
||||
resource_request = system_resources()
|
||||
# use NVIDIA
|
||||
kind = 'nvidia.com/gpu'
|
||||
nvidia_req = get_resource(resource_request, kind)
|
||||
if nvidia_req is not None: return 1
|
||||
if nvidia_req is not None:
|
||||
return 1
|
||||
# use AMD
|
||||
kind = 'amd.com/gpu'
|
||||
amd_req = get_resource(resource_request, kind, validate=False)
|
||||
if amd_req is not None: return 1
|
||||
if amd_req is not None:
|
||||
return 1
|
||||
# use CPU
|
||||
cpus = get_resource(resource_request, 'cpu')
|
||||
if cpus is not None and cpus > 0:
|
||||
@@ -279,10 +335,18 @@ class CascadingResourceStrategy(bentoml.Strategy, ReprMixin):
|
||||
return math.ceil(cpus) * workers_per_resource
|
||||
|
||||
# this should not be reached by user since we always read system resource as default
|
||||
raise ValueError(f'No known supported resource available for {runnable_class}. Please check your resource request. Leaving it blank will allow BentoML to use system resources.')
|
||||
raise ValueError(
|
||||
f'No known supported resource available for {runnable_class}. Please check your resource request. Leaving it blank will allow BentoML to use system resources.'
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_worker_env(cls, runnable_class: type[bentoml.Runnable], resource_request: dict[str, t.Any] | None, workers_per_resource: int | float, worker_index: int) -> dict[str, t.Any]:
|
||||
def get_worker_env(
|
||||
cls,
|
||||
runnable_class: type[bentoml.Runnable],
|
||||
resource_request: dict[str, t.Any] | None,
|
||||
workers_per_resource: int | float,
|
||||
worker_index: int,
|
||||
) -> dict[str, t.Any]:
|
||||
"""Get worker env for this given worker_index.
|
||||
|
||||
Args:
|
||||
@@ -295,7 +359,8 @@ class CascadingResourceStrategy(bentoml.Strategy, ReprMixin):
|
||||
disabled = cuda_env in ('', '-1')
|
||||
environ: dict[str, t.Any] = {}
|
||||
|
||||
if resource_request is None: resource_request = system_resources()
|
||||
if resource_request is None:
|
||||
resource_request = system_resources()
|
||||
# use NVIDIA
|
||||
kind = 'nvidia.com/gpu'
|
||||
typ = get_resource(resource_request, kind)
|
||||
@@ -340,20 +405,34 @@ class CascadingResourceStrategy(bentoml.Strategy, ReprMixin):
|
||||
# NOTE: We hit this branch when workers_per_resource is set to
|
||||
# float, for example 0.5 or 0.25
|
||||
if workers_per_resource > 1:
|
||||
raise ValueError("Currently, the default strategy doesn't support workers_per_resource > 1. It is recommended that one should implement a custom strategy in this case.")
|
||||
raise ValueError(
|
||||
"Currently, the default strategy doesn't support workers_per_resource > 1. It is recommended that one should implement a custom strategy in this case."
|
||||
)
|
||||
# We are round the assigned resource here. This means if workers_per_resource=.4
|
||||
# then it will round down to 2. If workers_per_source=0.6, then it will also round up to 2.
|
||||
assigned_resource_per_worker = round(1 / workers_per_resource)
|
||||
if len(gpus) < assigned_resource_per_worker:
|
||||
logger.warning('Failed to allocate %s GPUs for %s (number of available GPUs < assigned workers per resource [%s])', gpus, worker_index, assigned_resource_per_worker)
|
||||
raise IndexError(f"There aren't enough assigned GPU(s) for given worker id '{worker_index}' [required: {assigned_resource_per_worker}].")
|
||||
assigned_gpu = gpus[assigned_resource_per_worker * worker_index:assigned_resource_per_worker * (worker_index + 1)]
|
||||
logger.warning(
|
||||
'Failed to allocate %s GPUs for %s (number of available GPUs < assigned workers per resource [%s])',
|
||||
gpus,
|
||||
worker_index,
|
||||
assigned_resource_per_worker,
|
||||
)
|
||||
raise IndexError(
|
||||
f"There aren't enough assigned GPU(s) for given worker id '{worker_index}' [required: {assigned_resource_per_worker}]."
|
||||
)
|
||||
assigned_gpu = gpus[
|
||||
assigned_resource_per_worker * worker_index : assigned_resource_per_worker * (worker_index + 1)
|
||||
]
|
||||
dev = ','.join(assigned_gpu)
|
||||
else:
|
||||
idx = worker_index // workers_per_resource
|
||||
if idx >= len(gpus):
|
||||
raise ValueError(f'Number of available GPU ({gpus}) preceeds the given workers_per_resource {workers_per_resource}')
|
||||
raise ValueError(
|
||||
f'Number of available GPU ({gpus}) preceeds the given workers_per_resource {workers_per_resource}'
|
||||
)
|
||||
dev = str(gpus[idx])
|
||||
return dev
|
||||
|
||||
|
||||
__all__ = ['CascadingResourceStrategy', 'get_resource']
|
||||
|
||||
@@ -2,15 +2,24 @@
|
||||
|
||||
These utilities will stay internal, and its API can be changed or updated without backward-compatibility.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from openllm_core.utils import LazyModule
|
||||
|
||||
|
||||
_import_structure: dict[str, list[str]] = {
|
||||
'_package': ['create_bento', 'build_editable', 'construct_python_options', 'construct_docker_options'],
|
||||
'oci': ['CONTAINER_NAMES', 'get_base_container_tag', 'build_container', 'get_base_container_name', 'supported_registries', 'RefResolver']
|
||||
'_package': ['create_bento', 'build_editable', 'construct_python_options', 'construct_docker_options'],
|
||||
'oci': [
|
||||
'CONTAINER_NAMES',
|
||||
'get_base_container_tag',
|
||||
'build_container',
|
||||
'get_base_container_name',
|
||||
'supported_registries',
|
||||
'RefResolver',
|
||||
],
|
||||
}
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
|
||||
@@ -27,6 +27,7 @@ from bentoml._internal.configuration.containers import BentoMLContainer
|
||||
|
||||
from . import oci
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from fs.base import FS
|
||||
|
||||
@@ -43,15 +44,22 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
OPENLLM_DEV_BUILD = 'OPENLLM_DEV_BUILD'
|
||||
|
||||
def build_editable(path: str, package: t.Literal['openllm', 'openllm_core', 'openllm_client'] = 'openllm') -> str | None:
|
||||
|
||||
def build_editable(
|
||||
path: str, package: t.Literal['openllm', 'openllm_core', 'openllm_client'] = 'openllm'
|
||||
) -> str | None:
|
||||
"""Build OpenLLM if the OPENLLM_DEV_BUILD environment variable is set."""
|
||||
if openllm_core.utils.check_bool_env(OPENLLM_DEV_BUILD, default=False): return None
|
||||
if openllm_core.utils.check_bool_env(OPENLLM_DEV_BUILD, default=False):
|
||||
return None
|
||||
# We need to build the package in editable mode, so that we can import it
|
||||
from build import ProjectBuilder
|
||||
from build.env import IsolatedEnvBuilder
|
||||
|
||||
module_location = openllm_core.utils.pkg.source_locations(package)
|
||||
if not module_location:
|
||||
raise RuntimeError('Could not find the source location of OpenLLM. Make sure to unset OPENLLM_DEV_BUILD if you are developing OpenLLM.')
|
||||
raise RuntimeError(
|
||||
'Could not find the source location of OpenLLM. Make sure to unset OPENLLM_DEV_BUILD if you are developing OpenLLM.'
|
||||
)
|
||||
pyproject_path = Path(module_location).parent.parent / 'pyproject.toml'
|
||||
if os.path.isfile(pyproject_path.__fspath__()):
|
||||
logger.info('Generating built wheels for package %s...', package)
|
||||
@@ -61,57 +69,98 @@ def build_editable(path: str, package: t.Literal['openllm', 'openllm_core', 'ope
|
||||
builder.scripts_dir = env.scripts_dir
|
||||
env.install(builder.build_system_requires)
|
||||
return builder.build('wheel', path, config_settings={'--global-option': '--quiet'})
|
||||
raise RuntimeError('Custom OpenLLM build is currently not supported. Please install OpenLLM from PyPI or built it from Git source.')
|
||||
raise RuntimeError(
|
||||
'Custom OpenLLM build is currently not supported. Please install OpenLLM from PyPI or built it from Git source.'
|
||||
)
|
||||
|
||||
def construct_python_options(llm: openllm.LLM[t.Any, t.Any], llm_fs: FS, extra_dependencies: tuple[str, ...] | None = None, adapter_map: dict[str, str] | None = None) -> PythonOptions:
|
||||
|
||||
def construct_python_options(
|
||||
llm: openllm.LLM[t.Any, t.Any],
|
||||
llm_fs: FS,
|
||||
extra_dependencies: tuple[str, ...] | None = None,
|
||||
adapter_map: dict[str, str] | None = None,
|
||||
) -> PythonOptions:
|
||||
packages = ['openllm', 'scipy'] # apparently bnb misses this one
|
||||
if adapter_map is not None: packages += ['openllm[fine-tune]']
|
||||
if adapter_map is not None:
|
||||
packages += ['openllm[fine-tune]']
|
||||
# NOTE: add openllm to the default dependencies
|
||||
# if users has openllm custom built wheels, it will still respect
|
||||
# that since bentoml will always install dependencies from requirements.txt
|
||||
# first, then proceed to install everything inside the wheels/ folder.
|
||||
if extra_dependencies is not None: packages += [f'openllm[{k}]' for k in extra_dependencies]
|
||||
if extra_dependencies is not None:
|
||||
packages += [f'openllm[{k}]' for k in extra_dependencies]
|
||||
|
||||
req = llm.config['requirements']
|
||||
if req is not None: packages.extend(req)
|
||||
if req is not None:
|
||||
packages.extend(req)
|
||||
if str(os.environ.get('BENTOML_BUNDLE_LOCAL_BUILD', False)).lower() == 'false':
|
||||
packages.append(f"bentoml>={'.'.join([str(i) for i in openllm_core.utils.pkg.pkg_version_info('bentoml')])}")
|
||||
|
||||
if not openllm_core.utils.is_torch_available():
|
||||
raise ValueError('PyTorch is not available. Make sure to have it locally installed.')
|
||||
packages.extend(['torch==2.0.1+cu118', 'vllm==0.2.1.post1', 'xformers==0.0.22', 'bentoml[tracing]==1.1.9']) # XXX: Currently locking this for correctness
|
||||
packages.extend(
|
||||
['torch==2.0.1+cu118', 'vllm==0.2.1.post1', 'xformers==0.0.22', 'bentoml[tracing]==1.1.9']
|
||||
) # XXX: Currently locking this for correctness
|
||||
wheels: list[str] = []
|
||||
built_wheels = [build_editable(llm_fs.getsyspath('/'), t.cast(t.Literal['openllm', 'openllm_core', 'openllm_client'], p)) for p in ('openllm_core', 'openllm_client', 'openllm')]
|
||||
built_wheels = [
|
||||
build_editable(llm_fs.getsyspath('/'), t.cast(t.Literal['openllm', 'openllm_core', 'openllm_client'], p))
|
||||
for p in ('openllm_core', 'openllm_client', 'openllm')
|
||||
]
|
||||
if all(i for i in built_wheels):
|
||||
wheels.extend([llm_fs.getsyspath(f"/{i.split('/')[-1]}") for i in t.cast(t.List[str], built_wheels)])
|
||||
return PythonOptions(packages=packages,
|
||||
wheels=wheels,
|
||||
lock_packages=False,
|
||||
extra_index_url=['https://download.pytorch.org/whl/cu118', 'https://huggingface.github.io/autogptq-index/whl/cu118/'])
|
||||
return PythonOptions(
|
||||
packages=packages,
|
||||
wheels=wheels,
|
||||
lock_packages=False,
|
||||
extra_index_url=[
|
||||
'https://download.pytorch.org/whl/cu118',
|
||||
'https://huggingface.github.io/autogptq-index/whl/cu118/',
|
||||
],
|
||||
)
|
||||
|
||||
def construct_docker_options(llm: openllm.LLM[t.Any, t.Any], _: FS, quantize: LiteralString | None, adapter_map: dict[str, str] | None, dockerfile_template: str | None,
|
||||
serialisation: LiteralSerialisation, container_registry: LiteralContainerRegistry, container_version_strategy: LiteralContainerVersionStrategy) -> DockerOptions:
|
||||
|
||||
def construct_docker_options(
|
||||
llm: openllm.LLM[t.Any, t.Any],
|
||||
_: FS,
|
||||
quantize: LiteralString | None,
|
||||
adapter_map: dict[str, str] | None,
|
||||
dockerfile_template: str | None,
|
||||
serialisation: LiteralSerialisation,
|
||||
container_registry: LiteralContainerRegistry,
|
||||
container_version_strategy: LiteralContainerVersionStrategy,
|
||||
) -> DockerOptions:
|
||||
from openllm.cli._factory import parse_config_options
|
||||
|
||||
environ = parse_config_options(llm.config, llm.config['timeout'], 1.0, None, True, os.environ.copy())
|
||||
env_dict = {
|
||||
'OPENLLM_BACKEND': llm.__llm_backend__,
|
||||
'OPENLLM_CONFIG': f"'{llm.config.model_dump_json(flatten=True).decode()}'",
|
||||
'OPENLLM_SERIALIZATION': serialisation,
|
||||
'BENTOML_DEBUG': str(True),
|
||||
'BENTOML_QUIET': str(False),
|
||||
'BENTOML_CONFIG_OPTIONS': f"'{environ['BENTOML_CONFIG_OPTIONS']}'",
|
||||
'OPENLLM_BACKEND': llm.__llm_backend__,
|
||||
'OPENLLM_CONFIG': f"'{llm.config.model_dump_json(flatten=True).decode()}'",
|
||||
'OPENLLM_SERIALIZATION': serialisation,
|
||||
'BENTOML_DEBUG': str(True),
|
||||
'BENTOML_QUIET': str(False),
|
||||
'BENTOML_CONFIG_OPTIONS': f"'{environ['BENTOML_CONFIG_OPTIONS']}'",
|
||||
}
|
||||
if adapter_map: env_dict['BITSANDBYTES_NOWELCOME'] = os.environ.get('BITSANDBYTES_NOWELCOME', '1')
|
||||
if llm._system_message: env_dict['OPENLLM_SYSTEM_MESSAGE'] = repr(llm._system_message)
|
||||
if llm._prompt_template: env_dict['OPENLLM_PROMPT_TEMPLATE'] = repr(llm._prompt_template.to_string())
|
||||
if quantize: env_dict['OPENLLM_QUANTISE'] = str(quantize)
|
||||
return DockerOptions(base_image=f'{oci.CONTAINER_NAMES[container_registry]}:{oci.get_base_container_tag(container_version_strategy)}', env=env_dict, dockerfile_template=dockerfile_template)
|
||||
if adapter_map:
|
||||
env_dict['BITSANDBYTES_NOWELCOME'] = os.environ.get('BITSANDBYTES_NOWELCOME', '1')
|
||||
if llm._system_message:
|
||||
env_dict['OPENLLM_SYSTEM_MESSAGE'] = repr(llm._system_message)
|
||||
if llm._prompt_template:
|
||||
env_dict['OPENLLM_PROMPT_TEMPLATE'] = repr(llm._prompt_template.to_string())
|
||||
if quantize:
|
||||
env_dict['OPENLLM_QUANTISE'] = str(quantize)
|
||||
return DockerOptions(
|
||||
base_image=f'{oci.CONTAINER_NAMES[container_registry]}:{oci.get_base_container_tag(container_version_strategy)}',
|
||||
env=env_dict,
|
||||
dockerfile_template=dockerfile_template,
|
||||
)
|
||||
|
||||
|
||||
OPENLLM_MODEL_NAME = '# openllm: model name'
|
||||
OPENLLM_MODEL_ID = '# openllm: model id'
|
||||
OPENLLM_MODEL_TAG = '# openllm: model tag'
|
||||
OPENLLM_MODEL_ADAPTER_MAP = '# openllm: model adapter map'
|
||||
|
||||
|
||||
class ModelNameFormatter(string.Formatter):
|
||||
model_keyword: LiteralString = '__model_name__'
|
||||
|
||||
@@ -130,75 +179,122 @@ class ModelNameFormatter(string.Formatter):
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
class ModelIdFormatter(ModelNameFormatter):
|
||||
model_keyword: LiteralString = '__model_id__'
|
||||
|
||||
|
||||
class ModelTagFormatter(ModelNameFormatter):
|
||||
model_keyword: LiteralString = '__model_tag__'
|
||||
|
||||
|
||||
class ModelAdapterMapFormatter(ModelNameFormatter):
|
||||
model_keyword: LiteralString = '__model_adapter_map__'
|
||||
|
||||
|
||||
_service_file = Path(os.path.abspath(__file__)).parent.parent / '_service.py'
|
||||
_service_vars_file = Path(os.path.abspath(__file__)).parent.parent / '_service_vars_pkg.py'
|
||||
|
||||
|
||||
def write_service(llm: openllm.LLM[t.Any, t.Any], adapter_map: dict[str, str] | None, llm_fs: FS) -> None:
|
||||
from openllm_core.utils import DEBUG
|
||||
|
||||
model_name = llm.config['model_name']
|
||||
model_id = llm.model_id
|
||||
model_tag = str(llm.tag)
|
||||
logger.debug('Generating service vars file for %s at %s (dir=%s)', model_name, '_service_vars.py', llm_fs.getsyspath('/'))
|
||||
logger.debug(
|
||||
'Generating service vars file for %s at %s (dir=%s)', model_name, '_service_vars.py', llm_fs.getsyspath('/')
|
||||
)
|
||||
with open(_service_vars_file.__fspath__(), 'r') as f:
|
||||
src_contents = f.readlines()
|
||||
for it in src_contents:
|
||||
if OPENLLM_MODEL_NAME in it:
|
||||
src_contents[src_contents.index(it)] = (ModelNameFormatter(model_name).vformat(it)[:-(len(OPENLLM_MODEL_NAME) + 3)] + '\n')
|
||||
src_contents[src_contents.index(it)] = (
|
||||
ModelNameFormatter(model_name).vformat(it)[: -(len(OPENLLM_MODEL_NAME) + 3)] + '\n'
|
||||
)
|
||||
if OPENLLM_MODEL_ID in it:
|
||||
src_contents[src_contents.index(it)] = (ModelIdFormatter(model_id).vformat(it)[:-(len(OPENLLM_MODEL_ID) + 3)] + '\n')
|
||||
src_contents[src_contents.index(it)] = (
|
||||
ModelIdFormatter(model_id).vformat(it)[: -(len(OPENLLM_MODEL_ID) + 3)] + '\n'
|
||||
)
|
||||
elif OPENLLM_MODEL_TAG in it:
|
||||
src_contents[src_contents.index(it)] = (ModelTagFormatter(model_tag).vformat(it)[:-(len(OPENLLM_MODEL_TAG) + 3)] + '\n')
|
||||
src_contents[src_contents.index(it)] = (
|
||||
ModelTagFormatter(model_tag).vformat(it)[: -(len(OPENLLM_MODEL_TAG) + 3)] + '\n'
|
||||
)
|
||||
elif OPENLLM_MODEL_ADAPTER_MAP in it:
|
||||
src_contents[src_contents.index(it)] = (ModelAdapterMapFormatter(orjson.dumps(adapter_map).decode()).vformat(it)[:-(len(OPENLLM_MODEL_ADAPTER_MAP) + 3)] + '\n')
|
||||
src_contents[src_contents.index(it)] = (
|
||||
ModelAdapterMapFormatter(orjson.dumps(adapter_map).decode()).vformat(it)[
|
||||
: -(len(OPENLLM_MODEL_ADAPTER_MAP) + 3)
|
||||
]
|
||||
+ '\n'
|
||||
)
|
||||
script = f"# GENERATED BY 'openllm build {model_name}'. DO NOT EDIT\n\n" + ''.join(src_contents)
|
||||
if DEBUG: logger.info('Generated script:\n%s', script)
|
||||
if DEBUG:
|
||||
logger.info('Generated script:\n%s', script)
|
||||
llm_fs.writetext('_service_vars.py', script)
|
||||
|
||||
logger.debug('Generating service file for %s at %s (dir=%s)', model_name, llm.config['service_name'], llm_fs.getsyspath('/'))
|
||||
logger.debug(
|
||||
'Generating service file for %s at %s (dir=%s)', model_name, llm.config['service_name'], llm_fs.getsyspath('/')
|
||||
)
|
||||
with open(_service_file.__fspath__(), 'r') as f:
|
||||
service_src = f.read()
|
||||
llm_fs.writetext(llm.config['service_name'], service_src)
|
||||
|
||||
|
||||
@inject
|
||||
def create_bento(bento_tag: bentoml.Tag,
|
||||
llm_fs: FS,
|
||||
llm: openllm.LLM[t.Any, t.Any],
|
||||
quantize: LiteralString | None,
|
||||
dockerfile_template: str | None,
|
||||
adapter_map: dict[str, str] | None = None,
|
||||
extra_dependencies: tuple[str, ...] | None = None,
|
||||
serialisation: LiteralSerialisation | None = None,
|
||||
container_registry: LiteralContainerRegistry = 'ecr',
|
||||
container_version_strategy: LiteralContainerVersionStrategy = 'release',
|
||||
_bento_store: BentoStore = Provide[BentoMLContainer.bento_store],
|
||||
_model_store: ModelStore = Provide[BentoMLContainer.model_store]) -> bentoml.Bento:
|
||||
_serialisation: LiteralSerialisation = openllm_core.utils.first_not_none(serialisation, default=llm.config['serialisation'])
|
||||
def create_bento(
|
||||
bento_tag: bentoml.Tag,
|
||||
llm_fs: FS,
|
||||
llm: openllm.LLM[t.Any, t.Any],
|
||||
quantize: LiteralString | None,
|
||||
dockerfile_template: str | None,
|
||||
adapter_map: dict[str, str] | None = None,
|
||||
extra_dependencies: tuple[str, ...] | None = None,
|
||||
serialisation: LiteralSerialisation | None = None,
|
||||
container_registry: LiteralContainerRegistry = 'ecr',
|
||||
container_version_strategy: LiteralContainerVersionStrategy = 'release',
|
||||
_bento_store: BentoStore = Provide[BentoMLContainer.bento_store],
|
||||
_model_store: ModelStore = Provide[BentoMLContainer.model_store],
|
||||
) -> bentoml.Bento:
|
||||
_serialisation: LiteralSerialisation = openllm_core.utils.first_not_none(
|
||||
serialisation, default=llm.config['serialisation']
|
||||
)
|
||||
labels = dict(llm.identifying_params)
|
||||
labels.update({'_type': llm.llm_type, '_framework': llm.__llm_backend__, 'start_name': llm.config['start_name'], 'base_name_or_path': llm.model_id, 'bundler': 'openllm.bundle'})
|
||||
if adapter_map: labels.update(adapter_map)
|
||||
labels.update(
|
||||
{
|
||||
'_type': llm.llm_type,
|
||||
'_framework': llm.__llm_backend__,
|
||||
'start_name': llm.config['start_name'],
|
||||
'base_name_or_path': llm.model_id,
|
||||
'bundler': 'openllm.bundle',
|
||||
}
|
||||
)
|
||||
if adapter_map:
|
||||
labels.update(adapter_map)
|
||||
logger.debug("Building Bento '%s' with model backend '%s'", bento_tag, llm.__llm_backend__)
|
||||
# add service.py definition to this temporary folder
|
||||
write_service(llm, adapter_map, llm_fs)
|
||||
|
||||
llm_spec = ModelSpec.from_item({'tag': str(llm.tag), 'alias': llm.tag.name})
|
||||
build_config = BentoBuildConfig(service=f"{llm.config['service_name']}:svc",
|
||||
name=bento_tag.name,
|
||||
labels=labels,
|
||||
models=[llm_spec],
|
||||
description=f"OpenLLM service for {llm.config['start_name']}",
|
||||
include=list(llm_fs.walk.files()),
|
||||
exclude=['/venv', '/.venv', '__pycache__/', '*.py[cod]', '*$py.class'],
|
||||
python=construct_python_options(llm, llm_fs, extra_dependencies, adapter_map),
|
||||
docker=construct_docker_options(llm, llm_fs, quantize, adapter_map, dockerfile_template, _serialisation, container_registry, container_version_strategy))
|
||||
build_config = BentoBuildConfig(
|
||||
service=f"{llm.config['service_name']}:svc",
|
||||
name=bento_tag.name,
|
||||
labels=labels,
|
||||
models=[llm_spec],
|
||||
description=f"OpenLLM service for {llm.config['start_name']}",
|
||||
include=list(llm_fs.walk.files()),
|
||||
exclude=['/venv', '/.venv', '__pycache__/', '*.py[cod]', '*$py.class'],
|
||||
python=construct_python_options(llm, llm_fs, extra_dependencies, adapter_map),
|
||||
docker=construct_docker_options(
|
||||
llm,
|
||||
llm_fs,
|
||||
quantize,
|
||||
adapter_map,
|
||||
dockerfile_template,
|
||||
_serialisation,
|
||||
container_registry,
|
||||
container_version_strategy,
|
||||
),
|
||||
)
|
||||
|
||||
bento = bentoml.Bento.create(build_config=build_config, version=bento_tag.version, build_ctx=llm_fs.getsyspath('/'))
|
||||
# NOTE: the model_id_path here are only used for setting this environment variable within the container built with for BentoLLM.
|
||||
@@ -208,10 +304,12 @@ def create_bento(bento_tag: bentoml.Tag,
|
||||
service_contents = f.readlines()
|
||||
|
||||
for it in service_contents:
|
||||
if '__bento_name__' in it: service_contents[service_contents.index(it)] = it.format(__bento_name__=str(bento.tag))
|
||||
if '__bento_name__' in it:
|
||||
service_contents[service_contents.index(it)] = it.format(__bento_name__=str(bento.tag))
|
||||
|
||||
script = ''.join(service_contents)
|
||||
if openllm_core.utils.DEBUG: logger.info('Generated script:\n%s', script)
|
||||
if openllm_core.utils.DEBUG:
|
||||
logger.info('Generated script:\n%s', script)
|
||||
|
||||
bento._fs.writetext(service_fs_path, script)
|
||||
if 'model_store' in inspect.signature(bento.save).parameters:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# mypy: disable-error-code="misc"
|
||||
"""OCI-related utilities for OpenLLM. This module is considered to be internal and API are subjected to change."""
|
||||
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
import importlib
|
||||
@@ -23,6 +24,7 @@ import openllm_core
|
||||
|
||||
from openllm_core.utils.lazy import VersionInfo
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ghapi import all
|
||||
|
||||
@@ -42,7 +44,11 @@ ROOT_DIR = pathlib.Path(os.path.abspath('__file__')).parent.parent.parent
|
||||
# but in the future, we can infer based on git repo and everything to make it more options for users
|
||||
# to build the base image. For now, all of the base image will be <registry>/bentoml/openllm:...
|
||||
# NOTE: The ECR registry is the public one and currently only @bentoml team has access to push it.
|
||||
_CONTAINER_REGISTRY: dict[LiteralContainerRegistry, str] = {'docker': 'docker.io/bentoml/openllm', 'gh': 'ghcr.io/bentoml/openllm', 'ecr': 'public.ecr.aws/y5w8i4y6/bentoml/openllm'}
|
||||
_CONTAINER_REGISTRY: dict[LiteralContainerRegistry, str] = {
|
||||
'docker': 'docker.io/bentoml/openllm',
|
||||
'gh': 'ghcr.io/bentoml/openllm',
|
||||
'ecr': 'public.ecr.aws/y5w8i4y6/bentoml/openllm',
|
||||
}
|
||||
|
||||
# TODO: support custom fork. Currently it only support openllm main.
|
||||
_OWNER = 'bentoml'
|
||||
@@ -50,21 +56,29 @@ _REPO = 'openllm'
|
||||
|
||||
_module_location = openllm_core.utils.pkg.source_locations('openllm')
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
@openllm_core.utils.apply(str.lower)
|
||||
def get_base_container_name(reg: LiteralContainerRegistry) -> str:
|
||||
return _CONTAINER_REGISTRY[reg]
|
||||
|
||||
|
||||
def _convert_version_from_string(s: str) -> VersionInfo:
|
||||
return VersionInfo.from_version_string(s)
|
||||
|
||||
|
||||
def _commit_time_range(r: int = 5) -> str:
|
||||
return (datetime.now(timezone.utc) - timedelta(days=r)).strftime('%Y-%m-%dT%H:%M:%SZ')
|
||||
|
||||
|
||||
class VersionNotSupported(openllm.exceptions.OpenLLMException):
|
||||
"""Raised when the stable release is too low that it doesn't include OpenLLM base container."""
|
||||
|
||||
_RefTuple: type[RefTuple] = openllm_core.utils.codegen.make_attr_tuple_class('_RefTuple', ['git_hash', 'version', 'strategy'])
|
||||
|
||||
_RefTuple: type[RefTuple] = openllm_core.utils.codegen.make_attr_tuple_class(
|
||||
'_RefTuple', ['git_hash', 'version', 'strategy']
|
||||
)
|
||||
|
||||
|
||||
def nightly_resolver(cls: type[RefResolver]) -> str:
|
||||
# NOTE: all openllm container will have sha-<git_hash[:7]>
|
||||
@@ -73,12 +87,27 @@ def nightly_resolver(cls: type[RefResolver]) -> str:
|
||||
docker_bin = shutil.which('docker')
|
||||
if docker_bin is None:
|
||||
logger.warning(
|
||||
'To get the correct available nightly container, make sure to have docker available. Fallback to previous behaviour for determine nightly hash (container might not exists due to the lack of GPU machine at a time. See https://github.com/bentoml/OpenLLM/pkgs/container/openllm for available image.)'
|
||||
'To get the correct available nightly container, make sure to have docker available. Fallback to previous behaviour for determine nightly hash (container might not exists due to the lack of GPU machine at a time. See https://github.com/bentoml/OpenLLM/pkgs/container/openllm for available image.)'
|
||||
)
|
||||
commits = t.cast('list[dict[str, t.Any]]', cls._ghapi.repos.list_commits(since=_commit_time_range()))
|
||||
return next(f'sha-{it["sha"][:7]}' for it in commits if '[skip ci]' not in it['commit']['message'])
|
||||
# now is the correct behaviour
|
||||
return orjson.loads(subprocess.check_output([docker_bin, 'run', '--rm', '-it', 'quay.io/skopeo/stable:latest', 'list-tags', 'docker://ghcr.io/bentoml/openllm']).decode().strip())['Tags'][-2]
|
||||
return orjson.loads(
|
||||
subprocess.check_output(
|
||||
[
|
||||
docker_bin,
|
||||
'run',
|
||||
'--rm',
|
||||
'-it',
|
||||
'quay.io/skopeo/stable:latest',
|
||||
'list-tags',
|
||||
'docker://ghcr.io/bentoml/openllm',
|
||||
]
|
||||
)
|
||||
.decode()
|
||||
.strip()
|
||||
)['Tags'][-2]
|
||||
|
||||
|
||||
@attr.attrs(eq=False, order=False, slots=True, frozen=True)
|
||||
class RefResolver:
|
||||
@@ -98,80 +127,124 @@ class RefResolver:
|
||||
# NOTE: This strategy will only support openllm>0.2.12
|
||||
meta: dict[str, t.Any] = cls._ghapi.repos.get_latest_release()
|
||||
version_str = meta['name'].lstrip('v')
|
||||
version: tuple[str, str | None] = (cls._ghapi.git.get_ref(ref=f"tags/{meta['name']}")['object']['sha'], version_str)
|
||||
version: tuple[str, str | None] = (
|
||||
cls._ghapi.git.get_ref(ref=f"tags/{meta['name']}")['object']['sha'],
|
||||
version_str,
|
||||
)
|
||||
else:
|
||||
version = ('', version_str)
|
||||
if openllm_core.utils.VersionInfo.from_version_string(t.cast(str, version_str)) < (0, 2, 12):
|
||||
raise VersionNotSupported(f"Version {version_str} doesn't support OpenLLM base container. Consider using 'nightly' or upgrade 'openllm>=0.2.12'")
|
||||
raise VersionNotSupported(
|
||||
f"Version {version_str} doesn't support OpenLLM base container. Consider using 'nightly' or upgrade 'openllm>=0.2.12'"
|
||||
)
|
||||
return _RefTuple((*version, 'release' if _use_base_strategy else 'custom'))
|
||||
|
||||
@classmethod
|
||||
@functools.lru_cache(maxsize=64)
|
||||
def from_strategy(cls, strategy_or_version: t.Literal['release', 'nightly'] | LiteralString | None = None) -> RefResolver:
|
||||
def from_strategy(
|
||||
cls, strategy_or_version: t.Literal['release', 'nightly'] | LiteralString | None = None
|
||||
) -> RefResolver:
|
||||
# using default strategy
|
||||
if strategy_or_version is None or strategy_or_version == 'release': return cls(*cls._release_ref())
|
||||
elif strategy_or_version == 'latest': return cls('latest', '0.0.0', 'latest')
|
||||
if strategy_or_version is None or strategy_or_version == 'release':
|
||||
return cls(*cls._release_ref())
|
||||
elif strategy_or_version == 'latest':
|
||||
return cls('latest', '0.0.0', 'latest')
|
||||
elif strategy_or_version == 'nightly':
|
||||
_ref = cls._nightly_ref()
|
||||
return cls(_ref[0], '0.0.0', _ref[-1])
|
||||
else:
|
||||
logger.warning('Using custom %s. Make sure that it is at lease 0.2.12 for base container support.', strategy_or_version)
|
||||
logger.warning(
|
||||
'Using custom %s. Make sure that it is at lease 0.2.12 for base container support.', strategy_or_version
|
||||
)
|
||||
return cls(*cls._release_ref(version_str=strategy_or_version))
|
||||
|
||||
@property
|
||||
def tag(self) -> str:
|
||||
# NOTE: latest tag can also be nightly, but discouraged to use it. For nightly refer to use sha-<git_hash_short>
|
||||
if self.strategy == 'latest': return 'latest'
|
||||
elif self.strategy == 'nightly': return self.git_hash
|
||||
else: return repr(self.version)
|
||||
if self.strategy == 'latest':
|
||||
return 'latest'
|
||||
elif self.strategy == 'nightly':
|
||||
return self.git_hash
|
||||
else:
|
||||
return repr(self.version)
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=256)
|
||||
def get_base_container_tag(strategy: LiteralContainerVersionStrategy | None = None) -> str:
|
||||
return RefResolver.from_strategy(strategy).tag
|
||||
|
||||
def build_container(registries: LiteralContainerRegistry | t.Sequence[LiteralContainerRegistry] | None = None,
|
||||
version_strategy: LiteralContainerVersionStrategy = 'release',
|
||||
push: bool = False,
|
||||
machine: bool = False) -> dict[str | LiteralContainerRegistry, str]:
|
||||
|
||||
def build_container(
|
||||
registries: LiteralContainerRegistry | t.Sequence[LiteralContainerRegistry] | None = None,
|
||||
version_strategy: LiteralContainerVersionStrategy = 'release',
|
||||
push: bool = False,
|
||||
machine: bool = False,
|
||||
) -> dict[str | LiteralContainerRegistry, str]:
|
||||
try:
|
||||
if not _BUILDER.health(): raise openllm.exceptions.Error
|
||||
if not _BUILDER.health():
|
||||
raise openllm.exceptions.Error
|
||||
except (openllm.exceptions.Error, subprocess.CalledProcessError):
|
||||
raise RuntimeError('Building base container requires BuildKit (via Buildx) to be installed. See https://docs.docker.com/build/buildx/install/ for instalation instruction.') from None
|
||||
raise RuntimeError(
|
||||
'Building base container requires BuildKit (via Buildx) to be installed. See https://docs.docker.com/build/buildx/install/ for instalation instruction.'
|
||||
) from None
|
||||
if not shutil.which('nvidia-container-runtime'):
|
||||
raise RuntimeError('NVIDIA Container Toolkit is required to compile CUDA kernel in container.')
|
||||
if not _module_location:
|
||||
raise RuntimeError("Failed to determine source location of 'openllm'. (Possible broken installation)")
|
||||
pyproject_path = pathlib.Path(_module_location).parent.parent / 'pyproject.toml'
|
||||
if not pyproject_path.exists():
|
||||
raise ValueError("This utility can only be run within OpenLLM git repository. Clone it first with 'git clone https://github.com/bentoml/OpenLLM.git'")
|
||||
raise ValueError(
|
||||
"This utility can only be run within OpenLLM git repository. Clone it first with 'git clone https://github.com/bentoml/OpenLLM.git'"
|
||||
)
|
||||
if not registries:
|
||||
tags: dict[str | LiteralContainerRegistry, str] = {alias: f'{value}:{get_base_container_tag(version_strategy)}' for alias, value in _CONTAINER_REGISTRY.items()}
|
||||
tags: dict[str | LiteralContainerRegistry, str] = {
|
||||
alias: f'{value}:{get_base_container_tag(version_strategy)}' for alias, value in _CONTAINER_REGISTRY.items()
|
||||
}
|
||||
else:
|
||||
registries = [registries] if isinstance(registries, str) else list(registries)
|
||||
tags = {name: f'{_CONTAINER_REGISTRY[name]}:{get_base_container_tag(version_strategy)}' for name in registries}
|
||||
try:
|
||||
outputs = _BUILDER.build(file=pathlib.Path(__file__).parent.joinpath('Dockerfile').resolve().__fspath__(),
|
||||
context_path=pyproject_path.parent.__fspath__(),
|
||||
tag=tuple(tags.values()),
|
||||
push=push,
|
||||
progress='plain' if openllm_core.utils.get_debug_mode() else 'auto',
|
||||
quiet=machine)
|
||||
if machine and outputs is not None: tags['image_sha'] = outputs.decode('utf-8').strip()
|
||||
outputs = _BUILDER.build(
|
||||
file=pathlib.Path(__file__).parent.joinpath('Dockerfile').resolve().__fspath__(),
|
||||
context_path=pyproject_path.parent.__fspath__(),
|
||||
tag=tuple(tags.values()),
|
||||
push=push,
|
||||
progress='plain' if openllm_core.utils.get_debug_mode() else 'auto',
|
||||
quiet=machine,
|
||||
)
|
||||
if machine and outputs is not None:
|
||||
tags['image_sha'] = outputs.decode('utf-8').strip()
|
||||
except Exception as err:
|
||||
raise openllm.exceptions.OpenLLMException(f'Failed to containerize base container images (Scroll up to see error above, or set OPENLLMDEVDEBUG=True for more traceback):\n{err}') from err
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
f'Failed to containerize base container images (Scroll up to see error above, or set OPENLLMDEVDEBUG=True for more traceback):\n{err}'
|
||||
) from err
|
||||
return tags
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
CONTAINER_NAMES: dict[LiteralContainerRegistry, str]
|
||||
supported_registries: list[str]
|
||||
|
||||
__all__ = ['CONTAINER_NAMES', 'get_base_container_tag', 'build_container', 'get_base_container_name', 'supported_registries', 'RefResolver']
|
||||
__all__ = [
|
||||
'CONTAINER_NAMES',
|
||||
'get_base_container_tag',
|
||||
'build_container',
|
||||
'get_base_container_name',
|
||||
'supported_registries',
|
||||
'RefResolver',
|
||||
]
|
||||
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return sorted(__all__)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> t.Any:
|
||||
if name == 'supported_registries': return functools.lru_cache(1)(lambda: list(_CONTAINER_REGISTRY))()
|
||||
elif name == 'CONTAINER_NAMES': return _CONTAINER_REGISTRY
|
||||
elif name in __all__: return importlib.import_module('.' + name, __name__)
|
||||
else: raise AttributeError(f'{name} does not exists under {__name__}')
|
||||
if name == 'supported_registries':
|
||||
return functools.lru_cache(1)(lambda: list(_CONTAINER_REGISTRY))()
|
||||
elif name == 'CONTAINER_NAMES':
|
||||
return _CONTAINER_REGISTRY
|
||||
elif name in __all__:
|
||||
return importlib.import_module('.' + name, __name__)
|
||||
else:
|
||||
raise AttributeError(f'{name} does not exists under {__name__}')
|
||||
|
||||
@@ -25,8 +25,14 @@ from openllm_core._typing_compat import ParamSpec
|
||||
from openllm_core._typing_compat import get_literal_args
|
||||
from openllm_core.utils import DEBUG
|
||||
|
||||
|
||||
class _OpenLLM_GenericInternalConfig(LLMConfig):
|
||||
__config__ = {'name_type': 'lowercase', 'default_id': 'openllm/generic', 'model_ids': ['openllm/generic'], 'architecture': 'PreTrainedModel'}
|
||||
__config__ = {
|
||||
'name_type': 'lowercase',
|
||||
'default_id': 'openllm/generic',
|
||||
'model_ids': ['openllm/generic'],
|
||||
'architecture': 'PreTrainedModel',
|
||||
}
|
||||
|
||||
class GenerationConfig:
|
||||
top_k: int = 15
|
||||
@@ -34,6 +40,7 @@ class _OpenLLM_GenericInternalConfig(LLMConfig):
|
||||
temperature: float = 0.75
|
||||
max_new_tokens: int = 128
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
P = ParamSpec('P')
|
||||
@@ -42,37 +49,76 @@ LiteralOutput = t.Literal['json', 'pretty', 'porcelain']
|
||||
_AnyCallable = t.Callable[..., t.Any]
|
||||
FC = t.TypeVar('FC', bound=t.Union[_AnyCallable, click.Command])
|
||||
|
||||
|
||||
def bento_complete_envvar(ctx: click.Context, param: click.Parameter, incomplete: str) -> list[sc.CompletionItem]:
|
||||
return [sc.CompletionItem(str(it.tag), help='Bento') for it in bentoml.list() if str(it.tag).startswith(incomplete) and all(k in it.info.labels for k in {'start_name', 'bundler'})]
|
||||
return [
|
||||
sc.CompletionItem(str(it.tag), help='Bento')
|
||||
for it in bentoml.list()
|
||||
if str(it.tag).startswith(incomplete) and all(k in it.info.labels for k in {'start_name', 'bundler'})
|
||||
]
|
||||
|
||||
|
||||
def model_complete_envvar(ctx: click.Context, param: click.Parameter, incomplete: str) -> list[sc.CompletionItem]:
|
||||
return [sc.CompletionItem(inflection.dasherize(it), help='Model') for it in openllm.CONFIG_MAPPING if it.startswith(incomplete)]
|
||||
return [
|
||||
sc.CompletionItem(inflection.dasherize(it), help='Model')
|
||||
for it in openllm.CONFIG_MAPPING
|
||||
if it.startswith(incomplete)
|
||||
]
|
||||
|
||||
def parse_config_options(config: LLMConfig, server_timeout: int, workers_per_resource: float, device: t.Tuple[str, ...] | None, cors: bool, environ: DictStrAny) -> DictStrAny:
|
||||
|
||||
def parse_config_options(
|
||||
config: LLMConfig,
|
||||
server_timeout: int,
|
||||
workers_per_resource: float,
|
||||
device: t.Tuple[str, ...] | None,
|
||||
cors: bool,
|
||||
environ: DictStrAny,
|
||||
) -> DictStrAny:
|
||||
# TODO: Support amd.com/gpu on k8s
|
||||
_bentoml_config_options_env = environ.pop('BENTOML_CONFIG_OPTIONS', '')
|
||||
_bentoml_config_options_opts = [
|
||||
'tracing.sample_rate=1.0', f'api_server.traffic.timeout={server_timeout}', f'runners."llm-{config["start_name"]}-runner".traffic.timeout={config["timeout"]}',
|
||||
f'runners."llm-{config["start_name"]}-runner".workers_per_resource={workers_per_resource}'
|
||||
'tracing.sample_rate=1.0',
|
||||
f'api_server.traffic.timeout={server_timeout}',
|
||||
f'runners."llm-{config["start_name"]}-runner".traffic.timeout={config["timeout"]}',
|
||||
f'runners."llm-{config["start_name"]}-runner".workers_per_resource={workers_per_resource}',
|
||||
]
|
||||
if device:
|
||||
if len(device) > 1:
|
||||
_bentoml_config_options_opts.extend([f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"[{idx}]={dev}' for idx, dev in enumerate(device)])
|
||||
_bentoml_config_options_opts.extend(
|
||||
[
|
||||
f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"[{idx}]={dev}'
|
||||
for idx, dev in enumerate(device)
|
||||
]
|
||||
)
|
||||
else:
|
||||
_bentoml_config_options_opts.append(f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"=[{device[0]}]')
|
||||
_bentoml_config_options_opts.append(
|
||||
f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"=[{device[0]}]'
|
||||
)
|
||||
if cors:
|
||||
_bentoml_config_options_opts.extend(['api_server.http.cors.enabled=true', 'api_server.http.cors.access_control_allow_origins="*"'])
|
||||
_bentoml_config_options_opts.extend([f'api_server.http.cors.access_control_allow_methods[{idx}]="{it}"' for idx, it in enumerate(['GET', 'OPTIONS', 'POST', 'HEAD', 'PUT'])])
|
||||
_bentoml_config_options_opts.extend(
|
||||
['api_server.http.cors.enabled=true', 'api_server.http.cors.access_control_allow_origins="*"']
|
||||
)
|
||||
_bentoml_config_options_opts.extend(
|
||||
[
|
||||
f'api_server.http.cors.access_control_allow_methods[{idx}]="{it}"'
|
||||
for idx, it in enumerate(['GET', 'OPTIONS', 'POST', 'HEAD', 'PUT'])
|
||||
]
|
||||
)
|
||||
_bentoml_config_options_env += ' ' if _bentoml_config_options_env else '' + ' '.join(_bentoml_config_options_opts)
|
||||
environ['BENTOML_CONFIG_OPTIONS'] = _bentoml_config_options_env
|
||||
if DEBUG: logger.debug('Setting BENTOML_CONFIG_OPTIONS=%s', _bentoml_config_options_env)
|
||||
if DEBUG:
|
||||
logger.debug('Setting BENTOML_CONFIG_OPTIONS=%s', _bentoml_config_options_env)
|
||||
return environ
|
||||
|
||||
|
||||
_adapter_mapping_key = 'adapter_map'
|
||||
|
||||
|
||||
def _id_callback(ctx: click.Context, _: click.Parameter, value: t.Tuple[str, ...] | None) -> None:
|
||||
if not value: return None
|
||||
if _adapter_mapping_key not in ctx.params: ctx.params[_adapter_mapping_key] = {}
|
||||
if not value:
|
||||
return None
|
||||
if _adapter_mapping_key not in ctx.params:
|
||||
ctx.params[_adapter_mapping_key] = {}
|
||||
for v in value:
|
||||
adapter_id, *adapter_name = v.rsplit(':', maxsplit=1)
|
||||
# try to resolve the full path if users pass in relative,
|
||||
@@ -81,20 +127,28 @@ def _id_callback(ctx: click.Context, _: click.Parameter, value: t.Tuple[str, ...
|
||||
adapter_id = openllm.utils.resolve_user_filepath(adapter_id, os.getcwd())
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
if len(adapter_name) == 0: raise ClickException(f'Adapter name is required for {adapter_id}')
|
||||
if len(adapter_name) == 0:
|
||||
raise ClickException(f'Adapter name is required for {adapter_id}')
|
||||
ctx.params[_adapter_mapping_key][adapter_id] = adapter_name[0]
|
||||
return None
|
||||
|
||||
|
||||
def start_decorator(serve_grpc: bool = False) -> t.Callable[[FC], t.Callable[[FC], FC]]:
|
||||
def wrapper(fn: FC) -> t.Callable[[FC], FC]:
|
||||
composed = openllm.utils.compose(
|
||||
_OpenLLM_GenericInternalConfig().to_click_options, _http_server_args if not serve_grpc else _grpc_server_args,
|
||||
cog.optgroup.group('General LLM Options', help='The following options are related to running LLM Server.'), model_version_option(factory=cog.optgroup),
|
||||
system_message_option(factory=cog.optgroup), prompt_template_file_option(factory=cog.optgroup),
|
||||
cog.optgroup.option('--server-timeout', type=int, default=None, help='Server timeout in seconds'), workers_per_resource_option(factory=cog.optgroup), cors_option(factory=cog.optgroup),
|
||||
backend_option(factory=cog.optgroup),
|
||||
cog.optgroup.group('LLM Optimization Options',
|
||||
help='''Optimization related options.
|
||||
_OpenLLM_GenericInternalConfig().to_click_options,
|
||||
_http_server_args if not serve_grpc else _grpc_server_args,
|
||||
cog.optgroup.group('General LLM Options', help='The following options are related to running LLM Server.'),
|
||||
model_version_option(factory=cog.optgroup),
|
||||
system_message_option(factory=cog.optgroup),
|
||||
prompt_template_file_option(factory=cog.optgroup),
|
||||
cog.optgroup.option('--server-timeout', type=int, default=None, help='Server timeout in seconds'),
|
||||
workers_per_resource_option(factory=cog.optgroup),
|
||||
cors_option(factory=cog.optgroup),
|
||||
backend_option(factory=cog.optgroup),
|
||||
cog.optgroup.group(
|
||||
'LLM Optimization Options',
|
||||
help="""Optimization related options.
|
||||
|
||||
OpenLLM supports running model k-bit quantization (8-bit, 4-bit), GPTQ quantization, PagedAttention via vLLM.
|
||||
|
||||
@@ -102,16 +156,22 @@ def start_decorator(serve_grpc: bool = False) -> t.Callable[[FC], t.Callable[[FC
|
||||
|
||||
- DeepSpeed Inference: [link](https://www.deepspeed.ai/inference/)
|
||||
- GGML: Fast inference on [bare metal](https://github.com/ggerganov/ggml)
|
||||
'''), quantize_option(factory=cog.optgroup), serialisation_option(factory=cog.optgroup),
|
||||
cog.optgroup.option('--device',
|
||||
type=openllm.utils.dantic.CUDA,
|
||||
multiple=True,
|
||||
envvar='CUDA_VISIBLE_DEVICES',
|
||||
callback=parse_device_callback,
|
||||
help='Assign GPU devices (if available)',
|
||||
show_envvar=True),
|
||||
cog.optgroup.group('Fine-tuning related options',
|
||||
help='''\
|
||||
""",
|
||||
),
|
||||
quantize_option(factory=cog.optgroup),
|
||||
serialisation_option(factory=cog.optgroup),
|
||||
cog.optgroup.option(
|
||||
'--device',
|
||||
type=openllm.utils.dantic.CUDA,
|
||||
multiple=True,
|
||||
envvar='CUDA_VISIBLE_DEVICES',
|
||||
callback=parse_device_callback,
|
||||
help='Assign GPU devices (if available)',
|
||||
show_envvar=True,
|
||||
),
|
||||
cog.optgroup.group(
|
||||
'Fine-tuning related options',
|
||||
help="""\
|
||||
Note that the argument `--adapter-id` can accept the following format:
|
||||
|
||||
- `--adapter-id /path/to/adapter` (local adapter)
|
||||
@@ -125,46 +185,62 @@ def start_decorator(serve_grpc: bool = False) -> t.Callable[[FC], t.Callable[[FC
|
||||
$ openllm start opt --adapter-id /path/to/adapter_dir --adapter-id remote/adapter:eng_lora
|
||||
|
||||
```
|
||||
'''),
|
||||
cog.optgroup.option('--adapter-id',
|
||||
default=None,
|
||||
help='Optional name or path for given LoRA adapter',
|
||||
multiple=True,
|
||||
callback=_id_callback,
|
||||
metavar='[PATH | [remote/][adapter_name:]adapter_id][, ...]'), click.option('--return-process', is_flag=True, default=False, help='Internal use only.',
|
||||
hidden=True),
|
||||
""",
|
||||
),
|
||||
cog.optgroup.option(
|
||||
'--adapter-id',
|
||||
default=None,
|
||||
help='Optional name or path for given LoRA adapter',
|
||||
multiple=True,
|
||||
callback=_id_callback,
|
||||
metavar='[PATH | [remote/][adapter_name:]adapter_id][, ...]',
|
||||
),
|
||||
click.option('--return-process', is_flag=True, default=False, help='Internal use only.', hidden=True),
|
||||
)
|
||||
return composed(fn)
|
||||
|
||||
return wrapper
|
||||
|
||||
def parse_device_callback(ctx: click.Context, param: click.Parameter, value: tuple[tuple[str], ...] | None) -> t.Tuple[str, ...] | None:
|
||||
if value is None: return value
|
||||
if not isinstance(value, tuple): ctx.fail(f'{param} only accept multiple values, not {type(value)} (value: {value})')
|
||||
|
||||
def parse_device_callback(
|
||||
ctx: click.Context, param: click.Parameter, value: tuple[tuple[str], ...] | None
|
||||
) -> t.Tuple[str, ...] | None:
|
||||
if value is None:
|
||||
return value
|
||||
if not isinstance(value, tuple):
|
||||
ctx.fail(f'{param} only accept multiple values, not {type(value)} (value: {value})')
|
||||
el: t.Tuple[str, ...] = tuple(i for k in value for i in k)
|
||||
# NOTE: --device all is a special case
|
||||
if len(el) == 1 and el[0] == 'all': return tuple(map(str, openllm.utils.available_devices()))
|
||||
if len(el) == 1 and el[0] == 'all':
|
||||
return tuple(map(str, openllm.utils.available_devices()))
|
||||
return el
|
||||
|
||||
|
||||
# NOTE: A list of bentoml option that is not needed for parsing.
|
||||
# NOTE: User shouldn't set '--working-dir', as OpenLLM will setup this.
|
||||
# NOTE: production is also deprecated
|
||||
_IGNORED_OPTIONS = {'working_dir', 'production', 'protocol_version'}
|
||||
|
||||
|
||||
def parse_serve_args(serve_grpc: bool) -> t.Callable[[t.Callable[..., LLMConfig]], t.Callable[[FC], FC]]:
|
||||
"""Parsing `bentoml serve|serve-grpc` click.Option to be parsed via `openllm start`."""
|
||||
from bentoml_cli.cli import cli
|
||||
|
||||
command = 'serve' if not serve_grpc else 'serve-grpc'
|
||||
group = cog.optgroup.group(f"Start a {'HTTP' if not serve_grpc else 'gRPC'} server options",
|
||||
help=f"Related to serving the model [synonymous to `bentoml {'serve-http' if not serve_grpc else command }`]",
|
||||
)
|
||||
group = cog.optgroup.group(
|
||||
f"Start a {'HTTP' if not serve_grpc else 'gRPC'} server options",
|
||||
help=f"Related to serving the model [synonymous to `bentoml {'serve-http' if not serve_grpc else command }`]",
|
||||
)
|
||||
|
||||
def decorator(f: t.Callable[Concatenate[int, t.Optional[str], P], LLMConfig]) -> t.Callable[[FC], FC]:
|
||||
serve_command = cli.commands[command]
|
||||
# The first variable is the argument bento
|
||||
# The last five is from BentoMLCommandGroup.NUMBER_OF_COMMON_PARAMS
|
||||
serve_options = [p for p in serve_command.params[1:-BentoMLCommandGroup.NUMBER_OF_COMMON_PARAMS] if p.name not in _IGNORED_OPTIONS]
|
||||
serve_options = [
|
||||
p
|
||||
for p in serve_command.params[1 : -BentoMLCommandGroup.NUMBER_OF_COMMON_PARAMS]
|
||||
if p.name not in _IGNORED_OPTIONS
|
||||
]
|
||||
for options in reversed(serve_options):
|
||||
attrs = options.to_info_dict()
|
||||
# we don't need param_type_name, since it should all be options
|
||||
@@ -179,8 +255,10 @@ def parse_serve_args(serve_grpc: bool) -> t.Callable[[t.Callable[..., LLMConfig]
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
_http_server_args, _grpc_server_args = parse_serve_args(False), parse_serve_args(True)
|
||||
|
||||
|
||||
def _click_factory_type(*param_decls: t.Any, **attrs: t.Any) -> t.Callable[[FC | None], FC]:
|
||||
"""General ``@click`` decorator with some sauce.
|
||||
|
||||
@@ -189,68 +267,114 @@ def _click_factory_type(*param_decls: t.Any, **attrs: t.Any) -> t.Callable[[FC |
|
||||
"""
|
||||
factory = attrs.pop('factory', click)
|
||||
factory_attr = attrs.pop('attr', 'option')
|
||||
if factory_attr != 'argument': attrs.setdefault('help', 'General option for OpenLLM CLI.')
|
||||
if factory_attr != 'argument':
|
||||
attrs.setdefault('help', 'General option for OpenLLM CLI.')
|
||||
|
||||
def decorator(f: FC | None) -> FC:
|
||||
callback = getattr(factory, factory_attr, None)
|
||||
if callback is None: raise ValueError(f'Factory {factory} has no attribute {factory_attr}.')
|
||||
if callback is None:
|
||||
raise ValueError(f'Factory {factory} has no attribute {factory_attr}.')
|
||||
return t.cast(FC, callback(*param_decls, **attrs)(f) if f is not None else callback(*param_decls, **attrs))
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
cli_option = functools.partial(_click_factory_type, attr='option')
|
||||
cli_argument = functools.partial(_click_factory_type, attr='argument')
|
||||
|
||||
|
||||
def cors_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option('--cors/--no-cors', show_default=True, default=False, envvar='OPENLLM_CORS', show_envvar=True, help='Enable CORS for the server.', **attrs)(f)
|
||||
return cli_option(
|
||||
'--cors/--no-cors',
|
||||
show_default=True,
|
||||
default=False,
|
||||
envvar='OPENLLM_CORS',
|
||||
show_envvar=True,
|
||||
help='Enable CORS for the server.',
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
def machine_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option('--machine', is_flag=True, default=False, hidden=True, **attrs)(f)
|
||||
|
||||
|
||||
def model_id_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option('--model-id', type=click.STRING, default=None, envvar='OPENLLM_MODEL_ID', show_envvar=True, help='Optional model_id name or path for (fine-tune) weight.', **attrs)(f)
|
||||
return cli_option(
|
||||
'--model-id',
|
||||
type=click.STRING,
|
||||
default=None,
|
||||
envvar='OPENLLM_MODEL_ID',
|
||||
show_envvar=True,
|
||||
help='Optional model_id name or path for (fine-tune) weight.',
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
def model_version_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option('--model-version', type=click.STRING, default=None, help='Optional model version to save for this model. It will be inferred automatically from model-id.', **attrs)(f)
|
||||
return cli_option(
|
||||
'--model-version',
|
||||
type=click.STRING,
|
||||
default=None,
|
||||
help='Optional model version to save for this model. It will be inferred automatically from model-id.',
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
def system_message_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option('--system-message',
|
||||
type=click.STRING,
|
||||
default=None,
|
||||
envvar='OPENLLM_SYSTEM_MESSAGE',
|
||||
help='Optional system message for supported LLMs. If given LLM supports system message, OpenLLM will provide a default system message.',
|
||||
**attrs)(f)
|
||||
return cli_option(
|
||||
'--system-message',
|
||||
type=click.STRING,
|
||||
default=None,
|
||||
envvar='OPENLLM_SYSTEM_MESSAGE',
|
||||
help='Optional system message for supported LLMs. If given LLM supports system message, OpenLLM will provide a default system message.',
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
def prompt_template_file_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option('--prompt-template-file',
|
||||
type=click.File(),
|
||||
default=None,
|
||||
help='Optional file path containing user-defined custom prompt template. By default, the prompt template for the specified LLM will be used.',
|
||||
**attrs)(f)
|
||||
return cli_option(
|
||||
'--prompt-template-file',
|
||||
type=click.File(),
|
||||
default=None,
|
||||
help='Optional file path containing user-defined custom prompt template. By default, the prompt template for the specified LLM will be used.',
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
def backend_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
# NOTE: LiteralBackend needs to remove the last two item as ggml and mlc is wip
|
||||
# XXX: remove the check for __args__ once we have ggml and mlc supports
|
||||
return cli_option('--backend',
|
||||
type=click.Choice(get_literal_args(LiteralBackend)[:2]),
|
||||
default=None,
|
||||
envvar='OPENLLM_BACKEND',
|
||||
show_envvar=True,
|
||||
help='The implementation for saving this LLM.',
|
||||
**attrs)(f)
|
||||
return cli_option(
|
||||
'--backend',
|
||||
type=click.Choice(get_literal_args(LiteralBackend)[:2]),
|
||||
default=None,
|
||||
envvar='OPENLLM_BACKEND',
|
||||
show_envvar=True,
|
||||
help='The implementation for saving this LLM.',
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
def model_name_argument(f: _AnyCallable | None = None, required: bool = True, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_argument('model_name', type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING]), required=required, **attrs)(f)
|
||||
return cli_argument(
|
||||
'model_name',
|
||||
type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING]),
|
||||
required=required,
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option('--quantise',
|
||||
'--quantize',
|
||||
'quantize',
|
||||
type=click.Choice(get_literal_args(LiteralQuantise)),
|
||||
default=None,
|
||||
envvar='OPENLLM_QUANTIZE',
|
||||
show_envvar=True,
|
||||
help='''Dynamic quantization for running this LLM.
|
||||
return cli_option(
|
||||
'--quantise',
|
||||
'--quantize',
|
||||
'quantize',
|
||||
type=click.Choice(get_literal_args(LiteralQuantise)),
|
||||
default=None,
|
||||
envvar='OPENLLM_QUANTIZE',
|
||||
show_envvar=True,
|
||||
help="""Dynamic quantization for running this LLM.
|
||||
|
||||
The following quantization strategies are supported:
|
||||
|
||||
@@ -261,18 +385,29 @@ def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, **att
|
||||
- ``gptq``: ``GPTQ`` [quantization](https://arxiv.org/abs/2210.17323)
|
||||
|
||||
> [!NOTE] that the model can also be served with quantized weights.
|
||||
''' + ('''
|
||||
> [!NOTE] that this will set the mode for serving within deployment.''' if build else '') + '''
|
||||
> [!NOTE] that quantization are currently only available in *PyTorch* models.''',
|
||||
**attrs)(f)
|
||||
"""
|
||||
+ (
|
||||
"""
|
||||
> [!NOTE] that this will set the mode for serving within deployment."""
|
||||
if build
|
||||
else ''
|
||||
)
|
||||
+ """
|
||||
> [!NOTE] that quantization are currently only available in *PyTorch* models.""",
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
def workers_per_resource_option(f: _AnyCallable | None = None, *, build: bool = False, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option('--workers-per-resource',
|
||||
default=None,
|
||||
callback=workers_per_resource_callback,
|
||||
type=str,
|
||||
required=False,
|
||||
help='''Number of workers per resource assigned.
|
||||
|
||||
def workers_per_resource_option(
|
||||
f: _AnyCallable | None = None, *, build: bool = False, **attrs: t.Any
|
||||
) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--workers-per-resource',
|
||||
default=None,
|
||||
callback=workers_per_resource_callback,
|
||||
type=str,
|
||||
required=False,
|
||||
help="""Number of workers per resource assigned.
|
||||
|
||||
See https://docs.bentoml.org/en/latest/guides/scheduling.html#resource-scheduling-strategy
|
||||
for more information. By default, this is set to 1.
|
||||
@@ -282,22 +417,30 @@ def workers_per_resource_option(f: _AnyCallable | None = None, *, build: bool =
|
||||
- ``round_robin``: Similar behaviour when setting ``--workers-per-resource 1``. This is useful for smaller models.
|
||||
|
||||
- ``conserved``: This will determine the number of available GPU resources, and only assign one worker for the LLMRunner. For example, if ther are 4 GPUs available, then ``conserved`` is equivalent to ``--workers-per-resource 0.25``.
|
||||
''' + ("""\n
|
||||
"""
|
||||
+ (
|
||||
"""\n
|
||||
> [!NOTE] The workers value passed into 'build' will determine how the LLM can
|
||||
> be provisioned in Kubernetes as well as in standalone container. This will
|
||||
> ensure it has the same effect with 'openllm start --api-workers ...'""" if build else ''),
|
||||
**attrs)(f)
|
||||
> ensure it has the same effect with 'openllm start --api-workers ...'"""
|
||||
if build
|
||||
else ''
|
||||
),
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
def serialisation_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option('--serialisation',
|
||||
'--serialization',
|
||||
'serialisation',
|
||||
type=click.Choice(get_literal_args(LiteralSerialisation)),
|
||||
default=None,
|
||||
show_default=True,
|
||||
show_envvar=True,
|
||||
envvar='OPENLLM_SERIALIZATION',
|
||||
help='''Serialisation format for save/load LLM.
|
||||
return cli_option(
|
||||
'--serialisation',
|
||||
'--serialization',
|
||||
'serialisation',
|
||||
type=click.Choice(get_literal_args(LiteralSerialisation)),
|
||||
default=None,
|
||||
show_default=True,
|
||||
show_envvar=True,
|
||||
envvar='OPENLLM_SERIALIZATION',
|
||||
help="""Serialisation format for save/load LLM.
|
||||
|
||||
Currently the following strategies are supported:
|
||||
|
||||
@@ -306,37 +449,51 @@ def serialisation_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Cal
|
||||
> [!NOTE] Safetensors might not work for every cases, and you can always fallback to ``legacy`` if needed.
|
||||
|
||||
- ``legacy``: This will use PyTorch serialisation format, often as ``.bin`` files. This should be used if the model doesn't yet support safetensors.
|
||||
''',
|
||||
**attrs)(f)
|
||||
""",
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
def container_registry_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option('--container-registry',
|
||||
'container_registry',
|
||||
type=click.Choice(list(openllm.bundle.CONTAINER_NAMES)),
|
||||
default='ecr',
|
||||
show_default=True,
|
||||
show_envvar=True,
|
||||
envvar='OPENLLM_CONTAINER_REGISTRY',
|
||||
callback=container_registry_callback,
|
||||
help='The default container registry to get the base image for building BentoLLM. Currently, it supports ecr, ghcr, docker',
|
||||
**attrs)(f)
|
||||
return cli_option(
|
||||
'--container-registry',
|
||||
'container_registry',
|
||||
type=click.Choice(list(openllm.bundle.CONTAINER_NAMES)),
|
||||
default='ecr',
|
||||
show_default=True,
|
||||
show_envvar=True,
|
||||
envvar='OPENLLM_CONTAINER_REGISTRY',
|
||||
callback=container_registry_callback,
|
||||
help='The default container registry to get the base image for building BentoLLM. Currently, it supports ecr, ghcr, docker',
|
||||
**attrs,
|
||||
)(f)
|
||||
|
||||
|
||||
_wpr_strategies = {'round_robin', 'conserved'}
|
||||
|
||||
|
||||
def workers_per_resource_callback(ctx: click.Context, param: click.Parameter, value: str | None) -> str | None:
|
||||
if value is None: return value
|
||||
if value is None:
|
||||
return value
|
||||
value = inflection.underscore(value)
|
||||
if value in _wpr_strategies: return value
|
||||
if value in _wpr_strategies:
|
||||
return value
|
||||
else:
|
||||
try:
|
||||
float(value) # type: ignore[arg-type]
|
||||
except ValueError:
|
||||
raise click.BadParameter(f"'workers_per_resource' only accept '{_wpr_strategies}' as possible strategies, otherwise pass in float.", ctx, param) from None
|
||||
raise click.BadParameter(
|
||||
f"'workers_per_resource' only accept '{_wpr_strategies}' as possible strategies, otherwise pass in float.",
|
||||
ctx,
|
||||
param,
|
||||
) from None
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
def container_registry_callback(ctx: click.Context, param: click.Parameter, value: str | None) -> str | None:
|
||||
if value is None: return value
|
||||
if value is None:
|
||||
return value
|
||||
if value not in openllm.bundle.supported_registries:
|
||||
raise click.BadParameter(f'Value must be one of {openllm.bundle.supported_registries}', ctx, param)
|
||||
return value
|
||||
|
||||
@@ -22,6 +22,7 @@ from openllm_core.utils import codegen
|
||||
from openllm_core.utils import first_not_none
|
||||
from openllm_core.utils import is_vllm_available
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from bentoml._internal.bento import BentoStore
|
||||
from openllm_core._configuration import LLMConfig
|
||||
@@ -33,20 +34,23 @@ if t.TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def _start(model_id: str,
|
||||
timeout: int = 30,
|
||||
workers_per_resource: t.Literal['conserved', 'round_robin'] | float | None = None,
|
||||
device: tuple[str, ...] | t.Literal['all'] | None = None,
|
||||
quantize: LiteralQuantise | None = None,
|
||||
system_message: str | None = None,
|
||||
prompt_template_file: str | None = None,
|
||||
adapter_map: dict[LiteralString, str | None] | None = None,
|
||||
backend: LiteralBackend | None = None,
|
||||
additional_args: list[str] | None = None,
|
||||
cors: bool = False,
|
||||
_serve_grpc: bool = False,
|
||||
__test__: bool = False,
|
||||
**_: t.Any) -> LLMConfig | subprocess.Popen[bytes]:
|
||||
|
||||
def _start(
|
||||
model_id: str,
|
||||
timeout: int = 30,
|
||||
workers_per_resource: t.Literal['conserved', 'round_robin'] | float | None = None,
|
||||
device: tuple[str, ...] | t.Literal['all'] | None = None,
|
||||
quantize: LiteralQuantise | None = None,
|
||||
system_message: str | None = None,
|
||||
prompt_template_file: str | None = None,
|
||||
adapter_map: dict[LiteralString, str | None] | None = None,
|
||||
backend: LiteralBackend | None = None,
|
||||
additional_args: list[str] | None = None,
|
||||
cors: bool = False,
|
||||
_serve_grpc: bool = False,
|
||||
__test__: bool = False,
|
||||
**_: t.Any,
|
||||
) -> LLMConfig | subprocess.Popen[bytes]:
|
||||
"""Python API to start a LLM server. These provides one-to-one mapping to CLI arguments.
|
||||
|
||||
For all additional arguments, pass it as string to ``additional_args``. For example, if you want to
|
||||
@@ -85,45 +89,68 @@ def _start(model_id: str,
|
||||
"""
|
||||
from .entrypoint import start_command
|
||||
from .entrypoint import start_grpc_command
|
||||
os.environ['OPENLLM_BACKEND'] = openllm_core.utils.first_not_none(backend, default='vllm' if is_vllm_available() else 'pt')
|
||||
|
||||
os.environ['OPENLLM_BACKEND'] = openllm_core.utils.first_not_none(
|
||||
backend, default='vllm' if is_vllm_available() else 'pt'
|
||||
)
|
||||
|
||||
args: list[str] = [model_id]
|
||||
if system_message: args.extend(['--system-message', system_message])
|
||||
if prompt_template_file: args.extend(['--prompt-template-file', openllm_core.utils.resolve_filepath(prompt_template_file)])
|
||||
if timeout: args.extend(['--server-timeout', str(timeout)])
|
||||
if system_message:
|
||||
args.extend(['--system-message', system_message])
|
||||
if prompt_template_file:
|
||||
args.extend(['--prompt-template-file', openllm_core.utils.resolve_filepath(prompt_template_file)])
|
||||
if timeout:
|
||||
args.extend(['--server-timeout', str(timeout)])
|
||||
if workers_per_resource:
|
||||
args.extend(['--workers-per-resource', str(workers_per_resource) if not isinstance(workers_per_resource, str) else workers_per_resource])
|
||||
if device and not os.environ.get('CUDA_VISIBLE_DEVICES'): args.extend(['--device', ','.join(device)])
|
||||
if quantize: args.extend(['--quantize', str(quantize)])
|
||||
if cors: args.append('--cors')
|
||||
args.extend(
|
||||
[
|
||||
'--workers-per-resource',
|
||||
str(workers_per_resource) if not isinstance(workers_per_resource, str) else workers_per_resource,
|
||||
]
|
||||
)
|
||||
if device and not os.environ.get('CUDA_VISIBLE_DEVICES'):
|
||||
args.extend(['--device', ','.join(device)])
|
||||
if quantize:
|
||||
args.extend(['--quantize', str(quantize)])
|
||||
if cors:
|
||||
args.append('--cors')
|
||||
if adapter_map:
|
||||
args.extend(list(itertools.chain.from_iterable([['--adapter-id', f"{k}{':'+v if v else ''}"] for k, v in adapter_map.items()])))
|
||||
if additional_args: args.extend(additional_args)
|
||||
if __test__: args.append('--return-process')
|
||||
args.extend(
|
||||
list(
|
||||
itertools.chain.from_iterable([['--adapter-id', f"{k}{':'+v if v else ''}"] for k, v in adapter_map.items()])
|
||||
)
|
||||
)
|
||||
if additional_args:
|
||||
args.extend(additional_args)
|
||||
if __test__:
|
||||
args.append('--return-process')
|
||||
|
||||
cmd = start_command if not _serve_grpc else start_grpc_command
|
||||
return cmd.main(args=args, standalone_mode=False)
|
||||
|
||||
|
||||
@inject
|
||||
def _build(model_id: str,
|
||||
model_version: str | None = None,
|
||||
bento_version: str | None = None,
|
||||
quantize: LiteralQuantise | None = None,
|
||||
adapter_map: dict[str, str | None] | None = None,
|
||||
system_message: str | None = None,
|
||||
prompt_template_file: str | None = None,
|
||||
build_ctx: str | None = None,
|
||||
enable_features: tuple[str, ...] | None = None,
|
||||
dockerfile_template: str | None = None,
|
||||
overwrite: bool = False,
|
||||
container_registry: LiteralContainerRegistry | None = None,
|
||||
container_version_strategy: LiteralContainerVersionStrategy | None = None,
|
||||
push: bool = False,
|
||||
force_push: bool = False,
|
||||
containerize: bool = False,
|
||||
serialisation: LiteralSerialisation | None = None,
|
||||
additional_args: list[str] | None = None,
|
||||
bento_store: BentoStore = Provide[BentoMLContainer.bento_store]) -> bentoml.Bento:
|
||||
def _build(
|
||||
model_id: str,
|
||||
model_version: str | None = None,
|
||||
bento_version: str | None = None,
|
||||
quantize: LiteralQuantise | None = None,
|
||||
adapter_map: dict[str, str | None] | None = None,
|
||||
system_message: str | None = None,
|
||||
prompt_template_file: str | None = None,
|
||||
build_ctx: str | None = None,
|
||||
enable_features: tuple[str, ...] | None = None,
|
||||
dockerfile_template: str | None = None,
|
||||
overwrite: bool = False,
|
||||
container_registry: LiteralContainerRegistry | None = None,
|
||||
container_version_strategy: LiteralContainerVersionStrategy | None = None,
|
||||
push: bool = False,
|
||||
force_push: bool = False,
|
||||
containerize: bool = False,
|
||||
serialisation: LiteralSerialisation | None = None,
|
||||
additional_args: list[str] | None = None,
|
||||
bento_store: BentoStore = Provide[BentoMLContainer.bento_store],
|
||||
) -> bentoml.Bento:
|
||||
"""Package a LLM into a BentoLLM.
|
||||
|
||||
The LLM will be built into a BentoService with the following structure:
|
||||
@@ -161,49 +188,83 @@ def _build(model_id: str,
|
||||
``bentoml.Bento | str``: BentoLLM instance. This can be used to serve the LLM or can be pushed to BentoCloud.
|
||||
"""
|
||||
from ..serialisation.transformers.weights import has_safetensors_weights
|
||||
|
||||
args: list[str] = [
|
||||
sys.executable, '-m', 'openllm', 'build', model_id, '--machine', '--serialisation',
|
||||
t.cast(LiteralSerialisation, first_not_none(serialisation, default='safetensors' if has_safetensors_weights(model_id) else 'legacy'))
|
||||
sys.executable,
|
||||
'-m',
|
||||
'openllm',
|
||||
'build',
|
||||
model_id,
|
||||
'--machine',
|
||||
'--serialisation',
|
||||
t.cast(
|
||||
LiteralSerialisation,
|
||||
first_not_none(serialisation, default='safetensors' if has_safetensors_weights(model_id) else 'legacy'),
|
||||
),
|
||||
]
|
||||
if quantize: args.extend(['--quantize', quantize])
|
||||
if containerize and push: raise OpenLLMException("'containerize' and 'push' are currently mutually exclusive.")
|
||||
if push: args.extend(['--push'])
|
||||
if containerize: args.extend(['--containerize'])
|
||||
if build_ctx: args.extend(['--build-ctx', build_ctx])
|
||||
if enable_features: args.extend([f'--enable-features={f}' for f in enable_features])
|
||||
if overwrite: args.append('--overwrite')
|
||||
if system_message: args.extend(['--system-message', system_message])
|
||||
if prompt_template_file: args.extend(['--prompt-template-file', openllm_core.utils.resolve_filepath(prompt_template_file)])
|
||||
if adapter_map: args.extend([f"--adapter-id={k}{':'+v if v is not None else ''}" for k, v in adapter_map.items()])
|
||||
if model_version: args.extend(['--model-version', model_version])
|
||||
if bento_version: args.extend(['--bento-version', bento_version])
|
||||
if dockerfile_template: args.extend(['--dockerfile-template', dockerfile_template])
|
||||
if container_registry is None: container_registry = 'ecr'
|
||||
if container_version_strategy is None: container_version_strategy = 'release'
|
||||
if quantize:
|
||||
args.extend(['--quantize', quantize])
|
||||
if containerize and push:
|
||||
raise OpenLLMException("'containerize' and 'push' are currently mutually exclusive.")
|
||||
if push:
|
||||
args.extend(['--push'])
|
||||
if containerize:
|
||||
args.extend(['--containerize'])
|
||||
if build_ctx:
|
||||
args.extend(['--build-ctx', build_ctx])
|
||||
if enable_features:
|
||||
args.extend([f'--enable-features={f}' for f in enable_features])
|
||||
if overwrite:
|
||||
args.append('--overwrite')
|
||||
if system_message:
|
||||
args.extend(['--system-message', system_message])
|
||||
if prompt_template_file:
|
||||
args.extend(['--prompt-template-file', openllm_core.utils.resolve_filepath(prompt_template_file)])
|
||||
if adapter_map:
|
||||
args.extend([f"--adapter-id={k}{':'+v if v is not None else ''}" for k, v in adapter_map.items()])
|
||||
if model_version:
|
||||
args.extend(['--model-version', model_version])
|
||||
if bento_version:
|
||||
args.extend(['--bento-version', bento_version])
|
||||
if dockerfile_template:
|
||||
args.extend(['--dockerfile-template', dockerfile_template])
|
||||
if container_registry is None:
|
||||
container_registry = 'ecr'
|
||||
if container_version_strategy is None:
|
||||
container_version_strategy = 'release'
|
||||
args.extend(['--container-registry', container_registry, '--container-version-strategy', container_version_strategy])
|
||||
if additional_args: args.extend(additional_args)
|
||||
if additional_args:
|
||||
args.extend(additional_args)
|
||||
|
||||
try:
|
||||
output = subprocess.check_output(args, env=os.environ.copy(), cwd=build_ctx or os.getcwd())
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error("Exception caught while building Bento for '%s'", model_id, exc_info=e)
|
||||
if e.stderr: raise OpenLLMException(e.stderr.decode('utf-8')) from None
|
||||
if e.stderr:
|
||||
raise OpenLLMException(e.stderr.decode('utf-8')) from None
|
||||
raise OpenLLMException(str(e)) from None
|
||||
matched = re.match(r'__object__:(\{.*\})$', output.decode('utf-8').strip())
|
||||
if matched is None:
|
||||
raise ValueError(f"Failed to find tag from output: {output.decode('utf-8').strip()}\nNote: Output from 'openllm build' might not be correct. Please open an issue on GitHub.")
|
||||
raise ValueError(
|
||||
f"Failed to find tag from output: {output.decode('utf-8').strip()}\nNote: Output from 'openllm build' might not be correct. Please open an issue on GitHub."
|
||||
)
|
||||
try:
|
||||
result = orjson.loads(matched.group(1))
|
||||
except orjson.JSONDecodeError as e:
|
||||
raise ValueError(f"Failed to decode JSON from output: {output.decode('utf-8').strip()}\nNote: Output from 'openllm build' might not be correct. Please open an issue on GitHub.") from e
|
||||
raise ValueError(
|
||||
f"Failed to decode JSON from output: {output.decode('utf-8').strip()}\nNote: Output from 'openllm build' might not be correct. Please open an issue on GitHub."
|
||||
) from e
|
||||
return bentoml.get(result['tag'], _bento_store=bento_store)
|
||||
|
||||
def _import_model(model_id: str,
|
||||
model_version: str | None = None,
|
||||
backend: LiteralBackend | None = None,
|
||||
quantize: LiteralQuantise | None = None,
|
||||
serialisation: LiteralSerialisation | None = None,
|
||||
additional_args: t.Sequence[str] | None = None) -> dict[str, t.Any]:
|
||||
|
||||
def _import_model(
|
||||
model_id: str,
|
||||
model_version: str | None = None,
|
||||
backend: LiteralBackend | None = None,
|
||||
quantize: LiteralQuantise | None = None,
|
||||
serialisation: LiteralSerialisation | None = None,
|
||||
additional_args: t.Sequence[str] | None = None,
|
||||
) -> dict[str, t.Any]:
|
||||
"""Import a LLM into local store.
|
||||
|
||||
> [!NOTE]
|
||||
@@ -232,19 +293,32 @@ def _import_model(model_id: str,
|
||||
``bentoml.Model``:BentoModel of the given LLM. This can be used to serve the LLM or can be pushed to BentoCloud.
|
||||
"""
|
||||
from .entrypoint import import_command
|
||||
|
||||
args = [model_id, '--quiet']
|
||||
if backend is not None: args.extend(['--backend', backend])
|
||||
if model_version is not None: args.extend(['--model-version', str(model_version)])
|
||||
if quantize is not None: args.extend(['--quantize', quantize])
|
||||
if serialisation is not None: args.extend(['--serialisation', serialisation])
|
||||
if additional_args is not None: args.extend(additional_args)
|
||||
if backend is not None:
|
||||
args.extend(['--backend', backend])
|
||||
if model_version is not None:
|
||||
args.extend(['--model-version', str(model_version)])
|
||||
if quantize is not None:
|
||||
args.extend(['--quantize', quantize])
|
||||
if serialisation is not None:
|
||||
args.extend(['--serialisation', serialisation])
|
||||
if additional_args is not None:
|
||||
args.extend(additional_args)
|
||||
return import_command.main(args=args, standalone_mode=False)
|
||||
|
||||
|
||||
def _list_models() -> dict[str, t.Any]:
|
||||
"""List all available models within the local store."""
|
||||
from .entrypoint import models_command
|
||||
|
||||
return models_command.main(args=['--show-available', '--quiet'], standalone_mode=False)
|
||||
|
||||
|
||||
start, start_grpc = codegen.gen_sdk(_start, _serve_grpc=False), codegen.gen_sdk(_start, _serve_grpc=True)
|
||||
build, import_model, list_models = codegen.gen_sdk(_build), codegen.gen_sdk(_import_model), codegen.gen_sdk(_list_models)
|
||||
build, import_model, list_models = (
|
||||
codegen.gen_sdk(_build),
|
||||
codegen.gen_sdk(_import_model),
|
||||
codegen.gen_sdk(_list_models),
|
||||
)
|
||||
__all__ = ['start', 'start_grpc', 'build', 'import_model', 'list_models']
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -10,13 +10,16 @@ from openllm.cli import termui
|
||||
from openllm.cli._factory import container_registry_option
|
||||
from openllm.cli._factory import machine_option
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm_core._typing_compat import LiteralContainerRegistry
|
||||
from openllm_core._typing_compat import LiteralContainerVersionStrategy
|
||||
|
||||
@click.command('build_base_container',
|
||||
context_settings=termui.CONTEXT_SETTINGS,
|
||||
help='''Base image builder for BentoLLM.
|
||||
|
||||
@click.command(
|
||||
'build_base_container',
|
||||
context_settings=termui.CONTEXT_SETTINGS,
|
||||
help="""Base image builder for BentoLLM.
|
||||
|
||||
By default, the base image will include custom kernels (PagedAttention via vllm, FlashAttention-v2, etc.) built with CUDA 11.8, Python 3.9 on Ubuntu22.04.
|
||||
Optionally, this can also be pushed directly to remote registry. Currently support ``docker.io``, ``ghcr.io`` and ``quay.io``.
|
||||
@@ -26,12 +29,24 @@ if t.TYPE_CHECKING:
|
||||
This command is only useful for debugging and for building custom base image for extending BentoML with custom base images and custom kernels.
|
||||
|
||||
Note that we already release images on our CI to ECR and GHCR, so you don't need to build it yourself.
|
||||
''')
|
||||
""",
|
||||
)
|
||||
@container_registry_option
|
||||
@click.option('--version-strategy', type=click.Choice(['release', 'latest', 'nightly']), default='nightly', help='Version strategy to use for tagging the image.')
|
||||
@click.option(
|
||||
'--version-strategy',
|
||||
type=click.Choice(['release', 'latest', 'nightly']),
|
||||
default='nightly',
|
||||
help='Version strategy to use for tagging the image.',
|
||||
)
|
||||
@click.option('--push/--no-push', help='Whether to push to remote repository', is_flag=True, default=False)
|
||||
@machine_option
|
||||
def cli(container_registry: tuple[LiteralContainerRegistry, ...] | None, version_strategy: LiteralContainerVersionStrategy, push: bool, machine: bool) -> dict[str, str]:
|
||||
def cli(
|
||||
container_registry: tuple[LiteralContainerRegistry, ...] | None,
|
||||
version_strategy: LiteralContainerVersionStrategy,
|
||||
push: bool,
|
||||
machine: bool,
|
||||
) -> dict[str, str]:
|
||||
mapping = openllm.bundle.build_container(container_registry, version_strategy, push, machine)
|
||||
if machine: termui.echo(orjson.dumps(mapping, option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
if machine:
|
||||
termui.echo(orjson.dumps(mapping, option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
return mapping
|
||||
|
||||
@@ -16,24 +16,33 @@ from openllm.cli import termui
|
||||
from openllm.cli._factory import bento_complete_envvar
|
||||
from openllm.cli._factory import machine_option
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from bentoml._internal.bento import BentoStore
|
||||
|
||||
|
||||
@click.command('dive_bentos', context_settings=termui.CONTEXT_SETTINGS)
|
||||
@click.argument('bento', type=str, shell_complete=bento_complete_envvar)
|
||||
@machine_option
|
||||
@click.pass_context
|
||||
@inject
|
||||
def cli(ctx: click.Context, bento: str, machine: bool, _bento_store: BentoStore = Provide[BentoMLContainer.bento_store]) -> str | None:
|
||||
def cli(
|
||||
ctx: click.Context, bento: str, machine: bool, _bento_store: BentoStore = Provide[BentoMLContainer.bento_store]
|
||||
) -> str | None:
|
||||
"""Dive into a BentoLLM. This is synonymous to cd $(b get <bento>:<tag> -o path)."""
|
||||
try:
|
||||
bentomodel = _bento_store.get(bento)
|
||||
except bentoml.exceptions.NotFound:
|
||||
ctx.fail(f'Bento {bento} not found. Make sure to call `openllm build` first.')
|
||||
if 'bundler' not in bentomodel.info.labels or bentomodel.info.labels['bundler'] != 'openllm.bundle':
|
||||
ctx.fail(f"Bento is either too old or not built with OpenLLM. Make sure to use ``openllm build {bentomodel.info.labels['start_name']}`` for correctness.")
|
||||
if machine: return bentomodel.path
|
||||
ctx.fail(
|
||||
f"Bento is either too old or not built with OpenLLM. Make sure to use ``openllm build {bentomodel.info.labels['start_name']}`` for correctness."
|
||||
)
|
||||
if machine:
|
||||
return bentomodel.path
|
||||
# copy and paste this into a new shell
|
||||
if psutil.WINDOWS: subprocess.check_call([shutil.which('dir') or 'dir'], cwd=bentomodel.path)
|
||||
else: subprocess.check_call([shutil.which('ls') or 'ls', '-Rrthla'], cwd=bentomodel.path)
|
||||
if psutil.WINDOWS:
|
||||
subprocess.check_call([shutil.which('dir') or 'dir'], cwd=bentomodel.path)
|
||||
else:
|
||||
subprocess.check_call([shutil.which('ls') or 'ls', '-Rrthla'], cwd=bentomodel.path)
|
||||
ctx.exit(0)
|
||||
|
||||
@@ -16,10 +16,14 @@ from openllm.cli import termui
|
||||
from openllm.cli._factory import bento_complete_envvar
|
||||
from openllm_core.utils import converter
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from bentoml._internal.bento import BentoStore
|
||||
|
||||
@click.command('get_containerfile', context_settings=termui.CONTEXT_SETTINGS, help='Return Containerfile of any given Bento.')
|
||||
|
||||
@click.command(
|
||||
'get_containerfile', context_settings=termui.CONTEXT_SETTINGS, help='Return Containerfile of any given Bento.'
|
||||
)
|
||||
@click.argument('bento', type=str, shell_complete=bento_complete_envvar)
|
||||
@click.pass_context
|
||||
@inject
|
||||
@@ -41,6 +45,13 @@ def cli(ctx: click.Context, bento: str, _bento_store: BentoStore = Provide[Bento
|
||||
# for the reconstruction of the Dockerfile.
|
||||
if 'dockerfile_template' in docker_attrs and docker_attrs['dockerfile_template'] is not None:
|
||||
docker_attrs['dockerfile_template'] = 'env/docker/Dockerfile.template'
|
||||
doc = generate_containerfile(docker=DockerOptions(**docker_attrs), build_ctx=bentomodel.path, conda=options.conda, bento_fs=bentomodel._fs, enable_buildkit=True, add_header=True)
|
||||
doc = generate_containerfile(
|
||||
docker=DockerOptions(**docker_attrs),
|
||||
build_ctx=bentomodel.path,
|
||||
conda=options.conda,
|
||||
bento_fs=bentomodel._fs,
|
||||
enable_buildkit=True,
|
||||
add_header=True,
|
||||
)
|
||||
termui.echo(doc, fg='white')
|
||||
return bentomodel.path
|
||||
|
||||
@@ -16,20 +16,30 @@ from openllm.cli import termui
|
||||
from openllm.cli._factory import model_complete_envvar
|
||||
from openllm_core.prompts import process_prompt
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@click.command('get_prompt', context_settings=termui.CONTEXT_SETTINGS)
|
||||
@click.argument('model_name', type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()]), shell_complete=model_complete_envvar)
|
||||
@click.argument(
|
||||
'model_name',
|
||||
type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()]),
|
||||
shell_complete=model_complete_envvar,
|
||||
)
|
||||
@click.argument('prompt', type=click.STRING)
|
||||
@click.option('--format', type=click.STRING, default=None)
|
||||
@click.option('--opt',
|
||||
help="Define additional prompt variables. (format: ``--opt system_prompt='You are a useful assistant'``)",
|
||||
required=False,
|
||||
multiple=True,
|
||||
callback=opt_callback,
|
||||
metavar='ARG=VALUE[,ARG=VALUE]')
|
||||
@click.option(
|
||||
'--opt',
|
||||
help="Define additional prompt variables. (format: ``--opt system_prompt='You are a useful assistant'``)",
|
||||
required=False,
|
||||
multiple=True,
|
||||
callback=opt_callback,
|
||||
metavar='ARG=VALUE[,ARG=VALUE]',
|
||||
)
|
||||
@click.pass_context
|
||||
def cli(ctx: click.Context, /, model_name: str, prompt: str, format: str | None, _memoized: dict[str, t.Any], **_: t.Any) -> str | None:
|
||||
def cli(
|
||||
ctx: click.Context, /, model_name: str, prompt: str, format: str | None, _memoized: dict[str, t.Any], **_: t.Any
|
||||
) -> str | None:
|
||||
"""Get the default prompt used by OpenLLM."""
|
||||
module = getattr(openllm_core.config, f'configuration_{model_name}')
|
||||
_memoized = {k: v[0] for k, v in _memoized.items() if v}
|
||||
@@ -42,11 +52,18 @@ def cli(ctx: click.Context, /, model_name: str, prompt: str, format: str | None,
|
||||
if format is None:
|
||||
if not hasattr(module, 'PROMPT_MAPPING') or module.PROMPT_MAPPING is None:
|
||||
raise RuntimeError('Failed to find prompt mapping while DEFAULT_PROMPT_TEMPLATE is a function.')
|
||||
raise click.BadOptionUsage('format', f"{model_name} prompt requires passing '--format' (available format: {list(module.PROMPT_MAPPING)})")
|
||||
raise click.BadOptionUsage(
|
||||
'format',
|
||||
f"{model_name} prompt requires passing '--format' (available format: {list(module.PROMPT_MAPPING)})",
|
||||
)
|
||||
if prompt_mapping is None:
|
||||
raise click.BadArgumentUsage(f'Failed to fine prompt mapping while the default prompt for {model_name} is a callable.') from None
|
||||
raise click.BadArgumentUsage(
|
||||
f'Failed to fine prompt mapping while the default prompt for {model_name} is a callable.'
|
||||
) from None
|
||||
if format not in prompt_mapping:
|
||||
raise click.BadOptionUsage('format', f'Given format {format} is not valid for {model_name} (available format: {list(prompt_mapping)})')
|
||||
raise click.BadOptionUsage(
|
||||
'format', f'Given format {format} is not valid for {model_name} (available format: {list(prompt_mapping)})'
|
||||
)
|
||||
_prompt_template = template(format)
|
||||
else:
|
||||
_prompt_template = template
|
||||
@@ -55,7 +72,9 @@ def cli(ctx: click.Context, /, model_name: str, prompt: str, format: str | None,
|
||||
fully_formatted = process_prompt(prompt, _prompt_template, True, **_memoized)
|
||||
except RuntimeError as err:
|
||||
logger.debug('Exception caught while formatting prompt: %s', err)
|
||||
fully_formatted = openllm.AutoConfig.for_model(model_name).sanitize_parameters(prompt, prompt_template=_prompt_template)[0]
|
||||
fully_formatted = openllm.AutoConfig.for_model(model_name).sanitize_parameters(
|
||||
prompt, prompt_template=_prompt_template
|
||||
)[0]
|
||||
termui.echo(orjson.dumps({'prompt': fully_formatted}, option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
except Exception as err:
|
||||
traceback.print_exc()
|
||||
|
||||
@@ -10,20 +10,25 @@ import openllm
|
||||
from bentoml._internal.utils import human_readable_size
|
||||
from openllm.cli import termui
|
||||
|
||||
|
||||
@click.command('list_bentos', context_settings=termui.CONTEXT_SETTINGS)
|
||||
@click.pass_context
|
||||
def cli(ctx: click.Context) -> None:
|
||||
"""List available bentos built by OpenLLM."""
|
||||
mapping = {
|
||||
k: [{
|
||||
'tag': str(b.tag),
|
||||
'size': human_readable_size(openllm.utils.calc_dir_size(b.path)),
|
||||
'models': [{
|
||||
'tag': str(m.tag),
|
||||
'size': human_readable_size(openllm.utils.calc_dir_size(m.path))
|
||||
} for m in (bentoml.models.get(_.tag) for _ in b.info.models)]
|
||||
} for b in tuple(i for i in bentoml.list() if all(
|
||||
k in i.info.labels for k in {'start_name', 'bundler'})) if b.info.labels['start_name'] == k] for k in tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys())
|
||||
k: [
|
||||
{
|
||||
'tag': str(b.tag),
|
||||
'size': human_readable_size(openllm.utils.calc_dir_size(b.path)),
|
||||
'models': [
|
||||
{'tag': str(m.tag), 'size': human_readable_size(openllm.utils.calc_dir_size(m.path))}
|
||||
for m in (bentoml.models.get(_.tag) for _ in b.info.models)
|
||||
],
|
||||
}
|
||||
for b in tuple(i for i in bentoml.list() if all(k in i.info.labels for k in {'start_name', 'bundler'}))
|
||||
if b.info.labels['start_name'] == k
|
||||
]
|
||||
for k in tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys())
|
||||
}
|
||||
mapping = {k: v for k, v in mapping.items() if v}
|
||||
termui.echo(orjson.dumps(mapping, option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
|
||||
@@ -13,21 +13,40 @@ from openllm.cli import termui
|
||||
from openllm.cli._factory import model_complete_envvar
|
||||
from openllm.cli._factory import model_name_argument
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm_core._typing_compat import DictStrAny
|
||||
|
||||
|
||||
@click.command('list_models', context_settings=termui.CONTEXT_SETTINGS)
|
||||
@model_name_argument(required=False, shell_complete=model_complete_envvar)
|
||||
def cli(model_name: str | None) -> DictStrAny:
|
||||
"""This is equivalent to openllm models --show-available less the nice table."""
|
||||
models = tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys())
|
||||
ids_in_local_store = {
|
||||
k: [i for i in bentoml.models.list() if 'framework' in i.info.labels and i.info.labels['framework'] == 'openllm' and 'model_name' in i.info.labels and i.info.labels['model_name'] == k]
|
||||
for k in models
|
||||
k: [
|
||||
i
|
||||
for i in bentoml.models.list()
|
||||
if 'framework' in i.info.labels
|
||||
and i.info.labels['framework'] == 'openllm'
|
||||
and 'model_name' in i.info.labels
|
||||
and i.info.labels['model_name'] == k
|
||||
]
|
||||
for k in models
|
||||
}
|
||||
if model_name is not None:
|
||||
ids_in_local_store = {k: [i for i in v if 'model_name' in i.info.labels and i.info.labels['model_name'] == inflection.dasherize(model_name)] for k, v in ids_in_local_store.items()}
|
||||
ids_in_local_store = {
|
||||
k: [
|
||||
i
|
||||
for i in v
|
||||
if 'model_name' in i.info.labels and i.info.labels['model_name'] == inflection.dasherize(model_name)
|
||||
]
|
||||
for k, v in ids_in_local_store.items()
|
||||
}
|
||||
ids_in_local_store = {k: v for k, v in ids_in_local_store.items() if v}
|
||||
local_models = {k: [{'tag': str(i.tag), 'size': human_readable_size(openllm.utils.calc_dir_size(i.path))} for i in val] for k, val in ids_in_local_store.items()}
|
||||
local_models = {
|
||||
k: [{'tag': str(i.tag), 'size': human_readable_size(openllm.utils.calc_dir_size(i.path))} for i in val]
|
||||
for k, val in ids_in_local_store.items()
|
||||
}
|
||||
termui.echo(orjson.dumps(local_models, option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
return local_models
|
||||
|
||||
@@ -19,11 +19,13 @@ from openllm_core.utils import is_jupyter_available
|
||||
from openllm_core.utils import is_jupytext_available
|
||||
from openllm_core.utils import is_notebook_available
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm_core._typing_compat import DictStrAny
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_notebook_metadata() -> DictStrAny:
|
||||
with open(os.path.join(os.path.dirname(playground.__file__), '_meta.yml'), 'r') as f:
|
||||
content = yaml.safe_load(f)
|
||||
@@ -31,9 +33,17 @@ def load_notebook_metadata() -> DictStrAny:
|
||||
raise ValueError("Invalid metadata file. All entries must have a 'description' key.")
|
||||
return content
|
||||
|
||||
|
||||
@click.command('playground', context_settings=termui.CONTEXT_SETTINGS)
|
||||
@click.argument('output-dir', default=None, required=False)
|
||||
@click.option('--port', envvar='JUPYTER_PORT', show_envvar=True, show_default=True, default=8888, help='Default port for Jupyter server')
|
||||
@click.option(
|
||||
'--port',
|
||||
envvar='JUPYTER_PORT',
|
||||
show_envvar=True,
|
||||
show_default=True,
|
||||
default=8888,
|
||||
help='Default port for Jupyter server',
|
||||
)
|
||||
@click.pass_context
|
||||
def cli(ctx: click.Context, output_dir: str | None, port: int) -> None:
|
||||
"""OpenLLM Playground.
|
||||
@@ -54,7 +64,9 @@ def cli(ctx: click.Context, output_dir: str | None, port: int) -> None:
|
||||
> This command requires Jupyter to be installed. Install it with 'pip install "openllm[playground]"'
|
||||
"""
|
||||
if not is_jupyter_available() or not is_jupytext_available() or not is_notebook_available():
|
||||
raise RuntimeError("Playground requires 'jupyter', 'jupytext', and 'notebook'. Install it with 'pip install \"openllm[playground]\"'")
|
||||
raise RuntimeError(
|
||||
"Playground requires 'jupyter', 'jupytext', and 'notebook'. Install it with 'pip install \"openllm[playground]\"'"
|
||||
)
|
||||
metadata = load_notebook_metadata()
|
||||
_temp_dir = False
|
||||
if output_dir is None:
|
||||
@@ -66,20 +78,37 @@ def cli(ctx: click.Context, output_dir: str | None, port: int) -> None:
|
||||
termui.echo('The playground notebooks will be saved to: ' + os.path.abspath(output_dir), fg='blue')
|
||||
for module in pkgutil.iter_modules(playground.__path__):
|
||||
if module.ispkg or os.path.exists(os.path.join(output_dir, module.name + '.ipynb')):
|
||||
logger.debug('Skipping: %s (%s)', module.name, 'File already exists' if not module.ispkg else f'{module.name} is a module')
|
||||
logger.debug(
|
||||
'Skipping: %s (%s)', module.name, 'File already exists' if not module.ispkg else f'{module.name} is a module'
|
||||
)
|
||||
continue
|
||||
if not isinstance(module.module_finder, importlib.machinery.FileFinder):
|
||||
continue
|
||||
if not isinstance(module.module_finder, importlib.machinery.FileFinder): continue
|
||||
termui.echo('Generating notebook for: ' + module.name, fg='magenta')
|
||||
markdown_cell = nbformat.v4.new_markdown_cell(metadata[module.name]['description'])
|
||||
f = jupytext.read(os.path.join(module.module_finder.path, module.name + '.py'))
|
||||
f.cells.insert(0, markdown_cell)
|
||||
jupytext.write(f, os.path.join(output_dir, module.name + '.ipynb'), fmt='notebook')
|
||||
try:
|
||||
subprocess.check_output([sys.executable, '-m', 'jupyter', 'notebook', '--notebook-dir', output_dir, '--port', str(port), '--no-browser', '--debug'])
|
||||
subprocess.check_output(
|
||||
[
|
||||
sys.executable,
|
||||
'-m',
|
||||
'jupyter',
|
||||
'notebook',
|
||||
'--notebook-dir',
|
||||
output_dir,
|
||||
'--port',
|
||||
str(port),
|
||||
'--no-browser',
|
||||
'--debug',
|
||||
]
|
||||
)
|
||||
except subprocess.CalledProcessError as e:
|
||||
termui.echo(e.output, fg='red')
|
||||
raise click.ClickException(f'Failed to start a jupyter server:\n{e}') from None
|
||||
except KeyboardInterrupt:
|
||||
termui.echo('\nShutting down Jupyter server...', fg='yellow')
|
||||
if _temp_dir: termui.echo('Note: You can access the generated notebooks in: ' + output_dir, fg='blue')
|
||||
if _temp_dir:
|
||||
termui.echo('Note: You can access the generated notebooks in: ' + output_dir, fg='blue')
|
||||
ctx.exit(0)
|
||||
|
||||
@@ -13,8 +13,10 @@ from openllm_core._typing_compat import DictStrAny
|
||||
from openllm_core.utils import get_debug_mode
|
||||
from openllm_core.utils import get_quiet_mode
|
||||
|
||||
|
||||
logger = logging.getLogger('openllm')
|
||||
|
||||
|
||||
class Level(enum.IntEnum):
|
||||
NOTSET = logging.DEBUG
|
||||
DEBUG = logging.DEBUG
|
||||
@@ -25,19 +27,31 @@ class Level(enum.IntEnum):
|
||||
|
||||
@property
|
||||
def color(self) -> str | None:
|
||||
return {Level.NOTSET: None, Level.DEBUG: 'cyan', Level.INFO: 'green', Level.WARNING: 'yellow', Level.ERROR: 'red', Level.CRITICAL: 'red'}[self]
|
||||
return {
|
||||
Level.NOTSET: None,
|
||||
Level.DEBUG: 'cyan',
|
||||
Level.INFO: 'green',
|
||||
Level.WARNING: 'yellow',
|
||||
Level.ERROR: 'red',
|
||||
Level.CRITICAL: 'red',
|
||||
}[self]
|
||||
|
||||
|
||||
class JsonLog(t.TypedDict):
|
||||
log_level: Level
|
||||
content: str
|
||||
|
||||
|
||||
def log(content: str, level: Level = Level.INFO, fg: str | None = None) -> None:
|
||||
def caller(text: str) -> None:
|
||||
if get_debug_mode(): logger.log(level.value, text)
|
||||
else: echo(JsonLog(log_level=level, content=content), json=True, fg=fg)
|
||||
if get_debug_mode():
|
||||
logger.log(level.value, text)
|
||||
else:
|
||||
echo(JsonLog(log_level=level, content=content), json=True, fg=fg)
|
||||
|
||||
caller(orjson.dumps(JsonLog(log_level=level, content=content)).decode())
|
||||
|
||||
|
||||
warning = functools.partial(log, level=Level.WARNING)
|
||||
error = functools.partial(log, level=Level.ERROR)
|
||||
critical = functools.partial(log, level=Level.CRITICAL)
|
||||
@@ -45,8 +59,10 @@ debug = functools.partial(log, level=Level.DEBUG)
|
||||
info = functools.partial(log, level=Level.INFO)
|
||||
notset = functools.partial(log, level=Level.NOTSET)
|
||||
|
||||
|
||||
def echo(text: t.Any, fg: str | None = None, _with_style: bool = True, json: bool = False, **attrs: t.Any) -> None:
|
||||
if json and not isinstance(text, dict): raise TypeError('text must be a dict')
|
||||
if json and not isinstance(text, dict):
|
||||
raise TypeError('text must be a dict')
|
||||
if json:
|
||||
if 'content' in text and 'log_level' in text:
|
||||
content = t.cast(DictStrAny, text)['content']
|
||||
@@ -58,8 +74,14 @@ def echo(text: t.Any, fg: str | None = None, _with_style: bool = True, json: boo
|
||||
content = t.cast(str, text)
|
||||
attrs['fg'] = fg if not get_debug_mode() else None
|
||||
|
||||
if not get_quiet_mode(): t.cast(t.Callable[..., None], click.echo if not _with_style else click.secho)(content, **attrs)
|
||||
if not get_quiet_mode():
|
||||
t.cast(t.Callable[..., None], click.echo if not _with_style else click.secho)(content, **attrs)
|
||||
|
||||
|
||||
COLUMNS: int = int(os.environ.get('COLUMNS', str(120)))
|
||||
CONTEXT_SETTINGS: DictStrAny = {'help_option_names': ['-h', '--help'], 'max_content_width': COLUMNS, 'token_normalize_func': inflection.underscore}
|
||||
CONTEXT_SETTINGS: DictStrAny = {
|
||||
'help_option_names': ['-h', '--help'],
|
||||
'max_content_width': COLUMNS,
|
||||
'token_normalize_func': inflection.underscore,
|
||||
}
|
||||
__all__ = ['echo', 'COLUMNS', 'CONTEXT_SETTINGS', 'log', 'warning', 'error', 'critical', 'debug', 'info', 'Level']
|
||||
|
||||
@@ -1,23 +1,27 @@
|
||||
'''OpenLLM Python client.
|
||||
"""OpenLLM Python client.
|
||||
|
||||
```python
|
||||
client = openllm.client.HTTPClient("http://localhost:8080")
|
||||
client.query("What is the difference between gather and scatter?")
|
||||
```
|
||||
'''
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import openllm_client
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm_client import AsyncHTTPClient as AsyncHTTPClient
|
||||
from openllm_client import HTTPClient as HTTPClient
|
||||
# from openllm_client import AsyncGrpcClient as AsyncGrpcClient
|
||||
# from openllm_client import GrpcClient as GrpcClient
|
||||
|
||||
|
||||
def __dir__() -> t.Sequence[str]:
|
||||
return sorted(dir(openllm_client))
|
||||
|
||||
|
||||
def __getattr__(it: str) -> t.Any:
|
||||
return getattr(openllm_client, it)
|
||||
|
||||
@@ -6,6 +6,7 @@ Each module should implement the following API:
|
||||
|
||||
- `mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Service: ...`
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
@@ -14,16 +15,21 @@ from openllm_core.utils import LazyModule
|
||||
from . import hf as hf
|
||||
from . import openai as openai
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import bentoml
|
||||
import openllm
|
||||
|
||||
_import_structure: dict[str, list[str]] = {'openai': [], 'hf': []}
|
||||
|
||||
|
||||
def mount_entrypoints(svc: bentoml.Service, llm: openllm.LLM[t.Any, t.Any]) -> bentoml.Service:
|
||||
return openai.mount_to_svc(hf.mount_to_svc(svc, llm), llm)
|
||||
|
||||
__lazy = LazyModule(__name__, globals()['__file__'], _import_structure, extra_objects={'mount_entrypoints': mount_entrypoints})
|
||||
|
||||
__lazy = LazyModule(
|
||||
__name__, globals()['__file__'], _import_structure, extra_objects={'mount_entrypoints': mount_entrypoints}
|
||||
)
|
||||
__all__ = __lazy.__all__
|
||||
__dir__ = __lazy.__dir__
|
||||
__getattr__ = __lazy.__getattr__
|
||||
|
||||
@@ -15,6 +15,7 @@ from starlette.schemas import SchemaGenerator
|
||||
from openllm_core._typing_compat import ParamSpec
|
||||
from openllm_core.utils import first_not_none
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from attr import AttrsInstance
|
||||
|
||||
@@ -23,7 +24,7 @@ if t.TYPE_CHECKING:
|
||||
P = ParamSpec('P')
|
||||
OPENAPI_VERSION, API_VERSION = '3.0.2', '1.0'
|
||||
# NOTE: OpenAI schema
|
||||
LIST_MODEL_SCHEMA = '''\
|
||||
LIST_MODEL_SCHEMA = """\
|
||||
---
|
||||
consumes:
|
||||
- application/json
|
||||
@@ -53,8 +54,8 @@ responses:
|
||||
owned_by: 'na'
|
||||
schema:
|
||||
$ref: '#/components/schemas/ModelList'
|
||||
'''
|
||||
CHAT_COMPLETION_SCHEMA = '''\
|
||||
"""
|
||||
CHAT_COMPLETION_SCHEMA = """\
|
||||
---
|
||||
consumes:
|
||||
- application/json
|
||||
@@ -191,8 +192,8 @@ responses:
|
||||
}
|
||||
}
|
||||
description: Bad Request
|
||||
'''
|
||||
COMPLETION_SCHEMA = '''\
|
||||
"""
|
||||
COMPLETION_SCHEMA = """\
|
||||
---
|
||||
consumes:
|
||||
- application/json
|
||||
@@ -344,8 +345,8 @@ responses:
|
||||
}
|
||||
}
|
||||
description: Bad Request
|
||||
'''
|
||||
HF_AGENT_SCHEMA = '''\
|
||||
"""
|
||||
HF_AGENT_SCHEMA = """\
|
||||
---
|
||||
consumes:
|
||||
- application/json
|
||||
@@ -389,8 +390,8 @@ responses:
|
||||
schema:
|
||||
$ref: '#/components/schemas/HFErrorResponse'
|
||||
description: Not Found
|
||||
'''
|
||||
HF_ADAPTERS_SCHEMA = '''\
|
||||
"""
|
||||
HF_ADAPTERS_SCHEMA = """\
|
||||
---
|
||||
consumes:
|
||||
- application/json
|
||||
@@ -420,16 +421,19 @@ responses:
|
||||
schema:
|
||||
$ref: '#/components/schemas/HFErrorResponse'
|
||||
description: Not Found
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
def add_schema_definitions(append_str: str) -> t.Callable[[t.Callable[P, t.Any]], t.Callable[P, t.Any]]:
|
||||
def docstring_decorator(func: t.Callable[P, t.Any]) -> t.Callable[P, t.Any]:
|
||||
if func.__doc__ is None: func.__doc__ = ''
|
||||
if func.__doc__ is None:
|
||||
func.__doc__ = ''
|
||||
func.__doc__ = func.__doc__.strip() + '\n\n' + append_str.strip()
|
||||
return func
|
||||
|
||||
return docstring_decorator
|
||||
|
||||
|
||||
class OpenLLMSchemaGenerator(SchemaGenerator):
|
||||
def get_endpoints(self, routes: list[BaseRoute]) -> list[EndpointInfo]:
|
||||
endpoints_info: list[EndpointInfo] = []
|
||||
@@ -437,20 +441,29 @@ class OpenLLMSchemaGenerator(SchemaGenerator):
|
||||
if isinstance(route, (Mount, Host)):
|
||||
routes = route.routes or []
|
||||
path = self._remove_converter(route.path) if isinstance(route, Mount) else ''
|
||||
sub_endpoints = [EndpointInfo(path=f'{path}{sub_endpoint.path}', http_method=sub_endpoint.http_method, func=sub_endpoint.func) for sub_endpoint in self.get_endpoints(routes)]
|
||||
sub_endpoints = [
|
||||
EndpointInfo(path=f'{path}{sub_endpoint.path}', http_method=sub_endpoint.http_method, func=sub_endpoint.func)
|
||||
for sub_endpoint in self.get_endpoints(routes)
|
||||
]
|
||||
endpoints_info.extend(sub_endpoints)
|
||||
elif not isinstance(route, Route) or not route.include_in_schema:
|
||||
continue
|
||||
elif inspect.isfunction(route.endpoint) or inspect.ismethod(route.endpoint) or isinstance(route.endpoint, functools.partial):
|
||||
elif (
|
||||
inspect.isfunction(route.endpoint)
|
||||
or inspect.ismethod(route.endpoint)
|
||||
or isinstance(route.endpoint, functools.partial)
|
||||
):
|
||||
endpoint = route.endpoint.func if isinstance(route.endpoint, functools.partial) else route.endpoint
|
||||
path = self._remove_converter(route.path)
|
||||
for method in route.methods or ['GET']:
|
||||
if method == 'HEAD': continue
|
||||
if method == 'HEAD':
|
||||
continue
|
||||
endpoints_info.append(EndpointInfo(path, method.lower(), endpoint))
|
||||
else:
|
||||
path = self._remove_converter(route.path)
|
||||
for method in ['get', 'post', 'put', 'patch', 'delete', 'options']:
|
||||
if not hasattr(route.endpoint, method): continue
|
||||
if not hasattr(route.endpoint, method):
|
||||
continue
|
||||
func = getattr(route.endpoint, method)
|
||||
endpoints_info.append(EndpointInfo(path, method.lower(), func))
|
||||
return endpoints_info
|
||||
@@ -459,37 +472,52 @@ class OpenLLMSchemaGenerator(SchemaGenerator):
|
||||
schema = dict(self.base_schema)
|
||||
schema.setdefault('paths', {})
|
||||
endpoints_info = self.get_endpoints(routes)
|
||||
if mount_path: mount_path = f'/{mount_path}' if not mount_path.startswith('/') else mount_path
|
||||
if mount_path:
|
||||
mount_path = f'/{mount_path}' if not mount_path.startswith('/') else mount_path
|
||||
|
||||
for endpoint in endpoints_info:
|
||||
parsed = self.parse_docstring(endpoint.func)
|
||||
if not parsed: continue
|
||||
if not parsed:
|
||||
continue
|
||||
|
||||
path = endpoint.path if mount_path is None else mount_path + endpoint.path
|
||||
if path not in schema['paths']: schema['paths'][path] = {}
|
||||
if path not in schema['paths']:
|
||||
schema['paths'][path] = {}
|
||||
schema['paths'][path][endpoint.http_method] = parsed
|
||||
|
||||
return schema
|
||||
|
||||
def get_generator(title: str, components: list[type[AttrsInstance]] | None = None, tags: list[dict[str, t.Any]] | None = None) -> OpenLLMSchemaGenerator:
|
||||
|
||||
def get_generator(
|
||||
title: str, components: list[type[AttrsInstance]] | None = None, tags: list[dict[str, t.Any]] | None = None
|
||||
) -> OpenLLMSchemaGenerator:
|
||||
base_schema: dict[str, t.Any] = dict(info={'title': title, 'version': API_VERSION}, version=OPENAPI_VERSION)
|
||||
if components: base_schema['components'] = {'schemas': {c.__name__: component_schema_generator(c) for c in components}}
|
||||
if tags is not None and tags: base_schema['tags'] = tags
|
||||
if components:
|
||||
base_schema['components'] = {'schemas': {c.__name__: component_schema_generator(c) for c in components}}
|
||||
if tags is not None and tags:
|
||||
base_schema['tags'] = tags
|
||||
return OpenLLMSchemaGenerator(base_schema)
|
||||
|
||||
|
||||
def component_schema_generator(attr_cls: type[AttrsInstance], description: str | None = None) -> dict[str, t.Any]:
|
||||
schema: dict[str, t.Any] = {'type': 'object', 'required': [], 'properties': {}, 'title': attr_cls.__name__}
|
||||
schema['description'] = first_not_none(getattr(attr_cls, '__doc__', None), description, default=f'Generated components for {attr_cls.__name__}')
|
||||
schema['description'] = first_not_none(
|
||||
getattr(attr_cls, '__doc__', None), description, default=f'Generated components for {attr_cls.__name__}'
|
||||
)
|
||||
for field in attr.fields(attr.resolve_types(attr_cls)): # type: ignore[misc,type-var]
|
||||
attr_type = field.type
|
||||
origin_type = t.get_origin(attr_type)
|
||||
args_type = t.get_args(attr_type)
|
||||
|
||||
# Map Python types to OpenAPI schema types
|
||||
if attr_type == str: schema_type = 'string'
|
||||
elif attr_type == int: schema_type = 'integer'
|
||||
elif attr_type == float: schema_type = 'number'
|
||||
elif attr_type == bool: schema_type = 'boolean'
|
||||
if attr_type == str:
|
||||
schema_type = 'string'
|
||||
elif attr_type == int:
|
||||
schema_type = 'integer'
|
||||
elif attr_type == float:
|
||||
schema_type = 'number'
|
||||
elif attr_type == bool:
|
||||
schema_type = 'boolean'
|
||||
elif origin_type is list or origin_type is tuple:
|
||||
schema_type = 'array'
|
||||
elif origin_type is dict:
|
||||
@@ -504,14 +532,18 @@ def component_schema_generator(attr_cls: type[AttrsInstance], description: str |
|
||||
else:
|
||||
schema_type = 'string'
|
||||
|
||||
if 'prop_schema' not in locals(): prop_schema = {'type': schema_type}
|
||||
if field.default is not attr.NOTHING and not isinstance(field.default, attr.Factory): prop_schema['default'] = field.default # type: ignore[arg-type]
|
||||
if field.default is attr.NOTHING and not isinstance(attr_type, type(t.Optional)): schema['required'].append(field.name)
|
||||
if 'prop_schema' not in locals():
|
||||
prop_schema = {'type': schema_type}
|
||||
if field.default is not attr.NOTHING and not isinstance(field.default, attr.Factory):
|
||||
prop_schema['default'] = field.default # type: ignore[arg-type]
|
||||
if field.default is attr.NOTHING and not isinstance(attr_type, type(t.Optional)):
|
||||
schema['required'].append(field.name)
|
||||
schema['properties'][field.name] = prop_schema
|
||||
locals().pop('prop_schema', None)
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
class MKSchema:
|
||||
def __init__(self, it: dict[str, t.Any]) -> None:
|
||||
self.it = it
|
||||
@@ -519,19 +551,30 @@ class MKSchema:
|
||||
def asdict(self) -> dict[str, t.Any]:
|
||||
return self.it
|
||||
|
||||
def append_schemas(svc: bentoml.Service, generated_schema: dict[str, t.Any], tags_order: t.Literal['prepend', 'append'] = 'prepend') -> bentoml.Service:
|
||||
|
||||
def append_schemas(
|
||||
svc: bentoml.Service, generated_schema: dict[str, t.Any], tags_order: t.Literal['prepend', 'append'] = 'prepend'
|
||||
) -> bentoml.Service:
|
||||
# HACK: Dirty hack to append schemas to existing service. We def need to support mounting Starlette app OpenAPI spec.
|
||||
from bentoml._internal.service.openapi.specification import OpenAPISpecification
|
||||
|
||||
svc_schema: t.Any = svc.openapi_spec
|
||||
if isinstance(svc_schema, (OpenAPISpecification, MKSchema)): svc_schema = svc_schema.asdict()
|
||||
if isinstance(svc_schema, (OpenAPISpecification, MKSchema)):
|
||||
svc_schema = svc_schema.asdict()
|
||||
if 'tags' in generated_schema:
|
||||
if tags_order == 'prepend': svc_schema['tags'] = generated_schema['tags'] + svc_schema['tags']
|
||||
elif tags_order == 'append': svc_schema['tags'].extend(generated_schema['tags'])
|
||||
else: raise ValueError(f'Invalid tags_order: {tags_order}')
|
||||
if 'components' in generated_schema: svc_schema['components']['schemas'].update(generated_schema['components']['schemas'])
|
||||
if tags_order == 'prepend':
|
||||
svc_schema['tags'] = generated_schema['tags'] + svc_schema['tags']
|
||||
elif tags_order == 'append':
|
||||
svc_schema['tags'].extend(generated_schema['tags'])
|
||||
else:
|
||||
raise ValueError(f'Invalid tags_order: {tags_order}')
|
||||
if 'components' in generated_schema:
|
||||
svc_schema['components']['schemas'].update(generated_schema['components']['schemas'])
|
||||
svc_schema['paths'].update(generated_schema['paths'])
|
||||
|
||||
from bentoml._internal.service import openapi # HACK: mk this attribute until we have a better way to add starlette schemas.
|
||||
from bentoml._internal.service import (
|
||||
openapi, # HACK: mk this attribute until we have a better way to add starlette schemas.
|
||||
)
|
||||
|
||||
# yapf: disable
|
||||
def mk_generate_spec(svc:bentoml.Service,openapi_version:str=OPENAPI_VERSION)->MKSchema:return MKSchema(svc_schema)
|
||||
|
||||
@@ -23,17 +23,21 @@ from ..protocol.hf import AgentRequest
|
||||
from ..protocol.hf import AgentResponse
|
||||
from ..protocol.hf import HFErrorResponse
|
||||
|
||||
schemas = get_generator('hf',
|
||||
components=[AgentRequest, AgentResponse, HFErrorResponse],
|
||||
tags=[{
|
||||
'name': 'HF',
|
||||
'description': 'HF integration, including Agent and others schema endpoints.',
|
||||
'externalDocs': 'https://huggingface.co/docs/transformers/main_classes/agent'
|
||||
}])
|
||||
|
||||
schemas = get_generator(
|
||||
'hf',
|
||||
components=[AgentRequest, AgentResponse, HFErrorResponse],
|
||||
tags=[
|
||||
{
|
||||
'name': 'HF',
|
||||
'description': 'HF integration, including Agent and others schema endpoints.',
|
||||
'externalDocs': 'https://huggingface.co/docs/transformers/main_classes/agent',
|
||||
}
|
||||
],
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
|
||||
from peft.config import PeftConfig
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
@@ -44,20 +48,28 @@ if t.TYPE_CHECKING:
|
||||
from openllm_core._typing_compat import M
|
||||
from openllm_core._typing_compat import T
|
||||
|
||||
|
||||
def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Service:
|
||||
app = Starlette(debug=True,
|
||||
routes=[
|
||||
Route('/agent', endpoint=functools.partial(hf_agent, llm=llm), name='hf_agent', methods=['POST']),
|
||||
Route('/adapters', endpoint=functools.partial(adapters_map, llm=llm), name='adapters', methods=['GET']),
|
||||
Route('/schema', endpoint=openapi_schema, include_in_schema=False)
|
||||
])
|
||||
app = Starlette(
|
||||
debug=True,
|
||||
routes=[
|
||||
Route('/agent', endpoint=functools.partial(hf_agent, llm=llm), name='hf_agent', methods=['POST']),
|
||||
Route('/adapters', endpoint=functools.partial(adapters_map, llm=llm), name='adapters', methods=['GET']),
|
||||
Route('/schema', endpoint=openapi_schema, include_in_schema=False),
|
||||
],
|
||||
)
|
||||
mount_path = '/hf'
|
||||
generated_schema = schemas.get_schema(routes=app.routes, mount_path=mount_path)
|
||||
svc.mount_asgi_app(app, path=mount_path)
|
||||
return append_schemas(svc, generated_schema, tags_order='append')
|
||||
|
||||
|
||||
def error_response(status_code: HTTPStatus, message: str) -> JSONResponse:
|
||||
return JSONResponse(converter.unstructure(HFErrorResponse(message=message, error_code=status_code.value)), status_code=status_code.value)
|
||||
return JSONResponse(
|
||||
converter.unstructure(HFErrorResponse(message=message, error_code=status_code.value)),
|
||||
status_code=status_code.value,
|
||||
)
|
||||
|
||||
|
||||
@add_schema_definitions(HF_AGENT_SCHEMA)
|
||||
async def hf_agent(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
@@ -72,22 +84,26 @@ async def hf_agent(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
stop = request.parameters.pop('stop', ['\n'])
|
||||
try:
|
||||
result = await llm.generate(request.inputs, stop=stop, **request.parameters)
|
||||
return JSONResponse(converter.unstructure([AgentResponse(generated_text=result.outputs[0].text)]), status_code=HTTPStatus.OK.value)
|
||||
return JSONResponse(
|
||||
converter.unstructure([AgentResponse(generated_text=result.outputs[0].text)]), status_code=HTTPStatus.OK.value
|
||||
)
|
||||
except Exception as err:
|
||||
logger.error('Error while generating: %s', err)
|
||||
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, 'Error while generating (Check server log).')
|
||||
|
||||
|
||||
@add_schema_definitions(HF_ADAPTERS_SCHEMA)
|
||||
def adapters_map(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
if not llm.has_adapters: return error_response(HTTPStatus.NOT_FOUND, 'No adapters found.')
|
||||
if not llm.has_adapters:
|
||||
return error_response(HTTPStatus.NOT_FOUND, 'No adapters found.')
|
||||
return JSONResponse(
|
||||
{
|
||||
adapter_tuple[1]: {
|
||||
'adapter_name': k,
|
||||
'adapter_type': t.cast(Enum, adapter_tuple[0].peft_type).value
|
||||
} for k, adapter_tuple in t.cast(t.Dict[str, t.Tuple['PeftConfig', str]], dict(*llm.adapter_map.values())).items()
|
||||
},
|
||||
status_code=HTTPStatus.OK.value)
|
||||
{
|
||||
adapter_tuple[1]: {'adapter_name': k, 'adapter_type': t.cast(Enum, adapter_tuple[0].peft_type).value}
|
||||
for k, adapter_tuple in t.cast(t.Dict[str, t.Tuple['PeftConfig', str]], dict(*llm.adapter_map.values())).items()
|
||||
},
|
||||
status_code=HTTPStatus.OK.value,
|
||||
)
|
||||
|
||||
|
||||
def openapi_schema(req: Request) -> Response:
|
||||
return schemas.OpenAPIResponse(req)
|
||||
|
||||
@@ -42,14 +42,27 @@ from ..protocol.openai import ModelCard
|
||||
from ..protocol.openai import ModelList
|
||||
from ..protocol.openai import UsageInfo
|
||||
|
||||
|
||||
schemas = get_generator(
|
||||
'openai',
|
||||
components=[ErrorResponse, ModelList, ChatCompletionResponse, ChatCompletionRequest, ChatCompletionStreamResponse, CompletionRequest, CompletionResponse, CompletionStreamResponse],
|
||||
tags=[{
|
||||
'name': 'OpenAI',
|
||||
'description': 'OpenAI Compatible API support',
|
||||
'externalDocs': 'https://platform.openai.com/docs/api-reference/completions/object'
|
||||
}])
|
||||
'openai',
|
||||
components=[
|
||||
ErrorResponse,
|
||||
ModelList,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionStreamResponse,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionStreamResponse,
|
||||
],
|
||||
tags=[
|
||||
{
|
||||
'name': 'OpenAI',
|
||||
'description': 'OpenAI Compatible API support',
|
||||
'externalDocs': 'https://platform.openai.com/docs/api-reference/completions/object',
|
||||
}
|
||||
],
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
@@ -64,20 +77,34 @@ if t.TYPE_CHECKING:
|
||||
from openllm_core._typing_compat import M
|
||||
from openllm_core._typing_compat import T
|
||||
|
||||
|
||||
def jsonify_attr(obj: AttrsInstance) -> str:
|
||||
return orjson.dumps(converter.unstructure(obj)).decode()
|
||||
|
||||
def error_response(status_code: HTTPStatus, message: str) -> JSONResponse:
|
||||
return JSONResponse({'error': converter.unstructure(ErrorResponse(message=message, type='invalid_request_error', code=str(status_code.value)))}, status_code=status_code.value)
|
||||
|
||||
async def check_model(request: CompletionRequest | ChatCompletionRequest, model: str) -> JSONResponse | None:
|
||||
if request.model == model: return None
|
||||
return error_response(
|
||||
HTTPStatus.NOT_FOUND,
|
||||
f"Model '{request.model}' does not exists. Try 'GET /v1/models' to see available models.\nTip: If you are migrating from OpenAI, make sure to update your 'model' parameters in the request."
|
||||
def error_response(status_code: HTTPStatus, message: str) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
{
|
||||
'error': converter.unstructure(
|
||||
ErrorResponse(message=message, type='invalid_request_error', code=str(status_code.value))
|
||||
)
|
||||
},
|
||||
status_code=status_code.value,
|
||||
)
|
||||
|
||||
def create_logprobs(token_ids: list[int], id_logprobs: list[dict[int, float]], initial_text_offset: int = 0, *, llm: openllm.LLM[M, T]) -> LogProbs:
|
||||
|
||||
async def check_model(request: CompletionRequest | ChatCompletionRequest, model: str) -> JSONResponse | None:
|
||||
if request.model == model:
|
||||
return None
|
||||
return error_response(
|
||||
HTTPStatus.NOT_FOUND,
|
||||
f"Model '{request.model}' does not exists. Try 'GET /v1/models' to see available models.\nTip: If you are migrating from OpenAI, make sure to update your 'model' parameters in the request.",
|
||||
)
|
||||
|
||||
|
||||
def create_logprobs(
|
||||
token_ids: list[int], id_logprobs: list[dict[int, float]], initial_text_offset: int = 0, *, llm: openllm.LLM[M, T]
|
||||
) -> LogProbs:
|
||||
# Create OpenAI-style logprobs.
|
||||
logprobs = LogProbs()
|
||||
last_token_len = 0
|
||||
@@ -94,22 +121,29 @@ def create_logprobs(token_ids: list[int], id_logprobs: list[dict[int, float]], i
|
||||
logprobs.top_logprobs.append({llm.tokenizer.convert_ids_to_tokens(i): p for i, p in id_logprob.items()})
|
||||
return logprobs
|
||||
|
||||
|
||||
def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Service:
|
||||
app = Starlette(debug=True,
|
||||
routes=[
|
||||
Route('/models', functools.partial(list_models, llm=llm), methods=['GET']),
|
||||
Route('/completions', functools.partial(create_completions, llm=llm), methods=['POST']),
|
||||
Route('/chat/completions', functools.partial(create_chat_completions, llm=llm), methods=['POST'])
|
||||
])
|
||||
app = Starlette(
|
||||
debug=True,
|
||||
routes=[
|
||||
Route('/models', functools.partial(list_models, llm=llm), methods=['GET']),
|
||||
Route('/completions', functools.partial(create_completions, llm=llm), methods=['POST']),
|
||||
Route('/chat/completions', functools.partial(create_chat_completions, llm=llm), methods=['POST']),
|
||||
],
|
||||
)
|
||||
mount_path = '/v1'
|
||||
generated_schema = schemas.get_schema(routes=app.routes, mount_path=mount_path)
|
||||
svc.mount_asgi_app(app, path=mount_path)
|
||||
return append_schemas(svc, generated_schema)
|
||||
|
||||
|
||||
# GET /v1/models
|
||||
@add_schema_definitions(LIST_MODEL_SCHEMA)
|
||||
def list_models(_: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
return JSONResponse(converter.unstructure(ModelList(data=[ModelCard(id=llm.llm_type)])), status_code=HTTPStatus.OK.value)
|
||||
return JSONResponse(
|
||||
converter.unstructure(ModelList(data=[ModelCard(id=llm.llm_type)])), status_code=HTTPStatus.OK.value
|
||||
)
|
||||
|
||||
|
||||
# POST /v1/chat/completions
|
||||
@add_schema_definitions(CHAT_COMPLETION_SCHEMA)
|
||||
@@ -124,11 +158,14 @@ async def create_chat_completions(req: Request, llm: openllm.LLM[M, T]) -> Respo
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Invalid JSON input received (Check server log).')
|
||||
logger.debug('Received chat completion request: %s', request)
|
||||
err_check = await check_model(request, llm.llm_type)
|
||||
if err_check is not None: return err_check
|
||||
if err_check is not None:
|
||||
return err_check
|
||||
|
||||
model_name, request_id = request.model, gen_random_uuid('chatcmpl')
|
||||
created_time = int(time.monotonic())
|
||||
prompt = llm.tokenizer.apply_chat_template(request.messages, tokenize=False, add_generation_prompt=llm.config['add_generation_prompt'])
|
||||
prompt = llm.tokenizer.apply_chat_template(
|
||||
request.messages, tokenize=False, add_generation_prompt=llm.config['add_generation_prompt']
|
||||
)
|
||||
logger.debug('Prompt: %r', prompt)
|
||||
config = llm.config.with_openai_request(request)
|
||||
|
||||
@@ -141,10 +178,15 @@ async def create_chat_completions(req: Request, llm: openllm.LLM[M, T]) -> Respo
|
||||
|
||||
def create_stream_response_json(index: int, text: str, finish_reason: str | None = None) -> str:
|
||||
return jsonify_attr(
|
||||
ChatCompletionStreamResponse(id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[ChatCompletionResponseStreamChoice(index=index, delta=Delta(content=text), finish_reason=finish_reason)]))
|
||||
ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[
|
||||
ChatCompletionResponseStreamChoice(index=index, delta=Delta(content=text), finish_reason=finish_reason)
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
async def completion_stream_generator() -> t.AsyncGenerator[str, None]:
|
||||
# first chunk with role
|
||||
@@ -160,25 +202,47 @@ async def create_chat_completions(req: Request, llm: openllm.LLM[M, T]) -> Respo
|
||||
|
||||
try:
|
||||
# Streaming case
|
||||
if request.stream: return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
|
||||
if request.stream:
|
||||
return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
|
||||
# Non-streaming case
|
||||
final_result: GenerationOutput | None = None
|
||||
texts: list[list[str]] = [[]] * config['n']
|
||||
token_ids: list[list[int]] = [[]] * config['n']
|
||||
async for res in result_generator:
|
||||
if await req.is_disconnected(): return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
|
||||
if await req.is_disconnected():
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
|
||||
for output in res.outputs:
|
||||
texts[output.index].append(output.text)
|
||||
token_ids[output.index].extend(output.token_ids)
|
||||
final_result = res
|
||||
if final_result is None: return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
|
||||
final_result = final_result.with_options(outputs=[output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index]) for output in final_result.outputs])
|
||||
if final_result is None:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
|
||||
final_result = final_result.with_options(
|
||||
outputs=[
|
||||
output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index])
|
||||
for output in final_result.outputs
|
||||
]
|
||||
)
|
||||
choices = [
|
||||
ChatCompletionResponseChoice(index=output.index, message=ChatMessage(role='assistant', content=output.text), finish_reason=output.finish_reason) for output in final_result.outputs
|
||||
ChatCompletionResponseChoice(
|
||||
index=output.index,
|
||||
message=ChatMessage(role='assistant', content=output.text),
|
||||
finish_reason=output.finish_reason,
|
||||
)
|
||||
for output in final_result.outputs
|
||||
]
|
||||
num_prompt_tokens, num_generated_tokens = len(t.cast(t.List[int], final_result.prompt_token_ids)), sum(len(output.token_ids) for output in final_result.outputs)
|
||||
usage = UsageInfo(prompt_tokens=num_prompt_tokens, completion_tokens=num_generated_tokens, total_tokens=num_prompt_tokens + num_generated_tokens)
|
||||
response = ChatCompletionResponse(id=request_id, created=created_time, model=model_name, usage=usage, choices=choices)
|
||||
num_prompt_tokens, num_generated_tokens = (
|
||||
len(t.cast(t.List[int], final_result.prompt_token_ids)),
|
||||
sum(len(output.token_ids) for output in final_result.outputs),
|
||||
)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
total_tokens=num_prompt_tokens + num_generated_tokens,
|
||||
)
|
||||
response = ChatCompletionResponse(
|
||||
id=request_id, created=created_time, model=model_name, usage=usage, choices=choices
|
||||
)
|
||||
|
||||
if request.stream: # type: ignore[unreachable]
|
||||
# When user requests streaming but we don't stream, we still need to
|
||||
@@ -187,7 +251,9 @@ async def create_chat_completions(req: Request, llm: openllm.LLM[M, T]) -> Respo
|
||||
yield f'data: {jsonify_attr(response)}\n\n'
|
||||
yield 'data: [DONE]\n\n'
|
||||
|
||||
return StreamingResponse(fake_stream_generator(), media_type='text/event-stream', status_code=HTTPStatus.OK.value)
|
||||
return StreamingResponse(
|
||||
fake_stream_generator(), media_type='text/event-stream', status_code=HTTPStatus.OK.value
|
||||
)
|
||||
|
||||
return JSONResponse(converter.unstructure(response), status_code=HTTPStatus.OK.value)
|
||||
except Exception as err:
|
||||
@@ -195,6 +261,7 @@ async def create_chat_completions(req: Request, llm: openllm.LLM[M, T]) -> Respo
|
||||
logger.error('Error generating completion: %s', err)
|
||||
return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)')
|
||||
|
||||
|
||||
# POST /v1/completions
|
||||
@add_schema_definitions(COMPLETION_SCHEMA)
|
||||
async def create_completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
@@ -208,18 +275,25 @@ async def create_completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Invalid JSON input received (Check server log).')
|
||||
logger.debug('Received legacy completion request: %s', request)
|
||||
err_check = await check_model(request, llm.llm_type)
|
||||
if err_check is not None: return err_check
|
||||
if err_check is not None:
|
||||
return err_check
|
||||
|
||||
if request.echo: return error_response(HTTPStatus.BAD_REQUEST, "'echo' is not yet supported.")
|
||||
if request.suffix is not None: return error_response(HTTPStatus.BAD_REQUEST, "'suffix' is not yet supported.")
|
||||
if request.logit_bias is not None and len(request.logit_bias) > 0: return error_response(HTTPStatus.BAD_REQUEST, "'logit_bias' is not yet supported.")
|
||||
if request.echo:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, "'echo' is not yet supported.")
|
||||
if request.suffix is not None:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, "'suffix' is not yet supported.")
|
||||
if request.logit_bias is not None and len(request.logit_bias) > 0:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, "'logit_bias' is not yet supported.")
|
||||
|
||||
if not request.prompt: return error_response(HTTPStatus.BAD_REQUEST, 'Please provide a prompt.')
|
||||
if not request.prompt:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Please provide a prompt.')
|
||||
prompt = request.prompt
|
||||
# TODO: Support multiple prompts
|
||||
|
||||
if request.logprobs is not None and llm.__llm_backend__ == 'pt': # TODO: support logprobs generation for PyTorch
|
||||
return error_response(HTTPStatus.BAD_REQUEST, "'logprobs' is not yet supported for PyTorch models. Make sure to unset `logprobs`.")
|
||||
return error_response(
|
||||
HTTPStatus.BAD_REQUEST, "'logprobs' is not yet supported for PyTorch models. Make sure to unset `logprobs`."
|
||||
)
|
||||
|
||||
model_name, request_id = request.model, gen_random_uuid('cmpl')
|
||||
created_time = int(time.monotonic())
|
||||
@@ -236,12 +310,19 @@ async def create_completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
# TODO: support use_beam_search
|
||||
stream = request.stream and (config['best_of'] is None or config['n'] == config['best_of'])
|
||||
|
||||
def create_stream_response_json(index: int, text: str, logprobs: LogProbs | None = None, finish_reason: str | None = None) -> str:
|
||||
def create_stream_response_json(
|
||||
index: int, text: str, logprobs: LogProbs | None = None, finish_reason: str | None = None
|
||||
) -> str:
|
||||
return jsonify_attr(
|
||||
CompletionStreamResponse(id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[CompletionResponseStreamChoice(index=index, text=text, logprobs=logprobs, finish_reason=finish_reason)]))
|
||||
CompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[
|
||||
CompletionResponseStreamChoice(index=index, text=text, logprobs=logprobs, finish_reason=finish_reason)
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
async def completion_stream_generator() -> t.AsyncGenerator[str, None]:
|
||||
previous_num_tokens = [0] * config['n']
|
||||
@@ -249,7 +330,11 @@ async def create_completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
for output in res.outputs:
|
||||
i = output.index
|
||||
if request.logprobs is not None:
|
||||
logprobs = create_logprobs(token_ids=output.token_ids, id_logprobs=t.cast(SampleLogprobs, output.logprobs)[previous_num_tokens[i]:], llm=llm)
|
||||
logprobs = create_logprobs(
|
||||
token_ids=output.token_ids,
|
||||
id_logprobs=t.cast(SampleLogprobs, output.logprobs)[previous_num_tokens[i] :],
|
||||
llm=llm,
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
previous_num_tokens[i] += len(output.token_ids)
|
||||
@@ -261,32 +346,50 @@ async def create_completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
|
||||
try:
|
||||
# Streaming case
|
||||
if stream: return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
|
||||
if stream:
|
||||
return StreamingResponse(completion_stream_generator(), media_type='text/event-stream')
|
||||
# Non-streaming case
|
||||
final_result: GenerationOutput | None = None
|
||||
texts: list[list[str]] = [[]] * config['n']
|
||||
token_ids: list[list[int]] = [[]] * config['n']
|
||||
async for res in result_generator:
|
||||
if await req.is_disconnected(): return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
|
||||
if await req.is_disconnected():
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.')
|
||||
for output in res.outputs:
|
||||
texts[output.index].append(output.text)
|
||||
token_ids[output.index].extend(output.token_ids)
|
||||
final_result = res
|
||||
if final_result is None: return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
|
||||
final_result = final_result.with_options(outputs=[output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index]) for output in final_result.outputs])
|
||||
if final_result is None:
|
||||
return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.')
|
||||
final_result = final_result.with_options(
|
||||
outputs=[
|
||||
output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index])
|
||||
for output in final_result.outputs
|
||||
]
|
||||
)
|
||||
|
||||
choices: list[CompletionResponseChoice] = []
|
||||
for output in final_result.outputs:
|
||||
if request.logprobs is not None:
|
||||
logprobs = create_logprobs(token_ids=output.token_ids, id_logprobs=t.cast(SampleLogprobs, output.logprobs), llm=llm)
|
||||
logprobs = create_logprobs(
|
||||
token_ids=output.token_ids, id_logprobs=t.cast(SampleLogprobs, output.logprobs), llm=llm
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
choice_data = CompletionResponseChoice(index=output.index, text=output.text, logprobs=logprobs, finish_reason=output.finish_reason)
|
||||
choice_data = CompletionResponseChoice(
|
||||
index=output.index, text=output.text, logprobs=logprobs, finish_reason=output.finish_reason
|
||||
)
|
||||
choices.append(choice_data)
|
||||
|
||||
num_prompt_tokens = len(t.cast(t.List[int], final_result.prompt_token_ids)) # XXX: We will always return prompt_token_ids, so this won't be None
|
||||
num_prompt_tokens = len(
|
||||
t.cast(t.List[int], final_result.prompt_token_ids)
|
||||
) # XXX: We will always return prompt_token_ids, so this won't be None
|
||||
num_generated_tokens = sum(len(output.token_ids) for output in final_result.outputs)
|
||||
usage = UsageInfo(prompt_tokens=num_prompt_tokens, completion_tokens=num_generated_tokens, total_tokens=num_prompt_tokens + num_generated_tokens)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
total_tokens=num_prompt_tokens + num_generated_tokens,
|
||||
)
|
||||
response = CompletionResponse(id=request_id, created=created_time, model=model_name, usage=usage, choices=choices)
|
||||
|
||||
if request.stream:
|
||||
@@ -296,7 +399,9 @@ async def create_completions(req: Request, llm: openllm.LLM[M, T]) -> Response:
|
||||
yield f'data: {jsonify_attr(response)}\n\n'
|
||||
yield 'data: [DONE]\n\n'
|
||||
|
||||
return StreamingResponse(fake_stream_generator(), media_type='text/event-stream', status_code=HTTPStatus.OK.value)
|
||||
return StreamingResponse(
|
||||
fake_stream_generator(), media_type='text/event-stream', status_code=HTTPStatus.OK.value
|
||||
)
|
||||
|
||||
return JSONResponse(converter.unstructure(response), status_code=HTTPStatus.OK.value)
|
||||
except Exception as err:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Base exceptions for OpenLLM. This extends BentoML exceptions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from openllm_core.exceptions import Error as Error
|
||||
|
||||
@@ -23,14 +23,15 @@ logger = logging.getLogger(__name__)
|
||||
from datasets import load_dataset
|
||||
from trl import SFTTrainer
|
||||
|
||||
DEFAULT_MODEL_ID = "ybelkada/falcon-7b-sharded-bf16"
|
||||
DATASET_NAME = "timdettmers/openassistant-guanaco"
|
||||
DEFAULT_MODEL_ID = 'ybelkada/falcon-7b-sharded-bf16'
|
||||
DATASET_NAME = 'timdettmers/openassistant-guanaco'
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TrainingArguments:
|
||||
per_device_train_batch_size: int = dataclasses.field(default=4)
|
||||
gradient_accumulation_steps: int = dataclasses.field(default=4)
|
||||
optim: str = dataclasses.field(default="paged_adamw_32bit")
|
||||
optim: str = dataclasses.field(default='paged_adamw_32bit')
|
||||
save_steps: int = dataclasses.field(default=10)
|
||||
warmup_steps: int = dataclasses.field(default=10)
|
||||
max_steps: int = dataclasses.field(default=500)
|
||||
@@ -40,47 +41,56 @@ class TrainingArguments:
|
||||
warmup_ratio: float = dataclasses.field(default=0.03)
|
||||
fp16: bool = dataclasses.field(default=True)
|
||||
group_by_length: bool = dataclasses.field(default=True)
|
||||
lr_scheduler_type: str = dataclasses.field(default="constant")
|
||||
output_dir: str = dataclasses.field(default=os.path.join(os.getcwd(), "outputs", "falcon"))
|
||||
lr_scheduler_type: str = dataclasses.field(default='constant')
|
||||
output_dir: str = dataclasses.field(default=os.path.join(os.getcwd(), 'outputs', 'falcon'))
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModelArguments:
|
||||
model_id: str = dataclasses.field(default=DEFAULT_MODEL_ID)
|
||||
max_sequence_length: int = dataclasses.field(default=512)
|
||||
|
||||
|
||||
parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith('.json'):
|
||||
# If we pass only one argument to the script and it's the path to a json file,
|
||||
# let's parse it to get our arguments.
|
||||
model_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, training_args = t.cast(t.Tuple[ModelArguments, TrainingArguments], parser.parse_args_into_dataclasses())
|
||||
|
||||
llm = openllm.LLM(model_args.model_id, quantize="int4", bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16)
|
||||
model, tokenizer = llm.prepare_for_training(adapter_type="lora",
|
||||
lora_alpha=16,
|
||||
lora_dropout=0.1,
|
||||
r=16,
|
||||
bias="none",
|
||||
target_modules=["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"])
|
||||
llm = openllm.LLM(
|
||||
model_args.model_id, quantize='int4', bnb_4bit_quant_type='nf4', bnb_4bit_compute_dtype=torch.float16
|
||||
)
|
||||
model, tokenizer = llm.prepare_for_training(
|
||||
adapter_type='lora',
|
||||
lora_alpha=16,
|
||||
lora_dropout=0.1,
|
||||
r=16,
|
||||
bias='none',
|
||||
target_modules=['query_key_value', 'dense', 'dense_h_to_4h', 'dense_4h_to_h'],
|
||||
)
|
||||
model.config.use_cache = False
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
dataset = load_dataset(DATASET_NAME, split="train")
|
||||
dataset = load_dataset(DATASET_NAME, split='train')
|
||||
|
||||
trainer = SFTTrainer(model=model,
|
||||
train_dataset=dataset,
|
||||
dataset_text_field="text",
|
||||
max_seq_length=model_args.max_sequence_length,
|
||||
tokenizer=tokenizer,
|
||||
args=dataclasses.replace(transformers.TrainingArguments(training_args.output_dir), **dataclasses.asdict(training_args)),
|
||||
)
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
train_dataset=dataset,
|
||||
dataset_text_field='text',
|
||||
max_seq_length=model_args.max_sequence_length,
|
||||
tokenizer=tokenizer,
|
||||
args=dataclasses.replace(
|
||||
transformers.TrainingArguments(training_args.output_dir), **dataclasses.asdict(training_args)
|
||||
),
|
||||
)
|
||||
|
||||
# upcast layernorm in float32 for more stable training
|
||||
for name, module in trainer.model.named_modules():
|
||||
if "norm" in name:
|
||||
if 'norm' in name:
|
||||
module = module.to(torch.float32)
|
||||
|
||||
trainer.train()
|
||||
|
||||
trainer.model.save_pretrained(os.path.join(training_args.output_dir, "lora"))
|
||||
trainer.model.save_pretrained(os.path.join(training_args.output_dir, 'lora'))
|
||||
|
||||
@@ -15,6 +15,7 @@ MAX_NEW_TOKENS = 384
|
||||
Q = 'Answer the following question, step by step:\n{q}\nA:'
|
||||
question = 'What is the meaning of life?'
|
||||
|
||||
|
||||
async def main() -> int:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('question', default=question)
|
||||
@@ -37,6 +38,7 @@ async def main() -> int:
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def _mp_fn(index: t.Any): # type: ignore
|
||||
# For xla_spawn (TPUs)
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -30,39 +30,45 @@ from random import randint, randrange
|
||||
import bitsandbytes as bnb
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
# COPIED FROM https://github.com/artidoro/qlora/blob/main/qlora.py
|
||||
def find_all_linear_names(model):
|
||||
lora_module_names = set()
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, bnb.nn.Linear4bit):
|
||||
names = name.split(".")
|
||||
names = name.split('.')
|
||||
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
||||
|
||||
if "lm_head" in lora_module_names: # needed for 16-bit
|
||||
lora_module_names.remove("lm_head")
|
||||
if 'lm_head' in lora_module_names: # needed for 16-bit
|
||||
lora_module_names.remove('lm_head')
|
||||
return list(lora_module_names)
|
||||
|
||||
|
||||
# Change this to the local converted path if you don't have access to the meta-llama model
|
||||
DEFAULT_MODEL_ID = "meta-llama/Llama-2-7b-hf"
|
||||
DEFAULT_MODEL_ID = 'meta-llama/Llama-2-7b-hf'
|
||||
# change this to 'main' if you want to use the latest llama
|
||||
DEFAULT_MODEL_VERSION = "335a02887eb6684d487240bbc28b5699298c3135"
|
||||
DATASET_NAME = "databricks/databricks-dolly-15k"
|
||||
DEFAULT_MODEL_VERSION = '335a02887eb6684d487240bbc28b5699298c3135'
|
||||
DATASET_NAME = 'databricks/databricks-dolly-15k'
|
||||
|
||||
|
||||
def format_dolly(sample):
|
||||
instruction = f"### Instruction\n{sample['instruction']}"
|
||||
context = f"### Context\n{sample['context']}" if len(sample["context"]) > 0 else None
|
||||
context = f"### Context\n{sample['context']}" if len(sample['context']) > 0 else None
|
||||
response = f"### Answer\n{sample['response']}"
|
||||
# join all the parts together
|
||||
prompt = "\n\n".join([i for i in [instruction, context, response] if i is not None])
|
||||
prompt = '\n\n'.join([i for i in [instruction, context, response] if i is not None])
|
||||
return prompt
|
||||
|
||||
|
||||
# template dataset to add prompt to each sample
|
||||
def template_dataset(sample, tokenizer):
|
||||
sample["text"] = f"{format_dolly(sample)}{tokenizer.eos_token}"
|
||||
sample['text'] = f'{format_dolly(sample)}{tokenizer.eos_token}'
|
||||
return sample
|
||||
|
||||
|
||||
# empty list to save remainder from batches to use in next batch
|
||||
remainder = {"input_ids": [], "attention_mask": [], "token_type_ids": []}
|
||||
remainder = {'input_ids': [], 'attention_mask': [], 'token_type_ids': []}
|
||||
|
||||
|
||||
def chunk(sample, chunk_length=2048):
|
||||
# define global remainder variable to save remainder from batches to use in next batch
|
||||
@@ -78,61 +84,76 @@ def chunk(sample, chunk_length=2048):
|
||||
batch_chunk_length = (batch_total_length // chunk_length) * chunk_length
|
||||
|
||||
# Split by chunks of max_len.
|
||||
result = {k: [t[i:i + chunk_length] for i in range(0, batch_chunk_length, chunk_length)] for k, t in concatenated_examples.items()}
|
||||
result = {
|
||||
k: [t[i : i + chunk_length] for i in range(0, batch_chunk_length, chunk_length)]
|
||||
for k, t in concatenated_examples.items()
|
||||
}
|
||||
# add remainder to global variable for next batch
|
||||
remainder = {k: concatenated_examples[k][batch_chunk_length:] for k in concatenated_examples.keys()}
|
||||
# prepare labels
|
||||
result["labels"] = result["input_ids"].copy()
|
||||
result['labels'] = result['input_ids'].copy()
|
||||
return result
|
||||
|
||||
|
||||
def prepare_datasets(tokenizer, dataset_name=DATASET_NAME):
|
||||
# Load dataset from the hub
|
||||
dataset = load_dataset(dataset_name, split="train")
|
||||
dataset = load_dataset(dataset_name, split='train')
|
||||
|
||||
print(f"dataset size: {len(dataset)}")
|
||||
print(f'dataset size: {len(dataset)}')
|
||||
print(dataset[randrange(len(dataset))])
|
||||
|
||||
# apply prompt template per sample
|
||||
dataset = dataset.map(partial(template_dataset, tokenizer=tokenizer), remove_columns=list(dataset.features))
|
||||
# print random sample
|
||||
print("Sample from dolly-v2 ds:", dataset[randint(0, len(dataset))]["text"])
|
||||
print('Sample from dolly-v2 ds:', dataset[randint(0, len(dataset))]['text'])
|
||||
|
||||
# tokenize and chunk dataset
|
||||
lm_dataset = dataset.map(lambda sample: tokenizer(sample["text"]), batched=True, remove_columns=list(dataset.features)).map(partial(chunk, chunk_length=2048), batched=True)
|
||||
lm_dataset = dataset.map(
|
||||
lambda sample: tokenizer(sample['text']), batched=True, remove_columns=list(dataset.features)
|
||||
).map(partial(chunk, chunk_length=2048), batched=True)
|
||||
|
||||
# Print total number of samples
|
||||
print(f"Total number of samples: {len(lm_dataset)}")
|
||||
print(f'Total number of samples: {len(lm_dataset)}')
|
||||
return lm_dataset
|
||||
|
||||
def prepare_for_int4_training(model_id: str,
|
||||
model_version: str | None = None,
|
||||
gradient_checkpointing: bool = True,
|
||||
bf16: bool = True,
|
||||
) -> tuple[peft.PeftModel, transformers.LlamaTokenizerFast]:
|
||||
|
||||
def prepare_for_int4_training(
|
||||
model_id: str, model_version: str | None = None, gradient_checkpointing: bool = True, bf16: bool = True
|
||||
) -> tuple[peft.PeftModel, transformers.LlamaTokenizerFast]:
|
||||
from peft.tuners.lora import LoraLayer
|
||||
|
||||
llm = openllm.LLM(model_id, revision=model_version, quantize="int4", bnb_4bit_compute_dtype=torch.bfloat16, use_cache=not gradient_checkpointing, device_map="auto")
|
||||
print("Model summary:", llm.model)
|
||||
llm = openllm.LLM(
|
||||
model_id,
|
||||
revision=model_version,
|
||||
quantize='int4',
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
use_cache=not gradient_checkpointing,
|
||||
device_map='auto',
|
||||
)
|
||||
print('Model summary:', llm.model)
|
||||
|
||||
# get lora target modules
|
||||
modules = find_all_linear_names(llm.model)
|
||||
print(f"Found {len(modules)} modules to quantize: {modules}")
|
||||
print(f'Found {len(modules)} modules to quantize: {modules}')
|
||||
|
||||
model, tokenizer = llm.prepare_for_training(adapter_type="lora", use_gradient_checkpointing=gradient_checkpointing, target_modules=modules)
|
||||
model, tokenizer = llm.prepare_for_training(
|
||||
adapter_type='lora', use_gradient_checkpointing=gradient_checkpointing, target_modules=modules
|
||||
)
|
||||
|
||||
# pre-process the model by upcasting the layer norms in float 32 for
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LoraLayer):
|
||||
if bf16:
|
||||
module = module.to(torch.bfloat16)
|
||||
if "norm" in name:
|
||||
if 'norm' in name:
|
||||
module = module.to(torch.float32)
|
||||
if "lm_head" in name or "embed_tokens" in name:
|
||||
if hasattr(module, "weight"):
|
||||
if 'lm_head' in name or 'embed_tokens' in name:
|
||||
if hasattr(module, 'weight'):
|
||||
if bf16 and module.weight.dtype == torch.float32:
|
||||
module = module.to(torch.bfloat16)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TrainingArguments:
|
||||
per_device_train_batch_size: int = dataclasses.field(default=1)
|
||||
@@ -141,9 +162,10 @@ class TrainingArguments:
|
||||
learning_rate: float = dataclasses.field(default=5e-5)
|
||||
num_train_epochs: int = dataclasses.field(default=3)
|
||||
logging_steps: int = dataclasses.field(default=1)
|
||||
report_to: str = dataclasses.field(default="none")
|
||||
output_dir: str = dataclasses.field(default=os.path.join(os.getcwd(), "outputs", "llama"))
|
||||
save_strategy: str = dataclasses.field(default="no")
|
||||
report_to: str = dataclasses.field(default='none')
|
||||
output_dir: str = dataclasses.field(default=os.path.join(os.getcwd(), 'outputs', 'llama'))
|
||||
save_strategy: str = dataclasses.field(default='no')
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModelArguments:
|
||||
@@ -152,32 +174,42 @@ class ModelArguments:
|
||||
seed: int = dataclasses.field(default=42)
|
||||
merge_weights: bool = dataclasses.field(default=False)
|
||||
|
||||
|
||||
if openllm.utils.in_notebook():
|
||||
model_args, training_rags = ModelArguments(), TrainingArguments()
|
||||
else:
|
||||
parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith('.json'):
|
||||
# If we pass only one argument to the script and it's the path to a json file,
|
||||
# let's parse it to get our arguments.
|
||||
model_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, training_args = t.cast(t.Tuple[ModelArguments, TrainingArguments], parser.parse_args_into_dataclasses())
|
||||
model_args, training_args = t.cast(
|
||||
t.Tuple[ModelArguments, TrainingArguments], parser.parse_args_into_dataclasses()
|
||||
)
|
||||
|
||||
# import the model first hand
|
||||
openllm.import_model(model_id=model_args.model_id, model_version=model_args.model_version)
|
||||
|
||||
|
||||
def train_loop(model_args: ModelArguments, training_args: TrainingArguments):
|
||||
import peft
|
||||
|
||||
transformers.set_seed(model_args.seed)
|
||||
|
||||
model, tokenizer = prepare_for_int4_training(model_args.model_id, gradient_checkpointing=training_args.gradient_checkpointing, bf16=training_args.bf16,)
|
||||
model, tokenizer = prepare_for_int4_training(
|
||||
model_args.model_id, gradient_checkpointing=training_args.gradient_checkpointing, bf16=training_args.bf16
|
||||
)
|
||||
datasets = prepare_datasets(tokenizer)
|
||||
|
||||
trainer = transformers.Trainer(model=model,
|
||||
args=dataclasses.replace(transformers.TrainingArguments(training_args.output_dir), **dataclasses.asdict(training_args)),
|
||||
train_dataset=datasets,
|
||||
data_collator=transformers.default_data_collator)
|
||||
trainer = transformers.Trainer(
|
||||
model=model,
|
||||
args=dataclasses.replace(
|
||||
transformers.TrainingArguments(training_args.output_dir), **dataclasses.asdict(training_args)
|
||||
),
|
||||
train_dataset=datasets,
|
||||
data_collator=transformers.default_data_collator,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
@@ -192,11 +224,16 @@ def train_loop(model_args: ModelArguments, training_args: TrainingArguments):
|
||||
del model, trainer
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
model = peft.AutoPeftModelForCausalLM.from_pretrained(training_args.output_dir, low_cpu_mem_usage=True, torch_dtype=torch.float16)
|
||||
model = peft.AutoPeftModelForCausalLM.from_pretrained(
|
||||
training_args.output_dir, low_cpu_mem_usage=True, torch_dtype=torch.float16
|
||||
)
|
||||
# merge lora with base weights and save
|
||||
model = model.merge_and_unload()
|
||||
model.save_pretrained(os.path.join(os.getcwd(), "outputs", "merged_llama_lora"), safe_serialization=True, max_shard_size="2GB")
|
||||
model.save_pretrained(
|
||||
os.path.join(os.getcwd(), 'outputs', 'merged_llama_lora'), safe_serialization=True, max_shard_size='2GB'
|
||||
)
|
||||
else:
|
||||
trainer.model.save_pretrained(os.path.join(training_args.output_dir, "lora"))
|
||||
trainer.model.save_pretrained(os.path.join(training_args.output_dir, 'lora'))
|
||||
|
||||
|
||||
train_loop(model_args, training_args)
|
||||
|
||||
@@ -24,13 +24,21 @@ from datasets import load_dataset
|
||||
if t.TYPE_CHECKING:
|
||||
from peft import PeftModel
|
||||
|
||||
DEFAULT_MODEL_ID = "facebook/opt-6.7b"
|
||||
DEFAULT_MODEL_ID = 'facebook/opt-6.7b'
|
||||
|
||||
|
||||
def load_trainer(
|
||||
model: PeftModel, tokenizer: transformers.GPT2TokenizerFast, dataset_dict: t.Any, training_args: TrainingArguments
|
||||
):
|
||||
return transformers.Trainer(
|
||||
model=model,
|
||||
train_dataset=dataset_dict['train'],
|
||||
args=dataclasses.replace(
|
||||
transformers.TrainingArguments(training_args.output_dir), **dataclasses.asdict(training_args)
|
||||
),
|
||||
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
|
||||
)
|
||||
|
||||
def load_trainer(model: PeftModel, tokenizer: transformers.GPT2TokenizerFast, dataset_dict: t.Any, training_args: TrainingArguments):
|
||||
return transformers.Trainer(model=model,
|
||||
train_dataset=dataset_dict["train"],
|
||||
args=dataclasses.replace(transformers.TrainingArguments(training_args.output_dir), **dataclasses.asdict(training_args)),
|
||||
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False))
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TrainingArguments:
|
||||
@@ -41,30 +49,34 @@ class TrainingArguments:
|
||||
learning_rate: float = dataclasses.field(default=3e-4)
|
||||
fp16: bool = dataclasses.field(default=True)
|
||||
logging_steps: int = dataclasses.field(default=1)
|
||||
output_dir: str = dataclasses.field(default=os.path.join(os.getcwd(), "outputs", "opt"))
|
||||
output_dir: str = dataclasses.field(default=os.path.join(os.getcwd(), 'outputs', 'opt'))
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModelArguments:
|
||||
model_id: str = dataclasses.field(default=DEFAULT_MODEL_ID)
|
||||
|
||||
|
||||
parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith('.json'):
|
||||
# If we pass only one argument to the script and it's the path to a json file,
|
||||
# let's parse it to get our arguments.
|
||||
model_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, training_args = t.cast(t.Tuple[ModelArguments, TrainingArguments], parser.parse_args_into_dataclasses())
|
||||
|
||||
llm = openllm.LLM(model_args.model_id, quantize="int8")
|
||||
model, tokenizer = llm.prepare_for_training(adapter_type="lora", r=16, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none")
|
||||
llm = openllm.LLM(model_args.model_id, quantize='int8')
|
||||
model, tokenizer = llm.prepare_for_training(
|
||||
adapter_type='lora', r=16, lora_alpha=32, target_modules=['q_proj', 'v_proj'], lora_dropout=0.05, bias='none'
|
||||
)
|
||||
|
||||
# ft on english_quotes
|
||||
data = load_dataset("Abirate/english_quotes")
|
||||
data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)
|
||||
data = load_dataset('Abirate/english_quotes')
|
||||
data = data.map(lambda samples: tokenizer(samples['quote']), batched=True)
|
||||
|
||||
trainer = load_trainer(model, tokenizer, data, training_args)
|
||||
model.config.use_cache = False # silence just for warning, reenable for inference later
|
||||
|
||||
trainer.train()
|
||||
|
||||
trainer.model.save_pretrained(os.path.join(training_args.output_dir, "lora"))
|
||||
trainer.model.save_pretrained(os.path.join(training_args.output_dir, 'lora'))
|
||||
|
||||
@@ -2,12 +2,14 @@
|
||||
|
||||
Currently support OpenAI compatible API.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from openllm_core.utils import LazyModule
|
||||
|
||||
|
||||
_import_structure: dict[str, list[str]] = {'openai': []}
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
|
||||
@@ -3,15 +3,18 @@ import typing as t
|
||||
|
||||
import attr
|
||||
|
||||
|
||||
@attr.define
|
||||
class AgentRequest:
|
||||
inputs: str
|
||||
parameters: t.Dict[str, t.Any]
|
||||
|
||||
|
||||
@attr.define
|
||||
class AgentResponse:
|
||||
generated_text: str
|
||||
|
||||
|
||||
@attr.define
|
||||
class HFErrorResponse:
|
||||
error_code: int
|
||||
|
||||
@@ -8,6 +8,7 @@ import openllm_core
|
||||
|
||||
from openllm_core.utils import converter
|
||||
|
||||
|
||||
@attr.define
|
||||
class ErrorResponse:
|
||||
message: str
|
||||
@@ -16,6 +17,7 @@ class ErrorResponse:
|
||||
param: t.Optional[str] = None
|
||||
code: t.Optional[str] = None
|
||||
|
||||
|
||||
@attr.define
|
||||
class CompletionRequest:
|
||||
prompt: str
|
||||
@@ -37,6 +39,7 @@ class CompletionRequest:
|
||||
top_k: t.Optional[int] = attr.field(default=None)
|
||||
best_of: t.Optional[int] = attr.field(default=1)
|
||||
|
||||
|
||||
@attr.define
|
||||
class ChatCompletionRequest:
|
||||
messages: t.List[t.Dict[str, str]]
|
||||
@@ -57,6 +60,7 @@ class ChatCompletionRequest:
|
||||
top_k: t.Optional[int] = attr.field(default=None)
|
||||
best_of: t.Optional[int] = attr.field(default=1)
|
||||
|
||||
|
||||
@attr.define
|
||||
class LogProbs:
|
||||
text_offset: t.List[int] = attr.field(default=attr.Factory(list))
|
||||
@@ -64,12 +68,14 @@ class LogProbs:
|
||||
tokens: t.List[str] = attr.field(default=attr.Factory(list))
|
||||
top_logprobs: t.List[t.Dict[str, t.Any]] = attr.field(default=attr.Factory(list))
|
||||
|
||||
|
||||
@attr.define
|
||||
class UsageInfo:
|
||||
prompt_tokens: int = attr.field(default=0)
|
||||
completion_tokens: int = attr.field(default=0)
|
||||
total_tokens: int = attr.field(default=0)
|
||||
|
||||
|
||||
@attr.define
|
||||
class CompletionResponseChoice:
|
||||
index: int
|
||||
@@ -77,6 +83,7 @@ class CompletionResponseChoice:
|
||||
logprobs: t.Optional[LogProbs] = None
|
||||
finish_reason: t.Optional[str] = None
|
||||
|
||||
|
||||
@attr.define
|
||||
class CompletionResponseStreamChoice:
|
||||
index: int
|
||||
@@ -84,6 +91,7 @@ class CompletionResponseStreamChoice:
|
||||
logprobs: t.Optional[LogProbs] = None
|
||||
finish_reason: t.Optional[str] = None
|
||||
|
||||
|
||||
@attr.define
|
||||
class CompletionStreamResponse:
|
||||
model: str
|
||||
@@ -92,6 +100,7 @@ class CompletionStreamResponse:
|
||||
id: str = attr.field(default=attr.Factory(lambda: openllm_core.utils.gen_random_uuid('cmpl')))
|
||||
created: int = attr.field(default=attr.Factory(lambda: int(time.monotonic())))
|
||||
|
||||
|
||||
@attr.define
|
||||
class CompletionResponse:
|
||||
choices: t.List[CompletionResponseChoice]
|
||||
@@ -101,32 +110,39 @@ class CompletionResponse:
|
||||
id: str = attr.field(default=attr.Factory(lambda: openllm_core.utils.gen_random_uuid('cmpl')))
|
||||
created: int = attr.field(default=attr.Factory(lambda: int(time.monotonic())))
|
||||
|
||||
|
||||
LiteralRole = t.Literal['system', 'user', 'assistant']
|
||||
|
||||
|
||||
@attr.define
|
||||
class Delta:
|
||||
role: t.Optional[LiteralRole] = None
|
||||
content: t.Optional[str] = None
|
||||
|
||||
|
||||
@attr.define
|
||||
class ChatMessage:
|
||||
role: LiteralRole
|
||||
content: str
|
||||
|
||||
|
||||
converter.register_unstructure_hook(ChatMessage, lambda msg: {'role': msg.role, 'content': msg.content})
|
||||
|
||||
|
||||
@attr.define
|
||||
class ChatCompletionResponseStreamChoice:
|
||||
index: int
|
||||
delta: Delta
|
||||
finish_reason: t.Optional[str] = attr.field(default=None)
|
||||
|
||||
|
||||
@attr.define
|
||||
class ChatCompletionResponseChoice:
|
||||
index: int
|
||||
message: ChatMessage
|
||||
finish_reason: t.Optional[str] = attr.field(default=None)
|
||||
|
||||
|
||||
@attr.define
|
||||
class ChatCompletionResponse:
|
||||
choices: t.List[ChatCompletionResponseChoice]
|
||||
@@ -136,6 +152,7 @@ class ChatCompletionResponse:
|
||||
created: int = attr.field(default=attr.Factory(lambda: int(time.monotonic())))
|
||||
usage: UsageInfo = attr.field(default=attr.Factory(lambda: UsageInfo()))
|
||||
|
||||
|
||||
@attr.define
|
||||
class ChatCompletionStreamResponse:
|
||||
choices: t.List[ChatCompletionResponseStreamChoice]
|
||||
@@ -144,6 +161,7 @@ class ChatCompletionStreamResponse:
|
||||
id: str = attr.field(default=attr.Factory(lambda: openllm_core.utils.gen_random_uuid('chatcmpl')))
|
||||
created: int = attr.field(default=attr.Factory(lambda: int(time.monotonic())))
|
||||
|
||||
|
||||
@attr.define
|
||||
class ModelCard:
|
||||
id: str
|
||||
@@ -151,19 +169,25 @@ class ModelCard:
|
||||
created: int = attr.field(default=attr.Factory(lambda: int(time.monotonic())))
|
||||
owned_by: str = 'na'
|
||||
|
||||
|
||||
@attr.define
|
||||
class ModelList:
|
||||
object: str = 'list'
|
||||
data: t.List[ModelCard] = attr.field(factory=list)
|
||||
|
||||
|
||||
async def get_conversation_prompt(request: ChatCompletionRequest, llm_config: openllm_core.LLMConfig) -> str:
|
||||
conv = llm_config.get_conversation_template()
|
||||
for message in request.messages:
|
||||
msg_role = message['role']
|
||||
if msg_role == 'system': conv.set_system_message(message['content'])
|
||||
elif msg_role == 'user': conv.append_message(conv.roles[0], message['content'])
|
||||
elif msg_role == 'assistant': conv.append_message(conv.roles[1], message['content'])
|
||||
else: raise ValueError(f'Unknown role: {msg_role}')
|
||||
if msg_role == 'system':
|
||||
conv.set_system_message(message['content'])
|
||||
elif msg_role == 'user':
|
||||
conv.append_message(conv.roles[0], message['content'])
|
||||
elif msg_role == 'assistant':
|
||||
conv.append_message(conv.roles[1], message['content'])
|
||||
else:
|
||||
raise ValueError(f'Unknown role: {msg_role}')
|
||||
# Add a blank message for the assistant.
|
||||
conv.append_message(conv.roles[1], '')
|
||||
return conv.get_prompt()
|
||||
|
||||
@@ -4,6 +4,7 @@ Currently supports transformers for PyTorch, and vLLM.
|
||||
|
||||
Currently, GGML format is working in progress.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import importlib
|
||||
import typing as t
|
||||
@@ -18,6 +19,7 @@ from openllm_core._typing_compat import M
|
||||
from openllm_core._typing_compat import ParamSpec
|
||||
from openllm_core._typing_compat import T
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import transformers as _transformers
|
||||
|
||||
@@ -31,6 +33,7 @@ else:
|
||||
|
||||
P = ParamSpec('P')
|
||||
|
||||
|
||||
def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T:
|
||||
"""Load the tokenizer from BentoML store.
|
||||
|
||||
@@ -47,24 +50,34 @@ def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T:
|
||||
try:
|
||||
tokenizer = cloudpickle.load(t.cast('t.IO[bytes]', cofile))['tokenizer']
|
||||
except KeyError:
|
||||
raise openllm.exceptions.OpenLLMException("Bento model does not have tokenizer. Make sure to save the tokenizer within the model via 'custom_objects'. "
|
||||
"For example: \"bentoml.transformers.save_model(..., custom_objects={'tokenizer': tokenizer})\"") from None
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
"Bento model does not have tokenizer. Make sure to save the tokenizer within the model via 'custom_objects'. "
|
||||
'For example: "bentoml.transformers.save_model(..., custom_objects={\'tokenizer\': tokenizer})"'
|
||||
) from None
|
||||
else:
|
||||
tokenizer = _transformers.AutoTokenizer.from_pretrained(bentomodel_fs.getsyspath('/'), trust_remote_code=llm.trust_remote_code, **tokenizer_attrs)
|
||||
tokenizer = _transformers.AutoTokenizer.from_pretrained(
|
||||
bentomodel_fs.getsyspath('/'), trust_remote_code=llm.trust_remote_code, **tokenizer_attrs
|
||||
)
|
||||
|
||||
if tokenizer.pad_token_id is None:
|
||||
if config.pad_token_id is not None: tokenizer.pad_token_id = config.pad_token_id
|
||||
elif config.eos_token_id is not None: tokenizer.pad_token_id = config.eos_token_id
|
||||
elif tokenizer.eos_token_id is not None: tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
else: tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
||||
if config.pad_token_id is not None:
|
||||
tokenizer.pad_token_id = config.pad_token_id
|
||||
elif config.eos_token_id is not None:
|
||||
tokenizer.pad_token_id = config.eos_token_id
|
||||
elif tokenizer.eos_token_id is not None:
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
else:
|
||||
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
||||
return tokenizer
|
||||
|
||||
|
||||
class _Caller(t.Protocol[P]):
|
||||
def __call__(self, llm: openllm.LLM[M, T], *args: P.args, **kwargs: P.kwargs) -> t.Any:
|
||||
...
|
||||
def __call__(self, llm: openllm.LLM[M, T], *args: P.args, **kwargs: P.kwargs) -> t.Any: ...
|
||||
|
||||
|
||||
_extras = ['get', 'import_model', 'load_model']
|
||||
|
||||
|
||||
def _make_dispatch_function(fn: str) -> _Caller[P]:
|
||||
def caller(llm: openllm.LLM[M, T], *args: P.args, **kwargs: P.kwargs) -> t.Any:
|
||||
"""Generic function dispatch to correct serialisation submodules based on LLM runtime.
|
||||
@@ -74,30 +87,36 @@ def _make_dispatch_function(fn: str) -> _Caller[P]:
|
||||
> [!NOTE] See 'openllm.serialisation.ggml' if 'llm.__llm_backend__="ggml"'
|
||||
"""
|
||||
serde = 'transformers'
|
||||
if llm.__llm_backend__ == 'ggml': serde = 'ggml'
|
||||
if llm.__llm_backend__ == 'ggml':
|
||||
serde = 'ggml'
|
||||
return getattr(importlib.import_module(f'.{serde}', __name__), fn)(llm, *args, **kwargs)
|
||||
|
||||
return caller
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
|
||||
def get(llm: openllm.LLM[M, T], *args: t.Any, **kwargs: t.Any) -> bentoml.Model:
|
||||
...
|
||||
def get(llm: openllm.LLM[M, T], *args: t.Any, **kwargs: t.Any) -> bentoml.Model: ...
|
||||
|
||||
def import_model(llm: openllm.LLM[M, T], *args: t.Any, **kwargs: t.Any) -> bentoml.Model:
|
||||
...
|
||||
def import_model(llm: openllm.LLM[M, T], *args: t.Any, **kwargs: t.Any) -> bentoml.Model: ...
|
||||
|
||||
def load_model(llm: openllm.LLM[M, T], *args: t.Any, **kwargs: t.Any) -> M: ...
|
||||
|
||||
def load_model(llm: openllm.LLM[M, T], *args: t.Any, **kwargs: t.Any) -> M:
|
||||
...
|
||||
|
||||
_import_structure: dict[str, list[str]] = {'ggml': [], 'transformers': [], 'constants': []}
|
||||
__all__ = ['ggml', 'transformers', 'constants', 'load_tokenizer', *_extras]
|
||||
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return sorted(__all__)
|
||||
|
||||
|
||||
def __getattr__(name: str) -> t.Any:
|
||||
if name == 'load_tokenizer': return load_tokenizer
|
||||
elif name in _import_structure: return importlib.import_module(f'.{name}', __name__)
|
||||
elif name in _extras: return _make_dispatch_function(name)
|
||||
else: raise AttributeError(f'{__name__} has no attribute {name}')
|
||||
if name == 'load_tokenizer':
|
||||
return load_tokenizer
|
||||
elif name in _import_structure:
|
||||
return importlib.import_module(f'.{name}', __name__)
|
||||
elif name in _extras:
|
||||
return _make_dispatch_function(name)
|
||||
else:
|
||||
raise AttributeError(f'{__name__} has no attribute {name}')
|
||||
|
||||
@@ -1,7 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
FRAMEWORK_TO_AUTOCLASS_MAPPING = {'pt': ('AutoModelForCausalLM', 'AutoModelForSeq2SeqLM'), 'vllm': ('AutoModelForCausalLM', 'AutoModelForSeq2SeqLM')}
|
||||
HUB_ATTRS = ['cache_dir', 'code_revision', 'force_download', 'local_files_only', 'proxies', 'resume_download', 'revision', 'subfolder', 'use_auth_token']
|
||||
|
||||
FRAMEWORK_TO_AUTOCLASS_MAPPING = {
|
||||
'pt': ('AutoModelForCausalLM', 'AutoModelForSeq2SeqLM'),
|
||||
'vllm': ('AutoModelForCausalLM', 'AutoModelForSeq2SeqLM'),
|
||||
}
|
||||
HUB_ATTRS = [
|
||||
'cache_dir',
|
||||
'code_revision',
|
||||
'force_download',
|
||||
'local_files_only',
|
||||
'proxies',
|
||||
'resume_download',
|
||||
'revision',
|
||||
'subfolder',
|
||||
'use_auth_token',
|
||||
]
|
||||
CONFIG_FILE_NAME = 'config.json'
|
||||
# the below is similar to peft.utils.other.CONFIG_NAME
|
||||
PEFT_CONFIG_NAME = 'adapter_config.json'
|
||||
|
||||
@@ -2,9 +2,11 @@
|
||||
|
||||
This requires ctransformers to be installed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import bentoml
|
||||
import openllm
|
||||
@@ -13,11 +15,16 @@ if t.TYPE_CHECKING:
|
||||
|
||||
_conversion_strategy = {'pt': 'ggml'}
|
||||
|
||||
def import_model(llm: openllm.LLM[t.Any, t.Any], *decls: t.Any, trust_remote_code: bool = True, **attrs: t.Any,) -> bentoml.Model:
|
||||
|
||||
def import_model(
|
||||
llm: openllm.LLM[t.Any, t.Any], *decls: t.Any, trust_remote_code: bool = True, **attrs: t.Any
|
||||
) -> bentoml.Model:
|
||||
raise NotImplementedError('Currently work in progress.')
|
||||
|
||||
|
||||
def get(llm: openllm.LLM[t.Any, t.Any]) -> bentoml.Model:
|
||||
raise NotImplementedError('Currently work in progress.')
|
||||
|
||||
|
||||
def load_model(llm: openllm.LLM[M, t.Any], *decls: t.Any, **attrs: t.Any) -> M:
|
||||
raise NotImplementedError('Currently work in progress.')
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Serialisation related implementation for Transformers-based implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
import importlib
|
||||
import logging
|
||||
@@ -27,6 +28,7 @@ from ._helpers import infer_autoclass_from_llm
|
||||
from ._helpers import process_config
|
||||
from .weights import HfIgnore
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import types
|
||||
|
||||
@@ -38,29 +40,52 @@ logger = logging.getLogger(__name__)
|
||||
__all__ = ['import_model', 'get', 'load_model']
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
def _patch_correct_tag(llm: openllm.LLM[M, T], config: transformers.PretrainedConfig, _revision: str | None = None) -> None:
|
||||
|
||||
def _patch_correct_tag(
|
||||
llm: openllm.LLM[M, T], config: transformers.PretrainedConfig, _revision: str | None = None
|
||||
) -> None:
|
||||
# NOTE: The following won't hit during local since we generated a correct version based on local path hash It will only hit if we use model from HF Hub
|
||||
if llm.revision is not None: return
|
||||
if llm.revision is not None:
|
||||
return
|
||||
if not llm._local:
|
||||
try:
|
||||
if _revision is None: _revision = get_hash(config)
|
||||
if _revision is None:
|
||||
_revision = get_hash(config)
|
||||
except ValueError:
|
||||
pass
|
||||
if _revision is None and llm.tag.version is not None: _revision = llm.tag.version
|
||||
if llm._tag.version is None: _object_setattr(llm, '_tag', attr.evolve(llm.tag, version=_revision)) # HACK: This copies the correct revision into llm.tag
|
||||
if llm._revision is None: _object_setattr(llm, '_revision', _revision) # HACK: This copies the correct revision into llm._model_version
|
||||
if _revision is None and llm.tag.version is not None:
|
||||
_revision = llm.tag.version
|
||||
if llm._tag.version is None:
|
||||
_object_setattr(
|
||||
llm, '_tag', attr.evolve(llm.tag, version=_revision)
|
||||
) # HACK: This copies the correct revision into llm.tag
|
||||
if llm._revision is None:
|
||||
_object_setattr(llm, '_revision', _revision) # HACK: This copies the correct revision into llm._model_version
|
||||
|
||||
|
||||
@inject
|
||||
def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool, _model_store: ModelStore = Provide[BentoMLContainer.model_store], **attrs: t.Any) -> bentoml.Model:
|
||||
def import_model(
|
||||
llm: openllm.LLM[M, T],
|
||||
*decls: t.Any,
|
||||
trust_remote_code: bool,
|
||||
_model_store: ModelStore = Provide[BentoMLContainer.model_store],
|
||||
**attrs: t.Any,
|
||||
) -> bentoml.Model:
|
||||
config, hub_attrs, attrs = process_config(llm.model_id, trust_remote_code, **attrs)
|
||||
_patch_correct_tag(llm, config)
|
||||
_, tokenizer_attrs = llm.llm_parameters
|
||||
quantize = llm._quantise
|
||||
safe_serialisation = openllm.utils.first_not_none(attrs.get('safe_serialization'), default=llm._serialisation == 'safetensors')
|
||||
safe_serialisation = openllm.utils.first_not_none(
|
||||
attrs.get('safe_serialization'), default=llm._serialisation == 'safetensors'
|
||||
)
|
||||
metadata: DictStrAny = {'safe_serialisation': safe_serialisation}
|
||||
if quantize: metadata['_quantize'] = quantize
|
||||
if quantize:
|
||||
metadata['_quantize'] = quantize
|
||||
architectures = getattr(config, 'architectures', [])
|
||||
if not architectures: raise RuntimeError('Failed to determine the architecture for this model. Make sure the `config.json` is valid and can be loaded with `transformers.AutoConfig`')
|
||||
if not architectures:
|
||||
raise RuntimeError(
|
||||
'Failed to determine the architecture for this model. Make sure the `config.json` is valid and can be loaded with `transformers.AutoConfig`'
|
||||
)
|
||||
metadata['_pretrained_class'] = architectures[0]
|
||||
metadata['_revision'] = get_hash(config)
|
||||
|
||||
@@ -69,93 +94,152 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool,
|
||||
if quantize == 'gptq':
|
||||
if not openllm.utils.is_autogptq_available() or not openllm.utils.is_optimum_supports_gptq():
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
"GPTQ quantisation requires 'auto-gptq' and 'optimum' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\" --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/'"
|
||||
"GPTQ quantisation requires 'auto-gptq' and 'optimum' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\" --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/'"
|
||||
)
|
||||
signatures['generate'] = {'batchable': False}
|
||||
else:
|
||||
attrs['use_safetensors'] = safe_serialisation
|
||||
metadata['_framework'] = llm.__llm_backend__
|
||||
signatures.update({
|
||||
signatures.update(
|
||||
{
|
||||
k: ModelSignature(batchable=False)
|
||||
for k in ('__call__', 'forward', 'generate', 'contrastive_search', 'greedy_search', 'sample', 'beam_search', 'beam_sample', 'group_beam_search', 'constrained_beam_search')
|
||||
})
|
||||
for k in (
|
||||
'__call__',
|
||||
'forward',
|
||||
'generate',
|
||||
'contrastive_search',
|
||||
'greedy_search',
|
||||
'sample',
|
||||
'beam_search',
|
||||
'beam_sample',
|
||||
'group_beam_search',
|
||||
'constrained_beam_search',
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(llm.model_id, trust_remote_code=trust_remote_code, **hub_attrs, **tokenizer_attrs)
|
||||
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||
llm.model_id, trust_remote_code=trust_remote_code, **hub_attrs, **tokenizer_attrs
|
||||
)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
model = None
|
||||
external_modules: list[types.ModuleType] = [importlib.import_module(tokenizer.__module__)]
|
||||
imported_modules: list[types.ModuleType] = []
|
||||
bentomodel = bentoml.Model.create(llm.tag,
|
||||
module='openllm.serialisation.transformers',
|
||||
api_version='v2.1.0',
|
||||
options=ModelOptions(),
|
||||
context=openllm.utils.generate_context(framework_name='openllm'),
|
||||
labels=openllm.utils.generate_labels(llm),
|
||||
metadata=metadata,
|
||||
signatures=signatures)
|
||||
bentomodel = bentoml.Model.create(
|
||||
llm.tag,
|
||||
module='openllm.serialisation.transformers',
|
||||
api_version='v2.1.0',
|
||||
options=ModelOptions(),
|
||||
context=openllm.utils.generate_context(framework_name='openllm'),
|
||||
labels=openllm.utils.generate_labels(llm),
|
||||
metadata=metadata,
|
||||
signatures=signatures,
|
||||
)
|
||||
with openllm.utils.analytics.set_bentoml_tracking():
|
||||
try:
|
||||
bentomodel.enter_cloudpickle_context(external_modules, imported_modules)
|
||||
tokenizer.save_pretrained(bentomodel.path)
|
||||
if llm._quantization_config or (llm._quantise and llm._quantise not in {'squeezellm', 'awq'}): attrs['quantization_config'] = llm.quantization_config
|
||||
if llm._quantization_config or (llm._quantise and llm._quantise not in {'squeezellm', 'awq'}):
|
||||
attrs['quantization_config'] = llm.quantization_config
|
||||
if quantize == 'gptq':
|
||||
from optimum.gptq.constants import GPTQ_CONFIG
|
||||
|
||||
with open(bentomodel.path_of(GPTQ_CONFIG), 'w', encoding='utf-8') as f:
|
||||
f.write(orjson.dumps(config.quantization_config, option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS).decode())
|
||||
if llm._local: # possible local path
|
||||
model = infer_autoclass_from_llm(llm, config).from_pretrained(llm.model_id, *decls, config=config, trust_remote_code=trust_remote_code, **hub_attrs, **attrs)
|
||||
model = infer_autoclass_from_llm(llm, config).from_pretrained(
|
||||
llm.model_id, *decls, config=config, trust_remote_code=trust_remote_code, **hub_attrs, **attrs
|
||||
)
|
||||
# for trust_remote_code to work
|
||||
bentomodel.enter_cloudpickle_context([importlib.import_module(model.__module__)], imported_modules)
|
||||
model.save_pretrained(bentomodel.path, max_shard_size='5GB', safe_serialization=safe_serialisation)
|
||||
else:
|
||||
# we will clone the all tings into the bentomodel path without loading model into memory
|
||||
snapshot_download(llm.model_id, local_dir=bentomodel.path, local_dir_use_symlinks=False, ignore_patterns=HfIgnore.ignore_patterns(llm))
|
||||
snapshot_download(
|
||||
llm.model_id,
|
||||
local_dir=bentomodel.path,
|
||||
local_dir_use_symlinks=False,
|
||||
ignore_patterns=HfIgnore.ignore_patterns(llm),
|
||||
)
|
||||
except Exception:
|
||||
raise
|
||||
else:
|
||||
bentomodel.flush() # type: ignore[no-untyped-call]
|
||||
bentomodel.save(_model_store)
|
||||
openllm.utils.analytics.track(openllm.utils.analytics.ModelSaveEvent(module=bentomodel.info.module, model_size_in_kb=openllm.utils.calc_dir_size(bentomodel.path) / 1024))
|
||||
openllm.utils.analytics.track(
|
||||
openllm.utils.analytics.ModelSaveEvent(
|
||||
module=bentomodel.info.module, model_size_in_kb=openllm.utils.calc_dir_size(bentomodel.path) / 1024
|
||||
)
|
||||
)
|
||||
finally:
|
||||
bentomodel.exit_cloudpickle_context(imported_modules)
|
||||
# NOTE: We need to free up the cache after importing the model
|
||||
# in the case where users first run openllm start without the model available locally.
|
||||
if openllm.utils.is_torch_available() and torch.cuda.is_available(): torch.cuda.empty_cache()
|
||||
if openllm.utils.is_torch_available() and torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
del model
|
||||
return bentomodel
|
||||
|
||||
|
||||
def get(llm: openllm.LLM[M, T]) -> bentoml.Model:
|
||||
try:
|
||||
model = bentoml.models.get(llm.tag)
|
||||
backend = model.info.labels['backend']
|
||||
if backend != llm.__llm_backend__: raise openllm.exceptions.OpenLLMException(f"'{model.tag!s}' was saved with backend '{backend}', while loading with '{llm.__llm_backend__}'.")
|
||||
_patch_correct_tag(llm, transformers.AutoConfig.from_pretrained(model.path, trust_remote_code=llm.trust_remote_code), _revision=model.info.metadata.get('_revision'))
|
||||
if backend != llm.__llm_backend__:
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
f"'{model.tag!s}' was saved with backend '{backend}', while loading with '{llm.__llm_backend__}'."
|
||||
)
|
||||
_patch_correct_tag(
|
||||
llm,
|
||||
transformers.AutoConfig.from_pretrained(model.path, trust_remote_code=llm.trust_remote_code),
|
||||
_revision=model.info.metadata.get('_revision'),
|
||||
)
|
||||
return model
|
||||
except Exception as err:
|
||||
raise openllm.exceptions.OpenLLMException(f'Failed while getting stored artefact (lookup for traceback):\n{err}') from err
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
f'Failed while getting stored artefact (lookup for traceback):\n{err}'
|
||||
) from err
|
||||
|
||||
|
||||
def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M:
|
||||
if llm._quantise in {'awq', 'squeezellm'}: raise RuntimeError('AWQ is not yet supported with PyTorch backend.')
|
||||
config, attrs = transformers.AutoConfig.from_pretrained(llm.bentomodel.path, return_unused_kwargs=True, trust_remote_code=llm.trust_remote_code, **attrs)
|
||||
if llm._quantise in {'awq', 'squeezellm'}:
|
||||
raise RuntimeError('AWQ is not yet supported with PyTorch backend.')
|
||||
config, attrs = transformers.AutoConfig.from_pretrained(
|
||||
llm.bentomodel.path, return_unused_kwargs=True, trust_remote_code=llm.trust_remote_code, **attrs
|
||||
)
|
||||
auto_class = infer_autoclass_from_llm(llm, config)
|
||||
device_map = attrs.pop('device_map', None)
|
||||
if torch.cuda.is_available():
|
||||
if torch.cuda.device_count() > 1: device_map = 'auto'
|
||||
elif torch.cuda.device_count() == 1: device_map = 'cuda:0'
|
||||
if llm._quantise in {'int8', 'int4'}: attrs['quantization_config'] = llm.quantization_config
|
||||
if torch.cuda.device_count() > 1:
|
||||
device_map = 'auto'
|
||||
elif torch.cuda.device_count() == 1:
|
||||
device_map = 'cuda:0'
|
||||
if llm._quantise in {'int8', 'int4'}:
|
||||
attrs['quantization_config'] = llm.quantization_config
|
||||
|
||||
if '_quantize' in llm.bentomodel.info.metadata:
|
||||
_quantise = llm.bentomodel.info.metadata['_quantize']
|
||||
if _quantise == 'gptq':
|
||||
if not openllm.utils.is_autogptq_available() or not openllm.utils.is_optimum_supports_gptq():
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
"GPTQ quantisation requires 'auto-gptq' and 'optimum' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\" --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/'"
|
||||
"GPTQ quantisation requires 'auto-gptq' and 'optimum' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\" --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/'"
|
||||
)
|
||||
if llm.config['model_type'] != 'causal_lm':
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})"
|
||||
)
|
||||
if llm.config['model_type'] != 'causal_lm': raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
|
||||
# TODO: investigate load with flash attention
|
||||
model = auto_class.from_pretrained(llm.bentomodel.path, device_map=device_map, **attrs)
|
||||
else:
|
||||
model = auto_class.from_pretrained(llm.bentomodel.path, *decls, config=config, trust_remote_code=llm.trust_remote_code, device_map=device_map, **attrs)
|
||||
model = auto_class.from_pretrained(
|
||||
llm.bentomodel.path,
|
||||
*decls,
|
||||
config=config,
|
||||
trust_remote_code=llm.trust_remote_code,
|
||||
device_map=device_map,
|
||||
**attrs,
|
||||
)
|
||||
return t.cast('M', model)
|
||||
|
||||
@@ -10,6 +10,7 @@ import openllm
|
||||
from openllm.serialisation.constants import FRAMEWORK_TO_AUTOCLASS_MAPPING
|
||||
from openllm.serialisation.constants import HUB_ATTRS
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
||||
|
||||
@@ -17,12 +18,17 @@ if t.TYPE_CHECKING:
|
||||
from openllm_core._typing_compat import M
|
||||
from openllm_core._typing_compat import T
|
||||
|
||||
|
||||
def get_hash(config: transformers.PretrainedConfig) -> str:
|
||||
_commit_hash = getattr(config, '_commit_hash', None)
|
||||
if _commit_hash is None: raise ValueError(f'Cannot find commit hash in {config}')
|
||||
if _commit_hash is None:
|
||||
raise ValueError(f'Cannot find commit hash in {config}')
|
||||
return _commit_hash
|
||||
|
||||
def process_config(model_id: str, trust_remote_code: bool, **attrs: t.Any) -> tuple[transformers.PretrainedConfig, DictStrAny, DictStrAny]:
|
||||
|
||||
def process_config(
|
||||
model_id: str, trust_remote_code: bool, **attrs: t.Any
|
||||
) -> tuple[transformers.PretrainedConfig, DictStrAny, DictStrAny]:
|
||||
"""A helper function that correctly parse config and attributes for transformers.PretrainedConfig.
|
||||
|
||||
Args:
|
||||
@@ -38,25 +44,36 @@ def process_config(model_id: str, trust_remote_code: bool, **attrs: t.Any) -> tu
|
||||
hub_attrs = {k: attrs.pop(k) for k in HUB_ATTRS if k in attrs}
|
||||
if not isinstance(config, transformers.PretrainedConfig):
|
||||
copied_attrs = copy.deepcopy(attrs)
|
||||
if copied_attrs.get('torch_dtype', None) == 'auto': copied_attrs.pop('torch_dtype')
|
||||
config, attrs = transformers.AutoConfig.from_pretrained(model_id, return_unused_kwargs=True, trust_remote_code=trust_remote_code, **hub_attrs, **copied_attrs)
|
||||
if copied_attrs.get('torch_dtype', None) == 'auto':
|
||||
copied_attrs.pop('torch_dtype')
|
||||
config, attrs = transformers.AutoConfig.from_pretrained(
|
||||
model_id, return_unused_kwargs=True, trust_remote_code=trust_remote_code, **hub_attrs, **copied_attrs
|
||||
)
|
||||
return config, hub_attrs, attrs
|
||||
|
||||
|
||||
def infer_autoclass_from_llm(llm: openllm.LLM[M, T], config: transformers.PretrainedConfig, /) -> _BaseAutoModelClass:
|
||||
if llm.trust_remote_code:
|
||||
autoclass = 'AutoModelForSeq2SeqLM' if llm.config['model_type'] == 'seq2seq_lm' else 'AutoModelForCausalLM'
|
||||
if not hasattr(config, 'auto_map'):
|
||||
raise ValueError(f'Invalid configuration for {llm.model_id}. ``trust_remote_code=True`` requires `transformers.PretrainedConfig` to contain a `auto_map` mapping')
|
||||
raise ValueError(
|
||||
f'Invalid configuration for {llm.model_id}. ``trust_remote_code=True`` requires `transformers.PretrainedConfig` to contain a `auto_map` mapping'
|
||||
)
|
||||
# in case this model doesn't use the correct auto class for model type, for example like chatglm
|
||||
# where it uses AutoModel instead of AutoModelForCausalLM. Then we fallback to AutoModel
|
||||
if autoclass not in config.auto_map: autoclass = 'AutoModel'
|
||||
if autoclass not in config.auto_map:
|
||||
autoclass = 'AutoModel'
|
||||
return getattr(transformers, autoclass)
|
||||
else:
|
||||
if type(config) in transformers.MODEL_FOR_CAUSAL_LM_MAPPING: idx = 0
|
||||
elif type(config) in transformers.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING: idx = 1
|
||||
else: raise openllm.exceptions.OpenLLMException(f'Model type {type(config)} is not supported yet.')
|
||||
if type(config) in transformers.MODEL_FOR_CAUSAL_LM_MAPPING:
|
||||
idx = 0
|
||||
elif type(config) in transformers.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING:
|
||||
idx = 1
|
||||
else:
|
||||
raise openllm.exceptions.OpenLLMException(f'Model type {type(config)} is not supported yet.')
|
||||
return getattr(transformers, FRAMEWORK_TO_AUTOCLASS_MAPPING[llm.__llm_backend__][idx])
|
||||
|
||||
|
||||
def check_unintialised_params(model: torch.nn.Module) -> None:
|
||||
unintialized = [n for n, param in model.named_parameters() if param.data.device == torch.device('meta')]
|
||||
if len(unintialized) > 0:
|
||||
|
||||
@@ -8,6 +8,7 @@ from huggingface_hub import HfApi
|
||||
|
||||
from openllm_core.exceptions import Error
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from huggingface_hub.hf_api import ModelInfo as HfModelInfo
|
||||
|
||||
@@ -19,13 +20,17 @@ if t.TYPE_CHECKING:
|
||||
__global_inst__ = None
|
||||
__cached_id__: dict[str, HfModelInfo] = dict()
|
||||
|
||||
|
||||
def Client() -> HfApi:
|
||||
global __global_inst__ # noqa: PLW0603
|
||||
if __global_inst__ is None: __global_inst__ = HfApi()
|
||||
if __global_inst__ is None:
|
||||
__global_inst__ = HfApi()
|
||||
return __global_inst__
|
||||
|
||||
|
||||
def ModelInfo(model_id: str, revision: str | None = None) -> HfModelInfo:
|
||||
if model_id in __cached_id__: return __cached_id__[model_id]
|
||||
if model_id in __cached_id__:
|
||||
return __cached_id__[model_id]
|
||||
try:
|
||||
__cached_id__[model_id] = Client().model_info(model_id, revision=revision)
|
||||
return __cached_id__[model_id]
|
||||
@@ -33,9 +38,11 @@ def ModelInfo(model_id: str, revision: str | None = None) -> HfModelInfo:
|
||||
traceback.print_exc()
|
||||
raise Error(f'Failed to fetch {model_id} from huggingface.co') from err
|
||||
|
||||
|
||||
def has_safetensors_weights(model_id: str, revision: str | None = None) -> bool:
|
||||
return any(s.rfilename.endswith('.safetensors') for s in ModelInfo(model_id, revision=revision).siblings)
|
||||
|
||||
|
||||
@attr.define(slots=True)
|
||||
class HfIgnore:
|
||||
safetensors = '*.safetensors'
|
||||
@@ -48,9 +55,12 @@ class HfIgnore:
|
||||
def ignore_patterns(cls, llm: openllm.LLM[M, T]) -> list[str]:
|
||||
if llm.__llm_backend__ in {'vllm', 'pt'}:
|
||||
base = [cls.tf, cls.flax, cls.gguf]
|
||||
if has_safetensors_weights(llm.model_id): base.append(cls.pt)
|
||||
else: base.append(cls.safetensors)
|
||||
elif llm.__llm_backend__ == 'ggml': base = [cls.tf, cls.flax, cls.pt, cls.safetensors]
|
||||
if has_safetensors_weights(llm.model_id):
|
||||
base.append(cls.pt)
|
||||
else:
|
||||
base.append(cls.safetensors)
|
||||
elif llm.__llm_backend__ == 'ggml':
|
||||
base = [cls.tf, cls.flax, cls.pt, cls.safetensors]
|
||||
else:
|
||||
raise ValueError('Unknown backend (should never happen at all.)')
|
||||
# filter out these files, since we probably don't need them for now.
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Tests utilities for OpenLLM."""
|
||||
|
||||
from __future__ import annotations
|
||||
import contextlib
|
||||
import logging
|
||||
@@ -9,14 +10,18 @@ import typing as t
|
||||
import bentoml
|
||||
import openllm
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm_core._typing_compat import LiteralBackend
|
||||
from openllm_core._typing_compat import LiteralQuantise
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def build_bento(model: str, model_id: str | None = None, quantize: LiteralQuantise | None = None, cleanup: bool = False) -> t.Iterator[bentoml.Bento]:
|
||||
def build_bento(
|
||||
model: str, model_id: str | None = None, quantize: LiteralQuantise | None = None, cleanup: bool = False
|
||||
) -> t.Iterator[bentoml.Bento]:
|
||||
logger.info('Building BentoML for %s', model)
|
||||
bento = openllm.build(model, model_id=model_id, quantize=quantize)
|
||||
yield bento
|
||||
@@ -24,29 +29,39 @@ def build_bento(model: str, model_id: str | None = None, quantize: LiteralQuanti
|
||||
logger.info('Deleting %s', bento.tag)
|
||||
bentoml.bentos.delete(bento.tag)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def build_container(bento: bentoml.Bento | str | bentoml.Tag, image_tag: str | None = None, cleanup: bool = False, **attrs: t.Any) -> t.Iterator[str]:
|
||||
if isinstance(bento, bentoml.Bento): bento_tag = bento.tag
|
||||
else: bento_tag = bentoml.Tag.from_taglike(bento)
|
||||
if image_tag is None: image_tag = str(bento_tag)
|
||||
def build_container(
|
||||
bento: bentoml.Bento | str | bentoml.Tag, image_tag: str | None = None, cleanup: bool = False, **attrs: t.Any
|
||||
) -> t.Iterator[str]:
|
||||
if isinstance(bento, bentoml.Bento):
|
||||
bento_tag = bento.tag
|
||||
else:
|
||||
bento_tag = bentoml.Tag.from_taglike(bento)
|
||||
if image_tag is None:
|
||||
image_tag = str(bento_tag)
|
||||
executable = shutil.which('docker')
|
||||
if not executable: raise RuntimeError('docker executable not found')
|
||||
if not executable:
|
||||
raise RuntimeError('docker executable not found')
|
||||
try:
|
||||
logger.info('Building container for %s', bento_tag)
|
||||
bentoml.container.build(bento_tag, backend='docker', image_tag=(image_tag,), progress='plain', **attrs,)
|
||||
bentoml.container.build(bento_tag, backend='docker', image_tag=(image_tag,), progress='plain', **attrs)
|
||||
yield image_tag
|
||||
finally:
|
||||
if cleanup:
|
||||
logger.info('Deleting container %s', image_tag)
|
||||
subprocess.check_output([executable, 'rmi', '-f', image_tag])
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def prepare(model: str,
|
||||
model_id: str,
|
||||
backend: LiteralBackend = 'pt',
|
||||
deployment_mode: t.Literal['container', 'local'] = 'local',
|
||||
clean_context: contextlib.ExitStack | None = None,
|
||||
cleanup: bool = True) -> t.Iterator[str]:
|
||||
def prepare(
|
||||
model: str,
|
||||
model_id: str,
|
||||
backend: LiteralBackend = 'pt',
|
||||
deployment_mode: t.Literal['container', 'local'] = 'local',
|
||||
clean_context: contextlib.ExitStack | None = None,
|
||||
cleanup: bool = True,
|
||||
) -> t.Iterator[str]:
|
||||
if clean_context is None:
|
||||
clean_context = contextlib.ExitStack()
|
||||
cleanup = True
|
||||
@@ -60,4 +75,5 @@ def prepare(model: str,
|
||||
if deployment_mode == 'container':
|
||||
container_name = clean_context.enter_context(build_container(bento, image_tag=container_name, cleanup=cleanup))
|
||||
yield container_name
|
||||
if cleanup: clean_context.close()
|
||||
if cleanup:
|
||||
clean_context.close()
|
||||
|
||||
@@ -3,12 +3,14 @@
|
||||
User can import these function for convenience, but
|
||||
we won't ensure backward compatibility for these functions. So use with caution.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
import typing as t
|
||||
|
||||
import openllm_core
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import openllm
|
||||
|
||||
@@ -62,23 +64,38 @@ if t.TYPE_CHECKING:
|
||||
from openllm_core.utils import validate_is_path as validate_is_path
|
||||
from openllm_core.utils.serde import converter as converter
|
||||
|
||||
|
||||
def generate_labels(llm: openllm.LLM[t.Any, t.Any]) -> dict[str, t.Any]:
|
||||
return {'backend': llm.__llm_backend__, 'framework': 'openllm', 'model_name': llm.config['model_name'], 'architecture': llm.config['architecture'], 'serialisation': llm._serialisation}
|
||||
return {
|
||||
'backend': llm.__llm_backend__,
|
||||
'framework': 'openllm',
|
||||
'model_name': llm.config['model_name'],
|
||||
'architecture': llm.config['architecture'],
|
||||
'serialisation': llm._serialisation,
|
||||
}
|
||||
|
||||
|
||||
def available_devices() -> tuple[str, ...]:
|
||||
"""Return available GPU under system. Currently only supports NVIDIA GPUs."""
|
||||
from .._strategies import NvidiaGpuResource
|
||||
|
||||
return tuple(NvidiaGpuResource.from_system())
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def device_count() -> int:
|
||||
return len(available_devices())
|
||||
|
||||
|
||||
__all__ = ['generate_labels', 'available_devices', 'device_count']
|
||||
|
||||
|
||||
def __dir__() -> t.Sequence[str]:
|
||||
return sorted(__all__)
|
||||
|
||||
|
||||
def __getattr__(it: str) -> t.Any:
|
||||
if hasattr(openllm_core.utils, it): return getattr(openllm_core.utils, it)
|
||||
else: raise AttributeError(f'module {__name__} has no attribute {it}')
|
||||
if hasattr(openllm_core.utils, it):
|
||||
return getattr(openllm_core.utils, it)
|
||||
else:
|
||||
raise AttributeError(f'module {__name__} has no attribute {it}')
|
||||
|
||||
@@ -4,6 +4,8 @@ import os
|
||||
from hypothesis import HealthCheck
|
||||
from hypothesis import settings
|
||||
|
||||
|
||||
settings.register_profile('CI', settings(suppress_health_check=[HealthCheck.too_slow]), deadline=None)
|
||||
|
||||
if 'CI' in os.environ: settings.load_profile('CI')
|
||||
if 'CI' in os.environ:
|
||||
settings.load_profile('CI')
|
||||
|
||||
@@ -8,30 +8,34 @@ import openllm
|
||||
|
||||
from openllm_core._configuration import ModelSettings
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@st.composite
|
||||
def model_settings(draw: st.DrawFn):
|
||||
"""Strategy for generating ModelSettings objects."""
|
||||
kwargs: dict[str, t.Any] = {
|
||||
'default_id': st.text(min_size=1),
|
||||
'model_ids': st.lists(st.text(), min_size=1),
|
||||
'architecture': st.text(min_size=1),
|
||||
'url': st.text(),
|
||||
'trust_remote_code': st.booleans(),
|
||||
'requirements': st.none() | st.lists(st.text(), min_size=1),
|
||||
'model_type': st.sampled_from(['causal_lm', 'seq2seq_lm']),
|
||||
'name_type': st.sampled_from(['dasherize', 'lowercase']),
|
||||
'timeout': st.integers(min_value=3600),
|
||||
'workers_per_resource': st.one_of(st.integers(min_value=1), st.floats(min_value=0.1, max_value=1.0)),
|
||||
'default_id': st.text(min_size=1),
|
||||
'model_ids': st.lists(st.text(), min_size=1),
|
||||
'architecture': st.text(min_size=1),
|
||||
'url': st.text(),
|
||||
'trust_remote_code': st.booleans(),
|
||||
'requirements': st.none() | st.lists(st.text(), min_size=1),
|
||||
'model_type': st.sampled_from(['causal_lm', 'seq2seq_lm']),
|
||||
'name_type': st.sampled_from(['dasherize', 'lowercase']),
|
||||
'timeout': st.integers(min_value=3600),
|
||||
'workers_per_resource': st.one_of(st.integers(min_value=1), st.floats(min_value=0.1, max_value=1.0)),
|
||||
}
|
||||
return draw(st.builds(ModelSettings, **kwargs))
|
||||
|
||||
def make_llm_config(cls_name: str,
|
||||
dunder_config: dict[str, t.Any] | ModelSettings,
|
||||
fields: tuple[tuple[t.LiteralString, str, t.Any], ...] | None = None,
|
||||
generation_fields: tuple[tuple[t.LiteralString, t.Any], ...] | None = None,
|
||||
) -> type[openllm.LLMConfig]:
|
||||
|
||||
def make_llm_config(
|
||||
cls_name: str,
|
||||
dunder_config: dict[str, t.Any] | ModelSettings,
|
||||
fields: tuple[tuple[t.LiteralString, str, t.Any], ...] | None = None,
|
||||
generation_fields: tuple[tuple[t.LiteralString, t.Any], ...] | None = None,
|
||||
) -> type[openllm.LLMConfig]:
|
||||
globs: dict[str, t.Any] = {'openllm': openllm}
|
||||
_config_args: list[str] = []
|
||||
lines: list[str] = [f'class {cls_name}Config(openllm.LLMConfig):']
|
||||
|
||||
@@ -23,30 +23,44 @@ from openllm_core._configuration import field_env_key
|
||||
from ._strategies._configuration import make_llm_config
|
||||
from ._strategies._configuration import model_settings
|
||||
|
||||
|
||||
# XXX: @aarnphm fixes TypedDict behaviour in 3.11
|
||||
@pytest.mark.skipif(sys.version_info[:2] == (3, 11), reason='TypedDict in 3.11 behaves differently, so we need to fix this')
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info[:2] == (3, 11), reason='TypedDict in 3.11 behaves differently, so we need to fix this'
|
||||
)
|
||||
def test_missing_default():
|
||||
with pytest.raises(ValueError, match='Missing required fields *'):
|
||||
make_llm_config('MissingDefaultId', {'name_type': 'lowercase', 'requirements': ['bentoml']})
|
||||
with pytest.raises(ValueError, match='Missing required fields *'):
|
||||
make_llm_config('MissingModelId', {'default_id': 'huggingface/t5-tiny-testing', 'requirements': ['bentoml']})
|
||||
with pytest.raises(ValueError, match='Missing required fields *'):
|
||||
make_llm_config('MissingArchitecture', {'default_id': 'huggingface/t5-tiny-testing', 'model_ids': ['huggingface/t5-tiny-testing'], 'requirements': ['bentoml'],},)
|
||||
make_llm_config(
|
||||
'MissingArchitecture',
|
||||
{
|
||||
'default_id': 'huggingface/t5-tiny-testing',
|
||||
'model_ids': ['huggingface/t5-tiny-testing'],
|
||||
'requirements': ['bentoml'],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_forbidden_access():
|
||||
cl_ = make_llm_config('ForbiddenAccess', {
|
||||
cl_ = make_llm_config(
|
||||
'ForbiddenAccess',
|
||||
{
|
||||
'default_id': 'huggingface/t5-tiny-testing',
|
||||
'model_ids': ['huggingface/t5-tiny-testing', 'bentoml/t5-tiny-testing'],
|
||||
'architecture': 'PreTrainedModel',
|
||||
'requirements': ['bentoml'],
|
||||
},
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
assert pytest.raises(openllm.exceptions.ForbiddenAttributeError, cl_.__getattribute__, cl_(), '__config__',)
|
||||
assert pytest.raises(openllm.exceptions.ForbiddenAttributeError, cl_.__getattribute__, cl_(), 'GenerationConfig',)
|
||||
assert pytest.raises(openllm.exceptions.ForbiddenAttributeError, cl_.__getattribute__, cl_(), 'SamplingParams',)
|
||||
assert pytest.raises(openllm.exceptions.ForbiddenAttributeError, cl_.__getattribute__, cl_(), '__config__')
|
||||
assert pytest.raises(openllm.exceptions.ForbiddenAttributeError, cl_.__getattribute__, cl_(), 'GenerationConfig')
|
||||
assert pytest.raises(openllm.exceptions.ForbiddenAttributeError, cl_.__getattribute__, cl_(), 'SamplingParams')
|
||||
assert openllm.utils.lenient_issubclass(cl_.__openllm_generation_class__, GenerationConfig)
|
||||
|
||||
|
||||
@given(model_settings())
|
||||
def test_class_normal_gen(gen_settings: ModelSettings):
|
||||
assume(gen_settings['default_id'] and all(i for i in gen_settings['model_ids']))
|
||||
@@ -55,25 +69,42 @@ def test_class_normal_gen(gen_settings: ModelSettings):
|
||||
for key in gen_settings:
|
||||
assert object.__getattribute__(cl_, f'__openllm_{key}__') == gen_settings.__getitem__(key)
|
||||
|
||||
|
||||
@given(model_settings(), st.integers())
|
||||
def test_simple_struct_dump(gen_settings: ModelSettings, field1: int):
|
||||
cl_ = make_llm_config('IdempotentLLM', gen_settings, fields=(('field1', 'float', field1),))
|
||||
assert cl_().model_dump()['field1'] == field1
|
||||
|
||||
|
||||
@given(model_settings(), st.integers())
|
||||
def test_config_derivation(gen_settings: ModelSettings, field1: int):
|
||||
cl_ = make_llm_config('IdempotentLLM', gen_settings, fields=(('field1', 'float', field1),))
|
||||
new_cls = cl_.model_derivate('DerivedLLM', default_id='asdfasdf')
|
||||
assert new_cls.__openllm_default_id__ == 'asdfasdf'
|
||||
|
||||
|
||||
@given(model_settings())
|
||||
def test_config_derived_follow_attrs_protocol(gen_settings: ModelSettings):
|
||||
cl_ = make_llm_config('AttrsProtocolLLM', gen_settings)
|
||||
assert attr.has(cl_)
|
||||
|
||||
@given(model_settings(), st.integers(max_value=283473), st.floats(min_value=0.0, max_value=1.0), st.integers(max_value=283473), st.floats(min_value=0.0, max_value=1.0),)
|
||||
def test_complex_struct_dump(gen_settings: ModelSettings, field1: int, temperature: float, input_field1: int, input_temperature: float):
|
||||
cl_ = make_llm_config('ComplexLLM', gen_settings, fields=(('field1', 'float', field1),), generation_fields=(('temperature', temperature),),)
|
||||
|
||||
@given(
|
||||
model_settings(),
|
||||
st.integers(max_value=283473),
|
||||
st.floats(min_value=0.0, max_value=1.0),
|
||||
st.integers(max_value=283473),
|
||||
st.floats(min_value=0.0, max_value=1.0),
|
||||
)
|
||||
def test_complex_struct_dump(
|
||||
gen_settings: ModelSettings, field1: int, temperature: float, input_field1: int, input_temperature: float
|
||||
):
|
||||
cl_ = make_llm_config(
|
||||
'ComplexLLM',
|
||||
gen_settings,
|
||||
fields=(('field1', 'float', field1),),
|
||||
generation_fields=(('temperature', temperature),),
|
||||
)
|
||||
sent = cl_()
|
||||
assert sent.model_dump()['field1'] == field1
|
||||
assert sent.model_dump()['generation_config']['temperature'] == temperature
|
||||
@@ -90,16 +121,18 @@ def test_complex_struct_dump(gen_settings: ModelSettings, field1: int, temperatu
|
||||
assert pas_nested.model_dump()['field1'] == input_field1
|
||||
assert pas_nested.model_dump()['generation_config']['temperature'] == input_temperature
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def patch_env(**attrs: t.Any):
|
||||
with mock.patch.dict(os.environ, attrs, clear=True):
|
||||
yield
|
||||
|
||||
|
||||
def test_struct_envvar():
|
||||
with patch_env(**{field_env_key('field1'): '4', field_env_key('temperature', suffix='generation'): '0.2',}):
|
||||
with patch_env(**{field_env_key('field1'): '4', field_env_key('temperature', suffix='generation'): '0.2'}):
|
||||
|
||||
class EnvLLM(openllm.LLMConfig):
|
||||
__config__ = {'default_id': 'asdfasdf', 'model_ids': ['asdf', 'asdfasdfads'], 'architecture': 'PreTrainedModel',}
|
||||
__config__ = {'default_id': 'asdfasdf', 'model_ids': ['asdf', 'asdfasdfads'], 'architecture': 'PreTrainedModel'}
|
||||
field1: int = 2
|
||||
|
||||
class GenerationConfig:
|
||||
@@ -113,9 +146,10 @@ def test_struct_envvar():
|
||||
assert overwrite_default.field1 == 4
|
||||
assert overwrite_default['temperature'] == 0.2
|
||||
|
||||
|
||||
def test_struct_provided_fields():
|
||||
class EnvLLM(openllm.LLMConfig):
|
||||
__config__ = {'default_id': 'asdfasdf', 'model_ids': ['asdf', 'asdfasdfads'], 'architecture': 'PreTrainedModel',}
|
||||
__config__ = {'default_id': 'asdfasdf', 'model_ids': ['asdf', 'asdfasdfads'], 'architecture': 'PreTrainedModel'}
|
||||
field1: int = 2
|
||||
|
||||
class GenerationConfig:
|
||||
@@ -125,26 +159,27 @@ def test_struct_provided_fields():
|
||||
assert sent.field1 == 20
|
||||
assert sent.generation_config.temperature == 0.4
|
||||
|
||||
|
||||
def test_struct_envvar_with_overwrite_provided_env(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as mk:
|
||||
mk.setenv(field_env_key('field1'), str(4.0))
|
||||
mk.setenv(field_env_key('temperature', suffix='generation'), str(0.2))
|
||||
sent = make_llm_config('OverwriteWithEnvAvailable', {
|
||||
'default_id': 'asdfasdf',
|
||||
'model_ids': ['asdf', 'asdfasdfads'],
|
||||
'architecture': 'PreTrainedModel'
|
||||
},
|
||||
fields=(('field1', 'float', 3.0),),
|
||||
).model_construct_env(field1=20.0, temperature=0.4)
|
||||
sent = make_llm_config(
|
||||
'OverwriteWithEnvAvailable',
|
||||
{'default_id': 'asdfasdf', 'model_ids': ['asdf', 'asdfasdfads'], 'architecture': 'PreTrainedModel'},
|
||||
fields=(('field1', 'float', 3.0),),
|
||||
).model_construct_env(field1=20.0, temperature=0.4)
|
||||
assert sent.generation_config.temperature == 0.4
|
||||
assert sent.field1 == 20.0
|
||||
|
||||
|
||||
@given(model_settings())
|
||||
@pytest.mark.parametrize(('return_dict', 'typ'), [(True, dict), (False, transformers.GenerationConfig)])
|
||||
def test_conversion_to_transformers(return_dict: bool, typ: type[t.Any], gen_settings: ModelSettings):
|
||||
cl_ = make_llm_config('ConversionLLM', gen_settings)
|
||||
assert isinstance(cl_().to_generation_config(return_as_dict=return_dict), typ)
|
||||
|
||||
|
||||
@given(model_settings())
|
||||
def test_click_conversion(gen_settings: ModelSettings):
|
||||
# currently our conversion omit Union type.
|
||||
@@ -157,6 +192,7 @@ def test_click_conversion(gen_settings: ModelSettings):
|
||||
click_options_filtered = [i for i in wrapped.__click_params__ if i.name and not i.name.startswith('fake_')]
|
||||
assert len(filtered) == len(click_options_filtered)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('model_name', openllm.CONFIG_MAPPING.keys())
|
||||
def test_configuration_dict_protocol(model_name: str):
|
||||
config = openllm.AutoConfig.for_model(model_name)
|
||||
|
||||
@@ -7,23 +7,44 @@ import pytest
|
||||
|
||||
import openllm
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm_core._typing_compat import LiteralBackend
|
||||
|
||||
_MODELING_MAPPING = {'flan_t5': 'google/flan-t5-small', 'opt': 'facebook/opt-125m', 'baichuan': 'baichuan-inc/Baichuan-7B'}
|
||||
_PROMPT_MAPPING = {'qa': 'Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?'}
|
||||
_MODELING_MAPPING = {
|
||||
'flan_t5': 'google/flan-t5-small',
|
||||
'opt': 'facebook/opt-125m',
|
||||
'baichuan': 'baichuan-inc/Baichuan-7B',
|
||||
}
|
||||
_PROMPT_MAPPING = {
|
||||
'qa': 'Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?'
|
||||
}
|
||||
|
||||
def parametrise_local_llm(model: str) -> t.Generator[tuple[str, openllm.LLMRunner[t.Any, t.Any] | openllm.LLM[t.Any, t.Any]], None, None]:
|
||||
if model not in _MODELING_MAPPING: pytest.skip(f"'{model}' is not yet supported in framework testing.")
|
||||
|
||||
def parametrise_local_llm(
|
||||
model: str
|
||||
) -> t.Generator[tuple[str, openllm.LLMRunner[t.Any, t.Any] | openllm.LLM[t.Any, t.Any]], None, None]:
|
||||
if model not in _MODELING_MAPPING:
|
||||
pytest.skip(f"'{model}' is not yet supported in framework testing.")
|
||||
backends: tuple[LiteralBackend, ...] = ('pt',)
|
||||
for backend, prompt in itertools.product(backends, _PROMPT_MAPPING.keys()):
|
||||
yield prompt, openllm.Runner(model, model_id=_MODELING_MAPPING[model], ensure_available=True, backend=backend, init_local=True)
|
||||
yield (
|
||||
prompt,
|
||||
openllm.Runner(
|
||||
model, model_id=_MODELING_MAPPING[model], ensure_available=True, backend=backend, init_local=True
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def pytest_generate_tests(metafunc: pytest.Metafunc) -> None:
|
||||
if os.getenv('GITHUB_ACTIONS') is None:
|
||||
if 'prompt' in metafunc.fixturenames and 'llm' in metafunc.fixturenames:
|
||||
metafunc.parametrize('prompt,llm', [(p, llm) for p, llm in parametrise_local_llm(metafunc.function.__name__[5:-15])])
|
||||
metafunc.parametrize(
|
||||
'prompt,llm', [(p, llm) for p, llm in parametrise_local_llm(metafunc.function.__name__[5:-15])]
|
||||
)
|
||||
|
||||
|
||||
def pytest_sessionfinish(session: pytest.Session, exitstatus: int):
|
||||
# If no tests are collected, pytest exists with code 5, which makes the CI fail.
|
||||
if exitstatus == 5: session.exitstatus = 0
|
||||
if exitstatus == 5:
|
||||
session.exitstatus = 0
|
||||
|
||||
@@ -27,6 +27,7 @@ from openllm_core._typing_compat import DictStrAny
|
||||
from openllm_core._typing_compat import ListAny
|
||||
from openllm_core._typing_compat import LiteralQuantise
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
@@ -40,8 +41,11 @@ if t.TYPE_CHECKING:
|
||||
|
||||
from openllm.client import BaseAsyncClient
|
||||
|
||||
|
||||
class ResponseComparator(JSONSnapshotExtension):
|
||||
def serialize(self, data: SerializableData, *, exclude: PropertyFilter | None = None, matcher: PropertyMatcher | None = None,) -> SerializedData:
|
||||
def serialize(
|
||||
self, data: SerializableData, *, exclude: PropertyFilter | None = None, matcher: PropertyMatcher | None = None
|
||||
) -> SerializedData:
|
||||
if LazyType(ListAny).isinstance(data):
|
||||
data = [d.unmarshaled for d in data]
|
||||
else:
|
||||
@@ -73,12 +77,16 @@ class ResponseComparator(JSONSnapshotExtension):
|
||||
def eq_output(s: openllm.GenerationOutput, t: openllm.GenerationOutput) -> bool:
|
||||
return len(s.outputs) == len(t.outputs)
|
||||
|
||||
return len(serialized_data) == len(snapshot_data) and all([eq_output(s, t) for s, t in zip(serialized_data, snapshot_data)])
|
||||
return len(serialized_data) == len(snapshot_data) and all(
|
||||
[eq_output(s, t) for s, t in zip(serialized_data, snapshot_data)]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def response_snapshot(snapshot: SnapshotAssertion):
|
||||
return snapshot.use_extension(ResponseComparator)
|
||||
|
||||
|
||||
@attr.define(init=False)
|
||||
class _Handle(ABC):
|
||||
port: int
|
||||
@@ -88,8 +96,7 @@ class _Handle(ABC):
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
|
||||
def __attrs_init__(self, *args: t.Any, **attrs: t.Any):
|
||||
...
|
||||
def __attrs_init__(self, *args: t.Any, **attrs: t.Any): ...
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
self.client = openllm.client.AsyncHTTPClient(f'http://localhost:{self.port}')
|
||||
@@ -111,42 +118,65 @@ class _Handle(ABC):
|
||||
time.sleep(1)
|
||||
raise RuntimeError(f'Handle failed to initialise within {timeout} seconds.')
|
||||
|
||||
|
||||
@attr.define(init=False)
|
||||
class LocalHandle(_Handle):
|
||||
process: subprocess.Popen[bytes]
|
||||
|
||||
def __init__(self, process: subprocess.Popen[bytes], port: int, deployment_mode: t.Literal['container', 'local'],):
|
||||
def __init__(self, process: subprocess.Popen[bytes], port: int, deployment_mode: t.Literal['container', 'local']):
|
||||
self.__attrs_init__(port, deployment_mode, process)
|
||||
|
||||
def status(self) -> bool:
|
||||
return self.process.poll() is None
|
||||
|
||||
|
||||
class HandleProtocol(t.Protocol):
|
||||
@contextlib.contextmanager
|
||||
def __call__(*, model: str, model_id: str, image_tag: str, quantize: t.AnyStr | None = None,) -> t.Generator[_Handle, None, None]:
|
||||
...
|
||||
def __call__(
|
||||
*, model: str, model_id: str, image_tag: str, quantize: t.AnyStr | None = None
|
||||
) -> t.Generator[_Handle, None, None]: ...
|
||||
|
||||
|
||||
@attr.define(init=False)
|
||||
class DockerHandle(_Handle):
|
||||
container_name: str
|
||||
docker_client: docker.DockerClient
|
||||
|
||||
def __init__(self, docker_client: docker.DockerClient, container_name: str, port: int, deployment_mode: t.Literal['container', 'local'],):
|
||||
def __init__(
|
||||
self,
|
||||
docker_client: docker.DockerClient,
|
||||
container_name: str,
|
||||
port: int,
|
||||
deployment_mode: t.Literal['container', 'local'],
|
||||
):
|
||||
self.__attrs_init__(port, deployment_mode, container_name, docker_client)
|
||||
|
||||
def status(self) -> bool:
|
||||
container = self.docker_client.containers.get(self.container_name)
|
||||
return container.status in ['running', 'created']
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _local_handle(model: str, model_id: str, image_tag: str, deployment_mode: t.Literal['container', 'local'], quantize: LiteralQuantise | None = None, *, _serve_grpc: bool = False):
|
||||
def _local_handle(
|
||||
model: str,
|
||||
model_id: str,
|
||||
image_tag: str,
|
||||
deployment_mode: t.Literal['container', 'local'],
|
||||
quantize: LiteralQuantise | None = None,
|
||||
*,
|
||||
_serve_grpc: bool = False,
|
||||
):
|
||||
with openllm.utils.reserve_free_port() as port:
|
||||
pass
|
||||
|
||||
if not _serve_grpc:
|
||||
proc = openllm.start(model, model_id=model_id, quantize=quantize, additional_args=['--port', str(port)], __test__=True)
|
||||
proc = openllm.start(
|
||||
model, model_id=model_id, quantize=quantize, additional_args=['--port', str(port)], __test__=True
|
||||
)
|
||||
else:
|
||||
proc = openllm.start_grpc(model, model_id=model_id, quantize=quantize, additional_args=['--port', str(port)], __test__=True)
|
||||
proc = openllm.start_grpc(
|
||||
model, model_id=model_id, quantize=quantize, additional_args=['--port', str(port)], __test__=True
|
||||
)
|
||||
|
||||
yield LocalHandle(proc, port, deployment_mode)
|
||||
proc.terminate()
|
||||
@@ -159,8 +189,17 @@ def _local_handle(model: str, model_id: str, image_tag: str, deployment_mode: t.
|
||||
if proc.stderr:
|
||||
proc.stderr.close()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _container_handle(model: str, model_id: str, image_tag: str, deployment_mode: t.Literal['container', 'local'], quantize: LiteralQuantise | None = None, *, _serve_grpc: bool = False):
|
||||
def _container_handle(
|
||||
model: str,
|
||||
model_id: str,
|
||||
image_tag: str,
|
||||
deployment_mode: t.Literal['container', 'local'],
|
||||
quantize: LiteralQuantise | None = None,
|
||||
*,
|
||||
_serve_grpc: bool = False,
|
||||
):
|
||||
with openllm.utils.reserve_free_port() as port, openllm.utils.reserve_free_port() as prom_port:
|
||||
pass
|
||||
container_name = f'openllm-{model}-{self(model_id)}'.replace('-', '_')
|
||||
@@ -177,22 +216,22 @@ def _container_handle(model: str, model_id: str, image_tag: str, deployment_mode
|
||||
|
||||
env: DictStrAny = {}
|
||||
|
||||
if quantize is not None: env['OPENLLM_QUANTIZE'] = quantize
|
||||
if quantize is not None:
|
||||
env['OPENLLM_QUANTIZE'] = quantize
|
||||
|
||||
gpus = openllm.utils.device_count() or -1
|
||||
devs = [docker.types.DeviceRequest(count=gpus, capabilities=[['gpu']])] if gpus > 0 else None
|
||||
|
||||
container = client.containers.run(image_tag,
|
||||
command=args,
|
||||
name=container_name,
|
||||
environment=env,
|
||||
auto_remove=False,
|
||||
detach=True,
|
||||
device_requests=devs,
|
||||
ports={
|
||||
'3000/tcp': port,
|
||||
'3001/tcp': prom_port
|
||||
})
|
||||
container = client.containers.run(
|
||||
image_tag,
|
||||
command=args,
|
||||
name=container_name,
|
||||
environment=env,
|
||||
auto_remove=False,
|
||||
detach=True,
|
||||
device_requests=devs,
|
||||
ports={'3000/tcp': port, '3001/tcp': prom_port},
|
||||
)
|
||||
|
||||
yield DockerHandle(client, container.name, port, deployment_mode)
|
||||
|
||||
@@ -207,22 +246,26 @@ def _container_handle(model: str, model_id: str, image_tag: str, deployment_mode
|
||||
|
||||
container.remove()
|
||||
|
||||
|
||||
@pytest.fixture(scope='session', autouse=True)
|
||||
def clean_context() -> t.Generator[contextlib.ExitStack, None, None]:
|
||||
stack = contextlib.ExitStack()
|
||||
yield stack
|
||||
stack.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def el() -> t.Generator[asyncio.AbstractEventLoop, None, None]:
|
||||
loop = asyncio.get_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture(params=['container', 'local'], scope='session')
|
||||
def deployment_mode(request: pytest.FixtureRequest) -> str:
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def handler(el: asyncio.AbstractEventLoop, deployment_mode: t.Literal['container', 'local']):
|
||||
if deployment_mode == 'container':
|
||||
|
||||
@@ -5,6 +5,7 @@ import pytest
|
||||
|
||||
import openllm
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import contextlib
|
||||
|
||||
@@ -15,17 +16,24 @@ if t.TYPE_CHECKING:
|
||||
model = 'flan_t5'
|
||||
model_id = 'google/flan-t5-small'
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def flan_t5_handle(handler: HandleProtocol, deployment_mode: t.Literal['container', 'local'], clean_context: contextlib.ExitStack,):
|
||||
with openllm.testing.prepare(model, model_id=model_id, deployment_mode=deployment_mode, clean_context=clean_context) as image_tag:
|
||||
def flan_t5_handle(
|
||||
handler: HandleProtocol, deployment_mode: t.Literal['container', 'local'], clean_context: contextlib.ExitStack
|
||||
):
|
||||
with openllm.testing.prepare(
|
||||
model, model_id=model_id, deployment_mode=deployment_mode, clean_context=clean_context
|
||||
) as image_tag:
|
||||
with handler(model=model, model_id=model_id, image_tag=image_tag) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
async def flan_t5(flan_t5_handle: _Handle):
|
||||
await flan_t5_handle.health(240)
|
||||
return flan_t5_handle.client
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_flan_t5(flan_t5: t.Awaitable[openllm.client.AsyncHTTPClient], response_snapshot: ResponseComparator):
|
||||
client = await flan_t5
|
||||
|
||||
@@ -5,6 +5,7 @@ import pytest
|
||||
|
||||
import openllm
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import contextlib
|
||||
|
||||
@@ -15,17 +16,24 @@ if t.TYPE_CHECKING:
|
||||
model = 'opt'
|
||||
model_id = 'facebook/opt-125m'
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def opt_125m_handle(handler: HandleProtocol, deployment_mode: t.Literal['container', 'local'], clean_context: contextlib.ExitStack,):
|
||||
with openllm.testing.prepare(model, model_id=model_id, deployment_mode=deployment_mode, clean_context=clean_context) as image_tag:
|
||||
def opt_125m_handle(
|
||||
handler: HandleProtocol, deployment_mode: t.Literal['container', 'local'], clean_context: contextlib.ExitStack
|
||||
):
|
||||
with openllm.testing.prepare(
|
||||
model, model_id=model_id, deployment_mode=deployment_mode, clean_context=clean_context
|
||||
) as image_tag:
|
||||
with handler(model=model, model_id=model_id, image_tag=image_tag) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
async def opt_125m(opt_125m_handle: _Handle):
|
||||
await opt_125m_handle.health(240)
|
||||
return opt_125m_handle.client
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_opt_125m(opt_125m: t.Awaitable[openllm.client.AsyncHTTPClient], response_snapshot: ResponseComparator):
|
||||
client = await opt_125m
|
||||
|
||||
@@ -4,21 +4,25 @@ import typing as t
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import openllm
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.getenv('GITHUB_ACTIONS') is not None, reason='Model is too large for CI')
|
||||
def test_flan_t5_implementation(prompt: str, llm: openllm.LLM[t.Any, t.Any]):
|
||||
assert llm(prompt)
|
||||
|
||||
assert llm(prompt, temperature=0.8, top_p=0.23)
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.getenv('GITHUB_ACTIONS') is not None, reason='Model is too large for CI')
|
||||
def test_opt_implementation(prompt: str, llm: openllm.LLM[t.Any, t.Any]):
|
||||
assert llm(prompt)
|
||||
|
||||
assert llm(prompt, temperature=0.9, top_k=8)
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.getenv('GITHUB_ACTIONS') is not None, reason='Model is too large for CI')
|
||||
def test_baichuan_implementation(prompt: str, llm: openllm.LLM[t.Any, t.Any]):
|
||||
assert llm(prompt)
|
||||
|
||||
@@ -9,15 +9,18 @@ import openllm
|
||||
|
||||
from bentoml._internal.configuration.containers import BentoMLContainer
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
HF_INTERNAL_T5_TESTING = 'hf-internal-testing/tiny-random-t5'
|
||||
|
||||
actions_xfail = functools.partial(pytest.mark.xfail,
|
||||
condition=os.getenv('GITHUB_ACTIONS') is not None,
|
||||
reason='Marking GitHub Actions to xfail due to flakiness and building environment not isolated.',
|
||||
)
|
||||
actions_xfail = functools.partial(
|
||||
pytest.mark.xfail,
|
||||
condition=os.getenv('GITHUB_ACTIONS') is not None,
|
||||
reason='Marking GitHub Actions to xfail due to flakiness and building environment not isolated.',
|
||||
)
|
||||
|
||||
|
||||
@actions_xfail
|
||||
def test_general_build_with_internal_testing():
|
||||
@@ -32,6 +35,7 @@ def test_general_build_with_internal_testing():
|
||||
bento = openllm.build('flan-t5', model_id=HF_INTERNAL_T5_TESTING)
|
||||
assert len(bento_store.list(bento.tag)) == 1
|
||||
|
||||
|
||||
@actions_xfail
|
||||
def test_general_build_from_local(tmp_path_factory: pytest.TempPathFactory):
|
||||
local_path = tmp_path_factory.mktemp('local_t5')
|
||||
@@ -42,12 +46,16 @@ def test_general_build_from_local(tmp_path_factory: pytest.TempPathFactory):
|
||||
|
||||
assert openllm.build('flan-t5', model_id=local_path.resolve().__fspath__(), model_version='local')
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def dockerfile_template(tmp_path_factory: pytest.TempPathFactory):
|
||||
file = tmp_path_factory.mktemp('dockerfiles') / 'Dockerfile.template'
|
||||
file.write_text("{% extends bento_base_template %}\n{% block SETUP_BENTO_ENTRYPOINT %}\n{{ super() }}\nRUN echo 'sanity from custom dockerfile'\n{% endblock %}")
|
||||
file.write_text(
|
||||
"{% extends bento_base_template %}\n{% block SETUP_BENTO_ENTRYPOINT %}\n{{ super() }}\nRUN echo 'sanity from custom dockerfile'\n{% endblock %}"
|
||||
)
|
||||
return file
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('dockerfile_template')
|
||||
@actions_xfail
|
||||
def test_build_with_custom_dockerfile(dockerfile_template: Path):
|
||||
|
||||
@@ -11,9 +11,11 @@ from openllm._strategies import CascadingResourceStrategy
|
||||
from openllm._strategies import NvidiaGpuResource
|
||||
from openllm._strategies import get_resource
|
||||
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
|
||||
def test_nvidia_gpu_resource_from_env(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as mcls:
|
||||
mcls.setenv('CUDA_VISIBLE_DEVICES', '0,1')
|
||||
@@ -22,6 +24,7 @@ def test_nvidia_gpu_resource_from_env(monkeypatch: pytest.MonkeyPatch):
|
||||
assert resource == ['0', '1']
|
||||
mcls.delenv('CUDA_VISIBLE_DEVICES')
|
||||
|
||||
|
||||
def test_nvidia_gpu_cutoff_minus(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as mcls:
|
||||
mcls.setenv('CUDA_VISIBLE_DEVICES', '0,2,-1,1')
|
||||
@@ -30,6 +33,7 @@ def test_nvidia_gpu_cutoff_minus(monkeypatch: pytest.MonkeyPatch):
|
||||
assert resource == ['0', '2']
|
||||
mcls.delenv('CUDA_VISIBLE_DEVICES')
|
||||
|
||||
|
||||
def test_nvidia_gpu_neg_val(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as mcls:
|
||||
mcls.setenv('CUDA_VISIBLE_DEVICES', '-1')
|
||||
@@ -38,6 +42,7 @@ def test_nvidia_gpu_neg_val(monkeypatch: pytest.MonkeyPatch):
|
||||
assert resource == []
|
||||
mcls.delenv('CUDA_VISIBLE_DEVICES')
|
||||
|
||||
|
||||
def test_nvidia_gpu_parse_literal(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as mcls:
|
||||
mcls.setenv('CUDA_VISIBLE_DEVICES', 'GPU-5ebe9f43-ac33420d4628')
|
||||
@@ -64,6 +69,7 @@ def test_nvidia_gpu_parse_literal(monkeypatch: pytest.MonkeyPatch):
|
||||
assert resource == ['MIG-GPU-5ebe9f43-ac33420d4628']
|
||||
mcls.delenv('CUDA_VISIBLE_DEVICES')
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.getenv('GITHUB_ACTIONS') is not None, reason='skip GPUs test on CI')
|
||||
def test_nvidia_gpu_validate(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as mcls:
|
||||
@@ -71,9 +77,14 @@ def test_nvidia_gpu_validate(monkeypatch: pytest.MonkeyPatch):
|
||||
mcls.setenv('CUDA_VISIBLE_DEVICES', '')
|
||||
assert len(NvidiaGpuResource.from_system()) >= 0 # TODO: real from_system tests
|
||||
|
||||
assert pytest.raises(ValueError, NvidiaGpuResource.validate, [*NvidiaGpuResource.from_system(), 1],).match('Input list should be all string type.')
|
||||
assert pytest.raises(ValueError, NvidiaGpuResource.validate, [*NvidiaGpuResource.from_system(), 1]).match(
|
||||
'Input list should be all string type.'
|
||||
)
|
||||
assert pytest.raises(ValueError, NvidiaGpuResource.validate, [-2]).match('Input list should be all string type.')
|
||||
assert pytest.raises(ValueError, NvidiaGpuResource.validate, ['GPU-5ebe9f43', 'GPU-ac33420d4628']).match('Failed to parse available GPUs UUID')
|
||||
assert pytest.raises(ValueError, NvidiaGpuResource.validate, ['GPU-5ebe9f43', 'GPU-ac33420d4628']).match(
|
||||
'Failed to parse available GPUs UUID'
|
||||
)
|
||||
|
||||
|
||||
def test_nvidia_gpu_from_spec(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as mcls:
|
||||
@@ -102,12 +113,15 @@ def test_nvidia_gpu_from_spec(monkeypatch: pytest.MonkeyPatch):
|
||||
with pytest.raises(ValueError):
|
||||
assert NvidiaGpuResource.from_spec(-2)
|
||||
|
||||
|
||||
class GPURunnable(bentoml.Runnable):
|
||||
SUPPORTED_RESOURCES = ('nvidia.com/gpu', 'amd.com/gpu')
|
||||
|
||||
|
||||
def unvalidated_get_resource(x: dict[str, t.Any], y: str, validate: bool = False):
|
||||
return get_resource(x, y, validate=validate)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('gpu_type', ['nvidia.com/gpu', 'amd.com/gpu'])
|
||||
def test_cascade_strategy_worker_count(monkeypatch: MonkeyPatch, gpu_type: str):
|
||||
monkeypatch.setattr(strategy, 'get_resource', unvalidated_get_resource)
|
||||
@@ -119,6 +133,7 @@ def test_cascade_strategy_worker_count(monkeypatch: MonkeyPatch, gpu_type: str):
|
||||
assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 7, 8, 9]}, 0.5) == 1
|
||||
assert CascadingResourceStrategy.get_worker_count(GPURunnable, {gpu_type: [2, 5, 7, 8, 9]}, 0.4) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize('gpu_type', ['nvidia.com/gpu', 'amd.com/gpu'])
|
||||
def test_cascade_strategy_worker_env(monkeypatch: MonkeyPatch, gpu_type: str):
|
||||
monkeypatch.setattr(strategy, 'get_resource', unvalidated_get_resource)
|
||||
@@ -158,6 +173,7 @@ def test_cascade_strategy_worker_env(monkeypatch: MonkeyPatch, gpu_type: str):
|
||||
envs = CascadingResourceStrategy.get_worker_env(GPURunnable, {gpu_type: [2, 6, 7, 8, 9]}, 0.4, 2)
|
||||
assert envs.get('CUDA_VISIBLE_DEVICES') == '9'
|
||||
|
||||
|
||||
@pytest.mark.parametrize('gpu_type', ['nvidia.com/gpu', 'amd.com/gpu'])
|
||||
def test_cascade_strategy_disabled_via_env(monkeypatch: MonkeyPatch, gpu_type: str):
|
||||
monkeypatch.setattr(strategy, 'get_resource', unvalidated_get_resource)
|
||||
|
||||
13
ruff.toml
13
ruff.toml
@@ -1,4 +1,3 @@
|
||||
indent-width = 2
|
||||
extend-exclude = [
|
||||
"tools",
|
||||
"examples",
|
||||
@@ -38,8 +37,9 @@ ignore = [
|
||||
"RUF012", # mutable attributes to be used with ClassVar
|
||||
"E701", # multiple statement on single line
|
||||
]
|
||||
line-length = 192
|
||||
target-version = "py312"
|
||||
line-length = 119
|
||||
indent-width = 2
|
||||
target-version = "py38"
|
||||
typing-modules = ["openllm_core._typing_compat"]
|
||||
unfixable = ["TCH004"]
|
||||
|
||||
@@ -55,6 +55,7 @@ runtime-evaluated-base-classes = [
|
||||
runtime-evaluated-decorators = ["attrs.define", "attrs.frozen", "trait"]
|
||||
|
||||
[format]
|
||||
preview = true
|
||||
quote-style = "single"
|
||||
indent-style = "space"
|
||||
skip-magic-trailing-comma = true
|
||||
@@ -80,19 +81,19 @@ known-third-party = [
|
||||
"peft",
|
||||
"click_option_group",
|
||||
]
|
||||
lines-after-imports = 1
|
||||
lines-after-imports = 2
|
||||
lines-between-types = 1
|
||||
no-lines-before = ["future", "standard-library"]
|
||||
relative-imports-order = "closest-to-furthest"
|
||||
|
||||
[lint.flake8-quotes]
|
||||
avoid-escape = false
|
||||
multiline-quotes = "single"
|
||||
inline-quotes = "single"
|
||||
multiline-quotes = "double"
|
||||
docstring-quotes = "double"
|
||||
|
||||
[lint.extend-per-file-ignores]
|
||||
"openllm-python/src/openllm/models/**" = ["E", "F", "I001"]
|
||||
"openllm-client/src/openllm_client/__init__.pyi" = ["I001"]
|
||||
"openllm-python/tests/**/*" = ["S101", "TID252", "PT011", "S307"]
|
||||
"openllm-python/src/openllm/_llm.py" = ["F811"]
|
||||
"openllm-core/src/openllm_core/utils/import_utils.py" = ["PLW0603", "F811"]
|
||||
|
||||
@@ -10,7 +10,8 @@ import tomlkit
|
||||
|
||||
from ghapi.all import GhApi
|
||||
|
||||
if t.TYPE_CHECKING: from tomlkit.items import Array, Table
|
||||
if t.TYPE_CHECKING:
|
||||
from tomlkit.items import Array, Table
|
||||
|
||||
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, os.path.join(ROOT, 'openllm-python', 'src'))
|
||||
@@ -19,24 +20,40 @@ import openllm
|
||||
|
||||
_OWNER, _REPO = 'bentoml', 'openllm'
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Classifier:
|
||||
identifier: t.Dict[str, str] = dataclasses.field(
|
||||
default_factory=lambda: {
|
||||
'status': 'Development Status',
|
||||
'environment': 'Environment',
|
||||
'license': 'License',
|
||||
'topic': 'Topic',
|
||||
'os': 'Operating System',
|
||||
'audience': 'Intended Audience',
|
||||
'typing': 'Typing',
|
||||
'language': 'Programming Language',
|
||||
})
|
||||
default_factory=lambda: {
|
||||
'status': 'Development Status',
|
||||
'environment': 'Environment',
|
||||
'license': 'License',
|
||||
'topic': 'Topic',
|
||||
'os': 'Operating System',
|
||||
'audience': 'Intended Audience',
|
||||
'typing': 'Typing',
|
||||
'language': 'Programming Language',
|
||||
}
|
||||
)
|
||||
joiner: str = ' :: '
|
||||
|
||||
@staticmethod
|
||||
def status() -> dict[int, str]:
|
||||
return {v: status for v, status in zip(range(1, 8), ['1 - Planning', '2 - Pre-Alpha', '3 - Alpha', '4 - Beta', '5 - Production/Stable', '6 - Mature', '7 - Inactive'])}
|
||||
return {
|
||||
v: status
|
||||
for v, status in zip(
|
||||
range(1, 8),
|
||||
[
|
||||
'1 - Planning',
|
||||
'2 - Pre-Alpha',
|
||||
'3 - Alpha',
|
||||
'4 - Beta',
|
||||
'5 - Production/Stable',
|
||||
'6 - Mature',
|
||||
'7 - Inactive',
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def apache() -> str:
|
||||
@@ -50,19 +67,29 @@ class Classifier:
|
||||
return cls_.joiner.join([cls_.identifier[identifier], *decls])
|
||||
|
||||
@staticmethod
|
||||
def create_python_classifier(implementation: list[str] | None = None, supported_version: list[str] | None = None) -> list[str]:
|
||||
if supported_version is None: supported_version = ['3.8', '3.9', '3.10', '3.11', '3.12']
|
||||
if implementation is None: implementation = ['CPython', 'PyPy']
|
||||
base = [Classifier.create_classifier('language', 'Python'), Classifier.create_classifier('language', 'Python', '3')]
|
||||
def create_python_classifier(
|
||||
implementation: list[str] | None = None, supported_version: list[str] | None = None
|
||||
) -> list[str]:
|
||||
if supported_version is None:
|
||||
supported_version = ['3.8', '3.9', '3.10', '3.11', '3.12']
|
||||
if implementation is None:
|
||||
implementation = ['CPython', 'PyPy']
|
||||
base = [
|
||||
Classifier.create_classifier('language', 'Python'),
|
||||
Classifier.create_classifier('language', 'Python', '3'),
|
||||
]
|
||||
base.append(Classifier.create_classifier('language', 'Python', '3', 'Only'))
|
||||
base.extend([Classifier.create_classifier('language', 'Python', version) for version in supported_version])
|
||||
base.extend([Classifier.create_classifier('language', 'Python', 'Implementation', impl) for impl in implementation])
|
||||
base.extend(
|
||||
[Classifier.create_classifier('language', 'Python', 'Implementation', impl) for impl in implementation]
|
||||
)
|
||||
return base
|
||||
|
||||
@staticmethod
|
||||
def create_status_classifier(level: int) -> str:
|
||||
return Classifier.create_classifier('status', Classifier.status()[level])
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Dependencies:
|
||||
name: str
|
||||
@@ -105,29 +132,31 @@ class Dependencies:
|
||||
else:
|
||||
dep = f'{self.name}{self.pypi_extensions}'
|
||||
deps.append(dep)
|
||||
if self.platform: deps.append(self.platform_restriction(*self.platform))
|
||||
if self.platform:
|
||||
deps.append(self.platform_restriction(*self.platform))
|
||||
return ';'.join(deps)
|
||||
|
||||
@classmethod
|
||||
def from_tuple(cls, *decls: t.Any) -> Dependencies:
|
||||
return cls(*decls)
|
||||
|
||||
|
||||
lower_bentoml_constraint = '1.1.2'
|
||||
_BENTOML_EXT = ['io']
|
||||
_TRANSFORMERS_EXT = ['torch', 'tokenizers']
|
||||
|
||||
_BASE_DEPENDENCIES = [
|
||||
Dependencies(name='bentoml', extensions=_BENTOML_EXT, lower_constraint=lower_bentoml_constraint),
|
||||
Dependencies(name='transformers', extensions=_TRANSFORMERS_EXT, lower_constraint='4.35.0'),
|
||||
Dependencies(name='openllm-client'),
|
||||
Dependencies(name='openllm-core'),
|
||||
Dependencies(name='safetensors'),
|
||||
Dependencies(name='optimum', lower_constraint='1.12.0'),
|
||||
Dependencies(name='accelerate'),
|
||||
Dependencies(name='ghapi'),
|
||||
Dependencies(name='click', lower_constraint='8.1.3'),
|
||||
Dependencies(name='cuda-python', platform=('Darwin', 'ne')),
|
||||
Dependencies(name='bitsandbytes', upper_constraint='0.42'), # 0.41 works with CUDA 11.8
|
||||
Dependencies(name='bentoml', extensions=_BENTOML_EXT, lower_constraint=lower_bentoml_constraint),
|
||||
Dependencies(name='transformers', extensions=_TRANSFORMERS_EXT, lower_constraint='4.35.0'),
|
||||
Dependencies(name='openllm-client'),
|
||||
Dependencies(name='openllm-core'),
|
||||
Dependencies(name='safetensors'),
|
||||
Dependencies(name='optimum', lower_constraint='1.12.0'),
|
||||
Dependencies(name='accelerate'),
|
||||
Dependencies(name='ghapi'),
|
||||
Dependencies(name='click', lower_constraint='8.1.3'),
|
||||
Dependencies(name='cuda-python', platform=('Darwin', 'ne')),
|
||||
Dependencies(name='bitsandbytes', upper_constraint='0.42'), # 0.41 works with CUDA 11.8
|
||||
]
|
||||
|
||||
FINE_TUNE_DEPS = ['peft>=0.6.0', 'bitsandbytes', 'datasets', 'accelerate', 'trl', 'scipy']
|
||||
@@ -143,7 +172,9 @@ GPTQ_DEPS = ['auto-gptq[triton]>=0.4.2', 'optimum>=1.12.0']
|
||||
VLLM_DEPS = ['vllm>=0.2.1post1', 'ray']
|
||||
|
||||
_base_requirements: dict[str, t.Any] = {
|
||||
inflection.dasherize(name): config_cls.__openllm_requirements__ for name, config_cls in openllm.CONFIG_MAPPING.items() if config_cls.__openllm_requirements__
|
||||
inflection.dasherize(name): config_cls.__openllm_requirements__
|
||||
for name, config_cls in openllm.CONFIG_MAPPING.items()
|
||||
if config_cls.__openllm_requirements__
|
||||
}
|
||||
|
||||
# shallow copy from locals()
|
||||
@@ -151,18 +182,23 @@ _locals = locals().copy()
|
||||
|
||||
# NOTE: update this table when adding new external dependencies
|
||||
# sync with openllm.utils.OPTIONAL_DEPENDENCIES
|
||||
_base_requirements.update({v: _locals.get(f'{inflection.underscore(v).upper()}_DEPS') for v in openllm.utils.OPTIONAL_DEPENDENCIES})
|
||||
_base_requirements.update(
|
||||
{v: _locals.get(f'{inflection.underscore(v).upper()}_DEPS') for v in openllm.utils.OPTIONAL_DEPENDENCIES}
|
||||
)
|
||||
|
||||
_base_requirements = {k: v for k, v in sorted(_base_requirements.items())}
|
||||
|
||||
fname = f'{os.path.basename(os.path.dirname(__file__))}/{os.path.basename(__file__)}'
|
||||
|
||||
|
||||
def correct_style(it: t.Any) -> t.Any:
|
||||
return it
|
||||
|
||||
|
||||
def create_classifiers() -> Array:
|
||||
arr = correct_style(tomlkit.array())
|
||||
arr.extend([
|
||||
arr.extend(
|
||||
[
|
||||
Classifier.create_status_classifier(5),
|
||||
Classifier.create_classifier('environment', 'GPU', 'NVIDIA CUDA'),
|
||||
Classifier.create_classifier('environment', 'GPU', 'NVIDIA CUDA', '12'),
|
||||
@@ -175,36 +211,43 @@ def create_classifiers() -> Array:
|
||||
Classifier.create_classifier('audience', 'Developers'),
|
||||
Classifier.create_classifier('audience', 'Science/Research'),
|
||||
Classifier.create_classifier('audience', 'System Administrators'),
|
||||
Classifier.create_classifier('typing', 'Typed'), *Classifier.create_python_classifier(),
|
||||
])
|
||||
Classifier.create_classifier('typing', 'Typed'),
|
||||
*Classifier.create_python_classifier(),
|
||||
]
|
||||
)
|
||||
return arr.multiline(True)
|
||||
|
||||
|
||||
def create_optional_table() -> Table:
|
||||
all_array = tomlkit.array()
|
||||
all_array.append(f"openllm[{','.join(_base_requirements)}]")
|
||||
|
||||
table = tomlkit.table(is_super_table=True)
|
||||
_base_requirements.update({'full': correct_style(all_array.multiline(True)), 'all': tomlkit.array('["openllm[full]"]')})
|
||||
_base_requirements.update(
|
||||
{'full': correct_style(all_array.multiline(True)), 'all': tomlkit.array('["openllm[full]"]')}
|
||||
)
|
||||
table.update({k: v for k, v in sorted(_base_requirements.items())})
|
||||
table.add(tomlkit.nl())
|
||||
|
||||
return table
|
||||
|
||||
|
||||
def create_url_table(_info: t.Any) -> Table:
|
||||
table = tomlkit.table()
|
||||
_urls = {
|
||||
'Blog': 'https://modelserving.com',
|
||||
'Chat': 'https://discord.gg/openllm',
|
||||
'Documentation': 'https://github.com/bentoml/openllm#readme',
|
||||
'GitHub': _info.html_url,
|
||||
'History': f'{_info.html_url}/blob/main/CHANGELOG.md',
|
||||
'Homepage': _info.homepage,
|
||||
'Tracker': f'{_info.html_url}/issues',
|
||||
'Twitter': 'https://twitter.com/bentomlai',
|
||||
'Blog': 'https://modelserving.com',
|
||||
'Chat': 'https://discord.gg/openllm',
|
||||
'Documentation': 'https://github.com/bentoml/openllm#readme',
|
||||
'GitHub': _info.html_url,
|
||||
'History': f'{_info.html_url}/blob/main/CHANGELOG.md',
|
||||
'Homepage': _info.homepage,
|
||||
'Tracker': f'{_info.html_url}/issues',
|
||||
'Twitter': 'https://twitter.com/bentomlai',
|
||||
}
|
||||
table.update({k: v for k, v in sorted(_urls.items())})
|
||||
return table
|
||||
|
||||
|
||||
def build_system() -> Table:
|
||||
table = tomlkit.table()
|
||||
table.add('build-backend', 'hatchling.build')
|
||||
@@ -213,33 +256,61 @@ def build_system() -> Table:
|
||||
table.add('requires', requires_array.multiline(True))
|
||||
return table
|
||||
|
||||
|
||||
def authors() -> Array:
|
||||
arr = correct_style(tomlkit.array())
|
||||
arr.append(dict(name='Aaron Pham', email='aarnphm@bentoml.com'))
|
||||
arr.append(dict(name='BentoML Team', email='contact@bentoml.com'))
|
||||
return arr.multiline(True)
|
||||
|
||||
|
||||
def keywords() -> Array:
|
||||
arr = correct_style(tomlkit.array())
|
||||
arr.extend([
|
||||
'MLOps', 'AI', 'BentoML', 'Model Serving', 'Model Deployment', 'LLMOps', 'Falcon', 'Vicuna', 'Llama 2', 'Fine tuning', 'Serverless', 'Large Language Model', 'Generative AI', 'StableLM',
|
||||
'Alpaca', 'PyTorch', 'Transformers'
|
||||
])
|
||||
arr.extend(
|
||||
[
|
||||
'MLOps',
|
||||
'AI',
|
||||
'BentoML',
|
||||
'Model Serving',
|
||||
'Model Deployment',
|
||||
'LLMOps',
|
||||
'Falcon',
|
||||
'Vicuna',
|
||||
'Llama 2',
|
||||
'Fine tuning',
|
||||
'Serverless',
|
||||
'Large Language Model',
|
||||
'Generative AI',
|
||||
'StableLM',
|
||||
'Alpaca',
|
||||
'PyTorch',
|
||||
'Transformers',
|
||||
]
|
||||
)
|
||||
return arr.multiline(True)
|
||||
|
||||
|
||||
def build_cli_extensions() -> Table:
|
||||
table = tomlkit.table()
|
||||
ext: dict[str, str] = {'openllm': 'openllm.cli.entrypoint:cli'}
|
||||
ext.update({
|
||||
f'openllm-{inflection.dasherize(ke)}': f'openllm.cli.extension.{ke}:cli' for ke in sorted([
|
||||
ext.update(
|
||||
{
|
||||
f'openllm-{inflection.dasherize(ke)}': f'openllm.cli.extension.{ke}:cli'
|
||||
for ke in sorted(
|
||||
[
|
||||
fname[:-3]
|
||||
for fname in os.listdir(os.path.abspath(os.path.join(ROOT, 'openllm-python', 'src', 'openllm', 'cli', 'extension')))
|
||||
for fname in os.listdir(
|
||||
os.path.abspath(os.path.join(ROOT, 'openllm-python', 'src', 'openllm', 'cli', 'extension'))
|
||||
)
|
||||
if fname.endswith('.py') and not fname.startswith('__')
|
||||
])
|
||||
})
|
||||
]
|
||||
)
|
||||
}
|
||||
)
|
||||
table.update(ext)
|
||||
return table
|
||||
|
||||
|
||||
def main() -> int:
|
||||
api = GhApi(owner=_OWNER, repo=_REPO, authenticate=False)
|
||||
_info = api.repos.get()
|
||||
@@ -271,4 +342,6 @@ def main() -> int:
|
||||
f.write(tomlkit.dumps(pyproject))
|
||||
return 0
|
||||
|
||||
if __name__ == '__main__': raise SystemExit(main())
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise SystemExit(main())
|
||||
|
||||
@@ -10,6 +10,7 @@ ROOT = Path(__file__).resolve().parent.parent
|
||||
|
||||
PACKAGES = {'openllm-python/src/openllm/': 'openllm'}
|
||||
|
||||
|
||||
def main() -> int:
|
||||
coverage_report = ROOT / 'coverage.xml'
|
||||
root = etree.fromstring(coverage_report.read_text())
|
||||
@@ -27,8 +28,10 @@ def main() -> int:
|
||||
raise ValueError(message)
|
||||
|
||||
for line in module.find('lines'):
|
||||
if line.attrib['hits'] == '1': data['hits'] += 1
|
||||
else: data['misses'] += 1
|
||||
if line.attrib['hits'] == '1':
|
||||
data['hits'] += 1
|
||||
else:
|
||||
data['misses'] += 1
|
||||
|
||||
total_statements_covered = 0
|
||||
total_statements = 0
|
||||
@@ -45,4 +48,6 @@ def main() -> int:
|
||||
coverage_summary.write_text(orjson.dumps(coverage_data, option=orjson.OPT_INDENT_2).decode(), encoding='utf-8')
|
||||
return 0
|
||||
|
||||
if __name__ == '__main__': raise SystemExit(main())
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise SystemExit(main())
|
||||
|
||||
@@ -7,7 +7,8 @@ from jinja2 import Environment
|
||||
from jinja2.loaders import FileSystemLoader
|
||||
from plumbum.cmd import curl, cut, shasum
|
||||
|
||||
if t.TYPE_CHECKING: from plumbum.commands.base import Pipeline
|
||||
if t.TYPE_CHECKING:
|
||||
from plumbum.commands.base import Pipeline
|
||||
|
||||
# get git root from this file
|
||||
ROOT = Path(__file__).parent.parent
|
||||
@@ -16,43 +17,59 @@ _OWNER = 'bentoml'
|
||||
_REPO = 'openllm'
|
||||
|
||||
_gz_strategies: dict[t.Literal['macos_arm', 'macos_intel', 'linux_intel'], str] = {
|
||||
'macos_arm': 'aarch64-apple-darwin',
|
||||
'macos_intel': 'x86_64-apple-darwin',
|
||||
'linux_intel': 'x86_64-unknown-linux-musl'
|
||||
'macos_arm': 'aarch64-apple-darwin',
|
||||
'macos_intel': 'x86_64-apple-darwin',
|
||||
'linux_intel': 'x86_64-unknown-linux-musl',
|
||||
}
|
||||
|
||||
def determine_release_url(svn_url: str, tag: str, target: t.Literal['macos_arm', 'macos_intel', 'linux_intel', 'archive']) -> str:
|
||||
if target == 'archive': return f'{svn_url}/archive/{tag}.tar.gz'
|
||||
|
||||
def determine_release_url(
|
||||
svn_url: str, tag: str, target: t.Literal['macos_arm', 'macos_intel', 'linux_intel', 'archive']
|
||||
) -> str:
|
||||
if target == 'archive':
|
||||
return f'{svn_url}/archive/{tag}.tar.gz'
|
||||
return f"{svn_url}/releases/download/{tag}/openllm-{tag.replace('v', '')}-{_gz_strategies[target]}.tar.gz"
|
||||
|
||||
|
||||
# curl -sSL <svn_url>/archive/refs/tags/<tag>.tar.gz | shasum -a256 | cut -d'' -f1
|
||||
def get_release_hash_command(svn_url: str, tag: str) -> Pipeline:
|
||||
return curl['-sSL', svn_url] | shasum['-a256'] | cut['-d', ' ', '-f1']
|
||||
|
||||
|
||||
def main() -> int:
|
||||
api = GhApi(owner=_OWNER, repo=_REPO, authenticate=False)
|
||||
_info = api.repos.get()
|
||||
release_tag = api.repos.get_latest_release().name
|
||||
|
||||
shadict: dict[str, t.Any] = {k: get_release_hash_command(determine_release_url(_info.svn_url, release_tag, k), release_tag)().strip() for k in _gz_strategies}
|
||||
shadict['archive'] = get_release_hash_command(determine_release_url(_info.svn_url, release_tag, 'archive'), release_tag)().strip()
|
||||
shadict: dict[str, t.Any] = {
|
||||
k: get_release_hash_command(determine_release_url(_info.svn_url, release_tag, k), release_tag)().strip()
|
||||
for k in _gz_strategies
|
||||
}
|
||||
shadict['archive'] = get_release_hash_command(
|
||||
determine_release_url(_info.svn_url, release_tag, 'archive'), release_tag
|
||||
)().strip()
|
||||
|
||||
ENVIRONMENT = Environment(extensions=['jinja2.ext.do', 'jinja2.ext.loopcontrols', 'jinja2.ext.debug'],
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True,
|
||||
loader=FileSystemLoader((ROOT / 'Formula').__fspath__(), followlinks=True))
|
||||
ENVIRONMENT = Environment(
|
||||
extensions=['jinja2.ext.do', 'jinja2.ext.loopcontrols', 'jinja2.ext.debug'],
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True,
|
||||
loader=FileSystemLoader((ROOT / 'Formula').__fspath__(), followlinks=True),
|
||||
)
|
||||
template_file = 'openllm.rb.j2'
|
||||
with (ROOT / 'Formula' / 'openllm.rb').open('w') as f:
|
||||
f.write(
|
||||
ENVIRONMENT.get_template(template_file, globals={
|
||||
'determine_release_url': determine_release_url
|
||||
}).render(shadict=shadict,
|
||||
__tag__=release_tag,
|
||||
__cmd__=fs.path.join(os.path.basename(os.path.dirname(__file__)), os.path.basename(__file__)),
|
||||
__template_file__=fs.path.join('Formula', template_file),
|
||||
__gz_extension__=_gz_strategies,
|
||||
**_info))
|
||||
ENVIRONMENT.get_template(template_file, globals={'determine_release_url': determine_release_url}).render(
|
||||
shadict=shadict,
|
||||
__tag__=release_tag,
|
||||
__cmd__=fs.path.join(os.path.basename(os.path.dirname(__file__)), os.path.basename(__file__)),
|
||||
__template_file__=fs.path.join('Formula', template_file),
|
||||
__gz_extension__=_gz_strategies,
|
||||
**_info,
|
||||
)
|
||||
)
|
||||
f.write('\n')
|
||||
return 0
|
||||
|
||||
if __name__ == '__main__': raise SystemExit(main())
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise SystemExit(main())
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user