mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-02-18 14:47:30 -05:00
chore(style): synchronized style across packages [skip ci]
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
26
cz.py
26
cz.py
@@ -2,27 +2,25 @@
|
||||
from __future__ import annotations
|
||||
import itertools, os, token, tokenize
|
||||
from tabulate import tabulate
|
||||
|
||||
TOKEN_WHITELIST = [token.OP, token.NAME, token.NUMBER, token.STRING]
|
||||
def run_cz(dir: str, package: str):
|
||||
headers = ["Name", "Lines", "Tokens/Line"]
|
||||
headers = ['Name', 'Lines', 'Tokens/Line']
|
||||
table = []
|
||||
for path, _, files in os.walk(os.path.join(dir, "src", package)):
|
||||
for path, _, files in os.walk(os.path.join(dir, 'src', package)):
|
||||
for name in files:
|
||||
if not name.endswith(".py"): continue
|
||||
if not name.endswith('.py'): continue
|
||||
filepath = os.path.join(path, name)
|
||||
with tokenize.open(filepath) as file_:
|
||||
tokens = [t for t in tokenize.generate_tokens(file_.readline) if t.type in TOKEN_WHITELIST]
|
||||
token_count, line_count = len(tokens), len(set([t.start[0] for t in tokens]))
|
||||
table.append([filepath.replace(os.path.join(dir ,"src"), ""), line_count, token_count / line_count if line_count != 0 else 0])
|
||||
print(tabulate([headers, *sorted(table, key=lambda x: -x[1])], headers="firstrow", floatfmt=".1f") + "\n")
|
||||
for dir_name, group in itertools.groupby(sorted([(x[0].rsplit("/", 1)[0], x[1]) for x in table]), key=lambda x: x[0]):
|
||||
print(f"{dir_name:35s} : {sum([x[1] for x in group]):6d}")
|
||||
print(f"\ntotal line count: {sum([x[1] for x in table])}")
|
||||
table.append([filepath.replace(os.path.join(dir, 'src'), ''), line_count, token_count / line_count if line_count != 0 else 0])
|
||||
print(tabulate([headers, *sorted(table, key=lambda x: -x[1])], headers='firstrow', floatfmt='.1f') + '\n')
|
||||
for dir_name, group in itertools.groupby(sorted([(x[0].rsplit('/', 1)[0], x[1]) for x in table]), key=lambda x: x[0]):
|
||||
print(f'{dir_name:35s} : {sum([x[1] for x in group]):6d}')
|
||||
print(f'\ntotal line count: {sum([x[1] for x in table])}')
|
||||
def main() -> int:
|
||||
run_cz("openllm-python", "openllm")
|
||||
run_cz("openllm-core", "openllm_core")
|
||||
run_cz("openllm-client", "openllm_client")
|
||||
run_cz('openllm-python', 'openllm')
|
||||
run_cz('openllm-core', 'openllm_core')
|
||||
run_cz('openllm-client', 'openllm_client')
|
||||
return 0
|
||||
|
||||
if __name__ == "__main__": raise SystemExit(main())
|
||||
if __name__ == '__main__': raise SystemExit(main())
|
||||
|
||||
@@ -14,9 +14,9 @@ logger = logging.getLogger(__name__)
|
||||
class _ClientAttr:
|
||||
_address: str
|
||||
_timeout: float = attr.field(default=30)
|
||||
_api_version: str = attr.field(default="v1")
|
||||
_api_version: str = attr.field(default='v1')
|
||||
|
||||
def __init__(self, address: str, timeout: float = 30, api_version: str = "v1"):
|
||||
def __init__(self, address: str, timeout: float = 30, api_version: str = 'v1'):
|
||||
self.__attrs_init__(address, timeout, api_version)
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -29,37 +29,37 @@ class _ClientAttr:
|
||||
|
||||
@overload
|
||||
@abc.abstractmethod
|
||||
def query(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str:
|
||||
def query(self, prompt: str, *, return_response: t.Literal['processed'], **attrs: t.Any) -> str:
|
||||
...
|
||||
|
||||
@overload
|
||||
@abc.abstractmethod
|
||||
def query(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny:
|
||||
def query(self, prompt: str, *, return_response: t.Literal['raw'], **attrs: t.Any) -> DictStrAny:
|
||||
...
|
||||
|
||||
@overload
|
||||
@abc.abstractmethod
|
||||
def query(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm_core.GenerationOutput:
|
||||
def query(self, prompt: str, *, return_response: t.Literal['attrs'], **attrs: t.Any) -> openllm_core.GenerationOutput:
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def query(self, prompt: str, return_response: t.Literal["attrs", "raw", "processed"] = "processed", **attrs: t.Any) -> t.Any:
|
||||
def query(self, prompt: str, return_response: t.Literal['attrs', 'raw', 'processed'] = 'processed', **attrs: t.Any) -> t.Any:
|
||||
raise NotImplementedError
|
||||
|
||||
# NOTE: Scikit interface
|
||||
@overload
|
||||
@abc.abstractmethod
|
||||
def predict(self, prompt: str, *, return_response: t.Literal["processed"], **attrs: t.Any) -> str:
|
||||
def predict(self, prompt: str, *, return_response: t.Literal['processed'], **attrs: t.Any) -> str:
|
||||
...
|
||||
|
||||
@overload
|
||||
@abc.abstractmethod
|
||||
def predict(self, prompt: str, *, return_response: t.Literal["raw"], **attrs: t.Any) -> DictStrAny:
|
||||
def predict(self, prompt: str, *, return_response: t.Literal['raw'], **attrs: t.Any) -> DictStrAny:
|
||||
...
|
||||
|
||||
@overload
|
||||
@abc.abstractmethod
|
||||
def predict(self, prompt: str, *, return_response: t.Literal["attrs"], **attrs: t.Any) -> openllm_core.GenerationOutput:
|
||||
def predict(self, prompt: str, *, return_response: t.Literal['attrs'], **attrs: t.Any) -> openllm_core.GenerationOutput:
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -69,63 +69,63 @@ class _ClientAttr:
|
||||
@functools.cached_property
|
||||
def _hf_agent(self) -> transformers.HfAgent:
|
||||
if not is_transformers_available(): raise RuntimeError("transformers is required to use HF agent. Install with 'pip install \"openllm-client[agents]\"'.")
|
||||
if not self.supports_hf_agent: raise RuntimeError(f"{self.model_name} ({self.framework}) does not support running HF agent.")
|
||||
if not self.supports_hf_agent: raise RuntimeError(f'{self.model_name} ({self.framework}) does not support running HF agent.')
|
||||
if not is_transformers_supports_agent(): raise RuntimeError("Current 'transformers' does not support Agent. Make sure to upgrade to at least 4.29: 'pip install -U \"transformers>=4.29\"'")
|
||||
import transformers
|
||||
return transformers.HfAgent(urljoin(self._address, "/hf/agent"))
|
||||
return transformers.HfAgent(urljoin(self._address, '/hf/agent'))
|
||||
|
||||
@property
|
||||
def _metadata(self) -> t.Any:
|
||||
return self.call("metadata")
|
||||
return self.call('metadata')
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
try:
|
||||
return self._metadata["model_name"]
|
||||
return self._metadata['model_name']
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
raise RuntimeError('Malformed service endpoint. (Possible malicious)') from None
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
try:
|
||||
return self._metadata["model_id"]
|
||||
return self._metadata['model_id']
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
raise RuntimeError('Malformed service endpoint. (Possible malicious)') from None
|
||||
|
||||
@property
|
||||
def framework(self) -> LiteralRuntime:
|
||||
try:
|
||||
return self._metadata["framework"]
|
||||
return self._metadata['framework']
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
raise RuntimeError('Malformed service endpoint. (Possible malicious)') from None
|
||||
|
||||
@property
|
||||
def timeout(self) -> int:
|
||||
try:
|
||||
return self._metadata["timeout"]
|
||||
return self._metadata['timeout']
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
raise RuntimeError('Malformed service endpoint. (Possible malicious)') from None
|
||||
|
||||
@property
|
||||
def configuration(self) -> dict[str, t.Any]:
|
||||
try:
|
||||
return orjson.loads(self._metadata["configuration"])
|
||||
return orjson.loads(self._metadata['configuration'])
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
raise RuntimeError('Malformed service endpoint. (Possible malicious)') from None
|
||||
|
||||
@property
|
||||
def supports_embeddings(self) -> bool:
|
||||
try:
|
||||
return self._metadata.get("supports_embeddings", False)
|
||||
return self._metadata.get('supports_embeddings', False)
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
raise RuntimeError('Malformed service endpoint. (Possible malicious)') from None
|
||||
|
||||
@property
|
||||
def supports_hf_agent(self) -> bool:
|
||||
try:
|
||||
return self._metadata.get("supports_hf_agent", False)
|
||||
return self._metadata.get('supports_hf_agent', False)
|
||||
except KeyError:
|
||||
raise RuntimeError("Malformed service endpoint. (Possible malicious)") from None
|
||||
raise RuntimeError('Malformed service endpoint. (Possible malicious)') from None
|
||||
|
||||
@property
|
||||
def config(self) -> openllm_core.LLMConfig:
|
||||
@@ -139,7 +139,7 @@ class _Client(_ClientAttr):
|
||||
_port: str
|
||||
|
||||
def call(self, api_name: str, *args: t.Any, **attrs: t.Any) -> t.Any:
|
||||
return self.inner.call(f"{api_name}_{self._api_version}", *args, **attrs)
|
||||
return self.inner.call(f'{api_name}_{self._api_version}', *args, **attrs)
|
||||
|
||||
def health(self) -> t.Any:
|
||||
return self.inner.health()
|
||||
@@ -150,19 +150,19 @@ class _Client(_ClientAttr):
|
||||
return BentoClient.from_url(self._address)
|
||||
|
||||
# Agent integration
|
||||
def ask_agent(self, task: str, *, return_code: bool = False, remote: bool = False, agent_type: LiteralString = "hf", **attrs: t.Any) -> t.Any:
|
||||
if agent_type == "hf": return self._run_hf_agent(task, return_code=return_code, remote=remote, **attrs)
|
||||
def ask_agent(self, task: str, *, return_code: bool = False, remote: bool = False, agent_type: LiteralString = 'hf', **attrs: t.Any) -> t.Any:
|
||||
if agent_type == 'hf': return self._run_hf_agent(task, return_code=return_code, remote=remote, **attrs)
|
||||
else: raise RuntimeError(f"Unknown 'agent_type={agent_type}'")
|
||||
|
||||
def _run_hf_agent(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
|
||||
if len(args) > 1: raise ValueError("'args' should only take one positional argument.")
|
||||
task = kwargs.pop("task", args[0])
|
||||
return_code = kwargs.pop("return_code", False)
|
||||
remote = kwargs.pop("remote", False)
|
||||
task = kwargs.pop('task', args[0])
|
||||
return_code = kwargs.pop('return_code', False)
|
||||
remote = kwargs.pop('remote', False)
|
||||
try:
|
||||
return self._hf_agent.run(task, return_code=return_code, remote=remote, **kwargs)
|
||||
except Exception as err:
|
||||
logger.error("Exception caught while sending instruction to HF agent: %s", err, exc_info=err)
|
||||
logger.error('Exception caught while sending instruction to HF agent: %s', err, exc_info=err)
|
||||
logger.info("Tip: LLMServer at '%s' might not support 'generate_one'.", self._address)
|
||||
class _AsyncClient(_ClientAttr):
|
||||
_host: str
|
||||
@@ -172,7 +172,7 @@ class _AsyncClient(_ClientAttr):
|
||||
self._address, self._timeout = address, timeout
|
||||
|
||||
async def call(self, api_name: str, *args: t.Any, **attrs: t.Any) -> t.Any:
|
||||
return await self.inner.call(f"{api_name}_{self._api_version}", *args, **attrs)
|
||||
return await self.inner.call(f'{api_name}_{self._api_version}', *args, **attrs)
|
||||
|
||||
async def health(self) -> t.Any:
|
||||
return await self.inner.health()
|
||||
@@ -183,27 +183,26 @@ class _AsyncClient(_ClientAttr):
|
||||
return ensure_exec_coro(AsyncBentoClient.from_url(self._address))
|
||||
|
||||
# Agent integration
|
||||
async def ask_agent(self, task: str, *, return_code: bool = False, remote: bool = False, agent_type: LiteralString = "hf", **attrs: t.Any) -> t.Any:
|
||||
"""Async version of agent.run."""
|
||||
if agent_type == "hf": return await self._run_hf_agent(task, return_code=return_code, remote=remote, **attrs)
|
||||
async def ask_agent(self, task: str, *, return_code: bool = False, remote: bool = False, agent_type: LiteralString = 'hf', **attrs: t.Any) -> t.Any:
|
||||
if agent_type == 'hf': return await self._run_hf_agent(task, return_code=return_code, remote=remote, **attrs)
|
||||
else: raise RuntimeError(f"Unknown 'agent_type={agent_type}'")
|
||||
|
||||
async def _run_hf_agent(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
|
||||
if not is_transformers_supports_agent(): raise RuntimeError("This version of transformers does not support agent.run. Make sure to upgrade to transformers>4.30.0")
|
||||
if not is_transformers_supports_agent(): raise RuntimeError('This version of transformers does not support agent.run. Make sure to upgrade to transformers>4.30.0')
|
||||
if len(args) > 1: raise ValueError("'args' should only take one positional argument.")
|
||||
from transformers.tools.agents import clean_code_for_run, get_tool_creation_code, resolve_tools
|
||||
from transformers.tools.python_interpreter import evaluate
|
||||
|
||||
task = kwargs.pop("task", args[0])
|
||||
return_code = kwargs.pop("return_code", False)
|
||||
remote = kwargs.pop("remote", False)
|
||||
stop = ["Task:"]
|
||||
task = kwargs.pop('task', args[0])
|
||||
return_code = kwargs.pop('return_code', False)
|
||||
remote = kwargs.pop('remote', False)
|
||||
stop = ['Task:']
|
||||
prompt = t.cast(str, self._hf_agent.format_prompt(task))
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(self.timeout)) as client:
|
||||
response = await client.post(self._hf_agent.url_endpoint, json={"inputs": prompt, "parameters": {"max_new_tokens": 200, "return_full_text": False, "stop": stop}})
|
||||
if response.status_code != HTTPStatus.OK: raise ValueError(f"Error {response.status_code}: {response.json()}")
|
||||
response = await client.post(self._hf_agent.url_endpoint, json={'inputs': prompt, 'parameters': {'max_new_tokens': 200, 'return_full_text': False, 'stop': stop}})
|
||||
if response.status_code != HTTPStatus.OK: raise ValueError(f'Error {response.status_code}: {response.json()}')
|
||||
|
||||
result = response.json()[0]["generated_text"]
|
||||
result = response.json()[0]['generated_text']
|
||||
# Inference API returns the stop sequence
|
||||
for stop_seq in stop:
|
||||
if result.endswith(stop_seq):
|
||||
@@ -211,62 +210,62 @@ class _AsyncClient(_ClientAttr):
|
||||
break
|
||||
# the below have the same logic as agent.run API
|
||||
explanation, code = clean_code_for_run(result)
|
||||
self._hf_agent.log(f"==Explanation from the agent==\n{explanation}")
|
||||
self._hf_agent.log(f"\n\n==Code generated by the agent==\n{code}")
|
||||
self._hf_agent.log(f'==Explanation from the agent==\n{explanation}')
|
||||
self._hf_agent.log(f'\n\n==Code generated by the agent==\n{code}')
|
||||
if not return_code:
|
||||
self._hf_agent.log("\n\n==Result==")
|
||||
self._hf_agent.log('\n\n==Result==')
|
||||
self._hf_agent.cached_tools = resolve_tools(code, self._hf_agent.toolbox, remote=remote, cached_tools=self._hf_agent.cached_tools)
|
||||
return evaluate(code, self._hf_agent.cached_tools, state=kwargs.copy())
|
||||
else:
|
||||
tool_code = get_tool_creation_code(code, self._hf_agent.toolbox, remote=remote)
|
||||
return f"{tool_code}\n{code}"
|
||||
return f'{tool_code}\n{code}'
|
||||
class BaseClient(_Client):
|
||||
def chat(self, prompt: str, history: list[str], **attrs: t.Any) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def embed(self, prompt: t.Sequence[str] | str) -> openllm_core.EmbeddingsOutput:
|
||||
return openllm_core.EmbeddingsOutput(**self.call("embeddings", list([prompt] if isinstance(prompt, str) else prompt)))
|
||||
return openllm_core.EmbeddingsOutput(**self.call('embeddings', list([prompt] if isinstance(prompt, str) else prompt)))
|
||||
|
||||
def predict(self, prompt: str, **attrs: t.Any) -> openllm_core.GenerationOutput | DictStrAny | str:
|
||||
return self.query(prompt, **attrs)
|
||||
|
||||
def query(self, prompt: str, return_response: t.Literal["attrs", "raw", "processed"] = "processed", **attrs: t.Any) -> t.Any:
|
||||
return_raw_response = attrs.pop("return_raw_response", None)
|
||||
def query(self, prompt: str, return_response: t.Literal['attrs', 'raw', 'processed'] = 'processed', **attrs: t.Any) -> t.Any:
|
||||
return_raw_response = attrs.pop('return_raw_response', None)
|
||||
if return_raw_response is not None:
|
||||
logger.warning("'return_raw_response' is now deprecated. Please use 'return_response=\"raw\"' instead.")
|
||||
if return_raw_response is True: return_response = "raw"
|
||||
return_attrs = attrs.pop("return_attrs", None)
|
||||
if return_raw_response is True: return_response = 'raw'
|
||||
return_attrs = attrs.pop('return_attrs', None)
|
||||
if return_attrs is not None:
|
||||
logger.warning("'return_attrs' is now deprecated. Please use 'return_response=\"attrs\"' instead.")
|
||||
if return_attrs is True: return_response = "attrs"
|
||||
use_default_prompt_template = attrs.pop("use_default_prompt_template", False)
|
||||
if return_attrs is True: return_response = 'attrs'
|
||||
use_default_prompt_template = attrs.pop('use_default_prompt_template', False)
|
||||
prompt, generate_kwargs, postprocess_kwargs = self.config.sanitize_parameters(prompt, use_default_prompt_template=use_default_prompt_template, **attrs)
|
||||
r = openllm_core.GenerationOutput(**self.call("generate", openllm_core.GenerationInput(prompt=prompt, llm_config=self.config.model_construct_env(**generate_kwargs)).model_dump()))
|
||||
if return_response == "attrs": return r
|
||||
elif return_response == "raw": return bentoml_cattr.unstructure(r)
|
||||
r = openllm_core.GenerationOutput(**self.call('generate', openllm_core.GenerationInput(prompt=prompt, llm_config=self.config.model_construct_env(**generate_kwargs)).model_dump()))
|
||||
if return_response == 'attrs': return r
|
||||
elif return_response == 'raw': return bentoml_cattr.unstructure(r)
|
||||
else: return self.config.postprocess_generate(prompt, r.responses, **postprocess_kwargs)
|
||||
class BaseAsyncClient(_AsyncClient):
|
||||
async def chat(self, prompt: str, history: list[str], **attrs: t.Any) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
async def embed(self, prompt: t.Sequence[str] | str) -> openllm_core.EmbeddingsOutput:
|
||||
return openllm_core.EmbeddingsOutput(**(await self.call("embeddings", list([prompt] if isinstance(prompt, str) else prompt))))
|
||||
return openllm_core.EmbeddingsOutput(**(await self.call('embeddings', list([prompt] if isinstance(prompt, str) else prompt))))
|
||||
|
||||
async def predict(self, prompt: str, **attrs: t.Any) -> t.Any:
|
||||
return await self.query(prompt, **attrs)
|
||||
|
||||
async def query(self, prompt: str, return_response: t.Literal["attrs", "raw", "processed"] = "processed", **attrs: t.Any) -> t.Any:
|
||||
return_raw_response = attrs.pop("return_raw_response", None)
|
||||
async def query(self, prompt: str, return_response: t.Literal['attrs', 'raw', 'processed'] = 'processed', **attrs: t.Any) -> t.Any:
|
||||
return_raw_response = attrs.pop('return_raw_response', None)
|
||||
if return_raw_response is not None:
|
||||
logger.warning("'return_raw_response' is now deprecated. Please use 'return_response=\"raw\"' instead.")
|
||||
if return_raw_response is True: return_response = "raw"
|
||||
return_attrs = attrs.pop("return_attrs", None)
|
||||
if return_raw_response is True: return_response = 'raw'
|
||||
return_attrs = attrs.pop('return_attrs', None)
|
||||
if return_attrs is not None:
|
||||
logger.warning("'return_attrs' is now deprecated. Please use 'return_response=\"attrs\"' instead.")
|
||||
if return_attrs is True: return_response = "attrs"
|
||||
use_default_prompt_template = attrs.pop("use_default_prompt_template", False)
|
||||
if return_attrs is True: return_response = 'attrs'
|
||||
use_default_prompt_template = attrs.pop('use_default_prompt_template', False)
|
||||
prompt, generate_kwargs, postprocess_kwargs = self.config.sanitize_parameters(prompt, use_default_prompt_template=use_default_prompt_template, **attrs)
|
||||
r = openllm_core.GenerationOutput(**(await self.call("generate", openllm_core.GenerationInput(prompt=prompt, llm_config=self.config.model_construct_env(**generate_kwargs)).model_dump())))
|
||||
if return_response == "attrs": return r
|
||||
elif return_response == "raw": return bentoml_cattr.unstructure(r)
|
||||
r = openllm_core.GenerationOutput(**(await self.call('generate', openllm_core.GenerationInput(prompt=prompt, llm_config=self.config.model_construct_env(**generate_kwargs)).model_dump())))
|
||||
if return_response == 'attrs': return r
|
||||
elif return_response == 'raw': return bentoml_cattr.unstructure(r)
|
||||
else: return self.config.postprocess_generate(prompt, r.responses, **postprocess_kwargs)
|
||||
|
||||
@@ -16,7 +16,7 @@ import typing as t, bentoml, attr, httpx
|
||||
from abc import abstractmethod
|
||||
if t.TYPE_CHECKING: from bentoml._internal.service.inference_api import InferenceAPI
|
||||
|
||||
__all__ = ["Client", "AsyncClient"]
|
||||
__all__ = ['Client', 'AsyncClient']
|
||||
@attr.define(init=False)
|
||||
class Client:
|
||||
server_url: str
|
||||
@@ -25,7 +25,7 @@ class Client:
|
||||
timeout: int = attr.field(default=30)
|
||||
|
||||
def __init__(self, server_url: str, svc: bentoml.Service, **kwargs: t.Any) -> None:
|
||||
if len(svc.apis) == 0: raise bentoml.exceptions.BentoMLException("No APIs was found while constructing clients.")
|
||||
if len(svc.apis) == 0: raise bentoml.exceptions.BentoMLException('No APIs was found while constructing clients.')
|
||||
self.__attrs_init__(server_url=server_url, endpoints=list(svc.apis), svc=svc)
|
||||
for it, val in kwargs.items():
|
||||
object.__setattr__(self, it, val)
|
||||
@@ -50,7 +50,7 @@ class Client:
|
||||
from ._grpc import GrpcClient
|
||||
return GrpcClient.from_url(url, **kwargs)
|
||||
except Exception as err:
|
||||
raise bentoml.exceptions.BentoMLException("Failed to create client from url: %s" % url) from err
|
||||
raise bentoml.exceptions.BentoMLException('Failed to create client from url: %s' % url) from err
|
||||
|
||||
@staticmethod
|
||||
def wait_until_server_ready(host: str, port: int, timeout: float = 30, **kwargs: t.Any) -> None:
|
||||
@@ -61,7 +61,7 @@ class Client:
|
||||
from ._grpc import GrpcClient
|
||||
return GrpcClient.wait_until_server_ready(host, port, timeout, **kwargs)
|
||||
except Exception as err:
|
||||
raise bentoml.exceptions.BentoMLException("Failed to wait until server ready: %s:%d" % (host, port)) from err
|
||||
raise bentoml.exceptions.BentoMLException('Failed to wait until server ready: %s:%d' % (host, port)) from err
|
||||
@attr.define(init=False)
|
||||
class AsyncClient:
|
||||
server_url: str
|
||||
@@ -70,7 +70,7 @@ class AsyncClient:
|
||||
timeout: int = attr.field(default=30)
|
||||
|
||||
def __init__(self, server_url: str, svc: bentoml.Service, **kwargs: t.Any) -> None:
|
||||
if len(svc.apis) == 0: raise bentoml.exceptions.BentoMLException("No APIs was found while constructing clients.")
|
||||
if len(svc.apis) == 0: raise bentoml.exceptions.BentoMLException('No APIs was found while constructing clients.')
|
||||
self.__attrs_init__(server_url=server_url, endpoints=list(svc.apis), svc=svc)
|
||||
for it, val in kwargs.items():
|
||||
object.__setattr__(self, it, val)
|
||||
@@ -95,7 +95,7 @@ class AsyncClient:
|
||||
from ._grpc import AsyncGrpcClient
|
||||
return await AsyncGrpcClient.from_url(url, **kwargs)
|
||||
except Exception as err:
|
||||
raise bentoml.exceptions.BentoMLException("Failed to create client from url: %s" % url) from err
|
||||
raise bentoml.exceptions.BentoMLException('Failed to create client from url: %s' % url) from err
|
||||
|
||||
@staticmethod
|
||||
async def wait_until_server_ready(host: str, port: int, timeout: float = 30, **kwargs: t.Any) -> None:
|
||||
@@ -106,4 +106,4 @@ class AsyncClient:
|
||||
from ._grpc import AsyncGrpcClient
|
||||
await AsyncGrpcClient.wait_until_server_ready(host, port, timeout, **kwargs)
|
||||
except Exception as err:
|
||||
raise bentoml.exceptions.BentoMLException("Failed to wait until server ready: %s:%d" % (host, port)) from err
|
||||
raise bentoml.exceptions.BentoMLException('Failed to wait until server ready: %s:%d' % (host, port)) from err
|
||||
|
||||
@@ -10,7 +10,7 @@ if not is_grpc_available() or not is_grpc_health_available(): raise ImportError(
|
||||
from grpc import aio
|
||||
from google.protobuf import json_format
|
||||
import grpc, grpc_health.v1.health_pb2 as pb_health, grpc_health.v1.health_pb2_grpc as services_health
|
||||
pb, services = import_generated_stubs("v1")
|
||||
pb, services = import_generated_stubs('v1')
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from bentoml.grpc.v1.service_pb2 import ServiceMetadataResponse
|
||||
@@ -22,7 +22,7 @@ class ClientCredentials(t.TypedDict):
|
||||
@overload
|
||||
def dispatch_channel(
|
||||
server_url: str,
|
||||
typ: t.Literal["async"],
|
||||
typ: t.Literal['async'],
|
||||
ssl: bool = ...,
|
||||
ssl_client_credentials: ClientCredentials | None = ...,
|
||||
options: t.Any | None = ...,
|
||||
@@ -33,7 +33,7 @@ def dispatch_channel(
|
||||
@overload
|
||||
def dispatch_channel(
|
||||
server_url: str,
|
||||
typ: t.Literal["sync"],
|
||||
typ: t.Literal['sync'],
|
||||
ssl: bool = ...,
|
||||
ssl_client_credentials: ClientCredentials | None = ...,
|
||||
options: t.Any | None = ...,
|
||||
@@ -43,7 +43,7 @@ def dispatch_channel(
|
||||
...
|
||||
def dispatch_channel(
|
||||
server_url: str,
|
||||
typ: t.Literal["async", "sync"] = "sync",
|
||||
typ: t.Literal['async', 'sync'] = 'sync',
|
||||
ssl: bool = False,
|
||||
ssl_client_credentials: ClientCredentials | None = None,
|
||||
options: t.Any | None = None,
|
||||
@@ -55,11 +55,11 @@ def dispatch_channel(
|
||||
if ssl_client_credentials is None: raise RuntimeError("'ssl=True' requires 'ssl_client_credentials'")
|
||||
credentials = grpc.ssl_channel_credentials(**{k: load_from_file(v) if isinstance(v, str) else v for k, v in ssl_client_credentials.items()})
|
||||
|
||||
if typ == "async" and ssl: return aio.secure_channel(server_url, credentials=credentials, options=options, compression=compression, interceptors=interceptors)
|
||||
elif typ == "async": return aio.insecure_channel(server_url, options=options, compression=compression, interceptors=interceptors)
|
||||
elif typ == "sync" and ssl: return grpc.secure_channel(server_url, credentials=credentials, options=options, compression=compression)
|
||||
elif typ == "sync": return grpc.insecure_channel(server_url, options=options, compression=compression)
|
||||
else: raise ValueError(f"Unknown type: {typ}")
|
||||
if typ == 'async' and ssl: return aio.secure_channel(server_url, credentials=credentials, options=options, compression=compression, interceptors=interceptors)
|
||||
elif typ == 'async': return aio.insecure_channel(server_url, options=options, compression=compression, interceptors=interceptors)
|
||||
elif typ == 'sync' and ssl: return grpc.secure_channel(server_url, credentials=credentials, options=options, compression=compression)
|
||||
elif typ == 'sync': return grpc.insecure_channel(server_url, options=options, compression=compression)
|
||||
else: raise ValueError(f'Unknown type: {typ}')
|
||||
class GrpcClient(Client):
|
||||
ssl: bool
|
||||
ssl_client_credentials: t.Optional[ClientCredentials]
|
||||
@@ -91,14 +91,14 @@ class GrpcClient(Client):
|
||||
def wait_until_server_ready(host: str, port: int, timeout: float = 30, check_interval: int = 1, **kwargs: t.Any) -> None:
|
||||
with dispatch_channel(
|
||||
f"{host.replace(r'localhost', '0.0.0.0')}:{port}",
|
||||
typ="sync",
|
||||
options=kwargs.get("options", None),
|
||||
compression=kwargs.get("compression", None),
|
||||
ssl=kwargs.get("ssl", False),
|
||||
ssl_client_credentials=kwargs.get("ssl_client_credentials", None)
|
||||
typ='sync',
|
||||
options=kwargs.get('options', None),
|
||||
compression=kwargs.get('compression', None),
|
||||
ssl=kwargs.get('ssl', False),
|
||||
ssl_client_credentials=kwargs.get('ssl_client_credentials', None)
|
||||
) as channel:
|
||||
req = pb_health.HealthCheckRequest()
|
||||
req.service = "bentoml.grpc.v1.BentoService"
|
||||
req.service = 'bentoml.grpc.v1.BentoService'
|
||||
health_stub = services_health.HealthStub(channel)
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
@@ -107,30 +107,30 @@ class GrpcClient(Client):
|
||||
if resp.status == pb_health.HealthCheckResponse.SERVING: break
|
||||
else: time.sleep(check_interval)
|
||||
except grpc.RpcError:
|
||||
logger.debug("Waiting for server to be ready...")
|
||||
logger.debug('Waiting for server to be ready...')
|
||||
time.sleep(check_interval)
|
||||
try:
|
||||
resp = health_stub.Check(req)
|
||||
if resp.status != pb_health.HealthCheckResponse.SERVING: raise TimeoutError(f"Timed out waiting {timeout} seconds for server at '{host}:{port}' to be ready.")
|
||||
except grpc.RpcError as err:
|
||||
logger.error("Caught RpcError while connecting to %s:%s:\n", host, port)
|
||||
logger.error('Caught RpcError while connecting to %s:%s:\n', host, port)
|
||||
logger.error(err)
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def from_url(cls, url: str, **kwargs: t.Any) -> GrpcClient:
|
||||
with dispatch_channel(
|
||||
url.replace(r"localhost", "0.0.0.0"),
|
||||
typ="sync",
|
||||
options=kwargs.get("options", None),
|
||||
compression=kwargs.get("compression", None),
|
||||
ssl=kwargs.get("ssl", False),
|
||||
ssl_client_credentials=kwargs.get("ssl_client_credentials", None)
|
||||
url.replace(r'localhost', '0.0.0.0'),
|
||||
typ='sync',
|
||||
options=kwargs.get('options', None),
|
||||
compression=kwargs.get('compression', None),
|
||||
ssl=kwargs.get('ssl', False),
|
||||
ssl_client_credentials=kwargs.get('ssl_client_credentials', None)
|
||||
) as channel:
|
||||
metadata = t.cast(
|
||||
"ServiceMetadataResponse",
|
||||
'ServiceMetadataResponse',
|
||||
channel.unary_unary(
|
||||
"/bentoml.grpc.v1.BentoService/ServiceMetadata", request_serializer=pb.ServiceMetadataRequest.SerializeToString, response_deserializer=pb.ServiceMetadataResponse.FromString
|
||||
'/bentoml.grpc.v1.BentoService/ServiceMetadata', request_serializer=pb.ServiceMetadataRequest.SerializeToString, response_deserializer=pb.ServiceMetadataResponse.FromString
|
||||
)(pb.ServiceMetadataRequest())
|
||||
)
|
||||
reflection = bentoml.Service(metadata.name)
|
||||
@@ -139,23 +139,23 @@ class GrpcClient(Client):
|
||||
reflection.apis[api.name] = InferenceAPI[t.Any](
|
||||
None,
|
||||
bentoml.io.from_spec({
|
||||
"id": api.input.descriptor_id, "args": json_format.MessageToDict(api.input.attributes).get("args", None)
|
||||
'id': api.input.descriptor_id, 'args': json_format.MessageToDict(api.input.attributes).get('args', None)
|
||||
}),
|
||||
bentoml.io.from_spec({
|
||||
"id": api.output.descriptor_id, "args": json_format.MessageToDict(api.output.attributes).get("args", None)
|
||||
'id': api.output.descriptor_id, 'args': json_format.MessageToDict(api.output.attributes).get('args', None)
|
||||
}),
|
||||
name=api.name,
|
||||
doc=api.docs
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to instantiate client for API %s: ", api.name, e)
|
||||
logger.error('Failed to instantiate client for API %s: ', api.name, e)
|
||||
return cls(url, reflection, **kwargs)
|
||||
|
||||
def health(self) -> t.Any:
|
||||
return services_health.HealthStub(self.inner).Check(pb_health.HealthCheckRequest(service=""))
|
||||
return services_health.HealthStub(self.inner).Check(pb_health.HealthCheckRequest(service=''))
|
||||
|
||||
def _call(self, data: t.Any, /, *, _inference_api: InferenceAPI[t.Any], **kwargs: t.Any) -> t.Any:
|
||||
channel_kwargs = {k: kwargs.pop(f"_grpc_channel_{k}", None) for k in {"timeout", "metadata", "credentials", "wait_for_ready", "compression"}}
|
||||
channel_kwargs = {k: kwargs.pop(f'_grpc_channel_{k}', None) for k in {'timeout', 'metadata', 'credentials', 'wait_for_ready', 'compression'}}
|
||||
if _inference_api.multi_input:
|
||||
if data is not None: raise ValueError(f"'{_inference_api.name}' takes multiple inputs, and thus required to pass as keyword arguments.")
|
||||
fake_resp = ensure_exec_coro(_inference_api.input.to_proto(kwargs))
|
||||
@@ -163,8 +163,8 @@ class GrpcClient(Client):
|
||||
fake_resp = ensure_exec_coro(_inference_api.input.to_proto(data))
|
||||
api_fn = {v: k for k, v in self.svc.apis.items()}
|
||||
stubs = services.BentoServiceStub(self.inner)
|
||||
proto = stubs.Call(pb.Request(**{"api_name": api_fn[_inference_api], _inference_api.input.proto_fields[0]: fake_resp}), **channel_kwargs)
|
||||
return ensure_exec_coro(_inference_api.output.from_proto(getattr(proto, proto.WhichOneof("content"))))
|
||||
proto = stubs.Call(pb.Request(**{'api_name': api_fn[_inference_api], _inference_api.input.proto_fields[0]: fake_resp}), **channel_kwargs)
|
||||
return ensure_exec_coro(_inference_api.output.from_proto(getattr(proto, proto.WhichOneof('content'))))
|
||||
class AsyncGrpcClient(AsyncClient):
|
||||
ssl: bool
|
||||
ssl_client_credentials: t.Optional[ClientCredentials]
|
||||
@@ -198,14 +198,14 @@ class AsyncGrpcClient(AsyncClient):
|
||||
async def wait_until_server_ready(host: str, port: int, timeout: float = 30, check_interval: int = 1, **kwargs: t.Any) -> None:
|
||||
async with dispatch_channel(
|
||||
f"{host.replace(r'localhost', '0.0.0.0')}:{port}",
|
||||
typ="async",
|
||||
options=kwargs.get("options", None),
|
||||
compression=kwargs.get("compression", None),
|
||||
ssl=kwargs.get("ssl", False),
|
||||
ssl_client_credentials=kwargs.get("ssl_client_credentials", None)
|
||||
typ='async',
|
||||
options=kwargs.get('options', None),
|
||||
compression=kwargs.get('compression', None),
|
||||
ssl=kwargs.get('ssl', False),
|
||||
ssl_client_credentials=kwargs.get('ssl_client_credentials', None)
|
||||
) as channel:
|
||||
req = pb_health.HealthCheckRequest()
|
||||
req.service = "bentoml.grpc.v1.BentoService"
|
||||
req.service = 'bentoml.grpc.v1.BentoService'
|
||||
health_stub = services_health.HealthStub(channel)
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
@@ -214,31 +214,31 @@ class AsyncGrpcClient(AsyncClient):
|
||||
if resp.status == pb_health.HealthCheckResponse.SERVING: break
|
||||
else: time.sleep(check_interval)
|
||||
except grpc.RpcError:
|
||||
logger.debug("Waiting for server to be ready...")
|
||||
logger.debug('Waiting for server to be ready...')
|
||||
time.sleep(check_interval)
|
||||
try:
|
||||
resp = health_stub.Check(req)
|
||||
if resp.status != pb_health.HealthCheckResponse.SERVING: raise TimeoutError(f"Timed out waiting {timeout} seconds for server at '{host}:{port}' to be ready.")
|
||||
except grpc.RpcError as err:
|
||||
logger.error("Caught RpcError while connecting to %s:%s:\n", host, port)
|
||||
logger.error('Caught RpcError while connecting to %s:%s:\n', host, port)
|
||||
logger.error(err)
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
async def from_url(cls, url: str, **kwargs: t.Any) -> AsyncGrpcClient:
|
||||
async with dispatch_channel(
|
||||
url.replace(r"localhost", "0.0.0.0"),
|
||||
typ="async",
|
||||
options=kwargs.get("options", None),
|
||||
compression=kwargs.get("compression", None),
|
||||
ssl=kwargs.get("ssl", False),
|
||||
ssl_client_credentials=kwargs.get("ssl_client_credentials", None),
|
||||
interceptors=kwargs.get("interceptors", None)
|
||||
url.replace(r'localhost', '0.0.0.0'),
|
||||
typ='async',
|
||||
options=kwargs.get('options', None),
|
||||
compression=kwargs.get('compression', None),
|
||||
ssl=kwargs.get('ssl', False),
|
||||
ssl_client_credentials=kwargs.get('ssl_client_credentials', None),
|
||||
interceptors=kwargs.get('interceptors', None)
|
||||
) as channel:
|
||||
metadata = t.cast(
|
||||
"ServiceMetadataResponse",
|
||||
'ServiceMetadataResponse',
|
||||
channel.unary_unary(
|
||||
"/bentoml.grpc.v1.BentoService/ServiceMetadata", request_serializer=pb.ServiceMetadataRequest.SerializeToString, response_deserializer=pb.ServiceMetadataResponse.FromString
|
||||
'/bentoml.grpc.v1.BentoService/ServiceMetadata', request_serializer=pb.ServiceMetadataRequest.SerializeToString, response_deserializer=pb.ServiceMetadataResponse.FromString
|
||||
)(pb.ServiceMetadataRequest())
|
||||
)
|
||||
reflection = bentoml.Service(metadata.name)
|
||||
@@ -247,23 +247,23 @@ class AsyncGrpcClient(AsyncClient):
|
||||
reflection.apis[api.name] = InferenceAPI[t.Any](
|
||||
None,
|
||||
bentoml.io.from_spec({
|
||||
"id": api.input.descriptor_id, "args": json_format.MessageToDict(api.input.attributes).get("args", None)
|
||||
'id': api.input.descriptor_id, 'args': json_format.MessageToDict(api.input.attributes).get('args', None)
|
||||
}),
|
||||
bentoml.io.from_spec({
|
||||
"id": api.output.descriptor_id, "args": json_format.MessageToDict(api.output.attributes).get("args", None)
|
||||
'id': api.output.descriptor_id, 'args': json_format.MessageToDict(api.output.attributes).get('args', None)
|
||||
}),
|
||||
name=api.name,
|
||||
doc=api.docs
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to instantiate client for API %s: ", api.name, e)
|
||||
logger.error('Failed to instantiate client for API %s: ', api.name, e)
|
||||
return cls(url, reflection, **kwargs)
|
||||
|
||||
async def health(self) -> t.Any:
|
||||
return await services_health.HealthStub(self.inner).Check(pb_health.HealthCheckRequest(service=""))
|
||||
return await services_health.HealthStub(self.inner).Check(pb_health.HealthCheckRequest(service=''))
|
||||
|
||||
async def _call(self, data: t.Any, /, *, _inference_api: InferenceAPI[t.Any], **kwargs: t.Any) -> t.Any:
|
||||
channel_kwargs = {k: kwargs.pop(f"_grpc_channel_{k}", None) for k in {"timeout", "metadata", "credentials", "wait_for_ready", "compression"}}
|
||||
channel_kwargs = {k: kwargs.pop(f'_grpc_channel_{k}', None) for k in {'timeout', 'metadata', 'credentials', 'wait_for_ready', 'compression'}}
|
||||
state = self.inner.get_state(try_to_connect=True)
|
||||
if state != grpc.ChannelConnectivity.READY: await self.inner.channel_ready()
|
||||
if _inference_api.multi_input:
|
||||
@@ -274,5 +274,5 @@ class AsyncGrpcClient(AsyncClient):
|
||||
api_fn = {v: k for k, v in self.svc.apis.items()}
|
||||
async with self.inner:
|
||||
stubs = services.BentoServiceStub(self.inner)
|
||||
proto = await stubs.Call(pb.Request(**{"api_name": api_fn[_inference_api], _inference_api.input.proto_fields[0]: fake_resp}), **channel_kwargs)
|
||||
return await _inference_api.output.from_proto(getattr(proto, proto.WhichOneof("content")))
|
||||
proto = await stubs.Call(pb.Request(**{'api_name': api_fn[_inference_api], _inference_api.input.proto_fields[0]: fake_resp}), **channel_kwargs)
|
||||
return await _inference_api.output.from_proto(getattr(proto, proto.WhichOneof('content')))
|
||||
|
||||
@@ -8,64 +8,64 @@ logger = logging.getLogger(__name__)
|
||||
class HttpClient(Client):
|
||||
@functools.cached_property
|
||||
def inner(self) -> httpx.Client:
|
||||
if not urlparse(self.server_url).netloc: raise ValueError(f"Invalid server url: {self.server_url}")
|
||||
if not urlparse(self.server_url).netloc: raise ValueError(f'Invalid server url: {self.server_url}')
|
||||
return httpx.Client(base_url=self.server_url)
|
||||
|
||||
@staticmethod
|
||||
def wait_until_server_ready(host: str, port: int, timeout: float = 30, check_interval: int = 1, **kwargs: t.Any) -> None:
|
||||
host = host if "://" in host else "http://" + host
|
||||
logger.debug("Waiting for server @ `%s:%d` to be ready...", host, port)
|
||||
host = host if '://' in host else 'http://' + host
|
||||
logger.debug('Waiting for server @ `%s:%d` to be ready...', host, port)
|
||||
start = time.time()
|
||||
while time.time() - start < timeout:
|
||||
try:
|
||||
status = httpx.get(f"{host}:{port}/readyz").status_code
|
||||
status = httpx.get(f'{host}:{port}/readyz').status_code
|
||||
if status == 200: break
|
||||
else: time.sleep(check_interval)
|
||||
except (httpx.ConnectError, urllib.error.URLError, ConnectionError):
|
||||
logger.debug("Server is not ready yet, retrying in %d seconds...", check_interval)
|
||||
logger.debug('Server is not ready yet, retrying in %d seconds...', check_interval)
|
||||
time.sleep(check_interval)
|
||||
# Try once more and raise for exception
|
||||
try:
|
||||
httpx.get(f"{host}:{port}/readyz").raise_for_status()
|
||||
httpx.get(f'{host}:{port}/readyz').raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
logger.error("Failed to wait until server ready: %s:%d", host, port)
|
||||
logger.error('Failed to wait until server ready: %s:%d', host, port)
|
||||
logger.error(err)
|
||||
raise
|
||||
|
||||
def health(self) -> httpx.Response:
|
||||
return self.inner.get("/readyz")
|
||||
return self.inner.get('/readyz')
|
||||
|
||||
@classmethod
|
||||
def from_url(cls, url: str, **kwargs: t.Any) -> HttpClient:
|
||||
url = url if "://" in url else "http://" + url
|
||||
resp = httpx.get(f"{url}/docs.json")
|
||||
if resp.status_code != 200: raise ValueError(f"Failed to get OpenAPI schema from the server: {resp.status_code} {resp.reason_phrase}:\n{resp.content.decode()}")
|
||||
url = url if '://' in url else 'http://' + url
|
||||
resp = httpx.get(f'{url}/docs.json')
|
||||
if resp.status_code != 200: raise ValueError(f'Failed to get OpenAPI schema from the server: {resp.status_code} {resp.reason_phrase}:\n{resp.content.decode()}')
|
||||
_spec = orjson.loads(resp.content)
|
||||
|
||||
reflection = bentoml.Service(_spec["info"]["title"])
|
||||
reflection = bentoml.Service(_spec['info']['title'])
|
||||
|
||||
for route, spec in _spec["paths"].items():
|
||||
for route, spec in _spec['paths'].items():
|
||||
for meth_spec in spec.values():
|
||||
if "tags" in meth_spec and "Service APIs" in meth_spec["tags"]:
|
||||
if "x-bentoml-io-descriptor" not in meth_spec["requestBody"]: raise ValueError(f"Malformed BentoML spec received from BentoML server {url}")
|
||||
if "x-bentoml-io-descriptor" not in meth_spec["responses"]["200"]: raise ValueError(f"Malformed BentoML spec received from BentoML server {url}")
|
||||
if "x-bentoml-name" not in meth_spec: raise ValueError(f"Malformed BentoML spec received from BentoML server {url}")
|
||||
if 'tags' in meth_spec and 'Service APIs' in meth_spec['tags']:
|
||||
if 'x-bentoml-io-descriptor' not in meth_spec['requestBody']: raise ValueError(f'Malformed BentoML spec received from BentoML server {url}')
|
||||
if 'x-bentoml-io-descriptor' not in meth_spec['responses']['200']: raise ValueError(f'Malformed BentoML spec received from BentoML server {url}')
|
||||
if 'x-bentoml-name' not in meth_spec: raise ValueError(f'Malformed BentoML spec received from BentoML server {url}')
|
||||
try:
|
||||
reflection.apis[meth_spec["x-bentoml-name"]] = InferenceAPI[t.Any](
|
||||
reflection.apis[meth_spec['x-bentoml-name']] = InferenceAPI[t.Any](
|
||||
None,
|
||||
bentoml.io.from_spec(meth_spec["requestBody"]["x-bentoml-io-descriptor"]),
|
||||
bentoml.io.from_spec(meth_spec["responses"]["200"]["x-bentoml-io-descriptor"]),
|
||||
name=meth_spec["x-bentoml-name"],
|
||||
doc=meth_spec["description"],
|
||||
route=route.lstrip("/")
|
||||
bentoml.io.from_spec(meth_spec['requestBody']['x-bentoml-io-descriptor']),
|
||||
bentoml.io.from_spec(meth_spec['responses']['200']['x-bentoml-io-descriptor']),
|
||||
name=meth_spec['x-bentoml-name'],
|
||||
doc=meth_spec['description'],
|
||||
route=route.lstrip('/')
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to instantiate client for API %s: ", meth_spec["x-bentoml-name"], e)
|
||||
logger.error('Failed to instantiate client for API %s: ', meth_spec['x-bentoml-name'], e)
|
||||
return cls(url, reflection)
|
||||
|
||||
def _call(self, data: t.Any, /, *, _inference_api: InferenceAPI[t.Any], **kwargs: t.Any) -> t.Any:
|
||||
# All gRPC kwargs should be popped out.
|
||||
kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_grpc_")}
|
||||
kwargs = {k: v for k, v in kwargs.items() if not k.startswith('_grpc_')}
|
||||
if _inference_api.multi_input:
|
||||
if data is not None: raise ValueError(f"'{_inference_api.name}' takes multiple inputs, and thus required to pass as keyword arguments.")
|
||||
fake_resp = ensure_exec_coro(_inference_api.input.to_http_response(kwargs, None))
|
||||
@@ -77,13 +77,13 @@ class HttpClient(Client):
|
||||
else: body = fake_resp.body
|
||||
|
||||
resp = self.inner.post(
|
||||
"/" + _inference_api.route if not _inference_api.route.startswith("/") else _inference_api.route,
|
||||
'/' + _inference_api.route if not _inference_api.route.startswith('/') else _inference_api.route,
|
||||
data=body,
|
||||
headers={"content-type": fake_resp.headers["content-type"]},
|
||||
headers={'content-type': fake_resp.headers['content-type']},
|
||||
timeout=self.timeout
|
||||
)
|
||||
if resp.status_code != 200: raise ValueError(f"Error while making request: {resp.status_code}: {resp.content!s}")
|
||||
fake_req = starlette.requests.Request(scope={"type": "http"})
|
||||
if resp.status_code != 200: raise ValueError(f'Error while making request: {resp.status_code}: {resp.content!s}')
|
||||
fake_req = starlette.requests.Request(scope={'type': 'http'})
|
||||
headers = starlette.datastructures.Headers(headers=resp.headers)
|
||||
fake_req._body = resp.content
|
||||
# Request.headers sets a _headers variable. We will need to set this value to our fake request object.
|
||||
@@ -92,63 +92,63 @@ class HttpClient(Client):
|
||||
class AsyncHttpClient(AsyncClient):
|
||||
@functools.cached_property
|
||||
def inner(self) -> httpx.AsyncClient:
|
||||
if not urlparse(self.server_url).netloc: raise ValueError(f"Invalid server url: {self.server_url}")
|
||||
if not urlparse(self.server_url).netloc: raise ValueError(f'Invalid server url: {self.server_url}')
|
||||
return httpx.AsyncClient(base_url=self.server_url)
|
||||
|
||||
@staticmethod
|
||||
async def wait_until_server_ready(host: str, port: int, timeout: float = 30, check_interval: int = 1, **kwargs: t.Any) -> None:
|
||||
host = host if "://" in host else "http://" + host
|
||||
logger.debug("Waiting for server @ `%s:%d` to be ready...", host, port)
|
||||
host = host if '://' in host else 'http://' + host
|
||||
logger.debug('Waiting for server @ `%s:%d` to be ready...', host, port)
|
||||
start = time.time()
|
||||
while time.time() - start < timeout:
|
||||
try:
|
||||
async with httpx.AsyncClient(base_url=f"{host}:{port}") as sess:
|
||||
resp = await sess.get("/readyz")
|
||||
async with httpx.AsyncClient(base_url=f'{host}:{port}') as sess:
|
||||
resp = await sess.get('/readyz')
|
||||
if resp.status_code == 200: break
|
||||
else: await asyncio.sleep(check_interval)
|
||||
except (httpx.ConnectError, urllib.error.URLError, ConnectionError):
|
||||
logger.debug("Server is not ready yet, retrying in %d seconds...", check_interval)
|
||||
logger.debug('Server is not ready yet, retrying in %d seconds...', check_interval)
|
||||
await asyncio.sleep(check_interval)
|
||||
# Try once more and raise for exception
|
||||
async with httpx.AsyncClient(base_url=f"{host}:{port}") as sess:
|
||||
resp = await sess.get("/readyz")
|
||||
if resp.status_code != 200: raise TimeoutError(f"Timeout while waiting for server @ `{host}:{port}` to be ready: {resp.status_code}: {resp.content!s}")
|
||||
async with httpx.AsyncClient(base_url=f'{host}:{port}') as sess:
|
||||
resp = await sess.get('/readyz')
|
||||
if resp.status_code != 200: raise TimeoutError(f'Timeout while waiting for server @ `{host}:{port}` to be ready: {resp.status_code}: {resp.content!s}')
|
||||
|
||||
async def health(self) -> httpx.Response:
|
||||
return await self.inner.get("/readyz")
|
||||
return await self.inner.get('/readyz')
|
||||
|
||||
@classmethod
|
||||
async def from_url(cls, url: str, **kwargs: t.Any) -> AsyncHttpClient:
|
||||
url = url if "://" in url else "http://" + url
|
||||
url = url if '://' in url else 'http://' + url
|
||||
async with httpx.AsyncClient(base_url=url) as session:
|
||||
resp = await session.get("/docs.json")
|
||||
if resp.status_code != 200: raise ValueError(f"Failed to get OpenAPI schema from the server: {resp.status_code} {resp.reason_phrase}:\n{(await resp.aread()).decode()}")
|
||||
resp = await session.get('/docs.json')
|
||||
if resp.status_code != 200: raise ValueError(f'Failed to get OpenAPI schema from the server: {resp.status_code} {resp.reason_phrase}:\n{(await resp.aread()).decode()}')
|
||||
_spec = orjson.loads(await resp.aread())
|
||||
|
||||
reflection = bentoml.Service(_spec["info"]["title"])
|
||||
reflection = bentoml.Service(_spec['info']['title'])
|
||||
|
||||
for route, spec in _spec["paths"].items():
|
||||
for route, spec in _spec['paths'].items():
|
||||
for meth_spec in spec.values():
|
||||
if "tags" in meth_spec and "Service APIs" in meth_spec["tags"]:
|
||||
if "x-bentoml-io-descriptor" not in meth_spec["requestBody"]: raise ValueError(f"Malformed BentoML spec received from BentoML server {url}")
|
||||
if "x-bentoml-io-descriptor" not in meth_spec["responses"]["200"]: raise ValueError(f"Malformed BentoML spec received from BentoML server {url}")
|
||||
if "x-bentoml-name" not in meth_spec: raise ValueError(f"Malformed BentoML spec received from BentoML server {url}")
|
||||
if 'tags' in meth_spec and 'Service APIs' in meth_spec['tags']:
|
||||
if 'x-bentoml-io-descriptor' not in meth_spec['requestBody']: raise ValueError(f'Malformed BentoML spec received from BentoML server {url}')
|
||||
if 'x-bentoml-io-descriptor' not in meth_spec['responses']['200']: raise ValueError(f'Malformed BentoML spec received from BentoML server {url}')
|
||||
if 'x-bentoml-name' not in meth_spec: raise ValueError(f'Malformed BentoML spec received from BentoML server {url}')
|
||||
try:
|
||||
reflection.apis[meth_spec["x-bentoml-name"]] = InferenceAPI[t.Any](
|
||||
reflection.apis[meth_spec['x-bentoml-name']] = InferenceAPI[t.Any](
|
||||
None,
|
||||
bentoml.io.from_spec(meth_spec["requestBody"]["x-bentoml-io-descriptor"]),
|
||||
bentoml.io.from_spec(meth_spec["responses"]["200"]["x-bentoml-io-descriptor"]),
|
||||
name=meth_spec["x-bentoml-name"],
|
||||
doc=meth_spec["description"],
|
||||
route=route.lstrip("/")
|
||||
bentoml.io.from_spec(meth_spec['requestBody']['x-bentoml-io-descriptor']),
|
||||
bentoml.io.from_spec(meth_spec['responses']['200']['x-bentoml-io-descriptor']),
|
||||
name=meth_spec['x-bentoml-name'],
|
||||
doc=meth_spec['description'],
|
||||
route=route.lstrip('/')
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error("Failed to instantiate client for API %s: ", meth_spec["x-bentoml-name"], e)
|
||||
logger.error('Failed to instantiate client for API %s: ', meth_spec['x-bentoml-name'], e)
|
||||
return cls(url, reflection)
|
||||
|
||||
async def _call(self, data: t.Any, /, *, _inference_api: InferenceAPI[t.Any], **kwargs: t.Any) -> t.Any:
|
||||
# All gRPC kwargs should be popped out.
|
||||
kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_grpc_")}
|
||||
kwargs = {k: v for k, v in kwargs.items() if not k.startswith('_grpc_')}
|
||||
if _inference_api.multi_input:
|
||||
if data is not None: raise ValueError(f"'{_inference_api.name}' takes multiple inputs, and thus required to pass as keyword arguments.")
|
||||
fake_resp = await _inference_api.input.to_http_response(kwargs, None)
|
||||
@@ -160,13 +160,13 @@ class AsyncHttpClient(AsyncClient):
|
||||
else: body = t.cast(t.Any, fake_resp.body)
|
||||
|
||||
resp = await self.inner.post(
|
||||
"/" + _inference_api.route if not _inference_api.route.startswith("/") else _inference_api.route,
|
||||
'/' + _inference_api.route if not _inference_api.route.startswith('/') else _inference_api.route,
|
||||
data=body,
|
||||
headers={"content-type": fake_resp.headers["content-type"]},
|
||||
headers={'content-type': fake_resp.headers['content-type']},
|
||||
timeout=self.timeout
|
||||
)
|
||||
if resp.status_code != 200: raise ValueError(f"Error making request: {resp.status_code}: {(await resp.aread())!s}")
|
||||
fake_req = starlette.requests.Request(scope={"type": "http"})
|
||||
if resp.status_code != 200: raise ValueError(f'Error making request: {resp.status_code}: {(await resp.aread())!s}')
|
||||
fake_req = starlette.requests.Request(scope={'type': 'http'})
|
||||
headers = starlette.datastructures.Headers(headers=resp.headers)
|
||||
fake_req._body = resp.content
|
||||
# Request.headers sets a _headers variable. We will need to set this value to our fake request object.
|
||||
|
||||
@@ -4,10 +4,10 @@ from urllib.parse import urlparse
|
||||
from ._base import BaseClient, BaseAsyncClient
|
||||
logger = logging.getLogger(__name__)
|
||||
def process_http_address(self: AsyncHTTPClient | HTTPClient, address: str) -> None:
|
||||
address = address if "://" in address else "http://" + address
|
||||
address = address if '://' in address else 'http://' + address
|
||||
parsed = urlparse(address)
|
||||
self._host, *_port = parsed.netloc.split(":")
|
||||
if len(_port) == 0: self._port = "80" if parsed.scheme == "http" else "443"
|
||||
self._host, *_port = parsed.netloc.split(':')
|
||||
if len(_port) == 0: self._port = '80' if parsed.scheme == 'http' else '443'
|
||||
else: self._port = next(iter(_port))
|
||||
class HTTPClient(BaseClient):
|
||||
def __init__(self, address: str, timeout: int = 30):
|
||||
@@ -19,9 +19,9 @@ class AsyncHTTPClient(BaseAsyncClient):
|
||||
super().__init__(address, timeout)
|
||||
class GrpcClient(BaseClient):
|
||||
def __init__(self, address: str, timeout: int = 30):
|
||||
self._host, self._port = address.split(":")
|
||||
self._host, self._port = address.split(':')
|
||||
super().__init__(address, timeout)
|
||||
class AsyncGrpcClient(BaseAsyncClient):
|
||||
def __init__(self, address: str, timeout: int = 30):
|
||||
self._host, self._port = address.split(":")
|
||||
self._host, self._port = address.split(':')
|
||||
super().__init__(address, timeout)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -3,12 +3,12 @@ import string, typing as t
|
||||
class PromptFormatter(string.Formatter):
|
||||
"""This PromptFormatter is largely based on langchain's implementation."""
|
||||
def vformat(self, format_string: str, args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> t.Any:
|
||||
if len(args) > 0: raise ValueError("Positional arguments are not supported")
|
||||
if len(args) > 0: raise ValueError('Positional arguments are not supported')
|
||||
return super().vformat(format_string, args, kwargs)
|
||||
|
||||
def check_unused_args(self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> None:
|
||||
extras = set(kwargs).difference(used_args)
|
||||
if extras: raise KeyError(f"Extra params passed: {extras}")
|
||||
if extras: raise KeyError(f'Extra params passed: {extras}')
|
||||
|
||||
def extract_template_variables(self, template: str) -> t.Sequence[str]:
|
||||
return [field[1] for field in self.parse(template) if field[1] is not None]
|
||||
@@ -19,7 +19,7 @@ def process_prompt(prompt: str, template: str | None = None, use_prompt_template
|
||||
elif template is None: raise ValueError("'template' can't be None while 'use_prompt_template=False'")
|
||||
template_variables = default_formatter.extract_template_variables(template)
|
||||
prompt_variables = {k: v for k, v in attrs.items() if k in template_variables}
|
||||
if "instruction" in prompt_variables: raise RuntimeError("'instruction' should be passed as the first argument instead of kwargs when 'use_prompt_template=True'")
|
||||
if 'instruction' in prompt_variables: raise RuntimeError("'instruction' should be passed as the first argument instead of kwargs when 'use_prompt_template=True'")
|
||||
try:
|
||||
return template.format(instruction=prompt, **prompt_variables)
|
||||
except KeyError as e:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Schema definition for OpenLLM. This can be use for client interaction."""
|
||||
'''Schema definition for OpenLLM. This can be use for client interaction.'''
|
||||
from __future__ import annotations
|
||||
import functools, typing as t
|
||||
import attr, inflection
|
||||
@@ -12,7 +12,7 @@ class GenerationInput:
|
||||
adapter_name: str | None = attr.field(default=None)
|
||||
|
||||
def model_dump(self) -> dict[str, t.Any]:
|
||||
return {"prompt": self.prompt, "llm_config": self.llm_config.model_dump(flatten=True), "adapter_name": self.adapter_name}
|
||||
return {'prompt': self.prompt, 'llm_config': self.llm_config.model_dump(flatten=True), 'adapter_name': self.adapter_name}
|
||||
|
||||
@staticmethod
|
||||
def convert_llm_config(data: dict[str, t.Any] | LLMConfig, cls: type[LLMConfig] | None = None) -> LLMConfig:
|
||||
@@ -29,11 +29,11 @@ class GenerationInput:
|
||||
@classmethod
|
||||
def from_llm_config(cls, llm_config: LLMConfig) -> type[GenerationInput]:
|
||||
return attr.make_class(
|
||||
inflection.camelize(llm_config["model_name"]) + "GenerationInput",
|
||||
inflection.camelize(llm_config['model_name']) + 'GenerationInput',
|
||||
attrs={
|
||||
"prompt": attr.field(type=str),
|
||||
"llm_config": attr.field(type=llm_config.__class__, default=llm_config, converter=functools.partial(cls.convert_llm_config, cls=llm_config.__class__)),
|
||||
"adapter_name": attr.field(default=None, type=str)
|
||||
'prompt': attr.field(type=str),
|
||||
'llm_config': attr.field(type=llm_config.__class__, default=llm_config, converter=functools.partial(cls.convert_llm_config, cls=llm_config.__class__)),
|
||||
'adapter_name': attr.field(default=None, type=str)
|
||||
}
|
||||
)
|
||||
@attr.frozen(slots=True)
|
||||
|
||||
@@ -13,17 +13,17 @@ class DynResource(t.Protocol):
|
||||
...
|
||||
logger = logging.getLogger(__name__)
|
||||
def _strtoul(s: str) -> int:
|
||||
"""Return -1 or positive integer sequence string starts with,."""
|
||||
'''Return -1 or positive integer sequence string starts with,.'''
|
||||
if not s: return -1
|
||||
idx = 0
|
||||
for idx, c in enumerate(s):
|
||||
if not (c.isdigit() or (idx == 0 and c in "+-")): break
|
||||
if not (c.isdigit() or (idx == 0 and c in '+-')): break
|
||||
if idx + 1 == len(s): idx += 1 # noqa: PLW2901
|
||||
# NOTE: idx will be set via enumerate
|
||||
return int(s[:idx]) if idx > 0 else -1
|
||||
def _parse_list_with_prefix(lst: str, prefix: str) -> list[str]:
|
||||
rcs: list[str] = []
|
||||
for elem in lst.split(","):
|
||||
for elem in lst.split(','):
|
||||
# Repeated id results in empty set
|
||||
if elem in rcs: return []
|
||||
# Anything other but prefix is ignored
|
||||
@@ -41,21 +41,21 @@ def _parse_visible_devices(default_var: None, *, respect_env: t.Literal[True]) -
|
||||
def _parse_visible_devices(default_var: str = ..., *, respect_env: t.Literal[False]) -> list[str]:
|
||||
...
|
||||
def _parse_visible_devices(default_var: str | None = None, respect_env: bool = True) -> list[str] | None:
|
||||
"""CUDA_VISIBLE_DEVICES aware with default var for parsing spec."""
|
||||
'''CUDA_VISIBLE_DEVICES aware with default var for parsing spec.'''
|
||||
if respect_env:
|
||||
spec = os.environ.get("CUDA_VISIBLE_DEVICES", default_var)
|
||||
spec = os.environ.get('CUDA_VISIBLE_DEVICES', default_var)
|
||||
if not spec: return None
|
||||
else:
|
||||
if default_var is None: raise ValueError("spec is required to be not None when parsing spec.")
|
||||
if default_var is None: raise ValueError('spec is required to be not None when parsing spec.')
|
||||
spec = default_var
|
||||
|
||||
if spec.startswith("GPU-"): return _parse_list_with_prefix(spec, "GPU-")
|
||||
if spec.startswith("MIG-"): return _parse_list_with_prefix(spec, "MIG-")
|
||||
if spec.startswith('GPU-'): return _parse_list_with_prefix(spec, 'GPU-')
|
||||
if spec.startswith('MIG-'): return _parse_list_with_prefix(spec, 'MIG-')
|
||||
# XXX: We need to somehow handle cases such as '100m'
|
||||
# CUDA_VISIBLE_DEVICES uses something like strtoul
|
||||
# which makes `1gpu2,2ampere` is equivalent to `1,2`
|
||||
rc: list[int] = []
|
||||
for el in spec.split(","):
|
||||
for el in spec.split(','):
|
||||
x = _strtoul(el.strip())
|
||||
# Repeated ordinal results in empty set
|
||||
if x in rc: return []
|
||||
@@ -66,14 +66,14 @@ def _parse_visible_devices(default_var: str | None = None, respect_env: bool = T
|
||||
def _from_system(cls: type[DynResource]) -> list[str]:
|
||||
visible_devices = _parse_visible_devices()
|
||||
if visible_devices is None:
|
||||
if cls.resource_id == "amd.com/gpu":
|
||||
if cls.resource_id == 'amd.com/gpu':
|
||||
if not psutil.LINUX:
|
||||
if DEBUG: warnings.warn("AMD GPUs is currently only supported on Linux.", stacklevel=_STACK_LEVEL)
|
||||
if DEBUG: warnings.warn('AMD GPUs is currently only supported on Linux.', stacklevel=_STACK_LEVEL)
|
||||
return []
|
||||
# ROCm does not currently have the rocm_smi wheel.
|
||||
# So we need to use the ctypes bindings directly.
|
||||
# we don't want to use CLI because parsing is a pain.
|
||||
sys.path.append("/opt/rocm/libexec/rocm_smi")
|
||||
sys.path.append('/opt/rocm/libexec/rocm_smi')
|
||||
try:
|
||||
from ctypes import byref, c_uint32
|
||||
|
||||
@@ -88,7 +88,7 @@ def _from_system(cls: type[DynResource]) -> list[str]:
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
return []
|
||||
finally:
|
||||
sys.path.remove("/opt/rocm/libexec/rocm_smi")
|
||||
sys.path.remove('/opt/rocm/libexec/rocm_smi')
|
||||
else:
|
||||
try:
|
||||
from cuda import cuda
|
||||
@@ -110,11 +110,11 @@ def _from_spec(cls: type[DynResource], spec: str) -> list[str]:
|
||||
def _from_spec(cls: type[DynResource], spec: t.Any) -> list[str]:
|
||||
if isinstance(spec, int):
|
||||
if spec in (-1, 0): return []
|
||||
if spec < -1: raise ValueError("Spec cannot be < -1.")
|
||||
if spec < -1: raise ValueError('Spec cannot be < -1.')
|
||||
return [str(i) for i in range(spec)]
|
||||
elif isinstance(spec, str):
|
||||
if not spec: return []
|
||||
if spec.isdigit(): spec = ",".join([str(i) for i in range(_strtoul(spec))])
|
||||
if spec.isdigit(): spec = ','.join([str(i) for i in range(_strtoul(spec))])
|
||||
return _parse_visible_devices(spec, respect_env=False)
|
||||
elif isinstance(spec, list):
|
||||
return [str(x) for x in spec]
|
||||
@@ -124,9 +124,9 @@ def _raw_device_uuid_nvml() -> list[str] | None:
|
||||
from ctypes import CDLL, byref, c_int, c_void_p, create_string_buffer
|
||||
|
||||
try:
|
||||
nvml_h = CDLL("libnvidia-ml.so.1")
|
||||
nvml_h = CDLL('libnvidia-ml.so.1')
|
||||
except Exception:
|
||||
warnings.warn("Failed to find nvidia binding", stacklevel=_STACK_LEVEL)
|
||||
warnings.warn('Failed to find nvidia binding', stacklevel=_STACK_LEVEL)
|
||||
return None
|
||||
|
||||
rc = nvml_h.nvmlInit()
|
||||
@@ -136,98 +136,98 @@ def _raw_device_uuid_nvml() -> list[str] | None:
|
||||
dev_count = c_int(-1)
|
||||
rc = nvml_h.nvmlDeviceGetCount_v2(byref(dev_count))
|
||||
if rc != 0:
|
||||
warnings.warn("Failed to get available device from system.", stacklevel=_STACK_LEVEL)
|
||||
warnings.warn('Failed to get available device from system.', stacklevel=_STACK_LEVEL)
|
||||
return None
|
||||
uuids: list[str] = []
|
||||
for idx in range(dev_count.value):
|
||||
dev_id = c_void_p()
|
||||
rc = nvml_h.nvmlDeviceGetHandleByIndex_v2(idx, byref(dev_id))
|
||||
if rc != 0:
|
||||
warnings.warn(f"Failed to get device handle for {idx}", stacklevel=_STACK_LEVEL)
|
||||
warnings.warn(f'Failed to get device handle for {idx}', stacklevel=_STACK_LEVEL)
|
||||
return None
|
||||
buf_len = 96
|
||||
buf = create_string_buffer(buf_len)
|
||||
rc = nvml_h.nvmlDeviceGetUUID(dev_id, buf, buf_len)
|
||||
if rc != 0:
|
||||
warnings.warn(f"Failed to get device UUID for {idx}", stacklevel=_STACK_LEVEL)
|
||||
warnings.warn(f'Failed to get device UUID for {idx}', stacklevel=_STACK_LEVEL)
|
||||
return None
|
||||
uuids.append(buf.raw.decode("ascii").strip("\0"))
|
||||
uuids.append(buf.raw.decode('ascii').strip('\0'))
|
||||
del nvml_h
|
||||
return uuids
|
||||
def _validate(cls: type[DynResource], val: list[t.Any]) -> None:
|
||||
if cls.resource_id == "amd.com/gpu":
|
||||
if cls.resource_id == 'amd.com/gpu':
|
||||
raise RuntimeError("AMD GPU validation is not yet supported. Make sure to call 'get_resource(..., validate=False)'")
|
||||
if not all(isinstance(i, str) for i in val): raise ValueError("Input list should be all string type.")
|
||||
if not all(isinstance(i, str) for i in val): raise ValueError('Input list should be all string type.')
|
||||
|
||||
try:
|
||||
from cuda import cuda
|
||||
|
||||
err, *_ = cuda.cuInit(0)
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
raise RuntimeError("Failed to initialise CUDA runtime binding.")
|
||||
raise RuntimeError('Failed to initialise CUDA runtime binding.')
|
||||
# correctly parse handle
|
||||
for el in val:
|
||||
if el.startswith("GPU-") or el.startswith("MIG-"):
|
||||
if el.startswith('GPU-') or el.startswith('MIG-'):
|
||||
uuids = _raw_device_uuid_nvml()
|
||||
if uuids is None: raise ValueError("Failed to parse available GPUs UUID")
|
||||
if el not in uuids: raise ValueError(f"Given UUID {el} is not found with available UUID (available: {uuids})")
|
||||
if uuids is None: raise ValueError('Failed to parse available GPUs UUID')
|
||||
if el not in uuids: raise ValueError(f'Given UUID {el} is not found with available UUID (available: {uuids})')
|
||||
elif el.isdigit():
|
||||
err, _ = cuda.cuDeviceGet(int(el))
|
||||
if err != cuda.CUresult.CUDA_SUCCESS: raise ValueError(f"Failed to get device {el}")
|
||||
if err != cuda.CUresult.CUDA_SUCCESS: raise ValueError(f'Failed to get device {el}')
|
||||
except (ImportError, RuntimeError):
|
||||
pass
|
||||
def _make_resource_class(name: str, resource_kind: str, docstring: str) -> type[DynResource]:
|
||||
return types.new_class(
|
||||
name, (bentoml.Resource[t.List[str]], ReprMixin), {"resource_id": resource_kind},
|
||||
name, (bentoml.Resource[t.List[str]], ReprMixin), {'resource_id': resource_kind},
|
||||
lambda ns: ns.update({
|
||||
"resource_id": resource_kind,
|
||||
"from_spec": classmethod(_from_spec),
|
||||
"from_system": classmethod(_from_system),
|
||||
"validate": classmethod(_validate),
|
||||
"__repr_keys__": property(lambda _: {"resource_id"}),
|
||||
"__doc__": inspect.cleandoc(docstring),
|
||||
"__module__": "openllm._strategies"
|
||||
'resource_id': resource_kind,
|
||||
'from_spec': classmethod(_from_spec),
|
||||
'from_system': classmethod(_from_system),
|
||||
'validate': classmethod(_validate),
|
||||
'__repr_keys__': property(lambda _: {'resource_id'}),
|
||||
'__doc__': inspect.cleandoc(docstring),
|
||||
'__module__': 'openllm._strategies'
|
||||
})
|
||||
)
|
||||
# NOTE: we need to hint these t.Literal since mypy is to dumb to infer this as literal :facepalm:
|
||||
_TPU_RESOURCE: t.Literal["cloud-tpus.google.com/v2"] = "cloud-tpus.google.com/v2"
|
||||
_AMD_GPU_RESOURCE: t.Literal["amd.com/gpu"] = "amd.com/gpu"
|
||||
_NVIDIA_GPU_RESOURCE: t.Literal["nvidia.com/gpu"] = "nvidia.com/gpu"
|
||||
_CPU_RESOURCE: t.Literal["cpu"] = "cpu"
|
||||
_TPU_RESOURCE: t.Literal['cloud-tpus.google.com/v2'] = 'cloud-tpus.google.com/v2'
|
||||
_AMD_GPU_RESOURCE: t.Literal['amd.com/gpu'] = 'amd.com/gpu'
|
||||
_NVIDIA_GPU_RESOURCE: t.Literal['nvidia.com/gpu'] = 'nvidia.com/gpu'
|
||||
_CPU_RESOURCE: t.Literal['cpu'] = 'cpu'
|
||||
|
||||
NvidiaGpuResource = _make_resource_class(
|
||||
"NvidiaGpuResource",
|
||||
'NvidiaGpuResource',
|
||||
_NVIDIA_GPU_RESOURCE,
|
||||
"""NVIDIA GPU resource.
|
||||
'''NVIDIA GPU resource.
|
||||
|
||||
This is a modified version of internal's BentoML's NvidiaGpuResource
|
||||
where it respects and parse CUDA_VISIBLE_DEVICES correctly."""
|
||||
where it respects and parse CUDA_VISIBLE_DEVICES correctly.'''
|
||||
)
|
||||
AmdGpuResource = _make_resource_class(
|
||||
"AmdGpuResource",
|
||||
'AmdGpuResource',
|
||||
_AMD_GPU_RESOURCE,
|
||||
"""AMD GPU resource.
|
||||
'''AMD GPU resource.
|
||||
|
||||
Since ROCm will respect CUDA_VISIBLE_DEVICES, the behaviour of from_spec, from_system are similar to
|
||||
``NvidiaGpuResource``. Currently ``validate`` is not yet supported."""
|
||||
``NvidiaGpuResource``. Currently ``validate`` is not yet supported.'''
|
||||
)
|
||||
|
||||
LiteralResourceSpec = t.Literal["cloud-tpus.google.com/v2", "amd.com/gpu", "nvidia.com/gpu", "cpu"]
|
||||
LiteralResourceSpec = t.Literal['cloud-tpus.google.com/v2', 'amd.com/gpu', 'nvidia.com/gpu', 'cpu']
|
||||
# convenient mapping
|
||||
def resource_spec(name: t.Literal["tpu", "amd", "nvidia", "cpu"]) -> LiteralResourceSpec:
|
||||
if name == "tpu": return _TPU_RESOURCE
|
||||
elif name == "amd": return _AMD_GPU_RESOURCE
|
||||
elif name == "nvidia": return _NVIDIA_GPU_RESOURCE
|
||||
elif name == "cpu": return _CPU_RESOURCE
|
||||
def resource_spec(name: t.Literal['tpu', 'amd', 'nvidia', 'cpu']) -> LiteralResourceSpec:
|
||||
if name == 'tpu': return _TPU_RESOURCE
|
||||
elif name == 'amd': return _AMD_GPU_RESOURCE
|
||||
elif name == 'nvidia': return _NVIDIA_GPU_RESOURCE
|
||||
elif name == 'cpu': return _CPU_RESOURCE
|
||||
else: raise ValueError("Unknown alias. Accepted: ['tpu', 'amd', 'nvidia', 'cpu']")
|
||||
@functools.lru_cache
|
||||
def available_resource_spec() -> tuple[LiteralResourceSpec, ...]:
|
||||
"""This is a utility function helps to determine the available resources from given running system.
|
||||
'''This is a utility function helps to determine the available resources from given running system.
|
||||
|
||||
It will first check for TPUs -> AMD GPUS -> NVIDIA GPUS -> CPUs.
|
||||
|
||||
TODO: Supports TPUs
|
||||
"""
|
||||
'''
|
||||
available: list[LiteralResourceSpec] = []
|
||||
if len(AmdGpuResource.from_system()) > 0: available.append(_AMD_GPU_RESOURCE)
|
||||
if len(NvidiaGpuResource.from_system()) > 0: available.append(_NVIDIA_GPU_RESOURCE)
|
||||
@@ -244,81 +244,81 @@ class CascadingResourceStrategy(bentoml.Strategy, ReprMixin):
|
||||
"""
|
||||
@classmethod
|
||||
def get_worker_count(cls, runnable_class: type[bentoml.Runnable], resource_request: dict[str, t.Any] | None, workers_per_resource: float) -> int:
|
||||
"""Return the number of workers to be used for the given runnable class.
|
||||
'''Return the number of workers to be used for the given runnable class.
|
||||
|
||||
Note that for all available GPU, the number of workers will always be 1.
|
||||
"""
|
||||
'''
|
||||
if resource_request is None: resource_request = system_resources()
|
||||
# use NVIDIA
|
||||
kind = "nvidia.com/gpu"
|
||||
kind = 'nvidia.com/gpu'
|
||||
nvidia_req = get_resource(resource_request, kind)
|
||||
if nvidia_req is not None: return 1
|
||||
# use AMD
|
||||
kind = "amd.com/gpu"
|
||||
kind = 'amd.com/gpu'
|
||||
amd_req = get_resource(resource_request, kind, validate=False)
|
||||
if amd_req is not None: return 1
|
||||
# use CPU
|
||||
cpus = get_resource(resource_request, "cpu")
|
||||
cpus = get_resource(resource_request, 'cpu')
|
||||
if cpus is not None and cpus > 0:
|
||||
if "cpu" not in runnable_class.SUPPORTED_RESOURCES: logger.warning("No known supported resource available for %s, falling back to using CPU.", runnable_class)
|
||||
if 'cpu' not in runnable_class.SUPPORTED_RESOURCES: logger.warning('No known supported resource available for %s, falling back to using CPU.', runnable_class)
|
||||
|
||||
if runnable_class.SUPPORTS_CPU_MULTI_THREADING:
|
||||
if isinstance(workers_per_resource, float) and workers_per_resource < 1.0: raise ValueError("Fractional CPU multi threading support is not yet supported.")
|
||||
if isinstance(workers_per_resource, float) and workers_per_resource < 1.0: raise ValueError('Fractional CPU multi threading support is not yet supported.')
|
||||
return int(workers_per_resource)
|
||||
return math.ceil(cpus) * workers_per_resource
|
||||
|
||||
# this should not be reached by user since we always read system resource as default
|
||||
raise ValueError(f"No known supported resource available for {runnable_class}. Please check your resource request. Leaving it blank will allow BentoML to use system resources.")
|
||||
raise ValueError(f'No known supported resource available for {runnable_class}. Please check your resource request. Leaving it blank will allow BentoML to use system resources.')
|
||||
|
||||
@classmethod
|
||||
def get_worker_env(cls, runnable_class: type[bentoml.Runnable], resource_request: dict[str, t.Any] | None, workers_per_resource: int | float, worker_index: int) -> dict[str, t.Any]:
|
||||
"""Get worker env for this given worker_index.
|
||||
'''Get worker env for this given worker_index.
|
||||
|
||||
Args:
|
||||
runnable_class: The runnable class to be run.
|
||||
resource_request: The resource request of the runnable.
|
||||
workers_per_resource: # of workers per resource.
|
||||
worker_index: The index of the worker, start from 0.
|
||||
"""
|
||||
cuda_env = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
||||
disabled = cuda_env in ("", "-1")
|
||||
'''
|
||||
cuda_env = os.environ.get('CUDA_VISIBLE_DEVICES', None)
|
||||
disabled = cuda_env in ('', '-1')
|
||||
environ: dict[str, t.Any] = {}
|
||||
|
||||
if resource_request is None: resource_request = system_resources()
|
||||
# use NVIDIA
|
||||
kind = "nvidia.com/gpu"
|
||||
kind = 'nvidia.com/gpu'
|
||||
typ = get_resource(resource_request, kind)
|
||||
if typ is not None and len(typ) > 0 and kind in runnable_class.SUPPORTED_RESOURCES:
|
||||
if disabled:
|
||||
logger.debug("CUDA_VISIBLE_DEVICES is disabled, %s will not be using GPU.", worker_index)
|
||||
environ["CUDA_VISIBLE_DEVICES"] = cuda_env
|
||||
logger.debug('CUDA_VISIBLE_DEVICES is disabled, %s will not be using GPU.', worker_index)
|
||||
environ['CUDA_VISIBLE_DEVICES'] = cuda_env
|
||||
return environ
|
||||
environ["CUDA_VISIBLE_DEVICES"] = cls.transpile_workers_to_cuda_envvar(workers_per_resource, typ, worker_index)
|
||||
logger.debug("Environ for worker %s: %s", worker_index, environ)
|
||||
environ['CUDA_VISIBLE_DEVICES'] = cls.transpile_workers_to_cuda_envvar(workers_per_resource, typ, worker_index)
|
||||
logger.debug('Environ for worker %s: %s', worker_index, environ)
|
||||
return environ
|
||||
# use AMD
|
||||
kind = "amd.com/gpu"
|
||||
kind = 'amd.com/gpu'
|
||||
typ = get_resource(resource_request, kind, validate=False)
|
||||
if typ is not None and len(typ) > 0 and kind in runnable_class.SUPPORTED_RESOURCES:
|
||||
if disabled:
|
||||
logger.debug("CUDA_VISIBLE_DEVICES is disabled, %s will not be using GPU.", worker_index)
|
||||
environ["CUDA_VISIBLE_DEVICES"] = cuda_env
|
||||
logger.debug('CUDA_VISIBLE_DEVICES is disabled, %s will not be using GPU.', worker_index)
|
||||
environ['CUDA_VISIBLE_DEVICES'] = cuda_env
|
||||
return environ
|
||||
environ["CUDA_VISIBLE_DEVICES"] = cls.transpile_workers_to_cuda_envvar(workers_per_resource, typ, worker_index)
|
||||
logger.debug("Environ for worker %s: %s", worker_index, environ)
|
||||
environ['CUDA_VISIBLE_DEVICES'] = cls.transpile_workers_to_cuda_envvar(workers_per_resource, typ, worker_index)
|
||||
logger.debug('Environ for worker %s: %s', worker_index, environ)
|
||||
return environ
|
||||
# use CPU
|
||||
cpus = get_resource(resource_request, "cpu")
|
||||
cpus = get_resource(resource_request, 'cpu')
|
||||
if cpus is not None and cpus > 0:
|
||||
environ["CUDA_VISIBLE_DEVICES"] = "-1" # disable gpu
|
||||
environ['CUDA_VISIBLE_DEVICES'] = '-1' # disable gpu
|
||||
if runnable_class.SUPPORTS_CPU_MULTI_THREADING:
|
||||
thread_count = math.ceil(cpus)
|
||||
for thread_env in THREAD_ENVS:
|
||||
environ[thread_env] = os.environ.get(thread_env, str(thread_count))
|
||||
logger.debug("Environ for worker %s: %s", worker_index, environ)
|
||||
logger.debug('Environ for worker %s: %s', worker_index, environ)
|
||||
return environ
|
||||
for thread_env in THREAD_ENVS:
|
||||
environ[thread_env] = os.environ.get(thread_env, "1")
|
||||
environ[thread_env] = os.environ.get(thread_env, '1')
|
||||
return environ
|
||||
return environ
|
||||
|
||||
@@ -334,13 +334,13 @@ class CascadingResourceStrategy(bentoml.Strategy, ReprMixin):
|
||||
# then it will round down to 2. If workers_per_source=0.6, then it will also round up to 2.
|
||||
assigned_resource_per_worker = round(1 / workers_per_resource)
|
||||
if len(gpus) < assigned_resource_per_worker:
|
||||
logger.warning("Failed to allocate %s GPUs for %s (number of available GPUs < assigned workers per resource [%s])", gpus, worker_index, assigned_resource_per_worker)
|
||||
logger.warning('Failed to allocate %s GPUs for %s (number of available GPUs < assigned workers per resource [%s])', gpus, worker_index, assigned_resource_per_worker)
|
||||
raise IndexError(f"There aren't enough assigned GPU(s) for given worker id '{worker_index}' [required: {assigned_resource_per_worker}].")
|
||||
assigned_gpu = gpus[assigned_resource_per_worker * worker_index:assigned_resource_per_worker * (worker_index+1)]
|
||||
dev = ",".join(assigned_gpu)
|
||||
dev = ','.join(assigned_gpu)
|
||||
else:
|
||||
idx = worker_index // workers_per_resource
|
||||
if idx >= len(gpus): raise ValueError(f"Number of available GPU ({gpus}) preceeds the given workers_per_resource {workers_per_resource}")
|
||||
if idx >= len(gpus): raise ValueError(f'Number of available GPU ({gpus}) preceeds the given workers_per_resource {workers_per_resource}')
|
||||
dev = str(gpus[idx])
|
||||
return dev
|
||||
__all__ = ["CascadingResourceStrategy", "get_resource"]
|
||||
__all__ = ['CascadingResourceStrategy', 'get_resource']
|
||||
|
||||
@@ -11,24 +11,24 @@ if t.TYPE_CHECKING:
|
||||
|
||||
from .utils.lazy import VersionInfo
|
||||
M = t.TypeVar(
|
||||
"M",
|
||||
bound="t.Union[transformers.PreTrainedModel, transformers.Pipeline, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel, vllm.LLMEngine, vllm.AsyncLLMEngine, peft.PeftModel, autogptq.modeling.BaseGPTQForCausalLM]"
|
||||
'M',
|
||||
bound='t.Union[transformers.PreTrainedModel, transformers.Pipeline, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel, vllm.LLMEngine, vllm.AsyncLLMEngine, peft.PeftModel, autogptq.modeling.BaseGPTQForCausalLM]'
|
||||
)
|
||||
T = t.TypeVar("T", bound="t.Union[transformers.PreTrainedTokenizerFast, transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerBase]")
|
||||
T = t.TypeVar('T', bound='t.Union[transformers.PreTrainedTokenizerFast, transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerBase]')
|
||||
|
||||
AnyCallable = t.Callable[..., t.Any]
|
||||
DictStrAny = t.Dict[str, t.Any]
|
||||
ListAny = t.List[t.Any]
|
||||
ListStr = t.List[str]
|
||||
TupleAny = t.Tuple[t.Any, ...]
|
||||
At = t.TypeVar("At", bound=attr.AttrsInstance)
|
||||
At = t.TypeVar('At', bound=attr.AttrsInstance)
|
||||
|
||||
LiteralRuntime = t.Literal["pt", "tf", "flax", "vllm"]
|
||||
AdapterType = t.Literal["lora", "adalora", "adaption_prompt", "prefix_tuning", "p_tuning", "prompt_tuning", "ia3"]
|
||||
LiteralRuntime = t.Literal['pt', 'tf', 'flax', 'vllm']
|
||||
AdapterType = t.Literal['lora', 'adalora', 'adaption_prompt', 'prefix_tuning', 'p_tuning', 'prompt_tuning', 'ia3']
|
||||
|
||||
# TODO: support quay
|
||||
LiteralContainerRegistry = t.Literal["docker", "gh", "ecr"]
|
||||
LiteralContainerVersionStrategy = t.Literal["release", "nightly", "latest", "custom"]
|
||||
LiteralContainerRegistry = t.Literal['docker', 'gh', 'ecr']
|
||||
LiteralContainerVersionStrategy = t.Literal['release', 'nightly', 'latest', 'custom']
|
||||
|
||||
if sys.version_info[:2] >= (3, 11):
|
||||
from typing import LiteralString as LiteralString, Self as Self, overload as overload
|
||||
@@ -58,13 +58,13 @@ class RefTuple(TupleAny):
|
||||
version: VersionInfo
|
||||
strategy: LiteralContainerVersionStrategy
|
||||
class LLMRunnable(bentoml.Runnable, t.Generic[M, T]):
|
||||
SUPPORTED_RESOURCES = ("amd.com/gpu", "nvidia.com/gpu", "cpu")
|
||||
SUPPORTED_RESOURCES = ('amd.com/gpu', 'nvidia.com/gpu', 'cpu')
|
||||
SUPPORTS_CPU_MULTI_THREADING = True
|
||||
__call__: RunnableMethod[LLMRunnable[M, T], [str], list[t.Any]]
|
||||
set_adapter: RunnableMethod[LLMRunnable[M, T], [str], dict[t.Literal["success", "error_msg"], bool | str]]
|
||||
set_adapter: RunnableMethod[LLMRunnable[M, T], [str], dict[t.Literal['success', 'error_msg'], bool | str]]
|
||||
embeddings: RunnableMethod[LLMRunnable[M, T], [list[str]], LLMEmbeddings]
|
||||
generate: RunnableMethod[LLMRunnable[M, T], [str], list[t.Any]]
|
||||
generate_one: RunnableMethod[LLMRunnable[M, T], [str, list[str]], t.Sequence[dict[t.Literal["generated_text"], str]]]
|
||||
generate_one: RunnableMethod[LLMRunnable[M, T], [str, list[str]], t.Sequence[dict[t.Literal['generated_text'], str]]]
|
||||
generate_iterator: RunnableMethod[LLMRunnable[M, T], [str], t.Generator[str, None, str]]
|
||||
class LLMRunner(bentoml.Runner, t.Generic[M, T]):
|
||||
__doc__: str
|
||||
@@ -79,7 +79,7 @@ class LLMRunner(bentoml.Runner, t.Generic[M, T]):
|
||||
has_adapters: bool
|
||||
embeddings: RunnerMethod[LLMRunnable[M, T], [list[str]], t.Sequence[LLMEmbeddings]]
|
||||
generate: RunnerMethod[LLMRunnable[M, T], [str], list[t.Any]]
|
||||
generate_one: RunnerMethod[LLMRunnable[M, T], [str, list[str]], t.Sequence[dict[t.Literal["generated_text"], str]]]
|
||||
generate_one: RunnerMethod[LLMRunnable[M, T], [str, list[str]], t.Sequence[dict[t.Literal['generated_text'], str]]]
|
||||
generate_iterator: RunnerMethod[LLMRunnable[M, T], [str], t.Generator[str, None, str]]
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -13,9 +13,9 @@ if t.TYPE_CHECKING:
|
||||
ConfigItemsView = _odict_items[str, type[openllm_core.LLMConfig]]
|
||||
|
||||
# NOTE: This is the entrypoint when adding new model config
|
||||
CONFIG_MAPPING_NAMES = OrderedDict([("chatglm", "ChatGLMConfig"), ("dolly_v2", "DollyV2Config"), ("falcon", "FalconConfig"), ("flan_t5", "FlanT5Config"), ("gpt_neox", "GPTNeoXConfig"), (
|
||||
"llama", "LlamaConfig"
|
||||
), ("mpt", "MPTConfig"), ("opt", "OPTConfig"), ("stablelm", "StableLMConfig"), ("starcoder", "StarCoderConfig"), ("baichuan", "BaichuanConfig")])
|
||||
CONFIG_MAPPING_NAMES = OrderedDict([('chatglm', 'ChatGLMConfig'), ('dolly_v2', 'DollyV2Config'), ('falcon', 'FalconConfig'), ('flan_t5', 'FlanT5Config'), ('gpt_neox', 'GPTNeoXConfig'), (
|
||||
'llama', 'LlamaConfig'
|
||||
), ('mpt', 'MPTConfig'), ('opt', 'OPTConfig'), ('stablelm', 'StableLMConfig'), ('starcoder', 'StarCoderConfig'), ('baichuan', 'BaichuanConfig')])
|
||||
class _LazyConfigMapping(OrderedDict, ReprMixin):
|
||||
def __init__(self, mapping: OrderedDict[LiteralString, LiteralString]):
|
||||
self._mapping = mapping
|
||||
@@ -31,7 +31,7 @@ class _LazyConfigMapping(OrderedDict, ReprMixin):
|
||||
if module_name not in self._modules: self._modules[module_name] = openllm_core.utils.EnvVarMixin(module_name).module
|
||||
if hasattr(self._modules[module_name], value): return getattr(self._modules[module_name], value)
|
||||
# Some of the mappings have entries model_type -> config of another model type. In that case we try to grab the object at the top level.
|
||||
return getattr(importlib.import_module("openllm"), value)
|
||||
return getattr(importlib.import_module('openllm'), value)
|
||||
|
||||
@property
|
||||
def __repr_keys__(self) -> set[str]:
|
||||
@@ -44,13 +44,13 @@ class _LazyConfigMapping(OrderedDict, ReprMixin):
|
||||
yield from self._mapping.items()
|
||||
|
||||
def keys(self) -> ConfigKeysView:
|
||||
return t.cast("ConfigKeysView", list(self._mapping.keys()) + list(self._extra_content.keys()))
|
||||
return t.cast('ConfigKeysView', list(self._mapping.keys()) + list(self._extra_content.keys()))
|
||||
|
||||
def values(self) -> ConfigValuesView:
|
||||
return t.cast("ConfigValuesView", [self[k] for k in self._mapping.keys()] + list(self._extra_content.values()))
|
||||
return t.cast('ConfigValuesView', [self[k] for k in self._mapping.keys()] + list(self._extra_content.values()))
|
||||
|
||||
def items(self) -> ConfigItemsView:
|
||||
return t.cast("ConfigItemsView", [(k, self[k]) for k in self._mapping.keys()] + list(self._extra_content.items()))
|
||||
return t.cast('ConfigItemsView', [(k, self[k]) for k in self._mapping.keys()] + list(self._extra_content.items()))
|
||||
|
||||
def __iter__(self) -> t.Iterator[str]:
|
||||
return iter(list(self._mapping.keys()) + list(self._extra_content.keys()))
|
||||
@@ -63,10 +63,10 @@ class _LazyConfigMapping(OrderedDict, ReprMixin):
|
||||
self._extra_content[key] = value
|
||||
CONFIG_MAPPING: dict[str, type[openllm_core.LLMConfig]] = _LazyConfigMapping(CONFIG_MAPPING_NAMES)
|
||||
# The below handle special alias when we call underscore to the name directly without processing camelcase first.
|
||||
CONFIG_NAME_ALIASES: dict[str, str] = {"chat_glm": "chatglm", "stable_lm": "stablelm", "star_coder": "starcoder", "gpt_neo_x": "gpt_neox",}
|
||||
CONFIG_NAME_ALIASES: dict[str, str] = {'chat_glm': 'chatglm', 'stable_lm': 'stablelm', 'star_coder': 'starcoder', 'gpt_neo_x': 'gpt_neox',}
|
||||
class AutoConfig:
|
||||
def __init__(self, *_: t.Any, **__: t.Any):
|
||||
raise EnvironmentError("Cannot instantiate AutoConfig directly. Please use `AutoConfig.for_model(model_name)` instead.")
|
||||
raise EnvironmentError('Cannot instantiate AutoConfig directly. Please use `AutoConfig.for_model(model_name)` instead.')
|
||||
|
||||
@classmethod
|
||||
def for_model(cls, model_name: str, **attrs: t.Any) -> openllm_core.LLMConfig:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import openllm_core, typing as t
|
||||
from openllm_core._prompt import process_prompt
|
||||
START_BAICHUAN_COMMAND_DOCSTRING = """\
|
||||
START_BAICHUAN_COMMAND_DOCSTRING = '''\
|
||||
Run a LLMServer for Baichuan model.
|
||||
|
||||
\b
|
||||
@@ -19,8 +19,8 @@ or provide `--model-id` flag when running ``openllm start baichuan``:
|
||||
|
||||
\b
|
||||
$ openllm start baichuan --model-id='fireballoon/baichuan-vicuna-chinese-7b'
|
||||
"""
|
||||
DEFAULT_PROMPT_TEMPLATE = """{instruction}"""
|
||||
'''
|
||||
DEFAULT_PROMPT_TEMPLATE = '''{instruction}'''
|
||||
class BaichuanConfig(openllm_core.LLMConfig):
|
||||
"""Baichuan-7B is an open-source, large-scale pre-trained language model developed by Baichuan Intelligent Technology.
|
||||
|
||||
@@ -32,21 +32,21 @@ class BaichuanConfig(openllm_core.LLMConfig):
|
||||
Refer to [Baichuan-7B's GitHub page](https://github.com/baichuan-inc/Baichuan-7B) for more information.
|
||||
"""
|
||||
__config__ = {
|
||||
"name_type": "lowercase",
|
||||
"trust_remote_code": True,
|
||||
"timeout": 3600000,
|
||||
"requires_gpu": True,
|
||||
"url": "https://github.com/baichuan-inc/Baichuan-7B",
|
||||
"requirements": ["cpm-kernels", "sentencepiece"],
|
||||
"architecture": "BaiChuanForCausalLM",
|
||||
"default_id": "baichuan-inc/baichuan-7b",
|
||||
"model_ids": [
|
||||
"baichuan-inc/baichuan-7b",
|
||||
"baichuan-inc/baichuan-13b-base",
|
||||
"baichuan-inc/baichuan-13b-chat",
|
||||
"fireballoon/baichuan-vicuna-chinese-7b",
|
||||
"fireballoon/baichuan-vicuna-7b",
|
||||
"hiyouga/baichuan-7b-sft"
|
||||
'name_type': 'lowercase',
|
||||
'trust_remote_code': True,
|
||||
'timeout': 3600000,
|
||||
'requires_gpu': True,
|
||||
'url': 'https://github.com/baichuan-inc/Baichuan-7B',
|
||||
'requirements': ['cpm-kernels', 'sentencepiece'],
|
||||
'architecture': 'BaiChuanForCausalLM',
|
||||
'default_id': 'baichuan-inc/baichuan-7b',
|
||||
'model_ids': [
|
||||
'baichuan-inc/baichuan-7b',
|
||||
'baichuan-inc/baichuan-13b-base',
|
||||
'baichuan-inc/baichuan-13b-chat',
|
||||
'fireballoon/baichuan-vicuna-chinese-7b',
|
||||
'fireballoon/baichuan-vicuna-7b',
|
||||
'hiyouga/baichuan-7b-sft'
|
||||
]
|
||||
}
|
||||
|
||||
@@ -58,7 +58,7 @@ class BaichuanConfig(openllm_core.LLMConfig):
|
||||
def sanitize_parameters(
|
||||
self, prompt: str, max_new_tokens: int | None = None, top_p: float | None = None, temperature: float | None = None, use_default_prompt_template: bool = False, **attrs: t.Any
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "top_p": top_p, "temperature": temperature, **attrs}, {}
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {'max_new_tokens': max_new_tokens, 'top_p': top_p, 'temperature': temperature, **attrs}, {}
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str:
|
||||
return generation_result[0]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import openllm_core, typing as t
|
||||
from openllm_core.utils import dantic
|
||||
START_CHATGLM_COMMAND_DOCSTRING = """\
|
||||
START_CHATGLM_COMMAND_DOCSTRING = '''\
|
||||
Run a LLMServer for ChatGLM model.
|
||||
|
||||
\b
|
||||
@@ -19,8 +19,8 @@ or provide `--model-id` flag when running ``openllm start chatglm``:
|
||||
|
||||
\b
|
||||
$ openllm start chatglm --model-id='thudm/chatglm-6b-int8'
|
||||
"""
|
||||
DEFAULT_PROMPT_TEMPLATE = """{instruction}"""
|
||||
'''
|
||||
DEFAULT_PROMPT_TEMPLATE = '''{instruction}'''
|
||||
class ChatGLMConfig(openllm_core.LLMConfig):
|
||||
"""ChatGLM is an open bilingual language model based on [General Language Model (GLM)](https://github.com/THUDM/GLM) framework.
|
||||
|
||||
@@ -36,18 +36,18 @@ class ChatGLMConfig(openllm_core.LLMConfig):
|
||||
Refer to [ChatGLM's GitHub page](https://github.com/THUDM/ChatGLM-6B) for more information.
|
||||
"""
|
||||
__config__ = {
|
||||
"name_type": "lowercase",
|
||||
"trust_remote_code": True,
|
||||
"timeout": 3600000,
|
||||
"requires_gpu": True,
|
||||
"url": "https://github.com/THUDM/ChatGLM-6B",
|
||||
"requirements": ["cpm-kernels", "sentencepiece"],
|
||||
"architecture": "ChatGLMForConditionalGeneration",
|
||||
"default_id": "thudm/chatglm-6b",
|
||||
"model_ids": ["thudm/chatglm-6b", "thudm/chatglm-6b-int8", "thudm/chatglm-6b-int4", "thudm/chatglm2-6b", "thudm/chatglm2-6b-int4"]
|
||||
'name_type': 'lowercase',
|
||||
'trust_remote_code': True,
|
||||
'timeout': 3600000,
|
||||
'requires_gpu': True,
|
||||
'url': 'https://github.com/THUDM/ChatGLM-6B',
|
||||
'requirements': ['cpm-kernels', 'sentencepiece'],
|
||||
'architecture': 'ChatGLMForConditionalGeneration',
|
||||
'default_id': 'thudm/chatglm-6b',
|
||||
'model_ids': ['thudm/chatglm-6b', 'thudm/chatglm-6b-int8', 'thudm/chatglm-6b-int4', 'thudm/chatglm2-6b', 'thudm/chatglm2-6b-int4']
|
||||
}
|
||||
retain_history: bool = dantic.Field(False, description="Whether to retain history given to the model. If set to True, then the model will retain given history.")
|
||||
use_half_precision: bool = dantic.Field(True, description="Whether to use half precision for model.")
|
||||
retain_history: bool = dantic.Field(False, description='Whether to retain history given to the model. If set to True, then the model will retain given history.')
|
||||
use_half_precision: bool = dantic.Field(True, description='Whether to use half precision for model.')
|
||||
|
||||
class GenerationConfig:
|
||||
max_new_tokens: int = 2048
|
||||
@@ -66,15 +66,15 @@ class ChatGLMConfig(openllm_core.LLMConfig):
|
||||
use_default_prompt_template: bool = False,
|
||||
**attrs: t.Any
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
prompt_text = ""
|
||||
prompt_text = ''
|
||||
if use_default_prompt_template and chat_history is not None:
|
||||
for i, (old_query, response) in enumerate(chat_history):
|
||||
prompt_text += f"[Round {i}]\n问:{old_query}\n答:{response}\n"
|
||||
prompt_text += f"[Round {len(chat_history)}]\n问:{prompt}\n答:"
|
||||
prompt_text += f'[Round {i}]\n问:{old_query}\n答:{response}\n'
|
||||
prompt_text += f'[Round {len(chat_history)}]\n问:{prompt}\n答:'
|
||||
else:
|
||||
prompt_text = prompt
|
||||
postprocess_generate_kwargs = {"chat_history": chat_history if chat_history is not None else None}
|
||||
return prompt_text, {"max_new_tokens": max_new_tokens, "num_beams": num_beams, "top_p": top_p, "temperature": temperature, **attrs}, postprocess_generate_kwargs
|
||||
postprocess_generate_kwargs = {'chat_history': chat_history if chat_history is not None else None}
|
||||
return prompt_text, {'max_new_tokens': max_new_tokens, 'num_beams': num_beams, 'top_p': top_p, 'temperature': temperature, **attrs}, postprocess_generate_kwargs
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: tuple[str, list[tuple[str, str]]], *, chat_history: list[tuple[str, str]] | None = None, **attrs: t.Any) -> str:
|
||||
generated, history = generation_result
|
||||
|
||||
@@ -4,7 +4,7 @@ from openllm_core._prompt import process_prompt
|
||||
from openllm_core.utils import dantic
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
|
||||
START_DOLLY_V2_COMMAND_DOCSTRING = """\
|
||||
START_DOLLY_V2_COMMAND_DOCSTRING = '''\
|
||||
Run a LLMServer for dolly-v2 model.
|
||||
|
||||
\b
|
||||
@@ -22,21 +22,21 @@ or provide `--model-id` flag when running ``openllm start dolly-v2``:
|
||||
|
||||
\b
|
||||
$ openllm start dolly-v2 --model-id databricks/dolly-v2-7b
|
||||
"""
|
||||
INSTRUCTION_KEY = "### Instruction:"
|
||||
RESPONSE_KEY = "### Response:"
|
||||
END_KEY = "### End"
|
||||
INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
|
||||
'''
|
||||
INSTRUCTION_KEY = '### Instruction:'
|
||||
RESPONSE_KEY = '### Response:'
|
||||
END_KEY = '### End'
|
||||
INTRO_BLURB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.'
|
||||
# NOTE: This is the prompt that is used for generating responses using an already
|
||||
# trained model. It ends with the response key, where the job of the model is to provide
|
||||
# the completion that follows it (i.e. the response itself).
|
||||
DEFAULT_PROMPT_TEMPLATE = """{intro}
|
||||
DEFAULT_PROMPT_TEMPLATE = '''{intro}
|
||||
{instruction_key}
|
||||
{instruction}
|
||||
{response_key}
|
||||
""".format(intro=INTRO_BLURB, instruction_key=INSTRUCTION_KEY, instruction="{instruction}", response_key=RESPONSE_KEY)
|
||||
'''.format(intro=INTRO_BLURB, instruction_key=INSTRUCTION_KEY, instruction='{instruction}', response_key=RESPONSE_KEY)
|
||||
def get_special_token_id(tokenizer: transformers.PreTrainedTokenizer, key: str) -> int:
|
||||
"""Gets the token ID for a given string that has been added to the tokenizer as a special token.
|
||||
'''Gets the token ID for a given string that has been added to the tokenizer as a special token.
|
||||
|
||||
When training, we configure the tokenizer so that the sequences like "### Instruction:" and "### End" are
|
||||
treated specially and converted to a single, new token. This retrieves the token ID each of these keys map to.
|
||||
@@ -50,7 +50,7 @@ def get_special_token_id(tokenizer: transformers.PreTrainedTokenizer, key: str)
|
||||
|
||||
Returns:
|
||||
int: the token ID for the given key.
|
||||
"""
|
||||
'''
|
||||
token_ids = tokenizer.encode(key)
|
||||
if len(token_ids) > 1: raise ValueError(f"Expected only a single token for '{key}' but found {token_ids}")
|
||||
return token_ids[0]
|
||||
@@ -67,13 +67,13 @@ class DollyV2Config(openllm_core.LLMConfig):
|
||||
Refer to [Databricks's Dolly page](https://github.com/databrickslabs/dolly) for more information.
|
||||
"""
|
||||
__config__ = {
|
||||
"timeout": 3600000,
|
||||
"url": "https://github.com/databrickslabs/dolly",
|
||||
"architecture": "GPTNeoXForCausalLM",
|
||||
"default_id": "databricks/dolly-v2-3b",
|
||||
"model_ids": ["databricks/dolly-v2-3b", "databricks/dolly-v2-7b", "databricks/dolly-v2-12b"]
|
||||
'timeout': 3600000,
|
||||
'url': 'https://github.com/databrickslabs/dolly',
|
||||
'architecture': 'GPTNeoXForCausalLM',
|
||||
'default_id': 'databricks/dolly-v2-3b',
|
||||
'model_ids': ['databricks/dolly-v2-3b', 'databricks/dolly-v2-7b', 'databricks/dolly-v2-12b']
|
||||
}
|
||||
return_full_text: bool = dantic.Field(False, description="Whether to return the full prompt to the users.")
|
||||
return_full_text: bool = dantic.Field(False, description='Whether to return the full prompt to the users.')
|
||||
|
||||
class GenerationConfig:
|
||||
temperature: float = 0.9
|
||||
@@ -93,8 +93,8 @@ class DollyV2Config(openllm_core.LLMConfig):
|
||||
**attrs: t.Any
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {
|
||||
"max_new_tokens": max_new_tokens, "top_k": top_k, "top_p": top_p, "temperature": temperature, **attrs
|
||||
'max_new_tokens': max_new_tokens, 'top_k': top_k, 'top_p': top_p, 'temperature': temperature, **attrs
|
||||
}, {}
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: list[dict[t.Literal["generated_text"], str]], **_: t.Any) -> str:
|
||||
return generation_result[0]["generated_text"]
|
||||
def postprocess_generate(self, prompt: str, generation_result: list[dict[t.Literal['generated_text'], str]], **_: t.Any) -> str:
|
||||
return generation_result[0]['generated_text']
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import openllm_core, typing as t
|
||||
from openllm_core._prompt import process_prompt
|
||||
START_FALCON_COMMAND_DOCSTRING = """\
|
||||
START_FALCON_COMMAND_DOCSTRING = '''\
|
||||
Run a LLMServer for FalconLM model.
|
||||
|
||||
\b
|
||||
@@ -21,11 +21,11 @@ or provide `--model-id` flag when running ``openllm start falcon``:
|
||||
|
||||
\b
|
||||
$ openllm start falcon --model-id tiiuae/falcon-7b-instruct
|
||||
"""
|
||||
DEFAULT_PROMPT_TEMPLATE = """{context}
|
||||
'''
|
||||
DEFAULT_PROMPT_TEMPLATE = '''{context}
|
||||
{user_name}: {instruction}
|
||||
{agent}:
|
||||
"""
|
||||
'''
|
||||
class FalconConfig(openllm_core.LLMConfig):
|
||||
"""Falcon-7B is a 7B parameters causal decoder-only model built by TII and trained on 1,500B tokens of [RefinedWeb](https://huggingface.co/datasets/tiiuae/falcon-refinedweb) enhanced with curated corpora.
|
||||
|
||||
@@ -34,17 +34,17 @@ class FalconConfig(openllm_core.LLMConfig):
|
||||
Refer to [Falcon's HuggingFace page](https://huggingface.co/tiiuae/falcon-7b) for more information.
|
||||
"""
|
||||
__config__ = {
|
||||
"name_type": "lowercase",
|
||||
"trust_remote_code": True,
|
||||
"requires_gpu": True,
|
||||
"timeout": int(36e6),
|
||||
"url": "https://falconllm.tii.ae/",
|
||||
"requirements": ["einops", "xformers"],
|
||||
"architecture": "FalconForCausalLM",
|
||||
"default_id": "tiiuae/falcon-7b",
|
||||
"model_ids": ["tiiuae/falcon-7b", "tiiuae/falcon-40b", "tiiuae/falcon-7b-instruct", "tiiuae/falcon-40b-instruct"],
|
||||
"fine_tune_strategies": ({
|
||||
"adapter_type": "lora", "r": 64, "lora_alpha": 16, "lora_dropout": 0.1, "bias": "none", "target_modules": ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"]
|
||||
'name_type': 'lowercase',
|
||||
'trust_remote_code': True,
|
||||
'requires_gpu': True,
|
||||
'timeout': int(36e6),
|
||||
'url': 'https://falconllm.tii.ae/',
|
||||
'requirements': ['einops', 'xformers'],
|
||||
'architecture': 'FalconForCausalLM',
|
||||
'default_id': 'tiiuae/falcon-7b',
|
||||
'model_ids': ['tiiuae/falcon-7b', 'tiiuae/falcon-40b', 'tiiuae/falcon-7b-instruct', 'tiiuae/falcon-40b-instruct'],
|
||||
'fine_tune_strategies': ({
|
||||
'adapter_type': 'lora', 'r': 64, 'lora_alpha': 16, 'lora_dropout': 0.1, 'bias': 'none', 'target_modules': ['query_key_value', 'dense', 'dense_h_to_4h', 'dense_4h_to_h']
|
||||
},)
|
||||
}
|
||||
|
||||
@@ -66,7 +66,7 @@ class FalconConfig(openllm_core.LLMConfig):
|
||||
**attrs: t.Any
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {
|
||||
"max_new_tokens": max_new_tokens, "top_k": top_k, "num_return_sequences": num_return_sequences, "eos_token_id": eos_token_id, **attrs
|
||||
'max_new_tokens': max_new_tokens, 'top_k': top_k, 'num_return_sequences': num_return_sequences, 'eos_token_id': eos_token_id, **attrs
|
||||
}, {}
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import openllm_core, typing as t
|
||||
from openllm_core._prompt import process_prompt
|
||||
START_FLAN_T5_COMMAND_DOCSTRING = """\
|
||||
START_FLAN_T5_COMMAND_DOCSTRING = '''\
|
||||
Run a LLMServer for FLAN-T5 model.
|
||||
|
||||
\b
|
||||
@@ -25,8 +25,8 @@ or provide `--model-id` flag when running ``openllm start flan-t5``:
|
||||
|
||||
\b
|
||||
$ openllm start flan-t5 --model-id google/flan-t5-xxl
|
||||
"""
|
||||
DEFAULT_PROMPT_TEMPLATE = """Answer the following question:\nQuestion: {instruction}\nAnswer:"""
|
||||
'''
|
||||
DEFAULT_PROMPT_TEMPLATE = '''Answer the following question:\nQuestion: {instruction}\nAnswer:'''
|
||||
class FlanT5Config(openllm_core.LLMConfig):
|
||||
"""FLAN-T5 was released in the paper [Scaling Instruction-Finetuned Language Models](https://arxiv.org/pdf/2210.11416.pdf).
|
||||
|
||||
@@ -35,11 +35,11 @@ class FlanT5Config(openllm_core.LLMConfig):
|
||||
Refer to [FLAN-T5's page](https://huggingface.co/docs/transformers/model_doc/flan-t5) for more information.
|
||||
"""
|
||||
__config__ = {
|
||||
"url": "https://huggingface.co/docs/transformers/model_doc/flan-t5",
|
||||
"architecture": "T5ForConditionalGeneration",
|
||||
"model_type": "seq2seq_lm",
|
||||
"default_id": "google/flan-t5-large",
|
||||
"model_ids": ["google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large", "google/flan-t5-xl", "google/flan-t5-xxl",]
|
||||
'url': 'https://huggingface.co/docs/transformers/model_doc/flan-t5',
|
||||
'architecture': 'T5ForConditionalGeneration',
|
||||
'model_type': 'seq2seq_lm',
|
||||
'default_id': 'google/flan-t5-large',
|
||||
'model_ids': ['google/flan-t5-small', 'google/flan-t5-base', 'google/flan-t5-large', 'google/flan-t5-xl', 'google/flan-t5-xxl',]
|
||||
}
|
||||
|
||||
class GenerationConfig:
|
||||
@@ -61,7 +61,7 @@ class FlanT5Config(openllm_core.LLMConfig):
|
||||
**attrs: t.Any
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {
|
||||
"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "top_p": top_p, "repetition_penalty": repetition_penalty
|
||||
'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_k': top_k, 'top_p': top_p, 'repetition_penalty': repetition_penalty
|
||||
}, {}
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str:
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
import openllm_core, typing as t
|
||||
from openllm_core._prompt import process_prompt
|
||||
from openllm_core.utils import dantic
|
||||
START_GPT_NEOX_COMMAND_DOCSTRING = """\
|
||||
START_GPT_NEOX_COMMAND_DOCSTRING = '''\
|
||||
Run a LLMServer for GPTNeoX model.
|
||||
|
||||
\b
|
||||
@@ -20,8 +20,8 @@ or provide `--model-id` flag when running ``openllm start gpt-neox``:
|
||||
|
||||
\b
|
||||
$ openllm start gpt-neox --model-id 'stabilityai/stablelm-tuned-alpha-3b'
|
||||
"""
|
||||
DEFAULT_PROMPT_TEMPLATE = """{instruction}"""
|
||||
'''
|
||||
DEFAULT_PROMPT_TEMPLATE = '''{instruction}'''
|
||||
class GPTNeoXConfig(openllm_core.LLMConfig):
|
||||
"""GPTNeoX is an autoregressive language model trained on the Pile, whose weights will be made freely and openly available to the public through a permissive license.
|
||||
|
||||
@@ -38,15 +38,15 @@ class GPTNeoXConfig(openllm_core.LLMConfig):
|
||||
for more information.
|
||||
"""
|
||||
__config__ = {
|
||||
"model_name": "gpt_neox",
|
||||
"start_name": "gpt-neox",
|
||||
"requires_gpu": True,
|
||||
"architecture": "GPTNeoXForCausalLM",
|
||||
"url": "https://github.com/EleutherAI/gpt-neox",
|
||||
"default_id": "eleutherai/gpt-neox-20b",
|
||||
"model_ids": ["eleutherai/gpt-neox-20b"]
|
||||
'model_name': 'gpt_neox',
|
||||
'start_name': 'gpt-neox',
|
||||
'requires_gpu': True,
|
||||
'architecture': 'GPTNeoXForCausalLM',
|
||||
'url': 'https://github.com/EleutherAI/gpt-neox',
|
||||
'default_id': 'eleutherai/gpt-neox-20b',
|
||||
'model_ids': ['eleutherai/gpt-neox-20b']
|
||||
}
|
||||
use_half_precision: bool = dantic.Field(True, description="Whether to use half precision for model.")
|
||||
use_half_precision: bool = dantic.Field(True, description='Whether to use half precision for model.')
|
||||
|
||||
class GenerationConfig:
|
||||
temperature: float = 0.9
|
||||
@@ -54,7 +54,7 @@ class GPTNeoXConfig(openllm_core.LLMConfig):
|
||||
|
||||
def sanitize_parameters(self, prompt: str, temperature: float | None = None, max_new_tokens: int | None = None, use_default_prompt_template: bool = True,
|
||||
**attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {"max_new_tokens": max_new_tokens, "temperature": temperature}, {}
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {'max_new_tokens': max_new_tokens, 'temperature': temperature}, {}
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: list[str], **_: t.Any) -> str:
|
||||
return generation_result[0]
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
import typing as t, openllm_core
|
||||
from openllm_core._prompt import process_prompt
|
||||
from openllm_core.utils import dantic
|
||||
START_LLAMA_COMMAND_DOCSTRING = """\
|
||||
START_LLAMA_COMMAND_DOCSTRING = '''\
|
||||
Run a LLMServer for Llama model.
|
||||
|
||||
\b
|
||||
@@ -30,17 +30,17 @@ OpenLLM also supports running Llama-2 and its fine-tune and variants. To import
|
||||
|
||||
\b
|
||||
$ CONVERTER=hf-llama2 openllm import llama /path/to/llama-2
|
||||
"""
|
||||
SYSTEM_MESSAGE = """
|
||||
'''
|
||||
SYSTEM_MESSAGE = '''
|
||||
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
||||
|
||||
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
|
||||
"""
|
||||
SINST_KEY, EINST_KEY, SYS_KEY, EOS_TOKEN, BOS_TOKEN = "[INST]", "[/INST]", "<<SYS>>", "</s>", "<s>"
|
||||
'''
|
||||
SINST_KEY, EINST_KEY, SYS_KEY, EOS_TOKEN, BOS_TOKEN = '[INST]', '[/INST]', '<<SYS>>', '</s>', '<s>'
|
||||
# TODO: support history and v1 prompt implementation
|
||||
_v1_prompt, _v2_prompt = """{instruction}""", """{start_key} {sys_key}\n{system_message}\n{sys_key}\n\n{instruction}\n{end_key} """.format(start_key=SINST_KEY, sys_key=SYS_KEY, system_message=SYSTEM_MESSAGE, instruction="{instruction}", end_key=EINST_KEY)
|
||||
PROMPT_MAPPING = {"v1": _v1_prompt, "v2": _v2_prompt}
|
||||
def _get_prompt(model_type: t.Literal["v1", "v2"]) -> str:
|
||||
_v1_prompt, _v2_prompt = '''{instruction}''', '''{start_key} {sys_key}\n{system_message}\n{sys_key}\n\n{instruction}\n{end_key} '''.format(start_key=SINST_KEY, sys_key=SYS_KEY, system_message=SYSTEM_MESSAGE, instruction='{instruction}', end_key=EINST_KEY)
|
||||
PROMPT_MAPPING = {'v1': _v1_prompt, 'v2': _v2_prompt}
|
||||
def _get_prompt(model_type: t.Literal['v1', 'v2']) -> str:
|
||||
return PROMPT_MAPPING[model_type]
|
||||
DEFAULT_PROMPT_TEMPLATE = _get_prompt
|
||||
class LlamaConfig(openllm_core.LLMConfig):
|
||||
@@ -55,40 +55,40 @@ class LlamaConfig(openllm_core.LLMConfig):
|
||||
Refer to [Llama's model card](https://huggingface.co/docs/transformers/main/model_doc/llama)
|
||||
for more information.
|
||||
"""
|
||||
use_llama2_prompt: bool = dantic.Field(False, description="Whether to use the prompt format for Llama 2. Disable this when working with Llama 1.")
|
||||
use_llama2_prompt: bool = dantic.Field(False, description='Whether to use the prompt format for Llama 2. Disable this when working with Llama 1.')
|
||||
__config__ = {
|
||||
"name_type": "lowercase",
|
||||
"url": "https://github.com/facebookresearch/llama",
|
||||
"default_implementation": {
|
||||
"cpu": "pt", "nvidia.com/gpu": "pt"
|
||||
'name_type': 'lowercase',
|
||||
'url': 'https://github.com/facebookresearch/llama',
|
||||
'default_implementation': {
|
||||
'cpu': 'pt', 'nvidia.com/gpu': 'pt'
|
||||
},
|
||||
"architecture": "LlamaForCausalLM",
|
||||
"requirements": ["fairscale", "sentencepiece"],
|
||||
"tokenizer_class": "LlamaTokenizerFast",
|
||||
"default_id": "NousResearch/llama-2-7b-hf",
|
||||
"model_ids": [
|
||||
"meta-llama/Llama-2-70b-chat-hf",
|
||||
"meta-llama/Llama-2-13b-chat-hf",
|
||||
"meta-llama/Llama-2-7b-chat-hf",
|
||||
"meta-llama/Llama-2-70b-hf",
|
||||
"meta-llama/Llama-2-13b-hf",
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
"NousResearch/llama-2-70b-chat-hf",
|
||||
"NousResearch/llama-2-13b-chat-hf",
|
||||
"NousResearch/llama-2-7b-chat-hf",
|
||||
"NousResearch/llama-2-70b-hf",
|
||||
"NousResearch/llama-2-13b-hf",
|
||||
"NousResearch/llama-2-7b-hf",
|
||||
"openlm-research/open_llama_7b_v2",
|
||||
"openlm-research/open_llama_3b_v2",
|
||||
"openlm-research/open_llama_13b",
|
||||
"huggyllama/llama-65b",
|
||||
"huggyllama/llama-30b",
|
||||
"huggyllama/llama-13b",
|
||||
"huggyllama/llama-7b"
|
||||
'architecture': 'LlamaForCausalLM',
|
||||
'requirements': ['fairscale', 'sentencepiece'],
|
||||
'tokenizer_class': 'LlamaTokenizerFast',
|
||||
'default_id': 'NousResearch/llama-2-7b-hf',
|
||||
'model_ids': [
|
||||
'meta-llama/Llama-2-70b-chat-hf',
|
||||
'meta-llama/Llama-2-13b-chat-hf',
|
||||
'meta-llama/Llama-2-7b-chat-hf',
|
||||
'meta-llama/Llama-2-70b-hf',
|
||||
'meta-llama/Llama-2-13b-hf',
|
||||
'meta-llama/Llama-2-7b-hf',
|
||||
'NousResearch/llama-2-70b-chat-hf',
|
||||
'NousResearch/llama-2-13b-chat-hf',
|
||||
'NousResearch/llama-2-7b-chat-hf',
|
||||
'NousResearch/llama-2-70b-hf',
|
||||
'NousResearch/llama-2-13b-hf',
|
||||
'NousResearch/llama-2-7b-hf',
|
||||
'openlm-research/open_llama_7b_v2',
|
||||
'openlm-research/open_llama_3b_v2',
|
||||
'openlm-research/open_llama_13b',
|
||||
'huggyllama/llama-65b',
|
||||
'huggyllama/llama-30b',
|
||||
'huggyllama/llama-13b',
|
||||
'huggyllama/llama-7b'
|
||||
],
|
||||
"fine_tune_strategies": ({
|
||||
"adapter_type": "lora", "r": 64, "lora_alpha": 16, "lora_dropout": 0.1, "bias": "none"
|
||||
'fine_tune_strategies': ({
|
||||
'adapter_type': 'lora', 'r': 64, 'lora_alpha': 16, 'lora_dropout': 0.1, 'bias': 'none'
|
||||
},)
|
||||
}
|
||||
|
||||
@@ -113,8 +113,8 @@ class LlamaConfig(openllm_core.LLMConfig):
|
||||
use_llama2_prompt: bool = True,
|
||||
**attrs: t.Any
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE("v2" if use_llama2_prompt else "v1") if use_default_prompt_template else None, use_default_prompt_template, **attrs), {
|
||||
"max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE('v2' if use_llama2_prompt else 'v1') if use_default_prompt_template else None, use_default_prompt_template, **attrs), {
|
||||
'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_p': top_p, 'top_k': top_k
|
||||
}, {}
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: list[str], **_: t.Any) -> str:
|
||||
|
||||
@@ -2,9 +2,9 @@ from __future__ import annotations
|
||||
import typing as t, openllm_core
|
||||
from openllm_core.utils import dantic
|
||||
from openllm_core._prompt import process_prompt
|
||||
MPTPromptType = t.Literal["default", "instruct", "chat", "storywriter"]
|
||||
MPTPromptType = t.Literal['default', 'instruct', 'chat', 'storywriter']
|
||||
|
||||
START_MPT_COMMAND_DOCSTRING = """\
|
||||
START_MPT_COMMAND_DOCSTRING = '''\
|
||||
Run a LLMServer for MPT model.
|
||||
|
||||
\b
|
||||
@@ -29,18 +29,18 @@ or provide `--model-id` flag when running ``openllm start mpt``:
|
||||
|
||||
\b
|
||||
$ openllm start mpt --model-id mosaicml/mpt-30b
|
||||
"""
|
||||
INSTRUCTION_KEY, RESPONSE_KEY, END_KEY = "### Instruction:", "### Response:", "### End"
|
||||
INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
|
||||
'''
|
||||
INSTRUCTION_KEY, RESPONSE_KEY, END_KEY = '### Instruction:', '### Response:', '### End'
|
||||
INTRO_BLURB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.'
|
||||
# NOTE: This is the prompt that is used for generating responses using an already
|
||||
# trained model. It ends with the response key, where the job of the model is to provide
|
||||
# the completion that follows it (i.e. the response itself).
|
||||
_chat_prompt, _default_prompt, _instruct_prompt = """{instruction}""", """{instruction}""", """{intro}
|
||||
_chat_prompt, _default_prompt, _instruct_prompt = '''{instruction}''', '''{instruction}''', '''{intro}
|
||||
{instruction_key}
|
||||
{instruction}
|
||||
{response_key}
|
||||
""".format(intro=INTRO_BLURB, instruction_key=INSTRUCTION_KEY, instruction="{instruction}", response_key=RESPONSE_KEY)
|
||||
PROMPT_MAPPING = {"default": _default_prompt, "instruct": _instruct_prompt, "storywriter": _default_prompt, "chat": _chat_prompt}
|
||||
'''.format(intro=INTRO_BLURB, instruction_key=INSTRUCTION_KEY, instruction='{instruction}', response_key=RESPONSE_KEY)
|
||||
PROMPT_MAPPING = {'default': _default_prompt, 'instruct': _instruct_prompt, 'storywriter': _default_prompt, 'chat': _chat_prompt}
|
||||
def _get_prompt(model_type: str) -> str:
|
||||
return PROMPT_MAPPING[model_type]
|
||||
DEFAULT_PROMPT_TEMPLATE = _get_prompt
|
||||
@@ -54,21 +54,21 @@ class MPTConfig(openllm_core.LLMConfig):
|
||||
for more details on specific models.
|
||||
"""
|
||||
__config__ = {
|
||||
"name_type": "lowercase",
|
||||
"trust_remote_code": True,
|
||||
"url": "https://huggingface.co/mosaicml",
|
||||
"timeout": int(36e6),
|
||||
"requirements": ["triton", "einops"],
|
||||
"architecture": "MPTForCausalLM",
|
||||
"default_id": "mosaicml/mpt-7b-instruct",
|
||||
"model_ids": [
|
||||
"mosaicml/mpt-7b", "mosaicml/mpt-7b-instruct", "mosaicml/mpt-7b-chat", "mosaicml/mpt-7b-storywriter", "mosaicml/mpt-30b", "mosaicml/mpt-30b-instruct", "mosaicml/mpt-30b-chat"
|
||||
'name_type': 'lowercase',
|
||||
'trust_remote_code': True,
|
||||
'url': 'https://huggingface.co/mosaicml',
|
||||
'timeout': int(36e6),
|
||||
'requirements': ['triton', 'einops'],
|
||||
'architecture': 'MPTForCausalLM',
|
||||
'default_id': 'mosaicml/mpt-7b-instruct',
|
||||
'model_ids': [
|
||||
'mosaicml/mpt-7b', 'mosaicml/mpt-7b-instruct', 'mosaicml/mpt-7b-chat', 'mosaicml/mpt-7b-storywriter', 'mosaicml/mpt-30b', 'mosaicml/mpt-30b-instruct', 'mosaicml/mpt-30b-chat'
|
||||
]
|
||||
}
|
||||
prompt_type: MPTPromptType = dantic.Field('"default"', description="Given prompt type for running MPT. Default will be inferred from model name if pretrained.")
|
||||
prompt_type: MPTPromptType = dantic.Field('"default"', description='Given prompt type for running MPT. Default will be inferred from model name if pretrained.')
|
||||
max_sequence_length: int = dantic.Field(
|
||||
2048,
|
||||
description="Max sequence length to run MPT with. Note that MPT is trained ith sequence length of 2048, but with [ALiBi](https://arxiv.org/abs/2108.12409) it can set up to 4096 (for 7b models) and 16384 (for 30b models)"
|
||||
description='Max sequence length to run MPT with. Note that MPT is trained ith sequence length of 2048, but with [ALiBi](https://arxiv.org/abs/2108.12409) it can set up to 4096 (for 7b models) and 16384 (for 30b models)'
|
||||
)
|
||||
|
||||
class GenerationConfig:
|
||||
@@ -89,12 +89,12 @@ class MPTConfig(openllm_core.LLMConfig):
|
||||
_template = None
|
||||
if use_default_prompt_template:
|
||||
if prompt_type is None:
|
||||
if "instruct" in self.model_id: prompt_type = "instruct"
|
||||
elif "storywriter" in self.model_id: prompt_type = "storywriter"
|
||||
elif "chat" in self.model_id: prompt_type = "chat"
|
||||
else: prompt_type = "default"
|
||||
if 'instruct' in self.model_id: prompt_type = 'instruct'
|
||||
elif 'storywriter' in self.model_id: prompt_type = 'storywriter'
|
||||
elif 'chat' in self.model_id: prompt_type = 'chat'
|
||||
else: prompt_type = 'default'
|
||||
_template = DEFAULT_PROMPT_TEMPLATE(prompt_type)
|
||||
return process_prompt(prompt, _template, use_default_prompt_template), {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p}, {}
|
||||
return process_prompt(prompt, _template, use_default_prompt_template), {'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_p': top_p}, {}
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **attrs: t.Any) -> str:
|
||||
return generation_result[0]
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
import openllm_core, typing as t
|
||||
from openllm_core.utils import dantic
|
||||
from openllm_core._prompt import process_prompt
|
||||
START_OPT_COMMAND_DOCSTRING = """\
|
||||
START_OPT_COMMAND_DOCSTRING = '''\
|
||||
Run a LLMServer for OPT model.
|
||||
|
||||
\b
|
||||
@@ -26,8 +26,8 @@ or provide `--model-id` flag when running ``openllm start opt``:
|
||||
|
||||
\b
|
||||
$ openllm start opt --model-id facebook/opt-6.7b
|
||||
"""
|
||||
DEFAULT_PROMPT_TEMPLATE = """{instruction}"""
|
||||
'''
|
||||
DEFAULT_PROMPT_TEMPLATE = '''{instruction}'''
|
||||
class OPTConfig(openllm_core.LLMConfig):
|
||||
"""OPT was first introduced in [Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) and first released in [metaseq's repository](https://github.com/facebookresearch/metaseq) on May 3rd 2022 by Meta AI.
|
||||
|
||||
@@ -39,17 +39,17 @@ class OPTConfig(openllm_core.LLMConfig):
|
||||
Refer to [OPT's HuggingFace page](https://huggingface.co/docs/transformers/model_doc/opt) for more information.
|
||||
"""
|
||||
__config__ = {
|
||||
"name_type": "lowercase",
|
||||
"trust_remote_code": False,
|
||||
"url": "https://huggingface.co/docs/transformers/model_doc/opt",
|
||||
"default_id": "facebook/opt-1.3b",
|
||||
"architecture": "OPTForCausalLM",
|
||||
"model_ids": ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b", "facebook/opt-2.7b", "facebook/opt-6.7b", "facebook/opt-66b"],
|
||||
"fine_tune_strategies": ({
|
||||
"adapter_type": "lora", "r": 16, "lora_alpha": 32, "target_modules": ["q_proj", "v_proj"], "lora_dropout": 0.05, "bias": "none"
|
||||
'name_type': 'lowercase',
|
||||
'trust_remote_code': False,
|
||||
'url': 'https://huggingface.co/docs/transformers/model_doc/opt',
|
||||
'default_id': 'facebook/opt-1.3b',
|
||||
'architecture': 'OPTForCausalLM',
|
||||
'model_ids': ['facebook/opt-125m', 'facebook/opt-350m', 'facebook/opt-1.3b', 'facebook/opt-2.7b', 'facebook/opt-6.7b', 'facebook/opt-66b'],
|
||||
'fine_tune_strategies': ({
|
||||
'adapter_type': 'lora', 'r': 16, 'lora_alpha': 32, 'target_modules': ['q_proj', 'v_proj'], 'lora_dropout': 0.05, 'bias': 'none'
|
||||
},)
|
||||
}
|
||||
format_outputs: bool = dantic.Field(False, description="""Whether to format the outputs. This can be used when num_return_sequences > 1.""")
|
||||
format_outputs: bool = dantic.Field(False, description='''Whether to format the outputs. This can be used when num_return_sequences > 1.''')
|
||||
|
||||
class GenerationConfig:
|
||||
top_k: int = 15
|
||||
@@ -68,10 +68,10 @@ class OPTConfig(openllm_core.LLMConfig):
|
||||
**attrs: t.Any
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {
|
||||
"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "num_return_sequences": num_return_sequences
|
||||
'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_k': top_k, 'num_return_sequences': num_return_sequences
|
||||
}, {}
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **attrs: t.Any) -> str:
|
||||
if len(generation_result) == 1: return generation_result[0]
|
||||
if self.config.format_outputs: return "Generated result:\n" + "\n -".join(generation_result)
|
||||
else: return "\n".join(generation_result)
|
||||
if self.config.format_outputs: return 'Generated result:\n' + '\n -'.join(generation_result)
|
||||
else: return '\n'.join(generation_result)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import openllm_core, typing as t
|
||||
from openllm_core._prompt import process_prompt
|
||||
START_STABLELM_COMMAND_DOCSTRING = """\
|
||||
START_STABLELM_COMMAND_DOCSTRING = '''\
|
||||
Run a LLMServer for StableLM model.
|
||||
|
||||
\b
|
||||
@@ -19,14 +19,14 @@ or provide `--model-id` flag when running ``openllm start stablelm``:
|
||||
|
||||
\b
|
||||
$ openllm start stablelm --model-id 'stabilityai/stablelm-tuned-alpha-3b'
|
||||
"""
|
||||
SYSTEM_PROMPT = """<|SYSTEM|># StableLM Tuned (Alpha version)
|
||||
'''
|
||||
SYSTEM_PROMPT = '''<|SYSTEM|># StableLM Tuned (Alpha version)
|
||||
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
|
||||
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
|
||||
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
|
||||
- StableLM will refuse to participate in anything that could harm a human.
|
||||
"""
|
||||
DEFAULT_PROMPT_TEMPLATE = """{system_prompt}<|USER|>{instruction}<|ASSISTANT|>"""
|
||||
'''
|
||||
DEFAULT_PROMPT_TEMPLATE = '''{system_prompt}<|USER|>{instruction}<|ASSISTANT|>'''
|
||||
class StableLMConfig(openllm_core.LLMConfig):
|
||||
"""StableLM-Base-Alpha is a suite of 3B and 7B parameter decoder-only language models.
|
||||
|
||||
@@ -42,11 +42,11 @@ class StableLMConfig(openllm_core.LLMConfig):
|
||||
for more information.
|
||||
"""
|
||||
__config__ = {
|
||||
"name_type": "lowercase",
|
||||
"url": "https://github.com/Stability-AI/StableLM",
|
||||
"architecture": "GPTNeoXForCausalLM",
|
||||
"default_id": "stabilityai/stablelm-tuned-alpha-3b",
|
||||
"model_ids": ["stabilityai/stablelm-tuned-alpha-3b", "stabilityai/stablelm-tuned-alpha-7b", "stabilityai/stablelm-base-alpha-3b", "stabilityai/stablelm-base-alpha-7b"]
|
||||
'name_type': 'lowercase',
|
||||
'url': 'https://github.com/Stability-AI/StableLM',
|
||||
'architecture': 'GPTNeoXForCausalLM',
|
||||
'default_id': 'stabilityai/stablelm-tuned-alpha-3b',
|
||||
'model_ids': ['stabilityai/stablelm-tuned-alpha-3b', 'stabilityai/stablelm-tuned-alpha-7b', 'stabilityai/stablelm-base-alpha-3b', 'stabilityai/stablelm-base-alpha-7b']
|
||||
}
|
||||
|
||||
class GenerationConfig:
|
||||
@@ -65,12 +65,12 @@ class StableLMConfig(openllm_core.LLMConfig):
|
||||
use_default_prompt_template: bool = False,
|
||||
**attrs: t.Any
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
if "tuned" in self._model_id and use_default_prompt_template:
|
||||
system_prompt = attrs.pop("system_prompt", SYSTEM_PROMPT)
|
||||
if 'tuned' in self._model_id and use_default_prompt_template:
|
||||
system_prompt = attrs.pop('system_prompt', SYSTEM_PROMPT)
|
||||
prompt_text = process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, system_prompt=system_prompt, **attrs)
|
||||
else:
|
||||
prompt_text = prompt
|
||||
return prompt_text, {"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "top_p": top_p}, {}
|
||||
return prompt_text, {'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_k': top_k, 'top_p': top_p}, {}
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: list[str], **_: t.Any) -> str:
|
||||
return generation_result[0]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import openllm_core, typing as t
|
||||
START_STARCODER_COMMAND_DOCSTRING = """\
|
||||
START_STARCODER_COMMAND_DOCSTRING = '''\
|
||||
Run a LLMServer for StarCoder model.
|
||||
|
||||
\b
|
||||
@@ -18,9 +18,9 @@ or provide `--model-id` flag when running ``openllm start starcoder``:
|
||||
|
||||
\b
|
||||
$ openllm start starcoder --model-id 'bigcode/starcoder'
|
||||
"""
|
||||
DEFAULT_PROMPT_TEMPLATE = """{instruction}"""
|
||||
FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD, EOD, FIM_INDICATOR = "<fim-prefix>", "<fim-middle>", "<fim-suffix>", "<fim-pad>", "<|endoftext|>", "<FILL_HERE>"
|
||||
'''
|
||||
DEFAULT_PROMPT_TEMPLATE = '''{instruction}'''
|
||||
FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD, EOD, FIM_INDICATOR = '<fim-prefix>', '<fim-middle>', '<fim-suffix>', '<fim-pad>', '<|endoftext|>', '<FILL_HERE>'
|
||||
class StarCoderConfig(openllm_core.LLMConfig):
|
||||
"""The StarCoder models are 15.5B parameter models trained on 80+ programming languages from [The Stack (v1.2)](https://huggingface.co/datasets/bigcode/the-stack), with opt-out requests excluded.
|
||||
|
||||
@@ -31,14 +31,14 @@ class StarCoderConfig(openllm_core.LLMConfig):
|
||||
Refer to [StarCoder's model card](https://huggingface.co/bigcode/starcoder) for more information.
|
||||
"""
|
||||
__config__ = {
|
||||
"name_type": "lowercase",
|
||||
"requires_gpu": True,
|
||||
"url": "https://github.com/bigcode-project/starcoder",
|
||||
"architecture": "GPTBigCodeForCausalLM",
|
||||
"requirements": ["bitsandbytes"],
|
||||
"workers_per_resource": 0.5,
|
||||
"default_id": "bigcode/starcoder",
|
||||
"model_ids": ["bigcode/starcoder", "bigcode/starcoderbase"]
|
||||
'name_type': 'lowercase',
|
||||
'requires_gpu': True,
|
||||
'url': 'https://github.com/bigcode-project/starcoder',
|
||||
'architecture': 'GPTBigCodeForCausalLM',
|
||||
'requirements': ['bitsandbytes'],
|
||||
'workers_per_resource': 0.5,
|
||||
'default_id': 'bigcode/starcoder',
|
||||
'model_ids': ['bigcode/starcoder', 'bigcode/starcoderbase']
|
||||
}
|
||||
|
||||
class GenerationConfig:
|
||||
@@ -58,12 +58,12 @@ class StarCoderConfig(openllm_core.LLMConfig):
|
||||
try:
|
||||
prefix, suffix = prompt.split(FIM_INDICATOR)
|
||||
except Exception as err:
|
||||
raise ValueError(f"Only one {FIM_INDICATOR} allowed in prompt") from err
|
||||
prompt_text = f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}"
|
||||
raise ValueError(f'Only one {FIM_INDICATOR} allowed in prompt') from err
|
||||
prompt_text = f'{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}'
|
||||
else:
|
||||
prompt_text = prompt
|
||||
# XXX: This value for pad_token_id is currently a hack, need more investigate why the default starcoder doesn't include the same value as santacoder EOD
|
||||
return prompt_text, {"temperature": temperature, "top_p": top_p, "max_new_tokens": max_new_tokens, "repetition_penalty": repetition_penalty, "pad_token_id": 49152, **attrs}, {}
|
||||
return prompt_text, {'temperature': temperature, 'top_p': top_p, 'max_new_tokens': max_new_tokens, 'repetition_penalty': repetition_penalty, 'pad_token_id': 49152, **attrs}, {}
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str:
|
||||
return generation_result[0]
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
"""Base exceptions for OpenLLM. This extends BentoML exceptions."""
|
||||
'''Base exceptions for OpenLLM. This extends BentoML exceptions.'''
|
||||
from __future__ import annotations
|
||||
import bentoml
|
||||
class OpenLLMException(bentoml.exceptions.BentoMLException):
|
||||
"""Base class for all OpenLLM exceptions. This extends BentoMLException."""
|
||||
'''Base class for all OpenLLM exceptions. This extends BentoMLException.'''
|
||||
class GpuNotAvailableError(OpenLLMException):
|
||||
"""Raised when there is no GPU available in given system."""
|
||||
'''Raised when there is no GPU available in given system.'''
|
||||
class ValidationError(OpenLLMException):
|
||||
"""Raised when a validation fails."""
|
||||
'''Raised when a validation fails.'''
|
||||
class ForbiddenAttributeError(OpenLLMException):
|
||||
"""Raised when using an _internal field."""
|
||||
'''Raised when using an _internal field.'''
|
||||
class MissingAnnotationAttributeError(OpenLLMException):
|
||||
"""Raised when a field under openllm.LLMConfig is missing annotations."""
|
||||
'''Raised when a field under openllm.LLMConfig is missing annotations.'''
|
||||
class MissingDependencyError(BaseException):
|
||||
"""Raised when a dependency is missing."""
|
||||
'''Raised when a dependency is missing.'''
|
||||
class Error(BaseException):
|
||||
"""To be used instead of naked raise."""
|
||||
'''To be used instead of naked raise.'''
|
||||
class FineTuneStrategyNotSupportedError(OpenLLMException):
|
||||
"""Raised when a fine-tune strategy is not supported for given LLM."""
|
||||
'''Raised when a fine-tune strategy is not supported for given LLM.'''
|
||||
|
||||
@@ -37,12 +37,12 @@ except ImportError:
|
||||
if sys.version_info < (3, 10): _WithArgsTypes = (_TypingGenericAlias,)
|
||||
else: _WithArgsTypes: t.Any = (t._GenericAlias, types.GenericAlias, types.UnionType) # type: ignore # _GenericAlias is the actual GenericAlias implementation
|
||||
|
||||
DEV_DEBUG_VAR = "OPENLLMDEVDEBUG"
|
||||
DEV_DEBUG_VAR = 'OPENLLMDEVDEBUG'
|
||||
def set_debug_mode(enabled: bool, level: int = 1) -> None:
|
||||
# monkeypatch bentoml._internal.configuration.set_debug_mode to remove unused logs
|
||||
if enabled: os.environ[DEV_DEBUG_VAR] = str(level)
|
||||
os.environ[DEBUG_ENV_VAR] = str(enabled)
|
||||
os.environ[_GRPC_DEBUG_ENV_VAR] = "DEBUG" if enabled else "ERROR"
|
||||
os.environ[_GRPC_DEBUG_ENV_VAR] = 'DEBUG' if enabled else 'ERROR'
|
||||
def lenient_issubclass(cls: t.Any, class_or_tuple: type[t.Any] | tuple[type[t.Any], ...] | None) -> bool:
|
||||
try:
|
||||
return isinstance(cls, type) and issubclass(cls, class_or_tuple) # type: ignore[arg-type]
|
||||
@@ -54,11 +54,11 @@ def ensure_exec_coro(coro: t.Coroutine[t.Any, t.Any, t.Any]) -> t.Any:
|
||||
if loop.is_running(): return asyncio.run_coroutine_threadsafe(coro, loop).result()
|
||||
else: return loop.run_until_complete(coro)
|
||||
def available_devices() -> tuple[str, ...]:
|
||||
"""Return available GPU under system. Currently only supports NVIDIA GPUs."""
|
||||
'''Return available GPU under system. Currently only supports NVIDIA GPUs.'''
|
||||
from openllm_core._strategies import NvidiaGpuResource
|
||||
return tuple(NvidiaGpuResource.from_system())
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def generate_hash_from_file(f: str, algorithm: t.Literal["md5", "sha1"] = "sha1") -> str:
|
||||
def generate_hash_from_file(f: str, algorithm: t.Literal['md5', 'sha1'] = 'sha1') -> str:
|
||||
"""Generate a hash from given file's modification time.
|
||||
|
||||
Args:
|
||||
@@ -79,19 +79,19 @@ def non_intrusive_setattr(obj: t.Any, name: str, value: t.Any) -> None:
|
||||
_setattr = functools.partial(setattr, obj) if isinstance(obj, type) else _object_setattr.__get__(obj)
|
||||
if not hasattr(obj, name): _setattr(name, value)
|
||||
def field_env_key(model_name: str, key: str, suffix: str | None = None) -> str:
|
||||
return "_".join(filter(None, map(str.upper, ["OPENLLM", model_name, suffix.strip("_") if suffix else "", key])))
|
||||
return '_'.join(filter(None, map(str.upper, ['OPENLLM', model_name, suffix.strip('_') if suffix else '', key])))
|
||||
# Special debug flag controled via OPENLLMDEVDEBUG
|
||||
DEBUG: bool = sys.flags.dev_mode or (not sys.flags.ignore_environment and bool(os.environ.get(DEV_DEBUG_VAR)))
|
||||
# MYPY is like t.TYPE_CHECKING, but reserved for Mypy plugins
|
||||
MYPY = False
|
||||
SHOW_CODEGEN: bool = DEBUG and int(os.environ.get("OPENLLMDEVDEBUG", str(0))) > 3
|
||||
SHOW_CODEGEN: bool = DEBUG and int(os.environ.get('OPENLLMDEVDEBUG', str(0))) > 3
|
||||
def get_debug_mode() -> bool:
|
||||
return DEBUG or _get_debug_mode()
|
||||
def get_quiet_mode() -> bool:
|
||||
return not DEBUG and _get_quiet_mode()
|
||||
class ExceptionFilter(logging.Filter):
|
||||
def __init__(self, exclude_exceptions: list[type[Exception]] | None = None, **kwargs: t.Any):
|
||||
"""A filter of all exception."""
|
||||
'''A filter of all exception.'''
|
||||
if exclude_exceptions is None: exclude_exceptions = [ConflictError]
|
||||
if ConflictError not in exclude_exceptions: exclude_exceptions.append(ConflictError)
|
||||
super(ExceptionFilter, self).__init__(**kwargs)
|
||||
@@ -108,52 +108,52 @@ class InfoFilter(logging.Filter):
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
return logging.INFO <= record.levelno < logging.WARNING
|
||||
_LOGGING_CONFIG: dict[str, t.Any] = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": True,
|
||||
"filters": {
|
||||
"excfilter": {
|
||||
"()": "openllm_core.utils.ExceptionFilter"
|
||||
}, "infofilter": {
|
||||
"()": "openllm_core.utils.InfoFilter"
|
||||
'version': 1,
|
||||
'disable_existing_loggers': True,
|
||||
'filters': {
|
||||
'excfilter': {
|
||||
'()': 'openllm_core.utils.ExceptionFilter'
|
||||
}, 'infofilter': {
|
||||
'()': 'openllm_core.utils.InfoFilter'
|
||||
}
|
||||
},
|
||||
"handlers": {
|
||||
"bentomlhandler": {
|
||||
"class": "logging.StreamHandler", "filters": ["excfilter", "infofilter"], "stream": "ext://sys.stdout"
|
||||
'handlers': {
|
||||
'bentomlhandler': {
|
||||
'class': 'logging.StreamHandler', 'filters': ['excfilter', 'infofilter'], 'stream': 'ext://sys.stdout'
|
||||
},
|
||||
"defaulthandler": {
|
||||
"class": "logging.StreamHandler", "level": logging.WARNING
|
||||
'defaulthandler': {
|
||||
'class': 'logging.StreamHandler', 'level': logging.WARNING
|
||||
}
|
||||
},
|
||||
"loggers": {
|
||||
"bentoml": {
|
||||
"handlers": ["bentomlhandler", "defaulthandler"], "level": logging.INFO, "propagate": False
|
||||
'loggers': {
|
||||
'bentoml': {
|
||||
'handlers': ['bentomlhandler', 'defaulthandler'], 'level': logging.INFO, 'propagate': False
|
||||
},
|
||||
"openllm": {
|
||||
"handlers": ["bentomlhandler", "defaulthandler"], "level": logging.INFO, "propagate": False
|
||||
'openllm': {
|
||||
'handlers': ['bentomlhandler', 'defaulthandler'], 'level': logging.INFO, 'propagate': False
|
||||
}
|
||||
},
|
||||
"root": {
|
||||
"level": logging.WARNING
|
||||
'root': {
|
||||
'level': logging.WARNING
|
||||
},
|
||||
}
|
||||
def configure_logging() -> None:
|
||||
"""Configure logging for OpenLLM.
|
||||
'''Configure logging for OpenLLM.
|
||||
|
||||
Behaves similar to how BentoML loggers are being configured.
|
||||
"""
|
||||
'''
|
||||
if get_quiet_mode():
|
||||
_LOGGING_CONFIG["loggers"]["openllm"]["level"] = logging.ERROR
|
||||
_LOGGING_CONFIG["loggers"]["bentoml"]["level"] = logging.ERROR
|
||||
_LOGGING_CONFIG["root"]["level"] = logging.ERROR
|
||||
_LOGGING_CONFIG['loggers']['openllm']['level'] = logging.ERROR
|
||||
_LOGGING_CONFIG['loggers']['bentoml']['level'] = logging.ERROR
|
||||
_LOGGING_CONFIG['root']['level'] = logging.ERROR
|
||||
elif get_debug_mode() or DEBUG:
|
||||
_LOGGING_CONFIG["loggers"]["openllm"]["level"] = logging.DEBUG
|
||||
_LOGGING_CONFIG["loggers"]["bentoml"]["level"] = logging.DEBUG
|
||||
_LOGGING_CONFIG["root"]["level"] = logging.DEBUG
|
||||
_LOGGING_CONFIG['loggers']['openllm']['level'] = logging.DEBUG
|
||||
_LOGGING_CONFIG['loggers']['bentoml']['level'] = logging.DEBUG
|
||||
_LOGGING_CONFIG['root']['level'] = logging.DEBUG
|
||||
else:
|
||||
_LOGGING_CONFIG["loggers"]["openllm"]["level"] = logging.INFO
|
||||
_LOGGING_CONFIG["loggers"]["bentoml"]["level"] = logging.INFO
|
||||
_LOGGING_CONFIG["root"]["level"] = logging.INFO
|
||||
_LOGGING_CONFIG['loggers']['openllm']['level'] = logging.INFO
|
||||
_LOGGING_CONFIG['loggers']['bentoml']['level'] = logging.INFO
|
||||
_LOGGING_CONFIG['root']['level'] = logging.INFO
|
||||
|
||||
logging.config.dictConfig(_LOGGING_CONFIG)
|
||||
@functools.lru_cache(maxsize=1)
|
||||
@@ -162,10 +162,10 @@ def in_notebook() -> bool:
|
||||
from IPython.core.getipython import get_ipython
|
||||
if t.TYPE_CHECKING:
|
||||
from IPython.core.interactiveshell import InteractiveShell
|
||||
return "IPKernelApp" in t.cast("dict[str, t.Any]", t.cast(t.Callable[[], "InteractiveShell"], get_ipython)().config)
|
||||
return 'IPKernelApp' in t.cast('dict[str, t.Any]', t.cast(t.Callable[[], 'InteractiveShell'], get_ipython)().config)
|
||||
except (ImportError, AttributeError):
|
||||
return False
|
||||
_dockerenv, _cgroup = Path("/.dockerenv"), Path("/proc/self/cgroup")
|
||||
_dockerenv, _cgroup = Path('/.dockerenv'), Path('/proc/self/cgroup')
|
||||
class suppress(contextlib.suppress, contextlib.ContextDecorator):
|
||||
"""A version of contextlib.suppress with decorator support.
|
||||
|
||||
@@ -175,7 +175,7 @@ class suppress(contextlib.suppress, contextlib.ContextDecorator):
|
||||
>>> key_error()
|
||||
"""
|
||||
def compose(*funcs: AnyCallable) -> AnyCallable:
|
||||
"""Compose any number of unary functions into a single unary function.
|
||||
'''Compose any number of unary functions into a single unary function.
|
||||
|
||||
>>> import textwrap
|
||||
>>> expected = str.strip(textwrap.dedent(compose.__doc__))
|
||||
@@ -189,7 +189,7 @@ def compose(*funcs: AnyCallable) -> AnyCallable:
|
||||
>>> f = compose(round_three, int.__truediv__)
|
||||
>>> [f(3*x, x+1) for x in range(1,10)]
|
||||
[1.5, 2.0, 2.25, 2.4, 2.5, 2.571, 2.625, 2.667, 2.7]
|
||||
"""
|
||||
'''
|
||||
def compose_two(f1: AnyCallable, f2: AnyCallable) -> AnyCallable:
|
||||
return lambda *args, **kwargs: f1(f2(*args, **kwargs))
|
||||
|
||||
@@ -216,16 +216,16 @@ def apply(transform: AnyCallable) -> t.Callable[[AnyCallable], AnyCallable]:
|
||||
def _text_in_file(text: str, filename: Path) -> bool:
|
||||
return any(text in line for line in filename.open())
|
||||
def in_docker() -> bool:
|
||||
"""Is this current environment running in docker?
|
||||
'''Is this current environment running in docker?
|
||||
|
||||
```python
|
||||
type(in_docker())
|
||||
```
|
||||
"""
|
||||
return _dockerenv.exists() or _text_in_file("docker", _cgroup)
|
||||
T, K = t.TypeVar("T"), t.TypeVar("K")
|
||||
'''
|
||||
return _dockerenv.exists() or _text_in_file('docker', _cgroup)
|
||||
T, K = t.TypeVar('T'), t.TypeVar('K')
|
||||
def resolve_filepath(path: str, ctx: str | None = None) -> str:
|
||||
"""Resolve a file path to an absolute path, expand user and environment variables."""
|
||||
'''Resolve a file path to an absolute path, expand user and environment variables.'''
|
||||
try:
|
||||
return resolve_user_filepath(path, ctx)
|
||||
except FileNotFoundError:
|
||||
@@ -233,16 +233,16 @@ def resolve_filepath(path: str, ctx: str | None = None) -> str:
|
||||
def validate_is_path(maybe_path: str) -> bool:
|
||||
return os.path.exists(os.path.dirname(resolve_filepath(maybe_path)))
|
||||
def generate_context(framework_name: str) -> _ModelContext:
|
||||
framework_versions = {"transformers": pkg.get_pkg_version("transformers")}
|
||||
if openllm_core.utils.is_torch_available(): framework_versions["torch"] = pkg.get_pkg_version("torch")
|
||||
framework_versions = {'transformers': pkg.get_pkg_version('transformers')}
|
||||
if openllm_core.utils.is_torch_available(): framework_versions['torch'] = pkg.get_pkg_version('torch')
|
||||
if openllm_core.utils.is_tf_available():
|
||||
from bentoml._internal.frameworks.utils.tensorflow import get_tf_version
|
||||
framework_versions["tensorflow"] = get_tf_version()
|
||||
if openllm_core.utils.is_flax_available(): framework_versions.update({"flax": pkg.get_pkg_version("flax"), "jax": pkg.get_pkg_version("jax"), "jaxlib": pkg.get_pkg_version("jaxlib")})
|
||||
framework_versions['tensorflow'] = get_tf_version()
|
||||
if openllm_core.utils.is_flax_available(): framework_versions.update({'flax': pkg.get_pkg_version('flax'), 'jax': pkg.get_pkg_version('jax'), 'jaxlib': pkg.get_pkg_version('jaxlib')})
|
||||
return _ModelContext(framework_name=framework_name, framework_versions=framework_versions)
|
||||
_TOKENIZER_PREFIX = "_tokenizer_"
|
||||
_TOKENIZER_PREFIX = '_tokenizer_'
|
||||
def normalize_attrs_to_model_tokenizer_pair(**attrs: t.Any) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
|
||||
"""Normalize the given attrs to a model and tokenizer kwargs accordingly."""
|
||||
'''Normalize the given attrs to a model and tokenizer kwargs accordingly.'''
|
||||
tokenizer_attrs = {k[len(_TOKENIZER_PREFIX):]: v for k, v in attrs.items() if k.startswith(_TOKENIZER_PREFIX)}
|
||||
for k in tuple(attrs.keys()):
|
||||
if k.startswith(_TOKENIZER_PREFIX): del attrs[k]
|
||||
@@ -250,46 +250,46 @@ def normalize_attrs_to_model_tokenizer_pair(**attrs: t.Any) -> tuple[dict[str, t
|
||||
# NOTE: The set marks contains a set of modules name
|
||||
# that are available above and are whitelisted
|
||||
# to be included in the extra_objects map.
|
||||
_whitelist_modules = {"pkg"}
|
||||
_whitelist_modules = {'pkg'}
|
||||
|
||||
# XXX: define all classes, functions import above this line
|
||||
# since _extras will be the locals() import from this file.
|
||||
_extras: dict[str, t.Any] = {k: v for k, v in locals().items() if k in _whitelist_modules or (not isinstance(v, types.ModuleType) and not k.startswith("_"))}
|
||||
_extras["__openllm_migration__"] = {"ModelEnv": "EnvVarMixin"}
|
||||
_extras: dict[str, t.Any] = {k: v for k, v in locals().items() if k in _whitelist_modules or (not isinstance(v, types.ModuleType) and not k.startswith('_'))}
|
||||
_extras['__openllm_migration__'] = {'ModelEnv': 'EnvVarMixin'}
|
||||
_import_structure: dict[str, list[str]] = {
|
||||
"analytics": [],
|
||||
"codegen": [],
|
||||
"dantic": [],
|
||||
"representation": ["ReprMixin"],
|
||||
"lazy": ["LazyModule"],
|
||||
"import_utils": [
|
||||
"OPTIONAL_DEPENDENCIES",
|
||||
"ENV_VARS_TRUE_VALUES",
|
||||
"DummyMetaclass",
|
||||
"EnvVarMixin",
|
||||
"require_backends",
|
||||
"is_cpm_kernels_available",
|
||||
"is_einops_available",
|
||||
"is_flax_available",
|
||||
"is_tf_available",
|
||||
"is_vllm_available",
|
||||
"is_torch_available",
|
||||
"is_bitsandbytes_available",
|
||||
"is_peft_available",
|
||||
"is_datasets_available",
|
||||
"is_transformers_supports_kbit",
|
||||
"is_transformers_supports_agent",
|
||||
"is_jupyter_available",
|
||||
"is_jupytext_available",
|
||||
"is_notebook_available",
|
||||
"is_triton_available",
|
||||
"is_autogptq_available",
|
||||
"is_sentencepiece_available",
|
||||
"is_xformers_available",
|
||||
"is_fairscale_available",
|
||||
"is_grpc_available",
|
||||
"is_grpc_health_available",
|
||||
"is_transformers_available"
|
||||
'analytics': [],
|
||||
'codegen': [],
|
||||
'dantic': [],
|
||||
'representation': ['ReprMixin'],
|
||||
'lazy': ['LazyModule'],
|
||||
'import_utils': [
|
||||
'OPTIONAL_DEPENDENCIES',
|
||||
'ENV_VARS_TRUE_VALUES',
|
||||
'DummyMetaclass',
|
||||
'EnvVarMixin',
|
||||
'require_backends',
|
||||
'is_cpm_kernels_available',
|
||||
'is_einops_available',
|
||||
'is_flax_available',
|
||||
'is_tf_available',
|
||||
'is_vllm_available',
|
||||
'is_torch_available',
|
||||
'is_bitsandbytes_available',
|
||||
'is_peft_available',
|
||||
'is_datasets_available',
|
||||
'is_transformers_supports_kbit',
|
||||
'is_transformers_supports_agent',
|
||||
'is_jupyter_available',
|
||||
'is_jupytext_available',
|
||||
'is_notebook_available',
|
||||
'is_triton_available',
|
||||
'is_autogptq_available',
|
||||
'is_sentencepiece_available',
|
||||
'is_xformers_available',
|
||||
'is_fairscale_available',
|
||||
'is_grpc_available',
|
||||
'is_grpc_health_available',
|
||||
'is_transformers_available'
|
||||
]
|
||||
}
|
||||
|
||||
@@ -326,7 +326,7 @@ if t.TYPE_CHECKING:
|
||||
require_backends as require_backends,
|
||||
)
|
||||
from .representation import ReprMixin as ReprMixin
|
||||
__lazy = LazyModule(__name__, globals()["__file__"], _import_structure, extra_objects=_extras)
|
||||
__lazy = LazyModule(__name__, globals()['__file__'], _import_structure, extra_objects=_extras)
|
||||
__all__ = __lazy.__all__
|
||||
__dir__ = __lazy.__dir__
|
||||
__getattr__ = __lazy.__getattr__
|
||||
|
||||
@@ -1,24 +1,24 @@
|
||||
"""Telemetry related for OpenLLM tracking.
|
||||
'''Telemetry related for OpenLLM tracking.
|
||||
|
||||
Users can disable this with OPENLLM_DO_NOT_TRACK envvar.
|
||||
"""
|
||||
'''
|
||||
from __future__ import annotations
|
||||
import contextlib, functools, logging, os, re, typing as t, importlib.metadata, attr, openllm_core
|
||||
from bentoml._internal.utils import analytics as _internal_analytics
|
||||
from openllm_core._typing_compat import ParamSpec
|
||||
P = ParamSpec("P")
|
||||
T = t.TypeVar("T")
|
||||
P = ParamSpec('P')
|
||||
T = t.TypeVar('T')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# This variable is a proxy that will control BENTOML_DO_NOT_TRACK
|
||||
OPENLLM_DO_NOT_TRACK = "OPENLLM_DO_NOT_TRACK"
|
||||
OPENLLM_DO_NOT_TRACK = 'OPENLLM_DO_NOT_TRACK'
|
||||
DO_NOT_TRACK = os.environ.get(OPENLLM_DO_NOT_TRACK, str(False)).upper()
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def do_not_track() -> bool:
|
||||
return DO_NOT_TRACK in openllm_core.utils.ENV_VARS_TRUE_VALUES
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def _usage_event_debugging() -> bool:
|
||||
return os.environ.get("__BENTOML_DEBUG_USAGE", str(False)).lower() == "true"
|
||||
return os.environ.get('__BENTOML_DEBUG_USAGE', str(False)).lower() == 'true'
|
||||
def silent(func: t.Callable[P, T]) -> t.Callable[P, T]:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> t.Any:
|
||||
@@ -26,15 +26,15 @@ def silent(func: t.Callable[P, T]) -> t.Callable[P, T]:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as err:
|
||||
if _usage_event_debugging():
|
||||
if openllm_core.utils.get_debug_mode(): logger.error("Tracking Error: %s", err, stack_info=True, stacklevel=3)
|
||||
else: logger.info("Tracking Error: %s", err)
|
||||
else: logger.debug("Tracking Error: %s", err)
|
||||
if openllm_core.utils.get_debug_mode(): logger.error('Tracking Error: %s', err, stack_info=True, stacklevel=3)
|
||||
else: logger.info('Tracking Error: %s', err)
|
||||
else: logger.debug('Tracking Error: %s', err)
|
||||
|
||||
return wrapper
|
||||
@silent
|
||||
def track(event_properties: attr.AttrsInstance) -> None:
|
||||
if do_not_track(): return
|
||||
_internal_analytics.track(t.cast("_internal_analytics.schemas.EventMeta", event_properties))
|
||||
_internal_analytics.track(t.cast('_internal_analytics.schemas.EventMeta', event_properties))
|
||||
@contextlib.contextmanager
|
||||
def set_bentoml_tracking() -> t.Generator[None, None, None]:
|
||||
original_value = os.environ.pop(_internal_analytics.BENTOML_DO_NOT_TRACK, str(False))
|
||||
@@ -47,9 +47,9 @@ class EventMeta:
|
||||
@property
|
||||
def event_name(self) -> str:
|
||||
# camel case to snake case
|
||||
event_name = re.sub(r"(?<!^)(?=[A-Z])", "_", self.__class__.__name__).lower()
|
||||
event_name = re.sub(r'(?<!^)(?=[A-Z])', '_', self.__class__.__name__).lower()
|
||||
# remove "_event" suffix
|
||||
suffix_to_remove = "_event"
|
||||
suffix_to_remove = '_event'
|
||||
if event_name.endswith(suffix_to_remove): event_name = event_name[:-len(suffix_to_remove)]
|
||||
return event_name
|
||||
@attr.define
|
||||
@@ -60,7 +60,7 @@ class ModelSaveEvent(EventMeta):
|
||||
class OpenllmCliEvent(EventMeta):
|
||||
cmd_group: str
|
||||
cmd_name: str
|
||||
openllm_version: str = importlib.metadata.version("openllm")
|
||||
openllm_version: str = importlib.metadata.version('openllm')
|
||||
# NOTE: reserved for the do_not_track logics
|
||||
duration_in_ms: t.Any = attr.field(default=None)
|
||||
error_type: str = attr.field(default=None)
|
||||
@@ -72,7 +72,7 @@ class StartInitEvent(EventMeta):
|
||||
|
||||
@staticmethod
|
||||
def handler(llm_config: openllm_core.LLMConfig) -> StartInitEvent:
|
||||
return StartInitEvent(model_name=llm_config["model_name"], llm_config=llm_config.model_dump())
|
||||
return StartInitEvent(model_name=llm_config['model_name'], llm_config=llm_config.model_dump())
|
||||
def track_start_init(llm_config: openllm_core.LLMConfig) -> None:
|
||||
if do_not_track(): return
|
||||
track(StartInitEvent.handler(llm_config))
|
||||
|
||||
@@ -7,7 +7,7 @@ if t.TYPE_CHECKING:
|
||||
from openllm_core._typing_compat import LiteralString, AnyCallable, DictStrAny, ListStr
|
||||
PartialAny = functools.partial[t.Any]
|
||||
|
||||
_T = t.TypeVar("_T", bound=t.Callable[..., t.Any])
|
||||
_T = t.TypeVar('_T', bound=t.Callable[..., t.Any])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# sentinel object for unequivocal object() getattr
|
||||
@@ -21,29 +21,29 @@ def has_own_attribute(cls: type[t.Any], attrib_name: t.Any) -> bool:
|
||||
if attr is a: return False
|
||||
return True
|
||||
def get_annotations(cls: type[t.Any]) -> DictStrAny:
|
||||
if has_own_attribute(cls, "__annotations__"): return cls.__annotations__
|
||||
return t.cast("DictStrAny", {})
|
||||
if has_own_attribute(cls, '__annotations__'): return cls.__annotations__
|
||||
return t.cast('DictStrAny', {})
|
||||
def is_class_var(annot: str | t.Any) -> bool:
|
||||
annot = str(annot)
|
||||
# Annotation can be quoted.
|
||||
if annot.startswith(("'", '"')) and annot.endswith(("'", '"')): annot = annot[1:-1]
|
||||
return annot.startswith(("typing.ClassVar", "t.ClassVar", "ClassVar", "typing_extensions.ClassVar",))
|
||||
return annot.startswith(('typing.ClassVar', 't.ClassVar', 'ClassVar', 'typing_extensions.ClassVar',))
|
||||
def add_method_dunders(cls: type[t.Any], method_or_cls: _T, _overwrite_doc: str | None = None) -> _T:
|
||||
try:
|
||||
method_or_cls.__module__ = cls.__module__
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
method_or_cls.__qualname__ = f"{cls.__qualname__}.{method_or_cls.__name__}"
|
||||
method_or_cls.__qualname__ = f'{cls.__qualname__}.{method_or_cls.__name__}'
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
method_or_cls.__doc__ = _overwrite_doc or "Generated by ``openllm.LLMConfig`` for class " f"{cls.__qualname__}."
|
||||
method_or_cls.__doc__ = _overwrite_doc or 'Generated by ``openllm.LLMConfig`` for class ' f'{cls.__qualname__}.'
|
||||
except AttributeError:
|
||||
pass
|
||||
return method_or_cls
|
||||
def _compile_and_eval(script: str, globs: DictStrAny, locs: t.Any = None, filename: str = "") -> None:
|
||||
eval(compile(script, filename, "exec"), globs, locs)
|
||||
def _compile_and_eval(script: str, globs: DictStrAny, locs: t.Any = None, filename: str = '') -> None:
|
||||
eval(compile(script, filename, 'exec'), globs, locs)
|
||||
def _make_method(name: str, script: str, filename: str, globs: DictStrAny) -> AnyCallable:
|
||||
locs: DictStrAny = {}
|
||||
# In order of debuggers like PDB being able to step through the code, we add a fake linecache entry.
|
||||
@@ -54,31 +54,31 @@ def _make_method(name: str, script: str, filename: str, globs: DictStrAny) -> An
|
||||
old_val = linecache.cache.setdefault(filename, linecache_tuple)
|
||||
if old_val == linecache_tuple: break
|
||||
else:
|
||||
filename = f"{base_filename[:-1]}-{count}>"
|
||||
filename = f'{base_filename[:-1]}-{count}>'
|
||||
count += 1
|
||||
_compile_and_eval(script, globs, locs, filename)
|
||||
return locs[name]
|
||||
def make_attr_tuple_class(cls_name: str, attr_names: t.Sequence[str]) -> type[t.Any]:
|
||||
"""Create a tuple subclass to hold class attributes.
|
||||
'''Create a tuple subclass to hold class attributes.
|
||||
|
||||
The subclass is a bare tuple with properties for names.
|
||||
|
||||
class MyClassAttributes(tuple):
|
||||
__slots__ = ()
|
||||
x = property(itemgetter(0))
|
||||
"""
|
||||
'''
|
||||
from . import SHOW_CODEGEN
|
||||
|
||||
attr_class_name = f"{cls_name}Attributes"
|
||||
attr_class_template = [f"class {attr_class_name}(tuple):", " __slots__ = ()",]
|
||||
attr_class_name = f'{cls_name}Attributes'
|
||||
attr_class_template = [f'class {attr_class_name}(tuple):', ' __slots__ = ()',]
|
||||
if attr_names:
|
||||
for i, attr_name in enumerate(attr_names):
|
||||
attr_class_template.append(f" {attr_name} = _attrs_property(_attrs_itemgetter({i}))")
|
||||
attr_class_template.append(f' {attr_name} = _attrs_property(_attrs_itemgetter({i}))')
|
||||
else:
|
||||
attr_class_template.append(" pass")
|
||||
globs: DictStrAny = {"_attrs_itemgetter": itemgetter, "_attrs_property": property}
|
||||
if SHOW_CODEGEN: logger.info("Generated class for %s:\n\n%s", attr_class_name, "\n".join(attr_class_template))
|
||||
_compile_and_eval("\n".join(attr_class_template), globs)
|
||||
attr_class_template.append(' pass')
|
||||
globs: DictStrAny = {'_attrs_itemgetter': itemgetter, '_attrs_property': property}
|
||||
if SHOW_CODEGEN: logger.info('Generated class for %s:\n\n%s', attr_class_name, '\n'.join(attr_class_template))
|
||||
_compile_and_eval('\n'.join(attr_class_template), globs)
|
||||
return globs[attr_class_name]
|
||||
def generate_unique_filename(cls: type[t.Any], func_name: str) -> str:
|
||||
return f"<{cls.__name__} generated {func_name} {cls.__module__}.{getattr(cls, '__qualname__', cls.__name__)}>"
|
||||
@@ -86,10 +86,10 @@ def generate_function(
|
||||
typ: type[t.Any], func_name: str, lines: list[str] | None, args: tuple[str, ...] | None, globs: dict[str, t.Any], annotations: dict[str, t.Any] | None = None
|
||||
) -> AnyCallable:
|
||||
from openllm_core.utils import SHOW_CODEGEN
|
||||
script = "def %s(%s):\n %s\n" % (func_name, ", ".join(args) if args is not None else "", "\n ".join(lines) if lines else "pass")
|
||||
script = 'def %s(%s):\n %s\n' % (func_name, ', '.join(args) if args is not None else '', '\n '.join(lines) if lines else 'pass')
|
||||
meth = _make_method(func_name, script, generate_unique_filename(typ, func_name), globs)
|
||||
if annotations: meth.__annotations__ = annotations
|
||||
if SHOW_CODEGEN: logger.info("Generated script for %s:\n\n%s", typ, script)
|
||||
if SHOW_CODEGEN: logger.info('Generated script for %s:\n\n%s', typ, script)
|
||||
return meth
|
||||
def make_env_transformer(
|
||||
cls: type[openllm_core.LLMConfig], model_name: str, suffix: LiteralString | None = None, default_callback: t.Callable[[str, t.Any], t.Any] | None = None, globs: DictStrAny | None = None,
|
||||
@@ -101,50 +101,50 @@ def make_env_transformer(
|
||||
|
||||
default_callback = identity if default_callback is None else default_callback
|
||||
globs = {} if globs is None else globs
|
||||
globs.update({"__populate_env": dantic.env_converter, "__default_callback": default_callback, "__field_env": field_env_key, "__suffix": suffix or "", "__model_name": model_name,})
|
||||
globs.update({'__populate_env': dantic.env_converter, '__default_callback': default_callback, '__field_env': field_env_key, '__suffix': suffix or '', '__model_name': model_name,})
|
||||
lines: ListStr = [
|
||||
"__env = lambda field_name: __field_env(__model_name, field_name, __suffix)",
|
||||
"return [",
|
||||
" f.evolve(",
|
||||
" default=__populate_env(__default_callback(f.name, f.default), __env(f.name)),",
|
||||
" metadata={",
|
||||
'__env = lambda field_name: __field_env(__model_name, field_name, __suffix)',
|
||||
'return [',
|
||||
' f.evolve(',
|
||||
' default=__populate_env(__default_callback(f.name, f.default), __env(f.name)),',
|
||||
' metadata={',
|
||||
" 'env': f.metadata.get('env', __env(f.name)),",
|
||||
" 'description': f.metadata.get('description', '(not provided)'),",
|
||||
" },",
|
||||
" )",
|
||||
" for f in fields",
|
||||
"]"
|
||||
' },',
|
||||
' )',
|
||||
' for f in fields',
|
||||
']'
|
||||
]
|
||||
fields_ann = "list[attr.Attribute[t.Any]]"
|
||||
return generate_function(cls, "__auto_env", lines, args=("_", "fields"), globs=globs, annotations={"_": "type[LLMConfig]", "fields": fields_ann, "return": fields_ann})
|
||||
fields_ann = 'list[attr.Attribute[t.Any]]'
|
||||
return generate_function(cls, '__auto_env', lines, args=('_', 'fields'), globs=globs, annotations={'_': 'type[LLMConfig]', 'fields': fields_ann, 'return': fields_ann})
|
||||
def gen_sdk(func: _T, name: str | None = None, **attrs: t.Any) -> _T:
|
||||
"""Enhance sdk with nice repr that plays well with your brain."""
|
||||
'''Enhance sdk with nice repr that plays well with your brain.'''
|
||||
from openllm_core.utils import ReprMixin
|
||||
if name is None: name = func.__name__.strip("_")
|
||||
if name is None: name = func.__name__.strip('_')
|
||||
_signatures = inspect.signature(func).parameters
|
||||
|
||||
def _repr(self: ReprMixin) -> str:
|
||||
return f"<generated function {name} {orjson.dumps(dict(self.__repr_args__()), option=orjson.OPT_NON_STR_KEYS | orjson.OPT_INDENT_2).decode()}>"
|
||||
return f'<generated function {name} {orjson.dumps(dict(self.__repr_args__()), option=orjson.OPT_NON_STR_KEYS | orjson.OPT_INDENT_2).decode()}>'
|
||||
|
||||
def _repr_args(self: ReprMixin) -> t.Iterator[t.Tuple[str, t.Any]]:
|
||||
return ((k, _signatures[k].annotation) for k in self.__repr_keys__)
|
||||
|
||||
if func.__doc__ is None: doc = f"Generated SDK for {func.__name__}"
|
||||
if func.__doc__ is None: doc = f'Generated SDK for {func.__name__}'
|
||||
else: doc = func.__doc__
|
||||
return t.cast(
|
||||
_T,
|
||||
functools.update_wrapper(
|
||||
types.new_class(
|
||||
name, (t.cast("PartialAny", functools.partial), ReprMixin),
|
||||
name, (t.cast('PartialAny', functools.partial), ReprMixin),
|
||||
exec_body=lambda ns: ns.update({
|
||||
"__repr_keys__": property(lambda _: [i for i in _signatures.keys() if not i.startswith("_")]),
|
||||
"__repr_args__": _repr_args,
|
||||
"__repr__": _repr,
|
||||
"__doc__": inspect.cleandoc(doc),
|
||||
"__module__": "openllm"
|
||||
'__repr_keys__': property(lambda _: [i for i in _signatures.keys() if not i.startswith('_')]),
|
||||
'__repr_args__': _repr_args,
|
||||
'__repr__': _repr,
|
||||
'__doc__': inspect.cleandoc(doc),
|
||||
'__module__': 'openllm'
|
||||
}),
|
||||
)(func, **attrs),
|
||||
func,
|
||||
)
|
||||
)
|
||||
__all__ = ["gen_sdk", "make_attr_tuple_class", "make_env_transformer", "generate_unique_filename", "generate_function"]
|
||||
__all__ = ['gen_sdk', 'make_attr_tuple_class', 'make_env_transformer', 'generate_unique_filename', 'generate_function']
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""An interface provides the best of pydantic and attrs."""
|
||||
'''An interface provides the best of pydantic and attrs.'''
|
||||
from __future__ import annotations
|
||||
import functools, importlib, os, sys, typing as t
|
||||
from enum import Enum
|
||||
@@ -6,44 +6,44 @@ import attr, click, click_option_group as cog, inflection, orjson
|
||||
from click import ParamType, shell_completion as sc, types as click_types
|
||||
if t.TYPE_CHECKING: from attr import _ValidatorType
|
||||
AnyCallable = t.Callable[..., t.Any]
|
||||
FC = t.TypeVar("FC", bound=t.Union[AnyCallable, click.Command])
|
||||
FC = t.TypeVar('FC', bound=t.Union[AnyCallable, click.Command])
|
||||
|
||||
__all__ = [
|
||||
"FC",
|
||||
"attrs_to_options",
|
||||
"Field",
|
||||
"parse_type",
|
||||
"is_typing",
|
||||
"is_literal",
|
||||
"ModuleType",
|
||||
"EnumChoice",
|
||||
"LiteralChoice",
|
||||
"allows_multiple",
|
||||
"is_mapping",
|
||||
"is_container",
|
||||
"parse_container_args",
|
||||
"parse_single_arg",
|
||||
"CUDA",
|
||||
"JsonType",
|
||||
"BytesType"
|
||||
'FC',
|
||||
'attrs_to_options',
|
||||
'Field',
|
||||
'parse_type',
|
||||
'is_typing',
|
||||
'is_literal',
|
||||
'ModuleType',
|
||||
'EnumChoice',
|
||||
'LiteralChoice',
|
||||
'allows_multiple',
|
||||
'is_mapping',
|
||||
'is_container',
|
||||
'parse_container_args',
|
||||
'parse_single_arg',
|
||||
'CUDA',
|
||||
'JsonType',
|
||||
'BytesType'
|
||||
]
|
||||
def __dir__() -> list[str]:
|
||||
return sorted(__all__)
|
||||
def attrs_to_options(name: str, field: attr.Attribute[t.Any], model_name: str, typ: t.Any = None, suffix_generation: bool = False, suffix_sampling: bool = False) -> t.Callable[[FC], FC]:
|
||||
# TODO: support parsing nested attrs class and Union
|
||||
envvar = field.metadata["env"]
|
||||
envvar = field.metadata['env']
|
||||
dasherized = inflection.dasherize(name)
|
||||
underscored = inflection.underscore(name)
|
||||
|
||||
if typ in (None, attr.NOTHING):
|
||||
typ = field.type
|
||||
if typ is None: raise RuntimeError(f"Failed to parse type for {name}")
|
||||
if typ is None: raise RuntimeError(f'Failed to parse type for {name}')
|
||||
|
||||
full_option_name = f"--{dasherized}"
|
||||
if field.type is bool: full_option_name += f"/--no-{dasherized}"
|
||||
if suffix_generation: identifier = f"{model_name}_generation_{underscored}"
|
||||
elif suffix_sampling: identifier = f"{model_name}_sampling_{underscored}"
|
||||
else: identifier = f"{model_name}_{underscored}"
|
||||
full_option_name = f'--{dasherized}'
|
||||
if field.type is bool: full_option_name += f'/--no-{dasherized}'
|
||||
if suffix_generation: identifier = f'{model_name}_generation_{underscored}'
|
||||
elif suffix_sampling: identifier = f'{model_name}_sampling_{underscored}'
|
||||
else: identifier = f'{model_name}_{underscored}'
|
||||
|
||||
return cog.optgroup.option(
|
||||
identifier,
|
||||
@@ -53,7 +53,7 @@ def attrs_to_options(name: str, field: attr.Attribute[t.Any], model_name: str, t
|
||||
default=field.default if field.default not in (attr.NOTHING, None) else None,
|
||||
show_default=True,
|
||||
multiple=allows_multiple(typ) if typ else False,
|
||||
help=field.metadata.get("description", "(No description provided)"),
|
||||
help=field.metadata.get('description', '(No description provided)'),
|
||||
show_envvar=True,
|
||||
envvar=envvar,
|
||||
)
|
||||
@@ -100,13 +100,13 @@ def Field(
|
||||
for this given Field.
|
||||
**attrs: The rest of the arguments are passed to attr.field
|
||||
"""
|
||||
metadata = attrs.pop("metadata", {})
|
||||
if description is None: description = "(No description provided)"
|
||||
metadata["description"] = description
|
||||
if env is not None: metadata["env"] = env
|
||||
metadata = attrs.pop('metadata', {})
|
||||
if description is None: description = '(No description provided)'
|
||||
metadata['description'] = description
|
||||
if env is not None: metadata['env'] = env
|
||||
piped: list[_ValidatorType[t.Any]] = []
|
||||
|
||||
converter = attrs.pop("converter", None)
|
||||
converter = attrs.pop('converter', None)
|
||||
if use_default_converter: converter = functools.partial(env_converter, env=env)
|
||||
|
||||
if ge is not None: piped.append(attr.validators.ge(ge))
|
||||
@@ -117,15 +117,15 @@ def Field(
|
||||
elif len(piped) == 1: _validator = piped[0]
|
||||
else: _validator = attr.validators.and_(*piped)
|
||||
|
||||
factory = attrs.pop("factory", None)
|
||||
factory = attrs.pop('factory', None)
|
||||
if factory is not None and default is not None: raise RuntimeError("'factory' and 'default' are mutually exclusive.")
|
||||
# NOTE: the behaviour of this is we will respect factory over the default
|
||||
if factory is not None: attrs["factory"] = factory
|
||||
else: attrs["default"] = default
|
||||
if factory is not None: attrs['factory'] = factory
|
||||
else: attrs['default'] = default
|
||||
|
||||
kw_only = attrs.pop("kw_only", False)
|
||||
kw_only = attrs.pop('kw_only', False)
|
||||
if auto_default and kw_only:
|
||||
attrs.pop("default")
|
||||
attrs.pop('default')
|
||||
|
||||
return attr.field(metadata=metadata, validator=_validator, converter=converter, **attrs)
|
||||
def parse_type(field_type: t.Any) -> ParamType | tuple[ParamType, ...]:
|
||||
@@ -140,7 +140,7 @@ def parse_type(field_type: t.Any) -> ParamType | tuple[ParamType, ...]:
|
||||
from . import lenient_issubclass
|
||||
|
||||
if t.get_origin(field_type) is t.Union:
|
||||
raise NotImplementedError("Unions are not supported")
|
||||
raise NotImplementedError('Unions are not supported')
|
||||
# enumeration strings or other Enum derivatives
|
||||
if lenient_issubclass(field_type, Enum):
|
||||
return EnumChoice(enum=field_type, case_sensitive=True)
|
||||
@@ -159,20 +159,20 @@ def parse_type(field_type: t.Any) -> ParamType | tuple[ParamType, ...]:
|
||||
# return the current type: it should be a primitive
|
||||
return field_type
|
||||
def is_typing(field_type: type) -> bool:
|
||||
"""Checks whether the current type is a module-like type.
|
||||
'''Checks whether the current type is a module-like type.
|
||||
|
||||
Args:
|
||||
field_type: pydantic field type
|
||||
|
||||
Returns:
|
||||
bool: true if the type is itself a type
|
||||
"""
|
||||
'''
|
||||
raw = t.get_origin(field_type)
|
||||
if raw is None: return False
|
||||
if raw is type or raw is t.Type: return True
|
||||
return False
|
||||
def is_literal(field_type: type) -> bool:
|
||||
"""Checks whether the given field type is a Literal type or not.
|
||||
'''Checks whether the given field type is a Literal type or not.
|
||||
|
||||
Literals are weird: isinstance and subclass do not work, so you compare
|
||||
the origin with the Literal declaration itself.
|
||||
@@ -182,15 +182,15 @@ def is_literal(field_type: type) -> bool:
|
||||
|
||||
Returns:
|
||||
bool: true if Literal type, false otherwise
|
||||
"""
|
||||
'''
|
||||
origin = t.get_origin(field_type)
|
||||
return origin is not None and origin is t.Literal
|
||||
class ModuleType(ParamType):
|
||||
name = "module"
|
||||
name = 'module'
|
||||
|
||||
def _import_object(self, value: str) -> t.Any:
|
||||
module_name, class_name = value.rsplit(".", maxsplit=1)
|
||||
if not all(s.isidentifier() for s in module_name.split(".")): raise ValueError(f"'{value}' is not a valid module name")
|
||||
module_name, class_name = value.rsplit('.', maxsplit=1)
|
||||
if not all(s.isidentifier() for s in module_name.split('.')): raise ValueError(f"'{value}' is not a valid module name")
|
||||
if not class_name.isidentifier(): raise ValueError(f"Variable '{class_name}' is not a valid identifier")
|
||||
|
||||
module = importlib.import_module(module_name)
|
||||
@@ -207,15 +207,15 @@ class ModuleType(ParamType):
|
||||
except Exception as exc:
|
||||
self.fail(f"'{value}' is not a valid object ({type(exc)}: {exc!s})", param, ctx)
|
||||
class EnumChoice(click.Choice):
|
||||
name = "enum"
|
||||
name = 'enum'
|
||||
|
||||
def __init__(self, enum: Enum, case_sensitive: bool = False):
|
||||
"""Enum type support for click that extends ``click.Choice``.
|
||||
'''Enum type support for click that extends ``click.Choice``.
|
||||
|
||||
Args:
|
||||
enum: Given enum
|
||||
case_sensitive: Whether this choice should be case case_sensitive.
|
||||
"""
|
||||
'''
|
||||
self.mapping = enum
|
||||
self.internal_type = type(enum)
|
||||
choices: list[t.Any] = [e.name for e in enum.__class__]
|
||||
@@ -229,14 +229,14 @@ class EnumChoice(click.Choice):
|
||||
result = self.internal_type[result]
|
||||
return result
|
||||
class LiteralChoice(EnumChoice):
|
||||
name = "literal"
|
||||
name = 'literal'
|
||||
|
||||
def __init__(self, value: t.Any, case_sensitive: bool = False):
|
||||
"""Literal support for click."""
|
||||
'''Literal support for click.'''
|
||||
# expect every literal value to belong to the same primitive type
|
||||
values = list(value.__args__)
|
||||
item_type = type(values[0])
|
||||
if not all(isinstance(v, item_type) for v in values): raise ValueError(f"Field {value} contains items of different types.")
|
||||
if not all(isinstance(v, item_type) for v in values): raise ValueError(f'Field {value} contains items of different types.')
|
||||
_mapping = {str(v): v for v in values}
|
||||
super(EnumChoice, self).__init__(list(_mapping), case_sensitive)
|
||||
self.internal_type = item_type
|
||||
@@ -265,14 +265,14 @@ def allows_multiple(field_type: type[t.Any]) -> bool:
|
||||
return not isinstance(args, tuple)
|
||||
return False
|
||||
def is_mapping(field_type: type) -> bool:
|
||||
"""Checks whether this field represents a dictionary or JSON object.
|
||||
'''Checks whether this field represents a dictionary or JSON object.
|
||||
|
||||
Args:
|
||||
field_type (type): pydantic type
|
||||
|
||||
Returns:
|
||||
bool: true when the field is a dict-like object, false otherwise.
|
||||
"""
|
||||
'''
|
||||
# Early out for standard containers.
|
||||
from . import lenient_issubclass
|
||||
if lenient_issubclass(field_type, t.Mapping): return True
|
||||
@@ -299,16 +299,16 @@ def is_container(field_type: type) -> bool:
|
||||
if origin is None: return False
|
||||
return lenient_issubclass(origin, t.Container)
|
||||
def parse_container_args(field_type: type[t.Any]) -> ParamType | tuple[ParamType, ...]:
|
||||
"""Parses the arguments inside a container type (lists, tuples and so on).
|
||||
'''Parses the arguments inside a container type (lists, tuples and so on).
|
||||
|
||||
Args:
|
||||
field_type: pydantic field type
|
||||
|
||||
Returns:
|
||||
ParamType | tuple[ParamType]: single click-compatible type or a tuple
|
||||
"""
|
||||
'''
|
||||
if not is_container(field_type):
|
||||
raise ValueError("Field type is not a container type.")
|
||||
raise ValueError('Field type is not a container type.')
|
||||
args = t.get_args(field_type)
|
||||
# Early out for untyped containers: standard lists, tuples, List[Any]
|
||||
# Use strings when the type is unknown, avoid click's type guessing
|
||||
@@ -341,7 +341,7 @@ def parse_single_arg(arg: type) -> ParamType:
|
||||
if lenient_issubclass(arg, bytes): return BytesType()
|
||||
return click_types.convert_type(arg)
|
||||
class BytesType(ParamType):
|
||||
name = "bytes"
|
||||
name = 'bytes'
|
||||
|
||||
def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
|
||||
if isinstance(value, bytes): return value
|
||||
@@ -349,9 +349,9 @@ class BytesType(ParamType):
|
||||
return str.encode(value)
|
||||
except Exception as exc:
|
||||
self.fail(f"'{value}' is not a valid string ({exc!s})", param, ctx)
|
||||
CYGWIN = sys.platform.startswith("cygwin")
|
||||
WIN = sys.platform.startswith("win")
|
||||
if sys.platform.startswith("win") and WIN:
|
||||
CYGWIN = sys.platform.startswith('cygwin')
|
||||
WIN = sys.platform.startswith('win')
|
||||
if sys.platform.startswith('win') and WIN:
|
||||
|
||||
def _get_argv_encoding() -> str:
|
||||
import locale
|
||||
@@ -359,20 +359,20 @@ if sys.platform.startswith("win") and WIN:
|
||||
else:
|
||||
|
||||
def _get_argv_encoding() -> str:
|
||||
return getattr(sys.stdin, "encoding", None) or sys.getfilesystemencoding()
|
||||
return getattr(sys.stdin, 'encoding', None) or sys.getfilesystemencoding()
|
||||
class CudaValueType(ParamType):
|
||||
name = "cuda"
|
||||
envvar_list_splitter = ","
|
||||
name = 'cuda'
|
||||
envvar_list_splitter = ','
|
||||
is_composite = True
|
||||
|
||||
def split_envvar_value(self, rv: str) -> t.Sequence[str]:
|
||||
var = tuple(i for i in rv.split(self.envvar_list_splitter))
|
||||
if "-1" in var:
|
||||
return var[:var.index("-1")]
|
||||
if '-1' in var:
|
||||
return var[:var.index('-1')]
|
||||
return var
|
||||
|
||||
def shell_complete(self, ctx: click.Context, param: click.Parameter, incomplete: str) -> list[sc.CompletionItem]:
|
||||
"""Return a list of :class:`~click.shell_completion.CompletionItem` objects for the incomplete value.
|
||||
'''Return a list of :class:`~click.shell_completion.CompletionItem` objects for the incomplete value.
|
||||
|
||||
Most types do not provide completions, but some do, and this allows custom types to provide custom completions as well.
|
||||
|
||||
@@ -380,10 +380,10 @@ class CudaValueType(ParamType):
|
||||
ctx: Invocation context for this command.
|
||||
param: The parameter that is requesting completion.
|
||||
incomplete: Value being completed. May be empty.
|
||||
"""
|
||||
'''
|
||||
from openllm_core.utils import available_devices
|
||||
mapping = incomplete.split(self.envvar_list_splitter) if incomplete else available_devices()
|
||||
return [sc.CompletionItem(str(i), help=f"CUDA device index {i}") for i in mapping]
|
||||
return [sc.CompletionItem(str(i), help=f'CUDA device index {i}') for i in mapping]
|
||||
|
||||
def convert(self, value: t.Any, param: click.Parameter | None, ctx: click.Context | None) -> t.Any:
|
||||
typ = click_types.convert_type(str)
|
||||
@@ -397,16 +397,16 @@ class CudaValueType(ParamType):
|
||||
try:
|
||||
value = value.decode(fs_enc)
|
||||
except UnicodeError:
|
||||
value = value.decode("utf-8", "replace")
|
||||
value = value.decode('utf-8', 'replace')
|
||||
else:
|
||||
value = value.decode("utf-8", "replace")
|
||||
return tuple(typ(x, param, ctx) for x in value.split(","))
|
||||
value = value.decode('utf-8', 'replace')
|
||||
return tuple(typ(x, param, ctx) for x in value.split(','))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "STRING"
|
||||
return 'STRING'
|
||||
CUDA = CudaValueType()
|
||||
class JsonType(ParamType):
|
||||
name = "json"
|
||||
name = 'json'
|
||||
|
||||
def __init__(self, should_load: bool = True) -> None:
|
||||
"""Support JSON type for click.ParamType.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Some imports utils are vendorred from transformers/utils/import_utils.py for performance reasons."""
|
||||
'''Some imports utils are vendorred from transformers/utils/import_utils.py for performance reasons.'''
|
||||
from __future__ import annotations
|
||||
import importlib, importlib.metadata, importlib.util, logging, os, abc, typing as t, openllm_core
|
||||
from collections import OrderedDict
|
||||
@@ -12,13 +12,13 @@ if t.TYPE_CHECKING:
|
||||
BackendOrderedDict = OrderedDict[str, t.Tuple[t.Callable[[], bool], str]]
|
||||
from openllm_core._typing_compat import LiteralRuntime
|
||||
logger = logging.getLogger(__name__)
|
||||
OPTIONAL_DEPENDENCIES = {"opt", "flan-t5", "vllm", "fine-tune", "ggml", "agents", "openai", "playground", "gptq"}
|
||||
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
|
||||
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
|
||||
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
||||
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
||||
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
|
||||
FORCE_TF_AVAILABLE = os.environ.get("FORCE_TF_AVAILABLE", "AUTO").upper()
|
||||
OPTIONAL_DEPENDENCIES = {'opt', 'flan-t5', 'vllm', 'fine-tune', 'ggml', 'agents', 'openai', 'playground', 'gptq'}
|
||||
ENV_VARS_TRUE_VALUES = {'1', 'ON', 'YES', 'TRUE'}
|
||||
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({'AUTO'})
|
||||
USE_TF = os.environ.get('USE_TF', 'AUTO').upper()
|
||||
USE_TORCH = os.environ.get('USE_TORCH', 'AUTO').upper()
|
||||
USE_JAX = os.environ.get('USE_FLAX', 'AUTO').upper()
|
||||
FORCE_TF_AVAILABLE = os.environ.get('FORCE_TF_AVAILABLE', 'AUTO').upper()
|
||||
def _is_package_available(package: str) -> bool:
|
||||
_package_available = importlib.util.find_spec(package) is not None
|
||||
if _package_available:
|
||||
@@ -27,26 +27,26 @@ def _is_package_available(package: str) -> bool:
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
_package_available = False
|
||||
return _package_available
|
||||
_torch_available = importlib.util.find_spec("torch") is not None
|
||||
_tf_available = importlib.util.find_spec("tensorflow") is not None
|
||||
_flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None
|
||||
_vllm_available = importlib.util.find_spec("vllm") is not None
|
||||
_transformers_available = _is_package_available("transformers")
|
||||
_grpc_available = importlib.util.find_spec("grpc") is not None
|
||||
_grpc_health_available = importlib.util.find_spec("grpc_health") is not None
|
||||
_peft_available = _is_package_available("peft")
|
||||
_einops_available = _is_package_available("einops")
|
||||
_cpm_kernel_available = _is_package_available("cpm_kernels")
|
||||
_bitsandbytes_available = _is_package_available("bitsandbytes")
|
||||
_datasets_available = _is_package_available("datasets")
|
||||
_triton_available = _is_package_available("triton")
|
||||
_jupyter_available = _is_package_available("jupyter")
|
||||
_jupytext_available = _is_package_available("jupytext")
|
||||
_notebook_available = _is_package_available("notebook")
|
||||
_autogptq_available = _is_package_available("auto_gptq")
|
||||
_sentencepiece_available = _is_package_available("sentencepiece")
|
||||
_xformers_available = _is_package_available("xformers")
|
||||
_fairscale_available = _is_package_available("fairscale")
|
||||
_torch_available = importlib.util.find_spec('torch') is not None
|
||||
_tf_available = importlib.util.find_spec('tensorflow') is not None
|
||||
_flax_available = importlib.util.find_spec('jax') is not None and importlib.util.find_spec('flax') is not None
|
||||
_vllm_available = importlib.util.find_spec('vllm') is not None
|
||||
_transformers_available = _is_package_available('transformers')
|
||||
_grpc_available = importlib.util.find_spec('grpc') is not None
|
||||
_grpc_health_available = importlib.util.find_spec('grpc_health') is not None
|
||||
_peft_available = _is_package_available('peft')
|
||||
_einops_available = _is_package_available('einops')
|
||||
_cpm_kernel_available = _is_package_available('cpm_kernels')
|
||||
_bitsandbytes_available = _is_package_available('bitsandbytes')
|
||||
_datasets_available = _is_package_available('datasets')
|
||||
_triton_available = _is_package_available('triton')
|
||||
_jupyter_available = _is_package_available('jupyter')
|
||||
_jupytext_available = _is_package_available('jupytext')
|
||||
_notebook_available = _is_package_available('notebook')
|
||||
_autogptq_available = _is_package_available('auto_gptq')
|
||||
_sentencepiece_available = _is_package_available('sentencepiece')
|
||||
_xformers_available = _is_package_available('xformers')
|
||||
_fairscale_available = _is_package_available('fairscale')
|
||||
def is_transformers_available() -> bool:
|
||||
return _transformers_available
|
||||
def is_grpc_available() -> bool:
|
||||
@@ -54,9 +54,9 @@ def is_grpc_available() -> bool:
|
||||
def is_grpc_health_available() -> bool:
|
||||
return _grpc_health_available
|
||||
def is_transformers_supports_kbit() -> bool:
|
||||
return pkg.pkg_version_info("transformers")[:2] >= (4, 30)
|
||||
return pkg.pkg_version_info('transformers')[:2] >= (4, 30)
|
||||
def is_transformers_supports_agent() -> bool:
|
||||
return pkg.pkg_version_info("transformers")[:2] >= (4, 29)
|
||||
return pkg.pkg_version_info('transformers')[:2] >= (4, 29)
|
||||
def is_jupyter_available() -> bool:
|
||||
return _jupyter_available
|
||||
def is_jupytext_available() -> bool:
|
||||
@@ -90,11 +90,11 @@ def is_torch_available() -> bool:
|
||||
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
|
||||
if _torch_available:
|
||||
try:
|
||||
importlib.metadata.version("torch")
|
||||
importlib.metadata.version('torch')
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
_torch_available = False
|
||||
else:
|
||||
logger.info("Disabling PyTorch because USE_TF is set")
|
||||
logger.info('Disabling PyTorch because USE_TF is set')
|
||||
_torch_available = False
|
||||
return _torch_available
|
||||
def is_tf_available() -> bool:
|
||||
@@ -105,17 +105,17 @@ def is_tf_available() -> bool:
|
||||
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
|
||||
if _tf_available:
|
||||
candidates = (
|
||||
"tensorflow",
|
||||
"tensorflow-cpu",
|
||||
"tensorflow-gpu",
|
||||
"tf-nightly",
|
||||
"tf-nightly-cpu",
|
||||
"tf-nightly-gpu",
|
||||
"intel-tensorflow",
|
||||
"intel-tensorflow-avx512",
|
||||
"tensorflow-rocm",
|
||||
"tensorflow-macos",
|
||||
"tensorflow-aarch64",
|
||||
'tensorflow',
|
||||
'tensorflow-cpu',
|
||||
'tensorflow-gpu',
|
||||
'tf-nightly',
|
||||
'tf-nightly-cpu',
|
||||
'tf-nightly-gpu',
|
||||
'intel-tensorflow',
|
||||
'intel-tensorflow-avx512',
|
||||
'tensorflow-rocm',
|
||||
'tensorflow-macos',
|
||||
'tensorflow-aarch64',
|
||||
)
|
||||
_tf_version = None
|
||||
# For the metadata, we have to look for both tensorflow and tensorflow-cpu
|
||||
@@ -127,11 +127,11 @@ def is_tf_available() -> bool:
|
||||
pass # Ok to ignore here since we actually need to check for all possible tensorflow distribution.
|
||||
_tf_available = _tf_version is not None
|
||||
if _tf_available:
|
||||
if _tf_version and packaging.version.parse(_tf_version) < packaging.version.parse("2"):
|
||||
logger.info("TensorFlow found but with version %s. OpenLLM only supports TF 2.x", _tf_version)
|
||||
if _tf_version and packaging.version.parse(_tf_version) < packaging.version.parse('2'):
|
||||
logger.info('TensorFlow found but with version %s. OpenLLM only supports TF 2.x', _tf_version)
|
||||
_tf_available = False
|
||||
else:
|
||||
logger.info("Disabling Tensorflow because USE_TORCH is set")
|
||||
logger.info('Disabling Tensorflow because USE_TORCH is set')
|
||||
_tf_available = False
|
||||
return _tf_available
|
||||
def is_flax_available() -> bool:
|
||||
@@ -139,14 +139,14 @@ def is_flax_available() -> bool:
|
||||
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
|
||||
if _flax_available:
|
||||
try:
|
||||
importlib.metadata.version("jax")
|
||||
importlib.metadata.version("flax")
|
||||
importlib.metadata.version('jax')
|
||||
importlib.metadata.version('flax')
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
_flax_available = False
|
||||
else:
|
||||
_flax_available = False
|
||||
return _flax_available
|
||||
VLLM_IMPORT_ERROR_WITH_PYTORCH = """\
|
||||
VLLM_IMPORT_ERROR_WITH_PYTORCH = '''\
|
||||
{0} requires the vLLM library but it was not found in your environment.
|
||||
However, we were able to find a PyTorch installation. PyTorch classes do not begin
|
||||
with "VLLM", but are otherwise identically named to our PyTorch classes.
|
||||
@@ -154,8 +154,8 @@ If you want to use PyTorch, please use those classes instead!
|
||||
|
||||
If you really do want to use vLLM, please follow the instructions on the
|
||||
installation page https://github.com/vllm-project/vllm that match your environment.
|
||||
"""
|
||||
VLLM_IMPORT_ERROR_WITH_TF = """\
|
||||
'''
|
||||
VLLM_IMPORT_ERROR_WITH_TF = '''\
|
||||
{0} requires the vLLM library but it was not found in your environment.
|
||||
However, we were able to find a TensorFlow installation. TensorFlow classes begin
|
||||
with "TF", but are otherwise identically named to the PyTorch classes. This
|
||||
@@ -164,8 +164,8 @@ If you want to use TensorFlow, please use TF classes instead!
|
||||
|
||||
If you really do want to use vLLM, please follow the instructions on the
|
||||
installation page https://github.com/vllm-project/vllm that match your environment.
|
||||
"""
|
||||
VLLM_IMPORT_ERROR_WITH_FLAX = """\
|
||||
'''
|
||||
VLLM_IMPORT_ERROR_WITH_FLAX = '''\
|
||||
{0} requires the vLLM library but it was not found in your environment.
|
||||
However, we were able to find a Flax installation. Flax classes begin
|
||||
with "Flax", but are otherwise identically named to the PyTorch classes. This
|
||||
@@ -174,8 +174,8 @@ If you want to use Flax, please use Flax classes instead!
|
||||
|
||||
If you really do want to use vLLM, please follow the instructions on the
|
||||
installation page https://github.com/vllm-project/vllm that match your environment.
|
||||
"""
|
||||
PYTORCH_IMPORT_ERROR_WITH_TF = """\
|
||||
'''
|
||||
PYTORCH_IMPORT_ERROR_WITH_TF = '''\
|
||||
{0} requires the PyTorch library but it was not found in your environment.
|
||||
However, we were able to find a TensorFlow installation. TensorFlow classes begin
|
||||
with "TF", but are otherwise identically named to the PyTorch classes. This
|
||||
@@ -185,8 +185,8 @@ If you want to use TensorFlow, please use TF classes instead!
|
||||
If you really do want to use PyTorch please go to
|
||||
https://pytorch.org/get-started/locally/ and follow the instructions that
|
||||
match your environment.
|
||||
"""
|
||||
TF_IMPORT_ERROR_WITH_PYTORCH = """\
|
||||
'''
|
||||
TF_IMPORT_ERROR_WITH_PYTORCH = '''\
|
||||
{0} requires the TensorFlow library but it was not found in your environment.
|
||||
However, we were able to find a PyTorch installation. PyTorch classes do not begin
|
||||
with "TF", but are otherwise identically named to our TF classes.
|
||||
@@ -194,97 +194,97 @@ If you want to use PyTorch, please use those classes instead!
|
||||
|
||||
If you really do want to use TensorFlow, please follow the instructions on the
|
||||
installation page https://www.tensorflow.org/install that match your environment.
|
||||
"""
|
||||
TENSORFLOW_IMPORT_ERROR = """{0} requires the TensorFlow library but it was not found in your environment.
|
||||
'''
|
||||
TENSORFLOW_IMPORT_ERROR = '''{0} requires the TensorFlow library but it was not found in your environment.
|
||||
Checkout the instructions on the installation page: https://www.tensorflow.org/install and follow the
|
||||
ones that match your environment. Please note that you may need to restart your runtime after installation.
|
||||
"""
|
||||
FLAX_IMPORT_ERROR = """{0} requires the FLAX library but it was not found in your environment.
|
||||
'''
|
||||
FLAX_IMPORT_ERROR = '''{0} requires the FLAX library but it was not found in your environment.
|
||||
Checkout the instructions on the installation page: https://github.com/google/flax and follow the
|
||||
ones that match your environment. Please note that you may need to restart your runtime after installation.
|
||||
"""
|
||||
PYTORCH_IMPORT_ERROR = """{0} requires the PyTorch library but it was not found in your environment.
|
||||
'''
|
||||
PYTORCH_IMPORT_ERROR = '''{0} requires the PyTorch library but it was not found in your environment.
|
||||
Checkout the instructions on the installation page: https://pytorch.org/get-started/locally/ and follow the
|
||||
ones that match your environment. Please note that you may need to restart your runtime after installation.
|
||||
"""
|
||||
VLLM_IMPORT_ERROR = """{0} requires the vLLM library but it was not found in your environment.
|
||||
'''
|
||||
VLLM_IMPORT_ERROR = '''{0} requires the vLLM library but it was not found in your environment.
|
||||
Checkout the instructions on the installation page: https://github.com/vllm-project/vllm
|
||||
ones that match your environment. Please note that you may need to restart your runtime after installation.
|
||||
"""
|
||||
CPM_KERNELS_IMPORT_ERROR = """{0} requires the cpm_kernels library but it was not found in your environment.
|
||||
'''
|
||||
CPM_KERNELS_IMPORT_ERROR = '''{0} requires the cpm_kernels library but it was not found in your environment.
|
||||
You can install it with pip: `pip install cpm_kernels`. Please note that you may need to restart your
|
||||
runtime after installation.
|
||||
"""
|
||||
EINOPS_IMPORT_ERROR = """{0} requires the einops library but it was not found in your environment.
|
||||
'''
|
||||
EINOPS_IMPORT_ERROR = '''{0} requires the einops library but it was not found in your environment.
|
||||
You can install it with pip: `pip install einops`. Please note that you may need to restart
|
||||
your runtime after installation.
|
||||
"""
|
||||
TRITON_IMPORT_ERROR = """{0} requires the triton library but it was not found in your environment.
|
||||
'''
|
||||
TRITON_IMPORT_ERROR = '''{0} requires the triton library but it was not found in your environment.
|
||||
You can install it with pip: 'pip install \"git+https://github.com/openai/triton.git#egg=triton&subdirectory=python\"'.
|
||||
Please note that you may need to restart your runtime after installation.
|
||||
"""
|
||||
DATASETS_IMPORT_ERROR = """{0} requires the datasets library but it was not found in your environment.
|
||||
'''
|
||||
DATASETS_IMPORT_ERROR = '''{0} requires the datasets library but it was not found in your environment.
|
||||
You can install it with pip: `pip install datasets`. Please note that you may need to restart
|
||||
your runtime after installation.
|
||||
"""
|
||||
PEFT_IMPORT_ERROR = """{0} requires the peft library but it was not found in your environment.
|
||||
'''
|
||||
PEFT_IMPORT_ERROR = '''{0} requires the peft library but it was not found in your environment.
|
||||
You can install it with pip: `pip install peft`. Please note that you may need to restart
|
||||
your runtime after installation.
|
||||
"""
|
||||
BITSANDBYTES_IMPORT_ERROR = """{0} requires the bitsandbytes library but it was not found in your environment.
|
||||
'''
|
||||
BITSANDBYTES_IMPORT_ERROR = '''{0} requires the bitsandbytes library but it was not found in your environment.
|
||||
You can install it with pip: `pip install bitsandbytes`. Please note that you may need to restart
|
||||
your runtime after installation.
|
||||
"""
|
||||
AUTOGPTQ_IMPORT_ERROR = """{0} requires the auto-gptq library but it was not found in your environment.
|
||||
'''
|
||||
AUTOGPTQ_IMPORT_ERROR = '''{0} requires the auto-gptq library but it was not found in your environment.
|
||||
You can install it with pip: `pip install auto-gptq`. Please note that you may need to restart
|
||||
your runtime after installation.
|
||||
"""
|
||||
SENTENCEPIECE_IMPORT_ERROR = """{0} requires the sentencepiece library but it was not found in your environment.
|
||||
'''
|
||||
SENTENCEPIECE_IMPORT_ERROR = '''{0} requires the sentencepiece library but it was not found in your environment.
|
||||
You can install it with pip: `pip install sentencepiece`. Please note that you may need to restart
|
||||
your runtime after installation.
|
||||
"""
|
||||
XFORMERS_IMPORT_ERROR = """{0} requires the xformers library but it was not found in your environment.
|
||||
'''
|
||||
XFORMERS_IMPORT_ERROR = '''{0} requires the xformers library but it was not found in your environment.
|
||||
You can install it with pip: `pip install xformers`. Please note that you may need to restart
|
||||
your runtime after installation.
|
||||
"""
|
||||
FAIRSCALE_IMPORT_ERROR = """{0} requires the fairscale library but it was not found in your environment.
|
||||
'''
|
||||
FAIRSCALE_IMPORT_ERROR = '''{0} requires the fairscale library but it was not found in your environment.
|
||||
You can install it with pip: `pip install fairscale`. Please note that you may need to restart
|
||||
your runtime after installation.
|
||||
"""
|
||||
'''
|
||||
|
||||
BACKENDS_MAPPING: BackendOrderedDict = OrderedDict([("flax", (is_flax_available, FLAX_IMPORT_ERROR)), ("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)), (
|
||||
"torch", (is_torch_available, PYTORCH_IMPORT_ERROR)
|
||||
), ("vllm", (is_vllm_available, VLLM_IMPORT_ERROR)), ("cpm_kernels", (is_cpm_kernels_available, CPM_KERNELS_IMPORT_ERROR)), ("einops", (is_einops_available, EINOPS_IMPORT_ERROR)), (
|
||||
"triton", (is_triton_available, TRITON_IMPORT_ERROR)
|
||||
), ("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)), ("peft", (is_peft_available, PEFT_IMPORT_ERROR)), ("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)), (
|
||||
"auto-gptq", (is_autogptq_available, AUTOGPTQ_IMPORT_ERROR)
|
||||
), ("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)), ("xformers", (is_xformers_available, XFORMERS_IMPORT_ERROR)), (
|
||||
"fairscale", (is_fairscale_available, FAIRSCALE_IMPORT_ERROR)
|
||||
BACKENDS_MAPPING: BackendOrderedDict = OrderedDict([('flax', (is_flax_available, FLAX_IMPORT_ERROR)), ('tf', (is_tf_available, TENSORFLOW_IMPORT_ERROR)), (
|
||||
'torch', (is_torch_available, PYTORCH_IMPORT_ERROR)
|
||||
), ('vllm', (is_vllm_available, VLLM_IMPORT_ERROR)), ('cpm_kernels', (is_cpm_kernels_available, CPM_KERNELS_IMPORT_ERROR)), ('einops', (is_einops_available, EINOPS_IMPORT_ERROR)), (
|
||||
'triton', (is_triton_available, TRITON_IMPORT_ERROR)
|
||||
), ('datasets', (is_datasets_available, DATASETS_IMPORT_ERROR)), ('peft', (is_peft_available, PEFT_IMPORT_ERROR)), ('bitsandbytes', (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)), (
|
||||
'auto-gptq', (is_autogptq_available, AUTOGPTQ_IMPORT_ERROR)
|
||||
), ('sentencepiece', (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)), ('xformers', (is_xformers_available, XFORMERS_IMPORT_ERROR)), (
|
||||
'fairscale', (is_fairscale_available, FAIRSCALE_IMPORT_ERROR)
|
||||
)])
|
||||
class DummyMetaclass(abc.ABCMeta):
|
||||
"""Metaclass for dummy object.
|
||||
'''Metaclass for dummy object.
|
||||
|
||||
It will raises ImportError generated by ``require_backends`` if users try to access attributes from given class.
|
||||
"""
|
||||
'''
|
||||
_backends: t.List[str]
|
||||
|
||||
def __getattribute__(cls, key: str) -> t.Any:
|
||||
if key.startswith("_"): return super().__getattribute__(key)
|
||||
if key.startswith('_'): return super().__getattribute__(key)
|
||||
require_backends(cls, cls._backends)
|
||||
def require_backends(o: t.Any, backends: t.MutableSequence[str]) -> None:
|
||||
if not isinstance(backends, (list, tuple)): backends = list(backends)
|
||||
name = o.__name__ if hasattr(o, "__name__") else o.__class__.__name__
|
||||
name = o.__name__ if hasattr(o, '__name__') else o.__class__.__name__
|
||||
# Raise an error for users who might not realize that classes without "TF" are torch-only
|
||||
if "torch" in backends and "tf" not in backends and not is_torch_available() and is_tf_available(): raise ImportError(PYTORCH_IMPORT_ERROR_WITH_TF.format(name))
|
||||
if 'torch' in backends and 'tf' not in backends and not is_torch_available() and is_tf_available(): raise ImportError(PYTORCH_IMPORT_ERROR_WITH_TF.format(name))
|
||||
# Raise the inverse error for PyTorch users trying to load TF classes
|
||||
if "tf" in backends and "torch" not in backends and is_torch_available() and not is_tf_available(): raise ImportError(TF_IMPORT_ERROR_WITH_PYTORCH.format(name))
|
||||
if 'tf' in backends and 'torch' not in backends and is_torch_available() and not is_tf_available(): raise ImportError(TF_IMPORT_ERROR_WITH_PYTORCH.format(name))
|
||||
# Raise an error when vLLM is not available to consider the alternative, order from PyTorch -> Tensorflow -> Flax
|
||||
if "vllm" in backends:
|
||||
if "torch" not in backends and is_torch_available() and not is_vllm_available(): raise ImportError(VLLM_IMPORT_ERROR_WITH_PYTORCH.format(name))
|
||||
if "tf" not in backends and is_tf_available() and not is_vllm_available(): raise ImportError(VLLM_IMPORT_ERROR_WITH_TF.format(name))
|
||||
if "flax" not in backends and is_flax_available() and not is_vllm_available(): raise ImportError(VLLM_IMPORT_ERROR_WITH_FLAX.format(name))
|
||||
if 'vllm' in backends:
|
||||
if 'torch' not in backends and is_torch_available() and not is_vllm_available(): raise ImportError(VLLM_IMPORT_ERROR_WITH_PYTORCH.format(name))
|
||||
if 'tf' not in backends and is_tf_available() and not is_vllm_available(): raise ImportError(VLLM_IMPORT_ERROR_WITH_TF.format(name))
|
||||
if 'flax' not in backends and is_flax_available() and not is_vllm_available(): raise ImportError(VLLM_IMPORT_ERROR_WITH_FLAX.format(name))
|
||||
failed = [msg.format(name) for available, msg in (BACKENDS_MAPPING[backend] for backend in backends) if not available()]
|
||||
if failed: raise ImportError("".join(failed))
|
||||
if failed: raise ImportError(''.join(failed))
|
||||
class EnvVarMixin(ReprMixin):
|
||||
model_name: str
|
||||
config: str
|
||||
@@ -295,64 +295,64 @@ class EnvVarMixin(ReprMixin):
|
||||
runtime: str
|
||||
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["config"]) -> str:
|
||||
def __getitem__(self, item: t.Literal['config']) -> str:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["model_id"]) -> str:
|
||||
def __getitem__(self, item: t.Literal['model_id']) -> str:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["quantize"]) -> str:
|
||||
def __getitem__(self, item: t.Literal['quantize']) -> str:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["framework"]) -> str:
|
||||
def __getitem__(self, item: t.Literal['framework']) -> str:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["bettertransformer"]) -> str:
|
||||
def __getitem__(self, item: t.Literal['bettertransformer']) -> str:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["runtime"]) -> str:
|
||||
def __getitem__(self, item: t.Literal['runtime']) -> str:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["framework_value"]) -> LiteralRuntime:
|
||||
def __getitem__(self, item: t.Literal['framework_value']) -> LiteralRuntime:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["quantize_value"]) -> t.Literal["int8", "int4", "gptq"] | None:
|
||||
def __getitem__(self, item: t.Literal['quantize_value']) -> t.Literal['int8', 'int4', 'gptq'] | None:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["model_id_value"]) -> str | None:
|
||||
def __getitem__(self, item: t.Literal['model_id_value']) -> str | None:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["bettertransformer_value"]) -> bool:
|
||||
def __getitem__(self, item: t.Literal['bettertransformer_value']) -> bool:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal["runtime_value"]) -> t.Literal["ggml", "transformers"]:
|
||||
def __getitem__(self, item: t.Literal['runtime_value']) -> t.Literal['ggml', 'transformers']:
|
||||
...
|
||||
|
||||
def __getitem__(self, item: str | t.Any) -> t.Any:
|
||||
if item.endswith("_value") and hasattr(self, f"_{item}"): return object.__getattribute__(self, f"_{item}")()
|
||||
if item.endswith('_value') and hasattr(self, f'_{item}'): return object.__getattribute__(self, f'_{item}')()
|
||||
elif hasattr(self, item): return getattr(self, item)
|
||||
raise KeyError(f"Key {item} not found in {self}")
|
||||
raise KeyError(f'Key {item} not found in {self}')
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
implementation: LiteralRuntime = "pt",
|
||||
implementation: LiteralRuntime = 'pt',
|
||||
model_id: str | None = None,
|
||||
bettertransformer: bool | None = None,
|
||||
quantize: LiteralString | None = None,
|
||||
runtime: t.Literal["ggml", "transformers"] = "transformers"
|
||||
runtime: t.Literal['ggml', 'transformers'] = 'transformers'
|
||||
) -> None:
|
||||
"""EnvVarMixin is a mixin class that returns the value extracted from environment variables."""
|
||||
'''EnvVarMixin is a mixin class that returns the value extracted from environment variables.'''
|
||||
from openllm_core.utils import field_env_key
|
||||
self.model_name = inflection.underscore(model_name)
|
||||
self._implementation = implementation
|
||||
@@ -360,37 +360,37 @@ class EnvVarMixin(ReprMixin):
|
||||
self._bettertransformer = bettertransformer
|
||||
self._quantize = quantize
|
||||
self._runtime = runtime
|
||||
for att in {"config", "model_id", "quantize", "framework", "bettertransformer", "runtime"}:
|
||||
for att in {'config', 'model_id', 'quantize', 'framework', 'bettertransformer', 'runtime'}:
|
||||
setattr(self, att, field_env_key(self.model_name, att.upper()))
|
||||
|
||||
def _quantize_value(self) -> t.Literal["int8", "int4", "gptq"] | None:
|
||||
def _quantize_value(self) -> t.Literal['int8', 'int4', 'gptq'] | None:
|
||||
from . import first_not_none
|
||||
return t.cast(t.Optional[t.Literal["int8", "int4", "gptq"]], first_not_none(os.environ.get(self["quantize"]), default=self._quantize))
|
||||
return t.cast(t.Optional[t.Literal['int8', 'int4', 'gptq']], first_not_none(os.environ.get(self['quantize']), default=self._quantize))
|
||||
|
||||
def _framework_value(self) -> LiteralRuntime:
|
||||
from . import first_not_none
|
||||
return t.cast(t.Literal["pt", "tf", "flax", "vllm"], first_not_none(os.environ.get(self["framework"]), default=self._implementation))
|
||||
return t.cast(t.Literal['pt', 'tf', 'flax', 'vllm'], first_not_none(os.environ.get(self['framework']), default=self._implementation))
|
||||
|
||||
def _bettertransformer_value(self) -> bool:
|
||||
from . import first_not_none
|
||||
return t.cast(bool, first_not_none(os.environ.get(self["bettertransformer"], str(False)).upper() in ENV_VARS_TRUE_VALUES, default=self._bettertransformer))
|
||||
return t.cast(bool, first_not_none(os.environ.get(self['bettertransformer'], str(False)).upper() in ENV_VARS_TRUE_VALUES, default=self._bettertransformer))
|
||||
|
||||
def _model_id_value(self) -> str | None:
|
||||
from . import first_not_none
|
||||
return first_not_none(os.environ.get(self["model_id"]), default=self._model_id)
|
||||
return first_not_none(os.environ.get(self['model_id']), default=self._model_id)
|
||||
|
||||
def _runtime_value(self) -> t.Literal["ggml", "transformers"]:
|
||||
def _runtime_value(self) -> t.Literal['ggml', 'transformers']:
|
||||
from . import first_not_none
|
||||
return t.cast(t.Literal["ggml", "transformers"], first_not_none(os.environ.get(self["runtime"]), default=self._runtime))
|
||||
return t.cast(t.Literal['ggml', 'transformers'], first_not_none(os.environ.get(self['runtime']), default=self._runtime))
|
||||
|
||||
@property
|
||||
def __repr_keys__(self) -> set[str]:
|
||||
return {"config", "model_id", "quantize", "framework", "bettertransformer", "runtime"}
|
||||
return {'config', 'model_id', 'quantize', 'framework', 'bettertransformer', 'runtime'}
|
||||
|
||||
@property
|
||||
def start_docstring(self) -> str:
|
||||
return getattr(openllm_core.config, f"START_{self.model_name.upper()}_COMMAND_DOCSTRING")
|
||||
return getattr(openllm_core.config, f'START_{self.model_name.upper()}_COMMAND_DOCSTRING')
|
||||
|
||||
@property
|
||||
def module(self) -> LazyLoader:
|
||||
return LazyLoader(self.model_name, globals(), f"openllm.models.{self.model_name}")
|
||||
return LazyLoader(self.model_name, globals(), f'openllm.models.{self.model_name}')
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import functools, importlib, importlib.machinery, importlib.metadata, importlib.util, itertools, os, time, types, warnings, typing as t, attr, openllm_core
|
||||
__all__ = ["VersionInfo", "LazyModule"]
|
||||
__all__ = ['VersionInfo', 'LazyModule']
|
||||
# vendorred from attrs
|
||||
@functools.total_ordering
|
||||
@attr.attrs(eq=False, order=False, slots=True, frozen=True, repr=False)
|
||||
@@ -12,8 +12,8 @@ class VersionInfo:
|
||||
|
||||
@classmethod
|
||||
def from_version_string(cls, s: str) -> VersionInfo:
|
||||
v = s.split(".")
|
||||
if len(v) == 3: v.append("final")
|
||||
v = s.split('.')
|
||||
if len(v) == 3: v.append('final')
|
||||
return cls(major=int(v[0]), minor=int(v[1]), micro=int(v[2]), releaselevel=v[3])
|
||||
|
||||
def _ensure_tuple(self, other: VersionInfo) -> tuple[tuple[int, int, int, str], tuple[int, int, int, str]]:
|
||||
@@ -38,8 +38,8 @@ class VersionInfo:
|
||||
return us < them
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "{0}.{1}.{2}".format(*attr.astuple(self)[:3])
|
||||
_sentinel, _reserved_namespace = object(), {"__openllm_migration__"}
|
||||
return '{0}.{1}.{2}'.format(*attr.astuple(self)[:3])
|
||||
_sentinel, _reserved_namespace = object(), {'__openllm_migration__'}
|
||||
class LazyModule(types.ModuleType):
|
||||
# Very heavily inspired by optuna.integration._IntegrationModule: https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
|
||||
def __init__(
|
||||
@@ -81,64 +81,64 @@ class LazyModule(types.ModuleType):
|
||||
self._import_structure = import_structure
|
||||
|
||||
def __dir__(self) -> list[str]:
|
||||
result = t.cast("list[str]", super().__dir__())
|
||||
result = t.cast('list[str]', super().__dir__())
|
||||
# The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether
|
||||
# they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir.
|
||||
return result + [i for i in self.__all__ if i not in result]
|
||||
|
||||
def __getattr__(self, name: str) -> t.Any:
|
||||
"""Equivocal __getattr__ implementation.
|
||||
'''Equivocal __getattr__ implementation.
|
||||
|
||||
It checks from _objects > _modules and does it recursively.
|
||||
|
||||
It also contains a special case for all of the metadata information, such as __version__ and __version_info__.
|
||||
"""
|
||||
'''
|
||||
if name in _reserved_namespace: raise openllm_core.exceptions.ForbiddenAttributeError(f"'{name}' is a reserved namespace for {self._name} and should not be access nor modified.")
|
||||
dunder_to_metadata = {
|
||||
"__title__": "Name",
|
||||
"__copyright__": "",
|
||||
"__version__": "version",
|
||||
"__version_info__": "version",
|
||||
"__description__": "summary",
|
||||
"__uri__": "",
|
||||
"__url__": "",
|
||||
"__author__": "",
|
||||
"__email__": "",
|
||||
"__license__": "license",
|
||||
"__homepage__": ""
|
||||
'__title__': 'Name',
|
||||
'__copyright__': '',
|
||||
'__version__': 'version',
|
||||
'__version_info__': 'version',
|
||||
'__description__': 'summary',
|
||||
'__uri__': '',
|
||||
'__url__': '',
|
||||
'__author__': '',
|
||||
'__email__': '',
|
||||
'__license__': 'license',
|
||||
'__homepage__': ''
|
||||
}
|
||||
if name in dunder_to_metadata:
|
||||
if name not in {"__version_info__", "__copyright__", "__version__"}:
|
||||
if name not in {'__version_info__', '__copyright__', '__version__'}:
|
||||
warnings.warn(
|
||||
f"Accessing '{self._name}.{name}' is deprecated. Please consider using 'importlib.metadata' directly to query for openllm packaging metadata.", DeprecationWarning, stacklevel=2
|
||||
)
|
||||
meta = importlib.metadata.metadata("openllm")
|
||||
project_url = dict(url.split(", ") for url in t.cast(t.List[str], meta.get_all("Project-URL")))
|
||||
if name == "__license__": return "Apache-2.0"
|
||||
elif name == "__copyright__": return f"Copyright (c) 2023-{time.strftime('%Y')}, Aaron Pham et al."
|
||||
elif name in ("__uri__", "__url__"): return project_url["GitHub"]
|
||||
elif name == "__homepage__": return project_url["Homepage"]
|
||||
elif name == "__version_info__": return VersionInfo.from_version_string(meta["version"]) # similar to how attrs handle __version_info__
|
||||
elif name == "__author__": return meta["Author-email"].rsplit(" ", 1)[0]
|
||||
elif name == "__email__": return meta["Author-email"].rsplit("<", 1)[1][:-1]
|
||||
meta = importlib.metadata.metadata('openllm')
|
||||
project_url = dict(url.split(', ') for url in t.cast(t.List[str], meta.get_all('Project-URL')))
|
||||
if name == '__license__': return 'Apache-2.0'
|
||||
elif name == '__copyright__': return f"Copyright (c) 2023-{time.strftime('%Y')}, Aaron Pham et al."
|
||||
elif name in ('__uri__', '__url__'): return project_url['GitHub']
|
||||
elif name == '__homepage__': return project_url['Homepage']
|
||||
elif name == '__version_info__': return VersionInfo.from_version_string(meta['version']) # similar to how attrs handle __version_info__
|
||||
elif name == '__author__': return meta['Author-email'].rsplit(' ', 1)[0]
|
||||
elif name == '__email__': return meta['Author-email'].rsplit('<', 1)[1][:-1]
|
||||
return meta[dunder_to_metadata[name]]
|
||||
if "__openllm_migration__" in self._objects:
|
||||
cur_value = self._objects["__openllm_migration__"].get(name, _sentinel)
|
||||
if '__openllm_migration__' in self._objects:
|
||||
cur_value = self._objects['__openllm_migration__'].get(name, _sentinel)
|
||||
if cur_value is not _sentinel:
|
||||
warnings.warn(f"'{name}' is deprecated and will be removed in future version. Make sure to use '{cur_value}' instead", DeprecationWarning, stacklevel=3)
|
||||
return getattr(self, cur_value)
|
||||
if name in self._objects: return self._objects.__getitem__(name)
|
||||
if name in self._modules: value = self._get_module(name)
|
||||
elif name in self._class_to_module.keys(): value = getattr(self._get_module(self._class_to_module.__getitem__(name)), name)
|
||||
else: raise AttributeError(f"module {self.__name__} has no attribute {name}")
|
||||
else: raise AttributeError(f'module {self.__name__} has no attribute {name}')
|
||||
setattr(self, name, value)
|
||||
return value
|
||||
|
||||
def _get_module(self, module_name: str) -> types.ModuleType:
|
||||
try:
|
||||
return importlib.import_module("." + module_name, self.__name__)
|
||||
return importlib.import_module('.' + module_name, self.__name__)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to import {self.__name__}.{module_name} because of the following error (look up to see its traceback):\n{e}") from e
|
||||
raise RuntimeError(f'Failed to import {self.__name__}.{module_name} because of the following error (look up to see its traceback):\n{e}') from e
|
||||
|
||||
# make sure this module is picklable
|
||||
def __reduce__(self) -> tuple[type[LazyModule], tuple[str, str | None, dict[str, list[str]]]]:
|
||||
|
||||
@@ -11,38 +11,38 @@ class ReprMixin:
|
||||
def __repr_keys__(self) -> set[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
"""This can be overriden by base class using this mixin."""
|
||||
'''This can be overriden by base class using this mixin.'''
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__} {orjson.dumps({k: utils.bentoml_cattr.unstructure(v) if attr.has(v) else v for k, v in self.__repr_args__()}, option=orjson.OPT_INDENT_2).decode()}"
|
||||
return f'{self.__class__.__name__} {orjson.dumps({k: utils.bentoml_cattr.unstructure(v) if attr.has(v) else v for k, v in self.__repr_args__()}, option=orjson.OPT_INDENT_2).decode()}'
|
||||
|
||||
"""The `__repr__` for any subclass of Mixin.
|
||||
'''The `__repr__` for any subclass of Mixin.
|
||||
|
||||
It will print nicely the class name with each of the fields under '__repr_keys__' as kv JSON dict.
|
||||
"""
|
||||
'''
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.__repr_str__(" ")
|
||||
return self.__repr_str__(' ')
|
||||
|
||||
"""The string representation of the given Mixin subclass.
|
||||
'''The string representation of the given Mixin subclass.
|
||||
|
||||
It will contains all of the attributes from __repr_keys__
|
||||
"""
|
||||
'''
|
||||
|
||||
def __repr_name__(self) -> str:
|
||||
return self.__class__.__name__
|
||||
|
||||
"""Name of the instance's class, used in __repr__."""
|
||||
'''Name of the instance's class, used in __repr__.'''
|
||||
|
||||
def __repr_str__(self, join_str: str) -> str:
|
||||
return join_str.join(repr(v) if a is None else f"{a}={v!r}" for a, v in self.__repr_args__())
|
||||
return join_str.join(repr(v) if a is None else f'{a}={v!r}' for a, v in self.__repr_args__())
|
||||
|
||||
"""To be used with __str__."""
|
||||
'''To be used with __str__.'''
|
||||
|
||||
def __repr_args__(self) -> ReprArgs:
|
||||
return ((k, getattr(self, k)) for k in self.__repr_keys__)
|
||||
|
||||
"""This can also be overriden by base class using this mixin.
|
||||
'''This can also be overriden by base class using this mixin.
|
||||
|
||||
By default it does a getattr of the current object from __repr_keys__.
|
||||
"""
|
||||
'''
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
"""CLI entrypoint for OpenLLM.
|
||||
'''CLI entrypoint for OpenLLM.
|
||||
|
||||
Usage:
|
||||
openllm --help
|
||||
|
||||
To start any OpenLLM model:
|
||||
openllm start <model_name> --options ...
|
||||
"""
|
||||
'''
|
||||
from __future__ import annotations
|
||||
|
||||
if __name__ == "__main__":
|
||||
if __name__ == '__main__':
|
||||
from openllm.cli.entrypoint import cli
|
||||
cli()
|
||||
|
||||
@@ -6,37 +6,37 @@ from bentoml._internal.frameworks.transformers import MODULE_NAME, API_VERSION
|
||||
from bentoml._internal.models.model import ModelOptions, ModelSignature
|
||||
if t.TYPE_CHECKING: import torch
|
||||
|
||||
_GENERIC_EMBEDDING_ID = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
_BENTOMODEL_ID = "sentence-transformers--all-MiniLM-L6-v2"
|
||||
_GENERIC_EMBEDDING_ID = 'sentence-transformers/all-MiniLM-L6-v2'
|
||||
_BENTOMODEL_ID = 'sentence-transformers--all-MiniLM-L6-v2'
|
||||
def get_or_download(ids: str = _BENTOMODEL_ID) -> bentoml.Model:
|
||||
try:
|
||||
return bentoml.transformers.get(ids)
|
||||
except bentoml.exceptions.NotFound:
|
||||
model_signatures = {
|
||||
k: ModelSignature(batchable=False)
|
||||
for k in ("forward", "generate", "contrastive_search", "greedy_search", "sample", "beam_search", "beam_sample", "group_beam_search", "constrained_beam_search", "__call__")
|
||||
for k in ('forward', 'generate', 'contrastive_search', 'greedy_search', 'sample', 'beam_search', 'beam_sample', 'group_beam_search', 'constrained_beam_search', '__call__')
|
||||
}
|
||||
with bentoml.models.create(
|
||||
ids,
|
||||
module=MODULE_NAME,
|
||||
api_version=API_VERSION,
|
||||
options=ModelOptions(),
|
||||
context=openllm.utils.generate_context(framework_name="transformers"),
|
||||
context=openllm.utils.generate_context(framework_name='transformers'),
|
||||
labels={
|
||||
"runtime": "pt", "framework": "openllm"
|
||||
'runtime': 'pt', 'framework': 'openllm'
|
||||
},
|
||||
signatures=model_signatures
|
||||
) as bentomodel:
|
||||
snapshot_download(
|
||||
_GENERIC_EMBEDDING_ID, local_dir=bentomodel.path, local_dir_use_symlinks=False, ignore_patterns=["*.safetensors", "*.h5", "*.ot", "*.pdf", "*.md", ".gitattributes", "LICENSE.txt"]
|
||||
_GENERIC_EMBEDDING_ID, local_dir=bentomodel.path, local_dir_use_symlinks=False, ignore_patterns=['*.safetensors', '*.h5', '*.ot', '*.pdf', '*.md', '.gitattributes', 'LICENSE.txt']
|
||||
)
|
||||
return bentomodel
|
||||
class GenericEmbeddingRunnable(bentoml.Runnable):
|
||||
SUPPORTED_RESOURCES = ("nvidia.com/gpu", "cpu")
|
||||
SUPPORTED_RESOURCES = ('nvidia.com/gpu', 'cpu')
|
||||
SUPPORTS_CPU_MULTI_THREADING = True
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.device = "cuda" if openllm.utils.device_count() > 0 else "cpu"
|
||||
self.device = 'cuda' if openllm.utils.device_count() > 0 else 'cpu'
|
||||
self._bentomodel = get_or_download()
|
||||
self.tokenizer = transformers.AutoTokenizer.from_pretrained(self._bentomodel.path)
|
||||
self.model = transformers.AutoModel.from_pretrained(self._bentomodel.path)
|
||||
@@ -45,8 +45,8 @@ class GenericEmbeddingRunnable(bentoml.Runnable):
|
||||
@bentoml.Runnable.method(batchable=True, batch_dim=0)
|
||||
def encode(self, sentences: list[str]) -> t.Sequence[openllm.LLMEmbeddings]:
|
||||
import torch, torch.nn.functional as F
|
||||
encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors="pt").to(self.device)
|
||||
attention_mask = encoded_input["attention_mask"]
|
||||
encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt').to(self.device)
|
||||
attention_mask = encoded_input['attention_mask']
|
||||
# Compute token embeddings
|
||||
with torch.no_grad():
|
||||
model_output = self.model(**encoded_input)
|
||||
@@ -61,4 +61,4 @@ class GenericEmbeddingRunnable(bentoml.Runnable):
|
||||
token_embeddings = model_output[0] # First element of model_output contains all token embeddings
|
||||
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
||||
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
||||
__all__ = ["GenericEmbeddingRunnable"]
|
||||
__all__ = ['GenericEmbeddingRunnable']
|
||||
|
||||
@@ -19,23 +19,23 @@ class StopOnTokens(transformers.StoppingCriteria):
|
||||
def prepare_logits_processor(config: openllm.LLMConfig) -> transformers.LogitsProcessorList:
|
||||
generation_config = config.generation_config
|
||||
logits_processor = transformers.LogitsProcessorList()
|
||||
if generation_config["temperature"] >= 1e-5 and generation_config["temperature"] != 1.0: logits_processor.append(transformers.TemperatureLogitsWarper(generation_config["temperature"]))
|
||||
if generation_config["repetition_penalty"] > 1.0: logits_processor.append(transformers.RepetitionPenaltyLogitsProcessor(generation_config["repetition_penalty"]))
|
||||
if 1e-8 <= generation_config["top_p"]: logits_processor.append(transformers.TopPLogitsWarper(generation_config["top_p"]))
|
||||
if generation_config["top_k"] > 0: logits_processor.append(transformers.TopKLogitsWarper(generation_config["top_k"]))
|
||||
if generation_config['temperature'] >= 1e-5 and generation_config['temperature'] != 1.0: logits_processor.append(transformers.TemperatureLogitsWarper(generation_config['temperature']))
|
||||
if generation_config['repetition_penalty'] > 1.0: logits_processor.append(transformers.RepetitionPenaltyLogitsProcessor(generation_config['repetition_penalty']))
|
||||
if 1e-8 <= generation_config['top_p']: logits_processor.append(transformers.TopPLogitsWarper(generation_config['top_p']))
|
||||
if generation_config['top_k'] > 0: logits_processor.append(transformers.TopKLogitsWarper(generation_config['top_k']))
|
||||
return logits_processor
|
||||
# NOTE: The ordering here is important. Some models have two of these and we have a preference for which value gets used.
|
||||
SEQLEN_KEYS = ["max_sequence_length", "seq_length", "max_position_embeddings", "max_seq_len", "model_max_length"]
|
||||
SEQLEN_KEYS = ['max_sequence_length', 'seq_length', 'max_position_embeddings', 'max_seq_len', 'model_max_length']
|
||||
def get_context_length(config: transformers.PretrainedConfig) -> int:
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
rope_scaling_factor = config.rope_scaling["factor"] if rope_scaling else 1.0
|
||||
rope_scaling = getattr(config, 'rope_scaling', None)
|
||||
rope_scaling_factor = config.rope_scaling['factor'] if rope_scaling else 1.0
|
||||
for key in SEQLEN_KEYS:
|
||||
if getattr(config, key, None) is not None: return int(rope_scaling_factor * getattr(config, key))
|
||||
return 2048
|
||||
def is_sentence_complete(output: str) -> bool:
|
||||
return output.endswith((".", "?", "!", "...", "。", "?", "!", "…", '"', "'", "”"))
|
||||
return output.endswith(('.', '?', '!', '...', '。', '?', '!', '…', '"', "'", '”'))
|
||||
def is_partial_stop(output: str, stop_str: str) -> bool:
|
||||
"""Check whether the output contains a partial stop str."""
|
||||
'''Check whether the output contains a partial stop str.'''
|
||||
for i in range(0, min(len(output), len(stop_str))):
|
||||
if stop_str.startswith(output[-i:]): return True
|
||||
return False
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -6,38 +6,38 @@ from openllm_core._typing_compat import overload
|
||||
if t.TYPE_CHECKING:
|
||||
from ._llm import LLM
|
||||
from openllm_core._typing_compat import DictStrAny
|
||||
autogptq, torch, transformers = LazyLoader("autogptq", globals(), "auto_gptq"), LazyLoader("torch", globals(), "torch"), LazyLoader("transformers", globals(), "transformers")
|
||||
autogptq, torch, transformers = LazyLoader('autogptq', globals(), 'auto_gptq'), LazyLoader('torch', globals(), 'torch'), LazyLoader('transformers', globals(), 'transformers')
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
QuantiseMode = t.Literal["int8", "int4", "gptq"]
|
||||
QuantiseMode = t.Literal['int8', 'int4', 'gptq']
|
||||
@overload
|
||||
def infer_quantisation_config(cls: type[LLM[t.Any, t.Any]], quantise: t.Literal["int8", "int4"], **attrs: t.Any) -> tuple[transformers.BitsAndBytesConfig, DictStrAny]:
|
||||
def infer_quantisation_config(cls: type[LLM[t.Any, t.Any]], quantise: t.Literal['int8', 'int4'], **attrs: t.Any) -> tuple[transformers.BitsAndBytesConfig, DictStrAny]:
|
||||
...
|
||||
@overload
|
||||
def infer_quantisation_config(cls: type[LLM[t.Any, t.Any]], quantise: t.Literal["gptq"], **attrs: t.Any) -> tuple[autogptq.BaseQuantizeConfig, DictStrAny]:
|
||||
def infer_quantisation_config(cls: type[LLM[t.Any, t.Any]], quantise: t.Literal['gptq'], **attrs: t.Any) -> tuple[autogptq.BaseQuantizeConfig, DictStrAny]:
|
||||
...
|
||||
def infer_quantisation_config(cls: type[LLM[t.Any, t.Any]], quantise: QuantiseMode, **attrs: t.Any) -> tuple[transformers.BitsAndBytesConfig | autogptq.BaseQuantizeConfig, DictStrAny]:
|
||||
# 8 bit configuration
|
||||
int8_threshold = attrs.pop("llm_int8_threshhold", 6.0)
|
||||
int8_enable_fp32_cpu_offload = attrs.pop("llm_int8_enable_fp32_cpu_offload", False)
|
||||
int8_skip_modules: list[str] | None = attrs.pop("llm_int8_skip_modules", None)
|
||||
int8_has_fp16_weight = attrs.pop("llm_int8_has_fp16_weight", False)
|
||||
int8_threshold = attrs.pop('llm_int8_threshhold', 6.0)
|
||||
int8_enable_fp32_cpu_offload = attrs.pop('llm_int8_enable_fp32_cpu_offload', False)
|
||||
int8_skip_modules: list[str] | None = attrs.pop('llm_int8_skip_modules', None)
|
||||
int8_has_fp16_weight = attrs.pop('llm_int8_has_fp16_weight', False)
|
||||
|
||||
autogptq_attrs: DictStrAny = {
|
||||
"bits": attrs.pop("gptq_bits", 4),
|
||||
"group_size": attrs.pop("gptq_group_size", -1),
|
||||
"damp_percent": attrs.pop("gptq_damp_percent", 0.01),
|
||||
"desc_act": attrs.pop("gptq_desc_act", True),
|
||||
"sym": attrs.pop("gptq_sym", True),
|
||||
"true_sequential": attrs.pop("gptq_true_sequential", True),
|
||||
'bits': attrs.pop('gptq_bits', 4),
|
||||
'group_size': attrs.pop('gptq_group_size', -1),
|
||||
'damp_percent': attrs.pop('gptq_damp_percent', 0.01),
|
||||
'desc_act': attrs.pop('gptq_desc_act', True),
|
||||
'sym': attrs.pop('gptq_sym', True),
|
||||
'true_sequential': attrs.pop('gptq_true_sequential', True),
|
||||
}
|
||||
|
||||
def create_int8_config(int8_skip_modules: list[str] | None) -> transformers.BitsAndBytesConfig:
|
||||
if int8_skip_modules is None: int8_skip_modules = []
|
||||
if "lm_head" not in int8_skip_modules and cls.config_class.__openllm_model_type__ == "causal_lm":
|
||||
if 'lm_head' not in int8_skip_modules and cls.config_class.__openllm_model_type__ == 'causal_lm':
|
||||
logger.debug("Skipping 'lm_head' for quantization for %s", cls.__name__)
|
||||
int8_skip_modules.append("lm_head")
|
||||
int8_skip_modules.append('lm_head')
|
||||
return transformers.BitsAndBytesConfig(
|
||||
load_in_8bit=True,
|
||||
llm_int8_enable_fp32_cpu_offload=int8_enable_fp32_cpu_offload,
|
||||
@@ -47,16 +47,16 @@ def infer_quantisation_config(cls: type[LLM[t.Any, t.Any]], quantise: QuantiseMo
|
||||
)
|
||||
|
||||
# 4 bit configuration
|
||||
int4_compute_dtype = attrs.pop("bnb_4bit_compute_dtype", torch.bfloat16)
|
||||
int4_quant_type = attrs.pop("bnb_4bit_quant_type", "nf4")
|
||||
int4_use_double_quant = attrs.pop("bnb_4bit_use_double_quant", True)
|
||||
int4_compute_dtype = attrs.pop('bnb_4bit_compute_dtype', torch.bfloat16)
|
||||
int4_quant_type = attrs.pop('bnb_4bit_quant_type', 'nf4')
|
||||
int4_use_double_quant = attrs.pop('bnb_4bit_use_double_quant', True)
|
||||
|
||||
# NOTE: Quantization setup
|
||||
# quantize is a openllm.LLM feature, where we can quantize the model
|
||||
# with bitsandbytes or quantization aware training.
|
||||
if not is_bitsandbytes_available(): raise RuntimeError("Quantization requires bitsandbytes to be installed. Make sure to install OpenLLM with 'pip install \"openllm[fine-tune]\"'")
|
||||
if quantise == "int8": quantisation_config = create_int8_config(int8_skip_modules)
|
||||
elif quantise == "int4":
|
||||
if quantise == 'int8': quantisation_config = create_int8_config(int8_skip_modules)
|
||||
elif quantise == 'int4':
|
||||
if is_transformers_supports_kbit():
|
||||
quantisation_config = transformers.BitsAndBytesConfig(
|
||||
load_in_4bit=True, bnb_4bit_compute_dtype=int4_compute_dtype, bnb_4bit_quant_type=int4_quant_type, bnb_4bit_use_double_quant=int4_use_double_quant
|
||||
@@ -64,10 +64,10 @@ def infer_quantisation_config(cls: type[LLM[t.Any, t.Any]], quantise: QuantiseMo
|
||||
else:
|
||||
logger.warning(
|
||||
"'quantize' is set to int4, while the current transformers version %s does not support k-bit quantization. k-bit quantization is supported since transformers 4.30, therefore make sure to install the latest version of transformers either via PyPI or from git source: 'pip install git+https://github.com/huggingface/transformers'. Fallback to int8 quantisation.",
|
||||
pkg.pkg_version_info("transformers")
|
||||
pkg.pkg_version_info('transformers')
|
||||
)
|
||||
quantisation_config = create_int8_config(int8_skip_modules)
|
||||
elif quantise == "gptq":
|
||||
elif quantise == 'gptq':
|
||||
if not is_autogptq_available():
|
||||
logger.warning(
|
||||
"'quantize=\"gptq\"' requires 'auto-gptq' to be installed (not available with local environment). Make sure to have 'auto-gptq' available locally: 'pip install \"openllm[gptq]\"'. OpenLLM will fallback to int8 with bitsandbytes."
|
||||
|
||||
@@ -11,16 +11,16 @@ if t.TYPE_CHECKING:
|
||||
from bentoml._internal.runner.runner import RunnerMethod, AbstractRunner
|
||||
_EmbeddingMethod: TypeAlias = RunnerMethod[t.Union[bentoml.Runnable, openllm.LLMRunnable[t.Any, t.Any]], [t.List[str]], t.Sequence[openllm.LLMEmbeddings]]
|
||||
# The following warnings from bitsandbytes, and probably not that important for users to see
|
||||
warnings.filterwarnings("ignore", message="MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization")
|
||||
warnings.filterwarnings("ignore", message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization")
|
||||
warnings.filterwarnings("ignore", message="The installed version of bitsandbytes was compiled without GPU support.")
|
||||
model = os.environ.get("OPENLLM_MODEL", "{__model_name__}") # openllm: model name
|
||||
adapter_map = os.environ.get("OPENLLM_ADAPTER_MAP", """{__model_adapter_map__}""") # openllm: model adapter map
|
||||
warnings.filterwarnings('ignore', message='MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization')
|
||||
warnings.filterwarnings('ignore', message='MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization')
|
||||
warnings.filterwarnings('ignore', message='The installed version of bitsandbytes was compiled without GPU support.')
|
||||
model = os.environ.get('OPENLLM_MODEL', '{__model_name__}') # openllm: model name
|
||||
adapter_map = os.environ.get('OPENLLM_ADAPTER_MAP', '''{__model_adapter_map__}''') # openllm: model adapter map
|
||||
llm_config = openllm.AutoConfig.for_model(model)
|
||||
runner = openllm.Runner(model, llm_config=llm_config, ensure_available=False, adapter_map=orjson.loads(adapter_map))
|
||||
generic_embedding_runner = bentoml.Runner(
|
||||
openllm.GenericEmbeddingRunnable, # XXX: remove arg-type once bentoml.Runner is correct set with type
|
||||
name="llm-generic-embedding",
|
||||
name='llm-generic-embedding',
|
||||
scheduling_strategy=openllm_core.CascadingResourceStrategy,
|
||||
max_batch_size=32,
|
||||
max_latency_ms=300
|
||||
@@ -28,45 +28,45 @@ generic_embedding_runner = bentoml.Runner(
|
||||
runners: list[AbstractRunner] = [runner]
|
||||
if not runner.supports_embeddings: runners.append(generic_embedding_runner)
|
||||
svc = bentoml.Service(name=f"llm-{llm_config['start_name']}-service", runners=runners)
|
||||
_JsonInput = bentoml.io.JSON.from_sample({"prompt": "", "llm_config": llm_config.model_dump(flatten=True), "adapter_name": ""})
|
||||
@svc.api(route="/v1/generate", input=_JsonInput, output=bentoml.io.JSON.from_sample({"responses": [], "configuration": llm_config.model_dump(flatten=True)}))
|
||||
_JsonInput = bentoml.io.JSON.from_sample({'prompt': '', 'llm_config': llm_config.model_dump(flatten=True), 'adapter_name': ''})
|
||||
@svc.api(route='/v1/generate', input=_JsonInput, output=bentoml.io.JSON.from_sample({'responses': [], 'configuration': llm_config.model_dump(flatten=True)}))
|
||||
async def generate_v1(input_dict: dict[str, t.Any]) -> openllm.GenerationOutput:
|
||||
qa_inputs = openllm.GenerationInput.from_llm_config(llm_config)(**input_dict)
|
||||
config = qa_inputs.llm_config.model_dump()
|
||||
responses = await runner.generate.async_run(qa_inputs.prompt, **{"adapter_name": qa_inputs.adapter_name, **config})
|
||||
responses = await runner.generate.async_run(qa_inputs.prompt, **{'adapter_name': qa_inputs.adapter_name, **config})
|
||||
return openllm.GenerationOutput(responses=responses, configuration=config)
|
||||
@svc.api(route="/v1/generate_stream", input=_JsonInput, output=bentoml.io.Text(content_type="text/event_stream"))
|
||||
@svc.api(route='/v1/generate_stream', input=_JsonInput, output=bentoml.io.Text(content_type='text/event_stream'))
|
||||
async def generate_stream_v1(input_dict: dict[str, t.Any]) -> t.AsyncGenerator[str, None]:
|
||||
qa_inputs = openllm.GenerationInput.from_llm_config(llm_config)(**input_dict)
|
||||
return runner.generate_iterator.async_stream(qa_inputs.prompt, adapter_name=qa_inputs.adapter_name, **qa_inputs.llm_config.model_dump())
|
||||
@svc.api(
|
||||
route="/v1/metadata",
|
||||
route='/v1/metadata',
|
||||
input=bentoml.io.Text(),
|
||||
output=bentoml.io.JSON.from_sample({
|
||||
"model_id": runner.llm.model_id,
|
||||
"timeout": 3600,
|
||||
"model_name": llm_config["model_name"],
|
||||
"framework": "pt",
|
||||
"configuration": "",
|
||||
"supports_embeddings": runner.supports_embeddings,
|
||||
"supports_hf_agent": runner.supports_hf_agent
|
||||
'model_id': runner.llm.model_id,
|
||||
'timeout': 3600,
|
||||
'model_name': llm_config['model_name'],
|
||||
'framework': 'pt',
|
||||
'configuration': '',
|
||||
'supports_embeddings': runner.supports_embeddings,
|
||||
'supports_hf_agent': runner.supports_hf_agent
|
||||
})
|
||||
)
|
||||
def metadata_v1(_: str) -> openllm.MetadataOutput:
|
||||
return openllm.MetadataOutput(
|
||||
timeout=llm_config["timeout"],
|
||||
model_name=llm_config["model_name"],
|
||||
framework=llm_config["env"]["framework_value"],
|
||||
timeout=llm_config['timeout'],
|
||||
model_name=llm_config['model_name'],
|
||||
framework=llm_config['env']['framework_value'],
|
||||
model_id=runner.llm.model_id,
|
||||
configuration=llm_config.model_dump_json().decode(),
|
||||
supports_embeddings=runner.supports_embeddings,
|
||||
supports_hf_agent=runner.supports_hf_agent
|
||||
)
|
||||
@svc.api(
|
||||
route="/v1/embeddings",
|
||||
input=bentoml.io.JSON.from_sample(["Hey Jude, welcome to the jungle!", "What is the meaning of life?"]),
|
||||
route='/v1/embeddings',
|
||||
input=bentoml.io.JSON.from_sample(['Hey Jude, welcome to the jungle!', 'What is the meaning of life?']),
|
||||
output=bentoml.io.JSON.from_sample({
|
||||
"embeddings": [
|
||||
'embeddings': [
|
||||
0.007917795330286026,
|
||||
-0.014421648345887661,
|
||||
0.00481307040899992,
|
||||
@@ -94,13 +94,13 @@ def metadata_v1(_: str) -> openllm.MetadataOutput:
|
||||
-0.014814382418990135,
|
||||
0.01796768605709076
|
||||
],
|
||||
"num_tokens": 20
|
||||
'num_tokens': 20
|
||||
})
|
||||
)
|
||||
async def embeddings_v1(phrases: list[str]) -> openllm.EmbeddingsOutput:
|
||||
embed_call: _EmbeddingMethod = runner.embeddings if runner.supports_embeddings else generic_embedding_runner.encode # type: ignore[type-arg,assignment,valid-type]
|
||||
responses = (await embed_call.async_run(phrases))[0]
|
||||
return openllm.EmbeddingsOutput(embeddings=responses["embeddings"], num_tokens=responses["num_tokens"])
|
||||
return openllm.EmbeddingsOutput(embeddings=responses['embeddings'], num_tokens=responses['num_tokens'])
|
||||
if runner.supports_hf_agent and openllm.utils.is_transformers_supports_agent():
|
||||
|
||||
async def hf_agent(request: Request) -> Response:
|
||||
@@ -108,19 +108,19 @@ if runner.supports_hf_agent and openllm.utils.is_transformers_supports_agent():
|
||||
try:
|
||||
input_data = openllm.utils.bentoml_cattr.structure(orjson.loads(json_str), openllm.HfAgentInput)
|
||||
except orjson.JSONDecodeError as err:
|
||||
raise openllm.exceptions.OpenLLMException(f"Invalid JSON input received: {err}") from None
|
||||
stop = input_data.parameters.pop("stop", ["\n"])
|
||||
raise openllm.exceptions.OpenLLMException(f'Invalid JSON input received: {err}') from None
|
||||
stop = input_data.parameters.pop('stop', ['\n'])
|
||||
try:
|
||||
return JSONResponse(await runner.generate_one.async_run(input_data.inputs, stop, **input_data.parameters), status_code=200)
|
||||
except NotImplementedError:
|
||||
return JSONResponse(f"'{model}' is currently not supported with HuggingFace agents.", status_code=500)
|
||||
|
||||
hf_app = Starlette(debug=True, routes=[Route("/agent", hf_agent, methods=["POST"])])
|
||||
svc.mount_asgi_app(hf_app, path="/hf")
|
||||
hf_app = Starlette(debug=True, routes=[Route('/agent', hf_agent, methods=['POST'])])
|
||||
svc.mount_asgi_app(hf_app, path='/hf')
|
||||
async def list_adapter_v1(_: Request) -> Response:
|
||||
res: dict[str, t.Any] = {}
|
||||
if runner.peft_adapters["success"] is True: res["result"] = {k: v.to_dict() for k, v in runner.peft_adapters["result"].items()}
|
||||
res.update({"success": runner.peft_adapters["success"], "error_msg": runner.peft_adapters["error_msg"]})
|
||||
if runner.peft_adapters['success'] is True: res['result'] = {k: v.to_dict() for k, v in runner.peft_adapters['result'].items()}
|
||||
res.update({'success': runner.peft_adapters['success'], 'error_msg': runner.peft_adapters['error_msg']})
|
||||
return JSONResponse(res, status_code=200)
|
||||
adapters_app_v1 = Starlette(debug=True, routes=[Route("/adapters", list_adapter_v1, methods=["GET"])])
|
||||
svc.mount_asgi_app(adapters_app_v1, path="/v1")
|
||||
adapters_app_v1 = Starlette(debug=True, routes=[Route('/adapters', list_adapter_v1, methods=['GET'])])
|
||||
svc.mount_asgi_app(adapters_app_v1, path='/v1')
|
||||
|
||||
@@ -6,15 +6,15 @@ from __future__ import annotations
|
||||
import os, typing as t
|
||||
from openllm_core.utils import LazyModule
|
||||
_import_structure: dict[str, list[str]] = {
|
||||
"_package": ["create_bento", "build_editable", "construct_python_options", "construct_docker_options"],
|
||||
"oci": ["CONTAINER_NAMES", "get_base_container_tag", "build_container", "get_base_container_name", "supported_registries", "RefResolver"]
|
||||
'_package': ['create_bento', 'build_editable', 'construct_python_options', 'construct_docker_options'],
|
||||
'oci': ['CONTAINER_NAMES', 'get_base_container_tag', 'build_container', 'get_base_container_name', 'supported_registries', 'RefResolver']
|
||||
}
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from . import _package as _package, oci as oci
|
||||
from ._package import build_editable as build_editable, construct_docker_options as construct_docker_options, construct_python_options as construct_python_options, create_bento as create_bento
|
||||
from .oci import CONTAINER_NAMES as CONTAINER_NAMES, RefResolver as RefResolver, build_container as build_container, get_base_container_name as get_base_container_name, get_base_container_tag as get_base_container_tag, supported_registries as supported_registries
|
||||
__lazy = LazyModule(__name__, os.path.abspath("__file__"), _import_structure)
|
||||
__lazy = LazyModule(__name__, os.path.abspath('__file__'), _import_structure)
|
||||
__all__ = __lazy.__all__
|
||||
__dir__ = __lazy.__dir__
|
||||
__getattr__ = __lazy.__getattr__
|
||||
|
||||
@@ -15,77 +15,77 @@ if t.TYPE_CHECKING:
|
||||
from bentoml._internal.models.model import ModelStore
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OPENLLM_DEV_BUILD = "OPENLLM_DEV_BUILD"
|
||||
def build_editable(path: str, package: t.Literal["openllm", "openllm_core", "openllm_client"] = "openllm") -> str | None:
|
||||
"""Build OpenLLM if the OPENLLM_DEV_BUILD environment variable is set."""
|
||||
if str(os.environ.get(OPENLLM_DEV_BUILD, False)).lower() != "true": return None
|
||||
OPENLLM_DEV_BUILD = 'OPENLLM_DEV_BUILD'
|
||||
def build_editable(path: str, package: t.Literal['openllm', 'openllm_core', 'openllm_client'] = 'openllm') -> str | None:
|
||||
'''Build OpenLLM if the OPENLLM_DEV_BUILD environment variable is set.'''
|
||||
if str(os.environ.get(OPENLLM_DEV_BUILD, False)).lower() != 'true': return None
|
||||
# We need to build the package in editable mode, so that we can import it
|
||||
from build import ProjectBuilder
|
||||
from build.env import IsolatedEnvBuilder
|
||||
module_location = openllm_core.utils.pkg.source_locations(package)
|
||||
if not module_location: raise RuntimeError("Could not find the source location of OpenLLM. Make sure to unset OPENLLM_DEV_BUILD if you are developing OpenLLM.")
|
||||
pyproject_path = Path(module_location).parent.parent / "pyproject.toml"
|
||||
if not module_location: raise RuntimeError('Could not find the source location of OpenLLM. Make sure to unset OPENLLM_DEV_BUILD if you are developing OpenLLM.')
|
||||
pyproject_path = Path(module_location).parent.parent / 'pyproject.toml'
|
||||
if os.path.isfile(pyproject_path.__fspath__()):
|
||||
logger.info("Generating built wheels for package %s...", package)
|
||||
logger.info('Generating built wheels for package %s...', package)
|
||||
with IsolatedEnvBuilder() as env:
|
||||
builder = ProjectBuilder(pyproject_path.parent)
|
||||
builder.python_executable = env.executable
|
||||
builder.scripts_dir = env.scripts_dir
|
||||
env.install(builder.build_system_requires)
|
||||
return builder.build("wheel", path, config_settings={"--global-option": "--quiet"})
|
||||
raise RuntimeError("Custom OpenLLM build is currently not supported. Please install OpenLLM from PyPI or built it from Git source.")
|
||||
return builder.build('wheel', path, config_settings={'--global-option': '--quiet'})
|
||||
raise RuntimeError('Custom OpenLLM build is currently not supported. Please install OpenLLM from PyPI or built it from Git source.')
|
||||
def construct_python_options(llm: openllm.LLM[t.Any, t.Any], llm_fs: FS, extra_dependencies: tuple[str, ...] | None = None, adapter_map: dict[str, str | None] | None = None,) -> PythonOptions:
|
||||
packages = ["openllm", "scipy"] # apparently bnb misses this one
|
||||
if adapter_map is not None: packages += ["openllm[fine-tune]"]
|
||||
packages = ['openllm', 'scipy'] # apparently bnb misses this one
|
||||
if adapter_map is not None: packages += ['openllm[fine-tune]']
|
||||
# NOTE: add openllm to the default dependencies
|
||||
# if users has openllm custom built wheels, it will still respect
|
||||
# that since bentoml will always install dependencies from requirements.txt
|
||||
# first, then proceed to install everything inside the wheels/ folder.
|
||||
if extra_dependencies is not None: packages += [f"openllm[{k}]" for k in extra_dependencies]
|
||||
if extra_dependencies is not None: packages += [f'openllm[{k}]' for k in extra_dependencies]
|
||||
|
||||
req = llm.config["requirements"]
|
||||
req = llm.config['requirements']
|
||||
if req is not None: packages.extend(req)
|
||||
if str(os.environ.get("BENTOML_BUNDLE_LOCAL_BUILD", False)).lower() == "false": packages.append(f"bentoml>={'.'.join([str(i) for i in openllm_core.utils.pkg.pkg_version_info('bentoml')])}")
|
||||
if str(os.environ.get('BENTOML_BUNDLE_LOCAL_BUILD', False)).lower() == 'false': packages.append(f"bentoml>={'.'.join([str(i) for i in openllm_core.utils.pkg.pkg_version_info('bentoml')])}")
|
||||
|
||||
env = llm.config["env"]
|
||||
framework_envvar = env["framework_value"]
|
||||
if framework_envvar == "flax":
|
||||
env = llm.config['env']
|
||||
framework_envvar = env['framework_value']
|
||||
if framework_envvar == 'flax':
|
||||
if not openllm_core.utils.is_flax_available(): raise ValueError(f"Flax is not available, while {env.framework} is set to 'flax'")
|
||||
packages.extend([importlib.metadata.version("flax"), importlib.metadata.version("jax"), importlib.metadata.version("jaxlib")])
|
||||
elif framework_envvar == "tf":
|
||||
packages.extend([importlib.metadata.version('flax'), importlib.metadata.version('jax'), importlib.metadata.version('jaxlib')])
|
||||
elif framework_envvar == 'tf':
|
||||
if not openllm_core.utils.is_tf_available(): raise ValueError(f"TensorFlow is not available, while {env.framework} is set to 'tf'")
|
||||
candidates = (
|
||||
"tensorflow",
|
||||
"tensorflow-cpu",
|
||||
"tensorflow-gpu",
|
||||
"tf-nightly",
|
||||
"tf-nightly-cpu",
|
||||
"tf-nightly-gpu",
|
||||
"intel-tensorflow",
|
||||
"intel-tensorflow-avx512",
|
||||
"tensorflow-rocm",
|
||||
"tensorflow-macos",
|
||||
'tensorflow',
|
||||
'tensorflow-cpu',
|
||||
'tensorflow-gpu',
|
||||
'tf-nightly',
|
||||
'tf-nightly-cpu',
|
||||
'tf-nightly-gpu',
|
||||
'intel-tensorflow',
|
||||
'intel-tensorflow-avx512',
|
||||
'tensorflow-rocm',
|
||||
'tensorflow-macos',
|
||||
)
|
||||
# For the metadata, we have to look for both tensorflow and tensorflow-cpu
|
||||
for candidate in candidates:
|
||||
try:
|
||||
pkgver = importlib.metadata.version(candidate)
|
||||
if pkgver == candidate: packages.extend(["tensorflow"])
|
||||
if pkgver == candidate: packages.extend(['tensorflow'])
|
||||
else:
|
||||
_tf_version = importlib.metadata.version(candidate)
|
||||
packages.extend([f"tensorflow>={_tf_version}"])
|
||||
packages.extend([f'tensorflow>={_tf_version}'])
|
||||
break
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
pass # Ok to ignore here since we actually need to check for all possible tensorflow distribution.
|
||||
else:
|
||||
if not openllm_core.utils.is_torch_available(): raise ValueError("PyTorch is not available. Make sure to have it locally installed.")
|
||||
if not openllm_core.utils.is_torch_available(): raise ValueError('PyTorch is not available. Make sure to have it locally installed.')
|
||||
packages.extend([f'torch>={importlib.metadata.version("torch")}'])
|
||||
wheels: list[str] = []
|
||||
built_wheels: list[str | None] = [
|
||||
build_editable(llm_fs.getsyspath("/"), t.cast(t.Literal["openllm", "openllm_core", "openllm_client"], p)) for p in ("openllm_core", "openllm_client", "openllm")
|
||||
build_editable(llm_fs.getsyspath('/'), t.cast(t.Literal['openllm', 'openllm_core', 'openllm_client'], p)) for p in ('openllm_core', 'openllm_client', 'openllm')
|
||||
]
|
||||
if all(i for i in built_wheels): wheels.extend([llm_fs.getsyspath(f"/{i.split('/')[-1]}") for i in t.cast(t.List[str], built_wheels)])
|
||||
return PythonOptions(packages=packages, wheels=wheels, lock_packages=False, extra_index_url=["https://download.pytorch.org/whl/cu118"])
|
||||
return PythonOptions(packages=packages, wheels=wheels, lock_packages=False, extra_index_url=['https://download.pytorch.org/whl/cu118'])
|
||||
def construct_docker_options(
|
||||
llm: openllm.LLM[t.Any, t.Any],
|
||||
_: FS,
|
||||
@@ -94,39 +94,39 @@ def construct_docker_options(
|
||||
bettertransformer: bool | None,
|
||||
adapter_map: dict[str, str | None] | None,
|
||||
dockerfile_template: str | None,
|
||||
runtime: t.Literal["ggml", "transformers"],
|
||||
serialisation_format: t.Literal["safetensors", "legacy"],
|
||||
runtime: t.Literal['ggml', 'transformers'],
|
||||
serialisation_format: t.Literal['safetensors', 'legacy'],
|
||||
container_registry: LiteralContainerRegistry,
|
||||
container_version_strategy: LiteralContainerVersionStrategy
|
||||
) -> DockerOptions:
|
||||
from openllm.cli._factory import parse_config_options
|
||||
environ = parse_config_options(llm.config, llm.config["timeout"], workers_per_resource, None, True, os.environ.copy())
|
||||
env: openllm_core.utils.EnvVarMixin = llm.config["env"]
|
||||
if env["framework_value"] == "vllm": serialisation_format = "legacy"
|
||||
environ = parse_config_options(llm.config, llm.config['timeout'], workers_per_resource, None, True, os.environ.copy())
|
||||
env: openllm_core.utils.EnvVarMixin = llm.config['env']
|
||||
if env['framework_value'] == 'vllm': serialisation_format = 'legacy'
|
||||
env_dict = {
|
||||
env.framework: env["framework_value"],
|
||||
env.framework: env['framework_value'],
|
||||
env.config: f"'{llm.config.model_dump_json().decode()}'",
|
||||
env.model_id: f"/home/bentoml/bento/models/{llm.tag.path()}",
|
||||
"OPENLLM_MODEL": llm.config["model_name"],
|
||||
"OPENLLM_SERIALIZATION": serialisation_format,
|
||||
"OPENLLM_ADAPTER_MAP": f"'{orjson.dumps(adapter_map).decode()}'",
|
||||
"BENTOML_DEBUG": str(True),
|
||||
"BENTOML_QUIET": str(False),
|
||||
"BENTOML_CONFIG_OPTIONS": f"'{environ['BENTOML_CONFIG_OPTIONS']}'",
|
||||
env.model_id: f'/home/bentoml/bento/models/{llm.tag.path()}',
|
||||
'OPENLLM_MODEL': llm.config['model_name'],
|
||||
'OPENLLM_SERIALIZATION': serialisation_format,
|
||||
'OPENLLM_ADAPTER_MAP': f"'{orjson.dumps(adapter_map).decode()}'",
|
||||
'BENTOML_DEBUG': str(True),
|
||||
'BENTOML_QUIET': str(False),
|
||||
'BENTOML_CONFIG_OPTIONS': f"'{environ['BENTOML_CONFIG_OPTIONS']}'",
|
||||
}
|
||||
if adapter_map: env_dict["BITSANDBYTES_NOWELCOME"] = os.environ.get("BITSANDBYTES_NOWELCOME", "1")
|
||||
if adapter_map: env_dict['BITSANDBYTES_NOWELCOME'] = os.environ.get('BITSANDBYTES_NOWELCOME', '1')
|
||||
|
||||
# We need to handle None separately here, as env from subprocess doesn't accept None value.
|
||||
_env = openllm_core.utils.EnvVarMixin(llm.config["model_name"], bettertransformer=bettertransformer, quantize=quantize, runtime=runtime)
|
||||
_env = openllm_core.utils.EnvVarMixin(llm.config['model_name'], bettertransformer=bettertransformer, quantize=quantize, runtime=runtime)
|
||||
|
||||
env_dict[_env.bettertransformer] = str(_env["bettertransformer_value"])
|
||||
if _env["quantize_value"] is not None: env_dict[_env.quantize] = t.cast(str, _env["quantize_value"])
|
||||
env_dict[_env.runtime] = _env["runtime_value"]
|
||||
return DockerOptions(base_image=f"{oci.CONTAINER_NAMES[container_registry]}:{oci.get_base_container_tag(container_version_strategy)}", env=env_dict, dockerfile_template=dockerfile_template)
|
||||
OPENLLM_MODEL_NAME = "# openllm: model name"
|
||||
OPENLLM_MODEL_ADAPTER_MAP = "# openllm: model adapter map"
|
||||
env_dict[_env.bettertransformer] = str(_env['bettertransformer_value'])
|
||||
if _env['quantize_value'] is not None: env_dict[_env.quantize] = t.cast(str, _env['quantize_value'])
|
||||
env_dict[_env.runtime] = _env['runtime_value']
|
||||
return DockerOptions(base_image=f'{oci.CONTAINER_NAMES[container_registry]}:{oci.get_base_container_tag(container_version_strategy)}', env=env_dict, dockerfile_template=dockerfile_template)
|
||||
OPENLLM_MODEL_NAME = '# openllm: model name'
|
||||
OPENLLM_MODEL_ADAPTER_MAP = '# openllm: model adapter map'
|
||||
class ModelNameFormatter(string.Formatter):
|
||||
model_keyword: LiteralString = "__model_name__"
|
||||
model_keyword: LiteralString = '__model_name__'
|
||||
|
||||
def __init__(self, model_name: str):
|
||||
"""The formatter that extends model_name to be formatted the 'service.py'."""
|
||||
@@ -143,23 +143,23 @@ class ModelNameFormatter(string.Formatter):
|
||||
except ValueError:
|
||||
return False
|
||||
class ModelIdFormatter(ModelNameFormatter):
|
||||
model_keyword: LiteralString = "__model_id__"
|
||||
model_keyword: LiteralString = '__model_id__'
|
||||
class ModelAdapterMapFormatter(ModelNameFormatter):
|
||||
model_keyword: LiteralString = "__model_adapter_map__"
|
||||
_service_file = Path(os.path.abspath(__file__)).parent.parent / "_service.py"
|
||||
model_keyword: LiteralString = '__model_adapter_map__'
|
||||
_service_file = Path(os.path.abspath(__file__)).parent.parent / '_service.py'
|
||||
def write_service(llm: openllm.LLM[t.Any, t.Any], adapter_map: dict[str, str | None] | None, llm_fs: FS) -> None:
|
||||
from openllm_core.utils import DEBUG
|
||||
model_name = llm.config["model_name"]
|
||||
logger.debug("Generating service file for %s at %s (dir=%s)", model_name, llm.config["service_name"], llm_fs.getsyspath("/"))
|
||||
with open(_service_file.__fspath__(), "r") as f:
|
||||
model_name = llm.config['model_name']
|
||||
logger.debug('Generating service file for %s at %s (dir=%s)', model_name, llm.config['service_name'], llm_fs.getsyspath('/'))
|
||||
with open(_service_file.__fspath__(), 'r') as f:
|
||||
src_contents = f.readlines()
|
||||
for it in src_contents:
|
||||
if OPENLLM_MODEL_NAME in it: src_contents[src_contents.index(it)] = (ModelNameFormatter(model_name).vformat(it)[:-(len(OPENLLM_MODEL_NAME) + 3)] + "\n")
|
||||
if OPENLLM_MODEL_NAME in it: src_contents[src_contents.index(it)] = (ModelNameFormatter(model_name).vformat(it)[:-(len(OPENLLM_MODEL_NAME) + 3)] + '\n')
|
||||
elif OPENLLM_MODEL_ADAPTER_MAP in it:
|
||||
src_contents[src_contents.index(it)] = (ModelAdapterMapFormatter(orjson.dumps(adapter_map).decode()).vformat(it)[:-(len(OPENLLM_MODEL_ADAPTER_MAP) + 3)] + "\n")
|
||||
script = f"# GENERATED BY 'openllm build {model_name}'. DO NOT EDIT\n\n" + "".join(src_contents)
|
||||
if DEBUG: logger.info("Generated script:\n%s", script)
|
||||
llm_fs.writetext(llm.config["service_name"], script)
|
||||
src_contents[src_contents.index(it)] = (ModelAdapterMapFormatter(orjson.dumps(adapter_map).decode()).vformat(it)[:-(len(OPENLLM_MODEL_ADAPTER_MAP) + 3)] + '\n')
|
||||
script = f"# GENERATED BY 'openllm build {model_name}'. DO NOT EDIT\n\n" + ''.join(src_contents)
|
||||
if DEBUG: logger.info('Generated script:\n%s', script)
|
||||
llm_fs.writetext(llm.config['service_name'], script)
|
||||
@inject
|
||||
def create_bento(
|
||||
bento_tag: bentoml.Tag,
|
||||
@@ -171,20 +171,20 @@ def create_bento(
|
||||
dockerfile_template: str | None,
|
||||
adapter_map: dict[str, str | None] | None = None,
|
||||
extra_dependencies: tuple[str, ...] | None = None,
|
||||
runtime: t.Literal["ggml", "transformers"] = "transformers",
|
||||
serialisation_format: t.Literal["safetensors", "legacy"] = "safetensors",
|
||||
container_registry: LiteralContainerRegistry = "ecr",
|
||||
container_version_strategy: LiteralContainerVersionStrategy = "release",
|
||||
runtime: t.Literal['ggml', 'transformers'] = 'transformers',
|
||||
serialisation_format: t.Literal['safetensors', 'legacy'] = 'safetensors',
|
||||
container_registry: LiteralContainerRegistry = 'ecr',
|
||||
container_version_strategy: LiteralContainerVersionStrategy = 'release',
|
||||
_bento_store: BentoStore = Provide[BentoMLContainer.bento_store],
|
||||
_model_store: ModelStore = Provide[BentoMLContainer.model_store]
|
||||
) -> bentoml.Bento:
|
||||
framework_envvar = llm.config["env"]["framework_value"]
|
||||
framework_envvar = llm.config['env']['framework_value']
|
||||
labels = dict(llm.identifying_params)
|
||||
labels.update({"_type": llm.llm_type, "_framework": framework_envvar, "start_name": llm.config["start_name"], "base_name_or_path": llm.model_id, "bundler": "openllm.bundle"})
|
||||
labels.update({'_type': llm.llm_type, '_framework': framework_envvar, 'start_name': llm.config['start_name'], 'base_name_or_path': llm.model_id, 'bundler': 'openllm.bundle'})
|
||||
if adapter_map: labels.update(adapter_map)
|
||||
if isinstance(workers_per_resource, str):
|
||||
if workers_per_resource == "round_robin": workers_per_resource = 1.0
|
||||
elif workers_per_resource == "conserved": workers_per_resource = 1.0 if openllm_core.utils.device_count() == 0 else float(1 / openllm_core.utils.device_count())
|
||||
if workers_per_resource == 'round_robin': workers_per_resource = 1.0
|
||||
elif workers_per_resource == 'conserved': workers_per_resource = 1.0 if openllm_core.utils.device_count() == 0 else float(1 / openllm_core.utils.device_count())
|
||||
else:
|
||||
try:
|
||||
workers_per_resource = float(workers_per_resource)
|
||||
@@ -192,18 +192,18 @@ def create_bento(
|
||||
raise ValueError("'workers_per_resource' only accept ['round_robin', 'conserved'] as possible strategies.") from None
|
||||
elif isinstance(workers_per_resource, int):
|
||||
workers_per_resource = float(workers_per_resource)
|
||||
logger.info("Building Bento for '%s'", llm.config["start_name"])
|
||||
logger.info("Building Bento for '%s'", llm.config['start_name'])
|
||||
# add service.py definition to this temporary folder
|
||||
write_service(llm, adapter_map, llm_fs)
|
||||
|
||||
llm_spec = ModelSpec.from_item({"tag": str(llm.tag), "alias": llm.tag.name})
|
||||
llm_spec = ModelSpec.from_item({'tag': str(llm.tag), 'alias': llm.tag.name})
|
||||
build_config = BentoBuildConfig(
|
||||
service=f"{llm.config['service_name']}:svc",
|
||||
name=bento_tag.name,
|
||||
labels=labels,
|
||||
description=f"OpenLLM service for {llm.config['start_name']}",
|
||||
include=list(llm_fs.walk.files()),
|
||||
exclude=["/venv", "/.venv", "__pycache__/", "*.py[cod]", "*$py.class"],
|
||||
exclude=['/venv', '/.venv', '__pycache__/', '*.py[cod]', '*$py.class'],
|
||||
python=construct_python_options(llm, llm_fs, extra_dependencies, adapter_map),
|
||||
models=[llm_spec],
|
||||
docker=construct_docker_options(
|
||||
@@ -211,20 +211,20 @@ def create_bento(
|
||||
)
|
||||
)
|
||||
|
||||
bento = bentoml.Bento.create(build_config=build_config, version=bento_tag.version, build_ctx=llm_fs.getsyspath("/"))
|
||||
bento = bentoml.Bento.create(build_config=build_config, version=bento_tag.version, build_ctx=llm_fs.getsyspath('/'))
|
||||
# NOTE: the model_id_path here are only used for setting this environment variable within the container built with for BentoLLM.
|
||||
service_fs_path = fs.path.join("src", llm.config["service_name"])
|
||||
service_fs_path = fs.path.join('src', llm.config['service_name'])
|
||||
service_path = bento._fs.getsyspath(service_fs_path)
|
||||
with open(service_path, "r") as f:
|
||||
with open(service_path, 'r') as f:
|
||||
service_contents = f.readlines()
|
||||
|
||||
for it in service_contents:
|
||||
if "__bento_name__" in it: service_contents[service_contents.index(it)] = it.format(__bento_name__=str(bento.tag))
|
||||
if '__bento_name__' in it: service_contents[service_contents.index(it)] = it.format(__bento_name__=str(bento.tag))
|
||||
|
||||
script = "".join(service_contents)
|
||||
if openllm_core.utils.DEBUG: logger.info("Generated script:\n%s", script)
|
||||
script = ''.join(service_contents)
|
||||
if openllm_core.utils.DEBUG: logger.info('Generated script:\n%s', script)
|
||||
|
||||
bento._fs.writetext(service_fs_path, script)
|
||||
if "model_store" in inspect.signature(bento.save).parameters: return bento.save(bento_store=_bento_store, model_store=_model_store)
|
||||
if 'model_store' in inspect.signature(bento.save).parameters: return bento.save(bento_store=_bento_store, model_store=_model_store)
|
||||
# backward arguments. `model_store` is added recently
|
||||
return bento.save(bento_store=_bento_store)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# mypy: disable-error-code="misc"
|
||||
"""OCI-related utilities for OpenLLM. This module is considered to be internal and API are subjected to change."""
|
||||
'''OCI-related utilities for OpenLLM. This module is considered to be internal and API are subjected to change.'''
|
||||
from __future__ import annotations
|
||||
import functools, importlib, logging, os, pathlib, shutil, subprocess, typing as t, openllm_core
|
||||
from datetime import datetime, timedelta, timezone
|
||||
@@ -10,24 +10,24 @@ if t.TYPE_CHECKING:
|
||||
from openllm_core._typing_compat import LiteralContainerRegistry, LiteralContainerVersionStrategy
|
||||
from ghapi import all
|
||||
from openllm_core._typing_compat import RefTuple, LiteralString
|
||||
all = openllm_core.utils.LazyLoader("all", globals(), "ghapi.all") # noqa: F811
|
||||
all = openllm_core.utils.LazyLoader('all', globals(), 'ghapi.all') # noqa: F811
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_BUILDER = bentoml.container.get_backend("buildx")
|
||||
ROOT_DIR = pathlib.Path(os.path.abspath("__file__")).parent.parent.parent
|
||||
_BUILDER = bentoml.container.get_backend('buildx')
|
||||
ROOT_DIR = pathlib.Path(os.path.abspath('__file__')).parent.parent.parent
|
||||
|
||||
# XXX: This registry will be hard code for now for easier to maintain
|
||||
# but in the future, we can infer based on git repo and everything to make it more options for users
|
||||
# to build the base image. For now, all of the base image will be <registry>/bentoml/openllm:...
|
||||
# NOTE: The ECR registry is the public one and currently only @bentoml team has access to push it.
|
||||
_CONTAINER_REGISTRY: dict[LiteralContainerRegistry, str] = {"docker": "docker.io/bentoml/openllm", "gh": "ghcr.io/bentoml/openllm", "ecr": "public.ecr.aws/y5w8i4y6/bentoml/openllm"}
|
||||
_CONTAINER_REGISTRY: dict[LiteralContainerRegistry, str] = {'docker': 'docker.io/bentoml/openllm', 'gh': 'ghcr.io/bentoml/openllm', 'ecr': 'public.ecr.aws/y5w8i4y6/bentoml/openllm'}
|
||||
|
||||
# TODO: support custom fork. Currently it only support openllm main.
|
||||
_OWNER = "bentoml"
|
||||
_REPO = "openllm"
|
||||
_OWNER = 'bentoml'
|
||||
_REPO = 'openllm'
|
||||
|
||||
_module_location = openllm_core.utils.pkg.source_locations("openllm")
|
||||
_module_location = openllm_core.utils.pkg.source_locations('openllm')
|
||||
@functools.lru_cache
|
||||
@openllm_core.utils.apply(str.lower)
|
||||
def get_base_container_name(reg: LiteralContainerRegistry) -> str:
|
||||
@@ -35,23 +35,23 @@ def get_base_container_name(reg: LiteralContainerRegistry) -> str:
|
||||
def _convert_version_from_string(s: str) -> VersionInfo:
|
||||
return VersionInfo.from_version_string(s)
|
||||
def _commit_time_range(r: int = 5) -> str:
|
||||
return (datetime.now(timezone.utc) - timedelta(days=r)).strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
return (datetime.now(timezone.utc) - timedelta(days=r)).strftime('%Y-%m-%dT%H:%M:%SZ')
|
||||
class VersionNotSupported(openllm.exceptions.OpenLLMException):
|
||||
"""Raised when the stable release is too low that it doesn't include OpenLLM base container."""
|
||||
_RefTuple: type[RefTuple] = openllm_core.utils.codegen.make_attr_tuple_class("_RefTuple", ["git_hash", "version", "strategy"])
|
||||
_RefTuple: type[RefTuple] = openllm_core.utils.codegen.make_attr_tuple_class('_RefTuple', ['git_hash', 'version', 'strategy'])
|
||||
def nightly_resolver(cls: type[RefResolver]) -> str:
|
||||
# NOTE: all openllm container will have sha-<git_hash[:7]>
|
||||
# This will use docker to run skopeo to determine the correct latest tag that is available
|
||||
# If docker is not found, then fallback to previous behaviour. Which the container might not exists.
|
||||
docker_bin = shutil.which("docker")
|
||||
docker_bin = shutil.which('docker')
|
||||
if docker_bin is None:
|
||||
logger.warning(
|
||||
"To get the correct available nightly container, make sure to have docker available. Fallback to previous behaviour for determine nightly hash (container might not exists due to the lack of GPU machine at a time. See https://github.com/bentoml/OpenLLM/pkgs/container/openllm for available image.)"
|
||||
'To get the correct available nightly container, make sure to have docker available. Fallback to previous behaviour for determine nightly hash (container might not exists due to the lack of GPU machine at a time. See https://github.com/bentoml/OpenLLM/pkgs/container/openllm for available image.)'
|
||||
)
|
||||
commits = t.cast("list[dict[str, t.Any]]", cls._ghapi.repos.list_commits(since=_commit_time_range()))
|
||||
return next(f'sha-{it["sha"][:7]}' for it in commits if "[skip ci]" not in it["commit"]["message"])
|
||||
commits = t.cast('list[dict[str, t.Any]]', cls._ghapi.repos.list_commits(since=_commit_time_range()))
|
||||
return next(f'sha-{it["sha"][:7]}' for it in commits if '[skip ci]' not in it['commit']['message'])
|
||||
# now is the correct behaviour
|
||||
return orjson.loads(subprocess.check_output([docker_bin, "run", "--rm", "-it", "quay.io/skopeo/stable:latest", "list-tags", "docker://ghcr.io/bentoml/openllm"]).decode().strip())["Tags"][-2]
|
||||
return orjson.loads(subprocess.check_output([docker_bin, 'run', '--rm', '-it', 'quay.io/skopeo/stable:latest', 'list-tags', 'docker://ghcr.io/bentoml/openllm']).decode().strip())['Tags'][-2]
|
||||
@attr.attrs(eq=False, order=False, slots=True, frozen=True)
|
||||
class RefResolver:
|
||||
git_hash: str = attr.field()
|
||||
@@ -61,7 +61,7 @@ class RefResolver:
|
||||
|
||||
@classmethod
|
||||
def _nightly_ref(cls) -> RefTuple:
|
||||
return _RefTuple((nightly_resolver(cls), "refs/heads/main", "nightly"))
|
||||
return _RefTuple((nightly_resolver(cls), 'refs/heads/main', 'nightly'))
|
||||
|
||||
@classmethod
|
||||
def _release_ref(cls, version_str: str | None = None) -> RefTuple:
|
||||
@@ -69,80 +69,80 @@ class RefResolver:
|
||||
if version_str is None:
|
||||
# NOTE: This strategy will only support openllm>0.2.12
|
||||
meta: dict[str, t.Any] = cls._ghapi.repos.get_latest_release()
|
||||
version_str = meta["name"].lstrip("v")
|
||||
version: tuple[str, str | None] = (cls._ghapi.git.get_ref(ref=f"tags/{meta['name']}")["object"]["sha"], version_str)
|
||||
version_str = meta['name'].lstrip('v')
|
||||
version: tuple[str, str | None] = (cls._ghapi.git.get_ref(ref=f"tags/{meta['name']}")['object']['sha'], version_str)
|
||||
else:
|
||||
version = ("", version_str)
|
||||
version = ('', version_str)
|
||||
if openllm_core.utils.VersionInfo.from_version_string(t.cast(str, version_str)) < (0, 2, 12):
|
||||
raise VersionNotSupported(f"Version {version_str} doesn't support OpenLLM base container. Consider using 'nightly' or upgrade 'openllm>=0.2.12'")
|
||||
return _RefTuple((*version, "release" if _use_base_strategy else "custom"))
|
||||
return _RefTuple((*version, 'release' if _use_base_strategy else 'custom'))
|
||||
|
||||
@classmethod
|
||||
@functools.lru_cache(maxsize=64)
|
||||
def from_strategy(cls, strategy_or_version: t.Literal["release", "nightly"] | LiteralString | None = None) -> RefResolver:
|
||||
def from_strategy(cls, strategy_or_version: t.Literal['release', 'nightly'] | LiteralString | None = None) -> RefResolver:
|
||||
# using default strategy
|
||||
if strategy_or_version is None or strategy_or_version == "release": return cls(*cls._release_ref())
|
||||
elif strategy_or_version == "latest": return cls("latest", "0.0.0", "latest")
|
||||
elif strategy_or_version == "nightly":
|
||||
if strategy_or_version is None or strategy_or_version == 'release': return cls(*cls._release_ref())
|
||||
elif strategy_or_version == 'latest': return cls('latest', '0.0.0', 'latest')
|
||||
elif strategy_or_version == 'nightly':
|
||||
_ref = cls._nightly_ref()
|
||||
return cls(_ref[0], "0.0.0", _ref[-1])
|
||||
return cls(_ref[0], '0.0.0', _ref[-1])
|
||||
else:
|
||||
logger.warning("Using custom %s. Make sure that it is at lease 0.2.12 for base container support.", strategy_or_version)
|
||||
logger.warning('Using custom %s. Make sure that it is at lease 0.2.12 for base container support.', strategy_or_version)
|
||||
return cls(*cls._release_ref(version_str=strategy_or_version))
|
||||
|
||||
@property
|
||||
def tag(self) -> str:
|
||||
# NOTE: latest tag can also be nightly, but discouraged to use it. For nightly refer to use sha-<git_hash_short>
|
||||
if self.strategy == "latest": return "latest"
|
||||
elif self.strategy == "nightly": return self.git_hash
|
||||
if self.strategy == 'latest': return 'latest'
|
||||
elif self.strategy == 'nightly': return self.git_hash
|
||||
else: return repr(self.version)
|
||||
@functools.lru_cache(maxsize=256)
|
||||
def get_base_container_tag(strategy: LiteralContainerVersionStrategy | None = None) -> str:
|
||||
return RefResolver.from_strategy(strategy).tag
|
||||
def build_container(
|
||||
registries: LiteralContainerRegistry | t.Sequence[LiteralContainerRegistry] | None = None,
|
||||
version_strategy: LiteralContainerVersionStrategy = "release",
|
||||
version_strategy: LiteralContainerVersionStrategy = 'release',
|
||||
push: bool = False,
|
||||
machine: bool = False
|
||||
) -> dict[str | LiteralContainerRegistry, str]:
|
||||
try:
|
||||
if not _BUILDER.health(): raise openllm.exceptions.Error
|
||||
except (openllm.exceptions.Error, subprocess.CalledProcessError):
|
||||
raise RuntimeError("Building base container requires BuildKit (via Buildx) to be installed. See https://docs.docker.com/build/buildx/install/ for instalation instruction.") from None
|
||||
if openllm_core.utils.device_count() == 0: raise RuntimeError("Building base container requires GPUs (None available)")
|
||||
if not shutil.which("nvidia-container-runtime"): raise RuntimeError("NVIDIA Container Toolkit is required to compile CUDA kernel in container.")
|
||||
raise RuntimeError('Building base container requires BuildKit (via Buildx) to be installed. See https://docs.docker.com/build/buildx/install/ for instalation instruction.') from None
|
||||
if openllm_core.utils.device_count() == 0: raise RuntimeError('Building base container requires GPUs (None available)')
|
||||
if not shutil.which('nvidia-container-runtime'): raise RuntimeError('NVIDIA Container Toolkit is required to compile CUDA kernel in container.')
|
||||
if not _module_location: raise RuntimeError("Failed to determine source location of 'openllm'. (Possible broken installation)")
|
||||
pyproject_path = pathlib.Path(_module_location).parent.parent / "pyproject.toml"
|
||||
pyproject_path = pathlib.Path(_module_location).parent.parent / 'pyproject.toml'
|
||||
if not pyproject_path.exists(): raise ValueError("This utility can only be run within OpenLLM git repository. Clone it first with 'git clone https://github.com/bentoml/OpenLLM.git'")
|
||||
if not registries:
|
||||
tags: dict[str | LiteralContainerRegistry, str] = {
|
||||
alias: f"{value}:{get_base_container_tag(version_strategy)}" for alias, value in _CONTAINER_REGISTRY.items()
|
||||
alias: f'{value}:{get_base_container_tag(version_strategy)}' for alias, value in _CONTAINER_REGISTRY.items()
|
||||
} # default to all registries with latest tag strategy
|
||||
else:
|
||||
registries = [registries] if isinstance(registries, str) else list(registries)
|
||||
tags = {name: f"{_CONTAINER_REGISTRY[name]}:{get_base_container_tag(version_strategy)}" for name in registries}
|
||||
tags = {name: f'{_CONTAINER_REGISTRY[name]}:{get_base_container_tag(version_strategy)}' for name in registries}
|
||||
try:
|
||||
outputs = _BUILDER.build(
|
||||
file=pathlib.Path(__file__).parent.joinpath("Dockerfile").resolve().__fspath__(),
|
||||
file=pathlib.Path(__file__).parent.joinpath('Dockerfile').resolve().__fspath__(),
|
||||
context_path=pyproject_path.parent.__fspath__(),
|
||||
tag=tuple(tags.values()),
|
||||
push=push,
|
||||
progress="plain" if openllm_core.utils.get_debug_mode() else "auto",
|
||||
progress='plain' if openllm_core.utils.get_debug_mode() else 'auto',
|
||||
quiet=machine
|
||||
)
|
||||
if machine and outputs is not None: tags["image_sha"] = outputs.decode("utf-8").strip()
|
||||
if machine and outputs is not None: tags['image_sha'] = outputs.decode('utf-8').strip()
|
||||
except Exception as err:
|
||||
raise openllm.exceptions.OpenLLMException(f"Failed to containerize base container images (Scroll up to see error above, or set OPENLLMDEVDEBUG=True for more traceback):\n{err}") from err
|
||||
raise openllm.exceptions.OpenLLMException(f'Failed to containerize base container images (Scroll up to see error above, or set OPENLLMDEVDEBUG=True for more traceback):\n{err}') from err
|
||||
return tags
|
||||
if t.TYPE_CHECKING:
|
||||
CONTAINER_NAMES: dict[LiteralContainerRegistry, str]
|
||||
supported_registries: list[str]
|
||||
|
||||
__all__ = ["CONTAINER_NAMES", "get_base_container_tag", "build_container", "get_base_container_name", "supported_registries", "RefResolver"]
|
||||
__all__ = ['CONTAINER_NAMES', 'get_base_container_tag', 'build_container', 'get_base_container_name', 'supported_registries', 'RefResolver']
|
||||
def __dir__() -> list[str]:
|
||||
return sorted(__all__)
|
||||
def __getattr__(name: str) -> t.Any:
|
||||
if name == "supported_registries": return functools.lru_cache(1)(lambda: list(_CONTAINER_REGISTRY))()
|
||||
elif name == "CONTAINER_NAMES": return _CONTAINER_REGISTRY
|
||||
elif name in __all__: return importlib.import_module("." + name, __name__)
|
||||
else: raise AttributeError(f"{name} does not exists under {__name__}")
|
||||
if name == 'supported_registries': return functools.lru_cache(1)(lambda: list(_CONTAINER_REGISTRY))()
|
||||
elif name == 'CONTAINER_NAMES': return _CONTAINER_REGISTRY
|
||||
elif name in __all__: return importlib.import_module('.' + name, __name__)
|
||||
else: raise AttributeError(f'{name} does not exists under {__name__}')
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""OpenLLM CLI.
|
||||
'''OpenLLM CLI.
|
||||
|
||||
For more information see ``openllm -h``.
|
||||
"""
|
||||
'''
|
||||
|
||||
@@ -13,21 +13,21 @@ if t.TYPE_CHECKING:
|
||||
from openllm_core._configuration import LLMConfig
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
P = ParamSpec("P")
|
||||
LiteralOutput = t.Literal["json", "pretty", "porcelain"]
|
||||
P = ParamSpec('P')
|
||||
LiteralOutput = t.Literal['json', 'pretty', 'porcelain']
|
||||
|
||||
_AnyCallable = t.Callable[..., t.Any]
|
||||
FC = t.TypeVar("FC", bound=t.Union[_AnyCallable, click.Command])
|
||||
FC = t.TypeVar('FC', bound=t.Union[_AnyCallable, click.Command])
|
||||
def bento_complete_envvar(ctx: click.Context, param: click.Parameter, incomplete: str) -> list[sc.CompletionItem]:
|
||||
return [sc.CompletionItem(str(it.tag), help="Bento") for it in bentoml.list() if str(it.tag).startswith(incomplete) and all(k in it.info.labels for k in {"start_name", "bundler"})]
|
||||
return [sc.CompletionItem(str(it.tag), help='Bento') for it in bentoml.list() if str(it.tag).startswith(incomplete) and all(k in it.info.labels for k in {'start_name', 'bundler'})]
|
||||
def model_complete_envvar(ctx: click.Context, param: click.Parameter, incomplete: str) -> list[sc.CompletionItem]:
|
||||
return [sc.CompletionItem(inflection.dasherize(it), help="Model") for it in openllm.CONFIG_MAPPING if it.startswith(incomplete)]
|
||||
return [sc.CompletionItem(inflection.dasherize(it), help='Model') for it in openllm.CONFIG_MAPPING if it.startswith(incomplete)]
|
||||
def parse_config_options(config: LLMConfig, server_timeout: int, workers_per_resource: float, device: t.Tuple[str, ...] | None, cors: bool, environ: DictStrAny) -> DictStrAny:
|
||||
# TODO: Support amd.com/gpu on k8s
|
||||
_bentoml_config_options_env = environ.pop("BENTOML_CONFIG_OPTIONS", "")
|
||||
_bentoml_config_options_env = environ.pop('BENTOML_CONFIG_OPTIONS', '')
|
||||
_bentoml_config_options_opts = [
|
||||
"tracing.sample_rate=1.0",
|
||||
f"api_server.traffic.timeout={server_timeout}",
|
||||
'tracing.sample_rate=1.0',
|
||||
f'api_server.traffic.timeout={server_timeout}',
|
||||
f'runners."llm-{config["start_name"]}-runner".traffic.timeout={config["timeout"]}',
|
||||
f'runners."llm-{config["start_name"]}-runner".workers_per_resource={workers_per_resource}'
|
||||
]
|
||||
@@ -36,18 +36,18 @@ def parse_config_options(config: LLMConfig, server_timeout: int, workers_per_res
|
||||
else: _bentoml_config_options_opts.append(f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"=[{device[0]}]')
|
||||
_bentoml_config_options_opts.append(f'runners."llm-generic-embedding".resources.cpu={openllm.get_resource({"cpu":"system"},"cpu")}')
|
||||
if cors:
|
||||
_bentoml_config_options_opts.extend(["api_server.http.cors.enabled=true", 'api_server.http.cors.access_control_allow_origins="*"'])
|
||||
_bentoml_config_options_opts.extend([f'api_server.http.cors.access_control_allow_methods[{idx}]="{it}"' for idx, it in enumerate(["GET", "OPTIONS", "POST", "HEAD", "PUT"])])
|
||||
_bentoml_config_options_env += " " if _bentoml_config_options_env else "" + " ".join(_bentoml_config_options_opts)
|
||||
environ["BENTOML_CONFIG_OPTIONS"] = _bentoml_config_options_env
|
||||
if DEBUG: logger.debug("Setting BENTOML_CONFIG_OPTIONS=%s", _bentoml_config_options_env)
|
||||
_bentoml_config_options_opts.extend(['api_server.http.cors.enabled=true', 'api_server.http.cors.access_control_allow_origins="*"'])
|
||||
_bentoml_config_options_opts.extend([f'api_server.http.cors.access_control_allow_methods[{idx}]="{it}"' for idx, it in enumerate(['GET', 'OPTIONS', 'POST', 'HEAD', 'PUT'])])
|
||||
_bentoml_config_options_env += ' ' if _bentoml_config_options_env else '' + ' '.join(_bentoml_config_options_opts)
|
||||
environ['BENTOML_CONFIG_OPTIONS'] = _bentoml_config_options_env
|
||||
if DEBUG: logger.debug('Setting BENTOML_CONFIG_OPTIONS=%s', _bentoml_config_options_env)
|
||||
return environ
|
||||
_adapter_mapping_key = "adapter_map"
|
||||
_adapter_mapping_key = 'adapter_map'
|
||||
def _id_callback(ctx: click.Context, _: click.Parameter, value: t.Tuple[str, ...] | None) -> None:
|
||||
if not value: return None
|
||||
if _adapter_mapping_key not in ctx.params: ctx.params[_adapter_mapping_key] = {}
|
||||
for v in value:
|
||||
adapter_id, *adapter_name = v.rsplit(":", maxsplit=1)
|
||||
adapter_id, *adapter_name = v.rsplit(':', maxsplit=1)
|
||||
# try to resolve the full path if users pass in relative,
|
||||
# currently only support one level of resolve path with current directory
|
||||
try:
|
||||
@@ -59,11 +59,11 @@ def _id_callback(ctx: click.Context, _: click.Parameter, value: t.Tuple[str, ...
|
||||
def start_command_factory(group: click.Group, model: str, _context_settings: DictStrAny | None = None, _serve_grpc: bool = False) -> click.Command:
|
||||
llm_config = openllm.AutoConfig.for_model(model)
|
||||
command_attrs: DictStrAny = dict(
|
||||
name=llm_config["model_name"],
|
||||
name=llm_config['model_name'],
|
||||
context_settings=_context_settings or termui.CONTEXT_SETTINGS,
|
||||
short_help=f"Start a LLMServer for '{model}'",
|
||||
aliases=[llm_config["start_name"]] if llm_config["name_type"] == "dasherize" else None,
|
||||
help=f"""\
|
||||
aliases=[llm_config['start_name']] if llm_config['name_type'] == 'dasherize' else None,
|
||||
help=f'''\
|
||||
{llm_config['env'].start_docstring}
|
||||
|
||||
\b
|
||||
@@ -81,13 +81,13 @@ Available official model_id(s): [default: {llm_config['default_id']}]
|
||||
|
||||
\b
|
||||
{orjson.dumps(llm_config['model_ids'], option=orjson.OPT_INDENT_2).decode()}
|
||||
""",
|
||||
''',
|
||||
)
|
||||
|
||||
if llm_config["requires_gpu"] and openllm.utils.device_count() < 1:
|
||||
if llm_config['requires_gpu'] and openllm.utils.device_count() < 1:
|
||||
# NOTE: The model requires GPU, therefore we will return a dummy command
|
||||
command_attrs.update({
|
||||
"short_help": "(Disabled because there is no GPU available)", "help": f"{model} is currently not available to run on your local machine because it requires GPU for inference."
|
||||
'short_help': '(Disabled because there is no GPU available)', 'help': f'{model} is currently not available to run on your local machine because it requires GPU for inference.'
|
||||
})
|
||||
return noop_command(group, llm_config, _serve_grpc, **command_attrs)
|
||||
|
||||
@@ -100,39 +100,39 @@ Available official model_id(s): [default: {llm_config['default_id']}]
|
||||
server_timeout: int,
|
||||
model_id: str | None,
|
||||
model_version: str | None,
|
||||
workers_per_resource: t.Literal["conserved", "round_robin"] | LiteralString,
|
||||
workers_per_resource: t.Literal['conserved', 'round_robin'] | LiteralString,
|
||||
device: t.Tuple[str, ...],
|
||||
quantize: t.Literal["int8", "int4", "gptq"] | None,
|
||||
quantize: t.Literal['int8', 'int4', 'gptq'] | None,
|
||||
bettertransformer: bool | None,
|
||||
runtime: t.Literal["ggml", "transformers"],
|
||||
runtime: t.Literal['ggml', 'transformers'],
|
||||
fast: bool,
|
||||
serialisation_format: t.Literal["safetensors", "legacy"],
|
||||
serialisation_format: t.Literal['safetensors', 'legacy'],
|
||||
cors: bool,
|
||||
adapter_id: str | None,
|
||||
return_process: bool,
|
||||
**attrs: t.Any,
|
||||
) -> LLMConfig | subprocess.Popen[bytes]:
|
||||
fast = str(fast).upper() in openllm.utils.ENV_VARS_TRUE_VALUES
|
||||
if serialisation_format == "safetensors" and quantize is not None and os.environ.get("OPENLLM_SERIALIZATION_WARNING", str(True)).upper() in openllm.utils.ENV_VARS_TRUE_VALUES:
|
||||
if serialisation_format == 'safetensors' and quantize is not None and os.environ.get('OPENLLM_SERIALIZATION_WARNING', str(True)).upper() in openllm.utils.ENV_VARS_TRUE_VALUES:
|
||||
termui.echo(
|
||||
f"'--quantize={quantize}' might not work with 'safetensors' serialisation format. Use with caution!. To silence this warning, set \"OPENLLM_SERIALIZATION_WARNING=False\"\nNote: You can always fallback to '--serialisation legacy' when running quantisation.",
|
||||
fg="yellow"
|
||||
fg='yellow'
|
||||
)
|
||||
adapter_map: dict[str, str | None] | None = attrs.pop(_adapter_mapping_key, None)
|
||||
config, server_attrs = llm_config.model_validate_click(**attrs)
|
||||
server_timeout = openllm.utils.first_not_none(server_timeout, default=config["timeout"])
|
||||
server_attrs.update({"working_dir": os.path.dirname(os.path.dirname(__file__)), "timeout": server_timeout})
|
||||
if _serve_grpc: server_attrs["grpc_protocol_version"] = "v1"
|
||||
server_timeout = openllm.utils.first_not_none(server_timeout, default=config['timeout'])
|
||||
server_attrs.update({'working_dir': os.path.dirname(os.path.dirname(__file__)), 'timeout': server_timeout})
|
||||
if _serve_grpc: server_attrs['grpc_protocol_version'] = 'v1'
|
||||
# NOTE: currently, theres no development args in bentoml.Server. To be fixed upstream.
|
||||
development = server_attrs.pop("development")
|
||||
server_attrs.setdefault("production", not development)
|
||||
wpr = openllm.utils.first_not_none(workers_per_resource, default=config["workers_per_resource"])
|
||||
development = server_attrs.pop('development')
|
||||
server_attrs.setdefault('production', not development)
|
||||
wpr = openllm.utils.first_not_none(workers_per_resource, default=config['workers_per_resource'])
|
||||
|
||||
if isinstance(wpr, str):
|
||||
if wpr == "round_robin": wpr = 1.0
|
||||
elif wpr == "conserved":
|
||||
if wpr == 'round_robin': wpr = 1.0
|
||||
elif wpr == 'conserved':
|
||||
if device and openllm.utils.device_count() == 0:
|
||||
termui.echo("--device will have no effect as there is no GPUs available", fg="yellow")
|
||||
termui.echo('--device will have no effect as there is no GPUs available', fg='yellow')
|
||||
wpr = 1.0
|
||||
else:
|
||||
available_gpu = len(device) if device else openllm.utils.device_count()
|
||||
@@ -144,7 +144,7 @@ Available official model_id(s): [default: {llm_config['default_id']}]
|
||||
|
||||
# Create a new model env to work with the envvar during CLI invocation
|
||||
env = openllm.utils.EnvVarMixin(
|
||||
config["model_name"], config.default_implementation(), model_id=model_id or config["default_id"], bettertransformer=bettertransformer, quantize=quantize, runtime=runtime
|
||||
config['model_name'], config.default_implementation(), model_id=model_id or config['default_id'], bettertransformer=bettertransformer, quantize=quantize, runtime=runtime
|
||||
)
|
||||
prerequisite_check(ctx, config, quantize, adapter_map, int(1 / wpr))
|
||||
|
||||
@@ -152,38 +152,38 @@ Available official model_id(s): [default: {llm_config['default_id']}]
|
||||
start_env = os.environ.copy()
|
||||
start_env = parse_config_options(config, server_timeout, wpr, device, cors, start_env)
|
||||
if fast:
|
||||
termui.echo(f"Fast mode is enabled. Make sure the model is available in local store before 'start': 'openllm import {model}{' --model-id ' + model_id if model_id else ''}'", fg="yellow")
|
||||
termui.echo(f"Fast mode is enabled. Make sure the model is available in local store before 'start': 'openllm import {model}{' --model-id ' + model_id if model_id else ''}'", fg='yellow')
|
||||
|
||||
start_env.update({
|
||||
"OPENLLM_MODEL": model,
|
||||
"BENTOML_DEBUG": str(openllm.utils.get_debug_mode()),
|
||||
"BENTOML_HOME": os.environ.get("BENTOML_HOME", BentoMLContainer.bentoml_home.get()),
|
||||
"OPENLLM_ADAPTER_MAP": orjson.dumps(adapter_map).decode(),
|
||||
"OPENLLM_SERIALIZATION": serialisation_format,
|
||||
env.runtime: env["runtime_value"],
|
||||
env.framework: env["framework_value"]
|
||||
'OPENLLM_MODEL': model,
|
||||
'BENTOML_DEBUG': str(openllm.utils.get_debug_mode()),
|
||||
'BENTOML_HOME': os.environ.get('BENTOML_HOME', BentoMLContainer.bentoml_home.get()),
|
||||
'OPENLLM_ADAPTER_MAP': orjson.dumps(adapter_map).decode(),
|
||||
'OPENLLM_SERIALIZATION': serialisation_format,
|
||||
env.runtime: env['runtime_value'],
|
||||
env.framework: env['framework_value']
|
||||
})
|
||||
if env["model_id_value"]: start_env[env.model_id] = str(env["model_id_value"])
|
||||
if env['model_id_value']: start_env[env.model_id] = str(env['model_id_value'])
|
||||
# NOTE: quantize and bettertransformer value is already assigned within env
|
||||
if bettertransformer is not None: start_env[env.bettertransformer] = str(env["bettertransformer_value"])
|
||||
if quantize is not None: start_env[env.quantize] = str(t.cast(str, env["quantize_value"]))
|
||||
if bettertransformer is not None: start_env[env.bettertransformer] = str(env['bettertransformer_value'])
|
||||
if quantize is not None: start_env[env.quantize] = str(t.cast(str, env['quantize_value']))
|
||||
|
||||
llm = openllm.utils.infer_auto_class(env["framework_value"]).for_model(
|
||||
llm = openllm.utils.infer_auto_class(env['framework_value']).for_model(
|
||||
model, model_id=start_env[env.model_id], model_version=model_version, llm_config=config, ensure_available=not fast, adapter_map=adapter_map, serialisation=serialisation_format
|
||||
)
|
||||
start_env.update({env.config: llm.config.model_dump_json().decode()})
|
||||
|
||||
server = bentoml.GrpcServer("_service:svc", **server_attrs) if _serve_grpc else bentoml.HTTPServer("_service:svc", **server_attrs)
|
||||
server = bentoml.GrpcServer('_service:svc', **server_attrs) if _serve_grpc else bentoml.HTTPServer('_service:svc', **server_attrs)
|
||||
openllm.utils.analytics.track_start_init(llm.config)
|
||||
|
||||
def next_step(model_name: str, adapter_map: DictStrAny | None) -> None:
|
||||
cmd_name = f"openllm build {model_name}"
|
||||
if adapter_map is not None: cmd_name += " " + " ".join([f"--adapter-id {s}" for s in [f"{p}:{name}" if name not in (None, "default") else p for p, name in adapter_map.items()]])
|
||||
if not openllm.utils.get_quiet_mode(): termui.echo(f"\n🚀 Next step: run '{cmd_name}' to create a Bento for {model_name}", fg="blue")
|
||||
cmd_name = f'openllm build {model_name}'
|
||||
if adapter_map is not None: cmd_name += ' ' + ' '.join([f'--adapter-id {s}' for s in [f'{p}:{name}' if name not in (None, 'default') else p for p, name in adapter_map.items()]])
|
||||
if not openllm.utils.get_quiet_mode(): termui.echo(f"\n🚀 Next step: run '{cmd_name}' to create a Bento for {model_name}", fg='blue')
|
||||
|
||||
if return_process:
|
||||
server.start(env=start_env, text=True)
|
||||
if server.process is None: raise click.ClickException("Failed to start the server.")
|
||||
if server.process is None: raise click.ClickException('Failed to start the server.')
|
||||
return server.process
|
||||
else:
|
||||
try:
|
||||
@@ -191,7 +191,7 @@ Available official model_id(s): [default: {llm_config['default_id']}]
|
||||
except KeyboardInterrupt:
|
||||
next_step(model, adapter_map)
|
||||
except Exception as err:
|
||||
termui.echo(f"Error caught while running LLM Server:\n{err}", fg="red")
|
||||
termui.echo(f'Error caught while running LLM Server:\n{err}', fg='red')
|
||||
else:
|
||||
next_step(model, adapter_map)
|
||||
|
||||
@@ -200,40 +200,40 @@ Available official model_id(s): [default: {llm_config['default_id']}]
|
||||
|
||||
return start_cmd
|
||||
def noop_command(group: click.Group, llm_config: LLMConfig, _serve_grpc: bool, **command_attrs: t.Any) -> click.Command:
|
||||
context_settings = command_attrs.pop("context_settings", {})
|
||||
context_settings.update({"ignore_unknown_options": True, "allow_extra_args": True})
|
||||
command_attrs["context_settings"] = context_settings
|
||||
context_settings = command_attrs.pop('context_settings', {})
|
||||
context_settings.update({'ignore_unknown_options': True, 'allow_extra_args': True})
|
||||
command_attrs['context_settings'] = context_settings
|
||||
# NOTE: The model requires GPU, therefore we will return a dummy command
|
||||
@group.command(**command_attrs)
|
||||
def noop(**_: t.Any) -> LLMConfig:
|
||||
termui.echo("No GPU available, therefore this command is disabled", fg="red")
|
||||
termui.echo('No GPU available, therefore this command is disabled', fg='red')
|
||||
openllm.utils.analytics.track_start_init(llm_config)
|
||||
return llm_config
|
||||
|
||||
return noop
|
||||
def prerequisite_check(ctx: click.Context, llm_config: LLMConfig, quantize: LiteralString | None, adapter_map: dict[str, str | None] | None, num_workers: int) -> None:
|
||||
if adapter_map and not openllm.utils.is_peft_available(): ctx.fail("Using adapter requires 'peft' to be available. Make sure to install with 'pip install \"openllm[fine-tune]\"'")
|
||||
if quantize and llm_config.default_implementation() == "vllm":
|
||||
if quantize and llm_config.default_implementation() == 'vllm':
|
||||
ctx.fail(f"Quantization is not yet supported with vLLM. Set '{llm_config['env']['framework']}=\"pt\"' to run with quantization.")
|
||||
requirements = llm_config["requirements"]
|
||||
requirements = llm_config['requirements']
|
||||
if requirements is not None and len(requirements) > 0:
|
||||
missing_requirements = [i for i in requirements if importlib.util.find_spec(inflection.underscore(i)) is None]
|
||||
if len(missing_requirements) > 0: termui.echo(f"Make sure to have the following dependencies available: {missing_requirements}", fg="yellow")
|
||||
if len(missing_requirements) > 0: termui.echo(f'Make sure to have the following dependencies available: {missing_requirements}', fg='yellow')
|
||||
def start_decorator(llm_config: LLMConfig, serve_grpc: bool = False) -> t.Callable[[FC], t.Callable[[FC], FC]]:
|
||||
def wrapper(fn: FC) -> t.Callable[[FC], FC]:
|
||||
composed = openllm.utils.compose(
|
||||
llm_config.to_click_options,
|
||||
_http_server_args if not serve_grpc else _grpc_server_args,
|
||||
cog.optgroup.group("General LLM Options", help=f"The following options are related to running '{llm_config['start_name']}' LLM Server."),
|
||||
model_id_option(factory=cog.optgroup, model_env=llm_config["env"]),
|
||||
cog.optgroup.group('General LLM Options', help=f"The following options are related to running '{llm_config['start_name']}' LLM Server."),
|
||||
model_id_option(factory=cog.optgroup, model_env=llm_config['env']),
|
||||
model_version_option(factory=cog.optgroup),
|
||||
cog.optgroup.option("--server-timeout", type=int, default=None, help="Server timeout in seconds"),
|
||||
cog.optgroup.option('--server-timeout', type=int, default=None, help='Server timeout in seconds'),
|
||||
workers_per_resource_option(factory=cog.optgroup),
|
||||
cors_option(factory=cog.optgroup),
|
||||
fast_option(factory=cog.optgroup),
|
||||
cog.optgroup.group(
|
||||
"LLM Optimization Options",
|
||||
help="""Optimization related options.
|
||||
'LLM Optimization Options',
|
||||
help='''Optimization related options.
|
||||
|
||||
OpenLLM supports running model with [BetterTransformer](https://pytorch.org/blog/a-better-transformer-for-fast-transformer-encoder-inference/),
|
||||
k-bit quantization (8-bit, 4-bit), GPTQ quantization, PagedAttention via vLLM.
|
||||
@@ -242,24 +242,24 @@ def start_decorator(llm_config: LLMConfig, serve_grpc: bool = False) -> t.Callab
|
||||
|
||||
- DeepSpeed Inference: [link](https://www.deepspeed.ai/inference/)
|
||||
- GGML: Fast inference on [bare metal](https://github.com/ggerganov/ggml)
|
||||
""",
|
||||
''',
|
||||
),
|
||||
cog.optgroup.option(
|
||||
"--device",
|
||||
'--device',
|
||||
type=openllm.utils.dantic.CUDA,
|
||||
multiple=True,
|
||||
envvar="CUDA_VISIBLE_DEVICES",
|
||||
envvar='CUDA_VISIBLE_DEVICES',
|
||||
callback=parse_device_callback,
|
||||
help=f"Assign GPU devices (if available) for {llm_config['model_name']}.",
|
||||
show_envvar=True
|
||||
),
|
||||
cog.optgroup.option("--runtime", type=click.Choice(["ggml", "transformers"]), default="transformers", help="The runtime to use for the given model. Default is transformers."),
|
||||
quantize_option(factory=cog.optgroup, model_env=llm_config["env"]),
|
||||
bettertransformer_option(factory=cog.optgroup, model_env=llm_config["env"]),
|
||||
cog.optgroup.option('--runtime', type=click.Choice(['ggml', 'transformers']), default='transformers', help='The runtime to use for the given model. Default is transformers.'),
|
||||
quantize_option(factory=cog.optgroup, model_env=llm_config['env']),
|
||||
bettertransformer_option(factory=cog.optgroup, model_env=llm_config['env']),
|
||||
serialisation_option(factory=cog.optgroup),
|
||||
cog.optgroup.group(
|
||||
"Fine-tuning related options",
|
||||
help="""\
|
||||
'Fine-tuning related options',
|
||||
help='''\
|
||||
Note that the argument `--adapter-id` can accept the following format:
|
||||
|
||||
- `--adapter-id /path/to/adapter` (local adapter)
|
||||
@@ -273,37 +273,37 @@ def start_decorator(llm_config: LLMConfig, serve_grpc: bool = False) -> t.Callab
|
||||
$ openllm start opt --adapter-id /path/to/adapter_dir --adapter-id remote/adapter:eng_lora
|
||||
|
||||
```
|
||||
"""
|
||||
'''
|
||||
),
|
||||
cog.optgroup.option(
|
||||
"--adapter-id",
|
||||
'--adapter-id',
|
||||
default=None,
|
||||
help="Optional name or path for given LoRA adapter" + f" to wrap '{llm_config['model_name']}'",
|
||||
help='Optional name or path for given LoRA adapter' + f" to wrap '{llm_config['model_name']}'",
|
||||
multiple=True,
|
||||
callback=_id_callback,
|
||||
metavar="[PATH | [remote/][adapter_name:]adapter_id][, ...]"
|
||||
metavar='[PATH | [remote/][adapter_name:]adapter_id][, ...]'
|
||||
),
|
||||
click.option("--return-process", is_flag=True, default=False, help="Internal use only.", hidden=True),
|
||||
click.option('--return-process', is_flag=True, default=False, help='Internal use only.', hidden=True),
|
||||
)
|
||||
return composed(fn)
|
||||
|
||||
return wrapper
|
||||
def parse_device_callback(ctx: click.Context, param: click.Parameter, value: tuple[tuple[str], ...] | None) -> t.Tuple[str, ...] | None:
|
||||
if value is None: return value
|
||||
if not isinstance(value, tuple): ctx.fail(f"{param} only accept multiple values, not {type(value)} (value: {value})")
|
||||
if not isinstance(value, tuple): ctx.fail(f'{param} only accept multiple values, not {type(value)} (value: {value})')
|
||||
el: t.Tuple[str, ...] = tuple(i for k in value for i in k)
|
||||
# NOTE: --device all is a special case
|
||||
if len(el) == 1 and el[0] == "all": return tuple(map(str, openllm.utils.available_devices()))
|
||||
if len(el) == 1 and el[0] == 'all': return tuple(map(str, openllm.utils.available_devices()))
|
||||
return el
|
||||
# NOTE: A list of bentoml option that is not needed for parsing.
|
||||
# NOTE: User shouldn't set '--working-dir', as OpenLLM will setup this.
|
||||
# NOTE: production is also deprecated
|
||||
_IGNORED_OPTIONS = {"working_dir", "production", "protocol_version"}
|
||||
_IGNORED_OPTIONS = {'working_dir', 'production', 'protocol_version'}
|
||||
def parse_serve_args(serve_grpc: bool) -> t.Callable[[t.Callable[..., LLMConfig]], t.Callable[[FC], FC]]:
|
||||
"""Parsing `bentoml serve|serve-grpc` click.Option to be parsed via `openllm start`."""
|
||||
'''Parsing `bentoml serve|serve-grpc` click.Option to be parsed via `openllm start`.'''
|
||||
from bentoml_cli.cli import cli
|
||||
|
||||
command = "serve" if not serve_grpc else "serve-grpc"
|
||||
command = 'serve' if not serve_grpc else 'serve-grpc'
|
||||
group = cog.optgroup.group(
|
||||
f"Start a {'HTTP' if not serve_grpc else 'gRPC'} server options", help=f"Related to serving the model [synonymous to `bentoml {'serve-http' if not serve_grpc else command }`]",
|
||||
)
|
||||
@@ -316,95 +316,95 @@ def parse_serve_args(serve_grpc: bool) -> t.Callable[[t.Callable[..., LLMConfig]
|
||||
for options in reversed(serve_options):
|
||||
attrs = options.to_info_dict()
|
||||
# we don't need param_type_name, since it should all be options
|
||||
attrs.pop("param_type_name")
|
||||
attrs.pop('param_type_name')
|
||||
# name is not a valid args
|
||||
attrs.pop("name")
|
||||
attrs.pop('name')
|
||||
# type can be determine from default value
|
||||
attrs.pop("type")
|
||||
param_decls = (*attrs.pop("opts"), *attrs.pop("secondary_opts"))
|
||||
attrs.pop('type')
|
||||
param_decls = (*attrs.pop('opts'), *attrs.pop('secondary_opts'))
|
||||
f = cog.optgroup.option(*param_decls, **attrs)(f)
|
||||
return group(f)
|
||||
|
||||
return decorator
|
||||
_http_server_args, _grpc_server_args = parse_serve_args(False), parse_serve_args(True)
|
||||
def _click_factory_type(*param_decls: t.Any, **attrs: t.Any) -> t.Callable[[FC | None], FC]:
|
||||
"""General ``@click`` decorator with some sauce.
|
||||
'''General ``@click`` decorator with some sauce.
|
||||
|
||||
This decorator extends the default ``@click.option`` plus a factory option and factory attr to
|
||||
provide type-safe click.option or click.argument wrapper for all compatible factory.
|
||||
"""
|
||||
factory = attrs.pop("factory", click)
|
||||
factory_attr = attrs.pop("attr", "option")
|
||||
if factory_attr != "argument": attrs.setdefault("help", "General option for OpenLLM CLI.")
|
||||
'''
|
||||
factory = attrs.pop('factory', click)
|
||||
factory_attr = attrs.pop('attr', 'option')
|
||||
if factory_attr != 'argument': attrs.setdefault('help', 'General option for OpenLLM CLI.')
|
||||
|
||||
def decorator(f: FC | None) -> FC:
|
||||
callback = getattr(factory, factory_attr, None)
|
||||
if callback is None: raise ValueError(f"Factory {factory} has no attribute {factory_attr}.")
|
||||
if callback is None: raise ValueError(f'Factory {factory} has no attribute {factory_attr}.')
|
||||
return t.cast(FC, callback(*param_decls, **attrs)(f) if f is not None else callback(*param_decls, **attrs))
|
||||
|
||||
return decorator
|
||||
cli_option = functools.partial(_click_factory_type, attr="option")
|
||||
cli_argument = functools.partial(_click_factory_type, attr="argument")
|
||||
def output_option(f: _AnyCallable | None = None, *, default_value: LiteralOutput = "pretty", **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
output = ["json", "pretty", "porcelain"]
|
||||
cli_option = functools.partial(_click_factory_type, attr='option')
|
||||
cli_argument = functools.partial(_click_factory_type, attr='argument')
|
||||
def output_option(f: _AnyCallable | None = None, *, default_value: LiteralOutput = 'pretty', **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
output = ['json', 'pretty', 'porcelain']
|
||||
|
||||
def complete_output_var(ctx: click.Context, param: click.Parameter, incomplete: str) -> list[CompletionItem]:
|
||||
return [CompletionItem(it) for it in output]
|
||||
|
||||
return cli_option(
|
||||
"-o",
|
||||
"--output",
|
||||
"output",
|
||||
'-o',
|
||||
'--output',
|
||||
'output',
|
||||
type=click.Choice(output),
|
||||
default=default_value,
|
||||
help="Showing output type.",
|
||||
help='Showing output type.',
|
||||
show_default=True,
|
||||
envvar="OPENLLM_OUTPUT",
|
||||
envvar='OPENLLM_OUTPUT',
|
||||
show_envvar=True,
|
||||
shell_complete=complete_output_var,
|
||||
**attrs
|
||||
)(f)
|
||||
def fast_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
"--fast/--no-fast",
|
||||
'--fast/--no-fast',
|
||||
show_default=True,
|
||||
default=False,
|
||||
envvar="OPENLLM_USE_LOCAL_LATEST",
|
||||
envvar='OPENLLM_USE_LOCAL_LATEST',
|
||||
show_envvar=True,
|
||||
help="""Whether to skip checking if models is already in store.
|
||||
help='''Whether to skip checking if models is already in store.
|
||||
|
||||
This is useful if you already downloaded or setup the model beforehand.
|
||||
""",
|
||||
''',
|
||||
**attrs
|
||||
)(f)
|
||||
def cors_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option("--cors/--no-cors", show_default=True, default=False, envvar="OPENLLM_CORS", show_envvar=True, help="Enable CORS for the server.", **attrs)(f)
|
||||
return cli_option('--cors/--no-cors', show_default=True, default=False, envvar='OPENLLM_CORS', show_envvar=True, help='Enable CORS for the server.', **attrs)(f)
|
||||
def machine_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option("--machine", is_flag=True, default=False, hidden=True, **attrs)(f)
|
||||
return cli_option('--machine', is_flag=True, default=False, hidden=True, **attrs)(f)
|
||||
def model_id_option(f: _AnyCallable | None = None, *, model_env: openllm.utils.EnvVarMixin | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
"--model-id",
|
||||
'--model-id',
|
||||
type=click.STRING,
|
||||
default=None,
|
||||
envvar=model_env.model_id if model_env is not None else None,
|
||||
show_envvar=model_env is not None,
|
||||
help="Optional model_id name or path for (fine-tune) weight.",
|
||||
help='Optional model_id name or path for (fine-tune) weight.',
|
||||
**attrs
|
||||
)(f)
|
||||
def model_version_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option("--model-version", type=click.STRING, default=None, help="Optional model version to save for this model. It will be inferred automatically from model-id.", **attrs)(f)
|
||||
return cli_option('--model-version', type=click.STRING, default=None, help='Optional model version to save for this model. It will be inferred automatically from model-id.', **attrs)(f)
|
||||
def model_name_argument(f: _AnyCallable | None = None, required: bool = True, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_argument("model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING]), required=required, **attrs)(f)
|
||||
return cli_argument('model_name', type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING]), required=required, **attrs)(f)
|
||||
def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, model_env: openllm.utils.EnvVarMixin | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
"--quantise",
|
||||
"--quantize",
|
||||
"quantize",
|
||||
type=click.Choice(["int8", "int4", "gptq"]),
|
||||
'--quantise',
|
||||
'--quantize',
|
||||
'quantize',
|
||||
type=click.Choice(['int8', 'int4', 'gptq']),
|
||||
default=None,
|
||||
envvar=model_env.quantize if model_env is not None else None,
|
||||
show_envvar=model_env is not None,
|
||||
help="""Dynamic quantization for running this LLM.
|
||||
help='''Dynamic quantization for running this LLM.
|
||||
|
||||
The following quantization strategies are supported:
|
||||
|
||||
@@ -415,19 +415,19 @@ def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, model
|
||||
- ``gptq``: ``GPTQ`` [quantization](https://arxiv.org/abs/2210.17323)
|
||||
|
||||
> [!NOTE] that the model can also be served with quantized weights.
|
||||
""" + ("""
|
||||
> [!NOTE] that this will set the mode for serving within deployment.""" if build else "") + """
|
||||
> [!NOTE] that quantization are currently only available in *PyTorch* models.""",
|
||||
''' + ('''
|
||||
> [!NOTE] that this will set the mode for serving within deployment.''' if build else '') + '''
|
||||
> [!NOTE] that quantization are currently only available in *PyTorch* models.''',
|
||||
**attrs
|
||||
)(f)
|
||||
def workers_per_resource_option(f: _AnyCallable | None = None, *, build: bool = False, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
"--workers-per-resource",
|
||||
'--workers-per-resource',
|
||||
default=None,
|
||||
callback=workers_per_resource_callback,
|
||||
type=str,
|
||||
required=False,
|
||||
help="""Number of workers per resource assigned.
|
||||
help='''Number of workers per resource assigned.
|
||||
|
||||
See https://docs.bentoml.org/en/latest/guides/scheduling.html#resource-scheduling-strategy
|
||||
for more information. By default, this is set to 1.
|
||||
@@ -437,36 +437,36 @@ def workers_per_resource_option(f: _AnyCallable | None = None, *, build: bool =
|
||||
- ``round_robin``: Similar behaviour when setting ``--workers-per-resource 1``. This is useful for smaller models.
|
||||
|
||||
- ``conserved``: This will determine the number of available GPU resources, and only assign one worker for the LLMRunner. For example, if ther are 4 GPUs available, then ``conserved`` is equivalent to ``--workers-per-resource 0.25``.
|
||||
""" + (
|
||||
''' + (
|
||||
"""\n
|
||||
> [!NOTE] The workers value passed into 'build' will determine how the LLM can
|
||||
> be provisioned in Kubernetes as well as in standalone container. This will
|
||||
> ensure it has the same effect with 'openllm start --api-workers ...'""" if build else ""
|
||||
> ensure it has the same effect with 'openllm start --api-workers ...'""" if build else ''
|
||||
),
|
||||
**attrs
|
||||
)(f)
|
||||
def bettertransformer_option(f: _AnyCallable | None = None, *, build: bool = False, model_env: openllm.utils.EnvVarMixin | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
"--bettertransformer",
|
||||
'--bettertransformer',
|
||||
is_flag=True,
|
||||
default=None,
|
||||
envvar=model_env.bettertransformer if model_env is not None else None,
|
||||
show_envvar=model_env is not None,
|
||||
help="Apply FasterTransformer wrapper to serve model. This will applies during serving time."
|
||||
if not build else "Set default environment variable whether to serve this model with FasterTransformer in build time.",
|
||||
help='Apply FasterTransformer wrapper to serve model. This will applies during serving time.'
|
||||
if not build else 'Set default environment variable whether to serve this model with FasterTransformer in build time.',
|
||||
**attrs
|
||||
)(f)
|
||||
def serialisation_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
"--serialisation",
|
||||
"--serialization",
|
||||
"serialisation_format",
|
||||
type=click.Choice(["safetensors", "legacy"]),
|
||||
default="safetensors",
|
||||
'--serialisation',
|
||||
'--serialization',
|
||||
'serialisation_format',
|
||||
type=click.Choice(['safetensors', 'legacy']),
|
||||
default='safetensors',
|
||||
show_default=True,
|
||||
show_envvar=True,
|
||||
envvar="OPENLLM_SERIALIZATION",
|
||||
help="""Serialisation format for save/load LLM.
|
||||
envvar='OPENLLM_SERIALIZATION',
|
||||
help='''Serialisation format for save/load LLM.
|
||||
|
||||
Currently the following strategies are supported:
|
||||
|
||||
@@ -482,29 +482,29 @@ def serialisation_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Cal
|
||||
- ``legacy``: This will use PyTorch serialisation format, often as ``.bin`` files. This should be used if the model doesn't yet support safetensors.
|
||||
|
||||
> [!NOTE] that GGML format is working in progress.
|
||||
""",
|
||||
''',
|
||||
**attrs
|
||||
)(f)
|
||||
def container_registry_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
"--container-registry",
|
||||
"container_registry",
|
||||
'--container-registry',
|
||||
'container_registry',
|
||||
type=click.Choice(list(openllm.bundle.CONTAINER_NAMES)),
|
||||
default="ecr",
|
||||
default='ecr',
|
||||
show_default=True,
|
||||
show_envvar=True,
|
||||
envvar="OPENLLM_CONTAINER_REGISTRY",
|
||||
envvar='OPENLLM_CONTAINER_REGISTRY',
|
||||
callback=container_registry_callback,
|
||||
help="""The default container registry to get the base image for building BentoLLM.
|
||||
help='''The default container registry to get the base image for building BentoLLM.
|
||||
|
||||
Currently, it supports 'ecr', 'ghcr.io', 'docker.io'
|
||||
|
||||
\b
|
||||
> [!NOTE] that in order to build the base image, you will need a GPUs to compile custom kernel. See ``openllm ext build-base-container`` for more information.
|
||||
""",
|
||||
''',
|
||||
**attrs
|
||||
)(f)
|
||||
_wpr_strategies = {"round_robin", "conserved"}
|
||||
_wpr_strategies = {'round_robin', 'conserved'}
|
||||
def workers_per_resource_callback(ctx: click.Context, param: click.Parameter, value: str | None) -> str | None:
|
||||
if value is None: return value
|
||||
value = inflection.underscore(value)
|
||||
@@ -518,5 +518,5 @@ def workers_per_resource_callback(ctx: click.Context, param: click.Parameter, va
|
||||
return value
|
||||
def container_registry_callback(ctx: click.Context, param: click.Parameter, value: str | None) -> str | None:
|
||||
if value is None: return value
|
||||
if value not in openllm.bundle.supported_registries: raise click.BadParameter(f"Value must be one of {openllm.bundle.supported_registries}", ctx, param)
|
||||
if value not in openllm.bundle.supported_registries: raise click.BadParameter(f'Value must be one of {openllm.bundle.supported_registries}', ctx, param)
|
||||
return value
|
||||
|
||||
@@ -17,11 +17,11 @@ def _start(
|
||||
*,
|
||||
model_id: str | None = None,
|
||||
timeout: int = 30,
|
||||
workers_per_resource: t.Literal["conserved", "round_robin"] | float | None = None,
|
||||
device: tuple[str, ...] | t.Literal["all"] | None = None,
|
||||
quantize: t.Literal["int8", "int4", "gptq"] | None = None,
|
||||
workers_per_resource: t.Literal['conserved', 'round_robin'] | float | None = None,
|
||||
device: tuple[str, ...] | t.Literal['all'] | None = None,
|
||||
quantize: t.Literal['int8', 'int4', 'gptq'] | None = None,
|
||||
bettertransformer: bool | None = None,
|
||||
runtime: t.Literal["ggml", "transformers"] = "transformers",
|
||||
runtime: t.Literal['ggml', 'transformers'] = 'transformers',
|
||||
adapter_map: dict[LiteralString, str | None] | None = None,
|
||||
framework: LiteralRuntime | None = None,
|
||||
additional_args: list[str] | None = None,
|
||||
@@ -79,20 +79,20 @@ def _start(
|
||||
quantize=quantize,
|
||||
runtime=runtime
|
||||
)
|
||||
os.environ[_ModelEnv.framework] = _ModelEnv["framework_value"]
|
||||
os.environ[_ModelEnv.framework] = _ModelEnv['framework_value']
|
||||
|
||||
args: list[str] = ["--runtime", runtime]
|
||||
if model_id: args.extend(["--model-id", model_id])
|
||||
if timeout: args.extend(["--server-timeout", str(timeout)])
|
||||
if workers_per_resource: args.extend(["--workers-per-resource", str(workers_per_resource) if not isinstance(workers_per_resource, str) else workers_per_resource])
|
||||
if device and not os.environ.get("CUDA_VISIBLE_DEVICES"): args.extend(["--device", ",".join(device)])
|
||||
args: list[str] = ['--runtime', runtime]
|
||||
if model_id: args.extend(['--model-id', model_id])
|
||||
if timeout: args.extend(['--server-timeout', str(timeout)])
|
||||
if workers_per_resource: args.extend(['--workers-per-resource', str(workers_per_resource) if not isinstance(workers_per_resource, str) else workers_per_resource])
|
||||
if device and not os.environ.get('CUDA_VISIBLE_DEVICES'): args.extend(['--device', ','.join(device)])
|
||||
if quantize and bettertransformer: raise OpenLLMException("'quantize' and 'bettertransformer' are currently mutually exclusive.")
|
||||
if quantize: args.extend(["--quantize", str(quantize)])
|
||||
elif bettertransformer: args.append("--bettertransformer")
|
||||
if cors: args.append("--cors")
|
||||
if adapter_map: args.extend(list(itertools.chain.from_iterable([["--adapter-id", f"{k}{':'+v if v else ''}"] for k, v in adapter_map.items()])))
|
||||
if quantize: args.extend(['--quantize', str(quantize)])
|
||||
elif bettertransformer: args.append('--bettertransformer')
|
||||
if cors: args.append('--cors')
|
||||
if adapter_map: args.extend(list(itertools.chain.from_iterable([['--adapter-id', f"{k}{':'+v if v else ''}"] for k, v in adapter_map.items()])))
|
||||
if additional_args: args.extend(additional_args)
|
||||
if __test__: args.append("--return-process")
|
||||
if __test__: args.append('--return-process')
|
||||
|
||||
return start_command_factory(start_command if not _serve_grpc else start_grpc_command, model_name, _context_settings=termui.CONTEXT_SETTINGS, _serve_grpc=_serve_grpc).main(
|
||||
args=args if len(args) > 0 else None, standalone_mode=False
|
||||
@@ -105,20 +105,20 @@ def _build(
|
||||
model_id: str | None = None,
|
||||
model_version: str | None = None,
|
||||
bento_version: str | None = None,
|
||||
quantize: t.Literal["int8", "int4", "gptq"] | None = None,
|
||||
quantize: t.Literal['int8', 'int4', 'gptq'] | None = None,
|
||||
bettertransformer: bool | None = None,
|
||||
adapter_map: dict[str, str | None] | None = None,
|
||||
build_ctx: str | None = None,
|
||||
enable_features: tuple[str, ...] | None = None,
|
||||
workers_per_resource: float | None = None,
|
||||
runtime: t.Literal["ggml", "transformers"] = "transformers",
|
||||
runtime: t.Literal['ggml', 'transformers'] = 'transformers',
|
||||
dockerfile_template: str | None = None,
|
||||
overwrite: bool = False,
|
||||
container_registry: LiteralContainerRegistry | None = None,
|
||||
container_version_strategy: LiteralContainerVersionStrategy | None = None,
|
||||
push: bool = False,
|
||||
containerize: bool = False,
|
||||
serialisation_format: t.Literal["safetensors", "legacy"] = "safetensors",
|
||||
serialisation_format: t.Literal['safetensors', 'legacy'] = 'safetensors',
|
||||
additional_args: list[str] | None = None,
|
||||
bento_store: BentoStore = Provide[BentoMLContainer.bento_store]
|
||||
) -> bentoml.Bento:
|
||||
@@ -171,34 +171,34 @@ def _build(
|
||||
Returns:
|
||||
``bentoml.Bento | str``: BentoLLM instance. This can be used to serve the LLM or can be pushed to BentoCloud.
|
||||
"""
|
||||
args: list[str] = [sys.executable, "-m", "openllm", "build", model_name, "--machine", "--runtime", runtime, "--serialisation", serialisation_format]
|
||||
args: list[str] = [sys.executable, '-m', 'openllm', 'build', model_name, '--machine', '--runtime', runtime, '--serialisation', serialisation_format]
|
||||
if quantize and bettertransformer: raise OpenLLMException("'quantize' and 'bettertransformer' are currently mutually exclusive.")
|
||||
if quantize: args.extend(["--quantize", quantize])
|
||||
if bettertransformer: args.append("--bettertransformer")
|
||||
if quantize: args.extend(['--quantize', quantize])
|
||||
if bettertransformer: args.append('--bettertransformer')
|
||||
if containerize and push: raise OpenLLMException("'containerize' and 'push' are currently mutually exclusive.")
|
||||
if push: args.extend(["--push"])
|
||||
if containerize: args.extend(["--containerize"])
|
||||
if model_id: args.extend(["--model-id", model_id])
|
||||
if build_ctx: args.extend(["--build-ctx", build_ctx])
|
||||
if enable_features: args.extend([f"--enable-features={f}" for f in enable_features])
|
||||
if workers_per_resource: args.extend(["--workers-per-resource", str(workers_per_resource)])
|
||||
if overwrite: args.append("--overwrite")
|
||||
if push: args.extend(['--push'])
|
||||
if containerize: args.extend(['--containerize'])
|
||||
if model_id: args.extend(['--model-id', model_id])
|
||||
if build_ctx: args.extend(['--build-ctx', build_ctx])
|
||||
if enable_features: args.extend([f'--enable-features={f}' for f in enable_features])
|
||||
if workers_per_resource: args.extend(['--workers-per-resource', str(workers_per_resource)])
|
||||
if overwrite: args.append('--overwrite')
|
||||
if adapter_map: args.extend([f"--adapter-id={k}{':'+v if v is not None else ''}" for k, v in adapter_map.items()])
|
||||
if model_version: args.extend(["--model-version", model_version])
|
||||
if bento_version: args.extend(["--bento-version", bento_version])
|
||||
if dockerfile_template: args.extend(["--dockerfile-template", dockerfile_template])
|
||||
if container_registry is None: container_registry = "ecr"
|
||||
if container_version_strategy is None: container_version_strategy = "release"
|
||||
args.extend(["--container-registry", container_registry, "--container-version-strategy", container_version_strategy])
|
||||
if model_version: args.extend(['--model-version', model_version])
|
||||
if bento_version: args.extend(['--bento-version', bento_version])
|
||||
if dockerfile_template: args.extend(['--dockerfile-template', dockerfile_template])
|
||||
if container_registry is None: container_registry = 'ecr'
|
||||
if container_version_strategy is None: container_version_strategy = 'release'
|
||||
args.extend(['--container-registry', container_registry, '--container-version-strategy', container_version_strategy])
|
||||
if additional_args: args.extend(additional_args)
|
||||
|
||||
try:
|
||||
output = subprocess.check_output(args, env=os.environ.copy(), cwd=build_ctx or os.getcwd())
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error("Exception caught while building %s", model_name, exc_info=e)
|
||||
if e.stderr: raise OpenLLMException(e.stderr.decode("utf-8")) from None
|
||||
logger.error('Exception caught while building %s', model_name, exc_info=e)
|
||||
if e.stderr: raise OpenLLMException(e.stderr.decode('utf-8')) from None
|
||||
raise OpenLLMException(str(e)) from None
|
||||
matched = re.match(r"__tag__:([^:\n]+:[^:\n]+)$", output.decode("utf-8").strip())
|
||||
matched = re.match(r'__tag__:([^:\n]+:[^:\n]+)$', output.decode('utf-8').strip())
|
||||
if matched is None:
|
||||
raise ValueError(f"Failed to find tag from output: {output.decode('utf-8').strip()}\nNote: Output from 'openllm build' might not be correct. Please open an issue on GitHub.")
|
||||
return bentoml.get(matched.group(1), _bento_store=bento_store)
|
||||
@@ -208,10 +208,10 @@ def _import_model(
|
||||
*,
|
||||
model_id: str | None = None,
|
||||
model_version: str | None = None,
|
||||
runtime: t.Literal["ggml", "transformers"] = "transformers",
|
||||
implementation: LiteralRuntime = "pt",
|
||||
quantize: t.Literal["int8", "int4", "gptq"] | None = None,
|
||||
serialisation_format: t.Literal["legacy", "safetensors"] = "safetensors",
|
||||
runtime: t.Literal['ggml', 'transformers'] = 'transformers',
|
||||
implementation: LiteralRuntime = 'pt',
|
||||
quantize: t.Literal['int8', 'int4', 'gptq'] | None = None,
|
||||
serialisation_format: t.Literal['legacy', 'safetensors'] = 'safetensors',
|
||||
additional_args: t.Sequence[str] | None = None
|
||||
) -> bentoml.Model:
|
||||
"""Import a LLM into local store.
|
||||
@@ -245,15 +245,15 @@ def _import_model(
|
||||
``bentoml.Model``:BentoModel of the given LLM. This can be used to serve the LLM or can be pushed to BentoCloud.
|
||||
"""
|
||||
from .entrypoint import import_command
|
||||
args = [model_name, "--runtime", runtime, "--implementation", implementation, "--machine", "--serialisation", serialisation_format,]
|
||||
args = [model_name, '--runtime', runtime, '--implementation', implementation, '--machine', '--serialisation', serialisation_format,]
|
||||
if model_id is not None: args.append(model_id)
|
||||
if model_version is not None: args.extend(["--model-version", str(model_version)])
|
||||
if model_version is not None: args.extend(['--model-version', str(model_version)])
|
||||
if additional_args is not None: args.extend(additional_args)
|
||||
if quantize is not None: args.extend(["--quantize", quantize])
|
||||
if quantize is not None: args.extend(['--quantize', quantize])
|
||||
return import_command.main(args=args, standalone_mode=False)
|
||||
def _list_models() -> dict[str, t.Any]:
|
||||
"""List all available models within the local store."""
|
||||
'''List all available models within the local store.'''
|
||||
from .entrypoint import models_command
|
||||
return models_command.main(args=["-o", "json", "--show-available", "--machine"], standalone_mode=False)
|
||||
return models_command.main(args=['-o', 'json', '--show-available', '--machine'], standalone_mode=False)
|
||||
start, start_grpc, build, import_model, list_models = openllm_core.utils.codegen.gen_sdk(_start, _serve_grpc=False), openllm_core.utils.codegen.gen_sdk(_start, _serve_grpc=True), openllm_core.utils.codegen.gen_sdk(_build), openllm_core.utils.codegen.gen_sdk(_import_model), openllm_core.utils.codegen.gen_sdk(_list_models)
|
||||
__all__ = ["start", "start_grpc", "build", "import_model", "list_models"]
|
||||
__all__ = ['start', 'start_grpc', 'build', 'import_model', 'list_models']
|
||||
|
||||
@@ -41,38 +41,38 @@ if t.TYPE_CHECKING:
|
||||
from openllm_core._schema import EmbeddingsOutput
|
||||
from openllm_core._typing_compat import LiteralContainerRegistry, LiteralContainerVersionStrategy
|
||||
else:
|
||||
torch = LazyLoader("torch", globals(), "torch")
|
||||
torch = LazyLoader('torch', globals(), 'torch')
|
||||
|
||||
P = ParamSpec("P")
|
||||
P = ParamSpec('P')
|
||||
logger = logging.getLogger(__name__)
|
||||
OPENLLM_FIGLET = """\
|
||||
OPENLLM_FIGLET = '''\
|
||||
██████╗ ██████╗ ███████╗███╗ ██╗██╗ ██╗ ███╗ ███╗
|
||||
██╔═══██╗██╔══██╗██╔════╝████╗ ██║██║ ██║ ████╗ ████║
|
||||
██║ ██║██████╔╝█████╗ ██╔██╗ ██║██║ ██║ ██╔████╔██║
|
||||
██║ ██║██╔═══╝ ██╔══╝ ██║╚██╗██║██║ ██║ ██║╚██╔╝██║
|
||||
╚██████╔╝██║ ███████╗██║ ╚████║███████╗███████╗██║ ╚═╝ ██║
|
||||
╚═════╝ ╚═╝ ╚══════╝╚═╝ ╚═══╝╚══════╝╚══════╝╚═╝ ╚═╝
|
||||
"""
|
||||
'''
|
||||
|
||||
ServeCommand = t.Literal["serve", "serve-grpc"]
|
||||
ServeCommand = t.Literal['serve', 'serve-grpc']
|
||||
@attr.define
|
||||
class GlobalOptions:
|
||||
cloud_context: str | None = attr.field(default=None)
|
||||
|
||||
def with_options(self, **attrs: t.Any) -> Self:
|
||||
return attr.evolve(self, **attrs)
|
||||
GrpType = t.TypeVar("GrpType", bound=click.Group)
|
||||
GrpType = t.TypeVar('GrpType', bound=click.Group)
|
||||
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
_EXT_FOLDER = os.path.abspath(os.path.join(os.path.dirname(__file__), "extension"))
|
||||
_EXT_FOLDER = os.path.abspath(os.path.join(os.path.dirname(__file__), 'extension'))
|
||||
class Extensions(click.MultiCommand):
|
||||
def list_commands(self, ctx: click.Context) -> list[str]:
|
||||
return sorted([filename[:-3] for filename in os.listdir(_EXT_FOLDER) if filename.endswith(".py") and not filename.startswith("__")])
|
||||
return sorted([filename[:-3] for filename in os.listdir(_EXT_FOLDER) if filename.endswith('.py') and not filename.startswith('__')])
|
||||
|
||||
def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None:
|
||||
try:
|
||||
mod = __import__(f"openllm.cli.extension.{cmd_name}", None, None, ["cli"])
|
||||
mod = __import__(f'openllm.cli.extension.{cmd_name}', None, None, ['cli'])
|
||||
except ImportError:
|
||||
return None
|
||||
return mod.cli
|
||||
@@ -82,11 +82,11 @@ class OpenLLMCommandGroup(BentoMLCommandGroup):
|
||||
@staticmethod
|
||||
def common_params(f: t.Callable[P, t.Any]) -> t.Callable[[FC], FC]:
|
||||
# The following logics is similar to one of BentoMLCommandGroup
|
||||
@cog.optgroup.group(name="Global options", help="Shared globals options for all OpenLLM CLI.")
|
||||
@cog.optgroup.option("-q", "--quiet", envvar=QUIET_ENV_VAR, is_flag=True, default=False, help="Suppress all output.", show_envvar=True)
|
||||
@cog.optgroup.option("--debug", "--verbose", "debug", envvar=DEBUG_ENV_VAR, is_flag=True, default=False, help="Print out debug logs.", show_envvar=True)
|
||||
@cog.optgroup.option("--do-not-track", is_flag=True, default=False, envvar=analytics.OPENLLM_DO_NOT_TRACK, help="Do not send usage info", show_envvar=True)
|
||||
@cog.optgroup.option("--context", "cloud_context", envvar="BENTOCLOUD_CONTEXT", type=click.STRING, default=None, help="BentoCloud context name.", show_envvar=True)
|
||||
@cog.optgroup.group(name='Global options', help='Shared globals options for all OpenLLM CLI.')
|
||||
@cog.optgroup.option('-q', '--quiet', envvar=QUIET_ENV_VAR, is_flag=True, default=False, help='Suppress all output.', show_envvar=True)
|
||||
@cog.optgroup.option('--debug', '--verbose', 'debug', envvar=DEBUG_ENV_VAR, is_flag=True, default=False, help='Print out debug logs.', show_envvar=True)
|
||||
@cog.optgroup.option('--do-not-track', is_flag=True, default=False, envvar=analytics.OPENLLM_DO_NOT_TRACK, help='Do not send usage info', show_envvar=True)
|
||||
@cog.optgroup.option('--context', 'cloud_context', envvar='BENTOCLOUD_CONTEXT', type=click.STRING, default=None, help='BentoCloud context name.', show_envvar=True)
|
||||
@click.pass_context
|
||||
@functools.wraps(f)
|
||||
def wrapper(ctx: click.Context, quiet: bool, debug: bool, cloud_context: str | None, *args: P.args, **attrs: P.kwargs) -> t.Any:
|
||||
@@ -102,7 +102,7 @@ class OpenLLMCommandGroup(BentoMLCommandGroup):
|
||||
|
||||
@staticmethod
|
||||
def usage_tracking(func: t.Callable[P, t.Any], group: click.Group, **attrs: t.Any) -> t.Callable[Concatenate[bool, P], t.Any]:
|
||||
command_name = attrs.get("name", func.__name__)
|
||||
command_name = attrs.get('name', func.__name__)
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(do_not_track: bool, *args: P.args, **attrs: P.kwargs) -> t.Any:
|
||||
@@ -111,7 +111,7 @@ class OpenLLMCommandGroup(BentoMLCommandGroup):
|
||||
return func(*args, **attrs)
|
||||
start_time = time.time_ns()
|
||||
with analytics.set_bentoml_tracking():
|
||||
if group.name is None: raise ValueError("group.name should not be None")
|
||||
if group.name is None: raise ValueError('group.name should not be None')
|
||||
event = analytics.OpenllmCliEvent(cmd_group=group.name, cmd_name=command_name)
|
||||
try:
|
||||
return_value = func(*args, **attrs)
|
||||
@@ -131,22 +131,22 @@ class OpenLLMCommandGroup(BentoMLCommandGroup):
|
||||
|
||||
@staticmethod
|
||||
def exception_handling(func: t.Callable[P, t.Any], group: click.Group, **attrs: t.Any) -> t.Callable[P, t.Any]:
|
||||
command_name = attrs.get("name", func.__name__)
|
||||
command_name = attrs.get('name', func.__name__)
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: P.args, **attrs: P.kwargs) -> t.Any:
|
||||
try:
|
||||
return func(*args, **attrs)
|
||||
except OpenLLMException as err:
|
||||
raise click.ClickException(click.style(f"[{group.name}] '{command_name}' failed: " + err.message, fg="red")) from err
|
||||
raise click.ClickException(click.style(f"[{group.name}] '{command_name}' failed: " + err.message, fg='red')) from err
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
return wrapper
|
||||
|
||||
def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None:
|
||||
if cmd_name in t.cast("Extensions", extension_command).list_commands(ctx):
|
||||
return t.cast("Extensions", extension_command).get_command(ctx, cmd_name)
|
||||
if cmd_name in t.cast('Extensions', extension_command).list_commands(ctx):
|
||||
return t.cast('Extensions', extension_command).get_command(ctx, cmd_name)
|
||||
cmd_name = self.resolve_alias(cmd_name)
|
||||
if ctx.command.name in _start_mapping:
|
||||
try:
|
||||
@@ -158,36 +158,36 @@ class OpenLLMCommandGroup(BentoMLCommandGroup):
|
||||
raise click.ClickException(f"'openllm start {cmd_name}' is currently disabled for the time being. Please let us know if you need this feature by opening an issue on GitHub.")
|
||||
except bentoml.exceptions.NotFound:
|
||||
pass
|
||||
raise click.BadArgumentUsage(f"{cmd_name} is not a valid model identifier supported by OpenLLM.") from None
|
||||
raise click.BadArgumentUsage(f'{cmd_name} is not a valid model identifier supported by OpenLLM.') from None
|
||||
return super().get_command(ctx, cmd_name)
|
||||
|
||||
def list_commands(self, ctx: click.Context) -> list[str]:
|
||||
if ctx.command.name in {"start", "start-grpc"}: return list(CONFIG_MAPPING.keys())
|
||||
return super().list_commands(ctx) + t.cast("Extensions", extension_command).list_commands(ctx)
|
||||
if ctx.command.name in {'start', 'start-grpc'}: return list(CONFIG_MAPPING.keys())
|
||||
return super().list_commands(ctx) + t.cast('Extensions', extension_command).list_commands(ctx)
|
||||
|
||||
def command(self, *args: t.Any, **kwargs: t.Any) -> t.Callable[[t.Callable[..., t.Any]], click.Command]: # type: ignore[override] # XXX: fix decorator on BentoMLCommandGroup
|
||||
"""Override the default 'cli.command' with supports for aliases for given command, and it wraps the implementation with common parameters."""
|
||||
if "context_settings" not in kwargs: kwargs["context_settings"] = {}
|
||||
if "max_content_width" not in kwargs["context_settings"]: kwargs["context_settings"]["max_content_width"] = 120
|
||||
aliases = kwargs.pop("aliases", None)
|
||||
if 'context_settings' not in kwargs: kwargs['context_settings'] = {}
|
||||
if 'max_content_width' not in kwargs['context_settings']: kwargs['context_settings']['max_content_width'] = 120
|
||||
aliases = kwargs.pop('aliases', None)
|
||||
|
||||
def decorator(f: _AnyCallable) -> click.Command:
|
||||
name = f.__name__.lower()
|
||||
if name.endswith("_command"): name = name[:-8]
|
||||
name = name.replace("_", "-")
|
||||
kwargs.setdefault("help", inspect.getdoc(f))
|
||||
kwargs.setdefault("name", name)
|
||||
if name.endswith('_command'): name = name[:-8]
|
||||
name = name.replace('_', '-')
|
||||
kwargs.setdefault('help', inspect.getdoc(f))
|
||||
kwargs.setdefault('name', name)
|
||||
wrapped = self.exception_handling(self.usage_tracking(self.common_params(f), self, **kwargs), self, **kwargs)
|
||||
|
||||
# move common parameters to end of the parameters list
|
||||
_memo = getattr(wrapped, "__click_params__", None)
|
||||
if _memo is None: raise RuntimeError("Click command not register correctly.")
|
||||
_object_setattr(wrapped, "__click_params__", _memo[-self.NUMBER_OF_COMMON_PARAMS:] + _memo[:-self.NUMBER_OF_COMMON_PARAMS])
|
||||
_memo = getattr(wrapped, '__click_params__', None)
|
||||
if _memo is None: raise RuntimeError('Click command not register correctly.')
|
||||
_object_setattr(wrapped, '__click_params__', _memo[-self.NUMBER_OF_COMMON_PARAMS:] + _memo[:-self.NUMBER_OF_COMMON_PARAMS])
|
||||
# NOTE: we need to call super of super to avoid conflict with BentoMLCommandGroup command setup
|
||||
cmd = super(BentoMLCommandGroup, self).command(*args, **kwargs)(wrapped)
|
||||
# NOTE: add aliases to a given commands if it is specified.
|
||||
if aliases is not None:
|
||||
if not cmd.name: raise ValueError("name is required when aliases are available.")
|
||||
if not cmd.name: raise ValueError('name is required when aliases are available.')
|
||||
self._commands[cmd.name] = aliases
|
||||
self._aliases.update({alias: cmd.name for alias in aliases})
|
||||
return cmd
|
||||
@@ -195,11 +195,11 @@ class OpenLLMCommandGroup(BentoMLCommandGroup):
|
||||
return decorator
|
||||
|
||||
def format_commands(self, ctx: click.Context, formatter: click.HelpFormatter) -> None:
|
||||
"""Additional format methods that include extensions as well as the default cli command."""
|
||||
'''Additional format methods that include extensions as well as the default cli command.'''
|
||||
from gettext import gettext as _
|
||||
commands: list[tuple[str, click.Command]] = []
|
||||
extensions: list[tuple[str, click.Command]] = []
|
||||
_cached_extensions: list[str] = t.cast("Extensions", extension_command).list_commands(ctx)
|
||||
_cached_extensions: list[str] = t.cast('Extensions', extension_command).list_commands(ctx)
|
||||
for subcommand in self.list_commands(ctx):
|
||||
cmd = self.get_command(ctx, subcommand)
|
||||
if cmd is None or cmd.hidden: continue
|
||||
@@ -213,7 +213,7 @@ class OpenLLMCommandGroup(BentoMLCommandGroup):
|
||||
help = cmd.get_short_help_str(limit)
|
||||
rows.append((subcommand, help))
|
||||
if rows:
|
||||
with formatter.section(_("Commands")):
|
||||
with formatter.section(_('Commands')):
|
||||
formatter.write_dl(rows)
|
||||
if len(extensions):
|
||||
limit = formatter.width - 6 - max(len(cmd[0]) for cmd in extensions)
|
||||
@@ -222,14 +222,14 @@ class OpenLLMCommandGroup(BentoMLCommandGroup):
|
||||
help = cmd.get_short_help_str(limit)
|
||||
rows.append((inflection.dasherize(subcommand), help))
|
||||
if rows:
|
||||
with formatter.section(_("Extensions")):
|
||||
with formatter.section(_('Extensions')):
|
||||
formatter.write_dl(rows)
|
||||
@click.group(cls=OpenLLMCommandGroup, context_settings=termui.CONTEXT_SETTINGS, name="openllm")
|
||||
@click.group(cls=OpenLLMCommandGroup, context_settings=termui.CONTEXT_SETTINGS, name='openllm')
|
||||
@click.version_option(
|
||||
None, "--version", "-v", message=f"%(prog)s, %(version)s (compiled: {'yes' if openllm.COMPILED else 'no'})\nPython ({platform.python_implementation()}) {platform.python_version()}"
|
||||
None, '--version', '-v', message=f"%(prog)s, %(version)s (compiled: {'yes' if openllm.COMPILED else 'no'})\nPython ({platform.python_implementation()}) {platform.python_version()}"
|
||||
)
|
||||
def cli() -> None:
|
||||
"""\b
|
||||
'''\b
|
||||
██████╗ ██████╗ ███████╗███╗ ██╗██╗ ██╗ ███╗ ███╗
|
||||
██╔═══██╗██╔══██╗██╔════╝████╗ ██║██║ ██║ ████╗ ████║
|
||||
██║ ██║██████╔╝█████╗ ██╔██╗ ██║██║ ██║ ██╔████╔██║
|
||||
@@ -240,43 +240,43 @@ def cli() -> None:
|
||||
\b
|
||||
An open platform for operating large language models in production.
|
||||
Fine-tune, serve, deploy, and monitor any LLMs with ease.
|
||||
"""
|
||||
@cli.group(cls=OpenLLMCommandGroup, context_settings=termui.CONTEXT_SETTINGS, name="start", aliases=["start-http"])
|
||||
'''
|
||||
@cli.group(cls=OpenLLMCommandGroup, context_settings=termui.CONTEXT_SETTINGS, name='start', aliases=['start-http'])
|
||||
def start_command() -> None:
|
||||
"""Start any LLM as a REST server.
|
||||
'''Start any LLM as a REST server.
|
||||
|
||||
\b
|
||||
```bash
|
||||
$ openllm <start|start-http> <model_name> --<options> ...
|
||||
```
|
||||
"""
|
||||
@cli.group(cls=OpenLLMCommandGroup, context_settings=termui.CONTEXT_SETTINGS, name="start-grpc")
|
||||
'''
|
||||
@cli.group(cls=OpenLLMCommandGroup, context_settings=termui.CONTEXT_SETTINGS, name='start-grpc')
|
||||
def start_grpc_command() -> None:
|
||||
"""Start any LLM as a gRPC server.
|
||||
'''Start any LLM as a gRPC server.
|
||||
|
||||
\b
|
||||
```bash
|
||||
$ openllm start-grpc <model_name> --<options> ...
|
||||
```
|
||||
"""
|
||||
'''
|
||||
_start_mapping = {
|
||||
"start": {
|
||||
'start': {
|
||||
key: start_command_factory(start_command, key, _context_settings=termui.CONTEXT_SETTINGS) for key in CONFIG_MAPPING
|
||||
},
|
||||
"start-grpc": {
|
||||
'start-grpc': {
|
||||
key: start_command_factory(start_grpc_command, key, _context_settings=termui.CONTEXT_SETTINGS, _serve_grpc=True) for key in CONFIG_MAPPING
|
||||
}
|
||||
}
|
||||
@cli.command(name="import", aliases=["download"])
|
||||
@cli.command(name='import', aliases=['download'])
|
||||
@model_name_argument
|
||||
@click.argument("model_id", type=click.STRING, default=None, metavar="Optional[REMOTE_REPO/MODEL_ID | /path/to/local/model]", required=False)
|
||||
@click.argument("converter", envvar="CONVERTER", type=click.STRING, default=None, required=False, metavar=None)
|
||||
@click.argument('model_id', type=click.STRING, default=None, metavar='Optional[REMOTE_REPO/MODEL_ID | /path/to/local/model]', required=False)
|
||||
@click.argument('converter', envvar='CONVERTER', type=click.STRING, default=None, required=False, metavar=None)
|
||||
@model_version_option
|
||||
@click.option("--runtime", type=click.Choice(["ggml", "transformers"]), default="transformers", help="The runtime to use for the given model. Default is transformers.")
|
||||
@click.option('--runtime', type=click.Choice(['ggml', 'transformers']), default='transformers', help='The runtime to use for the given model. Default is transformers.')
|
||||
@output_option
|
||||
@quantize_option
|
||||
@machine_option
|
||||
@click.option("--implementation", type=click.Choice(["pt", "tf", "flax", "vllm"]), default=None, help="The implementation for saving this LLM.")
|
||||
@click.option('--implementation', type=click.Choice(['pt', 'tf', 'flax', 'vllm']), default=None, help='The implementation for saving this LLM.')
|
||||
@serialisation_option
|
||||
def import_command(
|
||||
model_name: str,
|
||||
@@ -284,11 +284,11 @@ def import_command(
|
||||
converter: str | None,
|
||||
model_version: str | None,
|
||||
output: LiteralOutput,
|
||||
runtime: t.Literal["ggml", "transformers"],
|
||||
runtime: t.Literal['ggml', 'transformers'],
|
||||
machine: bool,
|
||||
implementation: LiteralRuntime | None,
|
||||
quantize: t.Literal["int8", "int4", "gptq"] | None,
|
||||
serialisation_format: t.Literal["safetensors", "legacy"],
|
||||
quantize: t.Literal['int8', 'int4', 'gptq'] | None,
|
||||
serialisation_format: t.Literal['safetensors', 'legacy'],
|
||||
) -> bentoml.Model:
|
||||
"""Setup LLM interactively.
|
||||
|
||||
@@ -344,73 +344,73 @@ def import_command(
|
||||
"""
|
||||
llm_config = AutoConfig.for_model(model_name)
|
||||
env = EnvVarMixin(model_name, llm_config.default_implementation(), model_id=model_id, runtime=runtime, quantize=quantize)
|
||||
impl: LiteralRuntime = first_not_none(implementation, default=env["framework_value"])
|
||||
impl: LiteralRuntime = first_not_none(implementation, default=env['framework_value'])
|
||||
llm = infer_auto_class(impl).for_model(
|
||||
model_name, model_id=env["model_id_value"], llm_config=llm_config, model_version=model_version, ensure_available=False, serialisation=serialisation_format
|
||||
model_name, model_id=env['model_id_value'], llm_config=llm_config, model_version=model_version, ensure_available=False, serialisation=serialisation_format
|
||||
)
|
||||
_previously_saved = False
|
||||
try:
|
||||
_ref = serialisation.get(llm)
|
||||
_previously_saved = True
|
||||
except bentoml.exceptions.NotFound:
|
||||
if not machine and output == "pretty":
|
||||
if not machine and output == 'pretty':
|
||||
msg = f"'{model_name}' {'with model_id='+ model_id if model_id is not None else ''} does not exists in local store for implementation {llm.__llm_implementation__}. Saving to BENTOML_HOME{' (path=' + os.environ.get('BENTOML_HOME', BentoMLContainer.bentoml_home.get()) + ')' if get_debug_mode() else ''}..."
|
||||
termui.echo(msg, fg="yellow", nl=True)
|
||||
termui.echo(msg, fg='yellow', nl=True)
|
||||
_ref = serialisation.get(llm, auto_import=True)
|
||||
if impl == "pt" and is_torch_available() and torch.cuda.is_available(): torch.cuda.empty_cache()
|
||||
if impl == 'pt' and is_torch_available() and torch.cuda.is_available(): torch.cuda.empty_cache()
|
||||
if machine: return _ref
|
||||
elif output == "pretty":
|
||||
if _previously_saved: termui.echo(f"{model_name} with 'model_id={model_id}' is already setup for framework '{impl}': {_ref.tag!s}", nl=True, fg="yellow")
|
||||
else: termui.echo(f"Saved model: {_ref.tag}")
|
||||
elif output == "json": termui.echo(orjson.dumps({"previously_setup": _previously_saved, "framework": impl, "tag": str(_ref.tag)}, option=orjson.OPT_INDENT_2).decode())
|
||||
elif output == 'pretty':
|
||||
if _previously_saved: termui.echo(f"{model_name} with 'model_id={model_id}' is already setup for framework '{impl}': {_ref.tag!s}", nl=True, fg='yellow')
|
||||
else: termui.echo(f'Saved model: {_ref.tag}')
|
||||
elif output == 'json': termui.echo(orjson.dumps({'previously_setup': _previously_saved, 'framework': impl, 'tag': str(_ref.tag)}, option=orjson.OPT_INDENT_2).decode())
|
||||
else: termui.echo(_ref.tag)
|
||||
return _ref
|
||||
@cli.command(context_settings={"token_normalize_func": inflection.underscore})
|
||||
@cli.command(context_settings={'token_normalize_func': inflection.underscore})
|
||||
@model_name_argument
|
||||
@model_id_option
|
||||
@output_option
|
||||
@machine_option
|
||||
@click.option("--bento-version", type=str, default=None, help="Optional bento version for this BentoLLM. Default is the the model revision.")
|
||||
@click.option("--overwrite", is_flag=True, help="Overwrite existing Bento for given LLM if it already exists.")
|
||||
@click.option('--bento-version', type=str, default=None, help='Optional bento version for this BentoLLM. Default is the the model revision.')
|
||||
@click.option('--overwrite', is_flag=True, help='Overwrite existing Bento for given LLM if it already exists.')
|
||||
@workers_per_resource_option(factory=click, build=True)
|
||||
@click.option("--device", type=dantic.CUDA, multiple=True, envvar="CUDA_VISIBLE_DEVICES", callback=parse_device_callback, help="Set the device", show_envvar=True)
|
||||
@cog.optgroup.group(cls=cog.MutuallyExclusiveOptionGroup, name="Optimisation options")
|
||||
@click.option('--device', type=dantic.CUDA, multiple=True, envvar='CUDA_VISIBLE_DEVICES', callback=parse_device_callback, help='Set the device', show_envvar=True)
|
||||
@cog.optgroup.group(cls=cog.MutuallyExclusiveOptionGroup, name='Optimisation options')
|
||||
@quantize_option(factory=cog.optgroup, build=True)
|
||||
@bettertransformer_option(factory=cog.optgroup)
|
||||
@click.option("--runtime", type=click.Choice(["ggml", "transformers"]), default="transformers", help="The runtime to use for the given model. Default is transformers.")
|
||||
@click.option('--runtime', type=click.Choice(['ggml', 'transformers']), default='transformers', help='The runtime to use for the given model. Default is transformers.')
|
||||
@click.option(
|
||||
"--enable-features",
|
||||
'--enable-features',
|
||||
multiple=True,
|
||||
nargs=1,
|
||||
metavar="FEATURE[,FEATURE]",
|
||||
help="Enable additional features for building this LLM Bento. Available: {}".format(", ".join(OPTIONAL_DEPENDENCIES))
|
||||
metavar='FEATURE[,FEATURE]',
|
||||
help='Enable additional features for building this LLM Bento. Available: {}'.format(', '.join(OPTIONAL_DEPENDENCIES))
|
||||
)
|
||||
@click.option(
|
||||
"--adapter-id",
|
||||
'--adapter-id',
|
||||
default=None,
|
||||
multiple=True,
|
||||
metavar="[PATH | [remote/][adapter_name:]adapter_id][, ...]",
|
||||
metavar='[PATH | [remote/][adapter_name:]adapter_id][, ...]',
|
||||
help="Optional adapters id to be included within the Bento. Note that if you are using relative path, '--build-ctx' must be passed."
|
||||
)
|
||||
@click.option("--build-ctx", help="Build context. This is required if --adapter-id uses relative path", default=None)
|
||||
@click.option('--build-ctx', help='Build context. This is required if --adapter-id uses relative path', default=None)
|
||||
@model_version_option
|
||||
@click.option("--dockerfile-template", default=None, type=click.File(), help="Optional custom dockerfile template to be used with this BentoLLM.")
|
||||
@click.option('--dockerfile-template', default=None, type=click.File(), help='Optional custom dockerfile template to be used with this BentoLLM.')
|
||||
@serialisation_option
|
||||
@container_registry_option
|
||||
@click.option(
|
||||
"--container-version-strategy", type=click.Choice(["release", "latest", "nightly"]), default="release", help="Default container version strategy for the image from '--container-registry'"
|
||||
'--container-version-strategy', type=click.Choice(['release', 'latest', 'nightly']), default='release', help="Default container version strategy for the image from '--container-registry'"
|
||||
)
|
||||
@fast_option
|
||||
@cog.optgroup.group(cls=cog.MutuallyExclusiveOptionGroup, name="Utilities options")
|
||||
@cog.optgroup.group(cls=cog.MutuallyExclusiveOptionGroup, name='Utilities options')
|
||||
@cog.optgroup.option(
|
||||
"--containerize",
|
||||
'--containerize',
|
||||
default=False,
|
||||
is_flag=True,
|
||||
type=click.BOOL,
|
||||
help="Whether to containerize the Bento after building. '--containerize' is the shortcut of 'openllm build && bentoml containerize'."
|
||||
)
|
||||
@cog.optgroup.option("--push", default=False, is_flag=True, type=click.BOOL, help="Whether to push the result bento to BentoCloud. Make sure to login with 'bentoml cloud login' first.")
|
||||
@click.option("--force-push", default=False, is_flag=True, type=click.BOOL, help="Whether to force push.")
|
||||
@cog.optgroup.option('--push', default=False, is_flag=True, type=click.BOOL, help="Whether to push the result bento to BentoCloud. Make sure to login with 'bentoml cloud login' first.")
|
||||
@click.option('--force-push', default=False, is_flag=True, type=click.BOOL, help='Whether to force push.')
|
||||
@click.pass_context
|
||||
def build_command(
|
||||
ctx: click.Context,
|
||||
@@ -420,8 +420,8 @@ def build_command(
|
||||
bento_version: str | None,
|
||||
overwrite: bool,
|
||||
output: LiteralOutput,
|
||||
runtime: t.Literal["ggml", "transformers"],
|
||||
quantize: t.Literal["int8", "int4", "gptq"] | None,
|
||||
runtime: t.Literal['ggml', 'transformers'],
|
||||
quantize: t.Literal['int8', 'int4', 'gptq'] | None,
|
||||
enable_features: tuple[str, ...] | None,
|
||||
bettertransformer: bool | None,
|
||||
workers_per_resource: float | None,
|
||||
@@ -433,14 +433,14 @@ def build_command(
|
||||
dockerfile_template: t.TextIO | None,
|
||||
containerize: bool,
|
||||
push: bool,
|
||||
serialisation_format: t.Literal["safetensors", "legacy"],
|
||||
serialisation_format: t.Literal['safetensors', 'legacy'],
|
||||
fast: bool,
|
||||
container_registry: LiteralContainerRegistry,
|
||||
container_version_strategy: LiteralContainerVersionStrategy,
|
||||
force_push: bool,
|
||||
**attrs: t.Any,
|
||||
) -> bentoml.Bento:
|
||||
"""Package a given models into a Bento.
|
||||
'''Package a given models into a Bento.
|
||||
|
||||
\b
|
||||
```bash
|
||||
@@ -456,9 +456,9 @@ def build_command(
|
||||
> [!IMPORTANT]
|
||||
> To build the bento with compiled OpenLLM, make sure to prepend HATCH_BUILD_HOOKS_ENABLE=1. Make sure that the deployment
|
||||
> target also use the same Python version and architecture as build machine.
|
||||
"""
|
||||
if machine: output = "porcelain"
|
||||
if enable_features: enable_features = tuple(itertools.chain.from_iterable((s.split(",") for s in enable_features)))
|
||||
'''
|
||||
if machine: output = 'porcelain'
|
||||
if enable_features: enable_features = tuple(itertools.chain.from_iterable((s.split(',') for s in enable_features)))
|
||||
|
||||
_previously_built = False
|
||||
|
||||
@@ -468,32 +468,32 @@ def build_command(
|
||||
# NOTE: We set this environment variable so that our service.py logic won't raise RuntimeError
|
||||
# during build. This is a current limitation of bentoml build where we actually import the service.py into sys.path
|
||||
try:
|
||||
os.environ.update({"OPENLLM_MODEL": inflection.underscore(model_name), env.runtime: str(env["runtime_value"]), "OPENLLM_SERIALIZATION": serialisation_format})
|
||||
if env["model_id_value"]: os.environ[env.model_id] = str(env["model_id_value"])
|
||||
if env["quantize_value"]: os.environ[env.quantize] = str(env["quantize_value"])
|
||||
os.environ[env.bettertransformer] = str(env["bettertransformer_value"])
|
||||
os.environ.update({'OPENLLM_MODEL': inflection.underscore(model_name), env.runtime: str(env['runtime_value']), 'OPENLLM_SERIALIZATION': serialisation_format})
|
||||
if env['model_id_value']: os.environ[env.model_id] = str(env['model_id_value'])
|
||||
if env['quantize_value']: os.environ[env.quantize] = str(env['quantize_value'])
|
||||
os.environ[env.bettertransformer] = str(env['bettertransformer_value'])
|
||||
|
||||
llm = infer_auto_class(env["framework_value"]).for_model(
|
||||
model_name, model_id=env["model_id_value"], llm_config=llm_config, ensure_available=not fast, model_version=model_version, serialisation=serialisation_format, **attrs
|
||||
llm = infer_auto_class(env['framework_value']).for_model(
|
||||
model_name, model_id=env['model_id_value'], llm_config=llm_config, ensure_available=not fast, model_version=model_version, serialisation=serialisation_format, **attrs
|
||||
)
|
||||
|
||||
labels = dict(llm.identifying_params)
|
||||
labels.update({"_type": llm.llm_type, "_framework": env["framework_value"]})
|
||||
workers_per_resource = first_not_none(workers_per_resource, default=llm_config["workers_per_resource"])
|
||||
labels.update({'_type': llm.llm_type, '_framework': env['framework_value']})
|
||||
workers_per_resource = first_not_none(workers_per_resource, default=llm_config['workers_per_resource'])
|
||||
|
||||
with fs.open_fs(f"temp://llm_{llm_config['model_name']}") as llm_fs:
|
||||
dockerfile_template_path = None
|
||||
if dockerfile_template:
|
||||
with dockerfile_template:
|
||||
llm_fs.writetext("Dockerfile.template", dockerfile_template.read())
|
||||
dockerfile_template_path = llm_fs.getsyspath("/Dockerfile.template")
|
||||
llm_fs.writetext('Dockerfile.template', dockerfile_template.read())
|
||||
dockerfile_template_path = llm_fs.getsyspath('/Dockerfile.template')
|
||||
|
||||
adapter_map: dict[str, str | None] | None = None
|
||||
if adapter_id:
|
||||
if not build_ctx: ctx.fail("'build_ctx' is required when '--adapter-id' is passsed.")
|
||||
adapter_map = {}
|
||||
for v in adapter_id:
|
||||
_adapter_id, *adapter_name = v.rsplit(":", maxsplit=1)
|
||||
_adapter_id, *adapter_name = v.rsplit(':', maxsplit=1)
|
||||
name = adapter_name[0] if len(adapter_name) > 0 else None
|
||||
try:
|
||||
resolve_user_filepath(_adapter_id, build_ctx)
|
||||
@@ -508,16 +508,16 @@ def build_command(
|
||||
# that edge case.
|
||||
except FileNotFoundError:
|
||||
adapter_map[_adapter_id] = name
|
||||
os.environ["OPENLLM_ADAPTER_MAP"] = orjson.dumps(adapter_map).decode()
|
||||
os.environ['OPENLLM_ADAPTER_MAP'] = orjson.dumps(adapter_map).decode()
|
||||
|
||||
_bento_version = first_not_none(bento_version, default=llm.tag.version)
|
||||
bento_tag = bentoml.Tag.from_taglike(f"{llm.llm_type}-service:{_bento_version}".lower().strip())
|
||||
bento_tag = bentoml.Tag.from_taglike(f'{llm.llm_type}-service:{_bento_version}'.lower().strip())
|
||||
try:
|
||||
bento = bentoml.get(bento_tag)
|
||||
if overwrite:
|
||||
if output == "pretty": termui.echo(f"Overwriting existing Bento {bento_tag}", fg="yellow")
|
||||
if output == 'pretty': termui.echo(f'Overwriting existing Bento {bento_tag}', fg='yellow')
|
||||
bentoml.delete(bento_tag)
|
||||
raise bentoml.exceptions.NotFound(f"Rebuilding existing Bento {bento_tag}") from None
|
||||
raise bentoml.exceptions.NotFound(f'Rebuilding existing Bento {bento_tag}') from None
|
||||
_previously_built = True
|
||||
except bentoml.exceptions.NotFound:
|
||||
bento = bundle.create_bento(
|
||||
@@ -537,38 +537,38 @@ def build_command(
|
||||
except Exception as err:
|
||||
raise err from None
|
||||
|
||||
if machine: termui.echo(f"__tag__:{bento.tag}", fg="white")
|
||||
elif output == "pretty":
|
||||
if machine: termui.echo(f'__tag__:{bento.tag}', fg='white')
|
||||
elif output == 'pretty':
|
||||
if not get_quiet_mode() and (not push or not containerize):
|
||||
termui.echo("\n" + OPENLLM_FIGLET, fg="white")
|
||||
if not _previously_built: termui.echo(f"Successfully built {bento}.", fg="green")
|
||||
elif not overwrite: termui.echo(f"'{model_name}' already has a Bento built [{bento}]. To overwrite it pass '--overwrite'.", fg="yellow")
|
||||
termui.echo('\n' + OPENLLM_FIGLET, fg='white')
|
||||
if not _previously_built: termui.echo(f'Successfully built {bento}.', fg='green')
|
||||
elif not overwrite: termui.echo(f"'{model_name}' already has a Bento built [{bento}]. To overwrite it pass '--overwrite'.", fg='yellow')
|
||||
termui.echo(
|
||||
"📖 Next steps:\n\n" + f"* Push to BentoCloud with 'bentoml push':\n\t$ bentoml push {bento.tag}\n\n" +
|
||||
'📖 Next steps:\n\n' + f"* Push to BentoCloud with 'bentoml push':\n\t$ bentoml push {bento.tag}\n\n" +
|
||||
f"* Containerize your Bento with 'bentoml containerize':\n\t$ bentoml containerize {bento.tag} --opt progress=plain\n\n" +
|
||||
"\tTip: To enable additional BentoML features for 'containerize', use '--enable-features=FEATURE[,FEATURE]' [see 'bentoml containerize -h' for more advanced usage]\n",
|
||||
fg="blue",
|
||||
fg='blue',
|
||||
)
|
||||
elif output == "json":
|
||||
elif output == 'json':
|
||||
termui.echo(orjson.dumps(bento.info.to_dict(), option=orjson.OPT_INDENT_2).decode())
|
||||
else:
|
||||
termui.echo(bento.tag)
|
||||
|
||||
if push: BentoMLContainer.bentocloud_client.get().push_bento(bento, context=t.cast(GlobalOptions, ctx.obj).cloud_context, force=force_push)
|
||||
elif containerize:
|
||||
backend = t.cast("DefaultBuilder", os.environ.get("BENTOML_CONTAINERIZE_BACKEND", "docker"))
|
||||
backend = t.cast('DefaultBuilder', os.environ.get('BENTOML_CONTAINERIZE_BACKEND', 'docker'))
|
||||
try:
|
||||
bentoml.container.health(backend)
|
||||
except subprocess.CalledProcessError:
|
||||
raise OpenLLMException(f"Failed to use backend {backend}") from None
|
||||
raise OpenLLMException(f'Failed to use backend {backend}') from None
|
||||
try:
|
||||
bentoml.container.build(bento.tag, backend=backend, features=("grpc", "io"))
|
||||
bentoml.container.build(bento.tag, backend=backend, features=('grpc', 'io'))
|
||||
except Exception as err:
|
||||
raise OpenLLMException(f"Exception caught while containerizing '{bento.tag!s}':\n{err}") from err
|
||||
return bento
|
||||
@cli.command()
|
||||
@output_option
|
||||
@click.option("--show-available", is_flag=True, default=False, help="Show available models in local store (mutually exclusive with '-o porcelain').")
|
||||
@click.option('--show-available', is_flag=True, default=False, help="Show available models in local store (mutually exclusive with '-o porcelain').")
|
||||
@machine_option
|
||||
@click.pass_context
|
||||
def models_command(ctx: click.Context, output: LiteralOutput, show_available: bool, machine: bool) -> DictStrAny | None:
|
||||
@@ -585,30 +585,30 @@ def models_command(ctx: click.Context, output: LiteralOutput, show_available: bo
|
||||
from .._llm import normalise_model_name
|
||||
|
||||
models = tuple(inflection.dasherize(key) for key in CONFIG_MAPPING.keys())
|
||||
if output == "porcelain":
|
||||
if show_available: raise click.BadOptionUsage("--show-available", "Cannot use '--show-available' with '-o porcelain' (mutually exclusive).")
|
||||
termui.echo("\n".join(models), fg="white")
|
||||
if output == 'porcelain':
|
||||
if show_available: raise click.BadOptionUsage('--show-available', "Cannot use '--show-available' with '-o porcelain' (mutually exclusive).")
|
||||
termui.echo('\n'.join(models), fg='white')
|
||||
else:
|
||||
failed_initialized: list[tuple[str, Exception]] = []
|
||||
|
||||
json_data: dict[str, dict[t.Literal["architecture", "model_id", "url", "installation", "cpu", "gpu", "runtime_impl"], t.Any] | t.Any] = {}
|
||||
json_data: dict[str, dict[t.Literal['architecture', 'model_id', 'url', 'installation', 'cpu', 'gpu', 'runtime_impl'], t.Any] | t.Any] = {}
|
||||
converted: list[str] = []
|
||||
for m in models:
|
||||
config = AutoConfig.for_model(m)
|
||||
runtime_impl: tuple[str, ...] = ()
|
||||
if config["model_name"] in MODEL_MAPPING_NAMES: runtime_impl += ("pt",)
|
||||
if config["model_name"] in MODEL_FLAX_MAPPING_NAMES: runtime_impl += ("flax",)
|
||||
if config["model_name"] in MODEL_TF_MAPPING_NAMES: runtime_impl += ("tf",)
|
||||
if config["model_name"] in MODEL_VLLM_MAPPING_NAMES: runtime_impl += ("vllm",)
|
||||
if config['model_name'] in MODEL_MAPPING_NAMES: runtime_impl += ('pt',)
|
||||
if config['model_name'] in MODEL_FLAX_MAPPING_NAMES: runtime_impl += ('flax',)
|
||||
if config['model_name'] in MODEL_TF_MAPPING_NAMES: runtime_impl += ('tf',)
|
||||
if config['model_name'] in MODEL_VLLM_MAPPING_NAMES: runtime_impl += ('vllm',)
|
||||
json_data[m] = {
|
||||
"architecture": config["architecture"],
|
||||
"model_id": config["model_ids"],
|
||||
"cpu": not config["requires_gpu"],
|
||||
"gpu": True,
|
||||
"runtime_impl": runtime_impl,
|
||||
"installation": f'"openllm[{m}]"' if m in OPTIONAL_DEPENDENCIES or config["requirements"] else "openllm",
|
||||
'architecture': config['architecture'],
|
||||
'model_id': config['model_ids'],
|
||||
'cpu': not config['requires_gpu'],
|
||||
'gpu': True,
|
||||
'runtime_impl': runtime_impl,
|
||||
'installation': f'"openllm[{m}]"' if m in OPTIONAL_DEPENDENCIES or config['requirements'] else 'openllm',
|
||||
}
|
||||
converted.extend([normalise_model_name(i) for i in config["model_ids"]])
|
||||
converted.extend([normalise_model_name(i) for i in config['model_ids']])
|
||||
if DEBUG:
|
||||
try:
|
||||
AutoLLM.for_model(m, llm_config=config)
|
||||
@@ -617,7 +617,7 @@ def models_command(ctx: click.Context, output: LiteralOutput, show_available: bo
|
||||
|
||||
ids_in_local_store = {
|
||||
k: [
|
||||
i for i in bentoml.models.list() if "framework" in i.info.labels and i.info.labels["framework"] == "openllm" and "model_name" in i.info.labels and i.info.labels["model_name"] == k
|
||||
i for i in bentoml.models.list() if 'framework' in i.info.labels and i.info.labels['framework'] == 'openllm' and 'model_name' in i.info.labels and i.info.labels['model_name'] == k
|
||||
] for k in json_data.keys()
|
||||
}
|
||||
ids_in_local_store = {k: v for k, v in ids_in_local_store.items() if v}
|
||||
@@ -626,74 +626,74 @@ def models_command(ctx: click.Context, output: LiteralOutput, show_available: bo
|
||||
local_models = {k: [str(i.tag) for i in val] for k, val in ids_in_local_store.items()}
|
||||
|
||||
if machine:
|
||||
if show_available: json_data["local"] = local_models
|
||||
if show_available: json_data['local'] = local_models
|
||||
return json_data
|
||||
elif output == "pretty":
|
||||
elif output == 'pretty':
|
||||
import tabulate
|
||||
|
||||
tabulate.PRESERVE_WHITESPACE = True
|
||||
# llm, architecture, url, model_id, installation, cpu, gpu, runtime_impl
|
||||
data: list[str | tuple[str, str, list[str], str, LiteralString, LiteralString, tuple[LiteralRuntime, ...]]] = []
|
||||
for m, v in json_data.items():
|
||||
data.extend([(m, v["architecture"], v["model_id"], v["installation"], "❌" if not v["cpu"] else "✅", "✅", v["runtime_impl"],)])
|
||||
data.extend([(m, v['architecture'], v['model_id'], v['installation'], '❌' if not v['cpu'] else '✅', '✅', v['runtime_impl'],)])
|
||||
column_widths = [
|
||||
int(termui.COLUMNS / 12), int(termui.COLUMNS / 6), int(termui.COLUMNS / 4), int(termui.COLUMNS / 12), int(termui.COLUMNS / 12), int(termui.COLUMNS / 12), int(termui.COLUMNS / 4),
|
||||
]
|
||||
|
||||
if len(data) == 0 and len(failed_initialized) > 0:
|
||||
termui.echo("Exception found while parsing models:\n", fg="yellow")
|
||||
termui.echo('Exception found while parsing models:\n', fg='yellow')
|
||||
for m, err in failed_initialized:
|
||||
termui.echo(f"- {m}: ", fg="yellow", nl=False)
|
||||
termui.echo(traceback.print_exception(None, err, None, limit=5), fg="red") # type: ignore[func-returns-value]
|
||||
termui.echo(f'- {m}: ', fg='yellow', nl=False)
|
||||
termui.echo(traceback.print_exception(None, err, None, limit=5), fg='red') # type: ignore[func-returns-value]
|
||||
sys.exit(1)
|
||||
|
||||
table = tabulate.tabulate(data, tablefmt="fancy_grid", headers=["LLM", "Architecture", "Models Id", "pip install", "CPU", "GPU", "Runtime"], maxcolwidths=column_widths)
|
||||
termui.echo(table, fg="white")
|
||||
table = tabulate.tabulate(data, tablefmt='fancy_grid', headers=['LLM', 'Architecture', 'Models Id', 'pip install', 'CPU', 'GPU', 'Runtime'], maxcolwidths=column_widths)
|
||||
termui.echo(table, fg='white')
|
||||
|
||||
if DEBUG and len(failed_initialized) > 0:
|
||||
termui.echo("\nThe following models are supported but failed to initialize:\n")
|
||||
termui.echo('\nThe following models are supported but failed to initialize:\n')
|
||||
for m, err in failed_initialized:
|
||||
termui.echo(f"- {m}: ", fg="blue", nl=False)
|
||||
termui.echo(err, fg="red")
|
||||
termui.echo(f'- {m}: ', fg='blue', nl=False)
|
||||
termui.echo(err, fg='red')
|
||||
|
||||
if show_available:
|
||||
if len(ids_in_local_store) == 0:
|
||||
termui.echo("No models available locally.")
|
||||
termui.echo('No models available locally.')
|
||||
ctx.exit(0)
|
||||
termui.echo("The following are available in local store:", fg="magenta")
|
||||
termui.echo(orjson.dumps(local_models, option=orjson.OPT_INDENT_2).decode(), fg="white")
|
||||
termui.echo('The following are available in local store:', fg='magenta')
|
||||
termui.echo(orjson.dumps(local_models, option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
else:
|
||||
if show_available: json_data["local"] = local_models
|
||||
termui.echo(orjson.dumps(json_data, option=orjson.OPT_INDENT_2,).decode(), fg="white")
|
||||
if show_available: json_data['local'] = local_models
|
||||
termui.echo(orjson.dumps(json_data, option=orjson.OPT_INDENT_2,).decode(), fg='white')
|
||||
ctx.exit(0)
|
||||
@cli.command()
|
||||
@model_name_argument(required=False)
|
||||
@click.option("-y", "--yes", "--assume-yes", is_flag=True, help="Skip confirmation when deleting a specific model")
|
||||
@click.option("--include-bentos/--no-include-bentos", is_flag=True, default=False, help="Whether to also include pruning bentos.")
|
||||
@click.option('-y', '--yes', '--assume-yes', is_flag=True, help='Skip confirmation when deleting a specific model')
|
||||
@click.option('--include-bentos/--no-include-bentos', is_flag=True, default=False, help='Whether to also include pruning bentos.')
|
||||
@inject
|
||||
def prune_command(
|
||||
model_name: str | None, yes: bool, include_bentos: bool, model_store: ModelStore = Provide[BentoMLContainer.model_store], bento_store: BentoStore = Provide[BentoMLContainer.bento_store]
|
||||
) -> None:
|
||||
"""Remove all saved models, (and optionally bentos) built with OpenLLM locally.
|
||||
'''Remove all saved models, (and optionally bentos) built with OpenLLM locally.
|
||||
|
||||
\b
|
||||
If a model type is passed, then only prune models for that given model type.
|
||||
"""
|
||||
'''
|
||||
available: list[tuple[bentoml.Model | bentoml.Bento,
|
||||
ModelStore | BentoStore]] = [(m, model_store) for m in bentoml.models.list() if "framework" in m.info.labels and m.info.labels["framework"] == "openllm"]
|
||||
if model_name is not None: available = [(m, store) for m, store in available if "model_name" in m.info.labels and m.info.labels["model_name"] == inflection.underscore(model_name)]
|
||||
ModelStore | BentoStore]] = [(m, model_store) for m in bentoml.models.list() if 'framework' in m.info.labels and m.info.labels['framework'] == 'openllm']
|
||||
if model_name is not None: available = [(m, store) for m, store in available if 'model_name' in m.info.labels and m.info.labels['model_name'] == inflection.underscore(model_name)]
|
||||
if include_bentos:
|
||||
if model_name is not None:
|
||||
available += [(b, bento_store) for b in bentoml.bentos.list() if "start_name" in b.info.labels and b.info.labels["start_name"] == inflection.underscore(model_name)]
|
||||
available += [(b, bento_store) for b in bentoml.bentos.list() if 'start_name' in b.info.labels and b.info.labels['start_name'] == inflection.underscore(model_name)]
|
||||
else:
|
||||
available += [(b, bento_store) for b in bentoml.bentos.list() if "_type" in b.info.labels and "_framework" in b.info.labels]
|
||||
available += [(b, bento_store) for b in bentoml.bentos.list() if '_type' in b.info.labels and '_framework' in b.info.labels]
|
||||
|
||||
for store_item, store in available:
|
||||
if yes: delete_confirmed = True
|
||||
else: delete_confirmed = click.confirm(f"delete {'model' if isinstance(store, ModelStore) else 'bento'} {store_item.tag}?")
|
||||
if delete_confirmed:
|
||||
store.delete(store_item.tag)
|
||||
termui.echo(f"{store_item} deleted from {'model' if isinstance(store, ModelStore) else 'bento'} store.", fg="yellow")
|
||||
termui.echo(f"{store_item} deleted from {'model' if isinstance(store, ModelStore) else 'bento'} store.", fg='yellow')
|
||||
def parsing_instruction_callback(ctx: click.Context, param: click.Parameter, value: list[str] | str | None) -> tuple[str, bool | str] | list[str] | str | None:
|
||||
if value is None:
|
||||
return value
|
||||
@@ -702,40 +702,40 @@ def parsing_instruction_callback(ctx: click.Context, param: click.Parameter, val
|
||||
# we only parse --text foo bar -> --text foo and omit bar
|
||||
value = value[-1]
|
||||
|
||||
key, *values = value.split("=")
|
||||
if not key.startswith("--"):
|
||||
raise click.BadParameter(f"Invalid option format: {value}")
|
||||
key, *values = value.split('=')
|
||||
if not key.startswith('--'):
|
||||
raise click.BadParameter(f'Invalid option format: {value}')
|
||||
key = key[2:]
|
||||
if len(values) == 0:
|
||||
return key, True
|
||||
elif len(values) == 1:
|
||||
return key, values[0]
|
||||
else:
|
||||
raise click.BadParameter(f"Invalid option format: {value}")
|
||||
def shared_client_options(f: _AnyCallable | None = None, output_value: t.Literal["json", "porcelain", "pretty"] = "pretty") -> t.Callable[[FC], FC]:
|
||||
raise click.BadParameter(f'Invalid option format: {value}')
|
||||
def shared_client_options(f: _AnyCallable | None = None, output_value: t.Literal['json', 'porcelain', 'pretty'] = 'pretty') -> t.Callable[[FC], FC]:
|
||||
options = [
|
||||
click.option("--endpoint", type=click.STRING, help="OpenLLM Server endpoint, i.e: http://localhost:3000", envvar="OPENLLM_ENDPOINT", default="http://localhost:3000",
|
||||
click.option('--endpoint', type=click.STRING, help='OpenLLM Server endpoint, i.e: http://localhost:3000', envvar='OPENLLM_ENDPOINT', default='http://localhost:3000',
|
||||
),
|
||||
click.option("--timeout", type=click.INT, default=30, help="Default server timeout", show_default=True),
|
||||
click.option('--timeout', type=click.INT, default=30, help='Default server timeout', show_default=True),
|
||||
output_option(default_value=output_value),
|
||||
]
|
||||
return compose(*options)(f) if f is not None else compose(*options)
|
||||
@cli.command()
|
||||
@click.argument("task", type=click.STRING, metavar="TASK")
|
||||
@click.argument('task', type=click.STRING, metavar='TASK')
|
||||
@shared_client_options
|
||||
@click.option("--agent", type=click.Choice(["hf"]), default="hf", help="Whether to interact with Agents from given Server endpoint.", show_default=True)
|
||||
@click.option("--remote", is_flag=True, default=False, help="Whether or not to use remote tools (inference endpoints) instead of local ones.", show_default=True)
|
||||
@click.option('--agent', type=click.Choice(['hf']), default='hf', help='Whether to interact with Agents from given Server endpoint.', show_default=True)
|
||||
@click.option('--remote', is_flag=True, default=False, help='Whether or not to use remote tools (inference endpoints) instead of local ones.', show_default=True)
|
||||
@click.option(
|
||||
"--opt",
|
||||
'--opt',
|
||||
help="Define prompt options. "
|
||||
"(format: ``--opt text='I love this' --opt audio:./path/to/audio --opt image:/path/to/file``)",
|
||||
required=False,
|
||||
multiple=True,
|
||||
callback=opt_callback,
|
||||
metavar="ARG=VALUE[,ARG=VALUE]"
|
||||
metavar='ARG=VALUE[,ARG=VALUE]'
|
||||
)
|
||||
def instruct_command(endpoint: str, timeout: int, agent: LiteralString, output: LiteralOutput, remote: bool, task: str, _memoized: DictStrAny, **attrs: t.Any) -> str:
|
||||
"""Instruct agents interactively for given tasks, from a terminal.
|
||||
'''Instruct agents interactively for given tasks, from a terminal.
|
||||
|
||||
\b
|
||||
```bash
|
||||
@@ -743,92 +743,92 @@ def instruct_command(endpoint: str, timeout: int, agent: LiteralString, output:
|
||||
"Is the following `text` (in Spanish) positive or negative?" \\
|
||||
--text "¡Este es un API muy agradable!"
|
||||
```
|
||||
"""
|
||||
'''
|
||||
client = openllm.client.HTTPClient(endpoint, timeout=timeout)
|
||||
|
||||
try:
|
||||
client.call("metadata")
|
||||
client.call('metadata')
|
||||
except http.client.BadStatusLine:
|
||||
raise click.ClickException(f"{endpoint} is neither a HTTP server nor reachable.") from None
|
||||
if agent == "hf":
|
||||
raise click.ClickException(f'{endpoint} is neither a HTTP server nor reachable.') from None
|
||||
if agent == 'hf':
|
||||
if not is_transformers_supports_agent(): raise click.UsageError("Transformers version should be at least 4.29 to support HfAgent. Upgrade with 'pip install -U transformers'")
|
||||
_memoized = {k: v[0] for k, v in _memoized.items() if v}
|
||||
client._hf_agent.set_stream(logger.info)
|
||||
if output != "porcelain": termui.echo(f"Sending the following prompt ('{task}') with the following vars: {_memoized}", fg="magenta")
|
||||
if output != 'porcelain': termui.echo(f"Sending the following prompt ('{task}') with the following vars: {_memoized}", fg='magenta')
|
||||
result = client.ask_agent(task, agent_type=agent, return_code=False, remote=remote, **_memoized)
|
||||
if output == "json": termui.echo(orjson.dumps(result, option=orjson.OPT_INDENT_2).decode(), fg="white")
|
||||
else: termui.echo(result, fg="white")
|
||||
if output == 'json': termui.echo(orjson.dumps(result, option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
else: termui.echo(result, fg='white')
|
||||
return result
|
||||
else:
|
||||
raise click.BadOptionUsage("agent", f"Unknown agent type {agent}")
|
||||
raise click.BadOptionUsage('agent', f'Unknown agent type {agent}')
|
||||
@cli.command()
|
||||
@shared_client_options(output_value="json")
|
||||
@click.option("--server-type", type=click.Choice(["grpc", "http"]), help="Server type", default="http", show_default=True)
|
||||
@click.argument("text", type=click.STRING, nargs=-1)
|
||||
@shared_client_options(output_value='json')
|
||||
@click.option('--server-type', type=click.Choice(['grpc', 'http']), help='Server type', default='http', show_default=True)
|
||||
@click.argument('text', type=click.STRING, nargs=-1)
|
||||
@machine_option
|
||||
@click.pass_context
|
||||
def embed_command(
|
||||
ctx: click.Context, text: tuple[str, ...], endpoint: str, timeout: int, server_type: t.Literal["http", "grpc"], output: LiteralOutput, machine: bool
|
||||
ctx: click.Context, text: tuple[str, ...], endpoint: str, timeout: int, server_type: t.Literal['http', 'grpc'], output: LiteralOutput, machine: bool
|
||||
) -> EmbeddingsOutput | None:
|
||||
"""Get embeddings interactively, from a terminal.
|
||||
'''Get embeddings interactively, from a terminal.
|
||||
|
||||
\b
|
||||
```bash
|
||||
$ openllm embed --endpoint http://12.323.2.1:3000 "What is the meaning of life?" "How many stars are there in the sky?"
|
||||
```
|
||||
"""
|
||||
client = openllm.client.HTTPClient(endpoint, timeout=timeout) if server_type == "http" else openllm.client.GrpcClient(endpoint, timeout=timeout)
|
||||
'''
|
||||
client = openllm.client.HTTPClient(endpoint, timeout=timeout) if server_type == 'http' else openllm.client.GrpcClient(endpoint, timeout=timeout)
|
||||
try:
|
||||
gen_embed = client.embed(text)
|
||||
except ValueError:
|
||||
raise click.ClickException(f"Endpoint {endpoint} does not support embeddings.") from None
|
||||
raise click.ClickException(f'Endpoint {endpoint} does not support embeddings.') from None
|
||||
if machine: return gen_embed
|
||||
elif output == "pretty":
|
||||
termui.echo("Generated embeddings: ", fg="magenta", nl=False)
|
||||
termui.echo(gen_embed.embeddings, fg="white")
|
||||
termui.echo("\nNumber of tokens: ", fg="magenta", nl=False)
|
||||
termui.echo(gen_embed.num_tokens, fg="white")
|
||||
elif output == "json":
|
||||
termui.echo(orjson.dumps(bentoml_cattr.unstructure(gen_embed), option=orjson.OPT_INDENT_2).decode(), fg="white")
|
||||
elif output == 'pretty':
|
||||
termui.echo('Generated embeddings: ', fg='magenta', nl=False)
|
||||
termui.echo(gen_embed.embeddings, fg='white')
|
||||
termui.echo('\nNumber of tokens: ', fg='magenta', nl=False)
|
||||
termui.echo(gen_embed.num_tokens, fg='white')
|
||||
elif output == 'json':
|
||||
termui.echo(orjson.dumps(bentoml_cattr.unstructure(gen_embed), option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
else:
|
||||
termui.echo(gen_embed.embeddings, fg="white")
|
||||
termui.echo(gen_embed.embeddings, fg='white')
|
||||
ctx.exit(0)
|
||||
@cli.command()
|
||||
@shared_client_options
|
||||
@click.option("--server-type", type=click.Choice(["grpc", "http"]), help="Server type", default="http", show_default=True)
|
||||
@click.argument("prompt", type=click.STRING)
|
||||
@click.option('--server-type', type=click.Choice(['grpc', 'http']), help='Server type', default='http', show_default=True)
|
||||
@click.argument('prompt', type=click.STRING)
|
||||
@click.option(
|
||||
"--sampling-params", help="Define query options. (format: ``--opt temperature=0.8 --opt=top_k:12)", required=False, multiple=True, callback=opt_callback, metavar="ARG=VALUE[,ARG=VALUE]"
|
||||
'--sampling-params', help='Define query options. (format: ``--opt temperature=0.8 --opt=top_k:12)', required=False, multiple=True, callback=opt_callback, metavar='ARG=VALUE[,ARG=VALUE]'
|
||||
)
|
||||
@click.pass_context
|
||||
def query_command(
|
||||
ctx: click.Context, /, prompt: str, endpoint: str, timeout: int, server_type: t.Literal["http", "grpc"], output: LiteralOutput, _memoized: DictStrAny, **attrs: t.Any
|
||||
ctx: click.Context, /, prompt: str, endpoint: str, timeout: int, server_type: t.Literal['http', 'grpc'], output: LiteralOutput, _memoized: DictStrAny, **attrs: t.Any
|
||||
) -> None:
|
||||
"""Ask a LLM interactively, from a terminal.
|
||||
'''Ask a LLM interactively, from a terminal.
|
||||
|
||||
\b
|
||||
```bash
|
||||
$ openllm query --endpoint http://12.323.2.1:3000 "What is the meaning of life?"
|
||||
```
|
||||
"""
|
||||
'''
|
||||
_memoized = {k: orjson.loads(v[0]) for k, v in _memoized.items() if v}
|
||||
if server_type == "grpc": endpoint = re.sub(r"http://", "", endpoint)
|
||||
client = openllm.client.HTTPClient(endpoint, timeout=timeout) if server_type == "http" else openllm.client.GrpcClient(endpoint, timeout=timeout)
|
||||
input_fg, generated_fg = "magenta", "cyan"
|
||||
if output != "porcelain":
|
||||
termui.echo("==Input==\n", fg="white")
|
||||
termui.echo(f"{prompt}", fg=input_fg)
|
||||
res = client.query(prompt, return_response="raw", **{**client.configuration, **_memoized})
|
||||
if output == "pretty":
|
||||
response = client.config.postprocess_generate(prompt, res["responses"])
|
||||
termui.echo("\n\n==Responses==\n", fg="white")
|
||||
if server_type == 'grpc': endpoint = re.sub(r'http://', '', endpoint)
|
||||
client = openllm.client.HTTPClient(endpoint, timeout=timeout) if server_type == 'http' else openllm.client.GrpcClient(endpoint, timeout=timeout)
|
||||
input_fg, generated_fg = 'magenta', 'cyan'
|
||||
if output != 'porcelain':
|
||||
termui.echo('==Input==\n', fg='white')
|
||||
termui.echo(f'{prompt}', fg=input_fg)
|
||||
res = client.query(prompt, return_response='raw', **{**client.configuration, **_memoized})
|
||||
if output == 'pretty':
|
||||
response = client.config.postprocess_generate(prompt, res['responses'])
|
||||
termui.echo('\n\n==Responses==\n', fg='white')
|
||||
termui.echo(response, fg=generated_fg)
|
||||
elif output == "json":
|
||||
termui.echo(orjson.dumps(res, option=orjson.OPT_INDENT_2).decode(), fg="white")
|
||||
elif output == 'json':
|
||||
termui.echo(orjson.dumps(res, option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
else:
|
||||
termui.echo(res["responses"], fg="white")
|
||||
termui.echo(res['responses'], fg='white')
|
||||
ctx.exit(0)
|
||||
@cli.group(cls=Extensions, hidden=True, name="extension")
|
||||
@cli.group(cls=Extensions, hidden=True, name='extension')
|
||||
def extension_command() -> None:
|
||||
"""Extension for OpenLLM CLI."""
|
||||
if __name__ == "__main__": cli()
|
||||
'''Extension for OpenLLM CLI.'''
|
||||
if __name__ == '__main__': cli()
|
||||
|
||||
@@ -4,9 +4,9 @@ from openllm.cli import termui
|
||||
from openllm.cli._factory import machine_option, container_registry_option
|
||||
if t.TYPE_CHECKING: from openllm_core._typing_compat import LiteralContainerRegistry, LiteralContainerVersionStrategy
|
||||
@click.command(
|
||||
"build_base_container",
|
||||
'build_base_container',
|
||||
context_settings=termui.CONTEXT_SETTINGS,
|
||||
help="""Base image builder for BentoLLM.
|
||||
help='''Base image builder for BentoLLM.
|
||||
|
||||
By default, the base image will include custom kernels (PagedAttention via vllm, FlashAttention-v2, etc.) built with CUDA 11.8, Python 3.9 on Ubuntu22.04.
|
||||
Optionally, this can also be pushed directly to remote registry. Currently support ``docker.io``, ``ghcr.io`` and ``quay.io``.
|
||||
@@ -16,13 +16,13 @@ if t.TYPE_CHECKING: from openllm_core._typing_compat import LiteralContainerRegi
|
||||
This command is only useful for debugging and for building custom base image for extending BentoML with custom base images and custom kernels.
|
||||
|
||||
Note that we already release images on our CI to ECR and GHCR, so you don't need to build it yourself.
|
||||
"""
|
||||
'''
|
||||
)
|
||||
@container_registry_option
|
||||
@click.option("--version-strategy", type=click.Choice(["release", "latest", "nightly"]), default="nightly", help="Version strategy to use for tagging the image.")
|
||||
@click.option("--push/--no-push", help="Whether to push to remote repository", is_flag=True, default=False)
|
||||
@click.option('--version-strategy', type=click.Choice(['release', 'latest', 'nightly']), default='nightly', help='Version strategy to use for tagging the image.')
|
||||
@click.option('--push/--no-push', help='Whether to push to remote repository', is_flag=True, default=False)
|
||||
@machine_option
|
||||
def cli(container_registry: tuple[LiteralContainerRegistry, ...] | None, version_strategy: LiteralContainerVersionStrategy, push: bool, machine: bool) -> dict[str, str]:
|
||||
mapping = openllm.bundle.build_container(container_registry, version_strategy, push, machine)
|
||||
if machine: termui.echo(orjson.dumps(mapping, option=orjson.OPT_INDENT_2).decode(), fg="white")
|
||||
if machine: termui.echo(orjson.dumps(mapping, option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
return mapping
|
||||
|
||||
@@ -7,21 +7,21 @@ from openllm.cli import termui
|
||||
from openllm.cli._factory import bento_complete_envvar, machine_option
|
||||
|
||||
if t.TYPE_CHECKING: from bentoml._internal.bento import BentoStore
|
||||
@click.command("dive_bentos", context_settings=termui.CONTEXT_SETTINGS)
|
||||
@click.argument("bento", type=str, shell_complete=bento_complete_envvar)
|
||||
@click.command('dive_bentos', context_settings=termui.CONTEXT_SETTINGS)
|
||||
@click.argument('bento', type=str, shell_complete=bento_complete_envvar)
|
||||
@machine_option
|
||||
@click.pass_context
|
||||
@inject
|
||||
def cli(ctx: click.Context, bento: str, machine: bool, _bento_store: BentoStore = Provide[BentoMLContainer.bento_store]) -> str | None:
|
||||
"""Dive into a BentoLLM. This is synonymous to cd $(b get <bento>:<tag> -o path)."""
|
||||
'''Dive into a BentoLLM. This is synonymous to cd $(b get <bento>:<tag> -o path).'''
|
||||
try:
|
||||
bentomodel = _bento_store.get(bento)
|
||||
except bentoml.exceptions.NotFound:
|
||||
ctx.fail(f"Bento {bento} not found. Make sure to call `openllm build` first.")
|
||||
if "bundler" not in bentomodel.info.labels or bentomodel.info.labels["bundler"] != "openllm.bundle":
|
||||
ctx.fail(f'Bento {bento} not found. Make sure to call `openllm build` first.')
|
||||
if 'bundler' not in bentomodel.info.labels or bentomodel.info.labels['bundler'] != 'openllm.bundle':
|
||||
ctx.fail(f"Bento is either too old or not built with OpenLLM. Make sure to use ``openllm build {bentomodel.info.labels['start_name']}`` for correctness.")
|
||||
if machine: return bentomodel.path
|
||||
# copy and paste this into a new shell
|
||||
if psutil.WINDOWS: subprocess.check_call([shutil.which("dir") or "dir"], cwd=bentomodel.path)
|
||||
else: subprocess.check_call([shutil.which("ls") or "ls", "-Rrthla"], cwd=bentomodel.path)
|
||||
if psutil.WINDOWS: subprocess.check_call([shutil.which('dir') or 'dir'], cwd=bentomodel.path)
|
||||
else: subprocess.check_call([shutil.which('ls') or 'ls', '-Rrthla'], cwd=bentomodel.path)
|
||||
ctx.exit(0)
|
||||
|
||||
@@ -10,17 +10,17 @@ from openllm.cli._factory import bento_complete_envvar
|
||||
from openllm_core.utils import bentoml_cattr
|
||||
|
||||
if t.TYPE_CHECKING: from bentoml._internal.bento import BentoStore
|
||||
@click.command("get_containerfile", context_settings=termui.CONTEXT_SETTINGS, help="Return Containerfile of any given Bento.")
|
||||
@click.argument("bento", type=str, shell_complete=bento_complete_envvar)
|
||||
@click.command('get_containerfile', context_settings=termui.CONTEXT_SETTINGS, help='Return Containerfile of any given Bento.')
|
||||
@click.argument('bento', type=str, shell_complete=bento_complete_envvar)
|
||||
@click.pass_context
|
||||
@inject
|
||||
def cli(ctx: click.Context, bento: str, _bento_store: BentoStore = Provide[BentoMLContainer.bento_store]) -> str:
|
||||
try:
|
||||
bentomodel = _bento_store.get(bento)
|
||||
except bentoml.exceptions.NotFound:
|
||||
ctx.fail(f"Bento {bento} not found. Make sure to call `openllm build` first.")
|
||||
ctx.fail(f'Bento {bento} not found. Make sure to call `openllm build` first.')
|
||||
# The logic below are similar to bentoml._internal.container.construct_containerfile
|
||||
with open(bentomodel.path_of("bento.yaml"), "r") as f:
|
||||
with open(bentomodel.path_of('bento.yaml'), 'r') as f:
|
||||
options = BentoInfo.from_yaml_file(f)
|
||||
# NOTE: dockerfile_template is already included in the
|
||||
# Dockerfile inside bento, and it is not relevant to
|
||||
@@ -30,7 +30,7 @@ def cli(ctx: click.Context, bento: str, _bento_store: BentoStore = Provide[Bento
|
||||
# NOTE: if users specify a dockerfile_template, we will
|
||||
# save it to /env/docker/Dockerfile.template. This is necessary
|
||||
# for the reconstruction of the Dockerfile.
|
||||
if "dockerfile_template" in docker_attrs and docker_attrs["dockerfile_template"] is not None: docker_attrs["dockerfile_template"] = "env/docker/Dockerfile.template"
|
||||
if 'dockerfile_template' in docker_attrs and docker_attrs['dockerfile_template'] is not None: docker_attrs['dockerfile_template'] = 'env/docker/Dockerfile.template'
|
||||
doc = generate_containerfile(docker=DockerOptions(**docker_attrs), build_ctx=bentomodel.path, conda=options.conda, bento_fs=bentomodel._fs, enable_buildkit=True, add_header=True)
|
||||
termui.echo(doc, fg="white")
|
||||
termui.echo(doc, fg='white')
|
||||
return bentomodel.path
|
||||
|
||||
@@ -4,46 +4,46 @@ from bentoml_cli.utils import opt_callback
|
||||
from openllm.cli import termui
|
||||
from openllm.cli._factory import model_complete_envvar, output_option, machine_option
|
||||
from openllm_core._prompt import process_prompt
|
||||
LiteralOutput = t.Literal["json", "pretty", "porcelain"]
|
||||
@click.command("get_prompt", context_settings=termui.CONTEXT_SETTINGS)
|
||||
@click.argument("model_name", type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()]), shell_complete=model_complete_envvar)
|
||||
@click.argument("prompt", type=click.STRING)
|
||||
LiteralOutput = t.Literal['json', 'pretty', 'porcelain']
|
||||
@click.command('get_prompt', context_settings=termui.CONTEXT_SETTINGS)
|
||||
@click.argument('model_name', type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()]), shell_complete=model_complete_envvar)
|
||||
@click.argument('prompt', type=click.STRING)
|
||||
@output_option
|
||||
@click.option("--format", type=click.STRING, default=None)
|
||||
@click.option('--format', type=click.STRING, default=None)
|
||||
@machine_option
|
||||
@click.option(
|
||||
"--opt",
|
||||
'--opt',
|
||||
help="Define additional prompt variables. (format: ``--opt system_prompt='You are a useful assistant'``)",
|
||||
required=False,
|
||||
multiple=True,
|
||||
callback=opt_callback,
|
||||
metavar="ARG=VALUE[,ARG=VALUE]"
|
||||
metavar='ARG=VALUE[,ARG=VALUE]'
|
||||
)
|
||||
@click.pass_context
|
||||
def cli(ctx: click.Context, /, model_name: str, prompt: str, format: str | None, output: LiteralOutput, machine: bool, _memoized: dict[str, t.Any], **_: t.Any) -> str | None:
|
||||
"""Get the default prompt used by OpenLLM."""
|
||||
'''Get the default prompt used by OpenLLM.'''
|
||||
module = openllm.utils.EnvVarMixin(model_name).module
|
||||
_memoized = {k: v[0] for k, v in _memoized.items() if v}
|
||||
try:
|
||||
template = getattr(module, "DEFAULT_PROMPT_TEMPLATE", None)
|
||||
prompt_mapping = getattr(module, "PROMPT_MAPPING", None)
|
||||
if template is None: raise click.BadArgumentUsage(f"model {model_name} does not have a default prompt template") from None
|
||||
template = getattr(module, 'DEFAULT_PROMPT_TEMPLATE', None)
|
||||
prompt_mapping = getattr(module, 'PROMPT_MAPPING', None)
|
||||
if template is None: raise click.BadArgumentUsage(f'model {model_name} does not have a default prompt template') from None
|
||||
if callable(template):
|
||||
if format is None:
|
||||
if not hasattr(module, "PROMPT_MAPPING") or module.PROMPT_MAPPING is None: raise RuntimeError("Failed to find prompt mapping while DEFAULT_PROMPT_TEMPLATE is a function.")
|
||||
raise click.BadOptionUsage("format", f"{model_name} prompt requires passing '--format' (available format: {list(module.PROMPT_MAPPING)})")
|
||||
if prompt_mapping is None: raise click.BadArgumentUsage(f"Failed to fine prompt mapping while the default prompt for {model_name} is a callable.") from None
|
||||
if format not in prompt_mapping: raise click.BadOptionUsage("format", f"Given format {format} is not valid for {model_name} (available format: {list(prompt_mapping)})")
|
||||
if not hasattr(module, 'PROMPT_MAPPING') or module.PROMPT_MAPPING is None: raise RuntimeError('Failed to find prompt mapping while DEFAULT_PROMPT_TEMPLATE is a function.')
|
||||
raise click.BadOptionUsage('format', f"{model_name} prompt requires passing '--format' (available format: {list(module.PROMPT_MAPPING)})")
|
||||
if prompt_mapping is None: raise click.BadArgumentUsage(f'Failed to fine prompt mapping while the default prompt for {model_name} is a callable.') from None
|
||||
if format not in prompt_mapping: raise click.BadOptionUsage('format', f'Given format {format} is not valid for {model_name} (available format: {list(prompt_mapping)})')
|
||||
_prompt_template = template(format)
|
||||
else:
|
||||
_prompt_template = template
|
||||
fully_formatted = process_prompt(prompt, _prompt_template, True, **_memoized)
|
||||
if machine: return repr(fully_formatted)
|
||||
elif output == "porcelain": termui.echo(repr(fully_formatted), fg="white")
|
||||
elif output == "json": termui.echo(orjson.dumps({"prompt": fully_formatted}, option=orjson.OPT_INDENT_2).decode(), fg="white")
|
||||
elif output == 'porcelain': termui.echo(repr(fully_formatted), fg='white')
|
||||
elif output == 'json': termui.echo(orjson.dumps({'prompt': fully_formatted}, option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
else:
|
||||
termui.echo(f"== Prompt for {model_name} ==\n", fg="magenta")
|
||||
termui.echo(fully_formatted, fg="white")
|
||||
termui.echo(f'== Prompt for {model_name} ==\n', fg='magenta')
|
||||
termui.echo(fully_formatted, fg='white')
|
||||
except AttributeError:
|
||||
raise click.ClickException(f"Failed to determine a default prompt template for {model_name}.") from None
|
||||
raise click.ClickException(f'Failed to determine a default prompt template for {model_name}.') from None
|
||||
ctx.exit(0)
|
||||
|
||||
@@ -3,30 +3,30 @@ import click, inflection, orjson, bentoml, openllm
|
||||
from bentoml._internal.utils import human_readable_size
|
||||
from openllm.cli import termui
|
||||
from openllm.cli._factory import LiteralOutput, output_option
|
||||
@click.command("list_bentos", context_settings=termui.CONTEXT_SETTINGS)
|
||||
@output_option(default_value="json")
|
||||
@click.command('list_bentos', context_settings=termui.CONTEXT_SETTINGS)
|
||||
@output_option(default_value='json')
|
||||
@click.pass_context
|
||||
def cli(ctx: click.Context, output: LiteralOutput) -> None:
|
||||
"""List available bentos built by OpenLLM."""
|
||||
'''List available bentos built by OpenLLM.'''
|
||||
mapping = {
|
||||
k: [{
|
||||
"tag": str(b.tag),
|
||||
"size": human_readable_size(openllm.utils.calc_dir_size(b.path)),
|
||||
"models": [{
|
||||
"tag": str(m.tag), "size": human_readable_size(openllm.utils.calc_dir_size(m.path))
|
||||
'tag': str(b.tag),
|
||||
'size': human_readable_size(openllm.utils.calc_dir_size(b.path)),
|
||||
'models': [{
|
||||
'tag': str(m.tag), 'size': human_readable_size(openllm.utils.calc_dir_size(m.path))
|
||||
} for m in (bentoml.models.get(_.tag) for _ in b.info.models)]
|
||||
} for b in tuple(i for i in bentoml.list() if all(k in i.info.labels for k in {"start_name", "bundler"})) if b.info.labels["start_name"] == k] for k in tuple(
|
||||
} for b in tuple(i for i in bentoml.list() if all(k in i.info.labels for k in {'start_name', 'bundler'})) if b.info.labels['start_name'] == k] for k in tuple(
|
||||
inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys()
|
||||
)
|
||||
}
|
||||
mapping = {k: v for k, v in mapping.items() if v}
|
||||
if output == "pretty":
|
||||
if output == 'pretty':
|
||||
import tabulate
|
||||
tabulate.PRESERVE_WHITESPACE = True
|
||||
termui.echo(
|
||||
tabulate.tabulate([(k, i["tag"], i["size"], [_["tag"] for _ in i["models"]]) for k, v in mapping.items() for i in v], tablefmt="fancy_grid", headers=["LLM", "Tag", "Size", "Models"]),
|
||||
fg="white"
|
||||
tabulate.tabulate([(k, i['tag'], i['size'], [_['tag'] for _ in i['models']]) for k, v in mapping.items() for i in v], tablefmt='fancy_grid', headers=['LLM', 'Tag', 'Size', 'Models']),
|
||||
fg='white'
|
||||
)
|
||||
else:
|
||||
termui.echo(orjson.dumps(mapping, option=orjson.OPT_INDENT_2).decode(), fg="white")
|
||||
termui.echo(orjson.dumps(mapping, option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
ctx.exit(0)
|
||||
|
||||
@@ -5,24 +5,24 @@ from bentoml._internal.utils import human_readable_size
|
||||
from openllm.cli._factory import LiteralOutput, model_name_argument, output_option, model_complete_envvar
|
||||
|
||||
if t.TYPE_CHECKING: from openllm_core._typing_compat import DictStrAny
|
||||
@click.command("list_models", context_settings=termui.CONTEXT_SETTINGS)
|
||||
@click.command('list_models', context_settings=termui.CONTEXT_SETTINGS)
|
||||
@model_name_argument(required=False, shell_complete=model_complete_envvar)
|
||||
@output_option(default_value="json")
|
||||
@output_option(default_value='json')
|
||||
def cli(model_name: str | None, output: LiteralOutput) -> DictStrAny:
|
||||
"""This is equivalent to openllm models --show-available less the nice table."""
|
||||
'''This is equivalent to openllm models --show-available less the nice table.'''
|
||||
models = tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys())
|
||||
ids_in_local_store = {
|
||||
k: [i for i in bentoml.models.list() if "framework" in i.info.labels and i.info.labels["framework"] == "openllm" and "model_name" in i.info.labels and i.info.labels["model_name"] == k
|
||||
k: [i for i in bentoml.models.list() if 'framework' in i.info.labels and i.info.labels['framework'] == 'openllm' and 'model_name' in i.info.labels and i.info.labels['model_name'] == k
|
||||
] for k in models
|
||||
}
|
||||
if model_name is not None:
|
||||
ids_in_local_store = {k: [i for i in v if "model_name" in i.info.labels and i.info.labels["model_name"] == inflection.dasherize(model_name)] for k, v in ids_in_local_store.items()}
|
||||
ids_in_local_store = {k: [i for i in v if 'model_name' in i.info.labels and i.info.labels['model_name'] == inflection.dasherize(model_name)] for k, v in ids_in_local_store.items()}
|
||||
ids_in_local_store = {k: v for k, v in ids_in_local_store.items() if v}
|
||||
local_models = {k: [{"tag": str(i.tag), "size": human_readable_size(openllm.utils.calc_dir_size(i.path))} for i in val] for k, val in ids_in_local_store.items()}
|
||||
if output == "pretty":
|
||||
local_models = {k: [{'tag': str(i.tag), 'size': human_readable_size(openllm.utils.calc_dir_size(i.path))} for i in val] for k, val in ids_in_local_store.items()}
|
||||
if output == 'pretty':
|
||||
import tabulate
|
||||
tabulate.PRESERVE_WHITESPACE = True
|
||||
termui.echo(tabulate.tabulate([(k, i["tag"], i["size"]) for k, v in local_models.items() for i in v], tablefmt="fancy_grid", headers=["LLM", "Tag", "Size"]), fg="white")
|
||||
termui.echo(tabulate.tabulate([(k, i['tag'], i['size']) for k, v in local_models.items() for i in v], tablefmt='fancy_grid', headers=['LLM', 'Tag', 'Size']), fg='white')
|
||||
else:
|
||||
termui.echo(orjson.dumps(local_models, option=orjson.OPT_INDENT_2).decode(), fg="white")
|
||||
termui.echo(orjson.dumps(local_models, option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
return local_models
|
||||
|
||||
@@ -9,13 +9,13 @@ if t.TYPE_CHECKING:
|
||||
from openllm_core._typing_compat import DictStrAny
|
||||
logger = logging.getLogger(__name__)
|
||||
def load_notebook_metadata() -> DictStrAny:
|
||||
with open(os.path.join(os.path.dirname(playground.__file__), "_meta.yml"), "r") as f:
|
||||
with open(os.path.join(os.path.dirname(playground.__file__), '_meta.yml'), 'r') as f:
|
||||
content = yaml.safe_load(f)
|
||||
if not all("description" in k for k in content.values()): raise ValueError("Invalid metadata file. All entries must have a 'description' key.")
|
||||
if not all('description' in k for k in content.values()): raise ValueError("Invalid metadata file. All entries must have a 'description' key.")
|
||||
return content
|
||||
@click.command("playground", context_settings=termui.CONTEXT_SETTINGS)
|
||||
@click.argument("output-dir", default=None, required=False)
|
||||
@click.option("--port", envvar="JUPYTER_PORT", show_envvar=True, show_default=True, default=8888, help="Default port for Jupyter server")
|
||||
@click.command('playground', context_settings=termui.CONTEXT_SETTINGS)
|
||||
@click.argument('output-dir', default=None, required=False)
|
||||
@click.option('--port', envvar='JUPYTER_PORT', show_envvar=True, show_default=True, default=8888, help='Default port for Jupyter server')
|
||||
@click.pass_context
|
||||
def cli(ctx: click.Context, output_dir: str | None, port: int) -> None:
|
||||
"""OpenLLM Playground.
|
||||
@@ -41,27 +41,27 @@ def cli(ctx: click.Context, output_dir: str | None, port: int) -> None:
|
||||
_temp_dir = False
|
||||
if output_dir is None:
|
||||
_temp_dir = True
|
||||
output_dir = tempfile.mkdtemp(prefix="openllm-playground-")
|
||||
output_dir = tempfile.mkdtemp(prefix='openllm-playground-')
|
||||
else:
|
||||
os.makedirs(os.path.abspath(os.path.expandvars(os.path.expanduser(output_dir))), exist_ok=True)
|
||||
|
||||
termui.echo("The playground notebooks will be saved to: " + os.path.abspath(output_dir), fg="blue")
|
||||
termui.echo('The playground notebooks will be saved to: ' + os.path.abspath(output_dir), fg='blue')
|
||||
for module in pkgutil.iter_modules(playground.__path__):
|
||||
if module.ispkg or os.path.exists(os.path.join(output_dir, module.name + ".ipynb")):
|
||||
logger.debug("Skipping: %s (%s)", module.name, "File already exists" if not module.ispkg else f"{module.name} is a module")
|
||||
if module.ispkg or os.path.exists(os.path.join(output_dir, module.name + '.ipynb')):
|
||||
logger.debug('Skipping: %s (%s)', module.name, 'File already exists' if not module.ispkg else f'{module.name} is a module')
|
||||
continue
|
||||
if not isinstance(module.module_finder, importlib.machinery.FileFinder): continue
|
||||
termui.echo("Generating notebook for: " + module.name, fg="magenta")
|
||||
markdown_cell = nbformat.v4.new_markdown_cell(metadata[module.name]["description"])
|
||||
f = jupytext.read(os.path.join(module.module_finder.path, module.name + ".py"))
|
||||
termui.echo('Generating notebook for: ' + module.name, fg='magenta')
|
||||
markdown_cell = nbformat.v4.new_markdown_cell(metadata[module.name]['description'])
|
||||
f = jupytext.read(os.path.join(module.module_finder.path, module.name + '.py'))
|
||||
f.cells.insert(0, markdown_cell)
|
||||
jupytext.write(f, os.path.join(output_dir, module.name + ".ipynb"), fmt="notebook")
|
||||
jupytext.write(f, os.path.join(output_dir, module.name + '.ipynb'), fmt='notebook')
|
||||
try:
|
||||
subprocess.check_output([sys.executable, "-m", "jupyter", "notebook", "--notebook-dir", output_dir, "--port", str(port), "--no-browser", "--debug"])
|
||||
subprocess.check_output([sys.executable, '-m', 'jupyter', 'notebook', '--notebook-dir', output_dir, '--port', str(port), '--no-browser', '--debug'])
|
||||
except subprocess.CalledProcessError as e:
|
||||
termui.echo(e.output, fg="red")
|
||||
raise click.ClickException(f"Failed to start a jupyter server:\n{e}") from None
|
||||
termui.echo(e.output, fg='red')
|
||||
raise click.ClickException(f'Failed to start a jupyter server:\n{e}') from None
|
||||
except KeyboardInterrupt:
|
||||
termui.echo("\nShutting down Jupyter server...", fg="yellow")
|
||||
if _temp_dir: termui.echo("Note: You can access the generated notebooks in: " + output_dir, fg="blue")
|
||||
termui.echo('\nShutting down Jupyter server...', fg='yellow')
|
||||
if _temp_dir: termui.echo('Note: You can access the generated notebooks in: ' + output_dir, fg='blue')
|
||||
ctx.exit(0)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from __future__ import annotations
|
||||
import os, typing as t, click, inflection, openllm
|
||||
if t.TYPE_CHECKING: from openllm_core._typing_compat import DictStrAny
|
||||
def echo(text: t.Any, fg: str = "green", _with_style: bool = True, **attrs: t.Any) -> None:
|
||||
attrs["fg"] = fg if not openllm.utils.get_debug_mode() else None
|
||||
def echo(text: t.Any, fg: str = 'green', _with_style: bool = True, **attrs: t.Any) -> None:
|
||||
attrs['fg'] = fg if not openllm.utils.get_debug_mode() else None
|
||||
if not openllm.utils.get_quiet_mode(): t.cast(t.Callable[..., None], click.echo if not _with_style else click.secho)(text, **attrs)
|
||||
COLUMNS: int = int(os.environ.get("COLUMNS", str(120)))
|
||||
CONTEXT_SETTINGS: DictStrAny = {"help_option_names": ["-h", "--help"], "max_content_width": COLUMNS, "token_normalize_func": inflection.underscore}
|
||||
__all__ = ["echo", "COLUMNS", "CONTEXT_SETTINGS"]
|
||||
COLUMNS: int = int(os.environ.get('COLUMNS', str(120)))
|
||||
CONTEXT_SETTINGS: DictStrAny = {'help_option_names': ['-h', '--help'], 'max_content_width': COLUMNS, 'token_normalize_func': inflection.underscore}
|
||||
__all__ = ['echo', 'COLUMNS', 'CONTEXT_SETTINGS']
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""OpenLLM Python client.
|
||||
'''OpenLLM Python client.
|
||||
|
||||
```python
|
||||
client = openllm.client.HTTPClient("http://localhost:8080")
|
||||
@@ -9,7 +9,7 @@ If the server has embedding supports, use it via `client.embed`:
|
||||
```python
|
||||
client.embed("What is the difference between gather and scatter?")
|
||||
```
|
||||
"""
|
||||
'''
|
||||
from __future__ import annotations
|
||||
import openllm_client, typing as t
|
||||
if t.TYPE_CHECKING: from openllm_client import AsyncHTTPClient as AsyncHTTPClient, BaseAsyncClient as BaseAsyncClient, BaseClient as BaseClient, HTTPClient as HTTPClient, GrpcClient as GrpcClient, AsyncGrpcClient as AsyncGrpcClient
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
"""Base exceptions for OpenLLM. This extends BentoML exceptions."""
|
||||
'''Base exceptions for OpenLLM. This extends BentoML exceptions.'''
|
||||
from __future__ import annotations
|
||||
from openllm_core.exceptions import OpenLLMException as OpenLLMException, GpuNotAvailableError as GpuNotAvailableError, ValidationError as ValidationError, ForbiddenAttributeError as ForbiddenAttributeError, MissingAnnotationAttributeError as MissingAnnotationAttributeError, MissingDependencyError as MissingDependencyError, Error as Error, FineTuneStrategyNotSupportedError as FineTuneStrategyNotSupportedError
|
||||
|
||||
@@ -4,10 +4,10 @@ import openllm
|
||||
from openllm_core.utils import LazyModule, is_flax_available, is_tf_available, is_torch_available, is_vllm_available
|
||||
from openllm_core.config import AutoConfig as AutoConfig, CONFIG_MAPPING as CONFIG_MAPPING, CONFIG_MAPPING_NAMES as CONFIG_MAPPING_NAMES
|
||||
_import_structure: dict[str, list[str]] = {
|
||||
"modeling_auto": ["MODEL_MAPPING_NAMES"],
|
||||
"modeling_flax_auto": ["MODEL_FLAX_MAPPING_NAMES"],
|
||||
"modeling_tf_auto": ["MODEL_TF_MAPPING_NAMES"],
|
||||
"modeling_vllm_auto": ["MODEL_VLLM_MAPPING_NAMES"]
|
||||
'modeling_auto': ['MODEL_MAPPING_NAMES'],
|
||||
'modeling_flax_auto': ['MODEL_FLAX_MAPPING_NAMES'],
|
||||
'modeling_tf_auto': ['MODEL_TF_MAPPING_NAMES'],
|
||||
'modeling_vllm_auto': ['MODEL_VLLM_MAPPING_NAMES']
|
||||
}
|
||||
if t.TYPE_CHECKING:
|
||||
from .modeling_auto import MODEL_MAPPING_NAMES as MODEL_MAPPING_NAMES
|
||||
@@ -19,31 +19,31 @@ try:
|
||||
except openllm.exceptions.MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_auto"].extend(["AutoLLM", "MODEL_MAPPING"])
|
||||
_import_structure['modeling_auto'].extend(['AutoLLM', 'MODEL_MAPPING'])
|
||||
if t.TYPE_CHECKING: from .modeling_auto import MODEL_MAPPING as MODEL_MAPPING, AutoLLM as AutoLLM
|
||||
try:
|
||||
if not is_vllm_available(): raise openllm.exceptions.MissingDependencyError
|
||||
except openllm.exceptions.MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_vllm_auto"].extend(["AutoVLLM", "MODEL_VLLM_MAPPING"])
|
||||
_import_structure['modeling_vllm_auto'].extend(['AutoVLLM', 'MODEL_VLLM_MAPPING'])
|
||||
if t.TYPE_CHECKING: from .modeling_vllm_auto import MODEL_VLLM_MAPPING as MODEL_VLLM_MAPPING, AutoVLLM as AutoVLLM
|
||||
try:
|
||||
if not is_flax_available(): raise openllm.exceptions.MissingDependencyError
|
||||
except openllm.exceptions.MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_flax_auto"].extend(["AutoFlaxLLM", "MODEL_FLAX_MAPPING"])
|
||||
_import_structure['modeling_flax_auto'].extend(['AutoFlaxLLM', 'MODEL_FLAX_MAPPING'])
|
||||
if t.TYPE_CHECKING: from .modeling_flax_auto import MODEL_FLAX_MAPPING as MODEL_FLAX_MAPPING, AutoFlaxLLM as AutoFlaxLLM
|
||||
try:
|
||||
if not is_tf_available(): raise openllm.exceptions.MissingDependencyError
|
||||
except openllm.exceptions.MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_tf_auto"].extend(["AutoTFLLM", "MODEL_TF_MAPPING"])
|
||||
_import_structure['modeling_tf_auto'].extend(['AutoTFLLM', 'MODEL_TF_MAPPING'])
|
||||
if t.TYPE_CHECKING: from .modeling_tf_auto import MODEL_TF_MAPPING as MODEL_TF_MAPPING, AutoTFLLM as AutoTFLLM
|
||||
|
||||
__lazy = LazyModule(__name__, os.path.abspath("__file__"), _import_structure)
|
||||
__lazy = LazyModule(__name__, os.path.abspath('__file__'), _import_structure)
|
||||
__all__ = __lazy.__all__
|
||||
__dir__ = __lazy.__dir__
|
||||
__getattr__ = __lazy.__getattr__
|
||||
|
||||
@@ -25,20 +25,20 @@ class BaseAutoLLMClass:
|
||||
@classmethod
|
||||
def for_model(cls, model: str, /, model_id: str | None = None, model_version: str | None = None, llm_config: openllm.LLMConfig | None = None, ensure_available: bool = False,
|
||||
**attrs: t.Any) -> openllm.LLM[t.Any, t.Any]:
|
||||
"""The lower level API for creating a LLM instance.
|
||||
'''The lower level API for creating a LLM instance.
|
||||
|
||||
```python
|
||||
>>> import openllm
|
||||
>>> llm = openllm.AutoLLM.for_model("flan-t5")
|
||||
```
|
||||
"""
|
||||
'''
|
||||
llm = cls.infer_class_from_name(model).from_pretrained(model_id=model_id, model_version=model_version, llm_config=llm_config, **attrs)
|
||||
if ensure_available: llm.ensure_model_id_exists()
|
||||
return llm
|
||||
|
||||
@classmethod
|
||||
def create_runner(cls, model: str, model_id: str | None = None, **attrs: t.Any) -> LLMRunner[t.Any, t.Any]:
|
||||
"""Create a LLM Runner for the given model name.
|
||||
'''Create a LLM Runner for the given model name.
|
||||
|
||||
Args:
|
||||
model: The model name to instantiate.
|
||||
@@ -47,7 +47,7 @@ class BaseAutoLLMClass:
|
||||
|
||||
Returns:
|
||||
A LLM instance.
|
||||
"""
|
||||
'''
|
||||
runner_kwargs_name = set(inspect.signature(openllm.LLM[t.Any, t.Any].to_runner).parameters)
|
||||
runner_attrs = {k: v for k, v in attrs.items() if k in runner_kwargs_name}
|
||||
for k in runner_attrs:
|
||||
@@ -56,15 +56,15 @@ class BaseAutoLLMClass:
|
||||
|
||||
@classmethod
|
||||
def register(cls, config_class: type[openllm.LLMConfig], llm_class: type[openllm.LLM[t.Any, t.Any]]) -> None:
|
||||
"""Register a new model for this class.
|
||||
'''Register a new model for this class.
|
||||
|
||||
Args:
|
||||
config_class: The configuration corresponding to the model to register.
|
||||
llm_class: The runnable to register.
|
||||
"""
|
||||
if hasattr(llm_class, "config_class") and llm_class.config_class is not config_class:
|
||||
'''
|
||||
if hasattr(llm_class, 'config_class') and llm_class.config_class is not config_class:
|
||||
raise ValueError(
|
||||
f"The model class you are passing has a `config_class` attribute that is not consistent with the config class you passed (model has {llm_class.config_class} and you passed {config_class}. Fix one of those so they match!"
|
||||
f'The model class you are passing has a `config_class` attribute that is not consistent with the config class you passed (model has {llm_class.config_class} and you passed {config_class}. Fix one of those so they match!'
|
||||
)
|
||||
cls._model_mapping.register(config_class, llm_class)
|
||||
|
||||
@@ -80,13 +80,13 @@ def getattribute_from_module(module: types.ModuleType, attr: t.Any) -> t.Any:
|
||||
if isinstance(attr, tuple): return tuple(getattribute_from_module(module, a) for a in attr)
|
||||
if hasattr(module, attr): return getattr(module, attr)
|
||||
# Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the object at the top level.
|
||||
openllm_module = importlib.import_module("openllm")
|
||||
openllm_module = importlib.import_module('openllm')
|
||||
if module != openllm_module:
|
||||
try:
|
||||
return getattribute_from_module(openllm_module, attr)
|
||||
except ValueError:
|
||||
raise ValueError(f"Could not find {attr} neither in {module} nor in {openllm_module}!") from None
|
||||
raise ValueError(f"Could not find {attr} in {openllm_module}!")
|
||||
raise ValueError(f'Could not find {attr} neither in {module} nor in {openllm_module}!') from None
|
||||
raise ValueError(f'Could not find {attr} in {openllm_module}!')
|
||||
class _LazyAutoMapping(OrderedDict, ReprMixin):
|
||||
"""Based on transformers.models.auto.configuration_auto._LazyAutoMapping.
|
||||
|
||||
@@ -112,7 +112,7 @@ class _LazyAutoMapping(OrderedDict, ReprMixin):
|
||||
|
||||
def _load_attr_from_module(self, model_type: str, attr: str) -> t.Any:
|
||||
module_name = inflection.underscore(model_type)
|
||||
if module_name not in self._modules: self._modules[module_name] = importlib.import_module(f".{module_name}", "openllm.models")
|
||||
if module_name not in self._modules: self._modules[module_name] = importlib.import_module(f'.{module_name}', 'openllm.models')
|
||||
return getattribute_from_module(self._modules[module_name], attr)
|
||||
|
||||
def __len__(self) -> int:
|
||||
@@ -133,33 +133,33 @@ class _LazyAutoMapping(OrderedDict, ReprMixin):
|
||||
|
||||
def keys(self) -> ConfigModelKeysView:
|
||||
return t.cast(
|
||||
"ConfigModelKeysView", [self._load_attr_from_module(key, name) for key, name in self._config_mapping.items() if key in self._model_mapping.keys()] + list(self._extra_content.keys())
|
||||
'ConfigModelKeysView', [self._load_attr_from_module(key, name) for key, name in self._config_mapping.items() if key in self._model_mapping.keys()] + list(self._extra_content.keys())
|
||||
)
|
||||
|
||||
def values(self) -> ConfigModelValuesView:
|
||||
return t.cast(
|
||||
"ConfigModelValuesView", [self._load_attr_from_module(key, name) for key, name in self._model_mapping.items() if key in self._config_mapping.keys()] + list(
|
||||
'ConfigModelValuesView', [self._load_attr_from_module(key, name) for key, name in self._model_mapping.items() if key in self._config_mapping.keys()] + list(
|
||||
self._extra_content.values()
|
||||
)
|
||||
)
|
||||
|
||||
def items(self) -> ConfigModelItemsView:
|
||||
return t.cast(
|
||||
"ConfigModelItemsView",
|
||||
'ConfigModelItemsView',
|
||||
[(self._load_attr_from_module(key, self._config_mapping[key]),
|
||||
self._load_attr_from_module(key, self._model_mapping[key])) for key in self._model_mapping.keys() if key in self._config_mapping.keys()] + list(self._extra_content.items())
|
||||
)
|
||||
|
||||
def __iter__(self) -> t.Iterator[type[openllm.LLMConfig]]:
|
||||
return iter(t.cast("SupportsIter[t.Iterator[type[openllm.LLMConfig]]]", self.keys()))
|
||||
return iter(t.cast('SupportsIter[t.Iterator[type[openllm.LLMConfig]]]', self.keys()))
|
||||
|
||||
def __contains__(self, item: t.Any) -> bool:
|
||||
if item in self._extra_content: return True
|
||||
if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping: return False
|
||||
if not hasattr(item, '__name__') or item.__name__ not in self._reverse_config_mapping: return False
|
||||
return self._reverse_config_mapping[item.__name__] in self._model_mapping
|
||||
|
||||
def register(self, key: t.Any, value: t.Any) -> None:
|
||||
if hasattr(key, "__name__") and key.__name__ in self._reverse_config_mapping:
|
||||
if hasattr(key, '__name__') and key.__name__ in self._reverse_config_mapping:
|
||||
if self._reverse_config_mapping[key.__name__] in self._model_mapping.keys(): raise ValueError(f"'{key}' is already used by a OpenLLM model.")
|
||||
self._extra_content[key] = value
|
||||
__all__ = ["BaseAutoLLMClass", "_LazyAutoMapping"]
|
||||
__all__ = ['BaseAutoLLMClass', '_LazyAutoMapping']
|
||||
|
||||
@@ -3,9 +3,9 @@ import typing as t
|
||||
from collections import OrderedDict
|
||||
from .factory import BaseAutoLLMClass, _LazyAutoMapping
|
||||
from openllm_core.config import CONFIG_MAPPING_NAMES
|
||||
MODEL_MAPPING_NAMES = OrderedDict([("chatglm", "ChatGLM"), ("dolly_v2", "DollyV2"), ("falcon", "Falcon"), ("flan_t5", "FlanT5"), ("gpt_neox", "GPTNeoX"), ("llama", "Llama"), ("mpt", "MPT"), (
|
||||
"opt", "OPT"
|
||||
), ("stablelm", "StableLM"), ("starcoder", "StarCoder"), ("baichuan", "Baichuan")])
|
||||
MODEL_MAPPING_NAMES = OrderedDict([('chatglm', 'ChatGLM'), ('dolly_v2', 'DollyV2'), ('falcon', 'Falcon'), ('flan_t5', 'FlanT5'), ('gpt_neox', 'GPTNeoX'), ('llama', 'Llama'), ('mpt', 'MPT'), (
|
||||
'opt', 'OPT'
|
||||
), ('stablelm', 'StableLM'), ('starcoder', 'StarCoder'), ('baichuan', 'Baichuan')])
|
||||
MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
|
||||
class AutoLLM(BaseAutoLLMClass):
|
||||
_model_mapping: t.ClassVar = MODEL_MAPPING
|
||||
|
||||
@@ -3,7 +3,7 @@ import typing as t
|
||||
from collections import OrderedDict
|
||||
from .factory import BaseAutoLLMClass, _LazyAutoMapping
|
||||
from openllm_core.config import CONFIG_MAPPING_NAMES
|
||||
MODEL_FLAX_MAPPING_NAMES = OrderedDict([("flan_t5", "FlaxFlanT5"), ("opt", "FlaxOPT")])
|
||||
MODEL_FLAX_MAPPING_NAMES = OrderedDict([('flan_t5', 'FlaxFlanT5'), ('opt', 'FlaxOPT')])
|
||||
MODEL_FLAX_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FLAX_MAPPING_NAMES)
|
||||
class AutoFlaxLLM(BaseAutoLLMClass):
|
||||
_model_mapping: t.ClassVar = MODEL_FLAX_MAPPING
|
||||
|
||||
@@ -3,7 +3,7 @@ import typing as t
|
||||
from collections import OrderedDict
|
||||
from .factory import BaseAutoLLMClass, _LazyAutoMapping
|
||||
from openllm_core.config import CONFIG_MAPPING_NAMES
|
||||
MODEL_TF_MAPPING_NAMES = OrderedDict([("flan_t5", "TFFlanT5"), ("opt", "TFOPT")])
|
||||
MODEL_TF_MAPPING_NAMES = OrderedDict([('flan_t5', 'TFFlanT5'), ('opt', 'TFOPT')])
|
||||
MODEL_TF_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_TF_MAPPING_NAMES)
|
||||
class AutoTFLLM(BaseAutoLLMClass):
|
||||
_model_mapping: t.ClassVar = MODEL_TF_MAPPING
|
||||
|
||||
@@ -3,9 +3,9 @@ import typing as t
|
||||
from collections import OrderedDict
|
||||
from .factory import BaseAutoLLMClass, _LazyAutoMapping
|
||||
from openllm_core.config import CONFIG_MAPPING_NAMES
|
||||
MODEL_VLLM_MAPPING_NAMES = OrderedDict([("baichuan", "VLLMBaichuan"), ("dolly_v2", "VLLMDollyV2"), ("falcon", "VLLMFalcon"), ("gpt_neox", "VLLMGPTNeoX"), ("mpt", "VLLMMPT"), (
|
||||
"opt", "VLLMOPT"
|
||||
), ("stablelm", "VLLMStableLM"), ("starcoder", "VLLMStarCoder"), ("llama", "VLLMLlama")])
|
||||
MODEL_VLLM_MAPPING_NAMES = OrderedDict([('baichuan', 'VLLMBaichuan'), ('dolly_v2', 'VLLMDollyV2'), ('falcon', 'VLLMFalcon'), ('gpt_neox', 'VLLMGPTNeoX'), ('mpt', 'VLLMMPT'), (
|
||||
'opt', 'VLLMOPT'
|
||||
), ('stablelm', 'VLLMStableLM'), ('starcoder', 'VLLMStarCoder'), ('llama', 'VLLMLlama')])
|
||||
MODEL_VLLM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_VLLM_MAPPING_NAMES)
|
||||
class AutoVLLM(BaseAutoLLMClass):
|
||||
_model_mapping: t.ClassVar = MODEL_VLLM_MAPPING
|
||||
|
||||
@@ -9,14 +9,14 @@ try:
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_baichuan"] = ["Baichuan"]
|
||||
_import_structure['modeling_baichuan'] = ['Baichuan']
|
||||
if t.TYPE_CHECKING: from .modeling_baichuan import Baichuan as Baichuan
|
||||
try:
|
||||
if not is_vllm_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_vllm_baichuan"] = ["VLLMBaichuan"]
|
||||
_import_structure['modeling_vllm_baichuan'] = ['VLLMBaichuan']
|
||||
if t.TYPE_CHECKING: from .modeling_vllm_baichuan import VLLMBaichuan as VLLMBaichuan
|
||||
|
||||
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure)
|
||||
sys.modules[__name__] = LazyModule(__name__, globals()['__file__'], _import_structure)
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from __future__ import annotations
|
||||
import typing as t, openllm
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
class Baichuan(openllm.LLM["transformers.PreTrainedModel", "transformers.PreTrainedTokenizerBase"]):
|
||||
class Baichuan(openllm.LLM['transformers.PreTrainedModel', 'transformers.PreTrainedTokenizerBase']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
import torch
|
||||
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
||||
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16): # type: ignore[attr-defined]
|
||||
inputs = self.tokenizer(prompt, return_tensors='pt').to(self.device)
|
||||
with torch.inference_mode(), torch.autocast('cuda', dtype=torch.float16): # type: ignore[attr-defined]
|
||||
outputs = self.model.generate(**inputs, generation_config=self.config.model_construct_env(**attrs).to_generation_config())
|
||||
return self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import typing as t, openllm
|
||||
if t.TYPE_CHECKING: import vllm, transformers
|
||||
class VLLMBaichuan(openllm.LLM["vllm.LLMEngine", "transformers.PreTrainedTokenizerBase"]):
|
||||
class VLLMBaichuan(openllm.LLM['vllm.LLMEngine', 'transformers.PreTrainedTokenizerBase']):
|
||||
__openllm_internal__ = True
|
||||
tokenizer_id = "local"
|
||||
tokenizer_id = 'local'
|
||||
|
||||
@@ -9,7 +9,7 @@ try:
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_chatglm"] = ["ChatGLM"]
|
||||
_import_structure['modeling_chatglm'] = ['ChatGLM']
|
||||
if t.TYPE_CHECKING: from .modeling_chatglm import ChatGLM as ChatGLM
|
||||
|
||||
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure)
|
||||
sys.modules[__name__] = LazyModule(__name__, globals()['__file__'], _import_structure)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import typing as t, openllm
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
class ChatGLM(openllm.LLM["transformers.PreTrainedModel", "transformers.PreTrainedTokenizerFast"]):
|
||||
class ChatGLM(openllm.LLM['transformers.PreTrainedModel', 'transformers.PreTrainedTokenizerFast']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> tuple[str, list[tuple[str, str]]]:
|
||||
@@ -17,7 +17,7 @@ class ChatGLM(openllm.LLM["transformers.PreTrainedModel", "transformers.PreTrain
|
||||
embeddings: list[list[float]] = []
|
||||
num_tokens = 0
|
||||
for prompt in prompts:
|
||||
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
|
||||
input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.device)
|
||||
with torch.inference_mode():
|
||||
outputs = self.model(input_ids, output_hidden_states=True)
|
||||
data = F.normalize(torch.mean(outputs.hidden_states[-1].transpose(0, 1), dim=0), p=2, dim=0)
|
||||
|
||||
@@ -9,14 +9,14 @@ try:
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_dolly_v2"] = ["DollyV2"]
|
||||
_import_structure['modeling_dolly_v2'] = ['DollyV2']
|
||||
if t.TYPE_CHECKING: from .modeling_dolly_v2 import DollyV2 as DollyV2
|
||||
try:
|
||||
if not is_vllm_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_vllm_dolly_v2"] = ["VLLMDollyV2"]
|
||||
_import_structure['modeling_vllm_dolly_v2'] = ['VLLMDollyV2']
|
||||
if t.TYPE_CHECKING: from .modeling_vllm_dolly_v2 import VLLMDollyV2 as VLLMDollyV2
|
||||
|
||||
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure)
|
||||
sys.modules[__name__] = LazyModule(__name__, globals()['__file__'], _import_structure)
|
||||
|
||||
@@ -4,7 +4,7 @@ from openllm_core._typing_compat import overload
|
||||
from openllm_core.config.configuration_dolly_v2 import DEFAULT_PROMPT_TEMPLATE, END_KEY, RESPONSE_KEY, get_special_token_id
|
||||
|
||||
if t.TYPE_CHECKING: import torch, transformers, tensorflow as tf
|
||||
else: torch, transformers, tf = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("tf", globals(), "tensorflow")
|
||||
else: torch, transformers, tf = openllm.utils.LazyLoader('torch', globals(), 'torch'), openllm.utils.LazyLoader('transformers', globals(), 'transformers'), openllm.utils.LazyLoader('tf', globals(), 'tensorflow')
|
||||
logger = logging.getLogger(__name__)
|
||||
@overload
|
||||
def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, _init: t.Literal[True] = True, **attrs: t.Any) -> transformers.Pipeline:
|
||||
@@ -31,25 +31,25 @@ def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.Pr
|
||||
response_key_token_id = get_special_token_id(self.tokenizer, tokenizer_response_key)
|
||||
end_key_token_id = get_special_token_id(self.tokenizer, END_KEY)
|
||||
# Ensure generation stops once it generates "### End"
|
||||
generate_kwargs["eos_token_id"] = end_key_token_id
|
||||
generate_kwargs['eos_token_id'] = end_key_token_id
|
||||
except ValueError:
|
||||
pass
|
||||
forward_params = generate_kwargs
|
||||
postprocess_params = {"response_key_token_id": response_key_token_id, "end_key_token_id": end_key_token_id}
|
||||
if return_full_text is not None: postprocess_params["return_full_text"] = return_full_text
|
||||
postprocess_params = {'response_key_token_id': response_key_token_id, 'end_key_token_id': end_key_token_id}
|
||||
if return_full_text is not None: postprocess_params['return_full_text'] = return_full_text
|
||||
return preprocess_params, forward_params, postprocess_params
|
||||
|
||||
def preprocess(self, input_: str, **generate_kwargs: t.Any) -> t.Dict[str, t.Any]:
|
||||
if t.TYPE_CHECKING: assert self.tokenizer is not None
|
||||
prompt_text = DEFAULT_PROMPT_TEMPLATE.format(instruction=input_)
|
||||
inputs = self.tokenizer(prompt_text, return_tensors="pt")
|
||||
inputs["prompt_text"] = prompt_text
|
||||
inputs["instruction_text"] = input_
|
||||
inputs = self.tokenizer(prompt_text, return_tensors='pt')
|
||||
inputs['prompt_text'] = prompt_text
|
||||
inputs['instruction_text'] = input_
|
||||
return t.cast(t.Dict[str, t.Any], inputs)
|
||||
|
||||
def _forward(self, input_tensors: dict[str, t.Any], **generate_kwargs: t.Any) -> transformers.utils.generic.ModelOutput:
|
||||
if t.TYPE_CHECKING: assert self.tokenizer is not None
|
||||
input_ids, attention_mask = input_tensors["input_ids"], input_tensors.get("attention_mask", None)
|
||||
input_ids, attention_mask = input_tensors['input_ids'], input_tensors.get('attention_mask', None)
|
||||
if input_ids.shape[1] == 0: input_ids, attention_mask, in_b = None, None, 1
|
||||
else: in_b = input_ids.shape[0]
|
||||
generated_sequence = self.model.generate(
|
||||
@@ -59,16 +59,16 @@ def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.Pr
|
||||
**generate_kwargs
|
||||
)
|
||||
out_b = generated_sequence.shape[0]
|
||||
if self.framework == "pt": generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
|
||||
elif self.framework == "tf": generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:]))
|
||||
instruction_text = input_tensors.pop("instruction_text")
|
||||
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "instruction_text": instruction_text}
|
||||
if self.framework == 'pt': generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
|
||||
elif self.framework == 'tf': generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:]))
|
||||
instruction_text = input_tensors.pop('instruction_text')
|
||||
return {'generated_sequence': generated_sequence, 'input_ids': input_ids, 'instruction_text': instruction_text}
|
||||
|
||||
def postprocess(self, model_outputs: dict[str, t.Any], response_key_token_id: int, end_key_token_id: int, return_full_text: bool = False) -> list[dict[t.Literal["generated_text"], str]]:
|
||||
def postprocess(self, model_outputs: dict[str, t.Any], response_key_token_id: int, end_key_token_id: int, return_full_text: bool = False) -> list[dict[t.Literal['generated_text'], str]]:
|
||||
if t.TYPE_CHECKING: assert self.tokenizer is not None
|
||||
_generated_sequence, instruction_text = model_outputs["generated_sequence"][0], model_outputs["instruction_text"]
|
||||
_generated_sequence, instruction_text = model_outputs['generated_sequence'][0], model_outputs['instruction_text']
|
||||
generated_sequence: list[list[int]] = _generated_sequence.numpy().tolist()
|
||||
records: list[dict[t.Literal["generated_text"], str]] = []
|
||||
records: list[dict[t.Literal['generated_text'], str]] = []
|
||||
for sequence in generated_sequence:
|
||||
# The response will be set to this variable if we can identify it.
|
||||
decoded = None
|
||||
@@ -80,7 +80,7 @@ def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.Pr
|
||||
response_pos = sequence.index(response_key_token_id)
|
||||
except ValueError:
|
||||
response_pos = None
|
||||
if response_pos is None: logger.warning("Could not find response key %s in: %s", response_key_token_id, sequence)
|
||||
if response_pos is None: logger.warning('Could not find response key %s in: %s', response_key_token_id, sequence)
|
||||
if response_pos:
|
||||
# Next find where "### End" is located. The model has been trained to end its responses with this
|
||||
# sequence (or actually, the token ID it maps to, since it is a special token). We may not find
|
||||
@@ -96,33 +96,33 @@ def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.Pr
|
||||
fully_decoded = self.tokenizer.decode(sequence)
|
||||
# The response appears after "### Response:". The model has been trained to append "### End" at the
|
||||
# end.
|
||||
m = re.search(r"#+\s*Response:\s*(.+?)#+\s*End", fully_decoded, flags=re.DOTALL)
|
||||
m = re.search(r'#+\s*Response:\s*(.+?)#+\s*End', fully_decoded, flags=re.DOTALL)
|
||||
if m: decoded = m.group(1).strip()
|
||||
else:
|
||||
# The model might not generate the "### End" sequence before reaching the max tokens. In this case,
|
||||
# return everything after "### Response:".
|
||||
m = re.search(r"#+\s*Response:\s*(.+)", fully_decoded, flags=re.DOTALL)
|
||||
m = re.search(r'#+\s*Response:\s*(.+)', fully_decoded, flags=re.DOTALL)
|
||||
if m: decoded = m.group(1).strip()
|
||||
else: logger.warning("Failed to find response in:\n%s", fully_decoded)
|
||||
else: logger.warning('Failed to find response in:\n%s', fully_decoded)
|
||||
# If the full text is requested, then append the decoded text to the original instruction.
|
||||
# This technically isn't the full text, as we format the instruction in the prompt the model has been
|
||||
# trained on, but to the client it will appear to be the full text.
|
||||
if return_full_text: decoded = f"{instruction_text}\n{decoded}"
|
||||
records.append({"generated_text": t.cast(str, decoded)})
|
||||
if return_full_text: decoded = f'{instruction_text}\n{decoded}'
|
||||
records.append({'generated_text': t.cast(str, decoded)})
|
||||
return records
|
||||
|
||||
return InstructionTextGenerationPipeline() if _init else InstructionTextGenerationPipeline
|
||||
class DollyV2(openllm.LLM["transformers.Pipeline", "transformers.PreTrainedTokenizer"]):
|
||||
class DollyV2(openllm.LLM['transformers.Pipeline', 'transformers.PreTrainedTokenizer']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
|
||||
return {"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, "torch_dtype": torch.bfloat16}, {}
|
||||
return {'device_map': 'auto' if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, 'torch_dtype': torch.bfloat16}, {}
|
||||
|
||||
def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.Pipeline:
|
||||
return get_pipeline(transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path, *args, **attrs), self.tokenizer, _init=True, return_full_text=self.config.return_full_text)
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[dict[t.Literal["generated_text"], str]]:
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[dict[t.Literal['generated_text'], str]]:
|
||||
llm_config = self.config.model_construct_env(**attrs)
|
||||
with torch.inference_mode():
|
||||
return self.model(prompt, return_full_text=llm_config.return_full_text, generation_config=llm_config.to_generation_config())
|
||||
|
||||
@@ -3,6 +3,6 @@ import logging, typing as t, openllm
|
||||
if t.TYPE_CHECKING: import vllm, transformers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
class VLLMDollyV2(openllm.LLM["vllm.LLMEngine", "transformers.PreTrainedTokenizer"]):
|
||||
class VLLMDollyV2(openllm.LLM['vllm.LLMEngine', 'transformers.PreTrainedTokenizer']):
|
||||
__openllm_internal__ = True
|
||||
tokenizer_id = "local"
|
||||
tokenizer_id = 'local'
|
||||
|
||||
@@ -9,14 +9,14 @@ try:
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_falcon"] = ["Falcon"]
|
||||
_import_structure['modeling_falcon'] = ['Falcon']
|
||||
if t.TYPE_CHECKING: from .modeling_falcon import Falcon as Falcon
|
||||
try:
|
||||
if not is_vllm_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_vllm_falcon"] = ["VLLMFalcon"]
|
||||
_import_structure['modeling_vllm_falcon'] = ['VLLMFalcon']
|
||||
if t.TYPE_CHECKING: from .modeling_vllm_falcon import VLLMFalcon as VLLMFalcon
|
||||
|
||||
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure)
|
||||
sys.modules[__name__] = LazyModule(__name__, globals()['__file__'], _import_structure)
|
||||
|
||||
@@ -1,32 +1,32 @@
|
||||
from __future__ import annotations
|
||||
import typing as t, openllm
|
||||
if t.TYPE_CHECKING: import torch, transformers
|
||||
else: torch, transformers = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers")
|
||||
class Falcon(openllm.LLM["transformers.PreTrainedModel", "transformers.PreTrainedTokenizerBase"]):
|
||||
else: torch, transformers = openllm.utils.LazyLoader('torch', globals(), 'torch'), openllm.utils.LazyLoader('transformers', globals(), 'transformers')
|
||||
class Falcon(openllm.LLM['transformers.PreTrainedModel', 'transformers.PreTrainedTokenizerBase']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
|
||||
return {"torch_dtype": torch.bfloat16, "device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None}, {}
|
||||
return {'torch_dtype': torch.bfloat16, 'device_map': 'auto' if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None}, {}
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
eos_token_id, inputs = attrs.pop("eos_token_id", self.tokenizer.eos_token_id), self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
||||
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16): # type: ignore[attr-defined]
|
||||
eos_token_id, inputs = attrs.pop('eos_token_id', self.tokenizer.eos_token_id), self.tokenizer(prompt, return_tensors='pt').to(self.device)
|
||||
with torch.inference_mode(), torch.autocast('cuda', dtype=torch.float16): # type: ignore[attr-defined]
|
||||
return self.tokenizer.batch_decode(
|
||||
self.model.generate(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
input_ids=inputs['input_ids'],
|
||||
attention_mask=inputs['attention_mask'],
|
||||
generation_config=self.config.model_construct_env(eos_token_id=eos_token_id, **attrs).to_generation_config()
|
||||
),
|
||||
skip_special_tokens=True
|
||||
)
|
||||
|
||||
def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> list[dict[t.Literal["generated_text"], str]]:
|
||||
max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop("max_new_tokens", 200), self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
||||
src_len, stopping_criteria = encoded_inputs["input_ids"].shape[1], preprocess_generate_kwds.pop("stopping_criteria", openllm.StoppingCriteriaList([]))
|
||||
def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> list[dict[t.Literal['generated_text'], str]]:
|
||||
max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop('max_new_tokens', 200), self.tokenizer(prompt, return_tensors='pt').to(self.device)
|
||||
src_len, stopping_criteria = encoded_inputs['input_ids'].shape[1], preprocess_generate_kwds.pop('stopping_criteria', openllm.StoppingCriteriaList([]))
|
||||
stopping_criteria.append(openllm.StopSequenceCriteria(stop, self.tokenizer))
|
||||
result = self.tokenizer.decode(self.model.generate(encoded_inputs["input_ids"], max_new_tokens=max_new_tokens, stopping_criteria=stopping_criteria)[0].tolist()[src_len:])
|
||||
result = self.tokenizer.decode(self.model.generate(encoded_inputs['input_ids'], max_new_tokens=max_new_tokens, stopping_criteria=stopping_criteria)[0].tolist()[src_len:])
|
||||
# Inference API returns the stop sequence
|
||||
for stop_seq in stop:
|
||||
if result.endswith(stop_seq): result = result[:-len(stop_seq)]
|
||||
return [{"generated_text": result}]
|
||||
return [{'generated_text': result}]
|
||||
|
||||
@@ -3,6 +3,6 @@ import logging, typing as t, openllm
|
||||
if t.TYPE_CHECKING: import vllm, transformers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
class VLLMFalcon(openllm.LLM["vllm.LLMEngine", "transformers.PreTrainedTokenizerBase"]):
|
||||
class VLLMFalcon(openllm.LLM['vllm.LLMEngine', 'transformers.PreTrainedTokenizerBase']):
|
||||
__openllm_internal__ = True
|
||||
tokenizer_id = "local"
|
||||
tokenizer_id = 'local'
|
||||
|
||||
@@ -9,21 +9,21 @@ try:
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_flan_t5"] = ["FlanT5"]
|
||||
_import_structure['modeling_flan_t5'] = ['FlanT5']
|
||||
if t.TYPE_CHECKING: from .modeling_flan_t5 import FlanT5 as FlanT5
|
||||
try:
|
||||
if not is_flax_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_flax_flan_t5"] = ["FlaxFlanT5"]
|
||||
_import_structure['modeling_flax_flan_t5'] = ['FlaxFlanT5']
|
||||
if t.TYPE_CHECKING: from .modeling_flax_flan_t5 import FlaxFlanT5 as FlaxFlanT5
|
||||
try:
|
||||
if not is_tf_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_tf_flan_t5"] = ["TFFlanT5"]
|
||||
_import_structure['modeling_tf_flan_t5'] = ['TFFlanT5']
|
||||
if t.TYPE_CHECKING: from .modeling_tf_flan_t5 import TFFlanT5 as TFFlanT5
|
||||
|
||||
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure)
|
||||
sys.modules[__name__] = LazyModule(__name__, globals()['__file__'], _import_structure)
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
from __future__ import annotations
|
||||
import typing as t, openllm
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
class FlanT5(openllm.LLM["transformers.T5ForConditionalGeneration", "transformers.T5TokenizerFast"]):
|
||||
class FlanT5(openllm.LLM['transformers.T5ForConditionalGeneration', 'transformers.T5TokenizerFast']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
import torch
|
||||
with torch.inference_mode():
|
||||
return self.tokenizer.batch_decode(
|
||||
self.model.generate(**self.tokenizer(prompt, return_tensors="pt").to(self.device), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()),
|
||||
self.model.generate(**self.tokenizer(prompt, return_tensors='pt').to(self.device), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()),
|
||||
skip_special_tokens=True
|
||||
)
|
||||
|
||||
@@ -17,7 +17,7 @@ class FlanT5(openllm.LLM["transformers.T5ForConditionalGeneration", "transformer
|
||||
embeddings: list[list[float]] = []
|
||||
num_tokens = 0
|
||||
for prompt in prompts:
|
||||
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
|
||||
input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.device)
|
||||
with torch.inference_mode():
|
||||
outputs = self.model(input_ids, decoder_input_ids=input_ids)
|
||||
data = F.normalize(torch.mean(outputs.encoder_last_hidden_state[0], dim=0), p=2, dim=0)
|
||||
|
||||
@@ -3,7 +3,7 @@ import typing as t, openllm
|
||||
from openllm_core._prompt import process_prompt
|
||||
from openllm_core.config.configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
class FlaxFlanT5(openllm.LLM["transformers.FlaxT5ForConditionalGeneration", "transformers.T5TokenizerFast"]):
|
||||
class FlaxFlanT5(openllm.LLM['transformers.FlaxT5ForConditionalGeneration', 'transformers.T5TokenizerFast']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
def sanitize_parameters(
|
||||
@@ -20,20 +20,20 @@ class FlaxFlanT5(openllm.LLM["transformers.FlaxT5ForConditionalGeneration", "tra
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
if decoder_start_token_id is None: decoder_start_token_id = 0
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"temperature": temperature,
|
||||
"top_k": top_k,
|
||||
"top_p": top_p,
|
||||
"repetition_penalty": repetition_penalty,
|
||||
"decoder_start_token_id": decoder_start_token_id
|
||||
'max_new_tokens': max_new_tokens,
|
||||
'temperature': temperature,
|
||||
'top_k': top_k,
|
||||
'top_p': top_p,
|
||||
'repetition_penalty': repetition_penalty,
|
||||
'decoder_start_token_id': decoder_start_token_id
|
||||
}, {}
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
# NOTE: decoder_start_token_id is extracted from https://huggingface.co/google/flan-t5-small/tree/main as it is required for encoder-decoder generation.
|
||||
decoder_start_token_id = attrs.pop("decoder_start_token_id", 0)
|
||||
decoder_start_token_id = attrs.pop('decoder_start_token_id', 0)
|
||||
return self.tokenizer.batch_decode(
|
||||
self.model.generate(
|
||||
self.tokenizer(prompt, return_tensors="np")["input_ids"],
|
||||
self.tokenizer(prompt, return_tensors='np')['input_ids'],
|
||||
do_sample=True,
|
||||
generation_config=self.config.model_construct_env(**attrs).to_generation_config(),
|
||||
decoder_start_token_id=decoder_start_token_id
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from __future__ import annotations
|
||||
import typing as t, openllm
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
class TFFlanT5(openllm.LLM["transformers.TFT5ForConditionalGeneration", "transformers.T5TokenizerFast"]):
|
||||
class TFFlanT5(openllm.LLM['transformers.TFT5ForConditionalGeneration', 'transformers.T5TokenizerFast']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
return self.tokenizer.batch_decode(
|
||||
self.model.generate(self.tokenizer(prompt, return_tensors="tf").input_ids, do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()),
|
||||
self.model.generate(self.tokenizer(prompt, return_tensors='tf').input_ids, do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()),
|
||||
skip_special_tokens=True
|
||||
)
|
||||
|
||||
@@ -9,14 +9,14 @@ try:
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_gpt_neox"] = ["GPTNeoX"]
|
||||
_import_structure['modeling_gpt_neox'] = ['GPTNeoX']
|
||||
if t.TYPE_CHECKING: from .modeling_gpt_neox import GPTNeoX as GPTNeoX
|
||||
try:
|
||||
if not is_vllm_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_vllm_gpt_neox"] = ["VLLMGPTNeoX"]
|
||||
_import_structure['modeling_vllm_gpt_neox'] = ['VLLMGPTNeoX']
|
||||
if t.TYPE_CHECKING: from .modeling_vllm_gpt_neox import VLLMGPTNeoX as VLLMGPTNeoX
|
||||
|
||||
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure)
|
||||
sys.modules[__name__] = LazyModule(__name__, globals()['__file__'], _import_structure)
|
||||
|
||||
@@ -3,13 +3,13 @@ import logging, typing as t, openllm
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
class GPTNeoX(openllm.LLM["transformers.GPTNeoXForCausalLM", "transformers.GPTNeoXTokenizerFast"]):
|
||||
class GPTNeoX(openllm.LLM['transformers.GPTNeoXForCausalLM', 'transformers.GPTNeoXTokenizerFast']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
|
||||
import torch
|
||||
return {"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None}, {}
|
||||
return {'device_map': 'auto' if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None}, {}
|
||||
|
||||
def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.GPTNeoXForCausalLM:
|
||||
import transformers
|
||||
@@ -22,7 +22,7 @@ class GPTNeoX(openllm.LLM["transformers.GPTNeoXForCausalLM", "transformers.GPTNe
|
||||
with torch.inference_mode():
|
||||
return self.tokenizer.batch_decode(
|
||||
self.model.generate(
|
||||
self.tokenizer(prompt, return_tensors="pt").to(self.device).input_ids,
|
||||
self.tokenizer(prompt, return_tensors='pt').to(self.device).input_ids,
|
||||
do_sample=True,
|
||||
generation_config=self.config.model_construct_env(**attrs).to_generation_config(),
|
||||
pad_token_id=self.tokenizer.eos_token_id,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import typing as t, openllm
|
||||
if t.TYPE_CHECKING: import vllm, transformers
|
||||
class VLLMGPTNeoX(openllm.LLM["vllm.LLMEngine", "transformers.GPTNeoXTokenizerFast"]):
|
||||
class VLLMGPTNeoX(openllm.LLM['vllm.LLMEngine', 'transformers.GPTNeoXTokenizerFast']):
|
||||
__openllm_internal__ = True
|
||||
tokenizer_id = "local"
|
||||
tokenizer_id = 'local'
|
||||
|
||||
@@ -9,14 +9,14 @@ try:
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_vllm_llama"] = ["VLLMLlama"]
|
||||
_import_structure['modeling_vllm_llama'] = ['VLLMLlama']
|
||||
if t.TYPE_CHECKING: from .modeling_vllm_llama import VLLMLlama as VLLMLlama
|
||||
try:
|
||||
if not is_torch_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_llama"] = ["Llama"]
|
||||
_import_structure['modeling_llama'] = ['Llama']
|
||||
if t.TYPE_CHECKING: from .modeling_llama import Llama as Llama
|
||||
|
||||
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure)
|
||||
sys.modules[__name__] = LazyModule(__name__, globals()['__file__'], _import_structure)
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
from __future__ import annotations
|
||||
import typing as t, openllm
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
class Llama(openllm.LLM["transformers.LlamaForCausalLM", "transformers.LlamaTokenizerFast"]):
|
||||
class Llama(openllm.LLM['transformers.LlamaForCausalLM', 'transformers.LlamaTokenizerFast']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
|
||||
import torch
|
||||
return {"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32}, {}
|
||||
return {'torch_dtype': torch.float16 if torch.cuda.is_available() else torch.float32}, {}
|
||||
|
||||
def embeddings(self, prompts: list[str]) -> openllm.LLMEmbeddings:
|
||||
import torch, torch.nn.functional as F
|
||||
encoding = self.tokenizer(prompts, padding=True, return_tensors="pt").to(self.device)
|
||||
input_ids, attention_mask = encoding["input_ids"], encoding["attention_mask"]
|
||||
encoding = self.tokenizer(prompts, padding=True, return_tensors='pt').to(self.device)
|
||||
input_ids, attention_mask = encoding['input_ids'], encoding['attention_mask']
|
||||
with torch.inference_mode():
|
||||
data = self.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True).hidden_states[-1]
|
||||
mask = attention_mask.unsqueeze(-1).expand(data.size()).float()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import typing as t, openllm
|
||||
if t.TYPE_CHECKING: import vllm, transformers
|
||||
class VLLMLlama(openllm.LLM["vllm.LLMEngine", "transformers.LlamaTokenizerFast"]):
|
||||
class VLLMLlama(openllm.LLM['vllm.LLMEngine', 'transformers.LlamaTokenizerFast']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
@@ -9,14 +9,14 @@ try:
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_mpt"] = ["MPT"]
|
||||
_import_structure['modeling_mpt'] = ['MPT']
|
||||
if t.TYPE_CHECKING: from .modeling_mpt import MPT as MPT
|
||||
try:
|
||||
if not is_vllm_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_vllm_mpt"] = ["VLLMMPT"]
|
||||
_import_structure['modeling_vllm_mpt'] = ['VLLMMPT']
|
||||
if t.TYPE_CHECKING: from .modeling_vllm_mpt import VLLMMPT as VLLMMPT
|
||||
|
||||
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure)
|
||||
sys.modules[__name__] = LazyModule(__name__, globals()['__file__'], _import_structure)
|
||||
|
||||
@@ -9,8 +9,8 @@ def get_mpt_config(
|
||||
) -> transformers.PretrainedConfig:
|
||||
import torch
|
||||
config = transformers.AutoConfig.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code)
|
||||
if hasattr(config, "init_device") and device_map is None and isinstance(device, (str, torch.device)): config.init_device = str(device)
|
||||
if hasattr(config, "attn_config") and is_triton_available(): config.attn_config["attn_impl"] = "triton"
|
||||
if hasattr(config, 'init_device') and device_map is None and isinstance(device, (str, torch.device)): config.init_device = str(device)
|
||||
if hasattr(config, 'attn_config') and is_triton_available(): config.attn_config['attn_impl'] = 'triton'
|
||||
else:
|
||||
logger.debug(
|
||||
"'triton' is not available, Flash Attention will use the default Torch implementation. For faster inference, make sure to install triton with 'pip install \"git+https://github.com/openai/triton.git#egg=triton&subdirectory=python\"'"
|
||||
@@ -18,7 +18,7 @@ def get_mpt_config(
|
||||
# setting max_seq_len
|
||||
config.max_seq_len = max_sequence_length
|
||||
return config
|
||||
class MPT(openllm.LLM["transformers.PreTrainedModel", "transformers.GPTNeoXTokenizerFast"]):
|
||||
class MPT(openllm.LLM['transformers.PreTrainedModel', 'transformers.GPTNeoXTokenizerFast']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
def llm_post_init(self) -> None:
|
||||
@@ -28,28 +28,28 @@ class MPT(openllm.LLM["transformers.PreTrainedModel", "transformers.GPTNeoXToken
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
|
||||
import torch
|
||||
return {"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, "torch_dtype": torch.bfloat16 if torch.cuda.is_available() else torch.float32}, {}
|
||||
return {'device_map': 'auto' if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, 'torch_dtype': torch.bfloat16 if torch.cuda.is_available() else torch.float32}, {}
|
||||
|
||||
def import_model(self, *args: t.Any, trust_remote_code: bool = True, **attrs: t.Any) -> bentoml.Model:
|
||||
import torch, transformers
|
||||
_, tokenizer_attrs = self.llm_parameters
|
||||
torch_dtype = attrs.pop("torch_dtype", self.dtype)
|
||||
device_map = attrs.pop("device_map", None)
|
||||
attrs.pop("low_cpu_mem_usage", None)
|
||||
torch_dtype = attrs.pop('torch_dtype', self.dtype)
|
||||
device_map = attrs.pop('device_map', None)
|
||||
attrs.pop('low_cpu_mem_usage', None)
|
||||
config = get_mpt_config(self.model_id, self.config.max_sequence_length, self.device, device_map=device_map, trust_remote_code=trust_remote_code)
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **tokenizer_attrs)
|
||||
if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(self.model_id, config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code, device_map=device_map, **attrs)
|
||||
try:
|
||||
return bentoml.transformers.save_model(self.tag, model, custom_objects={"tokenizer": tokenizer}, labels=generate_labels(self))
|
||||
return bentoml.transformers.save_model(self.tag, model, custom_objects={'tokenizer': tokenizer}, labels=generate_labels(self))
|
||||
finally:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.PreTrainedModel:
|
||||
import transformers
|
||||
torch_dtype = attrs.pop("torch_dtype", self.dtype)
|
||||
device_map = attrs.pop("device_map", None)
|
||||
trust_remote_code = attrs.pop("trust_remote_code", True)
|
||||
torch_dtype = attrs.pop('torch_dtype', self.dtype)
|
||||
device_map = attrs.pop('device_map', None)
|
||||
trust_remote_code = attrs.pop('trust_remote_code', True)
|
||||
config = get_mpt_config(self._bentomodel.path, self.config.max_sequence_length, self.device, device_map=device_map, trust_remote_code=trust_remote_code,)
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(
|
||||
self._bentomodel.path, config=config, trust_remote_code=trust_remote_code, torch_dtype=torch_dtype, device_map=device_map, **attrs
|
||||
@@ -60,16 +60,16 @@ class MPT(openllm.LLM["transformers.PreTrainedModel", "transformers.GPTNeoXToken
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
import torch
|
||||
llm_config = self.config.model_construct_env(**attrs)
|
||||
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
||||
inputs = self.tokenizer(prompt, return_tensors='pt').to(self.device)
|
||||
attrs = {
|
||||
"do_sample": False if llm_config["temperature"] == 0 else True,
|
||||
"eos_token_id": self.tokenizer.eos_token_id,
|
||||
"pad_token_id": self.tokenizer.pad_token_id,
|
||||
"generation_config": llm_config.to_generation_config()
|
||||
'do_sample': False if llm_config['temperature'] == 0 else True,
|
||||
'eos_token_id': self.tokenizer.eos_token_id,
|
||||
'pad_token_id': self.tokenizer.pad_token_id,
|
||||
'generation_config': llm_config.to_generation_config()
|
||||
}
|
||||
with torch.inference_mode():
|
||||
if torch.cuda.is_available():
|
||||
with torch.autocast("cuda", torch.float16): # type: ignore[attr-defined]
|
||||
with torch.autocast('cuda', torch.float16): # type: ignore[attr-defined]
|
||||
generated_tensors = self.model.generate(**inputs, **attrs)
|
||||
else:
|
||||
generated_tensors = self.model.generate(**inputs, **attrs)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import typing as t, openllm
|
||||
if t.TYPE_CHECKING: import transformers, vllm
|
||||
class VLLMMPT(openllm.LLM["vllm.LLMEngine", "transformers.GPTNeoXTokenizerFast"]):
|
||||
class VLLMMPT(openllm.LLM['vllm.LLMEngine', 'transformers.GPTNeoXTokenizerFast']):
|
||||
__openllm_internal__ = True
|
||||
tokenizer_id = "local"
|
||||
tokenizer_id = 'local'
|
||||
|
||||
@@ -9,28 +9,28 @@ try:
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_opt"] = ["OPT"]
|
||||
_import_structure['modeling_opt'] = ['OPT']
|
||||
if t.TYPE_CHECKING: from .modeling_opt import OPT as OPT
|
||||
try:
|
||||
if not is_flax_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_flax_opt"] = ["FlaxOPT"]
|
||||
_import_structure['modeling_flax_opt'] = ['FlaxOPT']
|
||||
if t.TYPE_CHECKING: from .modeling_flax_opt import FlaxOPT as FlaxOPT
|
||||
try:
|
||||
if not is_vllm_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_vllm_opt"] = ["VLLMOPT"]
|
||||
_import_structure['modeling_vllm_opt'] = ['VLLMOPT']
|
||||
if t.TYPE_CHECKING: from .modeling_vllm_opt import VLLMOPT as VLLMOPT
|
||||
try:
|
||||
if not is_tf_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_tf_opt"] = ["TFOPT"]
|
||||
_import_structure['modeling_tf_opt'] = ['TFOPT']
|
||||
if t.TYPE_CHECKING: from .modeling_tf_opt import TFOPT as TFOPT
|
||||
|
||||
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure)
|
||||
sys.modules[__name__] = LazyModule(__name__, globals()['__file__'], _import_structure)
|
||||
|
||||
@@ -4,17 +4,17 @@ from openllm._prompt import process_prompt
|
||||
from openllm.utils import generate_labels
|
||||
from openllm_core.config.configuration_opt import DEFAULT_PROMPT_TEMPLATE
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
else: transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
|
||||
else: transformers = openllm.utils.LazyLoader('transformers', globals(), 'transformers')
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
class FlaxOPT(openllm.LLM["transformers.TFOPTForCausalLM", "transformers.GPT2Tokenizer"]):
|
||||
class FlaxOPT(openllm.LLM['transformers.TFOPTForCausalLM', 'transformers.GPT2Tokenizer']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model:
|
||||
config, tokenizer = transformers.AutoConfig.from_pretrained(self.model_id), transformers.AutoTokenizer.from_pretrained(self.model_id, **self.llm_parameters[-1])
|
||||
tokenizer.pad_token_id = config.pad_token_id
|
||||
return bentoml.transformers.save_model(
|
||||
self.tag, transformers.FlaxAutoModelForCausalLM.from_pretrained(self.model_id, **attrs), custom_objects={"tokenizer": tokenizer}, labels=generate_labels(self)
|
||||
self.tag, transformers.FlaxAutoModelForCausalLM.from_pretrained(self.model_id, **attrs), custom_objects={'tokenizer': tokenizer}, labels=generate_labels(self)
|
||||
)
|
||||
|
||||
def sanitize_parameters(
|
||||
@@ -29,11 +29,11 @@ class FlaxOPT(openllm.LLM["transformers.TFOPTForCausalLM", "transformers.GPT2Tok
|
||||
**attrs: t.Any
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {
|
||||
"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "num_return_sequences": num_return_sequences, "repetition_penalty": repetition_penalty
|
||||
'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_k': top_k, 'num_return_sequences': num_return_sequences, 'repetition_penalty': repetition_penalty
|
||||
}, {}
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
return self.tokenizer.batch_decode(
|
||||
self.model.generate(**self.tokenizer(prompt, return_tensors="np"), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()).sequences,
|
||||
self.model.generate(**self.tokenizer(prompt, return_tensors='np'), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()).sequences,
|
||||
skip_special_tokens=True
|
||||
)
|
||||
|
||||
@@ -3,18 +3,18 @@ import logging, typing as t, openllm
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
class OPT(openllm.LLM["transformers.OPTForCausalLM", "transformers.GPT2Tokenizer"]):
|
||||
class OPT(openllm.LLM['transformers.OPTForCausalLM', 'transformers.GPT2Tokenizer']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
|
||||
import torch
|
||||
return {"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32}, {}
|
||||
return {'torch_dtype': torch.float16 if torch.cuda.is_available() else torch.float32}, {}
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
import torch
|
||||
with torch.inference_mode():
|
||||
return self.tokenizer.batch_decode(
|
||||
self.model.generate(**self.tokenizer(prompt, return_tensors="pt").to(self.device), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()),
|
||||
self.model.generate(**self.tokenizer(prompt, return_tensors='pt').to(self.device), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()),
|
||||
skip_special_tokens=True
|
||||
)
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
import typing as t, bentoml, openllm
|
||||
from openllm_core.utils import generate_labels
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
class TFOPT(openllm.LLM["transformers.TFOPTForCausalLM", "transformers.GPT2Tokenizer"]):
|
||||
class TFOPT(openllm.LLM['transformers.TFOPTForCausalLM', 'transformers.GPT2Tokenizer']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model:
|
||||
@@ -12,12 +12,12 @@ class TFOPT(openllm.LLM["transformers.TFOPTForCausalLM", "transformers.GPT2Token
|
||||
return bentoml.transformers.save_model(
|
||||
self.tag,
|
||||
transformers.TFOPTForCausalLM.from_pretrained(self.model_id, trust_remote_code=trust_remote_code, **attrs),
|
||||
custom_objects={"tokenizer": tokenizer},
|
||||
custom_objects={'tokenizer': tokenizer},
|
||||
labels=generate_labels(self)
|
||||
)
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
return self.tokenizer.batch_decode(
|
||||
self.model.generate(**self.tokenizer(prompt, return_tensors="tf"), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()),
|
||||
self.model.generate(**self.tokenizer(prompt, return_tensors='tf'), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()),
|
||||
skip_special_tokens=True
|
||||
)
|
||||
|
||||
@@ -3,9 +3,9 @@ import typing as t, openllm
|
||||
from openllm_core._prompt import process_prompt
|
||||
from openllm_core.config.configuration_opt import DEFAULT_PROMPT_TEMPLATE
|
||||
if t.TYPE_CHECKING: import vllm, transformers
|
||||
class VLLMOPT(openllm.LLM["vllm.LLMEngine", "transformers.GPT2Tokenizer"]):
|
||||
class VLLMOPT(openllm.LLM['vllm.LLMEngine', 'transformers.GPT2Tokenizer']):
|
||||
__openllm_internal__ = True
|
||||
tokenizer_id = "local"
|
||||
tokenizer_id = 'local'
|
||||
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
@@ -18,5 +18,5 @@ class VLLMOPT(openllm.LLM["vllm.LLMEngine", "transformers.GPT2Tokenizer"]):
|
||||
**attrs: t.Any
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {
|
||||
"max_new_tokens": max_new_tokens, "temperature": temperature, "top_k": top_k, "num_return_sequences": num_return_sequences
|
||||
'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_k': top_k, 'num_return_sequences': num_return_sequences
|
||||
}, {}
|
||||
|
||||
@@ -9,14 +9,14 @@ try:
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_stablelm"] = ["StableLM"]
|
||||
_import_structure['modeling_stablelm'] = ['StableLM']
|
||||
if t.TYPE_CHECKING: from .modeling_stablelm import StableLM as StableLM
|
||||
try:
|
||||
if not is_vllm_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_vllm_stablelm"] = ["VLLMStableLM"]
|
||||
_import_structure['modeling_vllm_stablelm'] = ['VLLMStableLM']
|
||||
if t.TYPE_CHECKING: from .modeling_vllm_stablelm import VLLMStableLM as VLLMStableLM
|
||||
|
||||
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure)
|
||||
sys.modules[__name__] = LazyModule(__name__, globals()['__file__'], _import_structure)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import typing as t, openllm
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
class StableLM(openllm.LLM["transformers.GPTNeoXForCausalLM", "transformers.GPTNeoXTokenizerFast"]):
|
||||
class StableLM(openllm.LLM['transformers.GPTNeoXForCausalLM', 'transformers.GPTNeoXTokenizerFast']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
def llm_post_init(self) -> None:
|
||||
@@ -11,7 +11,7 @@ class StableLM(openllm.LLM["transformers.GPTNeoXForCausalLM", "transformers.GPTN
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
|
||||
import torch
|
||||
return {"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32}, {}
|
||||
return {'torch_dtype': torch.float16 if torch.cuda.is_available() else torch.float32}, {}
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
import torch
|
||||
@@ -19,7 +19,7 @@ class StableLM(openllm.LLM["transformers.GPTNeoXForCausalLM", "transformers.GPTN
|
||||
return [
|
||||
self.tokenizer.decode(
|
||||
self.model.generate(
|
||||
**self.tokenizer(prompt, return_tensors="pt").to(self.device),
|
||||
**self.tokenizer(prompt, return_tensors='pt').to(self.device),
|
||||
do_sample=True,
|
||||
generation_config=self.config.model_construct_env(**attrs).to_generation_config(),
|
||||
pad_token_id=self.tokenizer.eos_token_id,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import logging, typing as t, openllm
|
||||
if t.TYPE_CHECKING: import vllm, transformers
|
||||
class VLLMStableLM(openllm.LLM["vllm.LLMEngine", "transformers.GPTNeoXTokenizerFast"]):
|
||||
class VLLMStableLM(openllm.LLM['vllm.LLMEngine', 'transformers.GPTNeoXTokenizerFast']):
|
||||
__openllm_internal__ = True
|
||||
tokenizer_id = "local"
|
||||
tokenizer_id = 'local'
|
||||
|
||||
@@ -9,14 +9,14 @@ try:
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_starcoder"] = ["StarCoder"]
|
||||
_import_structure['modeling_starcoder'] = ['StarCoder']
|
||||
if t.TYPE_CHECKING: from .modeling_starcoder import StarCoder as StarCoder
|
||||
try:
|
||||
if not is_vllm_available(): raise MissingDependencyError
|
||||
except MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_vllm_starcoder"] = ["VLLMStarCoder"]
|
||||
_import_structure['modeling_vllm_starcoder'] = ['VLLMStarCoder']
|
||||
if t.TYPE_CHECKING: from .modeling_vllm_starcoder import VLLMStarCoder as VLLMStarCoder
|
||||
|
||||
sys.modules[__name__] = LazyModule(__name__, globals()["__file__"], _import_structure)
|
||||
sys.modules[__name__] = LazyModule(__name__, globals()['__file__'], _import_structure)
|
||||
|
||||
@@ -3,22 +3,22 @@ import logging, typing as t, bentoml, openllm
|
||||
from openllm.utils import generate_labels
|
||||
from openllm_core.config.configuration_starcoder import EOD, FIM_MIDDLE, FIM_PAD, FIM_PREFIX, FIM_SUFFIX
|
||||
if t.TYPE_CHECKING: import transformers
|
||||
class StarCoder(openllm.LLM["transformers.GPTBigCodeForCausalLM", "transformers.GPT2TokenizerFast"]):
|
||||
class StarCoder(openllm.LLM['transformers.GPTBigCodeForCausalLM', 'transformers.GPT2TokenizerFast']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
|
||||
import torch
|
||||
return {"device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32}, {}
|
||||
return {'device_map': 'auto' if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, 'torch_dtype': torch.float16 if torch.cuda.is_available() else torch.float32}, {}
|
||||
|
||||
def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model:
|
||||
import torch, transformers
|
||||
torch_dtype, device_map = attrs.pop("torch_dtype", torch.float16), attrs.pop("device_map", "auto")
|
||||
torch_dtype, device_map = attrs.pop('torch_dtype', torch.float16), attrs.pop('device_map', 'auto')
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **self.llm_parameters[-1])
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": [EOD, FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD], "pad_token": EOD})
|
||||
tokenizer.add_special_tokens({'additional_special_tokens': [EOD, FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD], 'pad_token': EOD})
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(self.model_id, torch_dtype=torch_dtype, device_map=device_map, **attrs)
|
||||
try:
|
||||
return bentoml.transformers.save_model(self.tag, model, custom_objects={"tokenizer": tokenizer}, labels=generate_labels(self))
|
||||
return bentoml.transformers.save_model(self.tag, model, custom_objects={'tokenizer': tokenizer}, labels=generate_labels(self))
|
||||
finally:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -28,7 +28,7 @@ class StarCoder(openllm.LLM["transformers.GPTBigCodeForCausalLM", "transformers.
|
||||
# eos_token_id=self.tokenizer.convert_tokens_to_ids("<|end|>"), # NOTE: this is for finetuning starcoder
|
||||
# NOTE: support fine-tuning starcoder
|
||||
result_tensor = self.model.generate(
|
||||
self.tokenizer.encode(prompt, return_tensors="pt").to(self.device),
|
||||
self.tokenizer.encode(prompt, return_tensors='pt').to(self.device),
|
||||
do_sample=True,
|
||||
pad_token_id=self.tokenizer.eos_token_id,
|
||||
generation_config=self.config.model_construct_env(**attrs).to_generation_config()
|
||||
@@ -37,12 +37,12 @@ class StarCoder(openllm.LLM["transformers.GPTBigCodeForCausalLM", "transformers.
|
||||
# return (skip_special_tokens=False, clean_up_tokenization_spaces=False))
|
||||
return self.tokenizer.batch_decode(result_tensor[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
|
||||
def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> list[dict[t.Literal["generated_text"], str]]:
|
||||
max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop("max_new_tokens", 200), self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
||||
src_len, stopping_criteria = encoded_inputs["input_ids"].shape[1], preprocess_generate_kwds.pop("stopping_criteria", openllm.StoppingCriteriaList([]))
|
||||
def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> list[dict[t.Literal['generated_text'], str]]:
|
||||
max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop('max_new_tokens', 200), self.tokenizer(prompt, return_tensors='pt').to(self.device)
|
||||
src_len, stopping_criteria = encoded_inputs['input_ids'].shape[1], preprocess_generate_kwds.pop('stopping_criteria', openllm.StoppingCriteriaList([]))
|
||||
stopping_criteria.append(openllm.StopSequenceCriteria(stop, self.tokenizer))
|
||||
result = self.tokenizer.decode(self.model.generate(encoded_inputs["input_ids"], max_new_tokens=max_new_tokens, stopping_criteria=stopping_criteria)[0].tolist()[src_len:])
|
||||
result = self.tokenizer.decode(self.model.generate(encoded_inputs['input_ids'], max_new_tokens=max_new_tokens, stopping_criteria=stopping_criteria)[0].tolist()[src_len:])
|
||||
# Inference API returns the stop sequence
|
||||
for stop_seq in stop:
|
||||
if result.endswith(stop_seq): result = result[:-len(stop_seq)]
|
||||
return [{"generated_text": result}]
|
||||
return [{'generated_text': result}]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import logging, typing as t, openllm
|
||||
if t.TYPE_CHECKING: import vllm, transformers
|
||||
class VLLMStarCoder(openllm.LLM["vllm.LLMEngine", "transformers.GPT2TokenizerFast"]):
|
||||
class VLLMStarCoder(openllm.LLM['vllm.LLMEngine', 'transformers.GPT2TokenizerFast']):
|
||||
__openllm_internal__ = True
|
||||
tokenizer_id = "local"
|
||||
tokenizer_id = 'local'
|
||||
|
||||
@@ -31,21 +31,21 @@ from openllm_core._typing_compat import M, T, ParamSpec
|
||||
if t.TYPE_CHECKING:
|
||||
import bentoml
|
||||
from . import constants as constants, ggml as ggml, transformers as transformers
|
||||
P = ParamSpec("P")
|
||||
P = ParamSpec('P')
|
||||
def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T:
|
||||
"""Load the tokenizer from BentoML store.
|
||||
'''Load the tokenizer from BentoML store.
|
||||
|
||||
By default, it will try to find the bentomodel whether it is in store..
|
||||
If model is not found, it will raises a ``bentoml.exceptions.NotFound``.
|
||||
"""
|
||||
'''
|
||||
from .transformers._helpers import infer_tokenizers_from_llm, process_config
|
||||
|
||||
config, *_ = process_config(llm._bentomodel.path, llm.__llm_trust_remote_code__)
|
||||
bentomodel_fs = fs.open_fs(llm._bentomodel.path)
|
||||
if bentomodel_fs.isfile(CUSTOM_OBJECTS_FILENAME):
|
||||
with bentomodel_fs.open(CUSTOM_OBJECTS_FILENAME, "rb") as cofile:
|
||||
with bentomodel_fs.open(CUSTOM_OBJECTS_FILENAME, 'rb') as cofile:
|
||||
try:
|
||||
tokenizer = cloudpickle.load(t.cast("t.IO[bytes]", cofile))["tokenizer"]
|
||||
tokenizer = cloudpickle.load(t.cast('t.IO[bytes]', cofile))['tokenizer']
|
||||
except KeyError:
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
"Bento model does not have tokenizer. Make sure to save"
|
||||
@@ -53,18 +53,18 @@ def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T:
|
||||
" For example: \"bentoml.transformers.save_model(..., custom_objects={'tokenizer': tokenizer})\""
|
||||
) from None
|
||||
else:
|
||||
tokenizer = infer_tokenizers_from_llm(llm).from_pretrained(bentomodel_fs.getsyspath("/"), trust_remote_code=llm.__llm_trust_remote_code__, **tokenizer_attrs)
|
||||
tokenizer = infer_tokenizers_from_llm(llm).from_pretrained(bentomodel_fs.getsyspath('/'), trust_remote_code=llm.__llm_trust_remote_code__, **tokenizer_attrs)
|
||||
|
||||
if tokenizer.pad_token_id is None:
|
||||
if config.pad_token_id is not None: tokenizer.pad_token_id = config.pad_token_id
|
||||
elif config.eos_token_id is not None: tokenizer.pad_token_id = config.eos_token_id
|
||||
elif tokenizer.eos_token_id is not None: tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
else: tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
else: tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
||||
return tokenizer
|
||||
class _Caller(t.Protocol[P]):
|
||||
def __call__(self, llm: openllm.LLM[M, T], *args: P.args, **kwargs: P.kwargs) -> t.Any:
|
||||
...
|
||||
_extras = ["get", "import_model", "save_pretrained", "load_model"]
|
||||
_extras = ['get', 'import_model', 'save_pretrained', 'load_model']
|
||||
def _make_dispatch_function(fn: str) -> _Caller[P]:
|
||||
def caller(llm: openllm.LLM[M, T], *args: P.args, **kwargs: P.kwargs) -> t.Any:
|
||||
"""Generic function dispatch to correct serialisation submodules based on LLM runtime.
|
||||
@@ -73,7 +73,7 @@ def _make_dispatch_function(fn: str) -> _Caller[P]:
|
||||
|
||||
> [!NOTE] See 'openllm.serialisation.ggml' if 'llm.runtime="ggml"'
|
||||
"""
|
||||
return getattr(importlib.import_module(f".{llm.runtime}", __name__), fn)(llm, *args, **kwargs)
|
||||
return getattr(importlib.import_module(f'.{llm.runtime}', __name__), fn)(llm, *args, **kwargs)
|
||||
|
||||
return caller
|
||||
if t.TYPE_CHECKING:
|
||||
@@ -89,12 +89,12 @@ if t.TYPE_CHECKING:
|
||||
|
||||
def load_model(llm: openllm.LLM[M, T], *args: t.Any, **kwargs: t.Any) -> M:
|
||||
...
|
||||
_import_structure: dict[str, list[str]] = {"ggml": [], "transformers": [], "constants": []}
|
||||
__all__ = ["ggml", "transformers", "constants", "load_tokenizer", *_extras]
|
||||
_import_structure: dict[str, list[str]] = {'ggml': [], 'transformers': [], 'constants': []}
|
||||
__all__ = ['ggml', 'transformers', 'constants', 'load_tokenizer', *_extras]
|
||||
def __dir__() -> list[str]:
|
||||
return sorted(__all__)
|
||||
def __getattr__(name: str) -> t.Any:
|
||||
if name == "load_tokenizer": return load_tokenizer
|
||||
elif name in _import_structure: return importlib.import_module(f".{name}", __name__)
|
||||
if name == 'load_tokenizer': return load_tokenizer
|
||||
elif name in _import_structure: return importlib.import_module(f'.{name}', __name__)
|
||||
elif name in _extras: return _make_dispatch_function(name)
|
||||
else: raise AttributeError(f"{__name__} has no attribute {name}")
|
||||
else: raise AttributeError(f'{__name__} has no attribute {name}')
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from __future__ import annotations
|
||||
FRAMEWORK_TO_AUTOCLASS_MAPPING = {
|
||||
"pt": ("AutoModelForCausalLM", "AutoModelForSeq2SeqLM"),
|
||||
"tf": ("TFAutoModelForCausalLM", "TFAutoModelForSeq2SeqLM"),
|
||||
"flax": ("FlaxAutoModelForCausalLM", "FlaxAutoModelForSeq2SeqLM"),
|
||||
"vllm": ("AutoModelForCausalLM", "AutoModelForSeq2SeqLM")
|
||||
'pt': ('AutoModelForCausalLM', 'AutoModelForSeq2SeqLM'),
|
||||
'tf': ('TFAutoModelForCausalLM', 'TFAutoModelForSeq2SeqLM'),
|
||||
'flax': ('FlaxAutoModelForCausalLM', 'FlaxAutoModelForSeq2SeqLM'),
|
||||
'vllm': ('AutoModelForCausalLM', 'AutoModelForSeq2SeqLM')
|
||||
}
|
||||
HUB_ATTRS = ["cache_dir", "code_revision", "force_download", "local_files_only", "proxies", "resume_download", "revision", "subfolder", "use_auth_token"]
|
||||
HUB_ATTRS = ['cache_dir', 'code_revision', 'force_download', 'local_files_only', 'proxies', 'resume_download', 'revision', 'subfolder', 'use_auth_token']
|
||||
|
||||
@@ -1,29 +1,29 @@
|
||||
"""Serialisation related implementation for GGML-based implementation.
|
||||
'''Serialisation related implementation for GGML-based implementation.
|
||||
|
||||
This requires ctransformers to be installed.
|
||||
"""
|
||||
'''
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
import bentoml, openllm
|
||||
|
||||
if t.TYPE_CHECKING: from openllm_core._typing_compat import M
|
||||
|
||||
_conversion_strategy = {"pt": "ggml"}
|
||||
_conversion_strategy = {'pt': 'ggml'}
|
||||
def import_model(llm: openllm.LLM[t.Any, t.Any], *decls: t.Any, trust_remote_code: bool = True, **attrs: t.Any,) -> bentoml.Model:
|
||||
raise NotImplementedError("Currently work in progress.")
|
||||
raise NotImplementedError('Currently work in progress.')
|
||||
def get(llm: openllm.LLM[t.Any, t.Any], auto_import: bool = False) -> bentoml.Model:
|
||||
"""Return an instance of ``bentoml.Model`` from given LLM instance.
|
||||
'''Return an instance of ``bentoml.Model`` from given LLM instance.
|
||||
|
||||
By default, it will try to check the model in the local store.
|
||||
If model is not found, and ``auto_import`` is set to True, it will try to import the model from HuggingFace Hub.
|
||||
|
||||
Otherwise, it will raises a ``bentoml.exceptions.NotFound``.
|
||||
"""
|
||||
'''
|
||||
try:
|
||||
model = bentoml.models.get(llm.tag)
|
||||
if model.info.module not in ("openllm.serialisation.ggml", __name__):
|
||||
if model.info.module not in ('openllm.serialisation.ggml', __name__):
|
||||
raise bentoml.exceptions.NotFound(f"Model {model.tag} was saved with module {model.info.module}, not loading with 'openllm.serialisation.transformers'.")
|
||||
if "runtime" in model.info.labels and model.info.labels["runtime"] != llm.runtime:
|
||||
if 'runtime' in model.info.labels and model.info.labels['runtime'] != llm.runtime:
|
||||
raise openllm.exceptions.OpenLLMException(f"Model {model.tag} was saved with runtime {model.info.labels['runtime']}, not loading with {llm.runtime}.")
|
||||
return model
|
||||
except bentoml.exceptions.NotFound:
|
||||
@@ -31,6 +31,6 @@ def get(llm: openllm.LLM[t.Any, t.Any], auto_import: bool = False) -> bentoml.Mo
|
||||
return import_model(llm, trust_remote_code=llm.__llm_trust_remote_code__)
|
||||
raise
|
||||
def load_model(llm: openllm.LLM[M, t.Any], *decls: t.Any, **attrs: t.Any) -> M:
|
||||
raise NotImplementedError("Currently work in progress.")
|
||||
raise NotImplementedError('Currently work in progress.')
|
||||
def save_pretrained(llm: openllm.LLM[t.Any, t.Any], save_directory: str, **attrs: t.Any) -> None:
|
||||
raise NotImplementedError("Currently work in progress.")
|
||||
raise NotImplementedError('Currently work in progress.')
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Serialisation related implementation for Transformers-based implementation."""
|
||||
'''Serialisation related implementation for Transformers-based implementation.'''
|
||||
from __future__ import annotations
|
||||
import importlib, logging, typing as t
|
||||
import bentoml, openllm
|
||||
@@ -18,14 +18,14 @@ if t.TYPE_CHECKING:
|
||||
from bentoml._internal.models import ModelStore
|
||||
from openllm_core._typing_compat import DictStrAny, M, T
|
||||
else:
|
||||
vllm = openllm.utils.LazyLoader("vllm", globals(), "vllm")
|
||||
autogptq = openllm.utils.LazyLoader("autogptq", globals(), "auto_gptq")
|
||||
transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
|
||||
torch = openllm.utils.LazyLoader("torch", globals(), "torch")
|
||||
vllm = openllm.utils.LazyLoader('vllm', globals(), 'vllm')
|
||||
autogptq = openllm.utils.LazyLoader('autogptq', globals(), 'auto_gptq')
|
||||
transformers = openllm.utils.LazyLoader('transformers', globals(), 'transformers')
|
||||
torch = openllm.utils.LazyLoader('torch', globals(), 'torch')
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ["import_model", "get", "load_model", "save_pretrained"]
|
||||
__all__ = ['import_model', 'get', 'load_model', 'save_pretrained']
|
||||
@inject
|
||||
def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool, _model_store: ModelStore = Provide[BentoMLContainer.model_store], **attrs: t.Any) -> bentoml.Model:
|
||||
"""Auto detect model type from given model_id and import it to bentoml's model store.
|
||||
@@ -48,23 +48,23 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool,
|
||||
config, hub_attrs, attrs = process_config(llm.model_id, trust_remote_code, **attrs)
|
||||
_, tokenizer_attrs = llm.llm_parameters
|
||||
quantize_method = llm._quantize_method
|
||||
safe_serialisation = openllm.utils.first_not_none(attrs.get("safe_serialization"), default=llm._serialisation_format == "safetensors")
|
||||
safe_serialisation = openllm.utils.first_not_none(attrs.get('safe_serialization'), default=llm._serialisation_format == 'safetensors')
|
||||
# Disable safe serialization with vLLM
|
||||
if llm.__llm_implementation__ == "vllm": safe_serialisation = False
|
||||
metadata: DictStrAny = {"safe_serialisation": safe_serialisation, "_quantize": quantize_method is not None and quantize_method}
|
||||
if llm.__llm_implementation__ == 'vllm': safe_serialisation = False
|
||||
metadata: DictStrAny = {'safe_serialisation': safe_serialisation, '_quantize': quantize_method is not None and quantize_method}
|
||||
signatures: DictStrAny = {}
|
||||
|
||||
if quantize_method == "gptq":
|
||||
if quantize_method == 'gptq':
|
||||
if not openllm.utils.is_autogptq_available():
|
||||
raise openllm.exceptions.OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
|
||||
if llm.config["model_type"] != "causal_lm": raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
signatures["generate"] = {"batchable": False}
|
||||
if llm.config['model_type'] != 'causal_lm': raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
signatures['generate'] = {'batchable': False}
|
||||
else:
|
||||
# this model might be called with --quantize int4, therefore we need to pop this out
|
||||
# since saving int4 is not yet supported
|
||||
if "quantization_config" in attrs and getattr(attrs["quantization_config"], "load_in_4bit", False): attrs.pop("quantization_config")
|
||||
if llm.__llm_implementation__ != "flax": attrs["use_safetensors"] = safe_serialisation
|
||||
metadata["_framework"] = "pt" if llm.__llm_implementation__ == "vllm" else llm.__llm_implementation__
|
||||
if 'quantization_config' in attrs and getattr(attrs['quantization_config'], 'load_in_4bit', False): attrs.pop('quantization_config')
|
||||
if llm.__llm_implementation__ != 'flax': attrs['use_safetensors'] = safe_serialisation
|
||||
metadata['_framework'] = 'pt' if llm.__llm_implementation__ == 'vllm' else llm.__llm_implementation__
|
||||
|
||||
tokenizer = infer_tokenizers_from_llm(llm).from_pretrained(llm.model_id, trust_remote_code=trust_remote_code, **hub_attrs, **tokenizer_attrs)
|
||||
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
|
||||
@@ -73,10 +73,10 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool,
|
||||
imported_modules: list[types.ModuleType] = []
|
||||
bentomodel = bentoml.Model.create(
|
||||
llm.tag,
|
||||
module="openllm.serialisation.transformers",
|
||||
api_version="v1",
|
||||
module='openllm.serialisation.transformers',
|
||||
api_version='v1',
|
||||
options=ModelOptions(),
|
||||
context=openllm.utils.generate_context(framework_name="openllm"),
|
||||
context=openllm.utils.generate_context(framework_name='openllm'),
|
||||
labels=openllm.utils.generate_labels(llm),
|
||||
signatures=signatures if signatures else make_model_signatures(llm)
|
||||
)
|
||||
@@ -84,35 +84,35 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool,
|
||||
try:
|
||||
bentomodel.enter_cloudpickle_context(external_modules, imported_modules)
|
||||
tokenizer.save_pretrained(bentomodel.path)
|
||||
if quantize_method == "gptq":
|
||||
if quantize_method == 'gptq':
|
||||
if not openllm.utils.is_autogptq_available():
|
||||
raise openllm.exceptions.OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
|
||||
if llm.config["model_type"] != "causal_lm": raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
logger.debug("Saving model with GPTQ quantisation will require loading model into memory.")
|
||||
if llm.config['model_type'] != 'causal_lm': raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
logger.debug('Saving model with GPTQ quantisation will require loading model into memory.')
|
||||
model = autogptq.AutoGPTQForCausalLM.from_quantized(
|
||||
llm.model_id,
|
||||
*decls,
|
||||
quantize_config=t.cast("autogptq.BaseQuantizeConfig", llm.quantization_config),
|
||||
quantize_config=t.cast('autogptq.BaseQuantizeConfig', llm.quantization_config),
|
||||
trust_remote_code=trust_remote_code,
|
||||
use_safetensors=safe_serialisation,
|
||||
**hub_attrs,
|
||||
**attrs,
|
||||
)
|
||||
update_model(bentomodel, metadata={"_pretrained_class": model.__class__.__name__, "_framework": model.model.framework})
|
||||
update_model(bentomodel, metadata={'_pretrained_class': model.__class__.__name__, '_framework': model.model.framework})
|
||||
model.save_quantized(bentomodel.path, use_safetensors=safe_serialisation)
|
||||
else:
|
||||
architectures = getattr(config, "architectures", [])
|
||||
architectures = getattr(config, 'architectures', [])
|
||||
if not architectures:
|
||||
raise RuntimeError("Failed to determine the architecture for this model. Make sure the `config.json` is valid and can be loaded with `transformers.AutoConfig`")
|
||||
raise RuntimeError('Failed to determine the architecture for this model. Make sure the `config.json` is valid and can be loaded with `transformers.AutoConfig`')
|
||||
architecture = architectures[0]
|
||||
update_model(bentomodel, metadata={"_pretrained_class": architecture})
|
||||
update_model(bentomodel, metadata={'_pretrained_class': architecture})
|
||||
if llm._local:
|
||||
# possible local path
|
||||
logger.debug("Model will be loaded into memory to save to target store as it is from local path.")
|
||||
logger.debug('Model will be loaded into memory to save to target store as it is from local path.')
|
||||
model = infer_autoclass_from_llm(llm, config).from_pretrained(llm.model_id, *decls, config=config, trust_remote_code=trust_remote_code, **hub_attrs, **attrs)
|
||||
# for trust_remote_code to work
|
||||
bentomodel.enter_cloudpickle_context([importlib.import_module(model.__module__)], imported_modules)
|
||||
model.save_pretrained(bentomodel.path, max_shard_size="5GB", safe_serialization=safe_serialisation)
|
||||
model.save_pretrained(bentomodel.path, max_shard_size='5GB', safe_serialization=safe_serialisation)
|
||||
else:
|
||||
# we will clone the all tings into the bentomodel path without loading model into memory
|
||||
snapshot_download(llm.model_id, local_dir=bentomodel.path, local_dir_use_symlinks=False, ignore_patterns=HfIgnore.ignore_patterns(llm))
|
||||
@@ -129,58 +129,58 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool,
|
||||
if openllm.utils.is_torch_available() and torch.cuda.is_available(): torch.cuda.empty_cache()
|
||||
return bentomodel
|
||||
def get(llm: openllm.LLM[M, T], auto_import: bool = False) -> bentoml.Model:
|
||||
"""Return an instance of ``bentoml.Model`` from given LLM instance.
|
||||
'''Return an instance of ``bentoml.Model`` from given LLM instance.
|
||||
|
||||
By default, it will try to check the model in the local store.
|
||||
If model is not found, and ``auto_import`` is set to True, it will try to import the model from HuggingFace Hub.
|
||||
|
||||
Otherwise, it will raises a ``bentoml.exceptions.NotFound``.
|
||||
"""
|
||||
'''
|
||||
try:
|
||||
model = bentoml.models.get(llm.tag)
|
||||
if model.info.module not in (
|
||||
"openllm.serialisation.transformers"
|
||||
"bentoml.transformers", "bentoml._internal.frameworks.transformers", __name__
|
||||
'openllm.serialisation.transformers'
|
||||
'bentoml.transformers', 'bentoml._internal.frameworks.transformers', __name__
|
||||
): # NOTE: backward compatible with previous version of OpenLLM.
|
||||
raise bentoml.exceptions.NotFound(f"Model {model.tag} was saved with module {model.info.module}, not loading with 'openllm.serialisation.transformers'.")
|
||||
if "runtime" in model.info.labels and model.info.labels["runtime"] != llm.runtime:
|
||||
if 'runtime' in model.info.labels and model.info.labels['runtime'] != llm.runtime:
|
||||
raise openllm.exceptions.OpenLLMException(f"Model {model.tag} was saved with runtime {model.info.labels['runtime']}, not loading with {llm.runtime}.")
|
||||
return model
|
||||
except bentoml.exceptions.NotFound as err:
|
||||
if auto_import: return import_model(llm, trust_remote_code=llm.__llm_trust_remote_code__)
|
||||
raise err from None
|
||||
def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M:
|
||||
"""Load the model from BentoML store.
|
||||
'''Load the model from BentoML store.
|
||||
|
||||
By default, it will try to find check the model in the local store.
|
||||
If model is not found, it will raises a ``bentoml.exceptions.NotFound``.
|
||||
"""
|
||||
'''
|
||||
config, hub_attrs, attrs = process_config(llm.model_id, llm.__llm_trust_remote_code__, **attrs)
|
||||
safe_serialization = openllm.utils.first_not_none(
|
||||
t.cast(t.Optional[bool], llm._bentomodel.info.metadata.get("safe_serialisation", None)), attrs.pop("safe_serialization", None), default=llm._serialisation_format == "safetensors"
|
||||
t.cast(t.Optional[bool], llm._bentomodel.info.metadata.get('safe_serialisation', None)), attrs.pop('safe_serialization', None), default=llm._serialisation_format == 'safetensors'
|
||||
)
|
||||
if "_quantize" in llm._bentomodel.info.metadata and llm._bentomodel.info.metadata["_quantize"] == "gptq":
|
||||
if '_quantize' in llm._bentomodel.info.metadata and llm._bentomodel.info.metadata['_quantize'] == 'gptq':
|
||||
if not openllm.utils.is_autogptq_available():
|
||||
raise openllm.exceptions.OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
|
||||
if llm.config["model_type"] != "causal_lm": raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
if llm.config['model_type'] != 'causal_lm': raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
return autogptq.AutoGPTQForCausalLM.from_quantized(
|
||||
llm._bentomodel.path,
|
||||
*decls,
|
||||
quantize_config=t.cast("autogptq.BaseQuantizeConfig", llm.quantization_config),
|
||||
quantize_config=t.cast('autogptq.BaseQuantizeConfig', llm.quantization_config),
|
||||
trust_remote_code=llm.__llm_trust_remote_code__,
|
||||
use_safetensors=safe_serialization,
|
||||
**hub_attrs,
|
||||
**attrs
|
||||
)
|
||||
|
||||
device_map = attrs.pop("device_map", "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None)
|
||||
device_map = attrs.pop('device_map', 'auto' if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None)
|
||||
model = infer_autoclass_from_llm(llm, config).from_pretrained(
|
||||
llm._bentomodel.path, *decls, config=config, trust_remote_code=llm.__llm_trust_remote_code__, device_map=device_map, **hub_attrs, **attrs
|
||||
).eval()
|
||||
# BetterTransformer is currently only supported on PyTorch.
|
||||
if llm.bettertransformer and isinstance(model, transformers.PreTrainedModel): model = model.to_bettertransformer()
|
||||
if llm.__llm_implementation__ in {"pt", "vllm"}: check_unintialised_params(model)
|
||||
return t.cast("M", model)
|
||||
if llm.__llm_implementation__ in {'pt', 'vllm'}: check_unintialised_params(model)
|
||||
return t.cast('M', model)
|
||||
def save_pretrained(
|
||||
llm: openllm.LLM[M, T],
|
||||
save_directory: str,
|
||||
@@ -188,29 +188,29 @@ def save_pretrained(
|
||||
state_dict: DictStrAny | None = None,
|
||||
save_function: t.Any | None = None,
|
||||
push_to_hub: bool = False,
|
||||
max_shard_size: int | str = "10GB",
|
||||
max_shard_size: int | str = '10GB',
|
||||
safe_serialization: bool = False,
|
||||
variant: str | None = None,
|
||||
**attrs: t.Any
|
||||
) -> None:
|
||||
save_function = t.cast(t.Callable[..., None], openllm.utils.first_not_none(save_function, default=torch.save))
|
||||
model_save_attrs, tokenizer_save_attrs = openllm.utils.normalize_attrs_to_model_tokenizer_pair(**attrs)
|
||||
safe_serialization = safe_serialization or llm._serialisation_format == "safetensors"
|
||||
safe_serialization = safe_serialization or llm._serialisation_format == 'safetensors'
|
||||
# NOTE: disable safetensors for vllm
|
||||
if llm.__llm_implementation__ == "vllm": safe_serialization = False
|
||||
if llm._quantize_method == "gptq":
|
||||
if llm.__llm_implementation__ == 'vllm': safe_serialization = False
|
||||
if llm._quantize_method == 'gptq':
|
||||
if not openllm.utils.is_autogptq_available():
|
||||
raise openllm.exceptions.OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
|
||||
if llm.config["model_type"] != "causal_lm": raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
if not openllm.utils.lenient_issubclass(llm.model, autogptq.modeling.BaseGPTQForCausalLM): raise ValueError(f"Model is not a BaseGPTQForCausalLM (type: {type(llm.model)})")
|
||||
t.cast("autogptq.modeling.BaseGPTQForCausalLM", llm.model).save_quantized(save_directory, use_safetensors=safe_serialization)
|
||||
elif openllm.utils.LazyType["vllm.LLMEngine"]("vllm.LLMEngine").isinstance(llm.model):
|
||||
if llm.config['model_type'] != 'causal_lm': raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
if not openllm.utils.lenient_issubclass(llm.model, autogptq.modeling.BaseGPTQForCausalLM): raise ValueError(f'Model is not a BaseGPTQForCausalLM (type: {type(llm.model)})')
|
||||
t.cast('autogptq.modeling.BaseGPTQForCausalLM', llm.model).save_quantized(save_directory, use_safetensors=safe_serialization)
|
||||
elif openllm.utils.LazyType['vllm.LLMEngine']('vllm.LLMEngine').isinstance(llm.model):
|
||||
raise RuntimeError("vllm.LLMEngine cannot be serialisation directly. This happens when 'save_pretrained' is called directly after `openllm.AutoVLLM` is initialized.")
|
||||
elif isinstance(llm.model, transformers.Pipeline):
|
||||
llm.model.save_pretrained(save_directory, safe_serialization=safe_serialization)
|
||||
else:
|
||||
# We can safely cast here since it will be the PreTrainedModel protocol.
|
||||
t.cast("transformers.PreTrainedModel", llm.model).save_pretrained(
|
||||
t.cast('transformers.PreTrainedModel', llm.model).save_pretrained(
|
||||
save_directory,
|
||||
is_main_process=is_main_process,
|
||||
state_dict=state_dict,
|
||||
|
||||
@@ -9,11 +9,11 @@ if t.TYPE_CHECKING:
|
||||
from bentoml._internal.models.model import ModelSignaturesType
|
||||
from openllm_core._typing_compat import DictStrAny, M, T
|
||||
else:
|
||||
transformers, torch = openllm_core.utils.LazyLoader("transformers", globals(), "transformers"), openllm_core.utils.LazyLoader("torch", globals(), "torch")
|
||||
transformers, torch = openllm_core.utils.LazyLoader('transformers', globals(), 'transformers'), openllm_core.utils.LazyLoader('torch', globals(), 'torch')
|
||||
|
||||
_object_setattr = object.__setattr__
|
||||
def process_config(model_id: str, trust_remote_code: bool, **attrs: t.Any) -> tuple[transformers.PretrainedConfig, DictStrAny, DictStrAny]:
|
||||
"""A helper function that correctly parse config and attributes for transformers.PretrainedConfig.
|
||||
'''A helper function that correctly parse config and attributes for transformers.PretrainedConfig.
|
||||
|
||||
Args:
|
||||
model_id: Model id to pass into ``transformers.AutoConfig``.
|
||||
@@ -22,51 +22,51 @@ def process_config(model_id: str, trust_remote_code: bool, **attrs: t.Any) -> tu
|
||||
|
||||
Returns:
|
||||
A tuple of ``transformers.PretrainedConfig``, all hub attributes, and remanining attributes that can be used by the Model class.
|
||||
"""
|
||||
config = attrs.pop("config", None)
|
||||
'''
|
||||
config = attrs.pop('config', None)
|
||||
# this logic below is synonymous to handling `from_pretrained` attrs.
|
||||
hub_attrs = {k: attrs.pop(k) for k in HUB_ATTRS if k in attrs}
|
||||
if not isinstance(config, transformers.PretrainedConfig):
|
||||
copied_attrs = copy.deepcopy(attrs)
|
||||
if copied_attrs.get("torch_dtype", None) == "auto": copied_attrs.pop("torch_dtype")
|
||||
if copied_attrs.get('torch_dtype', None) == 'auto': copied_attrs.pop('torch_dtype')
|
||||
config, attrs = transformers.AutoConfig.from_pretrained(model_id, return_unused_kwargs=True, trust_remote_code=trust_remote_code, **hub_attrs, **copied_attrs)
|
||||
return config, hub_attrs, attrs
|
||||
def infer_tokenizers_from_llm(__llm: openllm.LLM[t.Any, T], /) -> T:
|
||||
__cls = getattr(transformers, openllm_core.utils.first_not_none(__llm.config["tokenizer_class"], default="AutoTokenizer"), None)
|
||||
if __cls is None: raise ValueError(f"Cannot infer correct tokenizer class for {__llm}. Make sure to unset `tokenizer_class`")
|
||||
__cls = getattr(transformers, openllm_core.utils.first_not_none(__llm.config['tokenizer_class'], default='AutoTokenizer'), None)
|
||||
if __cls is None: raise ValueError(f'Cannot infer correct tokenizer class for {__llm}. Make sure to unset `tokenizer_class`')
|
||||
return __cls
|
||||
def infer_autoclass_from_llm(llm: openllm.LLM[M, T], config: transformers.PretrainedConfig, /) -> _BaseAutoModelClass:
|
||||
if llm.config["trust_remote_code"]:
|
||||
autoclass = "AutoModelForSeq2SeqLM" if llm.config["model_type"] == "seq2seq_lm" else "AutoModelForCausalLM"
|
||||
if not hasattr(config, "auto_map"):
|
||||
raise ValueError(f"Invalid configuraiton for {llm.model_id}. ``trust_remote_code=True`` requires `transformers.PretrainedConfig` to contain a `auto_map` mapping")
|
||||
if llm.config['trust_remote_code']:
|
||||
autoclass = 'AutoModelForSeq2SeqLM' if llm.config['model_type'] == 'seq2seq_lm' else 'AutoModelForCausalLM'
|
||||
if not hasattr(config, 'auto_map'):
|
||||
raise ValueError(f'Invalid configuraiton for {llm.model_id}. ``trust_remote_code=True`` requires `transformers.PretrainedConfig` to contain a `auto_map` mapping')
|
||||
# in case this model doesn't use the correct auto class for model type, for example like chatglm
|
||||
# where it uses AutoModel instead of AutoModelForCausalLM. Then we fallback to AutoModel
|
||||
if autoclass not in config.auto_map: autoclass = "AutoModel"
|
||||
if autoclass not in config.auto_map: autoclass = 'AutoModel'
|
||||
return getattr(transformers, autoclass)
|
||||
else:
|
||||
if type(config) in transformers.MODEL_FOR_CAUSAL_LM_MAPPING: idx = 0
|
||||
elif type(config) in transformers.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING: idx = 1
|
||||
else: raise openllm.exceptions.OpenLLMException(f"Model type {type(config)} is not supported yet.")
|
||||
else: raise openllm.exceptions.OpenLLMException(f'Model type {type(config)} is not supported yet.')
|
||||
return getattr(transformers, FRAMEWORK_TO_AUTOCLASS_MAPPING[llm.__llm_implementation__][idx])
|
||||
def check_unintialised_params(model: torch.nn.Module) -> None:
|
||||
unintialized = [n for n, param in model.named_parameters() if param.data.device == torch.device("meta")]
|
||||
if len(unintialized) > 0: raise RuntimeError(f"Found the following unintialized parameters in {model}: {unintialized}")
|
||||
unintialized = [n for n, param in model.named_parameters() if param.data.device == torch.device('meta')]
|
||||
if len(unintialized) > 0: raise RuntimeError(f'Found the following unintialized parameters in {model}: {unintialized}')
|
||||
def update_model(bentomodel: bentoml.Model, metadata: DictStrAny) -> bentoml.Model:
|
||||
based: DictStrAny = copy.deepcopy(bentomodel.info.metadata)
|
||||
based.update(metadata)
|
||||
_object_setattr(bentomodel, "_info", ModelInfo( # type: ignore[call-arg] # XXX: remove me once upstream is merged
|
||||
_object_setattr(bentomodel, '_info', ModelInfo( # type: ignore[call-arg] # XXX: remove me once upstream is merged
|
||||
tag=bentomodel.info.tag, module=bentomodel.info.module, labels=bentomodel.info.labels, options=bentomodel.info.options.to_dict(), signatures=bentomodel.info.signatures, context=bentomodel.info.context, api_version=bentomodel.info.api_version, creation_time=bentomodel.info.creation_time, metadata=based
|
||||
))
|
||||
return bentomodel
|
||||
# NOTE: sync with bentoml/_internal/frameworks/transformers.py#make_default_signatures
|
||||
def make_model_signatures(llm: openllm.LLM[M, T]) -> ModelSignaturesType:
|
||||
infer_fn: tuple[str, ...] = ("__call__",)
|
||||
infer_fn: tuple[str, ...] = ('__call__',)
|
||||
default_config = ModelSignature(batchable=False)
|
||||
if llm.__llm_implementation__ in {"pt", "vllm"}:
|
||||
infer_fn += ("forward", "generate", "contrastive_search", "greedy_search", "sample", "beam_search", "beam_sample", "group_beam_search", "constrained_beam_search",)
|
||||
elif llm.__llm_implementation__ == "tf":
|
||||
infer_fn += ("predict", "call", "generate", "compute_transition_scores", "greedy_search", "sample", "beam_search", "contrastive_search",)
|
||||
if llm.__llm_implementation__ in {'pt', 'vllm'}:
|
||||
infer_fn += ('forward', 'generate', 'contrastive_search', 'greedy_search', 'sample', 'beam_search', 'beam_sample', 'group_beam_search', 'constrained_beam_search',)
|
||||
elif llm.__llm_implementation__ == 'tf':
|
||||
infer_fn += ('predict', 'call', 'generate', 'compute_transition_scores', 'greedy_search', 'sample', 'beam_search', 'contrastive_search',)
|
||||
else:
|
||||
infer_fn += ("generate",)
|
||||
infer_fn += ('generate',)
|
||||
return {k: default_config for k in infer_fn}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user