diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 581b60b8..7dc681d3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,7 +8,7 @@ default_language_version: exclude: '.*\.(css|js|svg)$' repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: 'v0.1.7' + rev: 'v0.1.8' hooks: - id: ruff alias: r diff --git a/.ruff.toml b/.ruff.toml index 679cad98..7cc1e070 100644 --- a/.ruff.toml +++ b/.ruff.toml @@ -1,77 +1,38 @@ -extend-exclude = [ - "tools", - "examples", - "openllm-python/src/openllm/__init__.py", - "openllm-python/src/openllm/_version.py", - "openllm-python/src/openllm/models/__init__.py", - "openllm-python/src/openllm_cli/playground", - "openllm-client/src/openllm_client/pb/**", -] +exclude = ["tools", "examples", "openllm-python/src/openllm_cli/playground/"] extend-include = ["*.ipynb"] -extend-select = [ - "E", +preview = true +select = [ "F", - "B", - "PIE", "G", # flake8-logging-format - "W", # pycodestyle - "Q", # flake8-quotes - "FA", # flake8-future-annotations - "TCH", # flake8-type-checking - "PLW", # pylint-warning - "PLR", # pylint-refactor - "PT", # flake8-pytest-style "PERF", # perflint "RUF", # Ruff-specific rules - "YTT", # flake8-2020 -] -fix = true -ignore = [ - "PLR0911", - "PLR0912", - "PLR0913", - "PLR0915", - "ANN", # Use mypy - "PLR2004", # magic value to use constant - "E501", # ignore line length violation - "E401", # ignore multiple line import + "W6", + "E71", + "E72", + "E112", + "E113", + # "E124", + "E203", + "E272", + # "E303", + # "E304", + # "E501", + # "E502", "E702", - "TCH004", # don't move runtime import out, just warn about it - "RUF012", # mutable attributes to be used with ClassVar - "E701", # multiple statement on single line + "E703", + "E731", + "W191", + "W291", + "W293", + "UP039", # unnecessary-class-parentheses ] -line-length = 119 +ignore = ["RUF012"] +line-length = 150 indent-width = 2 -target-version = "py38" typing-modules = [ "openllm_core._typing_compat", "openllm_client._typing_compat", ] -unfixable = ["TCH004"] - -[lint.flake8-type-checking] -exempt-modules = [ - "typing", - "typing_extensions", - "openllm_core._typing_compat", - "openllm_client._typing_compat", -] -runtime-evaluated-base-classes = [ - "openllm_core._configuration.LLMConfig", - "openllm_core._configuration.GenerationConfig", - "openllm_core._configuration.SamplingParams", - "openllm_core._configuration.ModelSettings", - "openllm.LLMConfig", -] -runtime-evaluated-decorators = [ - "attrs.define", - "attrs.frozen", - "trait", - "attr.attrs", - 'attr.define', - '_attr.define', - 'attr.frozen', -] [format] preview = true @@ -81,20 +42,3 @@ skip-magic-trailing-comma = true [lint.pydocstyle] convention = "google" - -[lint.pycodestyle] -ignore-overlong-task-comments = true -max-line-length = 119 - -[lint.flake8-quotes] -avoid-escape = false -inline-quotes = "single" -multiline-quotes = "single" -docstring-quotes = "single" - -[lint.extend-per-file-ignores] -"openllm-python/tests/**/*" = ["S101", "TID252", "PT011", "S307"] -"openllm-python/src/openllm/_llm.py" = ["F811"] -"openllm-core/src/openllm_core/utils/import_utils.py" = ["PLW0603", "F811"] -"openllm-core/src/openllm_core/_configuration.py" = ["F811", "Q001"] -"openllm-python/src/openllm/_service_vars_pkg.py" = ["F821"] diff --git a/cz.py b/cz.py index 4ef0873f..2d7e38af 100755 --- a/cz.py +++ b/cz.py @@ -22,21 +22,13 @@ def run_cz(args): with tokenize.open(filepath) as file_: tokens = [t for t in tokenize.generate_tokens(file_.readline) if t.type in TOKEN_WHITELIST] token_count, line_count = len(tokens), len(set([t.start[0] for t in tokens])) - table.append( - [ - filepath.replace(os.path.join(args.dir, 'src'), ''), - line_count, - token_count / line_count if line_count != 0 else 0, - ] - ) + table.append([filepath.replace(os.path.join(args.dir, 'src'), ''), line_count, token_count / line_count if line_count != 0 else 0]) print(tabulate([headers, *sorted(table, key=lambda x: -x[1])], headers='firstrow', floatfmt='.1f') + '\n') print( tabulate( [ (dir_name, sum([x[1] for x in group])) - for dir_name, group in itertools.groupby( - sorted([(x[0].rsplit('/', 1)[0], x[1]) for x in table]), key=lambda x: x[0] - ) + for dir_name, group in itertools.groupby(sorted([(x[0].rsplit('/', 1)[0], x[1]) for x in table]), key=lambda x: x[0]) ], headers=['Directory', 'LOC'], floatfmt='.1f', @@ -54,10 +46,6 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( - '--dir', - choices=['openllm-python', 'openllm-core', 'openllm-client'], - help='directory to check', - default='openllm-python', - required=False, + '--dir', choices=['openllm-python', 'openllm-core', 'openllm-client'], help='directory to check', default='openllm-python', required=False ) raise SystemExit(run_cz(parser.parse_args())) diff --git a/openllm-client/src/openllm_client/__init__.pyi b/openllm-client/src/openllm_client/__init__.pyi index bec8fc11..3b5ecfb3 100644 --- a/openllm-client/src/openllm_client/__init__.pyi +++ b/openllm-client/src/openllm_client/__init__.pyi @@ -15,17 +15,11 @@ class HTTPClient: address: str helpers: _Helpers @overload - def __init__( - self, address: str, timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ... - ) -> None: ... + def __init__(self, address: str, timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...) -> None: ... @overload - def __init__( - self, address: str = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ... - ) -> None: ... + def __init__(self, address: str = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...) -> None: ... @overload - def __init__( - self, address: None = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ... - ) -> None: ... + def __init__(self, address: None = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...) -> None: ... @property def is_ready(self) -> bool: ... def health(self) -> bool: ... @@ -66,17 +60,11 @@ class AsyncHTTPClient: address: str helpers: _AsyncHelpers @overload - def __init__( - self, address: str, timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ... - ) -> None: ... + def __init__(self, address: str, timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...) -> None: ... @overload - def __init__( - self, address: str = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ... - ) -> None: ... + def __init__(self, address: str = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...) -> None: ... @overload - def __init__( - self, address: None = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ... - ) -> None: ... + def __init__(self, address: None = ..., timeout: int = ..., verify: bool = ..., max_retries: int = ..., api_version: str = ...) -> None: ... @property def is_ready(self) -> bool: ... async def health(self) -> bool: ... diff --git a/openllm-client/src/openllm_client/_http.py b/openllm-client/src/openllm_client/_http.py index ed802a67..9923468e 100644 --- a/openllm-client/src/openllm_client/_http.py +++ b/openllm-client/src/openllm_client/_http.py @@ -70,14 +70,10 @@ class HTTPClient(Client): return self.generate(prompt, **attrs) def health(self): - response = self._get( - '/readyz', response_cls=None, options={'return_raw_response': True, 'max_retries': self._max_retries} - ) + response = self._get('/readyz', response_cls=None, options={'return_raw_response': True, 'max_retries': self._max_retries}) return response.status_code == 200 - def generate( - self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs - ) -> Response: + def generate(self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs) -> Response: if timeout is None: timeout = self._timeout if verify is None: @@ -100,9 +96,7 @@ class HTTPClient(Client): for response_chunk in self.generate_iterator(prompt, llm_config, stop, adapter_name, timeout, verify, **attrs): yield StreamingResponse.from_response_chunk(response_chunk) - def generate_iterator( - self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs - ) -> t.Iterator[Response]: + def generate_iterator(self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs) -> t.Iterator[Response]: if timeout is None: timeout = self._timeout if verify is None: @@ -152,9 +146,7 @@ class AsyncHTTPClient(AsyncClient): @property async def _metadata(self) -> t.Awaitable[Metadata]: if self.__metadata is None: - self.__metadata = await self._post( - f'/{self._api_version}/metadata', response_cls=Metadata, json={}, options={'max_retries': self._max_retries} - ) + self.__metadata = await self._post(f'/{self._api_version}/metadata', response_cls=Metadata, json={}, options={'max_retries': self._max_retries}) return self.__metadata @property @@ -167,14 +159,10 @@ class AsyncHTTPClient(AsyncClient): return await self.generate(prompt, **attrs) async def health(self): - response = await self._get( - '/readyz', response_cls=None, options={'return_raw_response': True, 'max_retries': self._max_retries} - ) + response = await self._get('/readyz', response_cls=None, options={'return_raw_response': True, 'max_retries': self._max_retries}) return response.status_code == 200 - async def generate( - self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs - ) -> Response: + async def generate(self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs) -> Response: if timeout is None: timeout = self._timeout if verify is None: @@ -195,9 +183,7 @@ class AsyncHTTPClient(AsyncClient): async def generate_stream( self, prompt, llm_config=None, stop=None, adapter_name=None, timeout=None, verify=None, **attrs ) -> t.AsyncGenerator[StreamingResponse, t.Any]: - async for response_chunk in self.generate_iterator( - prompt, llm_config, stop, adapter_name, timeout, verify, **attrs - ): + async for response_chunk in self.generate_iterator(prompt, llm_config, stop, adapter_name, timeout, verify, **attrs): yield StreamingResponse.from_response_chunk(response_chunk) async def generate_iterator( diff --git a/openllm-client/src/openllm_client/_schemas.py b/openllm-client/src/openllm_client/_schemas.py index 2269bdd0..037c92b3 100644 --- a/openllm-client/src/openllm_client/_schemas.py +++ b/openllm-client/src/openllm_client/_schemas.py @@ -22,9 +22,9 @@ __all__ = ['Response', 'CompletionChunk', 'Metadata', 'StreamingResponse', 'Help @attr.define class Metadata(_SchemaMixin): - '''NOTE: Metadata is a modified version of the original MetadataOutput from openllm-core. + """NOTE: Metadata is a modified version of the original MetadataOutput from openllm-core. - The configuration is now structured into a dictionary for easy of use.''' + The configuration is now structured into a dictionary for easy of use.""" model_id: str timeout: int @@ -42,11 +42,7 @@ def _structure_metadata(data: t.Dict[str, t.Any], cls: type[Metadata]) -> Metada raise RuntimeError(f'Malformed metadata configuration (Server-side issue): {e}') from None try: return cls( - model_id=data['model_id'], - timeout=data['timeout'], - model_name=data['model_name'], - backend=data['backend'], - configuration=configuration, + model_id=data['model_id'], timeout=data['timeout'], model_name=data['model_name'], backend=data['backend'], configuration=configuration ) except Exception as e: raise RuntimeError(f'Malformed metadata (Server-side issue): {e}') from None @@ -65,10 +61,7 @@ class StreamingResponse(_SchemaMixin): @classmethod def from_response_chunk(cls, response: Response) -> StreamingResponse: return cls( - request_id=response.request_id, - index=response.outputs[0].index, - text=response.outputs[0].text, - token_ids=response.outputs[0].token_ids[0], + request_id=response.request_id, index=response.outputs[0].index, text=response.outputs[0].text, token_ids=response.outputs[0].token_ids[0] ) @@ -95,17 +88,11 @@ class Helpers: return self._async_client def messages(self, messages, add_generation_prompt=False): - return self.client._post( - '/v1/helpers/messages', - response_cls=str, - json=dict(messages=messages, add_generation_prompt=add_generation_prompt), - ) + return self.client._post('/v1/helpers/messages', response_cls=str, json=dict(messages=messages, add_generation_prompt=add_generation_prompt)) async def async_messages(self, messages, add_generation_prompt=False): return await self.async_client._post( - '/v1/helpers/messages', - response_cls=str, - json=dict(messages=messages, add_generation_prompt=add_generation_prompt), + '/v1/helpers/messages', response_cls=str, json=dict(messages=messages, add_generation_prompt=add_generation_prompt) ) @classmethod diff --git a/openllm-client/src/openllm_client/_shim.py b/openllm-client/src/openllm_client/_shim.py index 4c7470e7..04a9a730 100644 --- a/openllm-client/src/openllm_client/_shim.py +++ b/openllm-client/src/openllm_client/_shim.py @@ -140,9 +140,7 @@ class APIResponse(t.Generic[Response]): data = self._raw_response.json() try: - return self._client._process_response_data( - data=data, response_cls=self._response_cls, raw_response=self._raw_response - ) + return self._client._process_response_data(data=data, response_cls=self._response_cls, raw_response=self._raw_response) except Exception as exc: raise ValueError(exc) from None # validation error here @@ -273,16 +271,10 @@ class BaseClient(t.Generic[InnerClient, StreamType]): def _build_request(self, options: RequestOptions) -> httpx.Request: return self._inner.build_request( - method=options.method, - headers=self._build_headers(options), - url=self._prepare_url(options.url), - json=options.json, - params=options.params, + method=options.method, headers=self._build_headers(options), url=self._prepare_url(options.url), json=options.json, params=options.params ) - def _calculate_retry_timeout( - self, remaining_retries: int, options: RequestOptions, headers: t.Optional[httpx.Headers] = None - ) -> float: + def _calculate_retry_timeout(self, remaining_retries: int, options: RequestOptions, headers: t.Optional[httpx.Headers] = None) -> float: max_retries = options.get_max_retries(self._max_retries) # https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After try: @@ -323,9 +315,7 @@ class BaseClient(t.Generic[InnerClient, StreamType]): return True return False - def _process_response_data( - self, *, response_cls: type[Response], data: t.Dict[str, t.Any], raw_response: httpx.Response - ) -> Response: + def _process_response_data(self, *, response_cls: type[Response], data: t.Dict[str, t.Any], raw_response: httpx.Response) -> Response: return converter.structure(data, response_cls) def _process_response( @@ -338,24 +328,13 @@ class BaseClient(t.Generic[InnerClient, StreamType]): stream_cls: type[_Stream] | type[_AsyncStream] | None, ) -> Response: return APIResponse( - raw_response=raw_response, - client=self, - response_cls=response_cls, - stream=stream, - stream_cls=stream_cls, - options=options, + raw_response=raw_response, client=self, response_cls=response_cls, stream=stream, stream_cls=stream_cls, options=options ).parse() @attr.define(init=False) class Client(BaseClient[httpx.Client, Stream[t.Any]]): - def __init__( - self, - base_url: str | httpx.URL, - version: str, - timeout: int | httpx.Timeout = DEFAULT_TIMEOUT, - max_retries: int = MAX_RETRIES, - ): + def __init__(self, base_url: str | httpx.URL, version: str, timeout: int | httpx.Timeout = DEFAULT_TIMEOUT, max_retries: int = MAX_RETRIES): super().__init__( base_url=base_url, version=version, @@ -387,13 +366,7 @@ class Client(BaseClient[httpx.Client, Stream[t.Any]]): stream: bool = False, stream_cls: type[_Stream] | None = None, ) -> Response | _Stream: - return self._request( - response_cls=response_cls, - options=options, - remaining_retries=remaining_retries, - stream=stream, - stream_cls=stream_cls, - ) + return self._request(response_cls=response_cls, options=options, remaining_retries=remaining_retries, stream=stream, stream_cls=stream_cls) def _request( self, @@ -412,9 +385,7 @@ class Client(BaseClient[httpx.Client, Stream[t.Any]]): response.raise_for_status() except httpx.HTTPStatusError as exc: if retries > 0 and self._should_retry(exc.response): - return self._retry_request( - response_cls, options, retries, exc.response.headers, stream=stream, stream_cls=stream_cls - ) + return self._retry_request(response_cls, options, retries, exc.response.headers, stream=stream, stream_cls=stream_cls) # If the response is streamed then we need to explicitly read the completed response exc.response.read() raise ValueError(exc.message) from None @@ -427,9 +398,7 @@ class Client(BaseClient[httpx.Client, Stream[t.Any]]): return self._retry_request(response_cls, options, retries, stream=stream, stream_cls=stream_cls) raise ValueError(request) from None # connection error - return self._process_response( - response_cls=response_cls, options=options, raw_response=response, stream=stream, stream_cls=stream_cls - ) + return self._process_response(response_cls=response_cls, options=options, raw_response=response, stream=stream, stream_cls=stream_cls) def _retry_request( self, @@ -459,9 +428,7 @@ class Client(BaseClient[httpx.Client, Stream[t.Any]]): ) -> Response | _Stream: if options is None: options = {} - return self.request( - response_cls, RequestOptions(method='GET', url=path, **options), stream=stream, stream_cls=stream_cls - ) + return self.request(response_cls, RequestOptions(method='GET', url=path, **options), stream=stream, stream_cls=stream_cls) def _post( self, @@ -475,20 +442,12 @@ class Client(BaseClient[httpx.Client, Stream[t.Any]]): ) -> Response | _Stream: if options is None: options = {} - return self.request( - response_cls, RequestOptions(method='POST', url=path, json=json, **options), stream=stream, stream_cls=stream_cls - ) + return self.request(response_cls, RequestOptions(method='POST', url=path, json=json, **options), stream=stream, stream_cls=stream_cls) @attr.define(init=False) class AsyncClient(BaseClient[httpx.AsyncClient, AsyncStream[t.Any]]): - def __init__( - self, - base_url: str | httpx.URL, - version: str, - timeout: int | httpx.Timeout = DEFAULT_TIMEOUT, - max_retries: int = MAX_RETRIES, - ): + def __init__(self, base_url: str | httpx.URL, version: str, timeout: int | httpx.Timeout = DEFAULT_TIMEOUT, max_retries: int = MAX_RETRIES): super().__init__( base_url=base_url, version=version, @@ -527,9 +486,7 @@ class AsyncClient(BaseClient[httpx.AsyncClient, AsyncStream[t.Any]]): stream: bool = False, stream_cls: type[_AsyncStream] | None = None, ) -> Response | _AsyncStream: - return await self._request( - response_cls, options, remaining_retries=remaining_retries, stream=stream, stream_cls=stream_cls - ) + return await self._request(response_cls, options, remaining_retries=remaining_retries, stream=stream, stream_cls=stream_cls) async def _request( self, @@ -549,9 +506,7 @@ class AsyncClient(BaseClient[httpx.AsyncClient, AsyncStream[t.Any]]): response.raise_for_status() except httpx.HTTPStatusError as exc: if retries > 0 and self._should_retry(exc.response): - return self._retry_request( - response_cls, options, retries, exc.response.headers, stream=stream, stream_cls=stream_cls - ) + return self._retry_request(response_cls, options, retries, exc.response.headers, stream=stream, stream_cls=stream_cls) # If the response is streamed then we need to explicitly read the completed response await exc.response.aread() raise ValueError(exc.message) from None @@ -571,9 +526,7 @@ class AsyncClient(BaseClient[httpx.AsyncClient, AsyncStream[t.Any]]): return await self._retry_request(response_cls, options, retries, stream=stream, stream_cls=stream_cls) raise ValueError(request) from err # connection error - return self._process_response( - response_cls=response_cls, options=options, raw_response=response, stream=stream, stream_cls=stream_cls - ) + return self._process_response(response_cls=response_cls, options=options, raw_response=response, stream=stream, stream_cls=stream_cls) async def _retry_request( self, @@ -602,9 +555,7 @@ class AsyncClient(BaseClient[httpx.AsyncClient, AsyncStream[t.Any]]): ) -> Response | _AsyncStream: if options is None: options = {} - return await self.request( - response_cls, RequestOptions(method='GET', url=path, **options), stream=stream, stream_cls=stream_cls - ) + return await self.request(response_cls, RequestOptions(method='GET', url=path, **options), stream=stream, stream_cls=stream_cls) async def _post( self, @@ -618,6 +569,4 @@ class AsyncClient(BaseClient[httpx.AsyncClient, AsyncStream[t.Any]]): ) -> Response | _AsyncStream: if options is None: options = {} - return await self.request( - response_cls, RequestOptions(method='POST', url=path, json=json, **options), stream=stream, stream_cls=stream_cls - ) + return await self.request(response_cls, RequestOptions(method='POST', url=path, json=json, **options), stream=stream, stream_cls=stream_cls) diff --git a/openllm-client/src/openllm_client/_stream.py b/openllm-client/src/openllm_client/_stream.py index e81a7fb0..a5103207 100644 --- a/openllm-client/src/openllm_client/_stream.py +++ b/openllm-client/src/openllm_client/_stream.py @@ -38,9 +38,7 @@ class Stream(t.Generic[Response]): if sse.data.startswith('[DONE]'): break if sse.event is None: - yield self._client._process_response_data( - data=sse.model_dump(), response_cls=self._response_cls, raw_response=self._response - ) + yield self._client._process_response_data(data=sse.model_dump(), response_cls=self._response_cls, raw_response=self._response) @attr.define(auto_attribs=True) @@ -71,9 +69,7 @@ class AsyncStream(t.Generic[Response]): if sse.data.startswith('[DONE]'): break if sse.event is None: - yield self._client._process_response_data( - data=sse.model_dump(), response_cls=self._response_cls, raw_response=self._response - ) + yield self._client._process_response_data(data=sse.model_dump(), response_cls=self._response_cls, raw_response=self._response) @attr.define diff --git a/openllm-client/src/openllm_client/_typing_compat.py b/openllm-client/src/openllm_client/_typing_compat.py index 15d86f8a..48bd0a85 100644 --- a/openllm-client/src/openllm_client/_typing_compat.py +++ b/openllm-client/src/openllm_client/_typing_compat.py @@ -11,7 +11,5 @@ from openllm_core._typing_compat import ( overload as overload, ) -Platform = Annotated[ - LiteralString, Literal['MacOS', 'Linux', 'Windows', 'FreeBSD', 'OpenBSD', 'iOS', 'iPadOS', 'Android', 'Unknown'], str -] +Platform = Annotated[LiteralString, Literal['MacOS', 'Linux', 'Windows', 'FreeBSD', 'OpenBSD', 'iOS', 'iPadOS', 'Android', 'Unknown'], str] Architecture = Annotated[LiteralString, Literal['arm', 'arm64', 'x86', 'x86_64', 'Unknown'], str] diff --git a/openllm-core/src/openllm_core/__init__.py b/openllm-core/src/openllm_core/__init__.py index aea63eb5..bb5414db 100644 --- a/openllm-core/src/openllm_core/__init__.py +++ b/openllm-core/src/openllm_core/__init__.py @@ -1,14 +1,6 @@ from . import exceptions as exceptions, utils as utils -from ._configuration import ( - GenerationConfig as GenerationConfig, - LLMConfig as LLMConfig, - SamplingParams as SamplingParams, -) -from ._schemas import ( - GenerationInput as GenerationInput, - GenerationOutput as GenerationOutput, - MetadataOutput as MetadataOutput, -) +from ._configuration import GenerationConfig as GenerationConfig, LLMConfig as LLMConfig, SamplingParams as SamplingParams +from ._schemas import GenerationInput as GenerationInput, GenerationOutput as GenerationOutput, MetadataOutput as MetadataOutput from .config import ( CONFIG_MAPPING as CONFIG_MAPPING, CONFIG_MAPPING_NAMES as CONFIG_MAPPING_NAMES, diff --git a/openllm-core/src/openllm_core/_configuration.py b/openllm-core/src/openllm_core/_configuration.py index 4fb5651c..18a31ec3 100644 --- a/openllm-core/src/openllm_core/_configuration.py +++ b/openllm-core/src/openllm_core/_configuration.py @@ -35,7 +35,7 @@ from ._typing_compat import ( T, overload, ) -from .exceptions import ForbiddenAttributeError, MissingDependencyError +from .exceptions import ForbiddenAttributeError, MissingDependencyError, MissingAnnotationAttributeError, ValidationError from .utils import LazyLoader, ReprMixin, codegen, converter, dantic, field_env_key, first_not_none, lenient_issubclass from .utils.peft import PEFT_TASK_TYPE_TARGET_MAPPING, FineTuneConfig @@ -51,12 +51,7 @@ if t.TYPE_CHECKING: from ._schemas import MessageParam else: - vllm = LazyLoader( - 'vllm', - globals(), - 'vllm', - exc_msg='vLLM is not installed. Make sure to install it with `pip install "openllm[vllm]"`', - ) + vllm = LazyLoader('vllm', globals(), 'vllm', exc_msg='vLLM is not installed. Make sure to install it with `pip install "openllm[vllm]"`') transformers = LazyLoader('transformers', globals(), 'transformers') peft = LazyLoader('peft', globals(), 'peft') @@ -71,7 +66,8 @@ _object_setattr = object.__setattr__ class GenerationConfig(ReprMixin): max_new_tokens: int = dantic.Field(20, ge=0, description='The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.') min_length: int = dantic.Field( - 0, ge=0, # + 0, + ge=0, # description='The minimum length of the sequence to be generated. Corresponds to the length of the input prompt + `min_new_tokens`. Its effect is overridden by `min_new_tokens`, if also set.', ) min_new_tokens: int = dantic.Field(description='The minimum numbers of tokens to generate, ignoring the number of tokens in the prompt.') @@ -79,16 +75,19 @@ class GenerationConfig(ReprMixin): False, description="Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values: `True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an heuristic is applied and the generation stops when is it very unlikely to find better candidates; `'never'`, where the beam search procedure only stops when there cannot be better candidates (canonical beam search algorithm) ", ) - max_time: float = dantic.Field(description='The maximum amount of time you allow the computation to run for in seconds. generation will still finish the current pass after allocated time has been passed.') + max_time: float = dantic.Field( + description='The maximum amount of time you allow the computation to run for in seconds. generation will still finish the current pass after allocated time has been passed.' + ) num_beams: int = dantic.Field(1, description='Number of beams for beam search. 1 means no beam search.') num_beam_groups: int = dantic.Field( 1, description='Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams. [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.', ) - penalty_alpha: float = dantic.Field(description='The values balance the model confidence and the degeneration penalty in contrastive search decoding.') + penalty_alpha: float = dantic.Field( + description='The values balance the model confidence and the degeneration penalty in contrastive search decoding.' + ) use_cache: bool = dantic.Field( - True, - description='Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding.', + True, description='Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding.' ) temperature: float = dantic.Field(1.0, ge=0.0, le=1.0, description='The value used to modulate the next token probabilities.') top_k: int = dantic.Field(50, description='The number of highest probability vocabulary tokens to keep for top-k-filtering.') @@ -124,9 +123,7 @@ class GenerationConfig(ReprMixin): 1.0, description='Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while `length_penalty` < 0.0 encourages shorter sequences.', ) - no_repeat_ngram_size: int = dantic.Field( - 0, description='If set to int > 0, all ngrams of that size can only occur once.' - ) + no_repeat_ngram_size: int = dantic.Field(0, description='If set to int > 0, all ngrams of that size can only occur once.') bad_words_ids: t.List[t.List[int]] = dantic.Field( description='List of token ids that are not allowed to be generated. In order to get the token ids of the words that should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True, add_special_tokens=False).input_ids`.' ) @@ -159,46 +156,52 @@ class GenerationConfig(ReprMixin): forced_decoder_ids: t.List[t.List[int]] = dantic.Field( description='A list of pairs of integers which indicates a mapping from generation indices to token indices that will be forced before sampling. For example, `[[1, 123]]` means the second generated token will always be a token of index 123.' ) - num_return_sequences: int = dantic.Field( - 1, description='The number of independently computed returned sequences for each element in the batch.' - ) + num_return_sequences: int = dantic.Field(1, description='The number of independently computed returned sequences for each element in the batch.') output_attentions: bool = dantic.Field( False, description='Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more details.', ) output_hidden_states: bool = dantic.Field( - False, - description='Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more details.', + False, description='Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more details.' ) output_scores: bool = dantic.Field( - False, - description='Whether or not to return the prediction scores. See `scores` under returned tensors for more details.', + False, description='Whether or not to return the prediction scores. See `scores` under returned tensors for more details.' ) pad_token_id: int = dantic.Field(description='The id of the *padding* token.') bos_token_id: int = dantic.Field(description='The id of the *beginning-of-sequence* token.') - eos_token_id: t.Union[int, t.List[int]] = dantic.Field(description='The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.') - encoder_no_repeat_ngram_size: int = dantic.Field( - 0, - description='If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the `decoder_input_ids`.', + eos_token_id: t.Union[int, t.List[int]] = dantic.Field( + description='The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.' + ) + encoder_no_repeat_ngram_size: int = dantic.Field( + 0, description='If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the `decoder_input_ids`.' + ) + decoder_start_token_id: int = dantic.Field( + description='If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token.' ) - decoder_start_token_id: int = dantic.Field(description='If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token.') # NOTE: This is now implemented and supported for both PyTorch and vLLM logprobs: t.Optional[int] = dantic.Field(None, description='Number of log probabilities to return per output token.') prompt_logprobs: t.Optional[int] = dantic.Field(None, description='Number of log probabilities to return per input token.') + def __init__(self, *, _internal: bool = False, **attrs: t.Any): - if not _internal: raise RuntimeError('GenerationConfig is not meant to be used directly, but you can access this via a LLMConfig.generation_config') + if not _internal: + raise RuntimeError('GenerationConfig is not meant to be used directly, but you can access this via a LLMConfig.generation_config') self.__attrs_init__(**attrs) + def __getitem__(self, item: str) -> t.Any: - if hasattr(self, item): return getattr(self, item) + if hasattr(self, item): + return getattr(self, item) raise KeyError(f"'{self.__class__.__name__}' has no attribute {item}.") + @property - def __repr_keys__(self) -> set[str]: return {i.name for i in attr.fields(self.__class__)} + def __repr_keys__(self) -> set[str]: + return {i.name for i in attr.fields(self.__class__)} converter.register_unstructure_hook_factory( lambda cls: attr.has(cls) and lenient_issubclass(cls, GenerationConfig), lambda cls: make_dict_unstructure_fn( - cls, converter, # + cls, + converter, # **{k: override(omit=True) for k, v in attr.fields_dict(cls).items() if v.default in (None, attr.NOTHING)}, ), ) @@ -208,14 +211,14 @@ _GenerationConfigT = t.TypeVar('_GenerationConfigT', bound=GenerationConfig) @attr.frozen(slots=True, repr=False, init=False) class SamplingParams(ReprMixin): - '''SamplingParams is the attr-compatible version of ``vllm.SamplingParams``. It provides some utilities to also respect shared variables from ``openllm.LLMConfig``. + """SamplingParams is the attr-compatible version of ``vllm.SamplingParams``. It provides some utilities to also respect shared variables from ``openllm.LLMConfig``. The following value will be parsed directly from ``openllm.LLMConfig``: - temperature - top_k - top_p - max_tokens -> max_new_tokens - ''' + """ n: int = dantic.Field(1, description='Number of output sequences to return for the given prompt.') best_of: int = dantic.Field( @@ -232,8 +235,7 @@ class SamplingParams(ReprMixin): ) use_beam_search: bool = dantic.Field(False, description='Whether to use beam search instead of sampling.') ignore_eos: bool = dantic.Field( - False, - description='Whether to ignore the EOS token and continue generating tokens after the EOS token is generated.', + False, description='Whether to ignore the EOS token and continue generating tokens after the EOS token is generated.' ) skip_special_tokens: bool = dantic.Field(True, description='Whether to skip special tokens in the generated output.') # space_between_special_tokens: bool = dantic.Field(True, description='Whether to add a space between special tokens in the generated output.') @@ -252,9 +254,7 @@ class SamplingParams(ReprMixin): def __init__(self, *, _internal: bool = False, **attrs: t.Any): if not _internal: - raise RuntimeError( - 'SamplingParams is not meant to be used directly, but you can access this via a LLMConfig.sampling_config.' - ) + raise RuntimeError('SamplingParams is not meant to be used directly, but you can access this via a LLMConfig.sampling_config.') _object_setattr(self, 'max_tokens', attrs.pop('max_tokens', 16)) _object_setattr(self, 'temperature', attrs.pop('temperature', 1.0)) _object_setattr(self, 'top_k', attrs.pop('top_k', -1)) @@ -293,18 +293,14 @@ class SamplingParams(ReprMixin): @classmethod def from_generation_config(cls, generation_config: GenerationConfig, **attrs: t.Any) -> Self: - '''The main entrypoint for creating a SamplingParams from ``openllm.LLMConfig``.''' + """The main entrypoint for creating a SamplingParams from ``openllm.LLMConfig``.""" if 'max_tokens' in attrs and 'max_new_tokens' in attrs: raise ValueError("Both 'max_tokens' and 'max_new_tokens' are passed. Make sure to only use one of them.") temperature = first_not_none(attrs.pop('temperature', None), default=generation_config['temperature']) top_k = first_not_none(attrs.pop('top_k', None), default=generation_config['top_k']) top_p = first_not_none(attrs.pop('top_p', None), default=generation_config['top_p']) - max_tokens = first_not_none( - attrs.pop('max_tokens', None), attrs.pop('max_new_tokens', None), default=generation_config['max_new_tokens'] - ) - repetition_penalty = first_not_none( - attrs.pop('repetition_penalty', None), default=generation_config['repetition_penalty'] - ) + max_tokens = first_not_none(attrs.pop('max_tokens', None), attrs.pop('max_new_tokens', None), default=generation_config['max_new_tokens']) + repetition_penalty = first_not_none(attrs.pop('repetition_penalty', None), default=generation_config['repetition_penalty']) length_penalty = first_not_none(attrs.pop('length_penalty', None), default=generation_config['length_penalty']) early_stopping = first_not_none(attrs.pop('early_stopping', None), default=generation_config['early_stopping']) logprobs = first_not_none(attrs.pop('logprobs', None), default=generation_config['logprobs']) @@ -339,16 +335,12 @@ converter.register_unstructure_hook_factory( converter, _cattrs_omit_if_default=False, _cattrs_use_linecache=True, - **{ - k: override(omit_if_default=True) for k, v in attr.fields_dict(cls).items() if v.default in (None, attr.NOTHING) - }, + **{k: override(omit_if_default=True) for k, v in attr.fields_dict(cls).items() if v.default in (None, attr.NOTHING)}, ), ) converter.register_structure_hook_factory( lambda cls: attr.has(cls) and lenient_issubclass(cls, SamplingParams), - lambda cls: make_dict_structure_fn( - cls, converter, _cattrs_forbid_extra_keys=True, max_new_tokens=override(rename='max_tokens') - ), + lambda cls: make_dict_structure_fn(cls, converter, _cattrs_forbid_extra_keys=True, max_new_tokens=override(rename='max_tokens')), ) _SamplingParamsT = t.TypeVar('_SamplingParamsT', bound=SamplingParams) @@ -390,13 +382,15 @@ _transformed_type: DictStrAny = {'fine_tune_strategies': t.Dict[AdapterType, Fin @attr.define( - frozen=False, slots=True, # + frozen=False, + slots=True, # field_transformer=lambda _, __: [ attr.Attribute.from_counting_attr( k, dantic.Field( kw_only=False if t.get_origin(ann) is not Required else True, - auto_default=True, use_default_converter=False, # + auto_default=True, + use_default_converter=False, # type=_transformed_type.get(k, ann), metadata={'target': f'__openllm_{k}__'}, description=f'ModelSettings field for {k}.', @@ -407,10 +401,14 @@ _transformed_type: DictStrAny = {'fine_tune_strategies': t.Dict[AdapterType, Fin ) class _ModelSettingsAttr: def __getitem__(self, key: str) -> t.Any: - if key in codegen.get_annotations(ModelSettings): return _object_getattribute(self, key) + if key in codegen.get_annotations(ModelSettings): + return _object_getattribute(self, key) raise KeyError(key) + @classmethod - def from_settings(cls, settings: ModelSettings) -> _ModelSettingsAttr: return cls(**settings) + def from_settings(cls, settings: ModelSettings) -> _ModelSettingsAttr: + return cls(**settings) + if t.TYPE_CHECKING: # update-config-stubs.py: attrs start default_id: str @@ -431,15 +429,22 @@ class _ModelSettingsAttr: fine_tune_strategies: t.Dict[AdapterType, FineTuneConfig] # update-config-stubs.py: attrs stop + _DEFAULT = _ModelSettingsAttr.from_settings( ModelSettings( - name_type='dasherize', url='', # + name_type='dasherize', + url='', # backend=('pt', 'vllm', 'ctranslate'), - timeout=int(36e6), service_name='', # - model_type='causal_lm', requirements=None, # - trust_remote_code=False, workers_per_resource=1.0, # - default_id='__default__', model_ids=['__default__'], # - architecture='PreTrainedModel', serialisation='legacy', # + timeout=int(36e6), + service_name='', # + model_type='causal_lm', + requirements=None, # + trust_remote_code=False, + workers_per_resource=1.0, # + default_id='__default__', + model_ids=['__default__'], # + architecture='PreTrainedModel', + serialisation='legacy', # ) ) @@ -456,24 +461,24 @@ def structure_settings(cls: type[LLMConfig], _: type[_ModelSettingsAttr]) -> _Mo else: _attr['model_name'] = _cl_name.lower() _attr['start_name'] = _attr['model_name'] - _attr.update( - { - 'service_name': f'generated_{_attr["model_name"] if "model_name" in _attr else _config.model_name}_service.py', - 'fine_tune_strategies': { - ft_config.get('adapter_type', 'lora'): FineTuneConfig.from_config(ft_config, cls) for ft_config in _config.fine_tune_strategies - } if _config.fine_tune_strategies else {}, + _attr.update({ + 'service_name': f'generated_{_attr["model_name"] if "model_name" in _attr else _config.model_name}_service.py', + 'fine_tune_strategies': { + ft_config.get('adapter_type', 'lora'): FineTuneConfig.from_config(ft_config, cls) for ft_config in _config.fine_tune_strategies } - ) + if _config.fine_tune_strategies + else {}, + }) return attr.evolve(_config, **_attr) converter.register_structure_hook(_ModelSettingsAttr, structure_settings) - _reserved_namespace = {'__config__', 'GenerationConfig', 'SamplingParams'} + def _setattr_class(attr_name: str, value_var: t.Any) -> str: return f"setattr(cls, '{attr_name}', {value_var})" + def _make_assignment_script(cls: type[LLMConfig], attributes: attr.AttrsInstance) -> t.Callable[[type[LLMConfig]], None]: - '''Generate the assignment script with prefix attributes __openllm___.''' args, lines, annotations = [], [], {'return': None} globs = {'cls': cls, '_cached_attribute': attributes} for attr_name, field in attr.fields_dict(attributes.__class__).items(): @@ -483,82 +488,22 @@ def _make_assignment_script(cls: type[LLMConfig], attributes: attr.AttrsInstance annotations[attr_name] = field.type return codegen.generate_function(cls, '__assign_attr', lines, ('cls', *args), globs, annotations) + @attr.define(slots=True) class _ConfigAttr(t.Generic[_GenerationConfigT, _SamplingParamsT]): @staticmethod - def Field(default: t.Any = None, **attrs: t.Any) -> t.Any: - '''Field is a alias to the internal dantic utilities to easily create - attrs.fields with pydantic-compatible interface. For example: - - ```python - class MyModelConfig(openllm.LLMConfig): - field1 = openllm.LLMConfig.Field(...) - ``` - ''' - return dantic.Field(default, **attrs) - - # NOTE: The following is handled via __init_subclass__, and is only used for TYPE_CHECKING + def Field(default: t.Any = None, **attrs: t.Any) -> t.Any: return dantic.Field(default, **attrs) if t.TYPE_CHECKING: - # NOTE: public attributes to override - __config__: ModelSettings = Field(None) - """Internal configuration for this LLM model. Each of the field in here will be populated - and prefixed with __openllm___""" + __config__: t.ClassVar[ModelSettings] = Field(None) GenerationConfig: _GenerationConfigT = Field(None) - """Users can override this subclass of any given LLMConfig to provide GenerationConfig - default value. For example: - - ```python - class MyAwesomeModelConfig(openllm.LLMConfig): - class GenerationConfig: - max_new_tokens: int = 200 - top_k: int = 10 - num_return_sequences: int = 1 - eos_token_id: int = 11 - ``` - """ SamplingParams: _SamplingParamsT = Field(None) - """Users can override this subclass of any given LLMConfig to provide SamplingParams - default value. For example: - - ```python - class MyAwesomeModelConfig(openllm.LLMConfig): - class SamplingParams: - max_new_tokens: int = 200 - top_k: int = 10 - num_return_sequences: int = 1 - eos_token_id: int = 11 - ``` - """ - # NOTE: Internal attributes that should only be used by OpenLLM. Users usually shouldn't - # concern any of these. These are here for pyright not to complain. __attrs_attrs__: tuple[attr.Attribute[t.Any], ...] = Field(None, init=False) - """Since we are writing our own __init_subclass__, which is an alternative way for __prepare__, - we want openllm.LLMConfig to be attrs-like dataclass that has pydantic-like interface. - __attrs_attrs__ will be handled dynamically by __init_subclass__. - """ - __openllm_hints__: DictStrAny = Field(None, init=False) - """An internal cache of resolved types for this LLMConfig.""" - __openllm_accepted_keys__: set[str] = Field(None, init=False) - """The accepted keys for this LLMConfig.""" - __openllm_extras__: DictStrAny = Field(None, init=False) - """Extra metadata for this LLMConfig.""" - __openllm_config_override__: DictStrAny = Field(None, init=False) - """Additional override for some variables in LLMConfig.__config__""" - __openllm_generation_class__: type[_GenerationConfigT] = Field(None) - """The result generated GenerationConfig class for this LLMConfig. This will be used - to create the generation_config argument that can be used throughout the lifecycle. - This class will also be managed internally by OpenLLM.""" - __openllm_sampling_class__: type[_SamplingParamsT] = Field(None) - """The result generated SamplingParams class for this LLMConfig. This will be used - to create arguments for vLLM LLMEngine that can be used throughout the lifecycle. - This class will also be managed internally by OpenLLM.""" - - def __attrs_init__(self, *args: t.Any, **attrs: t.Any) -> None: - '''Generated __attrs_init__ for LLMConfig subclass that follows the attrs contract.''' - - # NOTE: The following will be populated from __config__ and also - # considered to be public API. Users can also access these via self[key] - # To update the docstring for these field, update it through tools/update-config-stubs.py + __openllm_hints__: DictStrAny = Field(None, init=False) # internal cache for type hint + __openllm_accepted_keys__: set[str] = Field(None, init=False) # accepted keys for LLMConfig + __openllm_extras__: DictStrAny = Field(None, init=False) # Additional metadata + __openllm_config_override__: DictStrAny = Field(None, init=False) # override variables for __config__ + __openllm_generation_class__: type[_GenerationConfigT] = Field(None) # generated GenerationConfig from class scope + __openllm_sampling_class__: type[_SamplingParamsT] = Field(None) # generated SamplingParams from class scope # update-config-stubs.py: special start __openllm_default_id__: str = Field(None) @@ -627,33 +572,13 @@ class _ConfigAttr(t.Generic[_GenerationConfigT, _SamplingParamsT]): class _ConfigBuilder: - __slots__ = ( - '_cls', - '_cls_dict', - '_attr_names', - '_attrs', - '_model_name', - '_base_attr_map', - '_base_names', - '_has_pre_init', - '_has_post_init', - ) + __slots__ = ('_cls', '_cls_dict', '_attr_names', '_attrs', '_model_name', '_base_attr_map', '_base_names', '_has_pre_init', '_has_post_init') def __init__( - self, - cls: type[LLMConfig], - these: dict[str, _CountingAttr], - auto_attribs: bool = False, - kw_only: bool = False, - collect_by_mro: bool = True, + self, cls: type[LLMConfig], these: dict[str, _CountingAttr], auto_attribs: bool = False, kw_only: bool = False, collect_by_mro: bool = True ): attrs, base_attrs, base_attr_map = _transform_attrs( - cls, - these, - auto_attribs, - kw_only, - collect_by_mro, - field_transformer=codegen.make_env_transformer(cls, cls.__openllm_model_name__), + cls, these, auto_attribs, kw_only, collect_by_mro, field_transformer=codegen.make_env_transformer(cls, cls.__openllm_model_name__) ) self._cls, self._model_name, self._cls_dict = cls, cls.__openllm_model_name__, dict(cls.__dict__) self._attrs = attrs @@ -682,15 +607,11 @@ class _ConfigBuilder: for base_cls in self._cls.__mro__[1:-1]: if base_cls.__dict__.get('__weakref__', None) is not None: weakref_inherited = True - existing_slots.update( - {name: getattr(base_cls, name, codegen._sentinel) for name in getattr(base_cls, '__slots__', [])} - ) + existing_slots.update({name: getattr(base_cls, name, codegen._sentinel) for name in getattr(base_cls, '__slots__', [])}) names = self._attr_names base_names = set(self._base_names) - if ( - '__weakref__' not in getattr(self._cls, '__slots__', ()) and '__weakref__' not in names and not weakref_inherited - ): + if '__weakref__' not in getattr(self._cls, '__slots__', ()) and '__weakref__' not in names and not weakref_inherited: names += ('__weakref__',) # We only add the names of attributes that aren't inherited. # Setting __slots__ to inherited attributes wastes memory. @@ -772,110 +693,13 @@ class _ConfigBuilder: for key, fn in ReprMixin.__dict__.items(): if key in ('__repr__', '__str__', '__repr_name__', '__repr_str__', '__repr_args__'): self._cls_dict[key] = codegen.add_method_dunders(self._cls, fn) - self._cls_dict['__repr_keys__'] = property( - lambda _: {i.name for i in self._attrs} | {'generation_config', 'sampling_config'} - ) + self._cls_dict['__repr_keys__'] = property(lambda _: {i.name for i in self._attrs} | {'generation_config', 'sampling_config'}) return self @attr.define(slots=True, init=False) class LLMConfig(_ConfigAttr[GenerationConfig, SamplingParams]): - """``openllm.LLMConfig`` is a pydantic-like ``attrs`` interface that offers fast and easy-to-use APIs. - - It lives in between the nice UX of `pydantic` and fast performance of `attrs` where it allows users to quickly formulate - a LLMConfig for any LLM without worrying too much about performance. It does a few things: - - - Automatic environment conversion: Each fields will automatically be provisioned with an environment - variable, make it easy to work with ahead-of-time or during serving time - - Familiar API: It is compatible with cattrs as well as providing a few Pydantic-2 like API, i.e: ``model_construct_env`` - - Automatic CLI generation: It can identify each fields and convert it to compatible Click options. - This means developers can use any of the LLMConfig to create CLI with compatible-Python - CLI library (click, typer, ...) - - > Internally, LLMConfig is an attrs class. All subclass of LLMConfig contains "attrs-like" features, - > which means LLMConfig will actually generate subclass to have attrs-compatible API, so that the subclass - > can be written as any normal Python class. - - To directly configure GenerationConfig for any given LLM, create a GenerationConfig under the subclass: - - ```python - class FlanT5Config(openllm.LLMConfig): - class GenerationConfig: - temperature: float = 0.75 - max_new_tokens: int = 3000 - top_k: int = 50 - top_p: float = 0.4 - repetition_penalty = 1.0 - ``` - By doing so, openllm.LLMConfig will create a compatible GenerationConfig attrs class that can be converted - to ``transformers.GenerationConfig``. These attribute can be accessed via ``LLMConfig.generation_config``. - - By default, all LLMConfig must provide a __config__ with 'default_id' and 'model_ids'. - - All other fields are optional, and will be use default value if not set. - - ```python - class FalconConfig(openllm.LLMConfig): - __config__ = { - "name_type": "lowercase", - "trust_remote_code": True, - "timeout": 3600000, - "url": "https://falconllm.tii.ae/", - "requirements": ["einops", "xformers", "safetensors"], - # NOTE: The below are always required - "default_id": "tiiuae/falcon-7b", - "model_ids": [ - "tiiuae/falcon-7b", - "tiiuae/falcon-40b", - "tiiuae/falcon-7b-instruct", - "tiiuae/falcon-40b-instruct", - ], - } - ``` - - > **Changelog**: - > Since 0.1.7, one can also define given fine-tune strategies for given LLM via its config: - ```python - class OPTConfig(openllm.LLMConfig): - __config__ = { - "name_type": "lowercase", - "trust_remote_code": False, - "url": "https://huggingface.co/docs/transformers/model_doc/opt", - "default_id": "facebook/opt-1.3b", - "model_ids": [ - "facebook/opt-125m", - "facebook/opt-350m", - "facebook/opt-1.3b", - "facebook/opt-2.7b", - "facebook/opt-6.7b", - "facebook/opt-66b", - ], - "fine_tune_strategies": ( - { - "adapter_type": "lora", - "r": 16, - "lora_alpha": 32, - "target_modules": ["q_proj", "v_proj"], - "lora_dropout": 0.05, - "bias": "none", - }, - ), - } - ``` - - Future work: - - Support pydantic-core as validation backend. - """ - def __init_subclass__(cls, **_: t.Any): - """The purpose of this ``__init_subclass__`` is to offer pydantic UX while adhering to attrs contract. - - This means we will construct all fields and metadata and hack into - how attrs use some of the 'magic' construction to generate the fields. - - It also does a few more extra features: It also generate all __openllm_*__ config from - ModelSettings (derived from __config__) to the class. - """ if not cls.__name__.endswith('Config'): logger.warning("LLMConfig subclass should end with 'Config'. Updating to %sConfig", cls.__name__) cls.__name__ = f'{cls.__name__}Config' @@ -885,9 +709,7 @@ class LLMConfig(_ConfigAttr[GenerationConfig, SamplingParams]): # auto assignment attributes generated from __config__ after create the new slot class. _make_assignment_script(cls, converter.structure(cls, _ModelSettingsAttr))(cls) - def _make_subclass( - class_attr: str, base: type[At], globs: dict[str, t.Any] | None = None, suffix_env: LiteralString | None = None - ) -> type[At]: + def _make_subclass(class_attr: str, base: type[At], globs: dict[str, t.Any] | None = None, suffix_env: LiteralString | None = None) -> type[At]: camel_name = cls.__name__.replace('Config', '') klass = attr.make_class( f'{camel_name}{class_attr}', @@ -904,9 +726,7 @@ class LLMConfig(_ConfigAttr[GenerationConfig, SamplingParams]): cls.__openllm_model_name__, suffix=suffix_env, globs=globs, - default_callback=lambda field_name, field_default: getattr( - getattr(cls, class_attr), field_name, field_default - ) + default_callback=lambda field_name, field_default: getattr(getattr(cls, class_attr), field_name, field_default) if codegen.has_own_attribute(cls, class_attr) else field_default, ), @@ -944,9 +764,7 @@ class LLMConfig(_ConfigAttr[GenerationConfig, SamplingParams]): unannotated = ca_names - annotated_names if len(unannotated) > 0: missing_annotated = sorted(unannotated, key=lambda n: t.cast('_CountingAttr', cd.get(n)).counter) - raise openllm_core.exceptions.MissingAnnotationAttributeError( - f"The following field doesn't have a type annotation: {missing_annotated}" - ) + raise MissingAnnotationAttributeError(f"The following field doesn't have a type annotation: {missing_annotated}") # We need to set the accepted key before generation_config # as generation_config is a special field that users shouldn't pass. cls.__openllm_accepted_keys__ = ( @@ -968,11 +786,7 @@ class LLMConfig(_ConfigAttr[GenerationConfig, SamplingParams]): # the hint cache for easier access cls.__openllm_hints__ = { f.name: f.type - for ite in [ - attr.fields(cls), - attr.fields(cls.__openllm_generation_class__), - attr.fields(cls.__openllm_sampling_class__), - ] + for ite in [attr.fields(cls), attr.fields(cls.__openllm_generation_class__), attr.fields(cls.__openllm_sampling_class__)] for f in ite } @@ -984,7 +798,9 @@ class LLMConfig(_ConfigAttr[GenerationConfig, SamplingParams]): def __setattr__(self, attr: str, value: t.Any) -> None: if attr in _reserved_namespace: - raise ForbiddenAttributeError(f'{attr} should not be set during runtime as these value will be reflected during runtime. Instead, you can create a custom LLM subclass {self.__class__.__name__}.') + raise ForbiddenAttributeError( + f'{attr} should not be set during runtime as these value will be reflected during runtime. Instead, you can create a custom LLM subclass {self.__class__.__name__}.' + ) super().__setattr__(attr, value) def __init__( @@ -1003,24 +819,19 @@ class LLMConfig(_ConfigAttr[GenerationConfig, SamplingParams]): if generation_config is None: generation_config = {k: v for k, v in attrs.items() if k in _generation_cl_dict} else: - generation_config = config_merger.merge( - generation_config, {k: v for k, v in attrs.items() if k in _generation_cl_dict} - ) + generation_config = config_merger.merge(generation_config, {k: v for k, v in attrs.items() if k in _generation_cl_dict}) if sampling_config is None: sampling_config = {k: v for k, v in attrs.items() if k in _sampling_cl_dict} else: - sampling_config = config_merger.merge( - sampling_config, {k: v for k, v in attrs.items() if k in _sampling_cl_dict} - ) + sampling_config = config_merger.merge(sampling_config, {k: v for k, v in attrs.items() if k in _sampling_cl_dict}) for k in _cached_keys: if k in generation_config or k in sampling_config or attrs[k] is None: del attrs[k] self.__openllm_config_override__ = __openllm_config_override__ or {} self.__openllm_extras__ = config_merger.merge( - first_not_none(__openllm_extras__, default={}), - {k: v for k, v in attrs.items() if k not in self.__openllm_accepted_keys__}, + first_not_none(__openllm_extras__, default={}), {k: v for k, v in attrs.items() if k not in self.__openllm_accepted_keys__} ) self.generation_config = self['generation_class'](_internal=True, **generation_config) self.sampling_config = self['sampling_class'].from_generation_config(self.generation_config, **sampling_config) @@ -1200,9 +1011,7 @@ class LLMConfig(_ConfigAttr[GenerationConfig, SamplingParams]): raise TypeError(f"{self} doesn't understand how to index None.") item = inflection.underscore(item) if item in _reserved_namespace: - raise ForbiddenAttributeError( - f"'{item}' is a reserved namespace for {self.__class__} and should not be access nor modified." - ) + raise ForbiddenAttributeError(f"'{item}' is a reserved namespace for {self.__class__} and should not be access nor modified.") internal_attributes = f'__openllm_{item}__' if hasattr(self, internal_attributes): if item in self.__openllm_config_override__: @@ -1220,12 +1029,18 @@ class LLMConfig(_ConfigAttr[GenerationConfig, SamplingParams]): return self.__openllm_extras__[item] else: raise KeyError(item) + def __getattribute__(self, item: str) -> t.Any: if item in _reserved_namespace: raise ForbiddenAttributeError(f"'{item}' belongs to a private namespace for {self.__class__} and should not be access nor modified.") return _object_getattribute.__get__(self)(item) - def __len__(self) -> int: return len(self.__openllm_accepted_keys__) + len(self.__openllm_extras__) - def keys(self) -> list[str]: return list(self.__openllm_accepted_keys__) + list(self.__openllm_extras__) + + def __len__(self) -> int: + return len(self.__openllm_accepted_keys__) + len(self.__openllm_extras__) + + def keys(self) -> list[str]: + return list(self.__openllm_accepted_keys__) + list(self.__openllm_extras__) + def values(self) -> list[t.Any]: return ( [getattr(self, k.name) for k in attr.fields(self.__class__)] @@ -1233,6 +1048,7 @@ class LLMConfig(_ConfigAttr[GenerationConfig, SamplingParams]): + [getattr(self.sampling_config, k.name) for k in attr.fields(self.__openllm_sampling_class__)] + list(self.__openllm_extras__.values()) ) + def items(self) -> list[tuple[str, t.Any]]: return ( [(k.name, getattr(self, k.name)) for k in attr.fields(self.__class__)] @@ -1240,9 +1056,13 @@ class LLMConfig(_ConfigAttr[GenerationConfig, SamplingParams]): + [(k.name, getattr(self.sampling_config, k.name)) for k in attr.fields(self.__openllm_sampling_class__)] + list(self.__openllm_extras__.items()) ) - def __iter__(self) -> t.Iterator[str]: return iter(self.keys()) + + def __iter__(self) -> t.Iterator[str]: + return iter(self.keys()) + def __contains__(self, item: t.Any) -> bool: - if item in self.__openllm_extras__: return True + if item in self.__openllm_extras__: + return True return item in self.__openllm_accepted_keys__ @classmethod @@ -1271,12 +1091,10 @@ class LLMConfig(_ConfigAttr[GenerationConfig, SamplingParams]): name or f"{cls.__name__.replace('Config', '')}DerivateConfig", (cls,), {}, - lambda ns: ns.update( - { - '__config__': config_merger.merge(copy.deepcopy(cls.__dict__['__config__']), _new_cfg), - '__base_config__': cls, # keep a reference for easy access - } - ), + lambda ns: ns.update({ + '__config__': config_merger.merge(copy.deepcopy(cls.__dict__['__config__']), _new_cfg), + '__base_config__': cls, # keep a reference for easy access + }), ) # For pickling to work, the __module__ variable needs to be set to the @@ -1308,12 +1126,11 @@ class LLMConfig(_ConfigAttr[GenerationConfig, SamplingParams]): try: attrs = orjson.loads(json_str) except orjson.JSONDecodeError as err: - raise openllm_core.exceptions.ValidationError(f'Failed to load JSON: {err}') from None + raise ValidationError(f'Failed to load JSON: {err}') from None return converter.structure(attrs, cls) @classmethod - def model_construct_env(cls, **attrs: t.Any) -> Self: - '''A helpers that respect configuration values environment variables.''' + def model_construct_env(cls, **attrs: t.Any) -> Self: # All LLMConfig init should start from here. attrs = {k: v for k, v in attrs.items() if v is not None} env_json_string = os.environ.get('OPENLLM_CONFIG', None) @@ -1329,9 +1146,7 @@ class LLMConfig(_ConfigAttr[GenerationConfig, SamplingParams]): sampling_config = attrs.pop('sampling_config') elif 'llm_config' in attrs: # NOTE: this is the new key llm_config = attrs.pop('llm_config') - generation_config = { - k: v for k, v in llm_config.items() if k in attr.fields_dict(cls.__openllm_generation_class__) - } + generation_config = {k: v for k, v in llm_config.items() if k in attr.fields_dict(cls.__openllm_generation_class__)} sampling_config = {k: v for k, v in llm_config.items() if k in attr.fields_dict(cls.__openllm_sampling_class__)} else: generation_config = {k: v for k, v in attrs.items() if k in attr.fields_dict(cls.__openllm_generation_class__)} @@ -1347,7 +1162,6 @@ class LLMConfig(_ConfigAttr[GenerationConfig, SamplingParams]): return converter.structure(config_from_env, cls) def model_validate_click(self, **attrs: t.Any) -> tuple[LLMConfig, DictStrAny]: - '''Parse given click attributes into a LLMConfig and return the remaining click attributes.''' llm_config_attrs: DictStrAny = {'generation_config': {}, 'sampling_config': {}} key_to_remove: ListStr = [] for k, v in attrs.items(): @@ -1412,10 +1226,12 @@ class LLMConfig(_ConfigAttr[GenerationConfig, SamplingParams]): no_repeat_ngram_size=config['no_repeat_ngram_size'], end_token=config['stop'], ) + class pt: @staticmethod def build(config: LLMConfig) -> LLMConfig: return config + class hf: @staticmethod def build(config: LLMConfig) -> transformers.GenerationConfig: @@ -1423,13 +1239,13 @@ class LLMConfig(_ConfigAttr[GenerationConfig, SamplingParams]): @overload def compatible_options(self, request: ChatCompletionRequest | CompletionRequest) -> dict[str, t.Any]: ... + @overload def compatible_options(self, request: CohereChatRequest | CohereGenerateRequest) -> dict[str, t.Any]: ... + def compatible_options(self, request: AttrsInstance) -> dict[str, t.Any]: if importlib.util.find_spec('openllm') is None: - raise MissingDependencyError( - "'openllm' is required to use 'compatible_options'. Make sure to install with 'pip install openllm'." - ) + raise MissingDependencyError("'openllm' is required to use 'compatible_options'. Make sure to install with 'pip install openllm'.") from openllm.protocol.cohere import CohereChatRequest, CohereGenerateRequest from openllm.protocol.openai import ChatCompletionRequest, CompletionRequest @@ -1439,6 +1255,7 @@ class LLMConfig(_ConfigAttr[GenerationConfig, SamplingParams]): return self.cohere.build(self, request) else: raise TypeError(f'Unknown request type {type(request)}') + class openai: @staticmethod def build(config: LLMConfig, request: ChatCompletionRequest | CompletionRequest) -> dict[str, t.Any]: @@ -1456,6 +1273,7 @@ class LLMConfig(_ConfigAttr[GenerationConfig, SamplingParams]): if hasattr(request, 'logprobs'): d['logprobs'] = first_not_none(request.logprobs, default=config['logprobs']) return d + class cohere: @staticmethod def build(config: LLMConfig, request: CohereGenerateRequest | CohereChatRequest) -> dict[str, t.Any]: @@ -1482,7 +1300,13 @@ class LLMConfig(_ConfigAttr[GenerationConfig, SamplingParams]): @property def chat_messages(self) -> list[MessageParam]: from ._schemas import MessageParam - return [MessageParam(role='system', content='You are a helpful assistant'), MessageParam(role='user', content="Hello, I'm looking for a chatbot that can help me with my work."), MessageParam(role='assistant', content='Yes? What can I help you with?')] + + return [ + MessageParam(role='system', content='You are a helpful assistant'), + MessageParam(role='user', content="Hello, I'm looking for a chatbot that can help me with my work."), + MessageParam(role='assistant', content='Yes? What can I help you with?'), + ] + @classmethod def parse(cls, f: AnyCallable) -> click.Command: for name, field in attr.fields_dict(cls.__openllm_generation_class__).items(): @@ -1501,9 +1325,7 @@ class LLMConfig(_ConfigAttr[GenerationConfig, SamplingParams]): f = dantic.attrs_to_options(name, field, cls.__openllm_model_name__, typ=ty, suffix_sampling=True)(f) f = cog.optgroup.group('SamplingParams sampling options')(f) - total_keys = set(attr.fields_dict(cls.__openllm_generation_class__)) | set( - attr.fields_dict(cls.__openllm_sampling_class__) - ) + total_keys = set(attr.fields_dict(cls.__openllm_generation_class__)) | set(attr.fields_dict(cls.__openllm_sampling_class__)) if len(cls.__openllm_accepted_keys__.difference(total_keys)) == 0: return t.cast('click.Command', f) @@ -1524,16 +1346,12 @@ class LLMConfig(_ConfigAttr[GenerationConfig, SamplingParams]): # deprecated def to_generation_config(self, return_as_dict: bool = False) -> transformers.GenerationConfig | DictStrAny: - warnings.warn( - "'to_generation_config' is deprecated, please use 'inference_options' instead.", DeprecationWarning, stacklevel=3 - ) + warnings.warn("'to_generation_config' is deprecated, please use 'inference_options' instead.", DeprecationWarning, stacklevel=3) _, config = self.inference_options(None, 'hf') return config.to_dict() if return_as_dict else config def to_sampling_config(self) -> vllm.SamplingParams: - warnings.warn( - "'to_sampling_config' is deprecated, please use 'inference_options' instead.", DeprecationWarning, stacklevel=3 - ) + warnings.warn("'to_sampling_config' is deprecated, please use 'inference_options' instead.", DeprecationWarning, stacklevel=3) return self.inference_options(None, 'vllm')[-1] @@ -1577,8 +1395,5 @@ def structure_llm_config(data: t.Any, cls: type[LLMConfig]) -> LLMConfig: converter.register_structure_hook_func(lambda cls: lenient_issubclass(cls, LLMConfig), structure_llm_config) openllm_home = os.path.expanduser( - os.environ.get( - 'OPENLLM_HOME', - os.path.join(os.environ.get('XDG_CACHE_HOME', os.path.join(os.path.expanduser('~'), '.cache')), 'openllm'), - ) + os.environ.get('OPENLLM_HOME', os.path.join(os.environ.get('XDG_CACHE_HOME', os.path.join(os.path.expanduser('~'), '.cache')), 'openllm')) ) diff --git a/openllm-core/src/openllm_core/_schemas.py b/openllm-core/src/openllm_core/_schemas.py index fb702b79..dcbfa431 100644 --- a/openllm-core/src/openllm_core/_schemas.py +++ b/openllm-core/src/openllm_core/_schemas.py @@ -14,10 +14,12 @@ if t.TYPE_CHECKING: from ._typing_compat import Self, LiteralString + class MessageParam(t.TypedDict): role: t.Union[t.Literal['system', 'user', 'assistant'], LiteralString] content: str + @attr.define class _SchemaMixin: def model_dump(self) -> dict[str, t.Any]: @@ -60,12 +62,7 @@ class GenerationInput(_SchemaMixin): return cls.from_llm_config(AutoConfig.for_model(model_name, **attrs)) def model_dump(self) -> dict[str, t.Any]: - return { - 'prompt': self.prompt, - 'stop': self.stop, - 'llm_config': self.llm_config.model_dump(flatten=True), - 'adapter_name': self.adapter_name, - } + return {'prompt': self.prompt, 'stop': self.stop, 'llm_config': self.llm_config.model_dump(flatten=True), 'adapter_name': self.adapter_name} @classmethod def from_llm_config(cls, llm_config: LLMConfig) -> type[GenerationInput]: @@ -120,6 +117,7 @@ class CompletionChunk(_SchemaMixin): def model_dump_json(self) -> str: return orjson.dumps(self.model_dump(), option=orjson.OPT_NON_STR_KEYS).decode('utf-8') + @attr.define class GenerationOutput(_SchemaMixin): prompt: str @@ -172,7 +170,8 @@ class GenerationOutput(_SchemaMixin): @classmethod def from_dict(cls, structured: dict[str, t.Any]) -> GenerationOutput: - if structured['prompt_logprobs']: structured['prompt_logprobs'] = [{int(k): v for k,v in it.items()} if it else None for it in structured['prompt_logprobs']] + if structured['prompt_logprobs']: + structured['prompt_logprobs'] = [{int(k): v for k, v in it.items()} if it else None for it in structured['prompt_logprobs']] return cls( prompt=structured['prompt'], finished=structured['finished'], @@ -186,7 +185,7 @@ class GenerationOutput(_SchemaMixin): token_ids=it['token_ids'], cumulative_logprob=it['cumulative_logprob'], finish_reason=it['finish_reason'], - logprobs=[{int(k): v for k,v in s.items()} for s in it['logprobs']] if it['logprobs'] else None, + logprobs=[{int(k): v for k, v in s.items()} for s in it['logprobs']] if it['logprobs'] else None, ) for it in structured['outputs'] ], @@ -216,4 +215,5 @@ class GenerationOutput(_SchemaMixin): def model_dump_json(self) -> str: return orjson.dumps(self.model_dump(), option=orjson.OPT_NON_STR_KEYS).decode('utf-8') + converter.register_structure_hook_func(lambda cls: attr.has(cls) and issubclass(cls, GenerationOutput), lambda data, cls: cls.from_dict(data)) diff --git a/openllm-core/src/openllm_core/_typing_compat.py b/openllm-core/src/openllm_core/_typing_compat.py index 9dc392a1..0b4004d9 100644 --- a/openllm-core/src/openllm_core/_typing_compat.py +++ b/openllm-core/src/openllm_core/_typing_compat.py @@ -7,7 +7,11 @@ import attr M = TypeVar('M') T = TypeVar('T') -def get_literal_args(typ: Any) -> Tuple[str, ...]: return getattr(typ, '__args__', tuple()) + +def get_literal_args(typ: Any) -> Tuple[str, ...]: + return getattr(typ, '__args__', tuple()) + + AnyCallable = Callable[..., Any] DictStrAny = Dict[str, Any] ListStr = List[str] @@ -19,7 +23,12 @@ LiteralBackend = Literal['pt', 'vllm', 'ctranslate', 'triton'] # TODO: ggml AdapterType = Literal['lora', 'adalora', 'adaption_prompt', 'prefix_tuning', 'p_tuning', 'prompt_tuning', 'ia3', 'loha', 'lokr'] LiteralVersionStrategy = Literal['release', 'nightly', 'latest', 'custom'] -class AdapterTuple(Tuple[Any, ...]): adapter_id: str; name: str ; config: DictStrAny + +class AdapterTuple(Tuple[Any, ...]): + adapter_id: str + name: str + config: DictStrAny + AdapterMap = Dict[AdapterType, Tuple[AdapterTuple, ...]] @@ -45,12 +54,7 @@ else: if sys.version_info[:2] >= (3, 10): from typing import Concatenate as Concatenate, ParamSpec as ParamSpec, TypeAlias as TypeAlias, TypeGuard as TypeGuard else: - from typing_extensions import ( - Concatenate as Concatenate, - ParamSpec as ParamSpec, - TypeAlias as TypeAlias, - TypeGuard as TypeGuard, - ) + from typing_extensions import Concatenate as Concatenate, ParamSpec as ParamSpec, TypeAlias as TypeAlias, TypeGuard as TypeGuard if sys.version_info[:2] >= (3, 9): from typing import Annotated as Annotated diff --git a/openllm-core/src/openllm_core/config/configuration_auto.py b/openllm-core/src/openllm_core/config/configuration_auto.py index 9a76df84..49b3a473 100644 --- a/openllm-core/src/openllm_core/config/configuration_auto.py +++ b/openllm-core/src/openllm_core/config/configuration_auto.py @@ -22,28 +22,27 @@ else: # NOTE: This is the entrypoint when adding new model config CONFIG_MAPPING_NAMES = OrderedDict( - sorted( - [ - ('flan_t5', 'FlanT5Config'), - ('baichuan', 'BaichuanConfig'), - ('chatglm', 'ChatGLMConfig'), - ('falcon', 'FalconConfig'), - ('gpt_neox', 'GPTNeoXConfig'), - ('dolly_v2', 'DollyV2Config'), - ('stablelm', 'StableLMConfig'), - ('llama', 'LlamaConfig'), - ('mpt', 'MPTConfig'), - ('opt', 'OPTConfig'), - ('phi', 'PhiConfig'), - ('qwen', 'QwenConfig'), - ('starcoder', 'StarCoderConfig'), - ('mistral', 'MistralConfig'), - ('mixtral', 'MixtralConfig'), - ('yi', 'YiConfig'), - ] - ) + sorted([ + ('flan_t5', 'FlanT5Config'), + ('baichuan', 'BaichuanConfig'), + ('chatglm', 'ChatGLMConfig'), + ('falcon', 'FalconConfig'), + ('gpt_neox', 'GPTNeoXConfig'), + ('dolly_v2', 'DollyV2Config'), + ('stablelm', 'StableLMConfig'), + ('llama', 'LlamaConfig'), + ('mpt', 'MPTConfig'), + ('opt', 'OPTConfig'), + ('phi', 'PhiConfig'), + ('qwen', 'QwenConfig'), + ('starcoder', 'StarCoderConfig'), + ('mistral', 'MistralConfig'), + ('mixtral', 'MixtralConfig'), + ('yi', 'YiConfig'), + ]) ) + class _LazyConfigMapping(OrderedDictType, ReprMixin): def __init__(self, mapping: OrderedDict[LiteralString, LiteralString]): self._mapping = mapping @@ -98,12 +97,7 @@ class _LazyConfigMapping(OrderedDictType, ReprMixin): CONFIG_MAPPING: dict[LiteralString, type[openllm_core.LLMConfig]] = _LazyConfigMapping(CONFIG_MAPPING_NAMES) # The below handle special alias when we call underscore to the name directly without processing camelcase first. -CONFIG_NAME_ALIASES: dict[str, str] = { - 'chat_glm': 'chatglm', - 'stable_lm': 'stablelm', - 'star_coder': 'starcoder', - 'gpt_neo_x': 'gpt_neox', -} +CONFIG_NAME_ALIASES: dict[str, str] = {'chat_glm': 'chatglm', 'stable_lm': 'stablelm', 'star_coder': 'starcoder', 'gpt_neo_x': 'gpt_neox'} CONFIG_FILE_NAME = 'config.json' @@ -167,9 +161,7 @@ class AutoConfig: model_name = inflection.underscore(model_name) if model_name in CONFIG_MAPPING: return CONFIG_MAPPING[model_name].model_construct_env(**attrs) - raise ValueError( - f"Unrecognized configuration class for {model_name}. Model name should be one of {', '.join(CONFIG_MAPPING.keys())}." - ) + raise ValueError(f"Unrecognized configuration class for {model_name}. Model name should be one of {', '.join(CONFIG_MAPPING.keys())}.") @classmethod def infer_class_from_name(cls, name: str) -> type[openllm_core.LLMConfig]: @@ -178,9 +170,7 @@ class AutoConfig: model_name = CONFIG_NAME_ALIASES[model_name] if model_name in CONFIG_MAPPING: return CONFIG_MAPPING[model_name] - raise ValueError( - f"Unrecognized configuration class for {model_name}. Model name should be one of {', '.join(CONFIG_MAPPING.keys())}." - ) + raise ValueError(f"Unrecognized configuration class for {model_name}. Model name should be one of {', '.join(CONFIG_MAPPING.keys())}.") _cached_mapping = None @@ -201,9 +191,7 @@ class AutoConfig: config_file = llm.bentomodel.path_of(CONFIG_FILE_NAME) except OpenLLMException as err: if not is_transformers_available(): - raise MissingDependencyError( - "Requires 'transformers' to be available. Do 'pip install transformers'" - ) from err + raise MissingDependencyError("Requires 'transformers' to be available. Do 'pip install transformers'") from err from transformers.utils import cached_file try: @@ -220,6 +208,4 @@ class AutoConfig: for architecture in loaded_config['architectures']: if architecture in cls._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE(): return cls.infer_class_from_name(cls._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE()[architecture]) - raise ValueError( - f"Failed to determine config class for '{llm.model_id}'. Make sure {llm.model_id} is saved with openllm." - ) + raise ValueError(f"Failed to determine config class for '{llm.model_id}'. Make sure {llm.model_id} is saved with openllm.") diff --git a/openllm-core/src/openllm_core/config/configuration_dolly_v2.py b/openllm-core/src/openllm_core/config/configuration_dolly_v2.py index 1114c22a..0fc48b1e 100644 --- a/openllm-core/src/openllm_core/config/configuration_dolly_v2.py +++ b/openllm-core/src/openllm_core/config/configuration_dolly_v2.py @@ -9,9 +9,7 @@ if t.TYPE_CHECKING: INSTRUCTION_KEY = '### Instruction:' RESPONSE_KEY = '### Response:' END_KEY = '### End' -INTRO_BLURB = ( - 'Below is an instruction that describes a task. Write a response that appropriately completes the request.' -) +INTRO_BLURB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.' def get_special_token_id(tokenizer: transformers.PreTrainedTokenizer, key: str) -> int: diff --git a/openllm-core/src/openllm_core/config/configuration_flan_t5.py b/openllm-core/src/openllm_core/config/configuration_flan_t5.py index 4de691f6..8c4c42c4 100644 --- a/openllm-core/src/openllm_core/config/configuration_flan_t5.py +++ b/openllm-core/src/openllm_core/config/configuration_flan_t5.py @@ -18,13 +18,7 @@ class FlanT5Config(openllm_core.LLMConfig): 'backend': ('pt',), # NOTE: See https://www.philschmid.de/fine-tune-flan-t5. No specific template found, but seems to have the same dialogue style 'default_id': 'google/flan-t5-large', - 'model_ids': [ - 'google/flan-t5-small', - 'google/flan-t5-base', - 'google/flan-t5-large', - 'google/flan-t5-xl', - 'google/flan-t5-xxl', - ], + 'model_ids': ['google/flan-t5-small', 'google/flan-t5-base', 'google/flan-t5-large', 'google/flan-t5-xl', 'google/flan-t5-xxl'], } class GenerationConfig: diff --git a/openllm-core/src/openllm_core/config/configuration_llama.py b/openllm-core/src/openllm_core/config/configuration_llama.py index ff88db9f..8f034453 100644 --- a/openllm-core/src/openllm_core/config/configuration_llama.py +++ b/openllm-core/src/openllm_core/config/configuration_llama.py @@ -38,9 +38,7 @@ class LlamaConfig(openllm_core.LLMConfig): 'NousResearch/llama-2-13b-hf', 'NousResearch/llama-2-7b-hf', ], - 'fine_tune_strategies': ( - {'adapter_type': 'lora', 'r': 64, 'lora_alpha': 16, 'lora_dropout': 0.1, 'bias': 'none'}, - ), + 'fine_tune_strategies': ({'adapter_type': 'lora', 'r': 64, 'lora_alpha': 16, 'lora_dropout': 0.1, 'bias': 'none'},), } class GenerationConfig: @@ -56,11 +54,7 @@ class LlamaConfig(openllm_core.LLMConfig): @property def template(self) -> str: return '{start_key} {sys_key}\n{system_message}\n{sys_key}\n\n{instruction}\n{end_key}\n'.format( - start_key=SINST_KEY, - sys_key=SYS_KEY, - system_message='{system_message}', - instruction='{instruction}', - end_key=EINST_KEY, + start_key=SINST_KEY, sys_key=SYS_KEY, system_message='{system_message}', instruction='{instruction}', end_key=EINST_KEY ) @property diff --git a/openllm-core/src/openllm_core/config/configuration_mistral.py b/openllm-core/src/openllm_core/config/configuration_mistral.py index fe8a16d0..21cc1723 100644 --- a/openllm-core/src/openllm_core/config/configuration_mistral.py +++ b/openllm-core/src/openllm_core/config/configuration_mistral.py @@ -30,9 +30,7 @@ class MistralConfig(openllm_core.LLMConfig): 'mistralai/Mistral-7B-Instruct-v0.1', 'mistralai/Mistral-7B-v0.1', ], - 'fine_tune_strategies': ( - {'adapter_type': 'lora', 'r': 64, 'lora_alpha': 16, 'lora_dropout': 0.1, 'bias': 'none'}, - ), + 'fine_tune_strategies': ({'adapter_type': 'lora', 'r': 64, 'lora_alpha': 16, 'lora_dropout': 0.1, 'bias': 'none'},), } class GenerationConfig: @@ -48,26 +46,30 @@ class MistralConfig(openllm_core.LLMConfig): # NOTE: see https://docs.mistral.ai/usage/guardrailing/ and https://docs.mistral.ai/llm/mistral-instruct-v0.1 @property def template(self) -> str: - return '''{start_key}{start_inst} {system_message} {instruction} {end_inst}\n'''.format( - start_inst=SINST_KEY, - end_inst=EINST_KEY, - start_key=BOS_TOKEN, - system_message='{system_message}', - instruction='{instruction}', + return """{start_key}{start_inst} {system_message} {instruction} {end_inst}\n""".format( + start_inst=SINST_KEY, end_inst=EINST_KEY, start_key=BOS_TOKEN, system_message='{system_message}', instruction='{instruction}' ) # NOTE: https://docs.mistral.ai/usage/guardrailing/ @property def system_message(self) -> str: - return '''Always assist with care, respect, and truth. Respond with utmost utility yet securely. Avoid harmful, unethical, prejudiced, or negative content. Ensure replies promote fairness and positivity.''' + return """Always assist with care, respect, and truth. Respond with utmost utility yet securely. Avoid harmful, unethical, prejudiced, or negative content. Ensure replies promote fairness and positivity.""" @property def chat_template(self) -> str: - return repr("{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}") + return repr( + "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" + ) @property def chat_messages(self) -> list[MessageParam]: from openllm_core._schemas import MessageParam - return [MessageParam(role='user', content='What is your favourite condiment?'), - MessageParam(role='assistant', content="Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"), - MessageParam(role='user', content='Do you have mayonnaise recipes?')] + + return [ + MessageParam(role='user', content='What is your favourite condiment?'), + MessageParam( + role='assistant', + content="Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!", + ), + MessageParam(role='user', content='Do you have mayonnaise recipes?'), + ] diff --git a/openllm-core/src/openllm_core/config/configuration_mixtral.py b/openllm-core/src/openllm_core/config/configuration_mixtral.py index 1c514f10..62e9e1c7 100644 --- a/openllm-core/src/openllm_core/config/configuration_mixtral.py +++ b/openllm-core/src/openllm_core/config/configuration_mixtral.py @@ -33,26 +33,30 @@ class MixtralConfig(openllm_core.LLMConfig): # NOTE: see https://docs.mistral.ai/usage/guardrailing/ and https://docs.mistral.ai/llm/mistral-instruct-v0.1 @property def template(self) -> str: - return '''{start_key}{start_inst} {system_message} {instruction} {end_inst}\n'''.format( - start_inst=SINST_KEY, - end_inst=EINST_KEY, - start_key=BOS_TOKEN, - system_message='{system_message}', - instruction='{instruction}', + return """{start_key}{start_inst} {system_message} {instruction} {end_inst}\n""".format( + start_inst=SINST_KEY, end_inst=EINST_KEY, start_key=BOS_TOKEN, system_message='{system_message}', instruction='{instruction}' ) # NOTE: https://docs.mistral.ai/usage/guardrailing/ @property def system_message(self) -> str: - return '''Always assist with care, respect, and truth. Respond with utmost utility yet securely. Avoid harmful, unethical, prejudiced, or negative content. Ensure replies promote fairness and positivity.''' + return """Always assist with care, respect, and truth. Respond with utmost utility yet securely. Avoid harmful, unethical, prejudiced, or negative content. Ensure replies promote fairness and positivity.""" @property def chat_template(self) -> str: - return repr("{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}") + return repr( + "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" + ) @property def chat_messages(self) -> list[MessageParam]: from openllm_core._schemas import MessageParam - return [MessageParam(role='user', content='What is your favourite condiment?'), - MessageParam(role='assistant', content="Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"), - MessageParam(role='user', content='Do you have mayonnaise recipes?')] + + return [ + MessageParam(role='user', content='What is your favourite condiment?'), + MessageParam( + role='assistant', + content="Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!", + ), + MessageParam(role='user', content='Do you have mayonnaise recipes?'), + ] diff --git a/openllm-core/src/openllm_core/config/configuration_opt.py b/openllm-core/src/openllm_core/config/configuration_opt.py index 1676896d..e403f726 100644 --- a/openllm-core/src/openllm_core/config/configuration_opt.py +++ b/openllm-core/src/openllm_core/config/configuration_opt.py @@ -20,23 +20,9 @@ class OPTConfig(openllm_core.LLMConfig): 'url': 'https://huggingface.co/docs/transformers/model_doc/opt', 'default_id': 'facebook/opt-1.3b', 'architecture': 'OPTForCausalLM', - 'model_ids': [ - 'facebook/opt-125m', - 'facebook/opt-350m', - 'facebook/opt-1.3b', - 'facebook/opt-2.7b', - 'facebook/opt-6.7b', - 'facebook/opt-66b', - ], + 'model_ids': ['facebook/opt-125m', 'facebook/opt-350m', 'facebook/opt-1.3b', 'facebook/opt-2.7b', 'facebook/opt-6.7b', 'facebook/opt-66b'], 'fine_tune_strategies': ( - { - 'adapter_type': 'lora', - 'r': 16, - 'lora_alpha': 32, - 'target_modules': ['q_proj', 'v_proj'], - 'lora_dropout': 0.05, - 'bias': 'none', - }, + {'adapter_type': 'lora', 'r': 16, 'lora_alpha': 32, 'target_modules': ['q_proj', 'v_proj'], 'lora_dropout': 0.05, 'bias': 'none'}, ), } diff --git a/openllm-core/src/openllm_core/config/configuration_phi.py b/openllm-core/src/openllm_core/config/configuration_phi.py index e326b693..40b537e9 100644 --- a/openllm-core/src/openllm_core/config/configuration_phi.py +++ b/openllm-core/src/openllm_core/config/configuration_phi.py @@ -5,6 +5,7 @@ import openllm_core, typing as t if t.TYPE_CHECKING: from openllm_core._schemas import MessageParam + class PhiConfig(openllm_core.LLMConfig): """The language model phi-1.5 is a Transformer with 1.3 billion parameters. @@ -26,9 +27,7 @@ class PhiConfig(openllm_core.LLMConfig): 'default_id': 'microsoft/phi-1_5', 'serialisation': 'safetensors', 'model_ids': ['microsoft/phi-1_5'], - 'fine_tune_strategies': ( - {'adapter_type': 'lora', 'r': 64, 'lora_alpha': 16, 'lora_dropout': 0.1, 'bias': 'none'}, - ), + 'fine_tune_strategies': ({'adapter_type': 'lora', 'r': 64, 'lora_alpha': 16, 'lora_dropout': 0.1, 'bias': 'none'},), } class GenerationConfig: @@ -39,11 +38,16 @@ class PhiConfig(openllm_core.LLMConfig): @property def chat_template(self) -> str: - return repr("{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\n\n'}}{% endfor %}{% if add_generation_prompt %}{{ 'assistant: ' }}{% endif %}") + return repr( + "{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\n\n'}}{% endfor %}{% if add_generation_prompt %}{{ 'assistant: ' }}{% endif %}" + ) @property def chat_messages(self) -> list[MessageParam]: from openllm_core._schemas import MessageParam - return [MessageParam(role='user', content="I don't know why, I'm struggling to maintain focus while studying. Any suggestions?"), - MessageParam(role='assistant', content='Have you tried using a timer? It can help you stay on track and avoid distractions.'), - MessageParam(role='user', content="That's a good idea. I'll give it a try. What else can I do to boost my productivity?")] + + return [ + MessageParam(role='user', content="I don't know why, I'm struggling to maintain focus while studying. Any suggestions?"), + MessageParam(role='assistant', content='Have you tried using a timer? It can help you stay on track and avoid distractions.'), + MessageParam(role='user', content="That's a good idea. I'll give it a try. What else can I do to boost my productivity?"), + ] diff --git a/openllm-core/src/openllm_core/config/configuration_qwen.py b/openllm-core/src/openllm_core/config/configuration_qwen.py index a8929ff3..38b8bc64 100644 --- a/openllm-core/src/openllm_core/config/configuration_qwen.py +++ b/openllm-core/src/openllm_core/config/configuration_qwen.py @@ -2,6 +2,7 @@ from __future__ import annotations import openllm_core + class QwenConfig(openllm_core.LLMConfig): """Qwen-7B is the 7B-parameter version of the large language model series, Qwen (abbr. Tongyi Qianwen), proposed by Alibaba Cloud. Qwen-14B is a Transformer-based large language model, diff --git a/openllm-core/src/openllm_core/config/configuration_stablelm.py b/openllm-core/src/openllm_core/config/configuration_stablelm.py index 0918d49e..6402196c 100644 --- a/openllm-core/src/openllm_core/config/configuration_stablelm.py +++ b/openllm-core/src/openllm_core/config/configuration_stablelm.py @@ -43,9 +43,9 @@ class StableLMConfig(openllm_core.LLMConfig): @property def system_message(self) -> str: - return '''<|SYSTEM|># StableLM Tuned (Alpha version) + return """<|SYSTEM|># StableLM Tuned (Alpha version) - StableLM is a helpful and harmless open-source AI language model developed by StabilityAI. - StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. - StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes. - StableLM will refuse to participate in anything that could harm a human. -''' +""" diff --git a/openllm-core/src/openllm_core/exceptions.py b/openllm-core/src/openllm_core/exceptions.py index fca392c5..1b162769 100644 --- a/openllm-core/src/openllm_core/exceptions.py +++ b/openllm-core/src/openllm_core/exceptions.py @@ -1,11 +1,11 @@ -'''Base exceptions for OpenLLM. This extends BentoML exceptions.''' +"""Base exceptions for OpenLLM. This extends BentoML exceptions.""" from __future__ import annotations from http import HTTPStatus class OpenLLMException(Exception): - '''Base class for all OpenLLM exceptions. This shares similar interface with BentoMLException.''' + """Base class for all OpenLLM exceptions. This shares similar interface with BentoMLException.""" error_code = HTTPStatus.INTERNAL_SERVER_ERROR @@ -15,28 +15,28 @@ class OpenLLMException(Exception): class GpuNotAvailableError(OpenLLMException): - '''Raised when there is no GPU available in given system.''' + """Raised when there is no GPU available in given system.""" class ValidationError(OpenLLMException): - '''Raised when a validation fails.''' + """Raised when a validation fails.""" class ForbiddenAttributeError(OpenLLMException): - '''Raised when using an _internal field.''' + """Raised when using an _internal field.""" class MissingAnnotationAttributeError(OpenLLMException): - '''Raised when a field under openllm.LLMConfig is missing annotations.''' + """Raised when a field under openllm.LLMConfig is missing annotations.""" class MissingDependencyError(BaseException): - '''Raised when a dependency is missing.''' + """Raised when a dependency is missing.""" class Error(BaseException): - '''To be used instead of naked raise.''' + """To be used instead of naked raise.""" class FineTuneStrategyNotSupportedError(OpenLLMException): - '''Raised when a fine-tune strategy is not supported for given LLM.''' + """Raised when a fine-tune strategy is not supported for given LLM.""" diff --git a/openllm-core/src/openllm_core/utils/__init__.py b/openllm-core/src/openllm_core/utils/__init__.py index 88e0aaec..085d6423 100644 --- a/openllm-core/src/openllm_core/utils/__init__.py +++ b/openllm-core/src/openllm_core/utils/__init__.py @@ -15,6 +15,8 @@ DEV_DEBUG_VAR = 'DEBUG' # equivocal setattr to save one lookup per assignment _object_setattr = object.__setattr__ logger = logging.getLogger(__name__) + + @functools.lru_cache(maxsize=1) def _WithArgsTypes() -> tuple[type[t.Any], ...]: try: @@ -23,95 +25,176 @@ def _WithArgsTypes() -> tuple[type[t.Any], ...]: _TypingGenericAlias = () # python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on) # _GenericAlias is the actual GenericAlias implementation return (_TypingGenericAlias,) if sys.version_info < (3, 10) else (t._GenericAlias, types.GenericAlias, types.UnionType) + + def lenient_issubclass(cls, class_or_tuple): try: return isinstance(cls, type) and issubclass(cls, class_or_tuple) except TypeError: - if isinstance(cls, _WithArgsTypes()): return False + if isinstance(cls, _WithArgsTypes()): + return False raise + + def resolve_user_filepath(filepath, ctx): _path = os.path.expanduser(os.path.expandvars(filepath)) - if os.path.exists(_path): return os.path.realpath(_path) + if os.path.exists(_path): + return os.path.realpath(_path) # Try finding file in ctx if provided if ctx: _path = os.path.expanduser(os.path.join(ctx, filepath)) - if os.path.exists(_path): return os.path.realpath(_path) + if os.path.exists(_path): + return os.path.realpath(_path) raise FileNotFoundError(f'file {filepath} not found') + + # this is the supress version of resolve_user_filepath def resolve_filepath(path, ctx=None): try: return resolve_user_filepath(path, ctx) except FileNotFoundError: return path + + def check_bool_env(env, default=True): v = os.getenv(env, default=str(default)).upper() - if v.isdigit(): return bool(int(v)) # special check for digits + if v.isdigit(): + return bool(int(v)) # special check for digits return v in ENV_VARS_TRUE_VALUES -def calc_dir_size(path): return sum(f.stat().st_size for f in _Path(path).glob('**/*') if f.is_file()) + + +def calc_dir_size(path): + return sum(f.stat().st_size for f in _Path(path).glob('**/*') if f.is_file()) + + @functools.lru_cache(maxsize=128) -def generate_hash_from_file(f, algorithm='sha1'): return str(getattr(hashlib, algorithm)(str(os.path.getmtime(resolve_filepath(f))).encode()).hexdigest()) +def generate_hash_from_file(f, algorithm='sha1'): + return str(getattr(hashlib, algorithm)(str(os.path.getmtime(resolve_filepath(f))).encode()).hexdigest()) + + def getenv(env, default=None, var=None): env_key = {env.upper(), f'OPENLLM_{env.upper()}'} - if var is not None: env_key = set(var) | env_key + if var is not None: + env_key = set(var) | env_key + def callback(k: str) -> t.Any: _var = os.getenv(k) - if _var and k.startswith('OPENLLM_'): logger.warning("Using '%s' environment is deprecated, use '%s' instead.", k.upper(), k[8:].upper()) + if _var and k.startswith('OPENLLM_'): + logger.warning("Using '%s' environment is deprecated, use '%s' instead.", k.upper(), k[8:].upper()) return _var + return first_not_none(*(callback(k) for k in env_key), default=default) -def field_env_key(key, suffix=None): return '_'.join(filter(None, map(str.upper, ['OPENLLM', suffix.strip('_') if suffix else '', key]))) -def get_debug_mode(): return check_bool_env(DEBUG_ENV_VAR, False) if (not DEBUG and DEBUG_ENV_VAR in os.environ) else DEBUG + + +def field_env_key(key, suffix=None): + return '_'.join(filter(None, map(str.upper, ['OPENLLM', suffix.strip('_') if suffix else '', key]))) + + +def get_debug_mode(): + return check_bool_env(DEBUG_ENV_VAR, False) if (not DEBUG and DEBUG_ENV_VAR in os.environ) else DEBUG + + def get_quiet_mode(): - if QUIET_ENV_VAR in os.environ: return check_bool_env(QUIET_ENV_VAR, False) - if DEBUG: return False + if QUIET_ENV_VAR in os.environ: + return check_bool_env(QUIET_ENV_VAR, False) + if DEBUG: + return False return False -def get_disable_warnings(): return check_bool_env(WARNING_ENV_VAR, False) + + +def get_disable_warnings(): + return check_bool_env(WARNING_ENV_VAR, False) + + def set_disable_warnings(disable=True): - if disable: os.environ[WARNING_ENV_VAR] = str(disable) + if disable: + os.environ[WARNING_ENV_VAR] = str(disable) + + def set_debug_mode(enabled, level=1): - if enabled: os.environ[DEV_DEBUG_VAR] = str(level) + if enabled: + os.environ[DEV_DEBUG_VAR] = str(level) os.environ.update({ - DEBUG_ENV_VAR: str(enabled), QUIET_ENV_VAR: str(not enabled), # - _GRPC_DEBUG_ENV_VAR: 'DEBUG' if enabled else 'ERROR', 'CT2_VERBOSE': '3', # + DEBUG_ENV_VAR: str(enabled), + QUIET_ENV_VAR: str(not enabled), # + _GRPC_DEBUG_ENV_VAR: 'DEBUG' if enabled else 'ERROR', + 'CT2_VERBOSE': '3', # }) set_disable_warnings(not enabled) + + def set_quiet_mode(enabled): os.environ.update({ - QUIET_ENV_VAR: str(enabled), DEBUG_ENV_VAR: str(not enabled), # - _GRPC_DEBUG_ENV_VAR: 'NONE', 'CT2_VERBOSE': '-1', # + QUIET_ENV_VAR: str(enabled), + DEBUG_ENV_VAR: str(not enabled), # + _GRPC_DEBUG_ENV_VAR: 'NONE', + 'CT2_VERBOSE': '-1', # }) set_disable_warnings(enabled) -def gen_random_uuid(prefix: str | None = None) -> str: return '-'.join([prefix or 'openllm', str(uuid.uuid4().hex)]) + + +def gen_random_uuid(prefix: str | None = None) -> str: + return '-'.join([prefix or 'openllm', str(uuid.uuid4().hex)]) + + # NOTE: `compose` any number of unary functions into a single unary function # compose(f, g, h)(x) == f(g(h(x))); compose(f, g, h)(x, y, z) == f(g(h(x, y, z))) -def compose(*funcs): return functools.reduce(lambda f1, f2: lambda *args, **kwargs: f1(f2(*args, **kwargs)), funcs) +def compose(*funcs): + return functools.reduce(lambda f1, f2: lambda *args, **kwargs: f1(f2(*args, **kwargs)), funcs) + + # NOTE: `apply` a transform function that is invoked on results returned from the decorated function # apply(reversed)(func)(*args, **kwargs) == reversed(func(*args, **kwargs)) -def apply(transform): return lambda func: functools.wraps(func)(compose(transform, func)) -def validate_is_path(maybe_path): return os.path.exists(os.path.dirname(resolve_filepath(maybe_path))) -def first_not_none(*args, default=None): return next((arg for arg in args if arg is not None), default) +def apply(transform): + return lambda func: functools.wraps(func)(compose(transform, func)) + + +def validate_is_path(maybe_path): + return os.path.exists(os.path.dirname(resolve_filepath(maybe_path))) + + +def first_not_none(*args, default=None): + return next((arg for arg in args if arg is not None), default) + + def generate_context(framework_name): from bentoml._internal.models.model import ModelContext + framework_versions = { 'transformers': pkg.get_pkg_version('transformers'), 'safetensors': pkg.get_pkg_version('safetensors'), 'optimum': pkg.get_pkg_version('optimum'), 'accelerate': pkg.get_pkg_version('accelerate'), } - if iutils.is_torch_available(): framework_versions['torch'] = pkg.get_pkg_version('torch') - if iutils.is_ctranslate_available(): framework_versions['ctranslate2'] = pkg.get_pkg_version('ctranslate2') - if iutils.is_vllm_available(): framework_versions['vllm'] = pkg.get_pkg_version('vllm') - if iutils.is_autoawq_available(): framework_versions['autoawq'] = pkg.get_pkg_version('autoawq') - if iutils.is_autogptq_available(): framework_versions['autogptq'] = pkg.get_pkg_version('auto_gptq') - if iutils.is_bentoml_available(): framework_versions['bentoml'] = pkg.get_pkg_version('bentoml') - if iutils.is_triton_available(): framework_versions['triton'] = pkg.get_pkg_version('triton') - if iutils.is_flash_attn_2_available(): framework_versions['flash_attn'] = pkg.get_pkg_version('flash_attn') + if iutils.is_torch_available(): + framework_versions['torch'] = pkg.get_pkg_version('torch') + if iutils.is_ctranslate_available(): + framework_versions['ctranslate2'] = pkg.get_pkg_version('ctranslate2') + if iutils.is_vllm_available(): + framework_versions['vllm'] = pkg.get_pkg_version('vllm') + if iutils.is_autoawq_available(): + framework_versions['autoawq'] = pkg.get_pkg_version('autoawq') + if iutils.is_autogptq_available(): + framework_versions['autogptq'] = pkg.get_pkg_version('auto_gptq') + if iutils.is_bentoml_available(): + framework_versions['bentoml'] = pkg.get_pkg_version('bentoml') + if iutils.is_triton_available(): + framework_versions['triton'] = pkg.get_pkg_version('triton') + if iutils.is_flash_attn_2_available(): + framework_versions['flash_attn'] = pkg.get_pkg_version('flash_attn') return ModelContext(framework_name=framework_name, framework_versions=framework_versions) + + @functools.lru_cache(maxsize=1) def in_notebook(): try: - from IPython.core.getipython import get_ipython; return 'IPKernelApp' in get_ipython().config + from IPython.core.getipython import get_ipython + + return 'IPKernelApp' in get_ipython().config except Exception: return False + + def flatten_attrs(**attrs): _TOKENIZER_PREFIX = '_tokenizer_' tokenizer_attrs = {k[len(_TOKENIZER_PREFIX) :]: v for k, v in attrs.items() if k.startswith(_TOKENIZER_PREFIX)} @@ -119,35 +202,52 @@ def flatten_attrs(**attrs): if k.startswith(_TOKENIZER_PREFIX): del attrs[k] return attrs, tokenizer_attrs + + # Special debug flag controled via DEBUG DEBUG = sys.flags.dev_mode or (not sys.flags.ignore_environment and check_bool_env(DEV_DEBUG_VAR, default=False)) # Whether to show the codenge for debug purposes SHOW_CODEGEN = DEBUG and os.environ.get(DEV_DEBUG_VAR, str(0)).isdigit() and int(os.environ.get(DEV_DEBUG_VAR, str(0))) > 3 # MYPY is like t.TYPE_CHECKING, but reserved for Mypy plugins MYPY = False + + class ExceptionFilter(logging.Filter): def __init__(self, exclude_exceptions: list[type[Exception]] | None = None, **kwargs: t.Any): - if exclude_exceptions is None: exclude_exceptions = [] + if exclude_exceptions is None: + exclude_exceptions = [] try: from circus.exc import ConflictError - if ConflictError not in exclude_exceptions: exclude_exceptions.append(ConflictError) + + if ConflictError not in exclude_exceptions: + exclude_exceptions.append(ConflictError) except ImportError: pass super(ExceptionFilter, self).__init__(**kwargs) self.EXCLUDE_EXCEPTIONS = exclude_exceptions + def filter(self, record: logging.LogRecord) -> bool: if record.exc_info: etype, _, _ = record.exc_info if etype is not None: for exc in self.EXCLUDE_EXCEPTIONS: - if issubclass(etype, exc): return False + if issubclass(etype, exc): + return False return True + + class InfoFilter(logging.Filter): - def filter(self, record: logging.LogRecord) -> bool: return logging.INFO <= record.levelno < logging.WARNING + def filter(self, record: logging.LogRecord) -> bool: + return logging.INFO <= record.levelno < logging.WARNING + + class WarningFilter(logging.Filter): # FIXME: Why does this not work? def filter(self, record: logging.LogRecord) -> bool: - if get_disable_warnings(): return record.levelno >= logging.ERROR + if get_disable_warnings(): + return record.levelno >= logging.ERROR return True + + _LOGGING_CONFIG = { 'version': 1, 'disable_existing_loggers': True, @@ -157,11 +257,7 @@ _LOGGING_CONFIG = { 'warningfilter': {'()': 'openllm_core.utils.WarningFilter'}, }, 'handlers': { - 'bentomlhandler': { - 'class': 'logging.StreamHandler', - 'filters': ['excfilter', 'warningfilter', 'infofilter'], - 'stream': 'ext://sys.stdout', - }, + 'bentomlhandler': {'class': 'logging.StreamHandler', 'filters': ['excfilter', 'warningfilter', 'infofilter'], 'stream': 'ext://sys.stdout'}, 'defaulthandler': {'class': 'logging.StreamHandler', 'level': logging.WARNING}, }, 'loggers': { @@ -170,6 +266,8 @@ _LOGGING_CONFIG = { }, 'root': {'level': logging.WARNING}, } + + def configure_logging(): if get_quiet_mode(): _LOGGING_CONFIG['loggers']['openllm']['level'] = logging.ERROR @@ -189,13 +287,13 @@ def configure_logging(): _LOGGING_CONFIG['loggers']['openllm']['level'] = logging.ERROR logging.config.dictConfig(_LOGGING_CONFIG) + + # XXX: define all classes, functions import above this line # since _extras will be the locals() import from this file. _extras = { **{ - k: v - for k, v in locals().items() - if k in {'pkg'} or (not isinstance(v, types.ModuleType) and k not in {'annotations'} and not k.startswith('_')) + k: v for k, v in locals().items() if k in {'pkg'} or (not isinstance(v, types.ModuleType) and k not in {'annotations'} and not k.startswith('_')) }, '__openllm_migration__': {'bentoml_cattr': 'converter'}, } diff --git a/openllm-core/src/openllm_core/utils/codegen.py b/openllm-core/src/openllm_core/utils/codegen.py index 939ef750..3c1d43c1 100644 --- a/openllm-core/src/openllm_core/utils/codegen.py +++ b/openllm-core/src/openllm_core/utils/codegen.py @@ -19,18 +19,24 @@ logger = logging.getLogger(__name__) # sentinel object for unequivocal object() getattr _sentinel = object() + def has_own_attribute(cls: type[t.Any], attrib_name: t.Any) -> bool: attr = getattr(cls, attrib_name, _sentinel) - if attr is _sentinel: return False + if attr is _sentinel: + return False for base_cls in cls.__mro__[1:]: a = getattr(base_cls, attrib_name, None) - if attr is a: return False + if attr is a: + return False return True + def get_annotations(cls: type[t.Any]) -> DictStrAny: - if has_own_attribute(cls, '__annotations__'): return cls.__annotations__ + if has_own_attribute(cls, '__annotations__'): + return cls.__annotations__ return {} + def is_class_var(annot: str | t.Any) -> bool: annot = str(annot) # Annotation can be quoted. @@ -38,6 +44,7 @@ def is_class_var(annot: str | t.Any) -> bool: annot = annot[1:-1] return annot.startswith(('typing.ClassVar', 't.ClassVar', 'ClassVar', 'typing_extensions.ClassVar')) + def add_method_dunders(cls: type[t.Any], method_or_cls: _T, _overwrite_doc: str | None = None) -> _T: try: method_or_cls.__module__ = cls.__module__ @@ -53,7 +60,11 @@ def add_method_dunders(cls: type[t.Any], method_or_cls: _T, _overwrite_doc: str pass return method_or_cls -def _compile_and_eval(script: str, globs: DictStrAny, locs: t.Any = None, filename: str = '') -> None: eval(compile(script, filename, 'exec'), globs, locs) + +def _compile_and_eval(script: str, globs: DictStrAny, locs: t.Any = None, filename: str = '') -> None: + eval(compile(script, filename, 'exec'), globs, locs) + + def _make_method(name: str, script: str, filename: str, globs: DictStrAny) -> AnyCallable: locs: dict[str, t.Any | AnyCallable] = {} # In order of debuggers like PDB being able to step through the code, we add a fake linecache entry. @@ -70,16 +81,18 @@ def _make_method(name: str, script: str, filename: str, globs: DictStrAny) -> An _compile_and_eval(script, globs, locs, filename) return locs[name] + def make_attr_tuple_class(cls_name: str, attr_names: t.Sequence[str]) -> type[t.Any]: - '''Create a tuple subclass to hold class attributes. + """Create a tuple subclass to hold class attributes. The subclass is a bare tuple with properties for names. class MyClassAttributes(tuple): __slots__ = () x = property(itemgetter(0)) - ''' + """ from . import SHOW_CODEGEN + attr_class_name = f'{cls_name}Attributes' attr_class_template = [f'class {attr_class_name}(tuple):', ' __slots__ = ()'] if attr_names: @@ -88,70 +101,93 @@ def make_attr_tuple_class(cls_name: str, attr_names: t.Sequence[str]) -> type[t. else: attr_class_template.append(' pass') globs = {'_attrs_itemgetter': itemgetter, '_attrs_property': property} - if SHOW_CODEGEN: print(f'Generated class for {attr_class_name}:\n\n', '\n'.join(attr_class_template)) + if SHOW_CODEGEN: + print(f'Generated class for {attr_class_name}:\n\n', '\n'.join(attr_class_template)) _compile_and_eval('\n'.join(attr_class_template), globs) return globs[attr_class_name] -def generate_unique_filename(cls: type[t.Any], func_name: str) -> str: return f"<{cls.__name__} generated {func_name} {cls.__module__}.{getattr(cls, '__qualname__', cls.__name__)}>" + +def generate_unique_filename(cls: type[t.Any], func_name: str) -> str: + return f"<{cls.__name__} generated {func_name} {cls.__module__}.{getattr(cls, '__qualname__', cls.__name__)}>" + def generate_function( - typ: type[t.Any], func_name: str, # + typ: type[t.Any], + func_name: str, # lines: list[str] | None, args: tuple[str, ...] | None, globs: dict[str, t.Any], annotations: dict[str, t.Any] | None = None, ) -> AnyCallable: from openllm_core.utils import SHOW_CODEGEN + script = 'def %s(%s):\n %s\n' % (func_name, ', '.join(args) if args is not None else '', '\n '.join(lines) if lines else 'pass') meth = _make_method(func_name, script, generate_unique_filename(typ, func_name), globs) - if annotations: meth.__annotations__ = annotations - if SHOW_CODEGEN: print(f'Generated script for {typ}:\n\n', script) + if annotations: + meth.__annotations__ = annotations + if SHOW_CODEGEN: + print(f'Generated script for {typ}:\n\n', script) return meth + def make_env_transformer( - cls: type[openllm_core.LLMConfig], model_name: str, # + cls: type[openllm_core.LLMConfig], + model_name: str, # suffix: LiteralString | None = None, default_callback: t.Callable[[str, t.Any], t.Any] | None = None, globs: DictStrAny | None = None, ) -> AnyCallable: from openllm_core.utils import dantic, field_env_key - def identity(_: str, x_value: t.Any) -> t.Any: return x_value + def identity(_: str, x_value: t.Any) -> t.Any: + return x_value globs = {} if globs is None else globs - globs.update( - { - '__populate_env': dantic.env_converter, '__field_env': field_env_key, # - '__suffix': suffix or '', '__model_name': model_name, # - '__default_callback': identity if default_callback is None else default_callback, - } - ) + globs.update({ + '__populate_env': dantic.env_converter, + '__field_env': field_env_key, # + '__suffix': suffix or '', + '__model_name': model_name, # + '__default_callback': identity if default_callback is None else default_callback, + }) fields_ann = 'list[attr.Attribute[t.Any]]' return generate_function( - cls, '__auto_env', # - ['__env=lambda field_name:__field_env(field_name,__suffix)', "return [f.evolve(default=__populate_env(__default_callback(f.name,f.default),__env(f.name)),metadata={'env':f.metadata.get('env',__env(f.name)),'description':f.metadata.get('description', '(not provided)')}) for f in fields]"], - ('_', 'fields'), globs, {'_': 'type[LLMConfig]', 'fields': fields_ann, 'return': fields_ann}, # + cls, + '__auto_env', # + [ + '__env=lambda field_name:__field_env(field_name,__suffix)', + "return [f.evolve(default=__populate_env(__default_callback(f.name,f.default),__env(f.name)),metadata={'env':f.metadata.get('env',__env(f.name)),'description':f.metadata.get('description', '(not provided)')}) for f in fields]", + ], + ('_', 'fields'), + globs, + {'_': 'type[LLMConfig]', 'fields': fields_ann, 'return': fields_ann}, # ) def gen_sdk(func: _T, name: str | None = None, **attrs: t.Any) -> _T: from .representation import ReprMixin - if name is None: name = func.__name__.strip('_') + + if name is None: + name = func.__name__.strip('_') _signatures = inspect.signature(func).parameters - def _repr(self: ReprMixin) -> str: return f'' - def _repr_args(self: ReprMixin) -> t.Iterator[t.Tuple[str, t.Any]]: return ((k, _signatures[k].annotation) for k in self.__repr_keys__) + + def _repr(self: ReprMixin) -> str: + return f'' + + def _repr_args(self: ReprMixin) -> t.Iterator[t.Tuple[str, t.Any]]: + return ((k, _signatures[k].annotation) for k in self.__repr_keys__) + return functools.update_wrapper( types.new_class( name, (functools.partial, ReprMixin), - exec_body=lambda ns: ns.update( - { - '__repr_keys__': property(lambda _: [i for i in _signatures.keys() if not i.startswith('_')]), - '__repr_args__': _repr_args, '__repr__': _repr, # - '__doc__': inspect.cleandoc(f'Generated SDK for {func.__name__}' if func.__doc__ is None else func.__doc__), - '__module__': 'openllm', - } - ), + exec_body=lambda ns: ns.update({ + '__repr_keys__': property(lambda _: [i for i in _signatures.keys() if not i.startswith('_')]), + '__repr_args__': _repr_args, + '__repr__': _repr, # + '__doc__': inspect.cleandoc(f'Generated SDK for {func.__name__}' if func.__doc__ is None else func.__doc__), + '__module__': 'openllm', + }), )(func, **attrs), func, ) diff --git a/openllm-core/src/openllm_core/utils/dantic.py b/openllm-core/src/openllm_core/utils/dantic.py index 442462ff..a25c6b40 100644 --- a/openllm-core/src/openllm_core/utils/dantic.py +++ b/openllm-core/src/openllm_core/utils/dantic.py @@ -45,12 +45,7 @@ def __dir__() -> list[str]: def attrs_to_options( - name: str, - field: attr.Attribute[t.Any], - model_name: str, - typ: t.Any = None, - suffix_generation: bool = False, - suffix_sampling: bool = False, + name: str, field: attr.Attribute[t.Any], model_name: str, typ: t.Any = None, suffix_generation: bool = False, suffix_sampling: bool = False ) -> t.Callable[[FC], FC]: # TODO: support parsing nested attrs class and Union envvar = field.metadata['env'] @@ -210,14 +205,14 @@ def parse_type(field_type: t.Any) -> ParamType | tuple[ParamType, ...]: def is_typing(field_type: type) -> bool: - '''Checks whether the current type is a module-like type. + """Checks whether the current type is a module-like type. Args: field_type: pydantic field type Returns: bool: true if the type is itself a type - ''' + """ raw = t.get_origin(field_type) if raw is None: return False @@ -227,7 +222,7 @@ def is_typing(field_type: type) -> bool: def is_literal(field_type: type) -> bool: - '''Checks whether the given field type is a Literal type or not. + """Checks whether the given field type is a Literal type or not. Literals are weird: isinstance and subclass do not work, so you compare the origin with the Literal declaration itself. @@ -237,7 +232,7 @@ def is_literal(field_type: type) -> bool: Returns: bool: true if Literal type, false otherwise - ''' + """ origin = t.get_origin(field_type) return origin is not None and origin is t.Literal @@ -272,12 +267,12 @@ class EnumChoice(click.Choice): name = 'enum' def __init__(self, enum: Enum, case_sensitive: bool = False): - '''Enum type support for click that extends ``click.Choice``. + """Enum type support for click that extends ``click.Choice``. Args: enum: Given enum case_sensitive: Whether this choice should be case case_sensitive. - ''' + """ self.mapping = enum self.internal_type = type(enum) choices: list[t.Any] = [e.name for e in enum.__class__] @@ -296,7 +291,7 @@ class LiteralChoice(EnumChoice): name = 'literal' def __init__(self, value: t.Any, case_sensitive: bool = False): - '''Literal support for click.''' + """Literal support for click.""" # expect every literal value to belong to the same primitive type values = list(value.__args__) item_type = type(values[0]) @@ -334,14 +329,14 @@ def allows_multiple(field_type: type[t.Any]) -> bool: def is_mapping(field_type: type) -> bool: - '''Checks whether this field represents a dictionary or JSON object. + """Checks whether this field represents a dictionary or JSON object. Args: field_type (type): pydantic type Returns: bool: true when the field is a dict-like object, false otherwise. - ''' + """ # Early out for standard containers. from . import lenient_issubclass @@ -379,14 +374,14 @@ def is_container(field_type: type) -> bool: def parse_container_args(field_type: type[t.Any]) -> ParamType | tuple[ParamType, ...]: - '''Parses the arguments inside a container type (lists, tuples and so on). + """Parses the arguments inside a container type (lists, tuples and so on). Args: field_type: pydantic field type Returns: ParamType | tuple[ParamType]: single click-compatible type or a tuple - ''' + """ if not is_container(field_type): raise ValueError('Field type is not a container type.') args = t.get_args(field_type) @@ -466,7 +461,7 @@ class CudaValueType(ParamType): return var def shell_complete(self, ctx: click.Context, param: click.Parameter, incomplete: str) -> list[sc.CompletionItem]: - '''Return a list of :class:`~click.shell_completion.CompletionItem` objects for the incomplete value. + """Return a list of :class:`~click.shell_completion.CompletionItem` objects for the incomplete value. Most types do not provide completions, but some do, and this allows custom types to provide custom completions as well. @@ -474,7 +469,7 @@ class CudaValueType(ParamType): ctx: Invocation context for this command. param: The parameter that is requesting completion. incomplete: Value being completed. May be empty. - ''' + """ from openllm.utils import available_devices mapping = incomplete.split(self.envvar_list_splitter) if incomplete else available_devices() diff --git a/openllm-core/src/openllm_core/utils/import_utils.py b/openllm-core/src/openllm_core/utils/import_utils.py index c6cd37fc..7d36e6c4 100644 --- a/openllm-core/src/openllm_core/utils/import_utils.py +++ b/openllm-core/src/openllm_core/utils/import_utils.py @@ -1,20 +1,11 @@ import importlib, importlib.metadata, importlib.util, os -OPTIONAL_DEPENDENCIES = { - 'vllm', - 'fine-tune', - 'ggml', - 'ctranslate', - 'agents', - 'openai', - 'playground', - 'gptq', - 'grpc', - 'awq', -} +OPTIONAL_DEPENDENCIES = {'vllm', 'fine-tune', 'ggml', 'ctranslate', 'agents', 'openai', 'playground', 'gptq', 'grpc', 'awq'} ENV_VARS_TRUE_VALUES = {'1', 'ON', 'YES', 'TRUE'} ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({'AUTO'}) USE_VLLM = os.getenv('USE_VLLM', 'AUTO').upper() + + def _is_package_available(package: str) -> bool: _package_available = importlib.util.find_spec(package) is not None if _package_available: @@ -23,6 +14,8 @@ def _is_package_available(package: str) -> bool: except importlib.metadata.PackageNotFoundError: _package_available = False return _package_available + + _ctranslate_available = importlib.util.find_spec('ctranslate2') is not None _vllm_available = importlib.util.find_spec('vllm') is not None _grpc_available = importlib.util.find_spec('grpc') is not None @@ -38,19 +31,60 @@ _jupyter_available = _is_package_available('jupyter') _jupytext_available = _is_package_available('jupytext') _notebook_available = _is_package_available('notebook') _autogptq_available = _is_package_available('auto_gptq') -def is_triton_available() -> bool: return _triton_available -def is_ctranslate_available() -> bool: return _ctranslate_available -def is_bentoml_available() -> bool: return _bentoml_available # needs this since openllm-core doesn't explicitly depends on bentoml -def is_transformers_available() -> bool: return _transformers_available # needs this since openllm-core doesn't explicitly depends on transformers -def is_grpc_available() -> bool: return _grpc_available -def is_jupyter_available() -> bool: return _jupyter_available -def is_jupytext_available() -> bool: return _jupytext_available -def is_notebook_available() -> bool: return _notebook_available -def is_peft_available() -> bool: return _peft_available -def is_bitsandbytes_available() -> bool: return _bitsandbytes_available -def is_autogptq_available() -> bool: return _autogptq_available -def is_torch_available() -> bool: return _torch_available -def is_flash_attn_2_available() -> bool: return _flash_attn_available + + +def is_triton_available() -> bool: + return _triton_available + + +def is_ctranslate_available() -> bool: + return _ctranslate_available + + +def is_bentoml_available() -> bool: + return _bentoml_available # needs this since openllm-core doesn't explicitly depends on bentoml + + +def is_transformers_available() -> bool: + return _transformers_available # needs this since openllm-core doesn't explicitly depends on transformers + + +def is_grpc_available() -> bool: + return _grpc_available + + +def is_jupyter_available() -> bool: + return _jupyter_available + + +def is_jupytext_available() -> bool: + return _jupytext_available + + +def is_notebook_available() -> bool: + return _notebook_available + + +def is_peft_available() -> bool: + return _peft_available + + +def is_bitsandbytes_available() -> bool: + return _bitsandbytes_available + + +def is_autogptq_available() -> bool: + return _autogptq_available + + +def is_torch_available() -> bool: + return _torch_available + + +def is_flash_attn_2_available() -> bool: + return _flash_attn_available + + def is_autoawq_available() -> bool: global _autoawq_available try: @@ -58,6 +92,8 @@ def is_autoawq_available() -> bool: except importlib.metadata.PackageNotFoundError: _autoawq_available = False return _autoawq_available + + def is_vllm_available() -> bool: global _vllm_available if USE_VLLM in ENV_VARS_TRUE_AND_AUTO_VALUES or _vllm_available: diff --git a/openllm-core/src/openllm_core/utils/lazy.py b/openllm-core/src/openllm_core/utils/lazy.py index 208879f0..0bcf6d36 100644 --- a/openllm-core/src/openllm_core/utils/lazy.py +++ b/openllm-core/src/openllm_core/utils/lazy.py @@ -23,11 +23,11 @@ logger = logging.getLogger(__name__) class LazyLoader(types.ModuleType): - ''' + """ LazyLoader module borrowed from Tensorflow https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/util/lazy_loader.py with a addition of "module caching". This will throw an exception if module cannot be imported. Lazily import a module, mainly to avoid pulling in large dependencies. `contrib`, and `ffmpeg` are examples of modules that are large and not always needed, and this allows them to only be loaded when they are used. - ''' + """ def __init__( self, @@ -107,9 +107,7 @@ class VersionInfo: raise NotImplementedError if not (1 <= len(cmp) <= 4): raise NotImplementedError - return t.cast(t.Tuple[int, int, int, str], attr.astuple(self)[: len(cmp)]), t.cast( - t.Tuple[int, int, int, str], cmp - ) + return t.cast(t.Tuple[int, int, int, str], attr.astuple(self)[: len(cmp)]), t.cast(t.Tuple[int, int, int, str], cmp) def __eq__(self, other: t.Any) -> bool: try: @@ -180,12 +178,12 @@ class LazyModule(types.ModuleType): return result + [i for i in self.__all__ if i not in result] def __getattr__(self, name: str) -> t.Any: - '''Equivocal __getattr__ implementation. + """Equivocal __getattr__ implementation. It checks from _objects > _modules and does it recursively. It also contains a special case for all of the metadata information, such as __version__ and __version_info__. - ''' + """ if name in _reserved_namespace: raise openllm_core.exceptions.ForbiddenAttributeError( f"'{name}' is a reserved namespace for {self._name} and should not be access nor modified." @@ -231,9 +229,7 @@ class LazyModule(types.ModuleType): cur_value = self._objects['__openllm_migration__'].get(name, _sentinel) if cur_value is not _sentinel: warnings.warn( - f"'{name}' is deprecated and will be removed in future version. Make sure to use '{cur_value}' instead", - DeprecationWarning, - stacklevel=3, + f"'{name}' is deprecated and will be removed in future version. Make sure to use '{cur_value}' instead", DeprecationWarning, stacklevel=3 ) return getattr(self, cur_value) if name in self._objects: @@ -254,9 +250,7 @@ class LazyModule(types.ModuleType): try: return importlib.import_module('.' + module_name, self.__name__) except Exception as e: - raise RuntimeError( - f'Failed to import {self.__name__}.{module_name} because of the following error (look up to see its traceback):\n{e}' - ) from e + raise RuntimeError(f'Failed to import {self.__name__}.{module_name} because of the following error (look up to see its traceback):\n{e}') from e # make sure this module is picklable def __reduce__(self) -> tuple[type[LazyModule], tuple[str, str | None, dict[str, list[str]]]]: diff --git a/openllm-core/src/openllm_core/utils/peft.py b/openllm-core/src/openllm_core/utils/peft.py index b0b25fd5..f8d85a65 100644 --- a/openllm-core/src/openllm_core/utils/peft.py +++ b/openllm-core/src/openllm_core/utils/peft.py @@ -87,12 +87,8 @@ class FineTuneConfig: converter=attr.converters.default_if_none(factory=dict), use_default_converter=False, ) - inference_mode: bool = dantic.Field( - False, description='Whether to use this Adapter for inference', use_default_converter=False - ) - llm_config_class: type[LLMConfig] = dantic.Field( - None, description='The reference class to openllm.LLMConfig', use_default_converter=False - ) + inference_mode: bool = dantic.Field(False, description='Whether to use this Adapter for inference', use_default_converter=False) + llm_config_class: type[LLMConfig] = dantic.Field(None, description='The reference class to openllm.LLMConfig', use_default_converter=False) def build(self) -> PeftConfig: try: @@ -110,12 +106,7 @@ class FineTuneConfig: # respect user set task_type if it is passed, otherwise use one managed by OpenLLM inference_mode = adapter_config.pop('inference_mode', self.inference_mode) task_type = adapter_config.pop('task_type', TaskType[self.llm_config_class.peft_task_type()]) - adapter_config = { - 'peft_type': self.adapter_type.value, - 'task_type': task_type, - 'inference_mode': inference_mode, - **adapter_config, - } + adapter_config = {'peft_type': self.adapter_type.value, 'task_type': task_type, 'inference_mode': inference_mode, **adapter_config} return get_peft_config(adapter_config) def train(self) -> FineTuneConfig: @@ -127,18 +118,10 @@ class FineTuneConfig: return self def with_config(self, **attrs: t.Any) -> FineTuneConfig: - adapter_type, inference_mode = ( - attrs.pop('adapter_type', self.adapter_type), - attrs.get('inference_mode', self.inference_mode), - ) + adapter_type, inference_mode = (attrs.pop('adapter_type', self.adapter_type), attrs.get('inference_mode', self.inference_mode)) if 'llm_config_class' in attrs: raise ForbiddenAttributeError("'llm_config_class' should not be passed when using 'with_config'.") - return attr.evolve( - self, - adapter_type=adapter_type, - inference_mode=inference_mode, - adapter_config=config_merger.merge(self.adapter_config, attrs), - ) + return attr.evolve(self, adapter_type=adapter_type, inference_mode=inference_mode, adapter_config=config_merger.merge(self.adapter_config, attrs)) @classmethod def from_config(cls, ft_config: dict[str, t.Any], llm_config_cls: type[LLMConfig]) -> FineTuneConfig: @@ -146,9 +129,4 @@ class FineTuneConfig: adapter_type = copied.pop('adapter_type', 'lora') inference_mode = copied.pop('inference_mode', False) llm_config_class = copied.pop('llm_confg_class', llm_config_cls) - return cls( - adapter_type=adapter_type, - adapter_config=copied, - inference_mode=inference_mode, - llm_config_class=llm_config_class, - ) + return cls(adapter_type=adapter_type, adapter_config=copied, inference_mode=inference_mode, llm_config_class=llm_config_class) diff --git a/openllm-core/src/openllm_core/utils/representation.py b/openllm-core/src/openllm_core/utils/representation.py index 1ee36df0..8221b5ff 100644 --- a/openllm-core/src/openllm_core/utils/representation.py +++ b/openllm-core/src/openllm_core/utils/representation.py @@ -4,14 +4,28 @@ from abc import abstractmethod import attr, orjson from openllm_core import utils -if t.TYPE_CHECKING: from openllm_core._typing_compat import TypeAlias +if t.TYPE_CHECKING: + from openllm_core._typing_compat import TypeAlias ReprArgs: TypeAlias = t.Generator[t.Tuple[t.Optional[str], t.Any], None, None] + + class ReprMixin: @property @abstractmethod - def __repr_keys__(self) -> set[str]: raise NotImplementedError - def __repr__(self) -> str: return f'{self.__class__.__name__} {orjson.dumps({k: utils.converter.unstructure(v) if attr.has(v) else v for k, v in self.__repr_args__()}, option=orjson.OPT_INDENT_2).decode()}' - def __str__(self) -> str: return self.__repr_str__(' ') - def __repr_name__(self) -> str: return self.__class__.__name__ - def __repr_str__(self, join_str: str) -> str: return join_str.join(repr(v) if a is None else f'{a}={v!r}' for a, v in self.__repr_args__()) - def __repr_args__(self) -> ReprArgs: return ((k, getattr(self, k)) for k in self.__repr_keys__) + def __repr_keys__(self) -> set[str]: + raise NotImplementedError + + def __repr__(self) -> str: + return f'{self.__class__.__name__} {orjson.dumps({k: utils.converter.unstructure(v) if attr.has(v) else v for k, v in self.__repr_args__()}, option=orjson.OPT_INDENT_2).decode()}' + + def __str__(self) -> str: + return self.__repr_str__(' ') + + def __repr_name__(self) -> str: + return self.__class__.__name__ + + def __repr_str__(self, join_str: str) -> str: + return join_str.join(repr(v) if a is None else f'{a}={v!r}' for a, v in self.__repr_args__()) + + def __repr_args__(self) -> ReprArgs: + return ((k, getattr(self, k)) for k in self.__repr_keys__) diff --git a/openllm-core/src/openllm_core/utils/serde.py b/openllm-core/src/openllm_core/utils/serde.py index 10d35255..c03a1653 100644 --- a/openllm-core/src/openllm_core/utils/serde.py +++ b/openllm-core/src/openllm_core/utils/serde.py @@ -19,16 +19,10 @@ def datetime_structure_hook(dt_like: str | datetime | t.Any, _: t.Any) -> dateti converter.register_structure_hook_factory( - attr.has, - lambda cls: make_dict_structure_fn( - cls, converter, _cattrs_forbid_extra_keys=getattr(cls, '__forbid_extra_keys__', False) - ), + attr.has, lambda cls: make_dict_structure_fn(cls, converter, _cattrs_forbid_extra_keys=getattr(cls, '__forbid_extra_keys__', False)) ) converter.register_unstructure_hook_factory( - attr.has, - lambda cls: make_dict_unstructure_fn( - cls, converter, _cattrs_omit_if_default=getattr(cls, '__omit_if_default__', False) - ), + attr.has, lambda cls: make_dict_unstructure_fn(cls, converter, _cattrs_omit_if_default=getattr(cls, '__omit_if_default__', False)) ) converter.register_structure_hook(datetime, datetime_structure_hook) converter.register_unstructure_hook(datetime, lambda dt: dt.isoformat()) diff --git a/openllm-python/src/openllm/__init__.py b/openllm-python/src/openllm/__init__.py index 8085e06e..f53db436 100644 --- a/openllm-python/src/openllm/__init__.py +++ b/openllm-python/src/openllm/__init__.py @@ -1,8 +1,10 @@ import logging as _logging, os as _os, pathlib as _pathlib, warnings as _warnings from openllm_cli import _sdk from . import utils as utils + if utils.DEBUG: - utils.set_debug_mode(True); _logging.basicConfig(level=_logging.NOTSET) + utils.set_debug_mode(True) + _logging.basicConfig(level=_logging.NOTSET) else: # configuration for bitsandbytes before import _os.environ['BITSANDBYTES_NOWELCOME'] = _os.environ.get('BITSANDBYTES_NOWELCOME', '1') @@ -30,8 +32,11 @@ __lazy = utils.LazyModule( # NOTE: update this to sys.modules[__name__] once my '_llm': ['LLM'], }, extra_objects={ - 'COMPILED': COMPILED, 'start': _sdk.start, 'build': _sdk.build, # - 'import_model': _sdk.import_model, 'list_models': _sdk.list_models, # + 'COMPILED': COMPILED, + 'start': _sdk.start, + 'build': _sdk.build, # + 'import_model': _sdk.import_model, + 'list_models': _sdk.list_models, # }, ) __all__, __dir__, __getattr__ = __lazy.__all__, __lazy.__dir__, __lazy.__getattr__ diff --git a/openllm-python/src/openllm/__init__.pyi b/openllm-python/src/openllm/__init__.pyi index 200b0764..4b65bc5a 100644 --- a/openllm-python/src/openllm/__init__.pyi +++ b/openllm-python/src/openllm/__init__.pyi @@ -1,4 +1,4 @@ -'''OpenLLM. +"""OpenLLM. =========== An open platform for operating large language models in production. @@ -8,36 +8,17 @@ Fine-tune, serve, deploy, and monitor any LLMs with ease. * Option to bring your own fine-tuned LLMs * Online Serving with HTTP, gRPC, SSE or custom API * Native integration with BentoML, LangChain, OpenAI compatible endpoints, LlamaIndex for custom LLM apps -''' +""" # update-config-stubs.py: import stubs start from openlm_core.config import CONFIG_MAPPING as CONFIG_MAPPING, CONFIG_MAPPING_NAMES as CONFIG_MAPPING_NAMES, AutoConfig as AutoConfig, BaichuanConfig as BaichuanConfig, ChatGLMConfig as ChatGLMConfig, DollyV2Config as DollyV2Config, FalconConfig as FalconConfig, FlanT5Config as FlanT5Config, GPTNeoXConfig as GPTNeoXConfig, LlamaConfig as LlamaConfig, MistralConfig as MistralConfig, MixtralConfig as MixtralConfig, MPTConfig as MPTConfig, OPTConfig as OPTConfig, PhiConfig as PhiConfig, QwenConfig as QwenConfig, StableLMConfig as StableLMConfig, StarCoderConfig as StarCoderConfig, YiConfig as YiConfig # update-config-stubs.py: import stubs stop -from openllm_cli._sdk import ( - build as build, - import_model as import_model, - list_models as list_models, - start as start, -) -from openllm_core._configuration import ( - GenerationConfig as GenerationConfig, - LLMConfig as LLMConfig, - SamplingParams as SamplingParams, -) -from openllm_core._schemas import ( - GenerationInput as GenerationInput, - GenerationOutput as GenerationOutput, - MetadataOutput as MetadataOutput, -) +from openllm_cli._sdk import build as build, import_model as import_model, list_models as list_models, start as start +from openllm_core._configuration import GenerationConfig as GenerationConfig, LLMConfig as LLMConfig, SamplingParams as SamplingParams +from openllm_core._schemas import GenerationInput as GenerationInput, GenerationOutput as GenerationOutput, MetadataOutput as MetadataOutput -from . import ( - bundle as bundle, - client as client, - exceptions as exceptions, - serialisation as serialisation, - utils as utils, -) +from . import bundle as bundle, client as client, exceptions as exceptions, serialisation as serialisation, utils as utils from ._deprecated import Runner as Runner from ._llm import LLM as LLM from ._quantisation import infer_quantisation_config as infer_quantisation_config diff --git a/openllm-python/src/openllm/__main__.py b/openllm-python/src/openllm/__main__.py index f6a0db49..72092181 100644 --- a/openllm-python/src/openllm/__main__.py +++ b/openllm-python/src/openllm/__main__.py @@ -1 +1,4 @@ -if __name__ == '__main__': from openllm_cli.entrypoint import cli; cli() +if __name__ == '__main__': + from openllm_cli.entrypoint import cli + + cli() diff --git a/openllm-python/src/openllm/_deprecated.py b/openllm-python/src/openllm/_deprecated.py index a1ffbbdb..559ad0b2 100644 --- a/openllm-python/src/openllm/_deprecated.py +++ b/openllm-python/src/openllm/_deprecated.py @@ -7,15 +7,22 @@ from openllm_core.utils import first_not_none, getenv, is_vllm_available __all__ = ['Runner'] logger = logging.getLogger(__name__) + def Runner( - model_name: str, ensure_available: bool = True, # - init_local: bool = False, backend: LiteralBackend | None = None, # - llm_config: openllm.LLMConfig | None = None, **attrs: t.Any, + model_name: str, + ensure_available: bool = True, # + init_local: bool = False, + backend: LiteralBackend | None = None, # + llm_config: openllm.LLMConfig | None = None, + **attrs: t.Any, ): - if llm_config is None: llm_config = openllm.AutoConfig.for_model(model_name) - if not ensure_available: logger.warning("'ensure_available=False' won't have any effect as LLM will always check to download the model on initialisation.") + if llm_config is None: + llm_config = openllm.AutoConfig.for_model(model_name) + if not ensure_available: + logger.warning("'ensure_available=False' won't have any effect as LLM will always check to download the model on initialisation.") model_id = attrs.get('model_id', os.getenv('OPENLLM_MODEL_ID', llm_config['default_id'])) - warnings.warn(f'''\ + warnings.warn( + f"""\ Using 'openllm.Runner' is now deprecated. Make sure to switch to the following syntax: ```python @@ -26,11 +33,15 @@ def Runner( @svc.api(...) async def chat(input: str) -> str: async for it in llm.generate_iterator(input): print(it) - ```''', DeprecationWarning, stacklevel=2) - attrs.update( - { - 'model_id': model_id, 'quantize': getenv('QUANTIZE', var=['QUANTISE'], default=attrs.get('quantize', None)), # - 'serialisation': getenv('serialization', default=attrs.get('serialisation', llm_config['serialisation']), var=['SERIALISATION']), - } + ```""", + DeprecationWarning, + stacklevel=2, ) - return openllm.LLM(backend=first_not_none(backend, default='vllm' if is_vllm_available() else 'pt'), llm_config=llm_config, embedded=init_local, **attrs).runner + attrs.update({ + 'model_id': model_id, + 'quantize': getenv('QUANTIZE', var=['QUANTISE'], default=attrs.get('quantize', None)), # + 'serialisation': getenv('serialization', default=attrs.get('serialisation', llm_config['serialisation']), var=['SERIALISATION']), + }) + return openllm.LLM( + backend=first_not_none(backend, default='vllm' if is_vllm_available() else 'pt'), llm_config=llm_config, embedded=init_local, **attrs + ).runner diff --git a/openllm-python/src/openllm/_generation.py b/openllm-python/src/openllm/_generation.py index 234b88c1..c06980aa 100644 --- a/openllm-python/src/openllm/_generation.py +++ b/openllm-python/src/openllm/_generation.py @@ -1,5 +1,6 @@ def prepare_logits_processor(config): import transformers + generation_config = config.generation_config logits_processor = transformers.LogitsProcessorList() if generation_config['temperature'] >= 1e-5 and generation_config['temperature'] != 1.0: @@ -11,16 +12,27 @@ def prepare_logits_processor(config): if generation_config['top_k'] > 0: logits_processor.append(transformers.TopKLogitsWarper(generation_config['top_k'])) return logits_processor + + # NOTE: The ordering here is important. Some models have two of these and we have a preference for which value gets used. SEQLEN_KEYS = ['max_sequence_length', 'seq_length', 'max_position_embeddings', 'max_seq_len', 'model_max_length'] + + def get_context_length(config): rope_scaling = getattr(config, 'rope_scaling', None) rope_scaling_factor = config.rope_scaling['factor'] if rope_scaling else 1.0 for key in SEQLEN_KEYS: - if getattr(config, key, None) is not None: return int(rope_scaling_factor * getattr(config, key)) + if getattr(config, key, None) is not None: + return int(rope_scaling_factor * getattr(config, key)) return 2048 -def is_sentence_complete(output): return output.endswith(('.', '?', '!', '...', '。', '?', '!', '…', '"', "'", '”')) + + +def is_sentence_complete(output): + return output.endswith(('.', '?', '!', '...', '。', '?', '!', '…', '"', "'", '”')) + + def is_partial_stop(output, stop_str): for i in range(min(len(output), len(stop_str))): - if stop_str.startswith(output[-i:]): return True + if stop_str.startswith(output[-i:]): + return True return False diff --git a/openllm-python/src/openllm/_llm.py b/openllm-python/src/openllm/_llm.py index 07962c7e..03caae26 100644 --- a/openllm-python/src/openllm/_llm.py +++ b/openllm-python/src/openllm/_llm.py @@ -48,7 +48,8 @@ ResolvedAdapterMap = t.Dict[AdapterType, t.Dict[str, t.Tuple['PeftConfig', str]] @attr.define(slots=False, repr=False, init=False) class LLM(t.Generic[M, T]): async def generate(self, prompt, prompt_token_ids=None, stop=None, stop_token_ids=None, request_id=None, adapter_name=None, **attrs): - if adapter_name is not None and self.__llm_backend__ != 'pt': raise NotImplementedError(f'Adapter is not supported with {self.__llm_backend__}.') + if adapter_name is not None and self.__llm_backend__ != 'pt': + raise NotImplementedError(f'Adapter is not supported with {self.__llm_backend__}.') config = self.config.model_construct_env(**attrs) texts, token_ids = [[]] * config['n'], [[]] * config['n'] async for result in self.generate_iterator( @@ -57,18 +58,18 @@ class LLM(t.Generic[M, T]): for output in result.outputs: texts[output.index].append(output.text) token_ids[output.index].extend(output.token_ids) - if (final_result := result) is None: raise RuntimeError('No result is returned.') + if (final_result := result) is None: + raise RuntimeError('No result is returned.') return final_result.with_options( prompt=prompt, - outputs=[ - output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index]) - for output in final_result.outputs - ], + outputs=[output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index]) for output in final_result.outputs], ) async def generate_iterator(self, prompt, prompt_token_ids=None, stop=None, stop_token_ids=None, request_id=None, adapter_name=None, **attrs): from bentoml._internal.runner.runner_handle import DummyRunnerHandle - if adapter_name is not None and self.__llm_backend__ != 'pt': raise NotImplementedError(f'Adapter is not supported with {self.__llm_backend__}.') + + if adapter_name is not None and self.__llm_backend__ != 'pt': + raise NotImplementedError(f'Adapter is not supported with {self.__llm_backend__}.') if isinstance(self.runner._runner_handle, DummyRunnerHandle): if os.getenv('BENTO_PATH') is not None: @@ -79,10 +80,13 @@ class LLM(t.Generic[M, T]): stop_token_ids = stop_token_ids or [] eos_token_id = attrs.get('eos_token_id', config['eos_token_id']) - if eos_token_id and not isinstance(eos_token_id, list): eos_token_id = [eos_token_id] + if eos_token_id and not isinstance(eos_token_id, list): + eos_token_id = [eos_token_id] stop_token_ids.extend(eos_token_id or []) - if (config_eos := config['eos_token_id']) and config_eos not in stop_token_ids: stop_token_ids.append(config_eos) - if self.tokenizer.eos_token_id not in stop_token_ids: stop_token_ids.append(self.tokenizer.eos_token_id) + if (config_eos := config['eos_token_id']) and config_eos not in stop_token_ids: + stop_token_ids.append(config_eos) + if self.tokenizer.eos_token_id not in stop_token_ids: + stop_token_ids.append(self.tokenizer.eos_token_id) if stop is None: stop = set() elif isinstance(stop, str): @@ -90,16 +94,20 @@ class LLM(t.Generic[M, T]): else: stop = set(stop) for tid in stop_token_ids: - if tid: stop.add(self.tokenizer.decode(tid)) + if tid: + stop.add(self.tokenizer.decode(tid)) if prompt_token_ids is None: - if prompt is None: raise ValueError('Either prompt or prompt_token_ids must be specified.') + if prompt is None: + raise ValueError('Either prompt or prompt_token_ids must be specified.') prompt_token_ids = self.tokenizer.encode(prompt) request_id = gen_random_uuid() if request_id is None else request_id previous_texts, previous_num_tokens = [''] * config['n'], [0] * config['n'] try: - generator = self.runner.generate_iterator.async_stream(prompt_token_ids, request_id, stop=list(stop), adapter_name=adapter_name, **config.model_dump(flatten=True)) + generator = self.runner.generate_iterator.async_stream( + prompt_token_ids, request_id, stop=list(stop), adapter_name=adapter_name, **config.model_dump(flatten=True) + ) except Exception as err: raise RuntimeError(f'Failed to start generation task: {err}') from err @@ -118,11 +126,18 @@ class LLM(t.Generic[M, T]): # NOTE: If you are here to see how generate_iterator and generate works, see above. # The below are mainly for internal implementation that you don't have to worry about. - _model_id: str; _revision: t.Optional[str] # + _model_id: str + _revision: t.Optional[str] # _quantization_config: t.Optional[t.Union[transformers.BitsAndBytesConfig, transformers.GPTQConfig, transformers.AwqConfig]] - _quantise: t.Optional[LiteralQuantise]; _model_decls: t.Tuple[t.Any, ...]; __model_attrs: t.Dict[str, t.Any] # - __tokenizer_attrs: t.Dict[str, t.Any]; _tag: bentoml.Tag; _adapter_map: t.Optional[AdapterMap] # - _serialisation: LiteralSerialisation; _local: bool; _max_model_len: t.Optional[int] # + _quantise: t.Optional[LiteralQuantise] + _model_decls: t.Tuple[t.Any, ...] + __model_attrs: t.Dict[str, t.Any] # + __tokenizer_attrs: t.Dict[str, t.Any] + _tag: bentoml.Tag + _adapter_map: t.Optional[AdapterMap] # + _serialisation: LiteralSerialisation + _local: bool + _max_model_len: t.Optional[int] # _gpu_memory_utilization: float __llm_dtype__: t.Union[LiteralDtype, t.Literal['auto', 'half', 'float']] = 'auto' @@ -159,20 +174,28 @@ class LLM(t.Generic[M, T]): ): torch_dtype = attrs.pop('torch_dtype', None) # backward compatible if torch_dtype is not None: - warnings.warn('The argument "torch_dtype" is deprecated and will be removed in the future. Please use "dtype" instead.', DeprecationWarning, stacklevel=3); dtype = torch_dtype + warnings.warn( + 'The argument "torch_dtype" is deprecated and will be removed in the future. Please use "dtype" instead.', DeprecationWarning, stacklevel=3 + ) + dtype = torch_dtype _local = False - if validate_is_path(model_id): model_id, _local = resolve_filepath(model_id), True + if validate_is_path(model_id): + model_id, _local = resolve_filepath(model_id), True backend = getenv('backend', default=backend) - if backend is None: backend = self._cascade_backend() + if backend is None: + backend = self._cascade_backend() dtype = getenv('dtype', default=dtype, var=['TORCH_DTYPE']) - if dtype is None: logger.warning('Setting dtype to auto. Inferring from framework specific models'); dtype = 'auto' + if dtype is None: + logger.warning('Setting dtype to auto. Inferring from framework specific models') + dtype = 'auto' quantize = getenv('quantize', default=quantize, var=['QUANITSE']) attrs.update({'low_cpu_mem_usage': low_cpu_mem_usage}) # parsing tokenizer and model kwargs, as the hierarchy is param pass > default model_attrs, tokenizer_attrs = flatten_attrs(**attrs) if model_tag is None: model_tag, model_version = self._make_tag_components(model_id, model_version, backend=backend) - if model_version: model_tag = f'{model_tag}:{model_version}' + if model_version: + model_tag = f'{model_tag}:{model_version}' self.__attrs_init__( model_id=model_id, @@ -200,42 +223,57 @@ class LLM(t.Generic[M, T]): model = openllm.serialisation.import_model(self, trust_remote_code=self.trust_remote_code) # resolve the tag self._tag = model.tag - if not _eager and embedded: raise RuntimeError("Embedded mode is not supported when '_eager' is False.") + if not _eager and embedded: + raise RuntimeError("Embedded mode is not supported when '_eager' is False.") if embedded: - logger.warning('NOT RECOMMENDED in production and SHOULD ONLY used for development.'); self.runner.init_local(quiet=True) + logger.warning('NOT RECOMMENDED in production and SHOULD ONLY used for development.') + self.runner.init_local(quiet=True) + class _Quantise: @staticmethod - def pt(llm: LLM, quantise=None): return quantise + def pt(llm: LLM, quantise=None): + return quantise + @staticmethod - def vllm(llm: LLM, quantise=None): return quantise + def vllm(llm: LLM, quantise=None): + return quantise + @staticmethod def ctranslate(llm: LLM, quantise=None): - if quantise in {'int4', 'awq', 'gptq', 'squeezellm'}: raise ValueError(f"Quantisation '{quantise}' is not supported for backend 'ctranslate'") - if quantise == 'int8': quantise = 'int8_float16' if llm._has_gpus else 'int8_float32' + if quantise in {'int4', 'awq', 'gptq', 'squeezellm'}: + raise ValueError(f"Quantisation '{quantise}' is not supported for backend 'ctranslate'") + if quantise == 'int8': + quantise = 'int8_float16' if llm._has_gpus else 'int8_float32' return quantise + @apply(lambda val: tuple(str.lower(i) if i else i for i in val)) def _make_tag_components(self, model_id: str, model_version: str | None, backend: str) -> tuple[str, str | None]: model_id, *maybe_revision = model_id.rsplit(':') if len(maybe_revision) > 0: - if model_version is not None: logger.warning("revision is specified (%s). 'model_version=%s' will be ignored.", maybe_revision[0], model_version) + if model_version is not None: + logger.warning("revision is specified (%s). 'model_version=%s' will be ignored.", maybe_revision[0], model_version) model_version = maybe_revision[0] if validate_is_path(model_id): model_id, model_version = resolve_filepath(model_id), first_not_none(model_version, default=generate_hash_from_file(model_id)) return f'{backend}-{normalise_model_name(model_id)}', model_version + @functools.cached_property def _has_gpus(self): try: from cuda import cuda + err, *_ = cuda.cuInit(0) if err != cuda.CUresult.CUDA_SUCCESS: raise RuntimeError('Failed to initialise CUDA runtime binding.') - err, num_gpus = cuda.cuDeviceGetCount() + err, _ = cuda.cuDeviceGetCount() if err != cuda.CUresult.CUDA_SUCCESS: raise RuntimeError('Failed to get CUDA device count.') return True except (ImportError, RuntimeError): return False + @property def _torch_dtype(self): import torch, transformers + _map = _torch_dtype_mapping() if not isinstance(self.__llm_torch_dtype__, torch.dtype): try: @@ -243,23 +281,32 @@ class LLM(t.Generic[M, T]): except OpenLLMException: hf_config = transformers.AutoConfig.from_pretrained(self.model_id, trust_remote_code=self.trust_remote_code) config_dtype = getattr(hf_config, 'torch_dtype', None) - if config_dtype is None: config_dtype = torch.float32 + if config_dtype is None: + config_dtype = torch.float32 if self.__llm_dtype__ == 'auto': if config_dtype == torch.float32: torch_dtype = torch.float16 else: torch_dtype = config_dtype else: - if self.__llm_dtype__ not in _map: raise ValueError(f"Unknown dtype '{self.__llm_dtype__}'") + if self.__llm_dtype__ not in _map: + raise ValueError(f"Unknown dtype '{self.__llm_dtype__}'") torch_dtype = _map[self.__llm_dtype__] self.__llm_torch_dtype__ = torch_dtype return self.__llm_torch_dtype__ + @property - def _model_attrs(self): return {**self.import_kwargs[0], **self.__model_attrs} + def _model_attrs(self): + return {**self.import_kwargs[0], **self.__model_attrs} + @_model_attrs.setter - def _model_attrs(self, model_attrs): self.__model_attrs = model_attrs + def _model_attrs(self, model_attrs): + self.__model_attrs = model_attrs + @property - def _tokenizer_attrs(self): return {**self.import_kwargs[1], **self.__tokenizer_attrs} + def _tokenizer_attrs(self): + return {**self.import_kwargs[1], **self.__tokenizer_attrs} + def _cascade_backend(self) -> LiteralBackend: logger.warning('It is recommended to specify the backend explicitly. Cascading backend might lead to unexpected behaviour.') if self._has_gpus: @@ -271,35 +318,61 @@ class LLM(t.Generic[M, T]): return 'ctranslate' else: return 'pt' + def __setattr__(self, attr, value): - if attr in {'model', 'tokenizer', 'runner', 'import_kwargs'}: raise ForbiddenAttributeError(f'{attr} should not be set during runtime.') + if attr in {'model', 'tokenizer', 'runner', 'import_kwargs'}: + raise ForbiddenAttributeError(f'{attr} should not be set during runtime.') super().__setattr__(attr, value) + def __del__(self): try: del self.__llm_model__, self.__llm_tokenizer__, self.__llm_adapter_map__ except AttributeError: pass - def __repr_args__(self): yield from (('model_id', self._model_id if not self._local else self.tag.name), ('revision', self._revision if self._revision else self.tag.version), ('backend', self.__llm_backend__), ('type', self.llm_type)) - def __repr__(self) -> str: return f'{self.__class__.__name__} {orjson.dumps({k: v for k, v in self.__repr_args__()}, option=orjson.OPT_INDENT_2).decode()}' + + def __repr_args__(self): + yield from ( + ('model_id', self._model_id if not self._local else self.tag.name), + ('revision', self._revision if self._revision else self.tag.version), + ('backend', self.__llm_backend__), + ('type', self.llm_type), + ) + + def __repr__(self) -> str: + return f'{self.__class__.__name__} {orjson.dumps({k: v for k, v in self.__repr_args__()}, option=orjson.OPT_INDENT_2).decode()}' + @property - def import_kwargs(self): return {'device_map': 'auto' if self._has_gpus else None, 'torch_dtype': self._torch_dtype}, {'padding_side': 'left', 'truncation_side': 'left'} + def import_kwargs(self): + return {'device_map': 'auto' if self._has_gpus else None, 'torch_dtype': self._torch_dtype}, {'padding_side': 'left', 'truncation_side': 'left'} + @property def trust_remote_code(self): env = os.getenv('TRUST_REMOTE_CODE') - if env is not None: return check_bool_env('TRUST_REMOTE_CODE', env) + if env is not None: + return check_bool_env('TRUST_REMOTE_CODE', env) return self.__llm_trust_remote_code__ + @property - def model_id(self): return self._model_id + def model_id(self): + return self._model_id + @property - def revision(self): return self._revision + def revision(self): + return self._revision + @property - def tag(self): return self._tag + def tag(self): + return self._tag + @property - def bentomodel(self): return openllm.serialisation.get(self) + def bentomodel(self): + return openllm.serialisation.get(self) + @property def quantization_config(self): if self.__llm_quantization_config__ is None: from ._quantisation import infer_quantisation_config + if self._quantization_config is not None: self.__llm_quantization_config__ = self._quantization_config elif self._quantise is not None: @@ -307,16 +380,27 @@ class LLM(t.Generic[M, T]): else: raise ValueError("Either 'quantization_config' or 'quantise' must be specified.") return self.__llm_quantization_config__ + @property - def has_adapters(self): return self._adapter_map is not None + def has_adapters(self): + return self._adapter_map is not None + @property - def local(self): return self._local + def local(self): + return self._local + @property - def quantise(self): return self._quantise + def quantise(self): + return self._quantise + @property - def llm_type(self): return normalise_model_name(self._model_id) + def llm_type(self): + return normalise_model_name(self._model_id) + @property - def llm_parameters(self): return (self._model_decls, self._model_attrs), self._tokenizer_attrs + def llm_parameters(self): + return (self._model_decls, self._model_attrs), self._tokenizer_attrs + @property def identifying_params(self): return { @@ -324,43 +408,55 @@ class LLM(t.Generic[M, T]): 'model_ids': orjson.dumps(self.config['model_ids']).decode(), 'model_id': self.model_id, } + @property def tokenizer(self): - if self.__llm_tokenizer__ is None: self.__llm_tokenizer__ = openllm.serialisation.load_tokenizer(self, **self.llm_parameters[-1]) + if self.__llm_tokenizer__ is None: + self.__llm_tokenizer__ = openllm.serialisation.load_tokenizer(self, **self.llm_parameters[-1]) return self.__llm_tokenizer__ + @property def runner(self): from ._runners import runner - if self.__llm_runner__ is None: self.__llm_runner__ = runner(self) + + if self.__llm_runner__ is None: + self.__llm_runner__ = runner(self) return self.__llm_runner__ + def prepare(self, adapter_type='lora', use_gradient_checking=True, **attrs): - if self.__llm_backend__ != 'pt': raise RuntimeError('Fine tuning is only supported for PyTorch backend.') + if self.__llm_backend__ != 'pt': + raise RuntimeError('Fine tuning is only supported for PyTorch backend.') from peft.mapping import get_peft_model from peft.utils.other import prepare_model_for_kbit_training model = get_peft_model( prepare_model_for_kbit_training(self.model, use_gradient_checkpointing=use_gradient_checking), - self.config['fine_tune_strategies'] - .get(adapter_type, self.config.make_fine_tune_config(adapter_type)) - .train().with_config(**attrs).build(), + self.config['fine_tune_strategies'].get(adapter_type, self.config.make_fine_tune_config(adapter_type)).train().with_config(**attrs).build(), ) - if DEBUG: model.print_trainable_parameters() + if DEBUG: + model.print_trainable_parameters() return model, self.tokenizer - def prepare_for_training(self, *args, **attrs): logger.warning('`prepare_for_training` is deprecated and will be removed in the future. Use `prepare` instead.'); return self.prepare(*args, **attrs) + + def prepare_for_training(self, *args, **attrs): + logger.warning('`prepare_for_training` is deprecated and will be removed in the future. Use `prepare` instead.') + return self.prepare(*args, **attrs) + @property def adapter_map(self): - if not is_peft_available(): raise MissingDependencyError("Failed to import 'peft'. Make sure to do 'pip install \"openllm[fine-tune]\"'") - if not self.has_adapters: raise AttributeError('Adapter map is not available.') + if not is_peft_available(): + raise MissingDependencyError("Failed to import 'peft'. Make sure to do 'pip install \"openllm[fine-tune]\"'") + if not self.has_adapters: + raise AttributeError('Adapter map is not available.') assert self._adapter_map is not None if self.__llm_adapter_map__ is None: _map: ResolvedAdapterMap = {k: {} for k in self._adapter_map} for adapter_type, adapter_tuple in self._adapter_map.items(): - base = first_not_none( - self.config['fine_tune_strategies'].get(adapter_type), default=self.config.make_fine_tune_config(adapter_type), - ) - for adapter in adapter_tuple: _map[adapter_type][adapter.name] = (base.with_config(**adapter.config).build(), adapter.adapter_id) + base = first_not_none(self.config['fine_tune_strategies'].get(adapter_type), default=self.config.make_fine_tune_config(adapter_type)) + for adapter in adapter_tuple: + _map[adapter_type][adapter.name] = (base.with_config(**adapter.config).build(), adapter.adapter_id) self.__llm_adapter_map__ = _map return self.__llm_adapter_map__ + @property def model(self): if self.__llm_model__ is None: @@ -368,7 +464,10 @@ class LLM(t.Generic[M, T]): # If OOM, then it is probably you don't have enough VRAM to run this model. if self.__llm_backend__ == 'pt': import torch - loaded_in_kbit = getattr(model, 'is_loaded_in_8bit', False) or getattr(model, 'is_loaded_in_4bit', False) or getattr(model, 'is_quantized', False) + + loaded_in_kbit = ( + getattr(model, 'is_loaded_in_8bit', False) or getattr(model, 'is_loaded_in_4bit', False) or getattr(model, 'is_quantized', False) + ) if torch.cuda.is_available() and torch.cuda.device_count() == 1 and not loaded_in_kbit: try: model = model.to('cuda') @@ -381,9 +480,11 @@ class LLM(t.Generic[M, T]): model.load_adapter(peft_model_id, adapter_name, peft_config=peft_config) self.__llm_model__ = model return self.__llm_model__ + @property def config(self): import transformers + if self.__llm_config__ is None: if self.__llm_backend__ == 'ctranslate': try: @@ -397,26 +498,41 @@ class LLM(t.Generic[M, T]): ).model_construct_env(**self._model_attrs) break else: - raise OpenLLMException(f"Failed to infer the configuration class. Make sure the model is a supported model. Supported models are: {', '.join(openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE.keys())}") + raise OpenLLMException( + f"Failed to infer the configuration class. Make sure the model is a supported model. Supported models are: {', '.join(openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE.keys())}" + ) else: config = openllm.AutoConfig.infer_class_from_llm(self).model_construct_env(**self._model_attrs) self.__llm_config__ = config return self.__llm_config__ + @functools.lru_cache(maxsize=1) def _torch_dtype_mapping() -> dict[str, torch.dtype]: - import torch; return { - 'half': torch.float16, 'float16': torch.float16, # - 'float': torch.float32, 'float32': torch.float32, # + import torch + + return { + 'half': torch.float16, + 'float16': torch.float16, # + 'float': torch.float32, + 'float32': torch.float32, # 'bfloat16': torch.bfloat16, } -def normalise_model_name(name: str) -> str: return os.path.basename(resolve_filepath(name)) if validate_is_path(name) else inflection.dasherize(name.replace('/', '--')) + + +def normalise_model_name(name: str) -> str: + return os.path.basename(resolve_filepath(name)) if validate_is_path(name) else inflection.dasherize(name.replace('/', '--')) + + def convert_peft_config_type(adapter_map: dict[str, str]) -> AdapterMap: - if not is_peft_available(): raise RuntimeError("LoRA adapter requires 'peft' to be installed. Make sure to do 'pip install \"openllm[fine-tune]\"'") + if not is_peft_available(): + raise RuntimeError("LoRA adapter requires 'peft' to be installed. Make sure to do 'pip install \"openllm[fine-tune]\"'") from huggingface_hub import hf_hub_download + resolved: AdapterMap = {} for path_or_adapter_id, name in adapter_map.items(): - if name is None: raise ValueError('Adapter name must be specified.') + if name is None: + raise ValueError('Adapter name must be specified.') if os.path.isfile(os.path.join(path_or_adapter_id, PEFT_CONFIG_NAME)): config_file = os.path.join(path_or_adapter_id, PEFT_CONFIG_NAME) else: @@ -424,8 +540,10 @@ def convert_peft_config_type(adapter_map: dict[str, str]) -> AdapterMap: config_file = hf_hub_download(path_or_adapter_id, PEFT_CONFIG_NAME) except Exception as err: raise ValueError(f"Can't find '{PEFT_CONFIG_NAME}' at '{path_or_adapter_id}'") from err - with open(config_file, 'r') as file: resolved_config = orjson.loads(file.read()) + with open(config_file, 'r') as file: + resolved_config = orjson.loads(file.read()) _peft_type = resolved_config['peft_type'].lower() - if _peft_type not in resolved: resolved[_peft_type] = () + if _peft_type not in resolved: + resolved[_peft_type] = () resolved[_peft_type] += (_AdapterTuple((path_or_adapter_id, name, resolved_config)),) return resolved diff --git a/openllm-python/src/openllm/_llm.pyi b/openllm-python/src/openllm/_llm.pyi index d4c62bd6..910a12df 100644 --- a/openllm-python/src/openllm/_llm.pyi +++ b/openllm-python/src/openllm/_llm.pyi @@ -8,16 +8,7 @@ from peft.peft_model import PeftModel, PeftModelForCausalLM, PeftModelForSeq2Seq from bentoml import Model, Tag from openllm_core import LLMConfig from openllm_core._schemas import GenerationOutput -from openllm_core._typing_compat import ( - AdapterMap, - AdapterType, - LiteralBackend, - LiteralDtype, - LiteralQuantise, - LiteralSerialisation, - M, - T, -) +from openllm_core._typing_compat import AdapterMap, AdapterType, LiteralBackend, LiteralDtype, LiteralQuantise, LiteralSerialisation, M, T from ._quantisation import QuantizationConfig from ._runners import Runner @@ -121,9 +112,7 @@ class LLM(Generic[M, T]): def runner(self) -> Runner[M, T]: ... @property def adapter_map(self) -> ResolvedAdapterMap: ... - def prepare( - self, adapter_type: AdapterType = ..., use_gradient_checking: bool = ..., **attrs: Any - ) -> Tuple[InjectedModel, T]: ... + def prepare(self, adapter_type: AdapterType = ..., use_gradient_checking: bool = ..., **attrs: Any) -> Tuple[InjectedModel, T]: ... async def generate( self, prompt: Optional[str], diff --git a/openllm-python/src/openllm/_quantisation.py b/openllm-python/src/openllm/_quantisation.py index 5e430662..1765a564 100644 --- a/openllm-python/src/openllm/_quantisation.py +++ b/openllm-python/src/openllm/_quantisation.py @@ -1,8 +1,11 @@ from __future__ import annotations from openllm_core.exceptions import MissingDependencyError from openllm_core.utils import is_autoawq_available, is_autogptq_available, is_bitsandbytes_available + + def infer_quantisation_config(llm, quantise, **attrs): import torch, transformers + # 8 bit configuration int8_threshold = attrs.pop('llm_int8_threshhold', 6.0) int8_enable_fp32_cpu_offload = attrs.pop('llm_int8_enable_fp32_cpu_offload', False) diff --git a/openllm-python/src/openllm/_quantisation.pyi b/openllm-python/src/openllm/_quantisation.pyi index d41809f7..a3e01878 100644 --- a/openllm-python/src/openllm/_quantisation.pyi +++ b/openllm-python/src/openllm/_quantisation.pyi @@ -9,18 +9,10 @@ from ._llm import LLM QuantizationConfig = Union[BitsAndBytesConfig, GPTQConfig, AwqConfig] @overload -def infer_quantisation_config( - self: LLM[M, T], quantise: Literal['int8', 'int4'], **attrs: Any -) -> tuple[BitsAndBytesConfig, Dict[str, Any]]: ... +def infer_quantisation_config(self: LLM[M, T], quantise: Literal['int8', 'int4'], **attrs: Any) -> tuple[BitsAndBytesConfig, Dict[str, Any]]: ... @overload -def infer_quantisation_config( - self: LLM[M, T], quantise: Literal['gptq'], **attrs: Any -) -> tuple[GPTQConfig, Dict[str, Any]]: ... +def infer_quantisation_config(self: LLM[M, T], quantise: Literal['gptq'], **attrs: Any) -> tuple[GPTQConfig, Dict[str, Any]]: ... @overload -def infer_quantisation_config( - self: LLM[M, T], quantise: Literal['awq'], **attrs: Any -) -> tuple[AwqConfig, Dict[str, Any]]: ... +def infer_quantisation_config(self: LLM[M, T], quantise: Literal['awq'], **attrs: Any) -> tuple[AwqConfig, Dict[str, Any]]: ... @overload -def infer_quantisation_config( - self: LLM[M, T], quantise: LiteralQuantise, **attrs: Any -) -> tuple[QuantizationConfig, Dict[str, Any]]: ... +def infer_quantisation_config(self: LLM[M, T], quantise: LiteralQuantise, **attrs: Any) -> tuple[QuantizationConfig, Dict[str, Any]]: ... diff --git a/openllm-python/src/openllm/_runners.py b/openllm-python/src/openllm/_runners.py index 23af330e..e0701692 100644 --- a/openllm-python/src/openllm/_runners.py +++ b/openllm-python/src/openllm/_runners.py @@ -11,13 +11,17 @@ if t.TYPE_CHECKING: _registry = {} __all__ = ['runner'] + def registry(cls=None, *, alias=None): def decorator(_cls): _registry[_cls.__name__[:-8].lower() if alias is None else alias] = _cls return _cls - if cls is None: return decorator + + if cls is None: + return decorator return decorator(cls) + def runner(llm: openllm.LLM[M, T]) -> Runner[M, T]: try: assert llm.bentomodel @@ -25,32 +29,36 @@ def runner(llm: openllm.LLM[M, T]) -> Runner[M, T]: raise RuntimeError(f'Failed to locate {llm.bentomodel}: {err}') from err return types.new_class( - llm.config.__class__.__name__[:-6] + 'Runner', (bentoml.Runner,), # - exec_body=lambda ns: ns.update( - { - 'llm_type': llm.llm_type, 'identifying_params': llm.identifying_params, # - 'llm_tag': llm.tag, 'llm': llm, 'config': llm.config, 'backend': llm.__llm_backend__, # - '__module__': llm.__module__, '__repr__': ReprMixin.__repr__, # - '__doc__': llm.config.__class__.__doc__ or f'Generated Runner class for {llm.config["model_name"]}', - '__repr_keys__': property(lambda _: {'config', 'llm_type', 'runner_methods', 'backend', 'llm_tag'}), - '__repr_args__': lambda _: ( - ( - 'runner_methods', - { - method.name: {'batchable': method.config.batchable, 'batch_dim': method.config.batch_dim if method.config.batchable else None} - for method in _.runner_methods - }, - ), - ('config', llm.config.model_dump(flatten=True)), - ('llm_type', llm.llm_type), - ('backend', llm.__llm_backend__), - ('llm_tag', llm.tag), + llm.config.__class__.__name__[:-6] + 'Runner', + (bentoml.Runner,), # + exec_body=lambda ns: ns.update({ + 'llm_type': llm.llm_type, + 'identifying_params': llm.identifying_params, # + 'llm_tag': llm.tag, + 'llm': llm, + 'config': llm.config, + 'backend': llm.__llm_backend__, # + '__module__': llm.__module__, + '__repr__': ReprMixin.__repr__, # + '__doc__': llm.config.__class__.__doc__ or f'Generated Runner class for {llm.config["model_name"]}', + '__repr_keys__': property(lambda _: {'config', 'llm_type', 'runner_methods', 'backend', 'llm_tag'}), + '__repr_args__': lambda _: ( + ( + 'runner_methods', + { + method.name: {'batchable': method.config.batchable, 'batch_dim': method.config.batch_dim if method.config.batchable else None} + for method in _.runner_methods + }, ), - 'has_adapters': llm.has_adapters, - 'template': llm.config.template, - 'system_message': llm.config.system_message, - } - ), + ('config', llm.config.model_dump(flatten=True)), + ('llm_type', llm.llm_type), + ('backend', llm.__llm_backend__), + ('llm_tag', llm.tag), + ), + 'has_adapters': llm.has_adapters, + 'template': llm.config.template, + 'system_message': llm.config.system_message, + }), )( _registry[llm.__llm_backend__], name=f"llm-{llm.config['start_name']}-runner", @@ -59,32 +67,44 @@ def runner(llm: openllm.LLM[M, T]) -> Runner[M, T]: runnable_init_params={'llm': llm}, ) + @registry class CTranslateRunnable(bentoml.Runnable): SUPPORTED_RESOURCES = ('nvidia.com/gpu', 'cpu') SUPPORTS_CPU_MULTI_THREADING = True + def __init__(self, llm): - if not is_ctranslate_available(): raise openllm.exceptions.OpenLLMException('ctranslate is not installed. Do `pip install "openllm[ctranslate]"`') + if not is_ctranslate_available(): + raise openllm.exceptions.OpenLLMException('ctranslate is not installed. Do `pip install "openllm[ctranslate]"`') self.llm, self.config, self.model, self.tokenizer = llm, llm.config, llm.model, llm.tokenizer + @bentoml.Runnable.method(batchable=False) async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs): config, sampling_params = self.config.model_construct_env(stop=list(stop), **attrs).inference_options(self.llm) cumulative_logprob, output_token_ids, input_len = 0.0, list(prompt_token_ids), len(prompt_token_ids) tokens = self.tokenizer.convert_ids_to_tokens(prompt_token_ids) async for request_output in self.model.async_generate_tokens(tokens, **sampling_params): - if config['logprobs']: cumulative_logprob += request_output.log_prob + if config['logprobs']: + cumulative_logprob += request_output.log_prob output_token_ids.append(request_output.token_id) text = self.tokenizer.decode( - output_token_ids[input_len:], skip_special_tokens=True, # - spaces_between_special_tokens=False, clean_up_tokenization_spaces=True, # + output_token_ids[input_len:], + skip_special_tokens=True, # + spaces_between_special_tokens=False, + clean_up_tokenization_spaces=True, # ) yield GenerationOutput( - prompt_token_ids=prompt_token_ids, # - prompt='', finished=request_output.is_last, request_id=request_id, # + prompt_token_ids=prompt_token_ids, # + prompt='', + finished=request_output.is_last, + request_id=request_id, # outputs=[ CompletionChunk( - index=0, text=text, finish_reason=None, # - token_ids=output_token_ids[input_len:], cumulative_logprob=cumulative_logprob, # + index=0, + text=text, + finish_reason=None, # + token_ids=output_token_ids[input_len:], + cumulative_logprob=cumulative_logprob, # # TODO: logprobs, but seems like we don't have access to the raw logits ) ], @@ -95,30 +115,39 @@ class CTranslateRunnable(bentoml.Runnable): class vLLMRunnable(bentoml.Runnable): SUPPORTED_RESOURCES = ('nvidia.com/gpu', 'amd.com/gpu', 'cpu') SUPPORTS_CPU_MULTI_THREADING = True + def __init__(self, llm): - if not is_vllm_available(): raise openllm.exceptions.OpenLLMException('vLLM is not installed. Do `pip install "openllm[vllm]"`.') + if not is_vllm_available(): + raise openllm.exceptions.OpenLLMException('vLLM is not installed. Do `pip install "openllm[vllm]"`.') import vllm self.llm, self.config, self.tokenizer = llm, llm.config, llm.tokenizer num_gpus, dev = 1, openllm.utils.device_count() - if dev >= 2: num_gpus = min(dev // 2 * 2, dev) + if dev >= 2: + num_gpus = min(dev // 2 * 2, dev) try: self.model = vllm.AsyncLLMEngine.from_engine_args( vllm.AsyncEngineArgs( - worker_use_ray=False, engine_use_ray=False, # - tokenizer_mode='auto', tensor_parallel_size=num_gpus, # - model=llm.bentomodel.path, tokenizer=llm.bentomodel.path, # - trust_remote_code=llm.trust_remote_code, dtype=llm._torch_dtype, # - max_model_len=llm._max_model_len, gpu_memory_utilization=llm._gpu_memory_utilization, # + worker_use_ray=False, + engine_use_ray=False, # + tokenizer_mode='auto', + tensor_parallel_size=num_gpus, # + model=llm.bentomodel.path, + tokenizer=llm.bentomodel.path, # + trust_remote_code=llm.trust_remote_code, + dtype=llm._torch_dtype, # + max_model_len=llm._max_model_len, + gpu_memory_utilization=llm._gpu_memory_utilization, # quantization=llm.quantise if llm.quantise and llm.quantise in {'awq', 'squeezellm'} else None, ) ) except Exception as err: traceback.print_exc() raise openllm.exceptions.OpenLLMException(f'Failed to initialise vLLMEngine due to the following error:\n{err}') from err + @bentoml.Runnable.method(batchable=False) async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs): - config, sampling_params = self.config.model_construct_env(stop=stop, **attrs).inference_options(self.llm) + _, sampling_params = self.config.model_construct_env(stop=stop, **attrs).inference_options(self.llm) async for request_output in self.model.generate(None, sampling_params, request_id, prompt_token_ids): yield GenerationOutput.from_vllm(request_output).model_dump_json() @@ -127,6 +156,7 @@ class vLLMRunnable(bentoml.Runnable): class PyTorchRunnable(bentoml.Runnable): SUPPORTED_RESOURCES = ('nvidia.com/gpu', 'amd.com/gpu', 'cpu') SUPPORTS_CPU_MULTI_THREADING = True + def __init__(self, llm): self.llm, self.config, self.model, self.tokenizer = llm, llm.config, llm.model, llm.tokenizer self.is_encoder_decoder = llm.model.config.is_encoder_decoder @@ -134,10 +164,13 @@ class PyTorchRunnable(bentoml.Runnable): self.device = llm.model.device else: self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + @bentoml.Runnable.method(batchable=False) async def generate_iterator(self, prompt_token_ids, request_id, stop=None, adapter_name=None, **attrs): from ._generation import get_context_length, prepare_logits_processor - if adapter_name is not None: self.model.set_adapter(adapter_name) + + if adapter_name is not None: + self.model.set_adapter(adapter_name) max_new_tokens = attrs.pop('max_new_tokens', 256) context_length = attrs.pop('context_length', None) @@ -165,9 +198,7 @@ class PyTorchRunnable(bentoml.Runnable): if config['logprobs']: # FIXME: logprobs is not supported raise NotImplementedError('Logprobs is yet to be supported with encoder-decoder models.') encoder_output = self.model.encoder(input_ids=torch.as_tensor([prompt_token_ids], device=self.device))[0] - start_ids = torch.as_tensor( - [[self.model.generation_config.decoder_start_token_id]], dtype=torch.int64, device=self.device - ) + start_ids = torch.as_tensor([[self.model.generation_config.decoder_start_token_id]], dtype=torch.int64, device=self.device) else: start_ids = torch.as_tensor([prompt_token_ids], device=self.device) @@ -195,9 +226,7 @@ class PyTorchRunnable(bentoml.Runnable): ) logits = self.model.lm_head(out[0]) else: - out = self.model( - input_ids=torch.as_tensor([[token]], device=self.device), past_key_values=past_key_values, use_cache=True - ) + out = self.model(input_ids=torch.as_tensor([[token]], device=self.device), past_key_values=past_key_values, use_cache=True) logits = out.logits past_key_values = out.past_key_values if logits_processor: @@ -241,12 +270,7 @@ class PyTorchRunnable(bentoml.Runnable): tmp_output_ids, rfind_start = output_token_ids[input_len:], 0 # XXX: Move this to API server - text = self.tokenizer.decode( - tmp_output_ids, - skip_special_tokens=True, - spaces_between_special_tokens=False, - clean_up_tokenization_spaces=True, - ) + text = self.tokenizer.decode(tmp_output_ids, skip_special_tokens=True, spaces_between_special_tokens=False, clean_up_tokenization_spaces=True) if len(stop) > 0: for it in stop: @@ -255,7 +279,8 @@ class PyTorchRunnable(bentoml.Runnable): text, stopped = text[:pos], True break - if config['logprobs']: sample_logprobs.append({token: token_logprobs}) + if config['logprobs']: + sample_logprobs.append({token: token_logprobs}) yield GenerationOutput( prompt='', @@ -296,7 +321,7 @@ class PyTorchRunnable(bentoml.Runnable): prompt_token_ids=prompt_token_ids, prompt_logprobs=prompt_logprobs if config['prompt_logprobs'] else None, request_id=request_id, - ).model_dump_json() + ).model_dump_json() # Clean del past_key_values, out diff --git a/openllm-python/src/openllm/_runners.pyi b/openllm-python/src/openllm/_runners.pyi index 681d8e57..0b422dcc 100644 --- a/openllm-python/src/openllm/_runners.pyi +++ b/openllm-python/src/openllm/_runners.pyi @@ -1,19 +1,4 @@ -from typing import ( - Any, - AsyncGenerator, - Dict, - Generic, - Iterable, - List, - Literal, - Optional, - Protocol, - Tuple, - Type, - TypeVar, - Union, - final, -) +from typing import Any, AsyncGenerator, Dict, Generic, Iterable, List, Literal, Optional, Protocol, Tuple, Type, TypeVar, Union, final import torch from transformers import PreTrainedModel, PreTrainedTokenizer @@ -89,11 +74,7 @@ class Runner(Protocol[Mo, To]): class generate_iterator(RunnerMethod[List[int], AsyncGenerator[str, None]]): @staticmethod def async_stream( - prompt_token_ids: List[int], - request_id: str, - stop: Optional[Union[Iterable[str], str]] = ..., - adapter_name: Optional[str] = ..., - **attrs: Any, + prompt_token_ids: List[int], request_id: str, stop: Optional[Union[Iterable[str], str]] = ..., adapter_name: Optional[str] = ..., **attrs: Any ) -> AsyncGenerator[str, None]: ... def __init__( diff --git a/openllm-python/src/openllm/_service.py b/openllm-python/src/openllm/_service.py index f99f4339..0a2db77f 100644 --- a/openllm-python/src/openllm/_service.py +++ b/openllm-python/src/openllm/_service.py @@ -6,38 +6,48 @@ from bentoml.io import JSON, Text logger = logging.getLogger(__name__) llm = openllm.LLM[t.Any, t.Any]( - model_id=svars.model_id, model_tag=svars.model_tag, adapter_map=svars.adapter_map, # - serialisation=svars.serialization, trust_remote_code=svars.trust_remote_code, # - max_model_len=svars.max_model_len, gpu_memory_utilization=svars.gpu_memory_utilization, # + model_id=svars.model_id, + model_tag=svars.model_tag, + adapter_map=svars.adapter_map, # + serialisation=svars.serialization, + trust_remote_code=svars.trust_remote_code, # + max_model_len=svars.max_model_len, + gpu_memory_utilization=svars.gpu_memory_utilization, # ) svc = bentoml.Service(name=f"llm-{llm.config['start_name']}-service", runners=[llm.runner]) llm_model_class = openllm.GenerationInput.from_llm_config(llm.config) -@svc.api( - route='/v1/generate', - input=JSON.from_sample(llm_model_class.examples()), - output=JSON.from_sample(openllm.GenerationOutput.examples()), -) -async def generate_v1(input_dict: dict[str, t.Any]) -> dict[str, t.Any]: return (await llm.generate(**llm_model_class(**input_dict).model_dump())).model_dump() -@svc.api( - route='/v1/generate_stream', - input=JSON.from_sample(llm_model_class.examples()), - output=Text(content_type='text/event-stream'), -) +@svc.api(route='/v1/generate', input=JSON.from_sample(llm_model_class.examples()), output=JSON.from_sample(openllm.GenerationOutput.examples())) +async def generate_v1(input_dict: dict[str, t.Any]) -> dict[str, t.Any]: + return (await llm.generate(**llm_model_class(**input_dict).model_dump())).model_dump() + + +@svc.api(route='/v1/generate_stream', input=JSON.from_sample(llm_model_class.examples()), output=Text(content_type='text/event-stream')) async def generate_stream_v1(input_dict: dict[str, t.Any]) -> t.AsyncGenerator[str, None]: - async for it in llm.generate_iterator(**llm_model_class(**input_dict).model_dump()): yield f'data: {it.model_dump_json()}\n\n' + async for it in llm.generate_iterator(**llm_model_class(**input_dict).model_dump()): + yield f'data: {it.model_dump_json()}\n\n' yield 'data: [DONE]\n\n' + _Metadata = openllm.MetadataOutput( - timeout=llm.config['timeout'], model_name=llm.config['model_name'], # - backend=llm.__llm_backend__, model_id=llm.model_id, configuration=llm.config.model_dump_json().decode(), # + timeout=llm.config['timeout'], + model_name=llm.config['model_name'], # + backend=llm.__llm_backend__, + model_id=llm.model_id, + configuration=llm.config.model_dump_json().decode(), # ) -@svc.api(route='/v1/metadata', input=Text(), output=JSON.from_sample(_Metadata.model_dump())) -def metadata_v1(_: str) -> openllm.MetadataOutput: return _Metadata -class MessagesConverterInput(t.TypedDict): add_generation_prompt: bool; messages: t.List[t.Dict[str, t.Any]] +@svc.api(route='/v1/metadata', input=Text(), output=JSON.from_sample(_Metadata.model_dump())) +def metadata_v1(_: str) -> openllm.MetadataOutput: + return _Metadata + + +class MessagesConverterInput(t.TypedDict): + add_generation_prompt: bool + messages: t.List[t.Dict[str, t.Any]] + @svc.api( route='/v1/helpers/messages', @@ -46,7 +56,8 @@ class MessagesConverterInput(t.TypedDict): add_generation_prompt: bool; messages add_generation_prompt=False, messages=[ MessageParam(role='system', content='You are acting as Ernest Hemmingway.'), - MessageParam(role='user', content='Hi there!'), MessageParam(role='assistant', content='Yes?'), # + MessageParam(role='user', content='Hi there!'), + MessageParam(role='assistant', content='Yes?'), # ], ) ), @@ -56,4 +67,5 @@ def helpers_messages_v1(message: MessagesConverterInput) -> str: add_generation_prompt, messages = message['add_generation_prompt'], message['messages'] return llm.tokenizer.apply_chat_template(messages, add_generation_prompt=add_generation_prompt, tokenize=False) -openllm.mount_entrypoints(svc, llm) # HACK: This must always be the last line in this file, as we will do some MK for OpenAPI schema. + +openllm.mount_entrypoints(svc, llm) # HACK: This must always be the last line in this file, as we will do some MK for OpenAPI schema. diff --git a/openllm-python/src/openllm/_service_vars.py b/openllm-python/src/openllm/_service_vars.py index efe8afd9..fad04043 100644 --- a/openllm-python/src/openllm/_service_vars.py +++ b/openllm-python/src/openllm/_service_vars.py @@ -1,3 +1,13 @@ import os, orjson, openllm_core.utils as coreutils -model_id, model_tag, adapter_map, serialization, trust_remote_code = os.environ['OPENLLM_MODEL_ID'], None, orjson.loads(os.getenv('OPENLLM_ADAPTER_MAP', orjson.dumps(None))), os.getenv('OPENLLM_SERIALIZATION', default='safetensors'), coreutils.check_bool_env('TRUST_REMOTE_CODE', False) -max_model_len, gpu_memory_utilization = orjson.loads(os.getenv('MAX_MODEL_LEN', orjson.dumps(None).decode())), orjson.loads(os.getenv('GPU_MEMORY_UTILIZATION', orjson.dumps(0.9).decode())) + +model_id, model_tag, adapter_map, serialization, trust_remote_code = ( + os.environ['OPENLLM_MODEL_ID'], + None, + orjson.loads(os.getenv('OPENLLM_ADAPTER_MAP', orjson.dumps(None))), + os.getenv('OPENLLM_SERIALIZATION', default='safetensors'), + coreutils.check_bool_env('TRUST_REMOTE_CODE', False), +) +max_model_len, gpu_memory_utilization = ( + orjson.loads(os.getenv('MAX_MODEL_LEN', orjson.dumps(None).decode())), + orjson.loads(os.getenv('GPU_MEMORY_UTILIZATION', orjson.dumps(0.9).decode())), +) diff --git a/openllm-python/src/openllm/_strategies.py b/openllm-python/src/openllm/_strategies.py index ec1c297a..0febeaa2 100644 --- a/openllm-python/src/openllm/_strategies.py +++ b/openllm-python/src/openllm/_strategies.py @@ -7,30 +7,42 @@ from bentoml._internal.runner.strategy import THREAD_ENVS __all__ = ['CascadingResourceStrategy', 'get_resource'] logger = logging.getLogger(__name__) + def _strtoul(s: str) -> int: # Return -1 or positive integer sequence string starts with. - if not s: return -1 + if not s: + return -1 idx = 0 for idx, c in enumerate(s): - if not (c.isdigit() or (idx == 0 and c in '+-')): break - if idx + 1 == len(s): idx += 1 # noqa: PLW2901 + if not (c.isdigit() or (idx == 0 and c in '+-')): + break + if idx + 1 == len(s): + idx += 1 # NOTE: idx will be set via enumerate return int(s[:idx]) if idx > 0 else -1 + + def _parse_list_with_prefix(lst: str, prefix: str) -> list[str]: rcs = [] for elem in lst.split(','): # Repeated id results in empty set - if elem in rcs: return [] + if elem in rcs: + return [] # Anything other but prefix is ignored - if not elem.startswith(prefix): break + if not elem.startswith(prefix): + break rcs.append(elem) return rcs + + def _parse_cuda_visible_devices(default_var: str | None = None, respect_env: bool = True) -> list[str] | None: if respect_env: spec = os.environ.get('CUDA_VISIBLE_DEVICES', default_var) - if not spec: return None + if not spec: + return None else: - if default_var is None: raise ValueError('spec is required to be not None when parsing spec.') + if default_var is None: + raise ValueError('spec is required to be not None when parsing spec.') spec = default_var if spec.startswith('GPU-'): @@ -44,48 +56,59 @@ def _parse_cuda_visible_devices(default_var: str | None = None, respect_env: boo for el in spec.split(','): x = _strtoul(el.strip()) # Repeated ordinal results in empty set - if x in rc: return [] + if x in rc: + return [] # Negative value aborts the sequence - if x < 0: break + if x < 0: + break rc.append(x) return [str(i) for i in rc] + + def _raw_device_uuid_nvml() -> list[str] | None: from ctypes import CDLL, byref, c_int, c_void_p, create_string_buffer try: nvml_h = CDLL('libnvidia-ml.so.1') except Exception: - warnings.warn('Failed to find nvidia binding', stacklevel=3); return None + warnings.warn('Failed to find nvidia binding', stacklevel=3) + return None rc = nvml_h.nvmlInit() if rc != 0: - warnings.warn("Can't initialize NVML", stacklevel=3); return None + warnings.warn("Can't initialize NVML", stacklevel=3) + return None dev_count = c_int(-1) rc = nvml_h.nvmlDeviceGetCount_v2(byref(dev_count)) if rc != 0: - warnings.warn('Failed to get available device from system.', stacklevel=3); return None + warnings.warn('Failed to get available device from system.', stacklevel=3) + return None uuids = [] for idx in range(dev_count.value): dev_id = c_void_p() rc = nvml_h.nvmlDeviceGetHandleByIndex_v2(idx, byref(dev_id)) if rc != 0: - warnings.warn(f'Failed to get device handle for {idx}', stacklevel=3); return None + warnings.warn(f'Failed to get device handle for {idx}', stacklevel=3) + return None buf_len = 96 buf = create_string_buffer(buf_len) rc = nvml_h.nvmlDeviceGetUUID(dev_id, buf, buf_len) if rc != 0: - warnings.warn(f'Failed to get device UUID for {idx}', stacklevel=3); return None + warnings.warn(f'Failed to get device UUID for {idx}', stacklevel=3) + return None uuids.append(buf.raw.decode('ascii').strip('\0')) del nvml_h return uuids + class _ResourceMixin: @staticmethod def from_system(cls) -> list[str]: visible_devices = _parse_cuda_visible_devices() if visible_devices is None: if cls.resource_id == 'amd.com/gpu': - if not psutil.LINUX: return [] + if not psutil.LINUX: + return [] # ROCm does not currently have the rocm_smi wheel. # So we need to use the ctypes bindings directly. # we don't want to use CLI because parsing is a pain. @@ -99,7 +122,8 @@ class _ResourceMixin: device_count = c_uint32(0) ret = rocmsmi.rsmi_num_monitor_devices(byref(device_count)) - if ret == rsmi_status_t.RSMI_STATUS_SUCCESS: return [str(i) for i in range(device_count.value)] + if ret == rsmi_status_t.RSMI_STATUS_SUCCESS: + return [str(i) for i in range(device_count.value)] return [] # In this case the binary is not found, returning empty list except (ModuleNotFoundError, ImportError): @@ -116,20 +140,26 @@ class _ResourceMixin: except (ImportError, RuntimeError, AttributeError): return [] return visible_devices + @staticmethod def from_spec(cls, spec) -> list[str]: if isinstance(spec, int): - if spec in (-1, 0): return [] - if spec < -1: raise ValueError('Spec cannot be < -1.') + if spec in (-1, 0): + return [] + if spec < -1: + raise ValueError('Spec cannot be < -1.') return [str(i) for i in range(spec)] elif isinstance(spec, str): - if not spec: return [] - if spec.isdigit(): spec = ','.join([str(i) for i in range(_strtoul(spec))]) + if not spec: + return [] + if spec.isdigit(): + spec = ','.join([str(i) for i in range(_strtoul(spec))]) return _parse_cuda_visible_devices(spec, respect_env=False) elif isinstance(spec, list): return [str(x) for x in spec] else: raise TypeError(f"'{cls.__name__}.from_spec' only supports parsing spec of type int, str, or list, got '{type(spec)}' instead.") + @staticmethod def validate(cls, val: list[t.Any]) -> None: if cls.resource_id == 'amd.com/gpu': @@ -139,83 +169,102 @@ class _ResourceMixin: try: from cuda import cuda + err, *_ = cuda.cuInit(0) - if err != cuda.CUresult.CUDA_SUCCESS: raise RuntimeError('Failed to initialise CUDA runtime binding.') + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError('Failed to initialise CUDA runtime binding.') # correctly parse handle for el in val: if el.startswith(('GPU-', 'MIG-')): uuids = _raw_device_uuid_nvml() - if uuids is None: raise ValueError('Failed to parse available GPUs UUID') - if el not in uuids: raise ValueError(f'Given UUID {el} is not found with available UUID (available: {uuids})') + if uuids is None: + raise ValueError('Failed to parse available GPUs UUID') + if el not in uuids: + raise ValueError(f'Given UUID {el} is not found with available UUID (available: {uuids})') elif el.isdigit(): err, _ = cuda.cuDeviceGet(int(el)) - if err != cuda.CUresult.CUDA_SUCCESS: raise ValueError(f'Failed to get device {el}') + if err != cuda.CUresult.CUDA_SUCCESS: + raise ValueError(f'Failed to get device {el}') except (ImportError, RuntimeError): pass + def _make_resource_class(name: str, resource_kind: str, docstring: str) -> type[bentoml.Resource[t.List[str]]]: return types.new_class( name, (bentoml.Resource[t.List[str]], coreutils.ReprMixin), {'resource_id': resource_kind}, - lambda ns: ns.update( - { - 'resource_id': resource_kind, - 'from_spec': classmethod(_ResourceMixin.from_spec), 'from_system': classmethod(_ResourceMixin.from_system), # - 'validate': classmethod(_ResourceMixin.validate), '__repr_keys__': property(lambda _: {'resource_id'}), # - '__doc__': inspect.cleandoc(docstring), '__module__': 'openllm._strategies', # - } - ), + lambda ns: ns.update({ + 'resource_id': resource_kind, + 'from_spec': classmethod(_ResourceMixin.from_spec), + 'from_system': classmethod(_ResourceMixin.from_system), # + 'validate': classmethod(_ResourceMixin.validate), + '__repr_keys__': property(lambda _: {'resource_id'}), # + '__doc__': inspect.cleandoc(docstring), + '__module__': 'openllm._strategies', # + }), ) + + NvidiaGpuResource = _make_resource_class( 'NvidiaGpuResource', 'nvidia.com/gpu', - '''NVIDIA GPU resource. + """NVIDIA GPU resource. This is a modified version of internal's BentoML's NvidiaGpuResource - where it respects and parse CUDA_VISIBLE_DEVICES correctly.''', + where it respects and parse CUDA_VISIBLE_DEVICES correctly.""", ) AmdGpuResource = _make_resource_class( 'AmdGpuResource', 'amd.com/gpu', - '''AMD GPU resource. + """AMD GPU resource. Since ROCm will respect CUDA_VISIBLE_DEVICES, the behaviour of from_spec, from_system are similar to - ``NvidiaGpuResource``. Currently ``validate`` is not yet supported.''', + ``NvidiaGpuResource``. Currently ``validate`` is not yet supported.""", ) + class CascadingResourceStrategy(bentoml.Strategy, coreutils.ReprMixin): @classmethod def get_worker_count(cls, runnable_class, resource_request, workers_per_resource): - if resource_request is None: resource_request = system_resources() + if resource_request is None: + resource_request = system_resources() # use NVIDIA kind = 'nvidia.com/gpu' nvidia_req = get_resource(resource_request, kind) - if nvidia_req is not None: return 1 + if nvidia_req is not None: + return 1 # use AMD kind = 'amd.com/gpu' amd_req = get_resource(resource_request, kind, validate=False) - if amd_req is not None: return 1 + if amd_req is not None: + return 1 # use CPU cpus = get_resource(resource_request, 'cpu') if cpus is not None and cpus > 0: if runnable_class.SUPPORTS_CPU_MULTI_THREADING: - if isinstance(workers_per_resource, float) and workers_per_resource < 1.0: raise ValueError('Fractional CPU multi threading support is not yet supported.') + if isinstance(workers_per_resource, float) and workers_per_resource < 1.0: + raise ValueError('Fractional CPU multi threading support is not yet supported.') return int(workers_per_resource) return math.ceil(cpus) * workers_per_resource # this should not be reached by user since we always read system resource as default - raise ValueError(f'No known supported resource available for {runnable_class}. Please check your resource request. Leaving it blank will allow BentoML to use system resources.') + raise ValueError( + f'No known supported resource available for {runnable_class}. Please check your resource request. Leaving it blank will allow BentoML to use system resources.' + ) + @classmethod def get_worker_env(cls, runnable_class, resource_request, workers_per_resource, worker_index): cuda_env = os.environ.get('CUDA_VISIBLE_DEVICES', None) disabled = cuda_env in ('', '-1') environ = {} - if resource_request is None: resource_request = system_resources() + if resource_request is None: + resource_request = system_resources() # use NVIDIA kind = 'nvidia.com/gpu' typ = get_resource(resource_request, kind) if typ is not None and len(typ) > 0 and kind in runnable_class.SUPPORTED_RESOURCES: if disabled: - environ['CUDA_VISIBLE_DEVICES'] = cuda_env; return environ + environ['CUDA_VISIBLE_DEVICES'] = cuda_env + return environ environ['CUDA_VISIBLE_DEVICES'] = cls.transpile_workers_to_cuda_envvar(workers_per_resource, typ, worker_index) return environ # use AMD @@ -223,7 +272,8 @@ class CascadingResourceStrategy(bentoml.Strategy, coreutils.ReprMixin): typ = get_resource(resource_request, kind, validate=False) if typ is not None and len(typ) > 0 and kind in runnable_class.SUPPORTED_RESOURCES: if disabled: - environ['CUDA_VISIBLE_DEVICES'] = cuda_env; return environ + environ['CUDA_VISIBLE_DEVICES'] = cuda_env + return environ environ['CUDA_VISIBLE_DEVICES'] = cls.transpile_workers_to_cuda_envvar(workers_per_resource, typ, worker_index) return environ # use CPU @@ -232,17 +282,21 @@ class CascadingResourceStrategy(bentoml.Strategy, coreutils.ReprMixin): environ['CUDA_VISIBLE_DEVICES'] = '-1' # disable gpu if runnable_class.SUPPORTS_CPU_MULTI_THREADING: thread_count = math.ceil(cpus) - for thread_env in THREAD_ENVS: environ[thread_env] = os.environ.get(thread_env, str(thread_count)) + for thread_env in THREAD_ENVS: + environ[thread_env] = os.environ.get(thread_env, str(thread_count)) return environ - for thread_env in THREAD_ENVS: environ[thread_env] = os.environ.get(thread_env, '1') + for thread_env in THREAD_ENVS: + environ[thread_env] = os.environ.get(thread_env, '1') return environ return environ + @staticmethod def transpile_workers_to_cuda_envvar(workers_per_resource, gpus, worker_index): # Convert given workers_per_resource to correct CUDA_VISIBLE_DEVICES string. if isinstance(workers_per_resource, float): # NOTE: We hit this branch when workers_per_resource is set to float, for example 0.5 or 0.25 - if workers_per_resource > 1: raise ValueError('workers_per_resource > 1 is not supported.') + if workers_per_resource > 1: + raise ValueError('workers_per_resource > 1 is not supported.') # We are round the assigned resource here. This means if workers_per_resource=.4 # then it will round down to 2. If workers_per_source=0.6, then it will also round up to 2. assigned_resource_per_worker = round(1 / workers_per_resource) diff --git a/openllm-python/src/openllm/_strategies.pyi b/openllm-python/src/openllm/_strategies.pyi index 75989ae2..8a32d00e 100644 --- a/openllm-python/src/openllm/_strategies.pyi +++ b/openllm-python/src/openllm/_strategies.pyi @@ -13,16 +13,11 @@ class CascadingResourceStrategy: TODO: Support CloudTPUResource """ @classmethod - def get_worker_count( - cls, - runnable_class: Type[bentoml.Runnable], - resource_request: Optional[Dict[str, Any]], - workers_per_resource: float, - ) -> int: - '''Return the number of workers to be used for the given runnable class. + def get_worker_count(cls, runnable_class: Type[bentoml.Runnable], resource_request: Optional[Dict[str, Any]], workers_per_resource: float) -> int: + """Return the number of workers to be used for the given runnable class. Note that for all available GPU, the number of workers will always be 1. - ''' + """ @classmethod def get_worker_env( cls, @@ -31,16 +26,14 @@ class CascadingResourceStrategy: workers_per_resource: Union[int, float], worker_index: int, ) -> Dict[str, Any]: - '''Get worker env for this given worker_index. + """Get worker env for this given worker_index. Args: runnable_class: The runnable class to be run. resource_request: The resource request of the runnable. workers_per_resource: # of workers per resource. worker_index: The index of the worker, start from 0. - ''' + """ @staticmethod - def transpile_workers_to_cuda_envvar( - workers_per_resource: Union[float, int], gpus: List[str], worker_index: int - ) -> str: - '''Convert given workers_per_resource to correct CUDA_VISIBLE_DEVICES string.''' + def transpile_workers_to_cuda_envvar(workers_per_resource: Union[float, int], gpus: List[str], worker_index: int) -> str: + """Convert given workers_per_resource to correct CUDA_VISIBLE_DEVICES string.""" diff --git a/openllm-python/src/openllm/bundle/__init__.py b/openllm-python/src/openllm/bundle/__init__.py index aa20f61b..119a4f60 100644 --- a/openllm-python/src/openllm/bundle/__init__.py +++ b/openllm-python/src/openllm/bundle/__init__.py @@ -4,11 +4,13 @@ from openllm_core._typing_compat import LiteralVersionStrategy from openllm_core.exceptions import OpenLLMException from openllm_core.utils.lazy import VersionInfo, LazyModule + @attr.attrs(eq=False, order=False, slots=True, frozen=True) class RefResolver: git_hash: str = attr.field() version: VersionInfo = attr.field(converter=lambda s: VersionInfo.from_version_string(s)) strategy: LiteralVersionStrategy = attr.field() + @classmethod @functools.lru_cache(maxsize=64) def from_strategy(cls, strategy_or_version: LiteralVersionStrategy | None = None) -> RefResolver: @@ -16,6 +18,7 @@ class RefResolver: if strategy_or_version is None or strategy_or_version == 'release': try: from ghapi.all import GhApi + ghapi = GhApi(owner='bentoml', repo='openllm', authenticate=False) meta = ghapi.repos.get_latest_release() git_hash = ghapi.git.get_ref(ref=f"tags/{meta['name']}")['object']['sha'] @@ -26,11 +29,16 @@ class RefResolver: return cls('latest', '0.0.0', 'latest') else: raise ValueError(f'Unknown strategy: {strategy_or_version}') + @property - def tag(self) -> str: return 'latest' if self.strategy in {'latest', 'nightly'} else repr(self.version) + def tag(self) -> str: + return 'latest' if self.strategy in {'latest', 'nightly'} else repr(self.version) + + __lazy = LazyModule( - __name__, os.path.abspath('__file__'), # + __name__, + os.path.abspath('__file__'), # {'_package': ['create_bento', 'build_editable', 'construct_python_options', 'construct_docker_options']}, - extra_objects={'RefResolver': RefResolver} + extra_objects={'RefResolver': RefResolver}, ) __all__, __dir__, __getattr__ = __lazy.__all__, __lazy.__dir__, __lazy.__getattr__ diff --git a/openllm-python/src/openllm/bundle/_package.py b/openllm-python/src/openllm/bundle/_package.py index 227ce50e..3a5a4579 100644 --- a/openllm-python/src/openllm/bundle/_package.py +++ b/openllm-python/src/openllm/bundle/_package.py @@ -12,14 +12,18 @@ OPENLLM_DEV_BUILD = 'OPENLLM_DEV_BUILD' _service_file = pathlib.Path(os.path.abspath(__file__)).parent.parent / '_service.py' _SERVICE_VARS = '''import orjson;model_id,model_tag,adapter_map,serialization,trust_remote_code,max_model_len,gpu_memory_utilization='{__model_id__}','{__model_tag__}',orjson.loads("""{__model_adapter_map__}"""),'{__model_serialization__}',{__model_trust_remote_code__},{__max_model_len__},{__gpu_memory_utilization__}''' + def build_editable(path, package='openllm'): - if not check_bool_env(OPENLLM_DEV_BUILD, default=False): return None + if not check_bool_env(OPENLLM_DEV_BUILD, default=False): + return None # We need to build the package in editable mode, so that we can import it # TODO: Upgrade to 1.0.3 from build import ProjectBuilder from build.env import IsolatedEnvBuilder + module_location = pkg.source_locations(package) - if not module_location: raise RuntimeError('Could not find the source location of OpenLLM.') + if not module_location: + raise RuntimeError('Could not find the source location of OpenLLM.') pyproject_path = pathlib.Path(module_location).parent.parent / 'pyproject.toml' if os.path.isfile(pyproject_path.__fspath__()): with IsolatedEnvBuilder() as env: @@ -29,56 +33,79 @@ def build_editable(path, package='openllm'): env.install(builder.build_system_requires) return builder.build('wheel', path, config_settings={'--global-option': '--quiet'}) raise RuntimeError('Please install OpenLLM from PyPI or built it from Git source.') + + def construct_python_options(llm, llm_fs, extra_dependencies=None, adapter_map=None): from . import RefResolver + packages = ['scipy', 'bentoml[tracing]>=1.1.10', f'openllm[vllm]>={RefResolver.from_strategy("release").version}'] # apparently bnb misses this one - if adapter_map is not None: packages += ['openllm[fine-tune]'] - if extra_dependencies is not None: packages += [f'openllm[{k}]' for k in extra_dependencies] - if llm.config['requirements'] is not None: packages.extend(llm.config['requirements']) + if adapter_map is not None: + packages += ['openllm[fine-tune]'] + if extra_dependencies is not None: + packages += [f'openllm[{k}]' for k in extra_dependencies] + if llm.config['requirements'] is not None: + packages.extend(llm.config['requirements']) built_wheels = [build_editable(llm_fs.getsyspath('/'), p) for p in ('openllm_core', 'openllm_client', 'openllm')] - return PythonOptions(packages=packages, wheels=[llm_fs.getsyspath(f"/{i.split('/')[-1]}") for i in built_wheels] if all(i for i in built_wheels) else None, lock_packages=False) + return PythonOptions( + packages=packages, + wheels=[llm_fs.getsyspath(f"/{i.split('/')[-1]}") for i in built_wheels] if all(i for i in built_wheels) else None, + lock_packages=False, + ) + + def construct_docker_options(llm, _, quantize, adapter_map, dockerfile_template, serialisation): from openllm_cli.entrypoint import process_environ - environ = process_environ( - llm.config, llm.config['timeout'], - 1.0, None, True, - llm.model_id, None, - llm._serialisation, llm, - use_current_env=False, - ) + + environ = process_environ(llm.config, llm.config['timeout'], 1.0, None, True, llm.model_id, None, llm._serialisation, llm, use_current_env=False) # XXX: We need to quote this so that the envvar in container recognize as valid json environ['OPENLLM_CONFIG'] = f"'{environ['OPENLLM_CONFIG']}'" environ.pop('BENTOML_HOME', None) # NOTE: irrelevant in container environ['NVIDIA_DRIVER_CAPABILITIES'] = 'compute,utility' return DockerOptions(cuda_version='12.1', python_version='3.11', env=environ, dockerfile_template=dockerfile_template) + + @inject def create_bento( - bento_tag, llm_fs, llm, # - quantize, dockerfile_template, # - adapter_map=None, extra_dependencies=None, serialisation=None, # - _bento_store=Provide[BentoMLContainer.bento_store], _model_store=Provide[BentoMLContainer.model_store], + bento_tag, + llm_fs, + llm, # + quantize, + dockerfile_template, # + adapter_map=None, + extra_dependencies=None, + serialisation=None, # + _bento_store=Provide[BentoMLContainer.bento_store], + _model_store=Provide[BentoMLContainer.model_store], ): _serialisation = openllm_core.utils.first_not_none(serialisation, default=llm.config['serialisation']) labels = dict(llm.identifying_params) - labels.update( - { - '_type': llm.llm_type, '_framework': llm.__llm_backend__, - 'start_name': llm.config['start_name'], 'base_name_or_path': llm.model_id, 'bundler': 'openllm.bundle', - **{f'{package.replace("-","_")}_version': importlib.metadata.version(package) for package in {'openllm', 'openllm-core', 'openllm-client'}}, - } - ) - if adapter_map: labels.update(adapter_map) + labels.update({ + '_type': llm.llm_type, + '_framework': llm.__llm_backend__, + 'start_name': llm.config['start_name'], + 'base_name_or_path': llm.model_id, + 'bundler': 'openllm.bundle', + **{f'{package.replace("-","_")}_version': importlib.metadata.version(package) for package in {'openllm', 'openllm-core', 'openllm-client'}}, + }) + if adapter_map: + labels.update(adapter_map) logger.debug("Building Bento '%s' with model backend '%s'", bento_tag, llm.__llm_backend__) logger.debug('Generating service vars %s (dir=%s)', llm.model_id, llm_fs.getsyspath('/')) script = f"# fmt: off\n# GENERATED BY 'openllm build {llm.model_id}'. DO NOT EDIT\n" + _SERVICE_VARS.format( - __model_id__=llm.model_id, __model_tag__=str(llm.tag), # - __model_adapter_map__=orjson.dumps(adapter_map).decode(), __model_serialization__=llm.config['serialisation'], # - __model_trust_remote_code__=str(llm.trust_remote_code), __max_model_len__ = llm._max_model_len, __gpu_memory_utilization__=llm._gpu_memory_utilization, # + __model_id__=llm.model_id, + __model_tag__=str(llm.tag), # + __model_adapter_map__=orjson.dumps(adapter_map).decode(), + __model_serialization__=llm.config['serialisation'], # + __model_trust_remote_code__=str(llm.trust_remote_code), + __max_model_len__=llm._max_model_len, + __gpu_memory_utilization__=llm._gpu_memory_utilization, # ) - if SHOW_CODEGEN: logger.info('Generated _service_vars.py:\n%s', script) + if SHOW_CODEGEN: + logger.info('Generated _service_vars.py:\n%s', script) llm_fs.writetext('_service_vars.py', script) - with open(_service_file.__fspath__(), 'r') as f: service_src = f.read() + with open(_service_file.__fspath__(), 'r') as f: + service_src = f.read() llm_fs.writetext(llm.config['service_name'], service_src) return bentoml.Bento.create( version=bento_tag.version, diff --git a/openllm-python/src/openllm/bundle/_package.pyi b/openllm-python/src/openllm/bundle/_package.pyi index eb43116c..0a6f2f38 100644 --- a/openllm-python/src/openllm/bundle/_package.pyi +++ b/openllm-python/src/openllm/bundle/_package.pyi @@ -7,21 +7,13 @@ from bentoml import Bento, Tag from bentoml._internal.bento import BentoStore from bentoml._internal.bento.build_config import DockerOptions, PythonOptions from bentoml._internal.models.model import ModelStore -from openllm_core._typing_compat import ( - LiteralQuantise, - LiteralSerialisation, - M, - T, -) +from openllm_core._typing_compat import LiteralQuantise, LiteralSerialisation, M, T from .._llm import LLM def build_editable(path: str, package: LiteralString) -> Optional[str]: ... def construct_python_options( - llm: LLM[M, T], - llm_fs: FS, - extra_dependencies: Optional[Tuple[str, ...]] = ..., - adapter_map: Optional[Dict[str, str]] = ..., + llm: LLM[M, T], llm_fs: FS, extra_dependencies: Optional[Tuple[str, ...]] = ..., adapter_map: Optional[Dict[str, str]] = ... ) -> PythonOptions: ... def construct_docker_options( llm: LLM[M, T], diff --git a/openllm-python/src/openllm/client.py b/openllm-python/src/openllm/client.py index 591aecc1..8c5c8fc1 100644 --- a/openllm-python/src/openllm/client.py +++ b/openllm-python/src/openllm/client.py @@ -1,2 +1,10 @@ -def __dir__(): import openllm_client as _client; return sorted(dir(_client)) -def __getattr__(it): import openllm_client as _client; return getattr(_client, it) +def __dir__(): + import openllm_client as _client + + return sorted(dir(_client)) + + +def __getattr__(it): + import openllm_client as _client + + return getattr(_client, it) diff --git a/openllm-python/src/openllm/client.pyi b/openllm-python/src/openllm/client.pyi index 9b221ce8..f26d4fa4 100644 --- a/openllm-python/src/openllm/client.pyi +++ b/openllm-python/src/openllm/client.pyi @@ -1,9 +1,9 @@ -'''OpenLLM Python client. +"""OpenLLM Python client. ```python client = openllm.client.HTTPClient("http://localhost:8080") client.query("What is the difference between gather and scatter?") ``` -''' +""" from openllm_client import AsyncHTTPClient as AsyncHTTPClient, HTTPClient as HTTPClient diff --git a/openllm-python/src/openllm/entrypoints/__init__.py b/openllm-python/src/openllm/entrypoints/__init__.py index fc64d69b..7ffc455a 100644 --- a/openllm-python/src/openllm/entrypoints/__init__.py +++ b/openllm-python/src/openllm/entrypoints/__init__.py @@ -2,10 +2,14 @@ import importlib from openllm_core.utils import LazyModule _import_structure = {'openai': [], 'hf': [], 'cohere': []} + + def mount_entrypoints(svc, llm): for module_name in _import_structure: module = importlib.import_module(f'.{module_name}', __name__) svc = module.mount_to_svc(svc, llm) return svc + + __lazy = LazyModule(__name__, globals()['__file__'], _import_structure, extra_objects={'mount_entrypoints': mount_entrypoints}) __all__, __dir__, __getattr__ = __lazy.__all__, __lazy.__dir__, __lazy.__getattr__ diff --git a/openllm-python/src/openllm/entrypoints/__init__.pyi b/openllm-python/src/openllm/entrypoints/__init__.pyi index 7ad64ccf..dbd38429 100644 --- a/openllm-python/src/openllm/entrypoints/__init__.pyi +++ b/openllm-python/src/openllm/entrypoints/__init__.pyi @@ -1,11 +1,11 @@ -'''Entrypoint for all third-party apps. +"""Entrypoint for all third-party apps. Currently support OpenAI, Cohere compatible API. Each module should implement the following API: - `mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Service: ...` -''' +""" from bentoml import Service from openllm_core._typing_compat import M, T diff --git a/openllm-python/src/openllm/entrypoints/_openapi.py b/openllm-python/src/openllm/entrypoints/_openapi.py index 05d5111b..86515d1d 100644 --- a/openllm-python/src/openllm/entrypoints/_openapi.py +++ b/openllm-python/src/openllm/entrypoints/_openapi.py @@ -11,7 +11,7 @@ from openllm_core.utils import first_not_none OPENAPI_VERSION, API_VERSION = '3.0.2', '1.0' # NOTE: OpenAI schema -LIST_MODELS_SCHEMA = '''\ +LIST_MODELS_SCHEMA = """\ --- consumes: - application/json @@ -41,8 +41,8 @@ responses: owned_by: 'na' schema: $ref: '#/components/schemas/ModelList' -''' -CHAT_COMPLETIONS_SCHEMA = '''\ +""" +CHAT_COMPLETIONS_SCHEMA = """\ --- consumes: - application/json @@ -181,8 +181,8 @@ responses: } } description: Bad Request -''' -COMPLETIONS_SCHEMA = '''\ +""" +COMPLETIONS_SCHEMA = """\ --- consumes: - application/json @@ -334,8 +334,8 @@ responses: } } description: Bad Request -''' -HF_AGENT_SCHEMA = '''\ +""" +HF_AGENT_SCHEMA = """\ --- consumes: - application/json @@ -379,8 +379,8 @@ responses: schema: $ref: '#/components/schemas/HFErrorResponse' description: Not Found -''' -HF_ADAPTERS_SCHEMA = '''\ +""" +HF_ADAPTERS_SCHEMA = """\ --- consumes: - application/json @@ -410,8 +410,8 @@ responses: schema: $ref: '#/components/schemas/HFErrorResponse' description: Not Found -''' -COHERE_GENERATE_SCHEMA = '''\ +""" +COHERE_GENERATE_SCHEMA = """\ --- consumes: - application/json @@ -455,8 +455,8 @@ requestBody: stop_sequences: - "\\n" - "<|endoftext|>" -''' -COHERE_CHAT_SCHEMA = '''\ +""" +COHERE_CHAT_SCHEMA = """\ --- consumes: - application/json @@ -469,7 +469,7 @@ tags: - Cohere x-bentoml-name: cohere_chat summary: Creates a model response for the given chat conversation. -''' +""" _SCHEMAS = {k[:-7].lower(): v for k, v in locals().items() if k.endswith('_SCHEMA')} @@ -504,11 +504,7 @@ class OpenLLMSchemaGenerator(SchemaGenerator): endpoints_info.extend(sub_endpoints) elif not isinstance(route, Route) or not route.include_in_schema: continue - elif ( - inspect.isfunction(route.endpoint) - or inspect.ismethod(route.endpoint) - or isinstance(route.endpoint, functools.partial) - ): + elif inspect.isfunction(route.endpoint) or inspect.ismethod(route.endpoint) or isinstance(route.endpoint, functools.partial): endpoint = route.endpoint.func if isinstance(route.endpoint, functools.partial) else route.endpoint path = self._remove_converter(route.path) for method in route.methods or ['GET']: @@ -555,22 +551,20 @@ def get_generator(title, components=None, tags=None, inject=True): def component_schema_generator(attr_cls, description=None): schema = {'type': 'object', 'required': [], 'properties': {}, 'title': attr_cls.__name__} - schema['description'] = first_not_none( - getattr(attr_cls, '__doc__', None), description, default=f'Generated components for {attr_cls.__name__}' - ) + schema['description'] = first_not_none(getattr(attr_cls, '__doc__', None), description, default=f'Generated components for {attr_cls.__name__}') for field in attr.fields(attr.resolve_types(attr_cls)): attr_type = field.type origin_type = t.get_origin(attr_type) args_type = t.get_args(attr_type) # Map Python types to OpenAPI schema types - if attr_type == str: + if isinstance(attr_type, str): schema_type = 'string' - elif attr_type == int: + elif isinstance(attr_type, int): schema_type = 'integer' - elif attr_type == float: + elif isinstance(attr_type, float): schema_type = 'number' - elif attr_type == bool: + elif isinstance(attr_type, bool): schema_type = 'boolean' elif origin_type is list or origin_type is tuple: schema_type = 'array' @@ -599,10 +593,7 @@ def component_schema_generator(attr_cls, description=None): _SimpleSchema = types.new_class( - '_SimpleSchema', - (object,), - {}, - lambda ns: ns.update({'__init__': lambda self, it: setattr(self, 'it', it), 'asdict': lambda self: self.it}), + '_SimpleSchema', (object,), {}, lambda ns: ns.update({'__init__': lambda self, it: setattr(self, 'it', it), 'asdict': lambda self: self.it}) ) diff --git a/openllm-python/src/openllm/entrypoints/_openapi.pyi b/openllm-python/src/openllm/entrypoints/_openapi.pyi index 4ecb9760..e5502636 100644 --- a/openllm-python/src/openllm/entrypoints/_openapi.pyi +++ b/openllm-python/src/openllm/entrypoints/_openapi.pyi @@ -17,13 +17,8 @@ class OpenLLMSchemaGenerator: def apply_schema(func: Callable[P, Any], **attrs: Any) -> Callable[P, Any]: ... def add_schema_definitions(func: Callable[P, Any]) -> Callable[P, Any]: ... -def append_schemas( - svc: Service, generated_schema: Dict[str, Any], tags_order: Literal['prepend', 'append'] = ..., inject: bool = ... -) -> Service: ... +def append_schemas(svc: Service, generated_schema: Dict[str, Any], tags_order: Literal['prepend', 'append'] = ..., inject: bool = ...) -> Service: ... def component_schema_generator(attr_cls: Type[AttrsInstance], description: Optional[str] = ...) -> Dict[str, Any]: ... def get_generator( - title: str, - components: Optional[List[Type[AttrsInstance]]] = ..., - tags: Optional[List[Dict[str, Any]]] = ..., - inject: bool = ..., + title: str, components: Optional[List[Type[AttrsInstance]]] = ..., tags: Optional[List[Dict[str, Any]]] = ..., inject: bool = ... ) -> OpenLLMSchemaGenerator: ... diff --git a/openllm-python/src/openllm/entrypoints/cohere.py b/openllm-python/src/openllm/entrypoints/cohere.py index 1192e3df..ded3b2a8 100644 --- a/openllm-python/src/openllm/entrypoints/cohere.py +++ b/openllm-python/src/openllm/entrypoints/cohere.py @@ -48,17 +48,19 @@ schemas = get_generator( logger = logging.getLogger(__name__) -def jsonify_attr(obj): return json.dumps(converter.unstructure(obj)) +def jsonify_attr(obj): + return json.dumps(converter.unstructure(obj)) + def error_response(status_code, message): return JSONResponse(converter.unstructure(CohereErrorResponse(text=message)), status_code=status_code.value) + async def check_model(request, model): - if request.model is None or request.model == model: return None - return error_response( - HTTPStatus.NOT_FOUND, - f"Model '{request.model}' does not exists. Try 'GET /v1/models' to see current running models.", - ) + if request.model is None or request.model == model: + return None + return error_response(HTTPStatus.NOT_FOUND, f"Model '{request.model}' does not exists. Try 'GET /v1/models' to see current running models.") + def mount_to_svc(svc, llm): app = Starlette( @@ -74,6 +76,7 @@ def mount_to_svc(svc, llm): svc.mount_asgi_app(app, path=mount_path) return append_schemas(svc, schemas.get_schema(routes=app.routes, mount_path=mount_path), tags_order='append', inject=DEBUG) + @add_schema_definitions async def cohere_generate(req, llm): json_str = await req.body() @@ -130,18 +133,14 @@ async def cohere_generate(req, llm): if final_result is None: return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.') final_result = final_result.with_options( - outputs=[ - output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index]) - for output in final_result.outputs - ] + outputs=[output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index]) for output in final_result.outputs] ) return JSONResponse( converter.unstructure( Generations( id=request_id, generations=[ - Generation(id=request_id, text=output.text, prompt=prompt, finish_reason=output.finish_reason) - for output in final_result.outputs + Generation(id=request_id, text=output.text, prompt=prompt, finish_reason=output.finish_reason) for output in final_result.outputs ], ) ), @@ -165,6 +164,7 @@ def _transpile_cohere_chat_messages(request: CohereChatRequest) -> list[dict[str messages.append({'role': 'user', 'content': request.message}) return messages + @add_schema_definitions async def cohere_chat(req, llm): json_str = await req.body() @@ -247,9 +247,7 @@ async def cohere_chat(req, llm): final_result = res if final_result is None: return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.') - final_result = final_result.with_options( - outputs=[final_result.outputs[0].with_options(text=''.join(texts), token_ids=token_ids)] - ) + final_result = final_result.with_options(outputs=[final_result.outputs[0].with_options(text=''.join(texts), token_ids=token_ids)]) num_prompt_tokens, num_response_tokens = len(final_result.prompt_token_ids), len(token_ids) return JSONResponse( converter.unstructure( diff --git a/openllm-python/src/openllm/entrypoints/cohere.pyi b/openllm-python/src/openllm/entrypoints/cohere.pyi index 51f8d46f..912f1185 100644 --- a/openllm-python/src/openllm/entrypoints/cohere.pyi +++ b/openllm-python/src/openllm/entrypoints/cohere.pyi @@ -14,8 +14,6 @@ from ..protocol.cohere import CohereChatRequest, CohereGenerateRequest def mount_to_svc(svc: Service, llm: LLM[M, T]) -> Service: ... def jsonify_attr(obj: AttrsInstance) -> str: ... def error_response(status_code: HTTPStatus, message: str) -> JSONResponse: ... -async def check_model( - request: Union[CohereGenerateRequest, CohereChatRequest], model: str -) -> Optional[JSONResponse]: ... +async def check_model(request: Union[CohereGenerateRequest, CohereChatRequest], model: str) -> Optional[JSONResponse]: ... async def cohere_generate(req: Request, llm: LLM[M, T]) -> Response: ... async def cohere_chat(req: Request, llm: LLM[M, T]) -> Response: ... diff --git a/openllm-python/src/openllm/entrypoints/hf.py b/openllm-python/src/openllm/entrypoints/hf.py index 51f230b8..f950987c 100644 --- a/openllm-python/src/openllm/entrypoints/hf.py +++ b/openllm-python/src/openllm/entrypoints/hf.py @@ -21,6 +21,7 @@ schemas = get_generator( ) logger = logging.getLogger(__name__) + def mount_to_svc(svc, llm): app = Starlette( debug=True, @@ -34,9 +35,11 @@ def mount_to_svc(svc, llm): svc.mount_asgi_app(app, path=mount_path) return append_schemas(svc, schemas.get_schema(routes=app.routes, mount_path=mount_path), tags_order='append') + def error_response(status_code, message): return JSONResponse(converter.unstructure(HFErrorResponse(message=message, error_code=status_code.value)), status_code=status_code.value) + @add_schema_definitions async def hf_agent(req, llm): json_str = await req.body() @@ -55,9 +58,11 @@ async def hf_agent(req, llm): logger.error('Error while generating: %s', err) return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, 'Error while generating (Check server log).') + @add_schema_definitions def hf_adapters(req, llm): - if not llm.has_adapters: return error_response(HTTPStatus.NOT_FOUND, 'No adapters found.') + if not llm.has_adapters: + return error_response(HTTPStatus.NOT_FOUND, 'No adapters found.') return JSONResponse( { adapter_tuple[1]: {'adapter_name': k, 'adapter_type': adapter_tuple[0].peft_type.value} diff --git a/openllm-python/src/openllm/entrypoints/openai.py b/openllm-python/src/openllm/entrypoints/openai.py index 04f203a9..72f7d350 100644 --- a/openllm-python/src/openllm/entrypoints/openai.py +++ b/openllm-python/src/openllm/entrypoints/openai.py @@ -55,7 +55,9 @@ schemas = get_generator( logger = logging.getLogger(__name__) -def jsonify_attr(obj): return orjson.dumps(converter.unstructure(obj)).decode() +def jsonify_attr(obj): + return orjson.dumps(converter.unstructure(obj)).decode() + def error_response(status_code, message): return JSONResponse( @@ -63,8 +65,10 @@ def error_response(status_code, message): status_code=status_code.value, ) + async def check_model(request, model): - if request.model == model: return None + if request.model == model: + return None return error_response( HTTPStatus.NOT_FOUND, f"Model '{request.model}' does not exists. Try 'GET /v1/models' to see available models.\nTip: If you are migrating from OpenAI, make sure to update your 'model' parameters in the request.", @@ -75,11 +79,13 @@ def create_logprobs(token_ids, top_logprobs, num_output_top_logprobs=None, initi # Create OpenAI-style logprobs. logprobs = LogProbs() last_token_len = 0 - if num_output_top_logprobs: logprobs.top_logprobs = [] + if num_output_top_logprobs: + logprobs.top_logprobs = [] for i, token_id in enumerate(token_ids): step_top_logprobs = top_logprobs[i] token_logprob = None - if step_top_logprobs is not None: token_logprob = step_top_logprobs[token_id] + if step_top_logprobs is not None: + token_logprob = step_top_logprobs[token_id] token = llm.tokenizer.convert_ids_to_tokens(token_id) logprobs.tokens.append(token) logprobs.token_logprobs.append(token_logprob) @@ -100,23 +106,20 @@ def mount_to_svc(svc, llm): app = Starlette( debug=True, routes=[ - Route( - '/models', - functools.partial(apply_schema(list_models, __model_id__=llm.llm_type), llm=llm), - methods=['GET'] - ), - Route( - '/completions', - functools.partial(apply_schema(completions, __model_id__=llm.llm_type), llm=llm), - methods=['POST'], - ), + Route('/models', functools.partial(apply_schema(list_models, __model_id__=llm.llm_type), llm=llm), methods=['GET']), + Route('/completions', functools.partial(apply_schema(completions, __model_id__=llm.llm_type), llm=llm), methods=['POST']), Route( '/chat/completions', - functools.partial(apply_schema(chat_completions, - __model_id__=llm.llm_type, - __chat_template__=orjson.dumps(llm.config.chat_template).decode(), - __chat_messages__=orjson.dumps(llm.config.chat_messages).decode(), - __add_generation_prompt__=str(True) if llm.config.chat_messages is not None else str(False)), llm=llm), + functools.partial( + apply_schema( + chat_completions, + __model_id__=llm.llm_type, + __chat_template__=orjson.dumps(llm.config.chat_template).decode(), + __chat_messages__=orjson.dumps(llm.config.chat_messages).decode(), + __add_generation_prompt__=str(True) if llm.config.chat_messages is not None else str(False), + ), + llm=llm, + ), methods=['POST'], ), Route('/schema', endpoint=lambda req: schemas.OpenAPIResponse(req), include_in_schema=False), @@ -128,7 +131,9 @@ def mount_to_svc(svc, llm): # GET /v1/models @add_schema_definitions -def list_models(_, llm): return JSONResponse(converter.unstructure(ModelList(data=[ModelCard(id=llm.llm_type)])), status_code=HTTPStatus.OK.value) +def list_models(_, llm): + return JSONResponse(converter.unstructure(ModelList(data=[ModelCard(id=llm.llm_type)])), status_code=HTTPStatus.OK.value) + # POST /v1/chat/completions @add_schema_definitions @@ -138,26 +143,36 @@ async def chat_completions(req, llm): try: request = converter.structure(orjson.loads(json_str), ChatCompletionRequest) except orjson.JSONDecodeError as err: - logger.debug('Sent body: %s', json_str); logger.error('Invalid JSON input received: %s', err) + logger.debug('Sent body: %s', json_str) + logger.error('Invalid JSON input received: %s', err) return error_response(HTTPStatus.BAD_REQUEST, 'Invalid JSON input received (Check server log).') logger.debug('Received chat completion request: %s', request) err_check = await check_model(request, llm.llm_type) - if err_check is not None: return err_check + if err_check is not None: + return err_check - if request.logit_bias is not None and len(request.logit_bias) > 0: return error_response(HTTPStatus.BAD_REQUEST, "'logit_bias' is not yet supported.") + if request.logit_bias is not None and len(request.logit_bias) > 0: + return error_response(HTTPStatus.BAD_REQUEST, "'logit_bias' is not yet supported.") model_name, request_id = request.model, gen_random_uuid('chatcmpl') created_time = int(time.monotonic()) - prompt = llm.tokenizer.apply_chat_template(request.messages, tokenize=False, chat_template=request.chat_template if request.chat_template != 'None' else None, add_generation_prompt=request.add_generation_prompt) + prompt = llm.tokenizer.apply_chat_template( + request.messages, + tokenize=False, + chat_template=request.chat_template if request.chat_template != 'None' else None, + add_generation_prompt=request.add_generation_prompt, + ) logger.debug('Prompt: %r', prompt) config = llm.config.compatible_options(request) - def get_role() -> str: return request.messages[-1]['role'] if not request.add_generation_prompt else 'assistant' # TODO: Support custom role here. + def get_role() -> str: + return request.messages[-1]['role'] if not request.add_generation_prompt else 'assistant' # TODO: Support custom role here. try: result_generator = llm.generate_iterator(prompt, request_id=request_id, **config) except Exception as err: - traceback.print_exc(); logger.error('Error generating completion: %s', err) + traceback.print_exc() + logger.error('Error generating completion: %s', err) return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)') def create_stream_response_json(index, text, finish_reason=None, usage=None): @@ -167,25 +182,30 @@ async def chat_completions(req, llm): model=model_name, choices=[ChatCompletionResponseStreamChoice(index=index, delta=Delta(content=text), finish_reason=finish_reason)], ) - if usage is not None: response.usage = usage + if usage is not None: + response.usage = usage return jsonify_attr(response) async def completion_stream_generator(): # first chunk with role role = get_role() - for i in range(config['n']): yield f'data: {jsonify_attr(ChatCompletionStreamResponse(id=request_id, created=created_time, choices=[ChatCompletionResponseStreamChoice(index=i, delta=Delta(role=role), finish_reason=None)], model=model_name))}\n\n' + for i in range(config['n']): + yield f'data: {jsonify_attr(ChatCompletionStreamResponse(id=request_id, created=created_time, choices=[ChatCompletionResponseStreamChoice(index=i, delta=Delta(role=role), finish_reason=None)], model=model_name))}\n\n' if request.echo: last_message, last_content = request.messages[-1], '' - if last_message.get('content') and last_message.get('role') == role: last_content = last_message['content'] + if last_message.get('content') and last_message.get('role') == role: + last_content = last_message['content'] if last_content: - for i in range(config['n']): yield f'data: {jsonify_attr(ChatCompletionStreamResponse(id=request_id, created=created_time, choices=[ChatCompletionResponseStreamChoice(index=i, delta=Delta(content=last_content), finish_reason=None)], model=model_name))}\n\n' + for i in range(config['n']): + yield f'data: {jsonify_attr(ChatCompletionStreamResponse(id=request_id, created=created_time, choices=[ChatCompletionResponseStreamChoice(index=i, delta=Delta(content=last_content), finish_reason=None)], model=model_name))}\n\n' previous_num_tokens = [0] * config['n'] finish_reason_sent = [False] * config['n'] async for res in result_generator: for output in res.outputs: - if finish_reason_sent[output.index]: continue + if finish_reason_sent[output.index]: + continue yield f'data: {create_stream_response_json(output.index, output.text)}\n\n' previous_num_tokens[output.index] += len(output.token_ids) if output.finish_reason is not None: @@ -197,35 +217,32 @@ async def chat_completions(req, llm): try: # Streaming case - if request.stream: return StreamingResponse(completion_stream_generator(), media_type='text/event-stream') + if request.stream: + return StreamingResponse(completion_stream_generator(), media_type='text/event-stream') # Non-streaming case final_result, texts, token_ids = None, [[]] * config['n'], [[]] * config['n'] async for res in result_generator: - if await req.is_disconnected(): return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.') + if await req.is_disconnected(): + return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.') for output in res.outputs: texts[output.index].append(output.text) token_ids[output.index].extend(output.token_ids) final_result = res - if final_result is None: return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.') + if final_result is None: + return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.') final_result = final_result.with_options( - outputs=[ - output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index]) - for output in final_result.outputs - ] + outputs=[output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index]) for output in final_result.outputs] ) role = get_role() choices = [ - ChatCompletionResponseChoice( - index=output.index, - message=ChatMessage(role=role, content=output.text), - finish_reason=output.finish_reason, - ) + ChatCompletionResponseChoice(index=output.index, message=ChatMessage(role=role, content=output.text), finish_reason=output.finish_reason) for output in final_result.outputs ] if request.echo: last_message, last_content = request.messages[-1], '' - if last_message.get('content') and last_message.get('role') == role: last_content = last_message['content'] + if last_message.get('content') and last_message.get('role') == role: + last_content = last_message['content'] for choice in choices: full_message = last_content + choice.message.content choice.message.content = full_message @@ -236,7 +253,8 @@ async def chat_completions(req, llm): response = ChatCompletionResponse(id=request_id, created=created_time, model=model_name, usage=usage, choices=choices) return JSONResponse(converter.unstructure(response), status_code=HTTPStatus.OK.value) except Exception as err: - traceback.print_exc(); logger.error('Error generating completion: %s', err) + traceback.print_exc() + logger.error('Error generating completion: %s', err) return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)') @@ -248,20 +266,26 @@ async def completions(req, llm): try: request = converter.structure(orjson.loads(json_str), CompletionRequest) except orjson.JSONDecodeError as err: - logger.debug('Sent body: %s', json_str); logger.error('Invalid JSON input received: %s', err) + logger.debug('Sent body: %s', json_str) + logger.error('Invalid JSON input received: %s', err) return error_response(HTTPStatus.BAD_REQUEST, 'Invalid JSON input received (Check server log).') logger.debug('Received legacy completion request: %s', request) err_check = await check_model(request, llm.llm_type) - if err_check is not None: return err_check + if err_check is not None: + return err_check # OpenAI API supports echoing the prompt when max_tokens is 0. echo_without_generation = request.echo and request.max_tokens == 0 - if echo_without_generation: request.max_tokens = 1 # XXX: Hack to make sure we get the prompt back. + if echo_without_generation: + request.max_tokens = 1 # XXX: Hack to make sure we get the prompt back. - if request.suffix is not None: return error_response(HTTPStatus.BAD_REQUEST, "'suffix' is not yet supported.") - if request.logit_bias is not None and len(request.logit_bias) > 0: return error_response(HTTPStatus.BAD_REQUEST, "'logit_bias' is not yet supported.") + if request.suffix is not None: + return error_response(HTTPStatus.BAD_REQUEST, "'suffix' is not yet supported.") + if request.logit_bias is not None and len(request.logit_bias) > 0: + return error_response(HTTPStatus.BAD_REQUEST, "'logit_bias' is not yet supported.") - if not request.prompt: return error_response(HTTPStatus.BAD_REQUEST, 'Please provide a prompt.') + if not request.prompt: + return error_response(HTTPStatus.BAD_REQUEST, 'Please provide a prompt.') prompt = request.prompt # TODO: Support multiple prompts @@ -272,7 +296,8 @@ async def completions(req, llm): try: result_generator = llm.generate_iterator(prompt, request_id=request_id, **config) except Exception as err: - traceback.print_exc(); logger.error('Error generating completion: %s', err) + traceback.print_exc() + logger.error('Error generating completion: %s', err) return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)') # best_of != n then we don't stream @@ -286,7 +311,8 @@ async def completions(req, llm): model=model_name, choices=[CompletionResponseStreamChoice(index=index, text=text, logprobs=logprobs, finish_reason=finish_reason)], ) - if usage: response.usage = usage + if usage: + response.usage = usage return jsonify_attr(response) async def completion_stream_generator(): @@ -301,7 +327,7 @@ async def completions(req, llm): logprobs = None top_logprobs = None if request.logprobs is not None: - top_logprobs = output.logprobs[previous_num_tokens[i]:] + top_logprobs = output.logprobs[previous_num_tokens[i] :] if request.echo and not previous_echo[i]: if not echo_without_generation: @@ -316,7 +342,7 @@ async def completions(req, llm): top_logprobs = res.prompt_logprobs previous_echo[i] = True if request.logprobs is not None: - logprobs = create_logprobs(output.token_ids, output.logprobs[previous_num_tokens[i]:], request.logprobs, len(previous_texts[i]), llm=llm) + logprobs = create_logprobs(output.token_ids, output.logprobs[previous_num_tokens[i] :], request.logprobs, len(previous_texts[i]), llm=llm) previous_num_tokens[i] += len(output.token_ids) previous_texts[i] += output.text yield f'data: {create_stream_response_json(index=i, text=output.text, logprobs=logprobs, finish_reason=output.finish_reason)}\n\n' @@ -329,21 +355,21 @@ async def completions(req, llm): try: # Streaming case - if stream: return StreamingResponse(completion_stream_generator(), media_type='text/event-stream') + if stream: + return StreamingResponse(completion_stream_generator(), media_type='text/event-stream') # Non-streaming case final_result, texts, token_ids = None, [[]] * config['n'], [[]] * config['n'] async for res in result_generator: - if await req.is_disconnected(): return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.') + if await req.is_disconnected(): + return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.') for output in res.outputs: texts[output.index].append(output.text) token_ids[output.index].extend(output.token_ids) final_result = res - if final_result is None: return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.') + if final_result is None: + return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.') final_result = final_result.with_options( - outputs=[ - output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index]) - for output in final_result.outputs - ] + outputs=[output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index]) for output in final_result.outputs] ) choices = [] @@ -355,13 +381,15 @@ async def completions(req, llm): if request.logprobs is not None: if not echo_without_generation: token_ids, top_logprobs = output.token_ids, output.logprobs - if request.echo: token_ids, top_logprobs = prompt_token_ids + token_ids, prompt_logprobs + top_logprobs + if request.echo: + token_ids, top_logprobs = prompt_token_ids + token_ids, prompt_logprobs + top_logprobs else: token_ids, top_logprobs = prompt_token_ids, prompt_logprobs logprobs = create_logprobs(token_ids, top_logprobs, request.logprobs, llm=llm) if not echo_without_generation: output_text = output.text - if request.echo: output_text = prompt_text + output_text + if request.echo: + output_text = prompt_text + output_text else: output_text = prompt_text choice_data = CompletionResponseChoice(index=output.index, text=output_text, logprobs=logprobs, finish_reason=output.finish_reason) diff --git a/openllm-python/src/openllm/entrypoints/openai.pyi b/openllm-python/src/openllm/entrypoints/openai.pyi index 110606bf..829728ba 100644 --- a/openllm-python/src/openllm/entrypoints/openai.pyi +++ b/openllm-python/src/openllm/entrypoints/openai.pyi @@ -14,12 +14,14 @@ from ..protocol.openai import ChatCompletionRequest, CompletionRequest, LogProbs def mount_to_svc(svc: Service, llm: LLM[M, T]) -> Service: ... def jsonify_attr(obj: AttrsInstance) -> str: ... def error_response(status_code: HTTPStatus, message: str) -> JSONResponse: ... -async def check_model( - request: Union[CompletionRequest, ChatCompletionRequest], model: str -) -> Optional[JSONResponse]: ... +async def check_model(request: Union[CompletionRequest, ChatCompletionRequest], model: str) -> Optional[JSONResponse]: ... def create_logprobs( - token_ids: List[int], top_logprobs: List[Dict[int, float]], # - num_output_top_logprobs: Optional[int] = ..., initial_text_offset: int = ..., *, llm: LLM[M, T] + token_ids: List[int], + top_logprobs: List[Dict[int, float]], # + num_output_top_logprobs: Optional[int] = ..., + initial_text_offset: int = ..., + *, + llm: LLM[M, T], ) -> LogProbs: ... def list_models(req: Request, llm: LLM[M, T]) -> Response: ... async def chat_completions(req: Request, llm: LLM[M, T]) -> Response: ... diff --git a/openllm-python/src/openllm/exceptions.py b/openllm-python/src/openllm/exceptions.py index 3422fe58..a2c65c38 100644 --- a/openllm-python/src/openllm/exceptions.py +++ b/openllm-python/src/openllm/exceptions.py @@ -1,7 +1,10 @@ from openllm_core.exceptions import ( - Error as Error, FineTuneStrategyNotSupportedError as FineTuneStrategyNotSupportedError, # - ForbiddenAttributeError as ForbiddenAttributeError, GpuNotAvailableError as GpuNotAvailableError, # - OpenLLMException as OpenLLMException, ValidationError as ValidationError, # + Error as Error, + FineTuneStrategyNotSupportedError as FineTuneStrategyNotSupportedError, # + ForbiddenAttributeError as ForbiddenAttributeError, + GpuNotAvailableError as GpuNotAvailableError, # + OpenLLMException as OpenLLMException, + ValidationError as ValidationError, # MissingAnnotationAttributeError as MissingAnnotationAttributeError, MissingDependencyError as MissingDependencyError, ) diff --git a/openllm-python/src/openllm/models/__init__.py b/openllm-python/src/openllm/models/__init__.py index 945d77b7..d815b197 100644 --- a/openllm-python/src/openllm/models/__init__.py +++ b/openllm-python/src/openllm/models/__init__.py @@ -1,3 +1,5 @@ from __future__ import annotations -import openllm, transformers +import openllm, transformers, typing as t + + def load_model(llm: openllm.LLM, config: transformers.PretrainedConfig, **attrs: t.Any): ... diff --git a/openllm-python/src/openllm/protocol/__init__.py b/openllm-python/src/openllm/protocol/__init__.py index 78c9cbaa..8b715b52 100644 --- a/openllm-python/src/openllm/protocol/__init__.py +++ b/openllm-python/src/openllm/protocol/__init__.py @@ -5,6 +5,7 @@ import typing as t from openllm_core.utils import LazyModule _import_structure: dict[str, list[str]] = {'openai': [], 'cohere': [], 'hf': []} -if t.TYPE_CHECKING: from . import cohere as cohere, hf as hf, openai as openai +if t.TYPE_CHECKING: + from . import cohere as cohere, hf as hf, openai as openai __lazy = LazyModule(__name__, os.path.abspath('__file__'), _import_structure) __all__, __dir__, __getattr__ = __lazy.__all__, __lazy.__dir__, __lazy.__getattr__ diff --git a/openllm-python/src/openllm/protocol/openai.py b/openllm-python/src/openllm/protocol/openai.py index 390f437f..4a2232b0 100644 --- a/openllm-python/src/openllm/protocol/openai.py +++ b/openllm-python/src/openllm/protocol/openai.py @@ -17,10 +17,13 @@ class ErrorResponse: param: t.Optional[str] = None code: t.Optional[str] = None + def _stop_converter(data: t.Union[str, t.List[str]]) -> t.List[str]: - if not data: return None + if not data: + return None return [data] if isinstance(data, str) else data + @attr.define class CompletionRequest: prompt: str diff --git a/openllm-python/src/openllm/serialisation/__init__.py b/openllm-python/src/openllm/serialisation/__init__.py index 0960b6a6..89586d7e 100644 --- a/openllm-python/src/openllm/serialisation/__init__.py +++ b/openllm-python/src/openllm/serialisation/__init__.py @@ -3,10 +3,13 @@ import importlib, typing as t from openllm_core._typing_compat import M, ParamSpec, T, TypeGuard, Concatenate from openllm_core.exceptions import OpenLLMException -if t.TYPE_CHECKING: from bentoml import Model; from .._llm import LLM +if t.TYPE_CHECKING: + from bentoml import Model + from .._llm import LLM P = ParamSpec('P') + def load_tokenizer(llm: LLM[M, T], **tokenizer_attrs: t.Any) -> TypeGuard[T]: import cloudpickle, fs, transformers from bentoml._internal.models.model import CUSTOM_OBJECTS_FILENAME @@ -38,16 +41,17 @@ def load_tokenizer(llm: LLM[M, T], **tokenizer_attrs: t.Any) -> TypeGuard[T]: tokenizer.pad_token_id = tokenizer.eos_token_id return tokenizer + def _make_dispatch_function(fn: str) -> t.Callable[Concatenate[LLM[M, T], P], TypeGuard[M | T | Model]]: def caller(llm: LLM[M, T], *args: P.args, **kwargs: P.kwargs) -> TypeGuard[M | T | Model]: - '''Generic function dispatch to correct serialisation submodules based on LLM runtime. + """Generic function dispatch to correct serialisation submodules based on LLM runtime. > [!NOTE] See 'openllm.serialisation.transformers' if 'llm.__llm_backend__ in ("pt", "vllm")' > [!NOTE] See 'openllm.serialisation.ggml' if 'llm.__llm_backend__="ggml"' > [!NOTE] See 'openllm.serialisation.ctranslate' if 'llm.__llm_backend__="ctranslate"' - ''' + """ if llm.__llm_backend__ == 'ggml': serde = 'ggml' elif llm.__llm_backend__ == 'ctranslate': @@ -57,12 +61,19 @@ def _make_dispatch_function(fn: str) -> t.Callable[Concatenate[LLM[M, T], P], Ty else: raise OpenLLMException(f'Not supported backend {llm.__llm_backend__}') return getattr(importlib.import_module(f'.{serde}', 'openllm.serialisation'), fn)(llm, *args, **kwargs) + return caller + _extras = ['get', 'import_model', 'load_model'] _import_structure = {'ggml', 'transformers', 'ctranslate', 'constants'} __all__ = ['load_tokenizer', *_extras, *_import_structure] -def __dir__() -> t.Sequence[str]: return sorted(__all__) + + +def __dir__() -> t.Sequence[str]: + return sorted(__all__) + + def __getattr__(name: str) -> t.Any: if name == 'load_tokenizer': return load_tokenizer diff --git a/openllm-python/src/openllm/serialisation/__init__.pyi b/openllm-python/src/openllm/serialisation/__init__.pyi index e8ea2b91..069294b4 100644 --- a/openllm-python/src/openllm/serialisation/__init__.pyi +++ b/openllm-python/src/openllm/serialisation/__init__.pyi @@ -1,9 +1,10 @@ -'''Serialisation utilities for OpenLLM. +"""Serialisation utilities for OpenLLM. Currently supports transformers for PyTorch, and vLLM. Currently, GGML format is working in progress. -''' +""" + from typing import Any from bentoml import Model from openllm import LLM @@ -11,11 +12,12 @@ from openllm_core._typing_compat import M, T from . import constants as constants, ggml as ggml, transformers as transformers def load_tokenizer(llm: LLM[M, T], **attrs: Any) -> T: - '''Load the tokenizer from BentoML store. + """Load the tokenizer from BentoML store. By default, it will try to find the bentomodel whether it is in store.. If model is not found, it will raises a ``bentoml.exceptions.NotFound``. - ''' + """ + def get(llm: LLM[M, T]) -> Model: ... def import_model(llm: LLM[M, T], *args: Any, trust_remote_code: bool, **attrs: Any) -> Model: ... def load_model(llm: LLM[M, T], *args: Any, **attrs: Any) -> M: ... diff --git a/openllm-python/src/openllm/serialisation/_helpers.py b/openllm-python/src/openllm/serialisation/_helpers.py index ea364b07..c155a844 100644 --- a/openllm-python/src/openllm/serialisation/_helpers.py +++ b/openllm-python/src/openllm/serialisation/_helpers.py @@ -8,61 +8,82 @@ from openllm_core.utils import is_autogptq_available _object_setattr = object.__setattr__ + def get_hash(config) -> str: _commit_hash = getattr(config, '_commit_hash', None) - if _commit_hash is None: raise ValueError(f'Cannot find commit hash in {config}') + if _commit_hash is None: + raise ValueError(f'Cannot find commit hash in {config}') return _commit_hash + def patch_correct_tag(llm, config, _revision=None) -> None: # NOTE: The following won't hit during local since we generated a correct version based on local path hash It will only hit if we use model from HF Hub - if llm.revision is not None: return + if llm.revision is not None: + return if not llm.local: try: - if _revision is None: _revision = get_hash(config) + if _revision is None: + _revision = get_hash(config) except ValueError: pass - if _revision is None and llm.tag.version is not None: _revision = llm.tag.version + if _revision is None and llm.tag.version is not None: + _revision = llm.tag.version if llm.tag.version is None: - _object_setattr(llm, '_tag', attr.evolve(llm.tag, version=_revision)) # HACK: This copies the correct revision into llm.tag - if llm._revision is None: _object_setattr(llm, '_revision', _revision) # HACK: This copies the correct revision into llm._model_version + _object_setattr(llm, '_tag', attr.evolve(llm.tag, version=_revision)) # HACK: This copies the correct revision into llm.tag + if llm._revision is None: + _object_setattr(llm, '_revision', _revision) # HACK: This copies the correct revision into llm._model_version + def _create_metadata(llm, config, safe_serialisation, trust_remote_code, metadata=None): - if metadata is None: metadata = {} + if metadata is None: + metadata = {} metadata.update({'safe_serialisation': safe_serialisation, '_framework': llm.__llm_backend__}) - if llm.quantise: metadata['_quantize'] = llm.quantise + if llm.quantise: + metadata['_quantize'] = llm.quantise architectures = getattr(config, 'architectures', []) if not architectures: if trust_remote_code: auto_map = getattr(config, 'auto_map', {}) - if not auto_map: raise RuntimeError(f'Failed to determine the architecture from both `auto_map` and `architectures` from {llm.model_id}') + if not auto_map: + raise RuntimeError(f'Failed to determine the architecture from both `auto_map` and `architectures` from {llm.model_id}') autoclass = 'AutoModelForSeq2SeqLM' if llm.config['model_type'] == 'seq2seq_lm' else 'AutoModelForCausalLM' if autoclass not in auto_map: - raise RuntimeError(f"Given model '{llm.model_id}' is yet to be supported with 'auto_map'. OpenLLM currently only support encoder-decoders or decoders only models.") + raise RuntimeError( + f"Given model '{llm.model_id}' is yet to be supported with 'auto_map'. OpenLLM currently only support encoder-decoders or decoders only models." + ) architectures = [auto_map[autoclass]] else: - raise RuntimeError('Failed to determine the architecture for this model. Make sure the `config.json` is valid and can be loaded with `transformers.AutoConfig`') + raise RuntimeError( + 'Failed to determine the architecture for this model. Make sure the `config.json` is valid and can be loaded with `transformers.AutoConfig`' + ) metadata.update({'_pretrained_class': architectures[0], '_revision': get_hash(config) if not llm.local else llm.revision}) return metadata + def _create_signatures(llm, signatures=None): - if signatures is None: signatures = {} + if signatures is None: + signatures = {} if llm.__llm_backend__ == 'pt': if llm.quantise == 'gptq': if not is_autogptq_available(): raise OpenLLMException("Requires 'auto-gptq' and 'optimum'. Install it with 'pip install \"openllm[gptq]\"'") signatures['generate'] = {'batchable': False} else: - signatures.update( - { - k: ModelSignature(batchable=False) - for k in ( - '__call__', 'forward', 'generate', # - 'contrastive_search', 'greedy_search', # - 'sample', 'beam_search', 'beam_sample', # - 'group_beam_search', 'constrained_beam_search', # - ) - } - ) + signatures.update({ + k: ModelSignature(batchable=False) + for k in ( + '__call__', + 'forward', + 'generate', # + 'contrastive_search', + 'greedy_search', # + 'sample', + 'beam_search', + 'beam_sample', # + 'group_beam_search', + 'constrained_beam_search', # + ) + }) elif llm.__llm_backend__ == 'ctranslate': if llm.config['model_type'] == 'seq2seq_lm': non_batch_keys = {'score_file', 'translate_file'} @@ -70,24 +91,37 @@ def _create_signatures(llm, signatures=None): else: non_batch_keys = set() batch_keys = { - 'async_generate_tokens', 'forward_batch', 'generate_batch', # - 'generate_iterable', 'generate_tokens', 'score_batch', 'score_iterable', # + 'async_generate_tokens', + 'forward_batch', + 'generate_batch', # + 'generate_iterable', + 'generate_tokens', + 'score_batch', + 'score_iterable', # } signatures.update({k: ModelSignature(batchable=False) for k in non_batch_keys}) signatures.update({k: ModelSignature(batchable=True) for k in batch_keys}) return signatures + @inject @contextlib.contextmanager def save_model( - llm, config, safe_serialisation, # - trust_remote_code, module, external_modules, # - _model_store=Provide[BentoMLContainer.model_store], _api_version='v2.1.0', # + llm, + config, + safe_serialisation, # + trust_remote_code, + module, + external_modules, # + _model_store=Provide[BentoMLContainer.model_store], + _api_version='v2.1.0', # ): imported_modules = [] bentomodel = bentoml.Model.create( - llm.tag, module=f'openllm.serialisation.{module}', # - api_version=_api_version, options=ModelOptions(), # + llm.tag, + module=f'openllm.serialisation.{module}', # + api_version=_api_version, + options=ModelOptions(), # context=openllm.utils.generate_context('openllm'), labels=openllm.utils.generate_labels(llm), metadata=_create_metadata(llm, config, safe_serialisation, trust_remote_code), @@ -103,9 +137,7 @@ def save_model( bentomodel.flush() bentomodel.save(_model_store) openllm.utils.analytics.track( - openllm.utils.analytics.ModelSaveEvent( - module=bentomodel.info.module, model_size_in_kb=openllm.utils.calc_dir_size(bentomodel.path) / 1024 - ) + openllm.utils.analytics.ModelSaveEvent(module=bentomodel.info.module, model_size_in_kb=openllm.utils.calc_dir_size(bentomodel.path) / 1024) ) finally: bentomodel.exit_cloudpickle_context(imported_modules) diff --git a/openllm-python/src/openllm/serialisation/_helpers.pyi b/openllm-python/src/openllm/serialisation/_helpers.pyi index 7f2628a1..4516599d 100644 --- a/openllm-python/src/openllm/serialisation/_helpers.pyi +++ b/openllm-python/src/openllm/serialisation/_helpers.pyi @@ -10,9 +10,7 @@ from openllm_core._typing_compat import M, T from .._llm import LLM def get_hash(config: transformers.PretrainedConfig) -> str: ... -def patch_correct_tag( - llm: LLM[M, T], config: transformers.PretrainedConfig, _revision: Optional[str] = ... -) -> None: ... +def patch_correct_tag(llm: LLM[M, T], config: transformers.PretrainedConfig, _revision: Optional[str] = ...) -> None: ... @contextmanager def save_model( llm: LLM[M, T], diff --git a/openllm-python/src/openllm/serialisation/constants.py b/openllm-python/src/openllm/serialisation/constants.py index 88614513..17b87eef 100644 --- a/openllm-python/src/openllm/serialisation/constants.py +++ b/openllm-python/src/openllm/serialisation/constants.py @@ -1,7 +1,13 @@ HUB_ATTRS = [ - 'cache_dir', 'code_revision', 'force_download', # - 'local_files_only', 'proxies', 'resume_download', # - 'revision', 'subfolder', 'use_auth_token', # + 'cache_dir', + 'code_revision', + 'force_download', # + 'local_files_only', + 'proxies', + 'resume_download', # + 'revision', + 'subfolder', + 'use_auth_token', # ] CONFIG_FILE_NAME = 'config.json' # the below is similar to peft.utils.other.CONFIG_NAME diff --git a/openllm-python/src/openllm/serialisation/ctranslate/__init__.py b/openllm-python/src/openllm/serialisation/ctranslate/__init__.py index 471e4418..b4a86634 100644 --- a/openllm-python/src/openllm/serialisation/ctranslate/__init__.py +++ b/openllm-python/src/openllm/serialisation/ctranslate/__init__.py @@ -12,9 +12,7 @@ from .._helpers import patch_correct_tag, save_model from ..transformers._helpers import get_tokenizer, process_config if not is_ctranslate_available(): - raise RuntimeError( - "'ctranslate2' is required to use with backend 'ctranslate'. Install it with 'pip install \"openllm[ctranslate]\"'" - ) + raise RuntimeError("'ctranslate2' is required to use with backend 'ctranslate'. Install it with 'pip install \"openllm[ctranslate]\"'") import ctranslate2 from ctranslate2.converters.transformers import TransformersConverter @@ -44,17 +42,11 @@ def import_model(llm, *decls, trust_remote_code, **attrs): config, hub_attrs, attrs = process_config(llm.model_id, trust_remote_code, **attrs) patch_correct_tag(llm, config) tokenizer = get_tokenizer(llm.model_id, trust_remote_code=trust_remote_code, **hub_attrs, **tokenizer_attrs) - with save_model( - llm, config, False, trust_remote_code, 'ctranslate', [importlib.import_module(tokenizer.__module__)] - ) as save_metadata: + with save_model(llm, config, False, trust_remote_code, 'ctranslate', [importlib.import_module(tokenizer.__module__)]) as save_metadata: bentomodel, _ = save_metadata if llm._local: shutil.copytree( - llm.model_id, - bentomodel.path, - symlinks=False, - ignore=shutil.ignore_patterns('.git', 'venv', '__pycache__', '.venv'), - dirs_exist_ok=True, + llm.model_id, bentomodel.path, symlinks=False, ignore=shutil.ignore_patterns('.git', 'venv', '__pycache__', '.venv'), dirs_exist_ok=True ) else: TransformersConverter( @@ -74,9 +66,7 @@ def get(llm): model = bentoml.models.get(llm.tag) backend = model.info.labels['backend'] if backend != llm.__llm_backend__: - raise OpenLLMException( - f"'{model.tag!s}' was saved with backend '{backend}', while loading with '{llm.__llm_backend__}'." - ) + raise OpenLLMException(f"'{model.tag!s}' was saved with backend '{backend}', while loading with '{llm.__llm_backend__}'.") patch_correct_tag( llm, transformers.AutoConfig.from_pretrained(model.path_of('/hf/'), trust_remote_code=llm.trust_remote_code), diff --git a/openllm-python/src/openllm/serialisation/transformers/__init__.py b/openllm-python/src/openllm/serialisation/transformers/__init__.py index 20d917af..d07f0af3 100644 --- a/openllm-python/src/openllm/serialisation/transformers/__init__.py +++ b/openllm-python/src/openllm/serialisation/transformers/__init__.py @@ -15,15 +15,18 @@ logger = logging.getLogger(__name__) __all__ = ['import_model', 'get', 'load_model'] _object_setattr = object.__setattr__ + def import_model(llm, *decls, trust_remote_code, **attrs): (_base_decls, _base_attrs), tokenizer_attrs = llm.llm_parameters decls = (*_base_decls, *decls) attrs = {**_base_attrs, **attrs} - if llm._local: logger.warning('Given model is a local model, OpenLLM will load model into memory for serialisation.') + if llm._local: + logger.warning('Given model is a local model, OpenLLM will load model into memory for serialisation.') config, hub_attrs, attrs = process_config(llm.model_id, trust_remote_code, **attrs) patch_correct_tag(llm, config) safe_serialisation = first_not_none(attrs.get('safe_serialization'), default=llm._serialisation == 'safetensors') - if llm.quantise != 'gptq': attrs['use_safetensors'] = safe_serialisation + if llm.quantise != 'gptq': + attrs['use_safetensors'] = safe_serialisation model = None tokenizer = get_tokenizer(llm.model_id, trust_remote_code=trust_remote_code, **hub_attrs, **tokenizer_attrs) @@ -36,31 +39,30 @@ def import_model(llm, *decls, trust_remote_code, **attrs): attrs['quantization_config'] = llm.quantization_config if llm.quantise == 'gptq': from optimum.gptq.constants import GPTQ_CONFIG + with open(bentomodel.path_of(GPTQ_CONFIG), 'w', encoding='utf-8') as f: f.write(orjson.dumps(config.quantization_config, option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS).decode()) if llm._local: # possible local path model = infer_autoclass_from_llm(llm, config).from_pretrained( - llm.model_id, - *decls, - local_files_only=True, - config=config, - trust_remote_code=trust_remote_code, - **hub_attrs, - **attrs, + llm.model_id, *decls, local_files_only=True, config=config, trust_remote_code=trust_remote_code, **hub_attrs, **attrs ) # for trust_remote_code to work bentomodel.enter_cloudpickle_context([importlib.import_module(model.__module__)], imported_modules) model.save_pretrained(bentomodel.path, max_shard_size='2GB', safe_serialization=safe_serialisation) del model - if torch.cuda.is_available(): torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() else: # we will clone the all tings into the bentomodel path without loading model into memory snapshot_download( - llm.model_id, local_dir=bentomodel.path, # - local_dir_use_symlinks=False, ignore_patterns=HfIgnore.ignore_patterns(llm), # + llm.model_id, + local_dir=bentomodel.path, # + local_dir_use_symlinks=False, + ignore_patterns=HfIgnore.ignore_patterns(llm), # ) return bentomodel + def get(llm): try: model = bentoml.models.get(llm.tag) @@ -79,15 +81,18 @@ def get(llm): def check_unintialised_params(model): unintialized = [n for n, param in model.named_parameters() if param.data.device == torch.device('meta')] - if len(unintialized) > 0: raise RuntimeError(f'Found the following unintialized parameters in {model}: {unintialized}') + if len(unintialized) > 0: + raise RuntimeError(f'Found the following unintialized parameters in {model}: {unintialized}') def load_model(llm, *decls, **attrs): - if llm.quantise in {'awq', 'squeezellm'}: raise RuntimeError('AWQ is not yet supported with PyTorch backend.') + if llm.quantise in {'awq', 'squeezellm'}: + raise RuntimeError('AWQ is not yet supported with PyTorch backend.') config, attrs = transformers.AutoConfig.from_pretrained( llm.bentomodel.path, return_unused_kwargs=True, trust_remote_code=llm.trust_remote_code, **attrs ) - if llm.__llm_backend__ == 'triton': return openllm.models.load_model(llm, config, **attrs) + if llm.__llm_backend__ == 'triton': + return openllm.models.load_model(llm, config, **attrs) auto_class = infer_autoclass_from_llm(llm, config) device_map = attrs.pop('device_map', None) @@ -111,33 +116,30 @@ def load_model(llm, *decls, **attrs): try: model = auto_class.from_pretrained( - llm.bentomodel.path, device_map=device_map, trust_remote_code=llm.trust_remote_code, use_flash_attention_2=is_flash_attn_2_available(), **attrs + llm.bentomodel.path, + device_map=device_map, + trust_remote_code=llm.trust_remote_code, + use_flash_attention_2=is_flash_attn_2_available(), + **attrs, + ) + except Exception as err: + logger.debug("Failed to load model with 'use_flash_attention_2' (lookup for traceback):\n%s", err) + model = auto_class.from_pretrained(llm.bentomodel.path, device_map=device_map, trust_remote_code=llm.trust_remote_code, **attrs) + else: + try: + model = auto_class.from_pretrained( + llm.bentomodel.path, + *decls, + config=config, + trust_remote_code=llm.trust_remote_code, + device_map=device_map, + use_flash_attention_2=is_flash_attn_2_available(), + **attrs, ) except Exception as err: logger.debug("Failed to load model with 'use_flash_attention_2' (lookup for traceback):\n%s", err) model = auto_class.from_pretrained( - llm.bentomodel.path, device_map=device_map, trust_remote_code=llm.trust_remote_code, **attrs + llm.bentomodel.path, *decls, config=config, trust_remote_code=llm.trust_remote_code, device_map=device_map, **attrs ) - else: - try: - model = auto_class.from_pretrained( - llm.bentomodel.path, - *decls, - config=config, - trust_remote_code=llm.trust_remote_code, - device_map=device_map, - use_flash_attention_2=is_flash_attn_2_available(), - **attrs, - ) - except Exception as err: - logger.debug("Failed to load model with 'use_flash_attention_2' (lookup for traceback):\n%s", err) - model = auto_class.from_pretrained( - llm.bentomodel.path, - *decls, - config=config, - trust_remote_code=llm.trust_remote_code, - device_map=device_map, - **attrs, - ) check_unintialised_params(model) return model diff --git a/openllm-python/src/openllm/serialisation/transformers/_helpers.py b/openllm-python/src/openllm/serialisation/transformers/_helpers.py index f77d639a..9b8500e1 100644 --- a/openllm-python/src/openllm/serialisation/transformers/_helpers.py +++ b/openllm-python/src/openllm/serialisation/transformers/_helpers.py @@ -3,26 +3,36 @@ import transformers from openllm.serialisation.constants import HUB_ATTRS logger = logging.getLogger(__name__) + + def get_tokenizer(model_id_or_path, trust_remote_code, **attrs): tokenizer = transformers.AutoTokenizer.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code, **attrs) - if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token return tokenizer + + def process_config(model_id, trust_remote_code, **attrs): config = attrs.pop('config', None) # this logic below is synonymous to handling `from_pretrained` attrs. hub_attrs = {k: attrs.pop(k) for k in HUB_ATTRS if k in attrs} if not isinstance(config, transformers.PretrainedConfig): copied_attrs = copy.deepcopy(attrs) - if copied_attrs.get('torch_dtype', None) == 'auto': copied_attrs.pop('torch_dtype') + if copied_attrs.get('torch_dtype', None) == 'auto': + copied_attrs.pop('torch_dtype') config, attrs = transformers.AutoConfig.from_pretrained( model_id, return_unused_kwargs=True, trust_remote_code=trust_remote_code, **hub_attrs, **copied_attrs ) return config, hub_attrs, attrs + + def infer_autoclass_from_llm(llm, config, /): autoclass = 'AutoModelForSeq2SeqLM' if llm.config['model_type'] == 'seq2seq_lm' else 'AutoModelForCausalLM' if llm.trust_remote_code: if not hasattr(config, 'auto_map'): - raise ValueError(f'Invalid configuration for {llm.model_id}. ``trust_remote_code=True`` requires `transformers.PretrainedConfig` to contain a `auto_map` mapping') + raise ValueError( + f'Invalid configuration for {llm.model_id}. ``trust_remote_code=True`` requires `transformers.PretrainedConfig` to contain a `auto_map` mapping' + ) # in case this model doesn't use the correct auto class for model type, for example like chatglm # where it uses AutoModel instead of AutoModelForCausalLM. Then we fallback to AutoModel if autoclass not in config.auto_map: diff --git a/openllm-python/src/openllm/serialisation/transformers/weights.py b/openllm-python/src/openllm/serialisation/transformers/weights.py index d97e8a5d..0d456947 100644 --- a/openllm-python/src/openllm/serialisation/transformers/weights.py +++ b/openllm-python/src/openllm/serialisation/transformers/weights.py @@ -11,12 +11,14 @@ if t.TYPE_CHECKING: __global_inst__ = None __cached_id__: dict[str, HfModelInfo] = dict() + def Client() -> HfApi: - global __global_inst__ # noqa: PLW0603 + global __global_inst__ if __global_inst__ is None: __global_inst__ = HfApi() return __global_inst__ + def ModelInfo(model_id: str, revision: str | None = None) -> HfModelInfo: if model_id in __cached_id__: return __cached_id__[model_id] @@ -27,14 +29,17 @@ def ModelInfo(model_id: str, revision: str | None = None) -> HfModelInfo: traceback.print_exc() raise Error(f'Failed to fetch {model_id} from huggingface.co') from err + def has_weights(model_id: str, revision: str | None = None, *, extensions: str) -> bool: if validate_is_path(model_id): return next((True for _ in pathlib.Path(resolve_filepath(model_id)).glob(f'*.{extensions}')), False) return any(s.rfilename.endswith(f'.{extensions}') for s in ModelInfo(model_id, revision=revision).siblings) + has_safetensors_weights = functools.partial(has_weights, extensions='safetensors') has_pt_weights = functools.partial(has_weights, extensions='pt') + @attr.define(slots=True) class HfIgnore: safetensors = '*.safetensors' @@ -42,11 +47,12 @@ class HfIgnore: tf = '*.h5' flax = '*.msgpack' gguf = '*.gguf' + @classmethod def ignore_patterns(cls, llm: openllm.LLM[t.Any, t.Any]) -> list[str]: if llm.__llm_backend__ in {'vllm', 'pt'}: base = [cls.tf, cls.flax, cls.gguf] - if llm.config['architecture'] == 'MixtralForCausalLM': # XXX: Hack for Mixtral as safetensors is yet to be working atm + if llm.config['architecture'] == 'MixtralForCausalLM': # XXX: Hack for Mixtral as safetensors is yet to be working atm base.append(cls.safetensors) elif has_safetensors_weights(llm.model_id): base.extend([cls.pt, '*.pt']) diff --git a/openllm-python/src/openllm/utils.py b/openllm-python/src/openllm/utils.py index ca33da9a..3d504fb0 100644 --- a/openllm-python/src/openllm/utils.py +++ b/openllm-python/src/openllm/utils.py @@ -1,16 +1,36 @@ import functools, importlib.metadata, openllm_core + __all__ = ['generate_labels', 'available_devices', 'device_count'] + + def generate_labels(llm): return { - 'backend': llm.__llm_backend__, 'framework': 'openllm', 'model_name': llm.config['model_name'], # - 'architecture': llm.config['architecture'], 'serialisation': llm._serialisation, # + 'backend': llm.__llm_backend__, + 'framework': 'openllm', + 'model_name': llm.config['model_name'], # + 'architecture': llm.config['architecture'], + 'serialisation': llm._serialisation, # **{package: importlib.metadata.version(package) for package in {'openllm', 'openllm-core', 'openllm-client'}}, } -def available_devices(): from ._strategies import NvidiaGpuResource; return tuple(NvidiaGpuResource.from_system()) + + +def available_devices(): + from ._strategies import NvidiaGpuResource + + return tuple(NvidiaGpuResource.from_system()) + + @functools.lru_cache(maxsize=1) -def device_count() -> int: return len(available_devices()) +def device_count() -> int: + return len(available_devices()) + + def __dir__(): - coreutils = set(dir(openllm_core.utils)) | set([it for it in openllm_core.utils._extras if not it.startswith('_')]); return sorted(__all__) + sorted(list(coreutils)) + coreutils = set(dir(openllm_core.utils)) | set([it for it in openllm_core.utils._extras if not it.startswith('_')]) + return sorted(__all__) + sorted(list(coreutils)) + + def __getattr__(it): - if hasattr(openllm_core.utils, it): return getattr(openllm_core.utils, it) + if hasattr(openllm_core.utils, it): + return getattr(openllm_core.utils, it) raise AttributeError(f'module {__name__} has no attribute {it}') diff --git a/openllm-python/src/openllm_cli/__init__.py b/openllm-python/src/openllm_cli/__init__.py index 8d27d1b9..cd72ef8d 100644 --- a/openllm-python/src/openllm_cli/__init__.py +++ b/openllm-python/src/openllm_cli/__init__.py @@ -1,4 +1,4 @@ -'''OpenLLM CLI. +"""OpenLLM CLI. For more information see ``openllm -h``. -''' +""" diff --git a/openllm-python/src/openllm_cli/_factory.py b/openllm-python/src/openllm_cli/_factory.py index 309fbfc0..c173977a 100644 --- a/openllm-python/src/openllm_cli/_factory.py +++ b/openllm-python/src/openllm_cli/_factory.py @@ -5,24 +5,12 @@ from bentoml_cli.utils import BentoMLCommandGroup from click import shell_completion as sc from openllm_core._configuration import LLMConfig -from openllm_core._typing_compat import ( - Concatenate, - DictStrAny, - LiteralBackend, - LiteralSerialisation, - ParamSpec, - AnyCallable, - get_literal_args, -) +from openllm_core._typing_compat import Concatenate, DictStrAny, LiteralBackend, LiteralSerialisation, ParamSpec, AnyCallable, get_literal_args from openllm_core.utils import DEBUG, compose, dantic, resolve_user_filepath + class _OpenLLM_GenericInternalConfig(LLMConfig): - __config__ = { - 'name_type': 'lowercase', - 'default_id': 'openllm/generic', - 'model_ids': ['openllm/generic'], - 'architecture': 'PreTrainedModel', - } + __config__ = {'name_type': 'lowercase', 'default_id': 'openllm/generic', 'model_ids': ['openllm/generic'], 'architecture': 'PreTrainedModel'} class GenerationConfig: top_k: int = 15 @@ -30,6 +18,7 @@ class _OpenLLM_GenericInternalConfig(LLMConfig): temperature: float = 0.75 max_new_tokens: int = 128 + logger = logging.getLogger(__name__) P = ParamSpec('P') @@ -38,6 +27,7 @@ LiteralOutput = t.Literal['json', 'pretty', 'porcelain'] _AnyCallable = t.Callable[..., t.Any] FC = t.TypeVar('FC', bound=t.Union[_AnyCallable, click.Command]) + def bento_complete_envvar(ctx: click.Context, param: click.Parameter, incomplete: str) -> list[sc.CompletionItem]: return [ sc.CompletionItem(str(it.tag), help='Bento') @@ -45,20 +35,13 @@ def bento_complete_envvar(ctx: click.Context, param: click.Parameter, incomplete if str(it.tag).startswith(incomplete) and all(k in it.info.labels for k in {'start_name', 'bundler'}) ] + def model_complete_envvar(ctx: click.Context, param: click.Parameter, incomplete: str) -> list[sc.CompletionItem]: - return [ - sc.CompletionItem(inflection.dasherize(it), help='Model') - for it in openllm.CONFIG_MAPPING - if it.startswith(incomplete) - ] + return [sc.CompletionItem(inflection.dasherize(it), help='Model') for it in openllm.CONFIG_MAPPING if it.startswith(incomplete)] + def parse_config_options( - config: LLMConfig, - server_timeout: int, - workers_per_resource: float, - device: t.Tuple[str, ...] | None, - cors: bool, - environ: DictStrAny, + config: LLMConfig, server_timeout: int, workers_per_resource: float, device: t.Tuple[str, ...] | None, cors: bool, environ: DictStrAny ) -> DictStrAny: # TODO: Support amd.com/gpu on k8s _bentoml_config_options_env = environ.pop('BENTOML_CONFIG_OPTIONS', '') @@ -72,26 +55,16 @@ def parse_config_options( ] if device: if len(device) > 1: - _bentoml_config_options_opts.extend( - [ - f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"[{idx}]={dev}' - for idx, dev in enumerate(device) - ] - ) + _bentoml_config_options_opts.extend([ + f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"[{idx}]={dev}' for idx, dev in enumerate(device) + ]) else: - _bentoml_config_options_opts.append( - f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"=[{device[0]}]' - ) + _bentoml_config_options_opts.append(f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"=[{device[0]}]') if cors: - _bentoml_config_options_opts.extend( - ['api_server.http.cors.enabled=true', 'api_server.http.cors.access_control_allow_origins="*"'] - ) - _bentoml_config_options_opts.extend( - [ - f'api_server.http.cors.access_control_allow_methods[{idx}]="{it}"' - for idx, it in enumerate(['GET', 'OPTIONS', 'POST', 'HEAD', 'PUT']) - ] - ) + _bentoml_config_options_opts.extend(['api_server.http.cors.enabled=true', 'api_server.http.cors.access_control_allow_origins="*"']) + _bentoml_config_options_opts.extend([ + f'api_server.http.cors.access_control_allow_methods[{idx}]="{it}"' for idx, it in enumerate(['GET', 'OPTIONS', 'POST', 'HEAD', 'PUT']) + ]) _bentoml_config_options_env += ' ' if _bentoml_config_options_env else '' + ' '.join(_bentoml_config_options_opts) environ['BENTOML_CONFIG_OPTIONS'] = _bentoml_config_options_env if DEBUG: @@ -119,22 +92,27 @@ def _id_callback(ctx: click.Context, _: click.Parameter, value: t.Tuple[str, ... ctx.params[_adapter_mapping_key][adapter_id] = name return None + def optimization_decorator(fn: FC, *, factory=click, _eager=True) -> FC | list[AnyCallable]: shared = [ - dtype_option(factory=factory), model_version_option(factory=factory), # - backend_option(factory=factory), quantize_option(factory=factory), # + dtype_option(factory=factory), + model_version_option(factory=factory), # + backend_option(factory=factory), + quantize_option(factory=factory), # serialisation_option(factory=factory), ] - if not _eager: return shared + if not _eager: + return shared return compose(*shared)(fn) + def start_decorator(fn: FC) -> FC: composed = compose( _OpenLLM_GenericInternalConfig.parse, parse_serve_args(), cog.optgroup.group( 'LLM Options', - help='''The following options are related to running LLM Server as well as optimization options. + help="""The following options are related to running LLM Server as well as optimization options. OpenLLM supports running model k-bit quantization (8-bit, 4-bit), GPTQ quantization, PagedAttention via vLLM. @@ -142,7 +120,7 @@ def start_decorator(fn: FC) -> FC: - DeepSpeed Inference: [link](https://www.deepspeed.ai/inference/) - GGML: Fast inference on [bare metal](https://github.com/ggerganov/ggml) - ''', + """, ), cog.optgroup.option('--server-timeout', type=int, default=None, help='Server timeout in seconds'), workers_per_resource_option(factory=cog.optgroup), @@ -163,12 +141,14 @@ def start_decorator(fn: FC) -> FC: return composed(fn) + def parse_device_callback(_: click.Context, param: click.Parameter, value: tuple[tuple[str], ...] | None) -> t.Tuple[str, ...] | None: if value is None: return value el: t.Tuple[str, ...] = tuple(i for k in value for i in k) # NOTE: --device all is a special case - if len(el) == 1 and el[0] == 'all': return tuple(map(str, openllm.utils.available_devices())) + if len(el) == 1 and el[0] == 'all': + return tuple(map(str, openllm.utils.available_devices())) return el @@ -182,15 +162,12 @@ def parse_serve_args() -> t.Callable[[t.Callable[..., LLMConfig]], t.Callable[[F from bentoml_cli.cli import cli group = cog.optgroup.group('Start a HTTP server options', help='Related to serving the model [synonymous to `bentoml serve-http`]') + def decorator(f: t.Callable[Concatenate[int, t.Optional[str], P], LLMConfig]) -> t.Callable[[FC], FC]: serve_command = cli.commands['serve'] # The first variable is the argument bento # The last five is from BentoMLCommandGroup.NUMBER_OF_COMMON_PARAMS - serve_options = [ - p - for p in serve_command.params[1 : -BentoMLCommandGroup.NUMBER_OF_COMMON_PARAMS] - if p.name not in _IGNORED_OPTIONS - ] + serve_options = [p for p in serve_command.params[1 : -BentoMLCommandGroup.NUMBER_OF_COMMON_PARAMS] if p.name not in _IGNORED_OPTIONS] for options in reversed(serve_options): attrs = options.to_info_dict() # we don't need param_type_name, since it should all be options @@ -202,14 +179,16 @@ def parse_serve_args() -> t.Callable[[t.Callable[..., LLMConfig]], t.Callable[[F param_decls = (*attrs.pop('opts'), *attrs.pop('secondary_opts')) f = cog.optgroup.option(*param_decls, **attrs)(f) return group(f) + return decorator + def _click_factory_type(*param_decls: t.Any, **attrs: t.Any) -> t.Callable[[FC | None], FC]: - '''General ``@click`` decorator with some sauce. + """General ``@click`` decorator with some sauce. This decorator extends the default ``@click.option`` plus a factory option and factory attr to provide type-safe click.option or click.argument wrapper for all compatible factory. - ''' + """ factory = attrs.pop('factory', click) factory_attr = attrs.pop('attr', 'option') if factory_attr != 'argument': @@ -242,18 +221,14 @@ def adapter_id_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callab def cors_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option( - '--cors/--no-cors', - show_default=True, - default=False, - envvar='OPENLLM_CORS', - show_envvar=True, - help='Enable CORS for the server.', - **attrs, + '--cors/--no-cors', show_default=True, default=False, envvar='OPENLLM_CORS', show_envvar=True, help='Enable CORS for the server.', **attrs )(f) + def machine_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option('--machine', is_flag=True, default=False, hidden=True, **attrs)(f) + def dtype_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option( '--dtype', @@ -264,6 +239,7 @@ def dtype_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[F **attrs, )(f) + def model_id_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option( '--model-id', @@ -294,16 +270,14 @@ def backend_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[ envvar='OPENLLM_BACKEND', show_envvar=True, help='Runtime to use for both serialisation/inference engine.', - **attrs)(f) - -def model_name_argument(f: _AnyCallable | None = None, required: bool = True, **attrs: t.Any) -> t.Callable[[FC], FC]: - return cli_argument( - 'model_name', - type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING]), - required=required, **attrs, )(f) + +def model_name_argument(f: _AnyCallable | None = None, required: bool = True, **attrs: t.Any) -> t.Callable[[FC], FC]: + return cli_argument('model_name', type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING]), required=required, **attrs)(f) + + def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option( '--quantise', @@ -313,7 +287,7 @@ def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, **att default=None, envvar='OPENLLM_QUANTIZE', show_envvar=True, - help='''Dynamic quantization for running this LLM. + help="""Dynamic quantization for running this LLM. The following quantization strategies are supported: @@ -328,23 +302,25 @@ def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, **att - ``squeezellm``: ``SqueezeLLM`` [SqueezeLLM: Dense-and-Sparse Quantization](https://arxiv.org/abs/2306.07629) > [!NOTE] that the model can also be served with quantized weights. - ''' + """ + ( - ''' - > [!NOTE] that this will set the mode for serving within deployment.''' if build else '' + """ + > [!NOTE] that this will set the mode for serving within deployment.""" + if build + else '' ), - **attrs)(f) + **attrs, + )(f) -def workers_per_resource_option( - f: _AnyCallable | None = None, *, build: bool = False, **attrs: t.Any -) -> t.Callable[[FC], FC]: + +def workers_per_resource_option(f: _AnyCallable | None = None, *, build: bool = False, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option( '--workers-per-resource', default=None, callback=workers_per_resource_callback, type=str, required=False, - help='''Number of workers per resource assigned. + help="""Number of workers per resource assigned. See https://docs.bentoml.org/en/latest/guides/scheduling.html#resource-scheduling-strategy for more information. By default, this is set to 1. @@ -354,7 +330,7 @@ def workers_per_resource_option( - ``round_robin``: Similar behaviour when setting ``--workers-per-resource 1``. This is useful for smaller models. - ``conserved``: This will determine the number of available GPU resources. For example, if ther are 4 GPUs available, then ``conserved`` is equivalent to ``--workers-per-resource 0.25``. - ''' + """ + ( """\n > [!NOTE] The workers value passed into 'build' will determine how the LLM can @@ -366,6 +342,7 @@ def workers_per_resource_option( **attrs, )(f) + def serialisation_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option( '--serialisation', @@ -376,7 +353,7 @@ def serialisation_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Cal show_default=True, show_envvar=True, envvar='OPENLLM_SERIALIZATION', - help='''Serialisation format for save/load LLM. + help="""Serialisation format for save/load LLM. Currently the following strategies are supported: @@ -385,12 +362,14 @@ def serialisation_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Cal > [!NOTE] Safetensors might not work for every cases, and you can always fallback to ``legacy`` if needed. - ``legacy``: This will use PyTorch serialisation format, often as ``.bin`` files. This should be used if the model doesn't yet support safetensors. - ''', + """, **attrs, )(f) + _wpr_strategies = {'round_robin', 'conserved'} + def workers_per_resource_callback(ctx: click.Context, param: click.Parameter, value: str | None) -> str | None: if value is None: return value @@ -402,9 +381,7 @@ def workers_per_resource_callback(ctx: click.Context, param: click.Parameter, va float(value) # type: ignore[arg-type] except ValueError: raise click.BadParameter( - f"'workers_per_resource' only accept '{_wpr_strategies}' as possible strategies, otherwise pass in float.", - ctx, - param, + f"'workers_per_resource' only accept '{_wpr_strategies}' as possible strategies, otherwise pass in float.", ctx, param ) from None else: return value diff --git a/openllm-python/src/openllm_cli/_sdk.py b/openllm-python/src/openllm_cli/_sdk.py index ab36f90c..70d24339 100644 --- a/openllm-python/src/openllm_cli/_sdk.py +++ b/openllm-python/src/openllm_cli/_sdk.py @@ -6,6 +6,7 @@ from bentoml._internal.configuration.containers import BentoMLContainer from openllm_core._typing_compat import LiteralSerialisation from openllm_core.exceptions import OpenLLMException from openllm_core.utils import WARNING_ENV_VAR, codegen, first_not_none, get_disable_warnings, is_vllm_available + if t.TYPE_CHECKING: from bentoml._internal.bento import BentoStore from openllm_core._configuration import LLMConfig @@ -13,6 +14,7 @@ if t.TYPE_CHECKING: logger = logging.getLogger(__name__) + def _start( model_id: str, timeout: int = 30, @@ -61,27 +63,25 @@ def _start( additional_args: Additional arguments to pass to ``openllm start``. """ from .entrypoint import start_command + os.environ['BACKEND'] = openllm_core.utils.first_not_none(backend, default='vllm' if is_vllm_available() else 'pt') args: list[str] = [model_id] - if timeout: args.extend(['--server-timeout', str(timeout)]) + if timeout: + args.extend(['--server-timeout', str(timeout)]) if workers_per_resource: - args.extend( - [ - '--workers-per-resource', - str(workers_per_resource) if not isinstance(workers_per_resource, str) else workers_per_resource, - ] - ) - if device and not os.environ.get('CUDA_VISIBLE_DEVICES'): args.extend(['--device', ','.join(device)]) - if quantize: args.extend(['--quantize', str(quantize)]) - if cors: args.append('--cors') + args.extend(['--workers-per-resource', str(workers_per_resource) if not isinstance(workers_per_resource, str) else workers_per_resource]) + if device and not os.environ.get('CUDA_VISIBLE_DEVICES'): + args.extend(['--device', ','.join(device)]) + if quantize: + args.extend(['--quantize', str(quantize)]) + if cors: + args.append('--cors') if adapter_map: - args.extend( - list( - itertools.chain.from_iterable([['--adapter-id', f"{k}{':'+v if v else ''}"] for k, v in adapter_map.items()]) - ) - ) - if additional_args: args.extend(additional_args) - if __test__: args.append('--return-process') + args.extend(list(itertools.chain.from_iterable([['--adapter-id', f"{k}{':'+v if v else ''}"] for k, v in adapter_map.items()]))) + if additional_args: + args.extend(additional_args) + if __test__: + args.append('--return-process') cmd = start_command return cmd.main(args=args, standalone_mode=False) @@ -138,6 +138,7 @@ def _build( ``bentoml.Bento | str``: BentoLLM instance. This can be used to serve the LLM or can be pushed to BentoCloud. """ from openllm.serialisation.transformers.weights import has_safetensors_weights + args: list[str] = [ sys.executable, '-m', @@ -147,23 +148,34 @@ def _build( '--machine', '--quiet', '--serialisation', - first_not_none( - serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy' - ), + first_not_none(serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'), ] - if quantize: args.extend(['--quantize', quantize]) - if containerize and push: raise OpenLLMException("'containerize' and 'push' are currently mutually exclusive.") - if push: args.extend(['--push']) - if containerize: args.extend(['--containerize']) - if build_ctx: args.extend(['--build-ctx', build_ctx]) - if enable_features: args.extend([f'--enable-features={f}' for f in enable_features]) - if overwrite: args.append('--overwrite') - if adapter_map: args.extend([f"--adapter-id={k}{':'+v if v is not None else ''}" for k, v in adapter_map.items()]) - if model_version: args.extend(['--model-version', model_version]) - if bento_version: args.extend(['--bento-version', bento_version]) - if dockerfile_template: args.extend(['--dockerfile-template', dockerfile_template]) - if additional_args: args.extend(additional_args) - if force_push: args.append('--force-push') + if quantize: + args.extend(['--quantize', quantize]) + if containerize and push: + raise OpenLLMException("'containerize' and 'push' are currently mutually exclusive.") + if push: + args.extend(['--push']) + if containerize: + args.extend(['--containerize']) + if build_ctx: + args.extend(['--build-ctx', build_ctx]) + if enable_features: + args.extend([f'--enable-features={f}' for f in enable_features]) + if overwrite: + args.append('--overwrite') + if adapter_map: + args.extend([f"--adapter-id={k}{':'+v if v is not None else ''}" for k, v in adapter_map.items()]) + if model_version: + args.extend(['--model-version', model_version]) + if bento_version: + args.extend(['--bento-version', bento_version]) + if dockerfile_template: + args.extend(['--dockerfile-template', dockerfile_template]) + if additional_args: + args.extend(additional_args) + if force_push: + args.append('--force-push') current_disable_warning = get_disable_warnings() os.environ[WARNING_ENV_VAR] = str(True) @@ -171,17 +183,24 @@ def _build( output = subprocess.check_output(args, env=os.environ.copy(), cwd=build_ctx or os.getcwd()) except subprocess.CalledProcessError as e: logger.error("Exception caught while building Bento for '%s'", model_id, exc_info=e) - if e.stderr: raise OpenLLMException(e.stderr.decode('utf-8')) from None + if e.stderr: + raise OpenLLMException(e.stderr.decode('utf-8')) from None raise OpenLLMException(str(e)) from None matched = re.match(r'__object__:(\{.*\})$', output.decode('utf-8').strip()) if matched is None: - raise ValueError(f"Failed to find tag from output: {output.decode('utf-8').strip()}\nNote: Output from 'openllm build' might not be correct. Please open an issue on GitHub.") + raise ValueError( + f"Failed to find tag from output: {output.decode('utf-8').strip()}\nNote: Output from 'openllm build' might not be correct. Please open an issue on GitHub." + ) os.environ[WARNING_ENV_VAR] = str(current_disable_warning) try: result = orjson.loads(matched.group(1)) except orjson.JSONDecodeError as e: - raise ValueError(f"Failed to decode JSON from output: {output.decode('utf-8').strip()}\nNote: Output from 'openllm build' might not be correct. Please open an issue on GitHub.") from e + raise ValueError( + f"Failed to decode JSON from output: {output.decode('utf-8').strip()}\nNote: Output from 'openllm build' might not be correct. Please open an issue on GitHub." + ) from e return bentoml.get(result['tag'], _bento_store=bento_store) + + def _import_model( model_id: str, model_version: str | None = None, @@ -218,15 +237,32 @@ def _import_model( ``bentoml.Model``:BentoModel of the given LLM. This can be used to serve the LLM or can be pushed to BentoCloud. """ from .entrypoint import import_command + args = [model_id, '--quiet'] - if backend is not None: args.extend(['--backend', backend]) - if model_version is not None: args.extend(['--model-version', str(model_version)]) - if quantize is not None: args.extend(['--quantize', quantize]) - if serialisation is not None: args.extend(['--serialisation', serialisation]) - if additional_args is not None: args.extend(additional_args) + if backend is not None: + args.extend(['--backend', backend]) + if model_version is not None: + args.extend(['--model-version', str(model_version)]) + if quantize is not None: + args.extend(['--quantize', quantize]) + if serialisation is not None: + args.extend(['--serialisation', serialisation]) + if additional_args is not None: + args.extend(additional_args) return import_command.main(args=args, standalone_mode=False) + + def _list_models() -> dict[str, t.Any]: - '''List all available models within the local store.''' - from .entrypoint import models_command; return models_command.main(args=['--quiet'], standalone_mode=False) -start, build, import_model, list_models = codegen.gen_sdk(_start), codegen.gen_sdk(_build), codegen.gen_sdk(_import_model), codegen.gen_sdk(_list_models) + """List all available models within the local store.""" + from .entrypoint import models_command + + return models_command.main(args=['--quiet'], standalone_mode=False) + + +start, build, import_model, list_models = ( + codegen.gen_sdk(_start), + codegen.gen_sdk(_build), + codegen.gen_sdk(_import_model), + codegen.gen_sdk(_list_models), +) __all__ = ['start', 'build', 'import_model', 'list_models'] diff --git a/openllm-python/src/openllm_cli/entrypoint.py b/openllm-python/src/openllm_cli/entrypoint.py index 57451d97..28c07ba2 100644 --- a/openllm-python/src/openllm_cli/entrypoint.py +++ b/openllm-python/src/openllm_cli/entrypoint.py @@ -43,15 +43,7 @@ from openllm_core.utils import ( ) from . import termui -from ._factory import ( - FC, - _AnyCallable, - machine_option, - model_name_argument, - parse_config_options, - start_decorator, - optimization_decorator, -) +from ._factory import FC, _AnyCallable, machine_option, model_name_argument, parse_config_options, start_decorator, optimization_decorator if t.TYPE_CHECKING: import torch @@ -65,14 +57,14 @@ else: P = ParamSpec('P') logger = logging.getLogger('openllm') -OPENLLM_FIGLET = '''\ +OPENLLM_FIGLET = """\ ██████╗ ██████╗ ███████╗███╗ ██╗██╗ ██╗ ███╗ ███╗ ██╔═══██╗██╔══██╗██╔════╝████╗ ██║██║ ██║ ████╗ ████║ ██║ ██║██████╔╝█████╗ ██╔██╗ ██║██║ ██║ ██╔████╔██║ ██║ ██║██╔═══╝ ██╔══╝ ██║╚██╗██║██║ ██║ ██║╚██╔╝██║ ╚██████╔╝██║ ███████╗██║ ╚████║███████╗███████╗██║ ╚═╝ ██║ ╚═════╝ ╚═╝ ╚══════╝╚═╝ ╚═══╝╚══════╝╚══════╝╚═╝ ╚═╝ -''' +""" ServeCommand = t.Literal['serve', 'serve-grpc'] @@ -103,20 +95,12 @@ def backend_warning(backend: LiteralBackend, build: bool = False) -> None: 'vLLM is not available. Note that PyTorch backend is not as performant as vLLM and you should always consider using vLLM for production.' ) if build: - logger.info( - "Tip: You can set '--backend vllm' to package your Bento with vLLM backend regardless if vLLM is available locally." - ) + logger.info("Tip: You can set '--backend vllm' to package your Bento with vLLM backend regardless if vLLM is available locally.") class Extensions(click.MultiCommand): def list_commands(self, ctx: click.Context) -> list[str]: - return sorted( - [ - filename[:-3] - for filename in os.listdir(_EXT_FOLDER) - if filename.endswith('.py') and not filename.startswith('__') - ] - ) + return sorted([filename[:-3] for filename in os.listdir(_EXT_FOLDER) if filename.endswith('.py') and not filename.startswith('__')]) def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None: try: @@ -133,41 +117,19 @@ class OpenLLMCommandGroup(BentoMLCommandGroup): def common_params(f: t.Callable[P, t.Any]) -> t.Callable[[FC], FC]: # The following logics is similar to one of BentoMLCommandGroup @cog.optgroup.group(name='Global options', help='Shared globals options for all OpenLLM CLI.') # type: ignore[misc] + @cog.optgroup.option('-q', '--quiet', envvar=QUIET_ENV_VAR, is_flag=True, default=False, help='Suppress all output.', show_envvar=True) @cog.optgroup.option( - '-q', '--quiet', envvar=QUIET_ENV_VAR, is_flag=True, default=False, help='Suppress all output.', show_envvar=True + '--debug', '--verbose', 'debug', envvar=DEBUG_ENV_VAR, is_flag=True, default=False, help='Print out debug logs.', show_envvar=True ) @cog.optgroup.option( - '--debug', - '--verbose', - 'debug', - envvar=DEBUG_ENV_VAR, - is_flag=True, - default=False, - help='Print out debug logs.', - show_envvar=True, + '--do-not-track', is_flag=True, default=False, envvar=analytics.OPENLLM_DO_NOT_TRACK, help='Do not send usage info', show_envvar=True ) @cog.optgroup.option( - '--do-not-track', - is_flag=True, - default=False, - envvar=analytics.OPENLLM_DO_NOT_TRACK, - help='Do not send usage info', - show_envvar=True, - ) - @cog.optgroup.option( - '--context', - 'cloud_context', - envvar='BENTOCLOUD_CONTEXT', - type=click.STRING, - default=None, - help='BentoCloud context name.', - show_envvar=True, + '--context', 'cloud_context', envvar='BENTOCLOUD_CONTEXT', type=click.STRING, default=None, help='BentoCloud context name.', show_envvar=True ) @click.pass_context @functools.wraps(f) - def wrapper( - ctx: click.Context, quiet: bool, debug: bool, cloud_context: str | None, *args: P.args, **attrs: P.kwargs - ) -> t.Any: + def wrapper(ctx: click.Context, quiet: bool, debug: bool, cloud_context: str | None, *args: P.args, **attrs: P.kwargs) -> t.Any: ctx.obj = GlobalOptions(cloud_context=cloud_context) if quiet: set_quiet_mode(True) @@ -181,9 +143,7 @@ class OpenLLMCommandGroup(BentoMLCommandGroup): return wrapper @staticmethod - def usage_tracking( - func: t.Callable[P, t.Any], group: click.Group, **attrs: t.Any - ) -> t.Callable[Concatenate[bool, P], t.Any]: + def usage_tracking(func: t.Callable[P, t.Any], group: click.Group, **attrs: t.Any) -> t.Callable[Concatenate[bool, P], t.Any]: command_name = attrs.get('name', func.__name__) @functools.wraps(func) @@ -242,9 +202,7 @@ class OpenLLMCommandGroup(BentoMLCommandGroup): _memo = getattr(wrapped, '__click_params__', None) if _memo is None: raise ValueError('Click command not register correctly.') - _object_setattr( - wrapped, '__click_params__', _memo[-self.NUMBER_OF_COMMON_PARAMS :] + _memo[: -self.NUMBER_OF_COMMON_PARAMS] - ) + _object_setattr(wrapped, '__click_params__', _memo[-self.NUMBER_OF_COMMON_PARAMS :] + _memo[: -self.NUMBER_OF_COMMON_PARAMS]) # NOTE: we need to call super of super to avoid conflict with BentoMLCommandGroup command setup cmd = super(BentoMLCommandGroup, self).command(*args, **kwargs)(wrapped) # NOTE: add aliases to a given commands if it is specified. @@ -258,7 +216,7 @@ class OpenLLMCommandGroup(BentoMLCommandGroup): return decorator def format_commands(self, ctx: click.Context, formatter: click.HelpFormatter) -> None: - '''Additional format methods that include extensions as well as the default cli command.''' + """Additional format methods that include extensions as well as the default cli command.""" from gettext import gettext as _ commands: list[tuple[str, click.Command]] = [] @@ -305,7 +263,7 @@ _PACKAGE_NAME = 'openllm' message=f'{_PACKAGE_NAME}, %(version)s (compiled: {openllm.COMPILED})\nPython ({platform.python_implementation()}) {platform.python_version()}', ) def cli() -> None: - '''\b + """\b ██████╗ ██████╗ ███████╗███╗ ██╗██╗ ██╗ ███╗ ███╗ ██╔═══██╗██╔══██╗██╔════╝████╗ ██║██║ ██║ ████╗ ████║ ██║ ██║██████╔╝█████╗ ██╔██╗ ██║██║ ██║ ██╔████╔██║ @@ -316,15 +274,10 @@ def cli() -> None: \b An open platform for operating large language models in production. Fine-tune, serve, deploy, and monitor any LLMs with ease. - ''' + """ -@cli.command( - context_settings=termui.CONTEXT_SETTINGS, - name='start', - aliases=['start-http'], - short_help='Start a LLMServer for any supported LLM.', -) +@cli.command(context_settings=termui.CONTEXT_SETTINGS, name='start', aliases=['start-http'], short_help='Start a LLMServer for any supported LLM.') @click.argument('model_id', type=click.STRING, metavar='[REMOTE_REPO/MODEL_ID | /path/to/local/model]', required=True) @click.option( '--model-id', @@ -366,24 +319,30 @@ def start_command( dtype: LiteralDtype, deprecated_model_id: str | None, max_model_len: int | None, - gpu_memory_utilization:float, + gpu_memory_utilization: float, **attrs: t.Any, ) -> LLMConfig | subprocess.Popen[bytes]: - '''Start any LLM as a REST server. + """Start any LLM as a REST server. \b ```bash $ openllm -- ... ``` - ''' - if backend == 'pt': logger.warning('PyTorch backend is deprecated and will be removed in future releases. Make sure to use vLLM instead.') + """ + if backend == 'pt': + logger.warning('PyTorch backend is deprecated and will be removed in future releases. Make sure to use vLLM instead.') if model_id in openllm.CONFIG_MAPPING: _model_name = model_id if deprecated_model_id is not None: model_id = deprecated_model_id else: model_id = openllm.AutoConfig.for_model(_model_name)['default_id'] - logger.warning("Passing 'openllm start %s%s' is deprecated and will be remove in a future version. Use 'openllm start %s' instead.", _model_name, '' if deprecated_model_id is None else f' --model-id {deprecated_model_id}', model_id) + logger.warning( + "Passing 'openllm start %s%s' is deprecated and will be remove in a future version. Use 'openllm start %s' instead.", + _model_name, + '' if deprecated_model_id is None else f' --model-id {deprecated_model_id}', + model_id, + ) adapter_map: dict[str, str] | None = attrs.pop('adapter_map', None) @@ -393,11 +352,7 @@ def start_command( if serialisation == 'safetensors' and quantize is not None: logger.warning("'--quantize=%s' might not work with 'safetensors' serialisation format.", quantize) - logger.warning( - "Make sure to check out '%s' repository to see if the weights is in '%s' format if unsure.", - model_id, - serialisation, - ) + logger.warning("Make sure to check out '%s' repository to see if the weights is in '%s' format if unsure.", model_id, serialisation) logger.info("Tip: You can always fallback to '--serialisation legacy' when running quantisation.") import torch @@ -425,7 +380,7 @@ def start_command( config, server_attrs = llm.config.model_validate_click(**attrs) server_timeout = first_not_none(server_timeout, default=config['timeout']) server_attrs.update({'working_dir': pkg.source_locations('openllm'), 'timeout': server_timeout}) - development = server_attrs.pop('development') # XXX: currently, theres no development args in bentoml.Server. To be fixed upstream. + development = server_attrs.pop('development') # XXX: currently, theres no development args in bentoml.Server. To be fixed upstream. server_attrs.setdefault('production', not development) start_env = process_environ( @@ -454,26 +409,27 @@ def start_command( # NOTE: Return the configuration for telemetry purposes. return config + def process_environ(config, server_timeout, wpr, device, cors, model_id, adapter_map, serialisation, llm, use_current_env=True): environ = parse_config_options(config, server_timeout, wpr, device, cors, os.environ.copy() if use_current_env else {}) - environ.update( - { - 'OPENLLM_MODEL_ID': model_id, - 'BENTOML_DEBUG': str(openllm.utils.get_debug_mode()), - 'BENTOML_HOME': os.environ.get('BENTOML_HOME', BentoMLContainer.bentoml_home.get()), - 'OPENLLM_ADAPTER_MAP': orjson.dumps(adapter_map).decode(), - 'OPENLLM_SERIALIZATION': serialisation, - 'OPENLLM_CONFIG': config.model_dump_json(flatten=True).decode(), - 'BACKEND': llm.__llm_backend__, - 'DTYPE': str(llm._torch_dtype).split('.')[-1], - 'TRUST_REMOTE_CODE': str(llm.trust_remote_code), - 'MAX_MODEL_LEN': orjson.dumps(llm._max_model_len).decode(), - 'GPU_MEMORY_UTILIZATION': orjson.dumps(llm._gpu_memory_utilization).decode(), - } - ) - if llm.quantise: environ['QUANTIZE'] = str(llm.quantise) + environ.update({ + 'OPENLLM_MODEL_ID': model_id, + 'BENTOML_DEBUG': str(openllm.utils.get_debug_mode()), + 'BENTOML_HOME': os.environ.get('BENTOML_HOME', BentoMLContainer.bentoml_home.get()), + 'OPENLLM_ADAPTER_MAP': orjson.dumps(adapter_map).decode(), + 'OPENLLM_SERIALIZATION': serialisation, + 'OPENLLM_CONFIG': config.model_dump_json(flatten=True).decode(), + 'BACKEND': llm.__llm_backend__, + 'DTYPE': str(llm._torch_dtype).split('.')[-1], + 'TRUST_REMOTE_CODE': str(llm.trust_remote_code), + 'MAX_MODEL_LEN': orjson.dumps(llm._max_model_len).decode(), + 'GPU_MEMORY_UTILIZATION': orjson.dumps(llm._gpu_memory_utilization).decode(), + }) + if llm.quantise: + environ['QUANTIZE'] = str(llm.quantise) return environ + def process_workers_per_resource(wpr: str | float | int, device: tuple[str, ...]) -> TypeGuard[float]: if isinstance(wpr, str): if wpr == 'round_robin': @@ -491,6 +447,7 @@ def process_workers_per_resource(wpr: str | float | int, device: tuple[str, ...] wpr = float(wpr) return wpr + def build_bento_instruction(llm, model_id, serialisation, adapter_map): cmd_name = f'openllm build {model_id} --backend {llm.__llm_backend__}' if llm.quantise: @@ -498,12 +455,9 @@ def build_bento_instruction(llm, model_id, serialisation, adapter_map): if llm.__llm_backend__ in {'pt', 'vllm'}: cmd_name += f' --serialization {serialisation}' if adapter_map is not None: - cmd_name += ' ' + ' '.join( - [ - f'--adapter-id {s}' - for s in [f'{p}:{name}' if name not in (None, 'default') else p for p, name in adapter_map.items()] - ] - ) + cmd_name += ' ' + ' '.join([ + f'--adapter-id {s}' for s in [f'{p}:{name}' if name not in (None, 'default') else p for p, name in adapter_map.items()] + ]) if not openllm.utils.get_quiet_mode(): termui.info(f"🚀Tip: run '{cmd_name}' to create a BentoLLM for '{model_id}'") @@ -537,9 +491,8 @@ def run_server(args, env, return_process=False) -> subprocess.Popen[bytes] | int if return_process: return process stop_event = threading.Event() - # yapf: disable stdout, stderr = threading.Thread(target=handle, args=(process.stdout, stop_event)), threading.Thread(target=handle, args=(process.stderr, stop_event)) - stdout.start(); stderr.start() + stdout.start(); stderr.start() # noqa: E702 try: process.wait() @@ -554,10 +507,9 @@ def run_server(args, env, return_process=False) -> subprocess.Popen[bytes] | int raise finally: stop_event.set() - stdout.join(); stderr.join() + stdout.join(); stderr.join() # noqa: E702 if process.poll() is not None: process.kill() - stdout.join(); stderr.join() - # yapf: disable + stdout.join(); stderr.join() # noqa: E702 return process.returncode @@ -645,10 +597,7 @@ def import_command( backend=backend, dtype=dtype, serialisation=t.cast( - LiteralSerialisation, - first_not_none( - serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy' - ), + LiteralSerialisation, first_not_none(serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy') ), ) backend_warning(llm.__llm_backend__) @@ -707,21 +656,14 @@ class BuildBentoOutput(t.TypedDict): metavar='[REMOTE_REPO/MODEL_ID | /path/to/local/model]', help='Deprecated. Use positional argument instead.', ) -@click.option( - '--bento-version', - type=str, - default=None, - help='Optional bento version for this BentoLLM. Default is the the model revision.', -) +@click.option('--bento-version', type=str, default=None, help='Optional bento version for this BentoLLM. Default is the the model revision.') @click.option('--overwrite', is_flag=True, help='Overwrite existing Bento for given LLM if it already exists.') @click.option( '--enable-features', multiple=True, nargs=1, metavar='FEATURE[,FEATURE]', - help='Enable additional features for building this LLM Bento. Available: {}'.format( - ', '.join(OPTIONAL_DEPENDENCIES) - ), + help='Enable additional features for building this LLM Bento. Available: {}'.format(', '.join(OPTIONAL_DEPENDENCIES)), ) @optimization_decorator @click.option( @@ -732,12 +674,7 @@ class BuildBentoOutput(t.TypedDict): help="Optional adapters id to be included within the Bento. Note that if you are using relative path, '--build-ctx' must be passed.", ) @click.option('--build-ctx', help='Build context. This is required if --adapter-id uses relative path', default=None) -@click.option( - '--dockerfile-template', - default=None, - type=click.File(), - help='Optional custom dockerfile template to be used with this BentoLLM.', -) +@click.option('--dockerfile-template', default=None, type=click.File(), help='Optional custom dockerfile template to be used with this BentoLLM.') @cog.optgroup.group(cls=cog.MutuallyExclusiveOptionGroup, name='Utilities options') # type: ignore[misc] @cog.optgroup.option( '--containerize', @@ -788,13 +725,13 @@ def build_command( build_ctx: str | None, dockerfile_template: t.TextIO | None, max_model_len: int | None, - gpu_memory_utilization:float, + gpu_memory_utilization: float, containerize: bool, push: bool, force_push: bool, **_: t.Any, ) -> BuildBentoOutput: - '''Package a given models into a BentoLLM. + """Package a given models into a BentoLLM. \b ```bash @@ -810,7 +747,7 @@ def build_command( > [!IMPORTANT] > To build the bento with compiled OpenLLM, make sure to prepend HATCH_BUILD_HOOKS_ENABLE=1. Make sure that the deployment > target also use the same Python version and architecture as build machine. - ''' + """ from openllm.serialisation.transformers.weights import has_safetensors_weights if model_id in openllm.CONFIG_MAPPING: @@ -840,9 +777,7 @@ def build_command( dtype=dtype, max_model_len=max_model_len, gpu_memory_utilization=gpu_memory_utilization, - serialisation=first_not_none( - serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy' - ), + serialisation=first_not_none(serialisation, default='safetensors' if has_safetensors_weights(model_id, model_version) else 'legacy'), _eager=False, ) if llm.__llm_backend__ not in llm.config['backend']: @@ -854,9 +789,7 @@ def build_command( model = openllm.serialisation.import_model(llm, trust_remote_code=llm.trust_remote_code) llm._tag = model.tag - os.environ.update( - **process_environ(llm.config, llm.config['timeout'], 1.0, None, True, llm.model_id, None, llm._serialisation, llm) - ) + os.environ.update(**process_environ(llm.config, llm.config['timeout'], 1.0, None, True, llm.model_id, None, llm._serialisation, llm)) try: assert llm.bentomodel # HACK: call it here to patch correct tag with revision and everything @@ -923,11 +856,7 @@ def build_command( def get_current_bentocloud_context() -> str | None: try: - context = ( - cloud_config.get_context(ctx.obj.cloud_context) - if ctx.obj.cloud_context - else cloud_config.get_current_context() - ) + context = cloud_config.get_context(ctx.obj.cloud_context) if ctx.obj.cloud_context else cloud_config.get_current_context() return context.name except Exception: return None @@ -951,9 +880,7 @@ def build_command( tag=str(bento_tag), backend=llm.__llm_backend__, instructions=[ - DeploymentInstruction.from_content( - type='bentocloud', instr="☁️ Push to BentoCloud with 'bentoml push':\n $ {cmd}", cmd=push_cmd - ), + DeploymentInstruction.from_content(type='bentocloud', instr="☁️ Push to BentoCloud with 'bentoml push':\n $ {cmd}", cmd=push_cmd), DeploymentInstruction.from_content( type='container', instr="🐳 Container BentoLLM with 'bentoml containerize':\n $ {cmd}", @@ -979,9 +906,7 @@ def build_command( termui.echo(f" * {instruction['content']}\n", nl=False) if push: - BentoMLContainer.bentocloud_client.get().push_bento( - bento, context=t.cast(GlobalOptions, ctx.obj).cloud_context, force=force_push - ) + BentoMLContainer.bentocloud_client.get().push_bento(bento, context=t.cast(GlobalOptions, ctx.obj).cloud_context, force=force_push) elif containerize: container_backend = t.cast('DefaultBuilder', os.environ.get('BENTOML_CONTAINERIZE_BACKEND', 'docker')) try: @@ -1009,20 +934,19 @@ class ModelItem(t.TypedDict): @cli.command() @click.option('--show-available', is_flag=True, default=True, hidden=True) def models_command(**_: t.Any) -> dict[t.LiteralString, ModelItem]: - '''List all supported models. + """List all supported models. \b ```bash openllm models ``` - ''' + """ result: dict[t.LiteralString, ModelItem] = { m: ModelItem( architecture=config.__openllm_architecture__, example_id=random.choice(config.__openllm_model_ids__), supported_backends=config.__openllm_backend__, - installation='pip install ' - + (f'"openllm[{m}]"' if m in OPTIONAL_DEPENDENCIES or config.__openllm_requirements__ else 'openllm'), + installation='pip install ' + (f'"openllm[{m}]"' if m in OPTIONAL_DEPENDENCIES or config.__openllm_requirements__ else 'openllm'), items=[ str(md.tag) for md in bentoml.models.list() @@ -1041,13 +965,7 @@ def models_command(**_: t.Any) -> dict[t.LiteralString, ModelItem]: @cli.command() @model_name_argument(required=False) @click.option('-y', '--yes', '--assume-yes', is_flag=True, help='Skip confirmation when deleting a specific model') -@click.option( - '--include-bentos/--no-include-bentos', - is_flag=True, - hidden=True, - default=True, - help='Whether to also include pruning bentos.', -) +@click.option('--include-bentos/--no-include-bentos', is_flag=True, hidden=True, default=True, help='Whether to also include pruning bentos.') @inject @click.pass_context def prune_command( @@ -1058,38 +976,30 @@ def prune_command( bento_store: BentoStore = Provide[BentoMLContainer.bento_store], **_: t.Any, ) -> None: - '''Remove all saved models, and bentos built with OpenLLM locally. + """Remove all saved models, and bentos built with OpenLLM locally. \b If a model type is passed, then only prune models for that given model type. - ''' + """ available: list[tuple[bentoml.Model | bentoml.Bento, ModelStore | BentoStore]] = [ - (m, model_store) - for m in bentoml.models.list() - if 'framework' in m.info.labels and m.info.labels['framework'] == 'openllm' + (m, model_store) for m in bentoml.models.list() if 'framework' in m.info.labels and m.info.labels['framework'] == 'openllm' ] if model_name is not None: available = [ - (m, store) - for m, store in available - if 'model_name' in m.info.labels and m.info.labels['model_name'] == inflection.underscore(model_name) + (m, store) for m, store in available if 'model_name' in m.info.labels and m.info.labels['model_name'] == inflection.underscore(model_name) ] + [ (b, bento_store) for b in bentoml.bentos.list() if 'start_name' in b.info.labels and b.info.labels['start_name'] == inflection.underscore(model_name) ] else: - available += [ - (b, bento_store) for b in bentoml.bentos.list() if '_type' in b.info.labels and '_framework' in b.info.labels - ] + available += [(b, bento_store) for b in bentoml.bentos.list() if '_type' in b.info.labels and '_framework' in b.info.labels] for store_item, store in available: if yes: delete_confirmed = True else: - delete_confirmed = click.confirm( - f"delete {'model' if isinstance(store, ModelStore) else 'bento'} {store_item.tag}?" - ) + delete_confirmed = click.confirm(f"delete {'model' if isinstance(store, ModelStore) else 'bento'} {store_item.tag}?") if delete_confirmed: store.delete(store_item.tag) termui.warning(f"{store_item} deleted from {'model' if isinstance(store, ModelStore) else 'bento'} store.") @@ -1136,17 +1046,8 @@ def shared_client_options(f: _AnyCallable | None = None) -> t.Callable[[FC], FC] @cli.command() @shared_client_options -@click.option( - '--server-type', - type=click.Choice(['grpc', 'http']), - help='Server type', - default='http', - show_default=True, - hidden=True, -) -@click.option( - '--stream/--no-stream', type=click.BOOL, is_flag=True, default=True, help='Whether to stream the response.' -) +@click.option('--server-type', type=click.Choice(['grpc', 'http']), help='Server type', default='http', show_default=True, hidden=True) +@click.option('--stream/--no-stream', type=click.BOOL, is_flag=True, default=True, help='Whether to stream the response.') @click.argument('prompt', type=click.STRING) @click.option( '--sampling-params', @@ -1168,14 +1069,15 @@ def query_command( _memoized: DictStrAny, **_: t.Any, ) -> None: - '''Query a LLM interactively, from a terminal. + """Query a LLM interactively, from a terminal. \b ```bash $ openllm query --endpoint http://12.323.2.1:3000 "What is the meaning of life?" ``` - ''' - if server_type == 'grpc': raise click.ClickException("'grpc' is currently disabled.") + """ + if server_type == 'grpc': + raise click.ClickException("'grpc' is currently disabled.") _memoized = {k: orjson.loads(v[0]) for k, v in _memoized.items() if v} # TODO: grpc support client = openllm.HTTPClient(address=endpoint, timeout=timeout) @@ -1194,7 +1096,7 @@ def query_command( @cli.group(cls=Extensions, hidden=True, name='extension') def extension_command() -> None: - '''Extension for OpenLLM CLI.''' + """Extension for OpenLLM CLI.""" if __name__ == '__main__': diff --git a/openllm-python/src/openllm_cli/extension/dive_bentos.py b/openllm-python/src/openllm_cli/extension/dive_bentos.py index 2fb32301..db488004 100644 --- a/openllm-python/src/openllm_cli/extension/dive_bentos.py +++ b/openllm-python/src/openllm_cli/extension/dive_bentos.py @@ -21,10 +21,8 @@ if t.TYPE_CHECKING: @machine_option @click.pass_context @inject -def cli( - ctx: click.Context, bento: str, machine: bool, _bento_store: BentoStore = Provide[BentoMLContainer.bento_store] -) -> str | None: - '''Dive into a BentoLLM. This is synonymous to cd $(b get : -o path).''' +def cli(ctx: click.Context, bento: str, machine: bool, _bento_store: BentoStore = Provide[BentoMLContainer.bento_store]) -> str | None: + """Dive into a BentoLLM. This is synonymous to cd $(b get : -o path).""" try: bentomodel = _bento_store.get(bento) except bentoml.exceptions.NotFound: diff --git a/openllm-python/src/openllm_cli/extension/get_containerfile.py b/openllm-python/src/openllm_cli/extension/get_containerfile.py index 50798829..88605414 100644 --- a/openllm-python/src/openllm_cli/extension/get_containerfile.py +++ b/openllm-python/src/openllm_cli/extension/get_containerfile.py @@ -17,9 +17,7 @@ if t.TYPE_CHECKING: from bentoml._internal.bento import BentoStore -@click.command( - 'get_containerfile', context_settings=termui.CONTEXT_SETTINGS, help='Return Containerfile of any given Bento.' -) +@click.command('get_containerfile', context_settings=termui.CONTEXT_SETTINGS, help='Return Containerfile of any given Bento.') @click.argument('bento', type=str, shell_complete=bento_complete_envvar) @click.pass_context @inject diff --git a/openllm-python/src/openllm_cli/extension/get_prompt.py b/openllm-python/src/openllm_cli/extension/get_prompt.py index 0e64c230..b679577f 100644 --- a/openllm-python/src/openllm_cli/extension/get_prompt.py +++ b/openllm-python/src/openllm_cli/extension/get_prompt.py @@ -22,9 +22,7 @@ class PromptFormatter(string.Formatter): raise ValueError('Positional arguments are not supported') return super().vformat(format_string, args, kwargs) - def check_unused_args( - self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any] - ) -> None: + def check_unused_args(self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> None: extras = set(kwargs).difference(used_args) if extras: raise KeyError(f'Extra params passed: {extras}') @@ -58,9 +56,7 @@ class PromptTemplate: try: return self.template.format(**prompt_variables) except KeyError as e: - raise RuntimeError( - f"Missing variable '{e.args[0]}' (required: {self._input_variables}) in the prompt template." - ) from None + raise RuntimeError(f"Missing variable '{e.args[0]}' (required: {self._input_variables}) in the prompt template.") from None @click.command('get_prompt', context_settings=termui.CONTEXT_SETTINGS) @@ -128,21 +124,15 @@ def cli( if prompt_template_file and chat_template_file: ctx.fail('prompt-template-file and chat-template-file are mutually exclusive.') - acceptable = set(openllm.CONFIG_MAPPING_NAMES.keys()) | set( - inflection.dasherize(name) for name in openllm.CONFIG_MAPPING_NAMES.keys() - ) + acceptable = set(openllm.CONFIG_MAPPING_NAMES.keys()) | set(inflection.dasherize(name) for name in openllm.CONFIG_MAPPING_NAMES.keys()) if model_id in acceptable: - logger.warning( - 'Using a default prompt from OpenLLM. Note that this prompt might not work for your intended usage.\n' - ) + logger.warning('Using a default prompt from OpenLLM. Note that this prompt might not work for your intended usage.\n') config = openllm.AutoConfig.for_model(model_id) template = prompt_template_file.read() if prompt_template_file is not None else config.template system_message = system_message or config.system_message try: - formatted = ( - PromptTemplate(template).with_options(system_message=system_message).format(instruction=prompt, **_memoized) - ) + formatted = PromptTemplate(template).with_options(system_message=system_message).format(instruction=prompt, **_memoized) except RuntimeError as err: logger.debug('Exception caught while formatting prompt: %s', err) ctx.fail(str(err)) @@ -159,21 +149,15 @@ def cli( for architecture in config.architectures: if architecture in openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE(): system_message = ( - openllm.AutoConfig.infer_class_from_name( - openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE()[architecture] - ) + openllm.AutoConfig.infer_class_from_name(openllm.AutoConfig._CONFIG_MAPPING_NAMES_TO_ARCHITECTURE()[architecture]) .model_construct_env() .system_message ) break else: - ctx.fail( - f'Failed to infer system message from model architecture: {config.architectures}. Please pass in --system-message' - ) + ctx.fail(f'Failed to infer system message from model architecture: {config.architectures}. Please pass in --system-message') messages = [{'role': 'system', 'content': system_message}, {'role': 'user', 'content': prompt}] - formatted = tokenizer.apply_chat_template( - messages, chat_template=chat_template_file, add_generation_prompt=add_generation_prompt, tokenize=False - ) + formatted = tokenizer.apply_chat_template(messages, chat_template=chat_template_file, add_generation_prompt=add_generation_prompt, tokenize=False) termui.echo(orjson.dumps({'prompt': formatted}, option=orjson.OPT_INDENT_2).decode(), fg='white') ctx.exit(0) diff --git a/openllm-python/src/openllm_cli/extension/list_bentos.py b/openllm-python/src/openllm_cli/extension/list_bentos.py index 34355b53..f189938d 100644 --- a/openllm-python/src/openllm_cli/extension/list_bentos.py +++ b/openllm-python/src/openllm_cli/extension/list_bentos.py @@ -13,7 +13,7 @@ from openllm_cli import termui @click.command('list_bentos', context_settings=termui.CONTEXT_SETTINGS) @click.pass_context def cli(ctx: click.Context) -> None: - '''List available bentos built by OpenLLM.''' + """List available bentos built by OpenLLM.""" mapping = { k: [ { diff --git a/openllm-python/src/openllm_cli/extension/list_models.py b/openllm-python/src/openllm_cli/extension/list_models.py index 61e4d26a..b6419ed9 100644 --- a/openllm-python/src/openllm_cli/extension/list_models.py +++ b/openllm-python/src/openllm_cli/extension/list_models.py @@ -18,7 +18,7 @@ if t.TYPE_CHECKING: @click.command('list_models', context_settings=termui.CONTEXT_SETTINGS) @model_name_argument(required=False, shell_complete=model_complete_envvar) def cli(model_name: str | None) -> DictStrAny: - '''List available models in lcoal store to be used wit OpenLLM.''' + """List available models in lcoal store to be used wit OpenLLM.""" models = tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys()) ids_in_local_store = { k: [ @@ -33,17 +33,12 @@ def cli(model_name: str | None) -> DictStrAny: } if model_name is not None: ids_in_local_store = { - k: [ - i - for i in v - if 'model_name' in i.info.labels and i.info.labels['model_name'] == inflection.dasherize(model_name) - ] + k: [i for i in v if 'model_name' in i.info.labels and i.info.labels['model_name'] == inflection.dasherize(model_name)] for k, v in ids_in_local_store.items() } ids_in_local_store = {k: v for k, v in ids_in_local_store.items() if v} local_models = { - k: [{'tag': str(i.tag), 'size': human_readable_size(openllm.utils.calc_dir_size(i.path))} for i in val] - for k, val in ids_in_local_store.items() + k: [{'tag': str(i.tag), 'size': human_readable_size(openllm.utils.calc_dir_size(i.path))} for i in val] for k, val in ids_in_local_store.items() } termui.echo(orjson.dumps(local_models, option=orjson.OPT_INDENT_2).decode(), fg='white') return local_models diff --git a/openllm-python/src/openllm_cli/extension/playground.py b/openllm-python/src/openllm_cli/extension/playground.py index 40ed9831..f8e5b4da 100644 --- a/openllm-python/src/openllm_cli/extension/playground.py +++ b/openllm-python/src/openllm_cli/extension/playground.py @@ -32,14 +32,7 @@ def load_notebook_metadata() -> DictStrAny: @click.command('playground', context_settings=termui.CONTEXT_SETTINGS) @click.argument('output-dir', default=None, required=False) -@click.option( - '--port', - envvar='JUPYTER_PORT', - show_envvar=True, - show_default=True, - default=8888, - help='Default port for Jupyter server', -) +@click.option('--port', envvar='JUPYTER_PORT', show_envvar=True, show_default=True, default=8888, help='Default port for Jupyter server') @click.pass_context def cli(ctx: click.Context, output_dir: str | None, port: int) -> None: """OpenLLM Playground. @@ -60,9 +53,7 @@ def cli(ctx: click.Context, output_dir: str | None, port: int) -> None: > This command requires Jupyter to be installed. Install it with 'pip install "openllm[playground]"' """ if not is_jupyter_available() or not is_jupytext_available() or not is_notebook_available(): - raise RuntimeError( - "Playground requires 'jupyter', 'jupytext', and 'notebook'. Install it with 'pip install \"openllm[playground]\"'" - ) + raise RuntimeError("Playground requires 'jupyter', 'jupytext', and 'notebook'. Install it with 'pip install \"openllm[playground]\"'") metadata = load_notebook_metadata() _temp_dir = False if output_dir is None: @@ -74,9 +65,7 @@ def cli(ctx: click.Context, output_dir: str | None, port: int) -> None: termui.echo('The playground notebooks will be saved to: ' + os.path.abspath(output_dir), fg='blue') for module in pkgutil.iter_modules(playground.__path__): if module.ispkg or os.path.exists(os.path.join(output_dir, module.name + '.ipynb')): - logger.debug( - 'Skipping: %s (%s)', module.name, 'File already exists' if not module.ispkg else f'{module.name} is a module' - ) + logger.debug('Skipping: %s (%s)', module.name, 'File already exists' if not module.ispkg else f'{module.name} is a module') continue if not isinstance(module.module_finder, importlib.machinery.FileFinder): continue @@ -86,20 +75,18 @@ def cli(ctx: click.Context, output_dir: str | None, port: int) -> None: f.cells.insert(0, markdown_cell) jupytext.write(f, os.path.join(output_dir, module.name + '.ipynb'), fmt='notebook') try: - subprocess.check_output( - [ - sys.executable, - '-m', - 'jupyter', - 'notebook', - '--notebook-dir', - output_dir, - '--port', - str(port), - '--no-browser', - '--debug', - ] - ) + subprocess.check_output([ + sys.executable, + '-m', + 'jupyter', + 'notebook', + '--notebook-dir', + output_dir, + '--port', + str(port), + '--no-browser', + '--debug', + ]) except subprocess.CalledProcessError as e: termui.echo(e.output, fg='red') raise click.ClickException(f'Failed to start a jupyter server:\n{e}') from None diff --git a/openllm-python/src/openllm_cli/termui.py b/openllm-python/src/openllm_cli/termui.py index 24755d29..c442d3be 100644 --- a/openllm-python/src/openllm_cli/termui.py +++ b/openllm-python/src/openllm_cli/termui.py @@ -25,14 +25,7 @@ class Level(enum.IntEnum): @property def color(self) -> str | None: - return { - Level.NOTSET: None, - Level.DEBUG: 'cyan', - Level.INFO: 'green', - Level.WARNING: 'yellow', - Level.ERROR: 'red', - Level.CRITICAL: 'red', - }[self] + return {Level.NOTSET: None, Level.DEBUG: 'cyan', Level.INFO: 'green', Level.WARNING: 'yellow', Level.ERROR: 'red', Level.CRITICAL: 'red'}[self] @classmethod def from_logging_level(cls, level: int) -> Level: @@ -82,9 +75,5 @@ def echo(text: t.Any, fg: str | None = None, *, _with_style: bool = True, json: COLUMNS: int = int(os.environ.get('COLUMNS', str(120))) -CONTEXT_SETTINGS: DictStrAny = { - 'help_option_names': ['-h', '--help'], - 'max_content_width': COLUMNS, - 'token_normalize_func': inflection.underscore, -} +CONTEXT_SETTINGS: DictStrAny = {'help_option_names': ['-h', '--help'], 'max_content_width': COLUMNS, 'token_normalize_func': inflection.underscore} __all__ = ['echo', 'COLUMNS', 'CONTEXT_SETTINGS', 'log', 'warning', 'error', 'critical', 'debug', 'info', 'Level'] diff --git a/openllm-python/tests/_strategies/_configuration.py b/openllm-python/tests/_strategies/_configuration.py index 04bfb236..25e12d47 100644 --- a/openllm-python/tests/_strategies/_configuration.py +++ b/openllm-python/tests/_strategies/_configuration.py @@ -12,7 +12,7 @@ logger = logging.getLogger(__name__) @st.composite def model_settings(draw: st.DrawFn): - '''Strategy for generating ModelSettings objects.''' + """Strategy for generating ModelSettings objects.""" kwargs: dict[str, t.Any] = { 'default_id': st.text(min_size=1), 'model_ids': st.lists(st.text(), min_size=1), diff --git a/openllm-python/tests/configuration_test.py b/openllm-python/tests/configuration_test.py index fafa4098..90069b79 100644 --- a/openllm-python/tests/configuration_test.py +++ b/openllm-python/tests/configuration_test.py @@ -66,15 +66,8 @@ def test_config_derived_follow_attrs_protocol(gen_settings: ModelSettings): st.integers(max_value=283473), st.floats(min_value=0.0, max_value=1.0), ) -def test_complex_struct_dump( - gen_settings: ModelSettings, field1: int, temperature: float, input_field1: int, input_temperature: float -): - cl_ = make_llm_config( - 'ComplexLLM', - gen_settings, - fields=(('field1', 'float', field1),), - generation_fields=(('temperature', temperature),), - ) +def test_complex_struct_dump(gen_settings: ModelSettings, field1: int, temperature: float, input_field1: int, input_temperature: float): + cl_ = make_llm_config('ComplexLLM', gen_settings, fields=(('field1', 'float', field1),), generation_fields=(('temperature', temperature),)) sent = cl_() assert sent.model_dump()['field1'] == field1 assert sent.model_dump()['generation_config']['temperature'] == temperature diff --git a/openllm-python/tests/conftest.py b/openllm-python/tests/conftest.py index e49b2656..1efd9e4d 100644 --- a/openllm-python/tests/conftest.py +++ b/openllm-python/tests/conftest.py @@ -10,14 +10,8 @@ import openllm if t.TYPE_CHECKING: from openllm_core._typing_compat import LiteralBackend -_MODELING_MAPPING = { - 'flan_t5': 'google/flan-t5-small', - 'opt': 'facebook/opt-125m', - 'baichuan': 'baichuan-inc/Baichuan-7B', -} -_PROMPT_MAPPING = { - 'qa': 'Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?' -} +_MODELING_MAPPING = {'flan_t5': 'google/flan-t5-small', 'opt': 'facebook/opt-125m', 'baichuan': 'baichuan-inc/Baichuan-7B'} +_PROMPT_MAPPING = {'qa': 'Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?'} def parametrise_local_llm(model: str) -> t.Generator[tuple[str, openllm.LLM[t.Any, t.Any]], None, None]: @@ -31,9 +25,7 @@ def parametrise_local_llm(model: str) -> t.Generator[tuple[str, openllm.LLM[t.An def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: if os.getenv('GITHUB_ACTIONS') is None: if 'prompt' in metafunc.fixturenames and 'llm' in metafunc.fixturenames: - metafunc.parametrize( - 'prompt,llm', [(p, llm) for p, llm in parametrise_local_llm(metafunc.function.__name__[5:-15])] - ) + metafunc.parametrize('prompt,llm', [(p, llm) for p, llm in parametrise_local_llm(metafunc.function.__name__[5:-15])]) def pytest_sessionfinish(session: pytest.Session, exitstatus: int): diff --git a/openllm-python/tests/strategies_test.py b/openllm-python/tests/strategies_test.py index 6b95ac0d..f801ed81 100644 --- a/openllm-python/tests/strategies_test.py +++ b/openllm-python/tests/strategies_test.py @@ -73,13 +73,9 @@ def test_nvidia_gpu_validate(monkeypatch: pytest.MonkeyPatch): mcls.setenv('CUDA_VISIBLE_DEVICES', '') assert len(NvidiaGpuResource.from_system()) >= 0 # TODO: real from_system tests - assert pytest.raises(ValueError, NvidiaGpuResource.validate, [*NvidiaGpuResource.from_system(), 1]).match( - 'Input list should be all string type.' - ) + assert pytest.raises(ValueError, NvidiaGpuResource.validate, [*NvidiaGpuResource.from_system(), 1]).match('Input list should be all string type.') assert pytest.raises(ValueError, NvidiaGpuResource.validate, [-2]).match('Input list should be all string type.') - assert pytest.raises(ValueError, NvidiaGpuResource.validate, ['GPU-5ebe9f43', 'GPU-ac33420d4628']).match( - 'Failed to parse available GPUs UUID' - ) + assert pytest.raises(ValueError, NvidiaGpuResource.validate, ['GPU-5ebe9f43', 'GPU-ac33420d4628']).match('Failed to parse available GPUs UUID') def test_nvidia_gpu_from_spec(monkeypatch: pytest.MonkeyPatch):