mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-02-18 14:47:30 -05:00
style: google
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -23,16 +23,6 @@ repos:
|
||||
- id: ruff
|
||||
verbose: true
|
||||
args: [--exit-non-zero-on-fix, --show-fixes]
|
||||
- repo: https://github.com/editorconfig-checker/editorconfig-checker.python
|
||||
rev: '2.7.2'
|
||||
hooks:
|
||||
- id: editorconfig-checker
|
||||
verbose: true
|
||||
exclude: |
|
||||
(?x)^(
|
||||
openllm-client/src/openllm_client/pb.*|
|
||||
openllm-python/src/openllm/cli/entrypoint.py
|
||||
)$
|
||||
- repo: https://github.com/econchick/interrogate
|
||||
rev: 1.5.0
|
||||
hooks:
|
||||
|
||||
@@ -22,7 +22,7 @@
|
||||
</a><a href="https://github.com/pypa/hatch">
|
||||
<img src="https://img.shields.io/badge/%F0%9F%A5%9A-Hatch-4051b5.svg" alt="Hatch" />
|
||||
</a><a href="https://github.com/bentoml/OpenLLM/blob/main/STYLE.md">
|
||||
<img src="https://img.shields.io/badge/code%20style-experimental-000000.svg" alt="code style" />
|
||||
<img src="https://img.shields.io/badge/code%20style-Google-000000.svg" alt="code style" />
|
||||
</a><a href="https://github.com/astral-sh/ruff">
|
||||
<img src="https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/charliermarsh/ruff/main/assets/badge/v2.json" alt="Ruff" />
|
||||
</a><a href="https://github.com/python/mypy">
|
||||
|
||||
5
cz.py
5
cz.py
@@ -19,7 +19,10 @@ def run_cz(dir: str, package: str):
|
||||
with tokenize.open(filepath) as file_:
|
||||
tokens = [t for t in tokenize.generate_tokens(file_.readline) if t.type in TOKEN_WHITELIST]
|
||||
token_count, line_count = len(tokens), len(set([t.start[0] for t in tokens]))
|
||||
table.append([filepath.replace(os.path.join(dir, 'src'), ''), line_count, token_count / line_count if line_count != 0 else 0])
|
||||
table.append([
|
||||
filepath.replace(os.path.join(dir, 'src'), ''), line_count,
|
||||
token_count / line_count if line_count != 0 else 0
|
||||
])
|
||||
print(tabulate([headers, *sorted(table, key=lambda x: -x[1])], headers='firstrow', floatfmt='.1f') + '\n')
|
||||
for dir_name, group in itertools.groupby(sorted([(x[0].rsplit('/', 1)[0], x[1]) for x in table]), key=lambda x: x[0]):
|
||||
print(f'{dir_name:35s} : {sum([x[1] for x in group]):6d}')
|
||||
|
||||
@@ -22,9 +22,8 @@ def gen_llm(model_name: str, model_id: str | None = None) -> OpenLLM:
|
||||
|
||||
llm = gen_llm("dolly-v2", model_id="databricks/dolly-v2-7b")
|
||||
|
||||
prompt = PromptTemplate(
|
||||
input_variables=["industry", "product_name", "keywords"],
|
||||
template="""
|
||||
prompt = PromptTemplate(input_variables=["industry", "product_name", "keywords"],
|
||||
template="""
|
||||
You are a Facebook Ads Copywriter with a strong background in persuasive
|
||||
writing and marketing. You craft compelling copy that appeals to the target
|
||||
audience's emotions and needs, peruading them to take action or make a
|
||||
@@ -36,8 +35,7 @@ Industry: {industry}
|
||||
Product: {product_name}
|
||||
Keywords: {keywords}
|
||||
Facebook Ads copy:
|
||||
""",
|
||||
)
|
||||
""")
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
|
||||
svc = bentoml.Service("fb-ads-copy", runners=[llm.runner])
|
||||
@@ -47,9 +45,16 @@ def download(_: bentoml.Context):
|
||||
llm.runner.download_model()
|
||||
|
||||
SAMPLE_INPUT = Query(
|
||||
industry="SAAS", product_name="BentoML", keywords=["open source", "developer tool", "AI application platform", "serverless", "cost-efficient"], llm_config=llm.runner.config.model_dump(),
|
||||
industry="SAAS",
|
||||
product_name="BentoML",
|
||||
keywords=["open source", "developer tool", "AI application platform", "serverless", "cost-efficient"],
|
||||
llm_config=llm.runner.config.model_dump(),
|
||||
)
|
||||
|
||||
@svc.api(input=JSON.from_sample(sample=SAMPLE_INPUT), output=Text())
|
||||
def generate(query: Query):
|
||||
return chain.run({"industry": query.industry, "product_name": query.product_name, "keywords": ", ".join(query.keywords)})
|
||||
return chain.run({
|
||||
"industry": query.industry,
|
||||
"product_name": query.product_name,
|
||||
"keywords": ", ".join(query.keywords)
|
||||
})
|
||||
|
||||
@@ -65,7 +65,10 @@ class _ClientAttr:
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def query(self, prompt: str, return_response: t.Literal['attrs', 'raw', 'processed'] = 'processed', **attrs: t.Any) -> t.Any:
|
||||
def query(self,
|
||||
prompt: str,
|
||||
return_response: t.Literal['attrs', 'raw', 'processed'] = 'processed',
|
||||
**attrs: t.Any) -> t.Any:
|
||||
raise NotImplementedError
|
||||
|
||||
# NOTE: Scikit interface
|
||||
@@ -81,7 +84,8 @@ class _ClientAttr:
|
||||
|
||||
@overload
|
||||
@abc.abstractmethod
|
||||
def predict(self, prompt: str, *, return_response: t.Literal['attrs'], **attrs: t.Any) -> openllm_core.GenerationOutput:
|
||||
def predict(self, prompt: str, *, return_response: t.Literal['attrs'],
|
||||
**attrs: t.Any) -> openllm_core.GenerationOutput:
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -90,9 +94,15 @@ class _ClientAttr:
|
||||
|
||||
@functools.cached_property
|
||||
def _hf_agent(self) -> transformers.HfAgent:
|
||||
if not is_transformers_available(): raise RuntimeError("transformers is required to use HF agent. Install with 'pip install \"openllm-client[agents]\"'.")
|
||||
if not self.supports_hf_agent: raise RuntimeError(f'{self.model_name} ({self.framework}) does not support running HF agent.')
|
||||
if not is_transformers_supports_agent(): raise RuntimeError("Current 'transformers' does not support Agent. Make sure to upgrade to at least 4.29: 'pip install -U \"transformers>=4.29\"'")
|
||||
if not is_transformers_available():
|
||||
raise RuntimeError(
|
||||
"transformers is required to use HF agent. Install with 'pip install \"openllm-client[agents]\"'.")
|
||||
if not self.supports_hf_agent:
|
||||
raise RuntimeError(f'{self.model_name} ({self.framework}) does not support running HF agent.')
|
||||
if not is_transformers_supports_agent():
|
||||
raise RuntimeError(
|
||||
"Current 'transformers' does not support Agent. Make sure to upgrade to at least 4.29: 'pip install -U \"transformers>=4.29\"'"
|
||||
)
|
||||
import transformers
|
||||
return transformers.HfAgent(urljoin(self._address, '/hf/agent'))
|
||||
|
||||
@@ -173,7 +183,13 @@ class _Client(_ClientAttr):
|
||||
return BentoClient.from_url(self._address)
|
||||
|
||||
# Agent integration
|
||||
def ask_agent(self, task: str, *, return_code: bool = False, remote: bool = False, agent_type: LiteralString = 'hf', **attrs: t.Any) -> t.Any:
|
||||
def ask_agent(self,
|
||||
task: str,
|
||||
*,
|
||||
return_code: bool = False,
|
||||
remote: bool = False,
|
||||
agent_type: LiteralString = 'hf',
|
||||
**attrs: t.Any) -> t.Any:
|
||||
if agent_type == 'hf': return self._run_hf_agent(task, return_code=return_code, remote=remote, **attrs)
|
||||
else: raise RuntimeError(f"Unknown 'agent_type={agent_type}'")
|
||||
|
||||
@@ -207,12 +223,20 @@ class _AsyncClient(_ClientAttr):
|
||||
return ensure_exec_coro(AsyncBentoClient.from_url(self._address))
|
||||
|
||||
# Agent integration
|
||||
async def ask_agent(self, task: str, *, return_code: bool = False, remote: bool = False, agent_type: LiteralString = 'hf', **attrs: t.Any) -> t.Any:
|
||||
async def ask_agent(self,
|
||||
task: str,
|
||||
*,
|
||||
return_code: bool = False,
|
||||
remote: bool = False,
|
||||
agent_type: LiteralString = 'hf',
|
||||
**attrs: t.Any) -> t.Any:
|
||||
if agent_type == 'hf': return await self._run_hf_agent(task, return_code=return_code, remote=remote, **attrs)
|
||||
else: raise RuntimeError(f"Unknown 'agent_type={agent_type}'")
|
||||
|
||||
async def _run_hf_agent(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
|
||||
if not is_transformers_supports_agent(): raise RuntimeError('This version of transformers does not support agent.run. Make sure to upgrade to transformers>4.30.0')
|
||||
if not is_transformers_supports_agent():
|
||||
raise RuntimeError(
|
||||
'This version of transformers does not support agent.run. Make sure to upgrade to transformers>4.30.0')
|
||||
if len(args) > 1: raise ValueError("'args' should only take one positional argument.")
|
||||
from transformers.tools.agents import clean_code_for_run
|
||||
from transformers.tools.agents import get_tool_creation_code
|
||||
@@ -225,7 +249,15 @@ class _AsyncClient(_ClientAttr):
|
||||
stop = ['Task:']
|
||||
prompt = t.cast(str, self._hf_agent.format_prompt(task))
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(self.timeout)) as client:
|
||||
response = await client.post(self._hf_agent.url_endpoint, json={'inputs': prompt, 'parameters': {'max_new_tokens': 200, 'return_full_text': False, 'stop': stop}})
|
||||
response = await client.post(self._hf_agent.url_endpoint,
|
||||
json={
|
||||
'inputs': prompt,
|
||||
'parameters': {
|
||||
'max_new_tokens': 200,
|
||||
'return_full_text': False,
|
||||
'stop': stop
|
||||
}
|
||||
})
|
||||
if response.status_code != HTTPStatus.OK: raise ValueError(f'Error {response.status_code}: {response.json()}')
|
||||
|
||||
result = response.json()[0]['generated_text']
|
||||
@@ -240,23 +272,31 @@ class _AsyncClient(_ClientAttr):
|
||||
self._hf_agent.log(f'\n\n==Code generated by the agent==\n{code}')
|
||||
if not return_code:
|
||||
self._hf_agent.log('\n\n==Result==')
|
||||
self._hf_agent.cached_tools = resolve_tools(code, self._hf_agent.toolbox, remote=remote, cached_tools=self._hf_agent.cached_tools)
|
||||
self._hf_agent.cached_tools = resolve_tools(code,
|
||||
self._hf_agent.toolbox,
|
||||
remote=remote,
|
||||
cached_tools=self._hf_agent.cached_tools)
|
||||
return evaluate(code, self._hf_agent.cached_tools, state=kwargs.copy())
|
||||
else:
|
||||
tool_code = get_tool_creation_code(code, self._hf_agent.toolbox, remote=remote)
|
||||
return f'{tool_code}\n{code}'
|
||||
|
||||
class BaseClient(_Client):
|
||||
|
||||
def chat(self, prompt: str, history: list[str], **attrs: t.Any) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def embed(self, prompt: t.Sequence[str] | str) -> openllm_core.EmbeddingsOutput:
|
||||
return openllm_core.EmbeddingsOutput(**self.call('embeddings', list([prompt] if isinstance(prompt, str) else prompt)))
|
||||
return openllm_core.EmbeddingsOutput(
|
||||
**self.call('embeddings', list([prompt] if isinstance(prompt, str) else prompt)))
|
||||
|
||||
def predict(self, prompt: str, **attrs: t.Any) -> openllm_core.GenerationOutput | DictStrAny | str:
|
||||
return self.query(prompt, **attrs)
|
||||
|
||||
def query(self, prompt: str, return_response: t.Literal['attrs', 'raw', 'processed'] = 'processed', **attrs: t.Any) -> t.Any:
|
||||
def query(self,
|
||||
prompt: str,
|
||||
return_response: t.Literal['attrs', 'raw', 'processed'] = 'processed',
|
||||
**attrs: t.Any) -> t.Any:
|
||||
return_raw_response = attrs.pop('return_raw_response', None)
|
||||
if return_raw_response is not None:
|
||||
logger.warning("'return_raw_response' is now deprecated. Please use 'return_response=\"raw\"' instead.")
|
||||
@@ -266,23 +306,32 @@ class BaseClient(_Client):
|
||||
logger.warning("'return_attrs' is now deprecated. Please use 'return_response=\"attrs\"' instead.")
|
||||
if return_attrs is True: return_response = 'attrs'
|
||||
use_default_prompt_template = attrs.pop('use_default_prompt_template', False)
|
||||
prompt, generate_kwargs, postprocess_kwargs = self.config.sanitize_parameters(prompt, use_default_prompt_template=use_default_prompt_template, **attrs)
|
||||
r = openllm_core.GenerationOutput(**self.call('generate', openllm_core.GenerationInput(prompt=prompt, llm_config=self.config.model_construct_env(**generate_kwargs)).model_dump()))
|
||||
prompt, generate_kwargs, postprocess_kwargs = self.config.sanitize_parameters(
|
||||
prompt, use_default_prompt_template=use_default_prompt_template, **attrs)
|
||||
r = openllm_core.GenerationOutput(**self.call(
|
||||
'generate',
|
||||
openllm_core.GenerationInput(prompt=prompt, llm_config=self.config.model_construct_env(
|
||||
**generate_kwargs)).model_dump()))
|
||||
if return_response == 'attrs': return r
|
||||
elif return_response == 'raw': return bentoml_cattr.unstructure(r)
|
||||
else: return self.config.postprocess_generate(prompt, r.responses, **postprocess_kwargs)
|
||||
|
||||
class BaseAsyncClient(_AsyncClient):
|
||||
|
||||
async def chat(self, prompt: str, history: list[str], **attrs: t.Any) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
async def embed(self, prompt: t.Sequence[str] | str) -> openllm_core.EmbeddingsOutput:
|
||||
return openllm_core.EmbeddingsOutput(**(await self.call('embeddings', list([prompt] if isinstance(prompt, str) else prompt))))
|
||||
return openllm_core.EmbeddingsOutput(
|
||||
**(await self.call('embeddings', list([prompt] if isinstance(prompt, str) else prompt))))
|
||||
|
||||
async def predict(self, prompt: str, **attrs: t.Any) -> t.Any:
|
||||
return await self.query(prompt, **attrs)
|
||||
|
||||
async def query(self, prompt: str, return_response: t.Literal['attrs', 'raw', 'processed'] = 'processed', **attrs: t.Any) -> t.Any:
|
||||
async def query(self,
|
||||
prompt: str,
|
||||
return_response: t.Literal['attrs', 'raw', 'processed'] = 'processed',
|
||||
**attrs: t.Any) -> t.Any:
|
||||
return_raw_response = attrs.pop('return_raw_response', None)
|
||||
if return_raw_response is not None:
|
||||
logger.warning("'return_raw_response' is now deprecated. Please use 'return_response=\"raw\"' instead.")
|
||||
@@ -292,8 +341,12 @@ class BaseAsyncClient(_AsyncClient):
|
||||
logger.warning("'return_attrs' is now deprecated. Please use 'return_response=\"attrs\"' instead.")
|
||||
if return_attrs is True: return_response = 'attrs'
|
||||
use_default_prompt_template = attrs.pop('use_default_prompt_template', False)
|
||||
prompt, generate_kwargs, postprocess_kwargs = self.config.sanitize_parameters(prompt, use_default_prompt_template=use_default_prompt_template, **attrs)
|
||||
r = openllm_core.GenerationOutput(**(await self.call('generate', openllm_core.GenerationInput(prompt=prompt, llm_config=self.config.model_construct_env(**generate_kwargs)).model_dump())))
|
||||
prompt, generate_kwargs, postprocess_kwargs = self.config.sanitize_parameters(
|
||||
prompt, use_default_prompt_template=use_default_prompt_template, **attrs)
|
||||
r = openllm_core.GenerationOutput(**(await self.call(
|
||||
'generate',
|
||||
openllm_core.GenerationInput(prompt=prompt, llm_config=self.config.model_construct_env(
|
||||
**generate_kwargs)).model_dump())))
|
||||
if return_response == 'attrs': return r
|
||||
elif return_response == 'raw': return bentoml_cattr.unstructure(r)
|
||||
else: return self.config.postprocess_generate(prompt, r.responses, **postprocess_kwargs)
|
||||
|
||||
@@ -18,7 +18,8 @@ from openllm_core.utils import ensure_exec_coro
|
||||
from openllm_core.utils import is_grpc_available
|
||||
from openllm_core.utils import is_grpc_health_available
|
||||
|
||||
if not is_grpc_available() or not is_grpc_health_available(): raise ImportError("gRPC is required to use gRPC client. Install with 'pip install \"openllm-client[grpc]\"'.")
|
||||
if not is_grpc_available() or not is_grpc_health_available():
|
||||
raise ImportError("gRPC is required to use gRPC client. Install with 'pip install \"openllm-client[grpc]\"'.")
|
||||
import grpc
|
||||
import grpc_health.v1.health_pb2 as pb_health
|
||||
import grpc_health.v1.health_pb2_grpc as services_health
|
||||
@@ -39,48 +40,53 @@ class ClientCredentials(t.TypedDict):
|
||||
certificate_chain: NotRequired[t.Union[bytes, str]]
|
||||
|
||||
@overload
|
||||
def dispatch_channel(
|
||||
server_url: str,
|
||||
typ: t.Literal['async'],
|
||||
ssl: bool = ...,
|
||||
ssl_client_credentials: ClientCredentials | None = ...,
|
||||
options: t.Any | None = ...,
|
||||
compression: grpc.Compression | None = ...,
|
||||
interceptors: t.Sequence[aio.ClientInterceptor] | None = ...
|
||||
) -> aio.Channel:
|
||||
def dispatch_channel(server_url: str,
|
||||
typ: t.Literal['async'],
|
||||
ssl: bool = ...,
|
||||
ssl_client_credentials: ClientCredentials | None = ...,
|
||||
options: t.Any | None = ...,
|
||||
compression: grpc.Compression | None = ...,
|
||||
interceptors: t.Sequence[aio.ClientInterceptor] | None = ...) -> aio.Channel:
|
||||
...
|
||||
|
||||
@overload
|
||||
def dispatch_channel(
|
||||
server_url: str,
|
||||
typ: t.Literal['sync'],
|
||||
ssl: bool = ...,
|
||||
ssl_client_credentials: ClientCredentials | None = ...,
|
||||
options: t.Any | None = ...,
|
||||
compression: grpc.Compression | None = ...,
|
||||
interceptors: t.Sequence[aio.ClientInterceptor] | None = None
|
||||
) -> grpc.Channel:
|
||||
def dispatch_channel(server_url: str,
|
||||
typ: t.Literal['sync'],
|
||||
ssl: bool = ...,
|
||||
ssl_client_credentials: ClientCredentials | None = ...,
|
||||
options: t.Any | None = ...,
|
||||
compression: grpc.Compression | None = ...,
|
||||
interceptors: t.Sequence[aio.ClientInterceptor] | None = None) -> grpc.Channel:
|
||||
...
|
||||
|
||||
def dispatch_channel(
|
||||
server_url: str,
|
||||
typ: t.Literal['async', 'sync'] = 'sync',
|
||||
ssl: bool = False,
|
||||
ssl_client_credentials: ClientCredentials | None = None,
|
||||
options: t.Any | None = None,
|
||||
compression: grpc.Compression | None = None,
|
||||
interceptors: t.Sequence[aio.ClientInterceptor] | None = None
|
||||
) -> aio.Channel | grpc.Channel:
|
||||
def dispatch_channel(server_url: str,
|
||||
typ: t.Literal['async', 'sync'] = 'sync',
|
||||
ssl: bool = False,
|
||||
ssl_client_credentials: ClientCredentials | None = None,
|
||||
options: t.Any | None = None,
|
||||
compression: grpc.Compression | None = None,
|
||||
interceptors: t.Sequence[aio.ClientInterceptor] | None = None) -> aio.Channel | grpc.Channel:
|
||||
credentials = None
|
||||
if ssl:
|
||||
if ssl_client_credentials is None: raise RuntimeError("'ssl=True' requires 'ssl_client_credentials'")
|
||||
credentials = grpc.ssl_channel_credentials(**{k: load_from_file(v) if isinstance(v, str) else v for k, v in ssl_client_credentials.items()})
|
||||
credentials = grpc.ssl_channel_credentials(**{
|
||||
k: load_from_file(v) if isinstance(v, str) else v for k, v in ssl_client_credentials.items()
|
||||
})
|
||||
|
||||
if typ == 'async' and ssl: return aio.secure_channel(server_url, credentials=credentials, options=options, compression=compression, interceptors=interceptors)
|
||||
elif typ == 'async': return aio.insecure_channel(server_url, options=options, compression=compression, interceptors=interceptors)
|
||||
elif typ == 'sync' and ssl: return grpc.secure_channel(server_url, credentials=credentials, options=options, compression=compression)
|
||||
elif typ == 'sync': return grpc.insecure_channel(server_url, options=options, compression=compression)
|
||||
else: raise ValueError(f'Unknown type: {typ}')
|
||||
if typ == 'async' and ssl:
|
||||
return aio.secure_channel(server_url,
|
||||
credentials=credentials,
|
||||
options=options,
|
||||
compression=compression,
|
||||
interceptors=interceptors)
|
||||
elif typ == 'async':
|
||||
return aio.insecure_channel(server_url, options=options, compression=compression, interceptors=interceptors)
|
||||
elif typ == 'sync' and ssl:
|
||||
return grpc.secure_channel(server_url, credentials=credentials, options=options, compression=compression)
|
||||
elif typ == 'sync':
|
||||
return grpc.insecure_channel(server_url, options=options, compression=compression)
|
||||
else:
|
||||
raise ValueError(f'Unknown type: {typ}')
|
||||
|
||||
class GrpcClient(Client):
|
||||
ssl: bool
|
||||
@@ -88,16 +94,14 @@ class GrpcClient(Client):
|
||||
options: t.Any
|
||||
compression: t.Optional[grpc.Compression]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
svc: bentoml.Service, # gRPC specific options
|
||||
ssl: bool = False,
|
||||
options: t.Any | None = None,
|
||||
compression: grpc.Compression | None = None,
|
||||
ssl_client_credentials: ClientCredentials | None = None,
|
||||
**kwargs: t.Any
|
||||
) -> None:
|
||||
def __init__(self,
|
||||
server_url: str,
|
||||
svc: bentoml.Service, # gRPC specific options
|
||||
ssl: bool = False,
|
||||
options: t.Any | None = None,
|
||||
compression: grpc.Compression | None = None,
|
||||
ssl_client_credentials: ClientCredentials | None = None,
|
||||
**kwargs: t.Any) -> None:
|
||||
self.ssl, self.ssl_client_credentials, self.options, self.compression = ssl, ssl_client_credentials, options, compression
|
||||
super().__init__(server_url, svc, **kwargs)
|
||||
|
||||
@@ -105,20 +109,27 @@ class GrpcClient(Client):
|
||||
def inner(self) -> grpc.Channel:
|
||||
if self.ssl:
|
||||
if self.ssl_client_credentials is None: raise RuntimeError("'ssl=True' requires 'ssl_client_credentials'")
|
||||
credentials = grpc.ssl_channel_credentials(**{k: load_from_file(v) if isinstance(v, str) else v for k, v in self.ssl_client_credentials.items()})
|
||||
return grpc.secure_channel(self.server_url, credentials=credentials, options=self.options, compression=self.compression)
|
||||
credentials = grpc.ssl_channel_credentials(**{
|
||||
k: load_from_file(v) if isinstance(v, str) else v for k, v in self.ssl_client_credentials.items()
|
||||
})
|
||||
return grpc.secure_channel(self.server_url,
|
||||
credentials=credentials,
|
||||
options=self.options,
|
||||
compression=self.compression)
|
||||
return grpc.insecure_channel(self.server_url, options=self.options, compression=self.compression)
|
||||
|
||||
@staticmethod
|
||||
def wait_until_server_ready(host: str, port: int, timeout: float = 30, check_interval: int = 1, **kwargs: t.Any) -> None:
|
||||
with dispatch_channel(
|
||||
f"{host.replace(r'localhost', '0.0.0.0')}:{port}",
|
||||
typ='sync',
|
||||
options=kwargs.get('options', None),
|
||||
compression=kwargs.get('compression', None),
|
||||
ssl=kwargs.get('ssl', False),
|
||||
ssl_client_credentials=kwargs.get('ssl_client_credentials', None)
|
||||
) as channel:
|
||||
def wait_until_server_ready(host: str,
|
||||
port: int,
|
||||
timeout: float = 30,
|
||||
check_interval: int = 1,
|
||||
**kwargs: t.Any) -> None:
|
||||
with dispatch_channel(f"{host.replace(r'localhost', '0.0.0.0')}:{port}",
|
||||
typ='sync',
|
||||
options=kwargs.get('options', None),
|
||||
compression=kwargs.get('compression', None),
|
||||
ssl=kwargs.get('ssl', False),
|
||||
ssl_client_credentials=kwargs.get('ssl_client_credentials', None)) as channel:
|
||||
req = pb_health.HealthCheckRequest()
|
||||
req.service = 'bentoml.grpc.v1.BentoService'
|
||||
health_stub = services_health.HealthStub(channel)
|
||||
@@ -133,7 +144,8 @@ class GrpcClient(Client):
|
||||
time.sleep(check_interval)
|
||||
try:
|
||||
resp = health_stub.Check(req)
|
||||
if resp.status != pb_health.HealthCheckResponse.SERVING: raise TimeoutError(f"Timed out waiting {timeout} seconds for server at '{host}:{port}' to be ready.")
|
||||
if resp.status != pb_health.HealthCheckResponse.SERVING:
|
||||
raise TimeoutError(f"Timed out waiting {timeout} seconds for server at '{host}:{port}' to be ready.")
|
||||
except grpc.RpcError as err:
|
||||
logger.error('Caught RpcError while connecting to %s:%s:\n', host, port)
|
||||
logger.error(err)
|
||||
@@ -141,34 +153,32 @@ class GrpcClient(Client):
|
||||
|
||||
@classmethod
|
||||
def from_url(cls, url: str, **kwargs: t.Any) -> GrpcClient:
|
||||
with dispatch_channel(
|
||||
url.replace(r'localhost', '0.0.0.0'),
|
||||
typ='sync',
|
||||
options=kwargs.get('options', None),
|
||||
compression=kwargs.get('compression', None),
|
||||
ssl=kwargs.get('ssl', False),
|
||||
ssl_client_credentials=kwargs.get('ssl_client_credentials', None)
|
||||
) as channel:
|
||||
with dispatch_channel(url.replace(r'localhost', '0.0.0.0'),
|
||||
typ='sync',
|
||||
options=kwargs.get('options', None),
|
||||
compression=kwargs.get('compression', None),
|
||||
ssl=kwargs.get('ssl', False),
|
||||
ssl_client_credentials=kwargs.get('ssl_client_credentials', None)) as channel:
|
||||
metadata = t.cast(
|
||||
'ServiceMetadataResponse',
|
||||
channel.unary_unary(
|
||||
'/bentoml.grpc.v1.BentoService/ServiceMetadata', request_serializer=pb.ServiceMetadataRequest.SerializeToString, response_deserializer=pb.ServiceMetadataResponse.FromString
|
||||
)(pb.ServiceMetadataRequest())
|
||||
)
|
||||
channel.unary_unary('/bentoml.grpc.v1.BentoService/ServiceMetadata',
|
||||
request_serializer=pb.ServiceMetadataRequest.SerializeToString,
|
||||
response_deserializer=pb.ServiceMetadataResponse.FromString)(pb.ServiceMetadataRequest()))
|
||||
reflection = bentoml.Service(metadata.name)
|
||||
for api in metadata.apis:
|
||||
try:
|
||||
reflection.apis[api.name] = InferenceAPI[t.Any](
|
||||
None,
|
||||
bentoml.io.from_spec({
|
||||
'id': api.input.descriptor_id, 'args': json_format.MessageToDict(api.input.attributes).get('args', None)
|
||||
'id': api.input.descriptor_id,
|
||||
'args': json_format.MessageToDict(api.input.attributes).get('args', None)
|
||||
}),
|
||||
bentoml.io.from_spec({
|
||||
'id': api.output.descriptor_id, 'args': json_format.MessageToDict(api.output.attributes).get('args', None)
|
||||
'id': api.output.descriptor_id,
|
||||
'args': json_format.MessageToDict(api.output.attributes).get('args', None)
|
||||
}),
|
||||
name=api.name,
|
||||
doc=api.docs
|
||||
)
|
||||
doc=api.docs)
|
||||
except Exception as e:
|
||||
logger.error('Failed to instantiate client for API %s: ', api.name, e)
|
||||
return cls(url, reflection, **kwargs)
|
||||
@@ -177,15 +187,24 @@ class GrpcClient(Client):
|
||||
return services_health.HealthStub(self.inner).Check(pb_health.HealthCheckRequest(service=''))
|
||||
|
||||
def _call(self, data: t.Any, /, *, _inference_api: InferenceAPI[t.Any], **kwargs: t.Any) -> t.Any:
|
||||
channel_kwargs = {k: kwargs.pop(f'_grpc_channel_{k}', None) for k in {'timeout', 'metadata', 'credentials', 'wait_for_ready', 'compression'}}
|
||||
channel_kwargs = {
|
||||
k: kwargs.pop(f'_grpc_channel_{k}', None)
|
||||
for k in {'timeout', 'metadata', 'credentials', 'wait_for_ready', 'compression'}
|
||||
}
|
||||
if _inference_api.multi_input:
|
||||
if data is not None: raise ValueError(f"'{_inference_api.name}' takes multiple inputs, and thus required to pass as keyword arguments.")
|
||||
if data is not None:
|
||||
raise ValueError(
|
||||
f"'{_inference_api.name}' takes multiple inputs, and thus required to pass as keyword arguments.")
|
||||
fake_resp = ensure_exec_coro(_inference_api.input.to_proto(kwargs))
|
||||
else:
|
||||
fake_resp = ensure_exec_coro(_inference_api.input.to_proto(data))
|
||||
api_fn = {v: k for k, v in self.svc.apis.items()}
|
||||
stubs = services.BentoServiceStub(self.inner)
|
||||
proto = stubs.Call(pb.Request(**{'api_name': api_fn[_inference_api], _inference_api.input.proto_fields[0]: fake_resp}), **channel_kwargs)
|
||||
proto = stubs.Call(
|
||||
pb.Request(**{
|
||||
'api_name': api_fn[_inference_api],
|
||||
_inference_api.input.proto_fields[0]: fake_resp
|
||||
}), **channel_kwargs)
|
||||
return ensure_exec_coro(_inference_api.output.from_proto(getattr(proto, proto.WhichOneof('content'))))
|
||||
|
||||
class AsyncGrpcClient(AsyncClient):
|
||||
@@ -195,17 +214,15 @@ class AsyncGrpcClient(AsyncClient):
|
||||
interceptors: t.Optional[t.Sequence[aio.ClientInterceptor]]
|
||||
compression: t.Optional[grpc.Compression]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
svc: bentoml.Service, # gRPC specific options
|
||||
ssl: bool = False,
|
||||
options: aio.ChannelArgumentType | None = None,
|
||||
interceptors: t.Sequence[aio.ClientInterceptor] | None = None,
|
||||
compression: grpc.Compression | None = None,
|
||||
ssl_client_credentials: ClientCredentials | None = None,
|
||||
**kwargs: t.Any
|
||||
) -> None:
|
||||
def __init__(self,
|
||||
server_url: str,
|
||||
svc: bentoml.Service, # gRPC specific options
|
||||
ssl: bool = False,
|
||||
options: aio.ChannelArgumentType | None = None,
|
||||
interceptors: t.Sequence[aio.ClientInterceptor] | None = None,
|
||||
compression: grpc.Compression | None = None,
|
||||
ssl_client_credentials: ClientCredentials | None = None,
|
||||
**kwargs: t.Any) -> None:
|
||||
self.ssl, self.ssl_client_credentials, self.options, self.interceptors, self.compression = ssl, ssl_client_credentials, options, interceptors, compression
|
||||
super().__init__(server_url, svc, **kwargs)
|
||||
|
||||
@@ -213,20 +230,31 @@ class AsyncGrpcClient(AsyncClient):
|
||||
def inner(self) -> aio.Channel:
|
||||
if self.ssl:
|
||||
if self.ssl_client_credentials is None: raise RuntimeError("'ssl=True' requires 'ssl_client_credentials'")
|
||||
credentials = grpc.ssl_channel_credentials(**{k: load_from_file(v) if isinstance(v, str) else v for k, v in self.ssl_client_credentials.items()})
|
||||
return aio.secure_channel(self.server_url, credentials=credentials, options=self.options, compression=self.compression, interceptors=self.interceptors)
|
||||
return aio.insecure_channel(self.server_url, options=self.options, compression=self.compression, interceptors=self.interceptors)
|
||||
credentials = grpc.ssl_channel_credentials(**{
|
||||
k: load_from_file(v) if isinstance(v, str) else v for k, v in self.ssl_client_credentials.items()
|
||||
})
|
||||
return aio.secure_channel(self.server_url,
|
||||
credentials=credentials,
|
||||
options=self.options,
|
||||
compression=self.compression,
|
||||
interceptors=self.interceptors)
|
||||
return aio.insecure_channel(self.server_url,
|
||||
options=self.options,
|
||||
compression=self.compression,
|
||||
interceptors=self.interceptors)
|
||||
|
||||
@staticmethod
|
||||
async def wait_until_server_ready(host: str, port: int, timeout: float = 30, check_interval: int = 1, **kwargs: t.Any) -> None:
|
||||
async with dispatch_channel(
|
||||
f"{host.replace(r'localhost', '0.0.0.0')}:{port}",
|
||||
typ='async',
|
||||
options=kwargs.get('options', None),
|
||||
compression=kwargs.get('compression', None),
|
||||
ssl=kwargs.get('ssl', False),
|
||||
ssl_client_credentials=kwargs.get('ssl_client_credentials', None)
|
||||
) as channel:
|
||||
async def wait_until_server_ready(host: str,
|
||||
port: int,
|
||||
timeout: float = 30,
|
||||
check_interval: int = 1,
|
||||
**kwargs: t.Any) -> None:
|
||||
async with dispatch_channel(f"{host.replace(r'localhost', '0.0.0.0')}:{port}",
|
||||
typ='async',
|
||||
options=kwargs.get('options', None),
|
||||
compression=kwargs.get('compression', None),
|
||||
ssl=kwargs.get('ssl', False),
|
||||
ssl_client_credentials=kwargs.get('ssl_client_credentials', None)) as channel:
|
||||
req = pb_health.HealthCheckRequest()
|
||||
req.service = 'bentoml.grpc.v1.BentoService'
|
||||
health_stub = services_health.HealthStub(channel)
|
||||
@@ -241,7 +269,8 @@ class AsyncGrpcClient(AsyncClient):
|
||||
time.sleep(check_interval)
|
||||
try:
|
||||
resp = health_stub.Check(req)
|
||||
if resp.status != pb_health.HealthCheckResponse.SERVING: raise TimeoutError(f"Timed out waiting {timeout} seconds for server at '{host}:{port}' to be ready.")
|
||||
if resp.status != pb_health.HealthCheckResponse.SERVING:
|
||||
raise TimeoutError(f"Timed out waiting {timeout} seconds for server at '{host}:{port}' to be ready.")
|
||||
except grpc.RpcError as err:
|
||||
logger.error('Caught RpcError while connecting to %s:%s:\n', host, port)
|
||||
logger.error(err)
|
||||
@@ -249,35 +278,33 @@ class AsyncGrpcClient(AsyncClient):
|
||||
|
||||
@classmethod
|
||||
async def from_url(cls, url: str, **kwargs: t.Any) -> AsyncGrpcClient:
|
||||
async with dispatch_channel(
|
||||
url.replace(r'localhost', '0.0.0.0'),
|
||||
typ='async',
|
||||
options=kwargs.get('options', None),
|
||||
compression=kwargs.get('compression', None),
|
||||
ssl=kwargs.get('ssl', False),
|
||||
ssl_client_credentials=kwargs.get('ssl_client_credentials', None),
|
||||
interceptors=kwargs.get('interceptors', None)
|
||||
) as channel:
|
||||
async with dispatch_channel(url.replace(r'localhost', '0.0.0.0'),
|
||||
typ='async',
|
||||
options=kwargs.get('options', None),
|
||||
compression=kwargs.get('compression', None),
|
||||
ssl=kwargs.get('ssl', False),
|
||||
ssl_client_credentials=kwargs.get('ssl_client_credentials', None),
|
||||
interceptors=kwargs.get('interceptors', None)) as channel:
|
||||
metadata = t.cast(
|
||||
'ServiceMetadataResponse',
|
||||
channel.unary_unary(
|
||||
'/bentoml.grpc.v1.BentoService/ServiceMetadata', request_serializer=pb.ServiceMetadataRequest.SerializeToString, response_deserializer=pb.ServiceMetadataResponse.FromString
|
||||
)(pb.ServiceMetadataRequest())
|
||||
)
|
||||
channel.unary_unary('/bentoml.grpc.v1.BentoService/ServiceMetadata',
|
||||
request_serializer=pb.ServiceMetadataRequest.SerializeToString,
|
||||
response_deserializer=pb.ServiceMetadataResponse.FromString)(pb.ServiceMetadataRequest()))
|
||||
reflection = bentoml.Service(metadata.name)
|
||||
for api in metadata.apis:
|
||||
try:
|
||||
reflection.apis[api.name] = InferenceAPI[t.Any](
|
||||
None,
|
||||
bentoml.io.from_spec({
|
||||
'id': api.input.descriptor_id, 'args': json_format.MessageToDict(api.input.attributes).get('args', None)
|
||||
'id': api.input.descriptor_id,
|
||||
'args': json_format.MessageToDict(api.input.attributes).get('args', None)
|
||||
}),
|
||||
bentoml.io.from_spec({
|
||||
'id': api.output.descriptor_id, 'args': json_format.MessageToDict(api.output.attributes).get('args', None)
|
||||
'id': api.output.descriptor_id,
|
||||
'args': json_format.MessageToDict(api.output.attributes).get('args', None)
|
||||
}),
|
||||
name=api.name,
|
||||
doc=api.docs
|
||||
)
|
||||
doc=api.docs)
|
||||
except Exception as e:
|
||||
logger.error('Failed to instantiate client for API %s: ', api.name, e)
|
||||
return cls(url, reflection, **kwargs)
|
||||
@@ -286,16 +313,25 @@ class AsyncGrpcClient(AsyncClient):
|
||||
return await services_health.HealthStub(self.inner).Check(pb_health.HealthCheckRequest(service=''))
|
||||
|
||||
async def _call(self, data: t.Any, /, *, _inference_api: InferenceAPI[t.Any], **kwargs: t.Any) -> t.Any:
|
||||
channel_kwargs = {k: kwargs.pop(f'_grpc_channel_{k}', None) for k in {'timeout', 'metadata', 'credentials', 'wait_for_ready', 'compression'}}
|
||||
channel_kwargs = {
|
||||
k: kwargs.pop(f'_grpc_channel_{k}', None)
|
||||
for k in {'timeout', 'metadata', 'credentials', 'wait_for_ready', 'compression'}
|
||||
}
|
||||
state = self.inner.get_state(try_to_connect=True)
|
||||
if state != grpc.ChannelConnectivity.READY: await self.inner.channel_ready()
|
||||
if _inference_api.multi_input:
|
||||
if data is not None: raise ValueError(f"'{_inference_api.name}' takes multiple inputs, and thus required to pass as keyword arguments.")
|
||||
if data is not None:
|
||||
raise ValueError(
|
||||
f"'{_inference_api.name}' takes multiple inputs, and thus required to pass as keyword arguments.")
|
||||
fake_resp = await _inference_api.input.to_proto(kwargs)
|
||||
else:
|
||||
fake_resp = await _inference_api.input.to_proto(data)
|
||||
api_fn = {v: k for k, v in self.svc.apis.items()}
|
||||
async with self.inner:
|
||||
stubs = services.BentoServiceStub(self.inner)
|
||||
proto = await stubs.Call(pb.Request(**{'api_name': api_fn[_inference_api], _inference_api.input.proto_fields[0]: fake_resp}), **channel_kwargs)
|
||||
proto = await stubs.Call(
|
||||
pb.Request(**{
|
||||
'api_name': api_fn[_inference_api],
|
||||
_inference_api.input.proto_fields[0]: fake_resp
|
||||
}), **channel_kwargs)
|
||||
return await _inference_api.output.from_proto(getattr(proto, proto.WhichOneof('content')))
|
||||
|
||||
@@ -24,13 +24,18 @@ from openllm_core.utils import ensure_exec_coro
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class HttpClient(Client):
|
||||
|
||||
@functools.cached_property
|
||||
def inner(self) -> httpx.Client:
|
||||
if not urlparse(self.server_url).netloc: raise ValueError(f'Invalid server url: {self.server_url}')
|
||||
return httpx.Client(base_url=self.server_url)
|
||||
|
||||
@staticmethod
|
||||
def wait_until_server_ready(host: str, port: int, timeout: float = 30, check_interval: int = 1, **kwargs: t.Any) -> None:
|
||||
def wait_until_server_ready(host: str,
|
||||
port: int,
|
||||
timeout: float = 30,
|
||||
check_interval: int = 1,
|
||||
**kwargs: t.Any) -> None:
|
||||
host = host if '://' in host else 'http://' + host
|
||||
logger.debug('Waiting for server @ `%s:%d` to be ready...', host, port)
|
||||
start = time.time()
|
||||
@@ -57,7 +62,10 @@ class HttpClient(Client):
|
||||
def from_url(cls, url: str, **kwargs: t.Any) -> HttpClient:
|
||||
url = url if '://' in url else 'http://' + url
|
||||
resp = httpx.get(f'{url}/docs.json')
|
||||
if resp.status_code != 200: raise ValueError(f'Failed to get OpenAPI schema from the server: {resp.status_code} {resp.reason_phrase}:\n{resp.content.decode()}')
|
||||
if resp.status_code != 200:
|
||||
raise ValueError(
|
||||
f'Failed to get OpenAPI schema from the server: {resp.status_code} {resp.reason_phrase}:\n{resp.content.decode()}'
|
||||
)
|
||||
_spec = orjson.loads(resp.content)
|
||||
|
||||
reflection = bentoml.Service(_spec['info']['title'])
|
||||
@@ -65,9 +73,12 @@ class HttpClient(Client):
|
||||
for route, spec in _spec['paths'].items():
|
||||
for meth_spec in spec.values():
|
||||
if 'tags' in meth_spec and 'Service APIs' in meth_spec['tags']:
|
||||
if 'x-bentoml-io-descriptor' not in meth_spec['requestBody']: raise ValueError(f'Malformed BentoML spec received from BentoML server {url}')
|
||||
if 'x-bentoml-io-descriptor' not in meth_spec['responses']['200']: raise ValueError(f'Malformed BentoML spec received from BentoML server {url}')
|
||||
if 'x-bentoml-name' not in meth_spec: raise ValueError(f'Malformed BentoML spec received from BentoML server {url}')
|
||||
if 'x-bentoml-io-descriptor' not in meth_spec['requestBody']:
|
||||
raise ValueError(f'Malformed BentoML spec received from BentoML server {url}')
|
||||
if 'x-bentoml-io-descriptor' not in meth_spec['responses']['200']:
|
||||
raise ValueError(f'Malformed BentoML spec received from BentoML server {url}')
|
||||
if 'x-bentoml-name' not in meth_spec:
|
||||
raise ValueError(f'Malformed BentoML spec received from BentoML server {url}')
|
||||
try:
|
||||
reflection.apis[meth_spec['x-bentoml-name']] = InferenceAPI[t.Any](
|
||||
None,
|
||||
@@ -75,8 +86,7 @@ class HttpClient(Client):
|
||||
bentoml.io.from_spec(meth_spec['responses']['200']['x-bentoml-io-descriptor']),
|
||||
name=meth_spec['x-bentoml-name'],
|
||||
doc=meth_spec['description'],
|
||||
route=route.lstrip('/')
|
||||
)
|
||||
route=route.lstrip('/'))
|
||||
except Exception as e:
|
||||
logger.error('Failed to instantiate client for API %s: ', meth_spec['x-bentoml-name'], e)
|
||||
return cls(url, reflection)
|
||||
@@ -85,7 +95,9 @@ class HttpClient(Client):
|
||||
# All gRPC kwargs should be popped out.
|
||||
kwargs = {k: v for k, v in kwargs.items() if not k.startswith('_grpc_')}
|
||||
if _inference_api.multi_input:
|
||||
if data is not None: raise ValueError(f"'{_inference_api.name}' takes multiple inputs, and thus required to pass as keyword arguments.")
|
||||
if data is not None:
|
||||
raise ValueError(
|
||||
f"'{_inference_api.name}' takes multiple inputs, and thus required to pass as keyword arguments.")
|
||||
fake_resp = ensure_exec_coro(_inference_api.input.to_http_response(kwargs, None))
|
||||
else:
|
||||
fake_resp = ensure_exec_coro(_inference_api.input.to_http_response(data, None))
|
||||
@@ -94,12 +106,11 @@ class HttpClient(Client):
|
||||
if isinstance(fake_resp, starlette.responses.StreamingResponse): body = None
|
||||
else: body = fake_resp.body
|
||||
|
||||
resp = self.inner.post(
|
||||
'/' + _inference_api.route if not _inference_api.route.startswith('/') else _inference_api.route,
|
||||
data=body,
|
||||
headers={'content-type': fake_resp.headers['content-type']},
|
||||
timeout=self.timeout
|
||||
)
|
||||
resp = self.inner.post('/' +
|
||||
_inference_api.route if not _inference_api.route.startswith('/') else _inference_api.route,
|
||||
data=body,
|
||||
headers={'content-type': fake_resp.headers['content-type']},
|
||||
timeout=self.timeout)
|
||||
if resp.status_code != 200: raise ValueError(f'Error while making request: {resp.status_code}: {resp.content!s}')
|
||||
fake_req = starlette.requests.Request(scope={'type': 'http'})
|
||||
headers = starlette.datastructures.Headers(headers=resp.headers)
|
||||
@@ -109,13 +120,18 @@ class HttpClient(Client):
|
||||
return ensure_exec_coro(_inference_api.output.from_http_request(fake_req))
|
||||
|
||||
class AsyncHttpClient(AsyncClient):
|
||||
|
||||
@functools.cached_property
|
||||
def inner(self) -> httpx.AsyncClient:
|
||||
if not urlparse(self.server_url).netloc: raise ValueError(f'Invalid server url: {self.server_url}')
|
||||
return httpx.AsyncClient(base_url=self.server_url)
|
||||
|
||||
@staticmethod
|
||||
async def wait_until_server_ready(host: str, port: int, timeout: float = 30, check_interval: int = 1, **kwargs: t.Any) -> None:
|
||||
async def wait_until_server_ready(host: str,
|
||||
port: int,
|
||||
timeout: float = 30,
|
||||
check_interval: int = 1,
|
||||
**kwargs: t.Any) -> None:
|
||||
host = host if '://' in host else 'http://' + host
|
||||
logger.debug('Waiting for server @ `%s:%d` to be ready...', host, port)
|
||||
start = time.time()
|
||||
@@ -131,7 +147,9 @@ class AsyncHttpClient(AsyncClient):
|
||||
# Try once more and raise for exception
|
||||
async with httpx.AsyncClient(base_url=f'{host}:{port}') as sess:
|
||||
resp = await sess.get('/readyz')
|
||||
if resp.status_code != 200: raise TimeoutError(f'Timeout while waiting for server @ `{host}:{port}` to be ready: {resp.status_code}: {resp.content!s}')
|
||||
if resp.status_code != 200:
|
||||
raise TimeoutError(
|
||||
f'Timeout while waiting for server @ `{host}:{port}` to be ready: {resp.status_code}: {resp.content!s}')
|
||||
|
||||
async def health(self) -> httpx.Response:
|
||||
return await self.inner.get('/readyz')
|
||||
@@ -141,7 +159,10 @@ class AsyncHttpClient(AsyncClient):
|
||||
url = url if '://' in url else 'http://' + url
|
||||
async with httpx.AsyncClient(base_url=url) as session:
|
||||
resp = await session.get('/docs.json')
|
||||
if resp.status_code != 200: raise ValueError(f'Failed to get OpenAPI schema from the server: {resp.status_code} {resp.reason_phrase}:\n{(await resp.aread()).decode()}')
|
||||
if resp.status_code != 200:
|
||||
raise ValueError(
|
||||
f'Failed to get OpenAPI schema from the server: {resp.status_code} {resp.reason_phrase}:\n{(await resp.aread()).decode()}'
|
||||
)
|
||||
_spec = orjson.loads(await resp.aread())
|
||||
|
||||
reflection = bentoml.Service(_spec['info']['title'])
|
||||
@@ -149,9 +170,12 @@ class AsyncHttpClient(AsyncClient):
|
||||
for route, spec in _spec['paths'].items():
|
||||
for meth_spec in spec.values():
|
||||
if 'tags' in meth_spec and 'Service APIs' in meth_spec['tags']:
|
||||
if 'x-bentoml-io-descriptor' not in meth_spec['requestBody']: raise ValueError(f'Malformed BentoML spec received from BentoML server {url}')
|
||||
if 'x-bentoml-io-descriptor' not in meth_spec['responses']['200']: raise ValueError(f'Malformed BentoML spec received from BentoML server {url}')
|
||||
if 'x-bentoml-name' not in meth_spec: raise ValueError(f'Malformed BentoML spec received from BentoML server {url}')
|
||||
if 'x-bentoml-io-descriptor' not in meth_spec['requestBody']:
|
||||
raise ValueError(f'Malformed BentoML spec received from BentoML server {url}')
|
||||
if 'x-bentoml-io-descriptor' not in meth_spec['responses']['200']:
|
||||
raise ValueError(f'Malformed BentoML spec received from BentoML server {url}')
|
||||
if 'x-bentoml-name' not in meth_spec:
|
||||
raise ValueError(f'Malformed BentoML spec received from BentoML server {url}')
|
||||
try:
|
||||
reflection.apis[meth_spec['x-bentoml-name']] = InferenceAPI[t.Any](
|
||||
None,
|
||||
@@ -159,8 +183,7 @@ class AsyncHttpClient(AsyncClient):
|
||||
bentoml.io.from_spec(meth_spec['responses']['200']['x-bentoml-io-descriptor']),
|
||||
name=meth_spec['x-bentoml-name'],
|
||||
doc=meth_spec['description'],
|
||||
route=route.lstrip('/')
|
||||
)
|
||||
route=route.lstrip('/'))
|
||||
except ValueError as e:
|
||||
logger.error('Failed to instantiate client for API %s: ', meth_spec['x-bentoml-name'], e)
|
||||
return cls(url, reflection)
|
||||
@@ -169,7 +192,9 @@ class AsyncHttpClient(AsyncClient):
|
||||
# All gRPC kwargs should be popped out.
|
||||
kwargs = {k: v for k, v in kwargs.items() if not k.startswith('_grpc_')}
|
||||
if _inference_api.multi_input:
|
||||
if data is not None: raise ValueError(f"'{_inference_api.name}' takes multiple inputs, and thus required to pass as keyword arguments.")
|
||||
if data is not None:
|
||||
raise ValueError(
|
||||
f"'{_inference_api.name}' takes multiple inputs, and thus required to pass as keyword arguments.")
|
||||
fake_resp = await _inference_api.input.to_http_response(kwargs, None)
|
||||
else:
|
||||
fake_resp = await _inference_api.input.to_http_response(data, None)
|
||||
@@ -182,8 +207,7 @@ class AsyncHttpClient(AsyncClient):
|
||||
'/' + _inference_api.route if not _inference_api.route.startswith('/') else _inference_api.route,
|
||||
data=body,
|
||||
headers={'content-type': fake_resp.headers['content-type']},
|
||||
timeout=self.timeout
|
||||
)
|
||||
timeout=self.timeout)
|
||||
if resp.status_code != 200: raise ValueError(f'Error making request: {resp.status_code}: {(await resp.aread())!s}')
|
||||
fake_req = starlette.requests.Request(scope={'type': 'http'})
|
||||
headers = starlette.datastructures.Headers(headers=resp.headers)
|
||||
|
||||
@@ -16,21 +16,25 @@ def process_http_address(self: AsyncHTTPClient | HTTPClient, address: str) -> No
|
||||
else: self._port = next(iter(_port))
|
||||
|
||||
class HTTPClient(BaseClient):
|
||||
|
||||
def __init__(self, address: str, timeout: int = 30):
|
||||
process_http_address(self, address)
|
||||
super().__init__(address, timeout)
|
||||
|
||||
class AsyncHTTPClient(BaseAsyncClient):
|
||||
|
||||
def __init__(self, address: str, timeout: int = 30):
|
||||
process_http_address(self, address)
|
||||
super().__init__(address, timeout)
|
||||
|
||||
class GrpcClient(BaseClient):
|
||||
|
||||
def __init__(self, address: str, timeout: int = 30):
|
||||
self._host, self._port = address.split(':')
|
||||
super().__init__(address, timeout)
|
||||
|
||||
class AsyncGrpcClient(BaseAsyncClient):
|
||||
|
||||
def __init__(self, address: str, timeout: int = 30):
|
||||
self._host, self._port = address.split(':')
|
||||
super().__init__(address, timeout)
|
||||
|
||||
@@ -105,6 +105,7 @@ config_merger = Merger([(dict, 'merge')], ['override'], ['override'])
|
||||
|
||||
# case insensitive, but rename to conform with type
|
||||
class _PeftEnumMeta(enum.EnumMeta):
|
||||
|
||||
def __getitem__(self, __key: str | t.Any, /) -> t.Any:
|
||||
if isinstance(__key, str): __key = inflection.underscore(__key).upper()
|
||||
return self._member_map_[__key]
|
||||
@@ -177,11 +178,19 @@ class FineTuneConfig:
|
||||
if t.TYPE_CHECKING and not MYPY:
|
||||
# The following type stubs makes __init__ aware of attrs internal type converter.
|
||||
@overload
|
||||
def __init__(self, adapter_type: AdapterType = ..., adapter_config: dict[str, t.Any] = ..., inference_mode: bool = ..., llm_config_class: type[LLMConfig] = ...) -> None:
|
||||
def __init__(self,
|
||||
adapter_type: AdapterType = ...,
|
||||
adapter_config: dict[str, t.Any] = ...,
|
||||
inference_mode: bool = ...,
|
||||
llm_config_class: type[LLMConfig] = ...) -> None:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(self, adapter_type: PeftType = ..., adapter_config: dict[str, t.Any] = ..., inference_mode: bool = ..., llm_config_class: type[LLMConfig] = ...) -> None:
|
||||
def __init__(self,
|
||||
adapter_type: PeftType = ...,
|
||||
adapter_config: dict[str, t.Any] = ...,
|
||||
inference_mode: bool = ...,
|
||||
llm_config_class: type[LLMConfig] = ...) -> None:
|
||||
...
|
||||
|
||||
# The below should be generated via attrs. Only here to conform with pyright strict checking.
|
||||
@@ -190,27 +199,35 @@ class FineTuneConfig:
|
||||
|
||||
adapter_type: PeftType = dantic.Field(
|
||||
'lora',
|
||||
description=f"The type of adapter to use for fine-tuning. Available supported methods: {PeftType.supported()}, default to 'lora'",
|
||||
description=
|
||||
f"The type of adapter to use for fine-tuning. Available supported methods: {PeftType.supported()}, default to 'lora'",
|
||||
use_default_converter=False,
|
||||
converter=_adapter_converter
|
||||
)
|
||||
converter=_adapter_converter)
|
||||
adapter_config: t.Dict[str, t.Any] = dantic.Field(
|
||||
None,
|
||||
description='The configuration for the adapter. The content of the dict depends on the adapter type.',
|
||||
validator=attr.validators.optional(attr.validators.instance_of(dict)),
|
||||
converter=attr.converters.default_if_none(factory=dict),
|
||||
use_default_converter=False
|
||||
)
|
||||
inference_mode: bool = dantic.Field(False, description='Whether to use this Adapter for inference', use_default_converter=False)
|
||||
llm_config_class: type[LLMConfig] = dantic.Field(None, description='The reference class to openllm.LLMConfig', use_default_converter=False)
|
||||
use_default_converter=False)
|
||||
inference_mode: bool = dantic.Field(False,
|
||||
description='Whether to use this Adapter for inference',
|
||||
use_default_converter=False)
|
||||
llm_config_class: type[LLMConfig] = dantic.Field(None,
|
||||
description='The reference class to openllm.LLMConfig',
|
||||
use_default_converter=False)
|
||||
|
||||
def to_peft_config(self) -> peft.PeftConfig: # type: ignore[name-defined]
|
||||
adapter_config = self.adapter_config.copy()
|
||||
# no need for peft_type since it is internally managed by OpenLLM and PEFT
|
||||
if 'peft_type' in adapter_config: adapter_config.pop('peft_type')
|
||||
# respect user set task_type if it is passed, otherwise use one managed by OpenLLM
|
||||
task_type, inference_mode = adapter_config.pop('task_type', peft.TaskType[self.llm_config_class.peft_task_type()]), adapter_config.pop('inference_mode', self.inference_mode)
|
||||
return peft.PEFT_TYPE_TO_CONFIG_MAPPING[self.adapter_type.to_str()](task_type=task_type, inference_mode=inference_mode, **adapter_config)
|
||||
task_type, inference_mode = adapter_config.pop(
|
||||
'task_type',
|
||||
peft.TaskType[self.llm_config_class.peft_task_type()]), adapter_config.pop('inference_mode',
|
||||
self.inference_mode)
|
||||
return peft.PEFT_TYPE_TO_CONFIG_MAPPING[self.adapter_type.to_str()](task_type=task_type,
|
||||
inference_mode=inference_mode,
|
||||
**adapter_config)
|
||||
|
||||
def train(self) -> FineTuneConfig:
|
||||
_object_setattr(self, 'inference_mode', False)
|
||||
@@ -221,9 +238,14 @@ class FineTuneConfig:
|
||||
return self
|
||||
|
||||
def with_config(self, **attrs: t.Any) -> FineTuneConfig:
|
||||
adapter_type, inference_mode = attrs.pop('adapter_type', self.adapter_type), attrs.get('inference_mode', self.inference_mode)
|
||||
if 'llm_config_class' in attrs: raise ForbiddenAttributeError("'llm_config_class' should not be passed when using 'with_config'.")
|
||||
return attr.evolve(self, adapter_type=adapter_type, inference_mode=inference_mode, adapter_config=config_merger.merge(self.adapter_config, attrs))
|
||||
adapter_type, inference_mode = attrs.pop('adapter_type',
|
||||
self.adapter_type), attrs.get('inference_mode', self.inference_mode)
|
||||
if 'llm_config_class' in attrs:
|
||||
raise ForbiddenAttributeError("'llm_config_class' should not be passed when using 'with_config'.")
|
||||
return attr.evolve(self,
|
||||
adapter_type=adapter_type,
|
||||
inference_mode=inference_mode,
|
||||
adapter_config=config_merger.merge(self.adapter_config, attrs))
|
||||
|
||||
@attr.frozen(slots=True, repr=False, init=False)
|
||||
class GenerationConfig(ReprMixin):
|
||||
@@ -233,105 +255,162 @@ class GenerationConfig(ReprMixin):
|
||||
to be used conjunction with LLMConfig. The instance of the generation config can then be accessed
|
||||
via ``LLMConfig.generation_config``.
|
||||
'''
|
||||
max_new_tokens: int = dantic.Field(20, ge=0, description='The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.')
|
||||
max_new_tokens: int = dantic.Field(
|
||||
20, ge=0, description='The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.')
|
||||
min_length: int = dantic.Field(
|
||||
0,
|
||||
ge=0,
|
||||
description='The minimum length of the sequence to be generated. Corresponds to the length of the input prompt + `min_new_tokens`. Its effect is overridden by `min_new_tokens`, if also set.'
|
||||
description=
|
||||
'The minimum length of the sequence to be generated. Corresponds to the length of the input prompt + `min_new_tokens`. Its effect is overridden by `min_new_tokens`, if also set.'
|
||||
)
|
||||
min_new_tokens: int = dantic.Field(description='The minimum numbers of tokens to generate, ignoring the number of tokens in the prompt.')
|
||||
min_new_tokens: int = dantic.Field(
|
||||
description='The minimum numbers of tokens to generate, ignoring the number of tokens in the prompt.')
|
||||
early_stopping: bool = dantic.Field(
|
||||
False,
|
||||
description='''Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values: `True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an heuristic is applied and the generation stops when is it very unlikely to find better candidates; `"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical beam search algorithm) '''
|
||||
description=
|
||||
'''Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values: `True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an heuristic is applied and the generation stops when is it very unlikely to find better candidates; `"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical beam search algorithm) '''
|
||||
)
|
||||
max_time: float = dantic.Field(
|
||||
description='The maximum amount of time you allow the computation to run for in seconds. generation will still finish the current pass after allocated time has been passed.'
|
||||
description=
|
||||
'The maximum amount of time you allow the computation to run for in seconds. generation will still finish the current pass after allocated time has been passed.'
|
||||
)
|
||||
num_beams: int = dantic.Field(1, description='Number of beams for beam search. 1 means no beam search.')
|
||||
num_beam_groups: int = dantic.Field(
|
||||
1,
|
||||
description='Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams. [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.'
|
||||
description=
|
||||
'Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams. [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.'
|
||||
)
|
||||
penalty_alpha: float = dantic.Field(description='The values balance the model confidence and the degeneration penalty in contrastive search decoding.')
|
||||
use_cache: bool = dantic.Field(True, description='Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding.')
|
||||
temperature: float = dantic.Field(1.0, ge=0.0, le=1.0, description='The value used to modulate the next token probabilities.')
|
||||
top_k: int = dantic.Field(50, description='The number of highest probability vocabulary tokens to keep for top-k-filtering.')
|
||||
penalty_alpha: float = dantic.Field(
|
||||
description='The values balance the model confidence and the degeneration penalty in contrastive search decoding.'
|
||||
)
|
||||
use_cache: bool = dantic.Field(
|
||||
True,
|
||||
description=
|
||||
'Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding.'
|
||||
)
|
||||
temperature: float = dantic.Field(1.0,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description='The value used to modulate the next token probabilities.')
|
||||
top_k: int = dantic.Field(
|
||||
50, description='The number of highest probability vocabulary tokens to keep for top-k-filtering.')
|
||||
top_p: float = dantic.Field(
|
||||
1.0, description='If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation.'
|
||||
1.0,
|
||||
description=
|
||||
'If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation.'
|
||||
)
|
||||
typical_p: float = dantic.Field(
|
||||
1.0,
|
||||
description='Local typicality measures how similar the conditional probability of predicting a target token next is to the expected conditional probability of predicting a random token next, given the partial text already generated. If set to float < 1, the smallest set of the most locally typical tokens with probabilities that add up to `typical_p` or higher are kept for generation. See [this paper](https://arxiv.org/pdf/2202.00666.pdf) for more details.'
|
||||
description=
|
||||
'Local typicality measures how similar the conditional probability of predicting a target token next is to the expected conditional probability of predicting a random token next, given the partial text already generated. If set to float < 1, the smallest set of the most locally typical tokens with probabilities that add up to `typical_p` or higher are kept for generation. See [this paper](https://arxiv.org/pdf/2202.00666.pdf) for more details.'
|
||||
)
|
||||
epsilon_cutoff: float = dantic.Field(
|
||||
0.0,
|
||||
description='If set to float strictly between 0 and 1, only tokens with a conditional probability greater than `epsilon_cutoff` will be sampled. In the paper, suggested values range from 3e-4 to 9e-4, depending on the size of the model. See [Truncation Sampling as Language Model Desmoothing](https://arxiv.org/abs/2210.15191) for more details.'
|
||||
description=
|
||||
'If set to float strictly between 0 and 1, only tokens with a conditional probability greater than `epsilon_cutoff` will be sampled. In the paper, suggested values range from 3e-4 to 9e-4, depending on the size of the model. See [Truncation Sampling as Language Model Desmoothing](https://arxiv.org/abs/2210.15191) for more details.'
|
||||
)
|
||||
eta_cutoff: float = dantic.Field(
|
||||
0.0,
|
||||
description='''Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to float strictly between 0 and 1, a token is only considered if it is greater than either `eta_cutoff` or `sqrt(eta_cutoff) * exp(-entropy(softmax(next_token_logits)))`. The latter term is intuitively the expected next token probability, scaled by `sqrt(eta_cutoff)`. In the paper, suggested values range from 3e-4 to 2e-3, depending on the size of the model. See [Truncation Sampling as Language Model Desmoothing](https://arxiv.org/abs/2210.15191) for more details. '''
|
||||
description=
|
||||
'''Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to float strictly between 0 and 1, a token is only considered if it is greater than either `eta_cutoff` or `sqrt(eta_cutoff) * exp(-entropy(softmax(next_token_logits)))`. The latter term is intuitively the expected next token probability, scaled by `sqrt(eta_cutoff)`. In the paper, suggested values range from 3e-4 to 2e-3, depending on the size of the model. See [Truncation Sampling as Language Model Desmoothing](https://arxiv.org/abs/2210.15191) for more details. '''
|
||||
)
|
||||
diversity_penalty: float = dantic.Field(
|
||||
0.0,
|
||||
description="This value is subtracted from a beam's score if it generates a token same as any beam from other group at a particular time. Note that `diversity_penalty` is only effective if `group beam search` is enabled. "
|
||||
description=
|
||||
"This value is subtracted from a beam's score if it generates a token same as any beam from other group at a particular time. Note that `diversity_penalty` is only effective if `group beam search` is enabled. "
|
||||
)
|
||||
repetition_penalty: float = dantic.Field(
|
||||
1.0, description='The parameter for repetition penalty. 1.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.'
|
||||
1.0,
|
||||
description=
|
||||
'The parameter for repetition penalty. 1.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.'
|
||||
)
|
||||
encoder_repetition_penalty: float = dantic.Field(
|
||||
1.0, description='The paramater for encoder_repetition_penalty. An exponential penalty on sequences that are not in the original input. 1.0 means no penalty.'
|
||||
1.0,
|
||||
description=
|
||||
'The paramater for encoder_repetition_penalty. An exponential penalty on sequences that are not in the original input. 1.0 means no penalty.'
|
||||
)
|
||||
length_penalty: float = dantic.Field(
|
||||
1.0,
|
||||
description='Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while `length_penalty` < 0.0 encourages shorter sequences.'
|
||||
description=
|
||||
'Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while `length_penalty` < 0.0 encourages shorter sequences.'
|
||||
)
|
||||
no_repeat_ngram_size: int = dantic.Field(0, description='If set to int > 0, all ngrams of that size can only occur once.')
|
||||
no_repeat_ngram_size: int = dantic.Field(
|
||||
0, description='If set to int > 0, all ngrams of that size can only occur once.')
|
||||
bad_words_ids: t.List[t.List[int]] = dantic.Field(
|
||||
description='List of token ids that are not allowed to be generated. In order to get the token ids of the words that should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True, add_special_tokens=False).input_ids`.'
|
||||
description=
|
||||
'List of token ids that are not allowed to be generated. In order to get the token ids of the words that should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True, add_special_tokens=False).input_ids`.'
|
||||
)
|
||||
force_words_ids: t.Union[t.List[t.List[int]], t.List[t.List[t.List[int]]]] = dantic.Field(
|
||||
description='List of token ids that must be generated. If given a `List[List[int]]`, this is treated as a simple list of words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`, this triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), where one can allow different forms of each word. '
|
||||
description=
|
||||
'List of token ids that must be generated. If given a `List[List[int]]`, this is treated as a simple list of words that must be included, the opposite to `bad_words_ids`. If given `List[List[List[int]]]`, this triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), where one can allow different forms of each word. '
|
||||
)
|
||||
renormalize_logits: bool = dantic.Field(
|
||||
False,
|
||||
description="Whether to renormalize the logits after applying all the logits processors or warpers (including the custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the score logits are normalized but some logit processors or warpers break the normalization. "
|
||||
description=
|
||||
"Whether to renormalize the logits after applying all the logits processors or warpers (including the custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the score logits are normalized but some logit processors or warpers break the normalization. "
|
||||
)
|
||||
constraints: t.List[Constraint] = dantic.Field(
|
||||
description='Custom constraints that can be added to the generation to ensure that the output will contain the use of certain tokens as defined by ``Constraint`` objects, in the most sensible way possible.'
|
||||
description=
|
||||
'Custom constraints that can be added to the generation to ensure that the output will contain the use of certain tokens as defined by ``Constraint`` objects, in the most sensible way possible.'
|
||||
)
|
||||
forced_bos_token_id: int = dantic.Field(
|
||||
description='The id of the token to force as the first generated token after the ``decoder_start_token_id``. Useful for multilingual models like [mBART](https://huggingface.co/docs/transformers/model_doc/mbart) where the first generated token needs to be the target language token. '
|
||||
description=
|
||||
'The id of the token to force as the first generated token after the ``decoder_start_token_id``. Useful for multilingual models like [mBART](https://huggingface.co/docs/transformers/model_doc/mbart) where the first generated token needs to be the target language token. '
|
||||
)
|
||||
forced_eos_token_id: t.Union[int, t.List[int]] = dantic.Field(
|
||||
description='The id of the token to force as the last generated token when `max_length` is reached. Optionally, use a list to set multiple *end-of-sequence* tokens.'
|
||||
description=
|
||||
'The id of the token to force as the last generated token when `max_length` is reached. Optionally, use a list to set multiple *end-of-sequence* tokens.'
|
||||
)
|
||||
remove_invalid_values: bool = dantic.Field(
|
||||
False,
|
||||
description='Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to crash. Note that using `remove_invalid_values` can slow down generation.'
|
||||
description=
|
||||
'Whether to remove possible *nan* and *inf* outputs of the model to prevent the generation method to crash. Note that using `remove_invalid_values` can slow down generation.'
|
||||
)
|
||||
exponential_decay_length_penalty: t.Tuple[int, float] = dantic.Field(
|
||||
description='This tuple adds an exponentially increasing length penalty, after a certain amount of tokens have been generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty starts and `decay_factor` represents the factor of exponential decay'
|
||||
description=
|
||||
'This tuple adds an exponentially increasing length penalty, after a certain amount of tokens have been generated. The tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty starts and `decay_factor` represents the factor of exponential decay'
|
||||
)
|
||||
suppress_tokens: t.List[int] = dantic.Field(
|
||||
description='A list of tokens that will be suppressed at generation. The `SupressTokens` logit processor will set their log probs to `-inf` so that they are not sampled.'
|
||||
description=
|
||||
'A list of tokens that will be suppressed at generation. The `SupressTokens` logit processor will set their log probs to `-inf` so that they are not sampled.'
|
||||
)
|
||||
begin_suppress_tokens: t.List[int] = dantic.Field(
|
||||
description='A list of tokens that will be suppressed at the beginning of the generation. The `SupressBeginTokens` logit processor will set their log probs to `-inf` so that they are not sampled. '
|
||||
description=
|
||||
'A list of tokens that will be suppressed at the beginning of the generation. The `SupressBeginTokens` logit processor will set their log probs to `-inf` so that they are not sampled. '
|
||||
)
|
||||
forced_decoder_ids: t.List[t.List[int]] = dantic.Field(
|
||||
description='A list of pairs of integers which indicates a mapping from generation indices to token indices that will be forced before sampling. For example, `[[1, 123]]` means the second generated token will always be a token of index 123.'
|
||||
description=
|
||||
'A list of pairs of integers which indicates a mapping from generation indices to token indices that will be forced before sampling. For example, `[[1, 123]]` means the second generated token will always be a token of index 123.'
|
||||
)
|
||||
num_return_sequences: int = dantic.Field(1, description='The number of independently computed returned sequences for each element in the batch.')
|
||||
num_return_sequences: int = dantic.Field(
|
||||
1, description='The number of independently computed returned sequences for each element in the batch.')
|
||||
output_attentions: bool = dantic.Field(
|
||||
False, description='''Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more details.'''
|
||||
False,
|
||||
description=
|
||||
'''Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more details.'''
|
||||
)
|
||||
output_hidden_states: bool = dantic.Field(False, description='''Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more details.''')
|
||||
output_scores: bool = dantic.Field(False, description='''Whether or not to return the prediction scores. See `scores` under returned tensors for more details.''')
|
||||
output_hidden_states: bool = dantic.Field(
|
||||
False,
|
||||
description=
|
||||
'''Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more details.'''
|
||||
)
|
||||
output_scores: bool = dantic.Field(
|
||||
False,
|
||||
description=
|
||||
'''Whether or not to return the prediction scores. See `scores` under returned tensors for more details.''')
|
||||
pad_token_id: int = dantic.Field(description='The id of the *padding* token.')
|
||||
bos_token_id: int = dantic.Field(description='The id of the *beginning-of-sequence* token.')
|
||||
eos_token_id: t.Union[int, t.List[int]] = dantic.Field(description='The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.')
|
||||
encoder_no_repeat_ngram_size: int = dantic.Field(0, description='If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the `decoder_input_ids`.')
|
||||
decoder_start_token_id: int = dantic.Field(description='If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token.')
|
||||
eos_token_id: t.Union[int, t.List[int]] = dantic.Field(
|
||||
description=
|
||||
'The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.')
|
||||
encoder_no_repeat_ngram_size: int = dantic.Field(
|
||||
0,
|
||||
description=
|
||||
'If set to int > 0, all ngrams of that size that occur in the `encoder_input_ids` cannot occur in the `decoder_input_ids`.'
|
||||
)
|
||||
decoder_start_token_id: int = dantic.Field(
|
||||
description='If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token.'
|
||||
)
|
||||
|
||||
if t.TYPE_CHECKING and not MYPY:
|
||||
# stubs this for pyright as mypy already has a attr plugin builtin
|
||||
@@ -339,7 +418,10 @@ class GenerationConfig(ReprMixin):
|
||||
...
|
||||
|
||||
def __init__(self, *, _internal: bool = False, **attrs: t.Any):
|
||||
if not _internal: raise RuntimeError('GenerationConfig is not meant to be used directly, but you can access this via a LLMConfig.generation_config')
|
||||
if not _internal:
|
||||
raise RuntimeError(
|
||||
'GenerationConfig is not meant to be used directly, but you can access this via a LLMConfig.generation_config'
|
||||
)
|
||||
self.__attrs_init__(**attrs)
|
||||
|
||||
def __getitem__(self, item: str) -> t.Any:
|
||||
@@ -352,16 +434,15 @@ class GenerationConfig(ReprMixin):
|
||||
|
||||
bentoml_cattr.register_unstructure_hook_factory(
|
||||
lambda cls: attr.has(cls) and lenient_issubclass(cls, GenerationConfig),
|
||||
lambda cls: make_dict_unstructure_fn(
|
||||
cls,
|
||||
bentoml_cattr,
|
||||
_cattrs_omit_if_default=False,
|
||||
_cattrs_use_linecache=True,
|
||||
**{
|
||||
k: override(omit=True) for k, v in attr.fields_dict(cls).items() if v.default in (None, attr.NOTHING)
|
||||
}
|
||||
)
|
||||
)
|
||||
lambda cls: make_dict_unstructure_fn(cls,
|
||||
bentoml_cattr,
|
||||
_cattrs_omit_if_default=False,
|
||||
_cattrs_use_linecache=True,
|
||||
**{
|
||||
k: override(omit=True)
|
||||
for k, v in attr.fields_dict(cls).items()
|
||||
if v.default in (None, attr.NOTHING)
|
||||
}))
|
||||
|
||||
@attr.frozen(slots=True, repr=False, init=False)
|
||||
class SamplingParams(ReprMixin):
|
||||
@@ -376,19 +457,28 @@ class SamplingParams(ReprMixin):
|
||||
n: int = dantic.Field(1, description='Number of output sequences to return for the given prompt.')
|
||||
best_of: int = dantic.Field(
|
||||
None,
|
||||
description='Number of output sequences that are generated from the prompt. From these `best_of` sequences, the top `n` sequences are returned. `best_of` must be greater than or equal to `n`. This is treated as the beam width when `use_beam_search` is True. By default, `best_of` is set to `n`.'
|
||||
description=
|
||||
'Number of output sequences that are generated from the prompt. From these `best_of` sequences, the top `n` sequences are returned. `best_of` must be greater than or equal to `n`. This is treated as the beam width when `use_beam_search` is True. By default, `best_of` is set to `n`.'
|
||||
)
|
||||
presence_penalty: float = dantic.Field(
|
||||
0.0,
|
||||
description='Float that penalizes new tokens based on whether they appear in the generated text so far. Values > 0 encourage the model to use new tokens, while values < 0 encourage the model to repeat tokens.'
|
||||
description=
|
||||
'Float that penalizes new tokens based on whether they appear in the generated text so far. Values > 0 encourage the model to use new tokens, while values < 0 encourage the model to repeat tokens.'
|
||||
)
|
||||
frequency_penalty: float = dantic.Field(
|
||||
0.0,
|
||||
description='Float that penalizes new tokens based on their frequency in the generated text so far. Values > 0 encourage the model to use new tokens, while values < 0 encourage the model to repeat tokens.'
|
||||
description=
|
||||
'Float that penalizes new tokens based on their frequency in the generated text so far. Values > 0 encourage the model to use new tokens, while values < 0 encourage the model to repeat tokens.'
|
||||
)
|
||||
use_beam_search: bool = dantic.Field(False, description='Whether to use beam search instead of sampling.')
|
||||
stop: t.List[str] = dantic.Field(None, description='List of strings that stop the generation when they are generated. The returned output will not contain the stop strings.')
|
||||
ignore_eos: bool = dantic.Field(False, description='Whether to ignore the EOS token and continue generating tokens after the EOS token is generated.')
|
||||
stop: t.List[str] = dantic.Field(
|
||||
None,
|
||||
description=
|
||||
'List of strings that stop the generation when they are generated. The returned output will not contain the stop strings.'
|
||||
)
|
||||
ignore_eos: bool = dantic.Field(
|
||||
False,
|
||||
description='Whether to ignore the EOS token and continue generating tokens after the EOS token is generated.')
|
||||
logprobs: int = dantic.Field(None, description='Number of log probabilities to return per output token.')
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
@@ -402,7 +492,9 @@ class SamplingParams(ReprMixin):
|
||||
|
||||
def __init__(self, *, _internal: bool = False, **attrs: t.Any):
|
||||
if not _internal:
|
||||
raise RuntimeError("SamplingParams is not meant to be used directly, but you can access this via a LLMConfig.sampling_config or create one with 'SamplingParams.from_generation_config'")
|
||||
raise RuntimeError(
|
||||
"SamplingParams is not meant to be used directly, but you can access this via a LLMConfig.sampling_config or create one with 'SamplingParams.from_generation_config'"
|
||||
)
|
||||
_object_setattr(self, 'max_tokens', attrs.pop('max_tokens', 16))
|
||||
_object_setattr(self, 'temperature', attrs.pop('temperature', 1.0))
|
||||
_object_setattr(self, 'top_k', attrs.pop('top_k', -1))
|
||||
@@ -418,7 +510,11 @@ class SamplingParams(ReprMixin):
|
||||
return {i.name for i in attr.fields(self.__class__)}
|
||||
|
||||
def to_vllm(self) -> vllm.SamplingParams:
|
||||
return vllm.SamplingParams(max_tokens=self.max_tokens, temperature=self.temperature, top_k=self.top_k, top_p=self.top_p, **bentoml_cattr.unstructure(self))
|
||||
return vllm.SamplingParams(max_tokens=self.max_tokens,
|
||||
temperature=self.temperature,
|
||||
top_k=self.top_k,
|
||||
top_p=self.top_p,
|
||||
**bentoml_cattr.unstructure(self))
|
||||
|
||||
@classmethod
|
||||
def from_generation_config(cls, generation_config: GenerationConfig, **attrs: t.Any) -> Self:
|
||||
@@ -426,29 +522,30 @@ class SamplingParams(ReprMixin):
|
||||
stop = attrs.pop('stop', None)
|
||||
if stop is not None and isinstance(stop, str): stop = [stop]
|
||||
attrs['stop'] = stop
|
||||
if 'max_tokens' in attrs and 'max_new_tokens' in attrs: raise ValueError("Both 'max_tokens' and 'max_new_tokens' are passed. Make sure to only use one of them.")
|
||||
if 'max_tokens' in attrs and 'max_new_tokens' in attrs:
|
||||
raise ValueError("Both 'max_tokens' and 'max_new_tokens' are passed. Make sure to only use one of them.")
|
||||
temperature = first_not_none(attrs.pop('temperature', None), default=generation_config['temperature'])
|
||||
top_k = first_not_none(attrs.pop('top_k', None), default=generation_config['top_k'])
|
||||
top_p = first_not_none(attrs.pop('top_p', None), default=generation_config['top_p'])
|
||||
max_tokens = first_not_none(attrs.pop('max_tokens', None), attrs.pop('max_new_tokens', None), default=generation_config['max_new_tokens'])
|
||||
max_tokens = first_not_none(attrs.pop('max_tokens', None),
|
||||
attrs.pop('max_new_tokens', None),
|
||||
default=generation_config['max_new_tokens'])
|
||||
return cls(_internal=True, temperature=temperature, top_k=top_k, top_p=top_p, max_tokens=max_tokens, **attrs)
|
||||
|
||||
bentoml_cattr.register_unstructure_hook_factory(
|
||||
lambda cls: attr.has(cls) and lenient_issubclass(cls, SamplingParams),
|
||||
lambda cls: make_dict_unstructure_fn(
|
||||
cls,
|
||||
bentoml_cattr,
|
||||
_cattrs_omit_if_default=False,
|
||||
_cattrs_use_linecache=True,
|
||||
**{
|
||||
k: override(omit=True) for k, v in attr.fields_dict(cls).items() if v.default in (None, attr.NOTHING)
|
||||
}
|
||||
)
|
||||
)
|
||||
lambda cls: make_dict_unstructure_fn(cls,
|
||||
bentoml_cattr,
|
||||
_cattrs_omit_if_default=False,
|
||||
_cattrs_use_linecache=True,
|
||||
**{
|
||||
k: override(omit=True)
|
||||
for k, v in attr.fields_dict(cls).items()
|
||||
if v.default in (None, attr.NOTHING)
|
||||
}))
|
||||
bentoml_cattr.register_structure_hook_factory(
|
||||
lambda cls: attr.has(cls) and lenient_issubclass(cls, SamplingParams),
|
||||
lambda cls: make_dict_structure_fn(cls, bentoml_cattr, _cattrs_forbid_extra_keys=True, max_new_tokens=override(rename='max_tokens'))
|
||||
)
|
||||
lambda cls: attr.has(cls) and lenient_issubclass(cls, SamplingParams), lambda cls: make_dict_structure_fn(
|
||||
cls, bentoml_cattr, _cattrs_forbid_extra_keys=True, max_new_tokens=override(rename='max_tokens')))
|
||||
|
||||
# cached it here to save one lookup per assignment
|
||||
_object_getattribute = object.__getattribute__
|
||||
@@ -498,29 +595,27 @@ class ModelSettings(t.TypedDict, total=False):
|
||||
# tokenizer_class is the custom tokenizer class for this given LLM
|
||||
tokenizer_class: t.Optional[str]
|
||||
|
||||
_transformed_type: DictStrAny = {'fine_tune_strategies': t.Dict[AdapterType, FineTuneConfig], 'default_implementation': t.Dict[LiteralResourceSpec, LiteralRuntime]}
|
||||
_transformed_type: DictStrAny = {
|
||||
'fine_tune_strategies': t.Dict[AdapterType, FineTuneConfig],
|
||||
'default_implementation': t.Dict[LiteralResourceSpec, LiteralRuntime]
|
||||
}
|
||||
|
||||
@attr.define(
|
||||
frozen=False,
|
||||
slots=True,
|
||||
field_transformer=lambda _,
|
||||
__: [
|
||||
attr.Attribute.from_counting_attr(
|
||||
k,
|
||||
dantic.Field(
|
||||
kw_only=False if t.get_origin(ann) is not Required else True,
|
||||
auto_default=True,
|
||||
use_default_converter=False,
|
||||
type=_transformed_type.get(k, ann),
|
||||
metadata={'target': f'__openllm_{k}__'},
|
||||
description=f'ModelSettings field for {k}.'
|
||||
)
|
||||
) for k,
|
||||
ann in t.get_type_hints(ModelSettings).items()
|
||||
]
|
||||
)
|
||||
@attr.define(frozen=False,
|
||||
slots=True,
|
||||
field_transformer=lambda _, __: [
|
||||
attr.Attribute.from_counting_attr(
|
||||
k,
|
||||
dantic.Field(kw_only=False if t.get_origin(ann) is not Required else True,
|
||||
auto_default=True,
|
||||
use_default_converter=False,
|
||||
type=_transformed_type.get(k, ann),
|
||||
metadata={'target': f'__openllm_{k}__'},
|
||||
description=f'ModelSettings field for {k}.'))
|
||||
for k, ann in t.get_type_hints(ModelSettings).items()
|
||||
])
|
||||
class _ModelSettingsAttr:
|
||||
'''Internal attrs representation of ModelSettings.'''
|
||||
|
||||
def __getitem__(self, key: str) -> t.Any:
|
||||
if key in codegen.get_annotations(ModelSettings):
|
||||
return _object_getattribute(self, key)
|
||||
@@ -528,30 +623,26 @@ class _ModelSettingsAttr:
|
||||
|
||||
@classmethod
|
||||
def default(cls) -> _ModelSettingsAttr:
|
||||
return cls(
|
||||
**t.cast(
|
||||
DictStrAny,
|
||||
ModelSettings(
|
||||
default_id='__default__',
|
||||
model_ids=['__default__'],
|
||||
architecture='PreTrainedModel',
|
||||
default_implementation={
|
||||
'cpu': 'pt', 'nvidia.com/gpu': 'pt'
|
||||
},
|
||||
name_type='dasherize',
|
||||
requires_gpu=False,
|
||||
url='',
|
||||
model_type='causal_lm',
|
||||
trust_remote_code=False,
|
||||
requirements=None,
|
||||
tokenizer_class=None,
|
||||
timeout=int(36e6),
|
||||
service_name='',
|
||||
workers_per_resource=1.,
|
||||
runtime='transformers'
|
||||
)
|
||||
)
|
||||
)
|
||||
return cls(**t.cast(
|
||||
DictStrAny,
|
||||
ModelSettings(default_id='__default__',
|
||||
model_ids=['__default__'],
|
||||
architecture='PreTrainedModel',
|
||||
default_implementation={
|
||||
'cpu': 'pt',
|
||||
'nvidia.com/gpu': 'pt'
|
||||
},
|
||||
name_type='dasherize',
|
||||
requires_gpu=False,
|
||||
url='',
|
||||
model_type='causal_lm',
|
||||
trust_remote_code=False,
|
||||
requirements=None,
|
||||
tokenizer_class=None,
|
||||
timeout=int(36e6),
|
||||
service_name='',
|
||||
workers_per_resource=1.,
|
||||
runtime='transformers')))
|
||||
|
||||
# NOTE: The below are dynamically generated by the field_transformer
|
||||
if t.TYPE_CHECKING:
|
||||
@@ -579,19 +670,25 @@ class _ModelSettingsAttr:
|
||||
# update-config-stubs.py: attrs stop
|
||||
|
||||
# a heuristic cascading implementation resolver based on available resources
|
||||
def get_default_implementation(default_implementation_mapping: dict[LiteralResourceSpec, LiteralRuntime]) -> LiteralRuntime:
|
||||
def get_default_implementation(
|
||||
default_implementation_mapping: dict[LiteralResourceSpec, LiteralRuntime]) -> LiteralRuntime:
|
||||
available_spec = available_resource_spec()
|
||||
if resource_spec('tpu') in available_spec: return default_implementation_mapping.get(resource_spec('tpu'), 'pt')
|
||||
elif resource_spec('amd') in available_spec: return default_implementation_mapping.get(resource_spec('amd'), 'pt')
|
||||
elif resource_spec('nvidia') in available_spec: return default_implementation_mapping.get(resource_spec('nvidia'), 'pt')
|
||||
else: return default_implementation_mapping.get(resource_spec('cpu'), 'pt')
|
||||
elif resource_spec('nvidia') in available_spec:
|
||||
return default_implementation_mapping.get(resource_spec('nvidia'), 'pt')
|
||||
else:
|
||||
return default_implementation_mapping.get(resource_spec('cpu'), 'pt')
|
||||
|
||||
def structure_settings(cl_: type[LLMConfig], cls: type[_ModelSettingsAttr]) -> _ModelSettingsAttr:
|
||||
if 'generation_class' in cl_.__config__:
|
||||
raise ValueError(f"'generation_class' shouldn't be defined in '__config__', rather defining all required attributes under '{cl_}.GenerationConfig' instead.")
|
||||
raise ValueError(
|
||||
f"'generation_class' shouldn't be defined in '__config__', rather defining all required attributes under '{cl_}.GenerationConfig' instead."
|
||||
)
|
||||
|
||||
required_fields = {k for k, ann in t.get_type_hints(ModelSettings).items() if t.get_origin(ann) is Required}
|
||||
if any(i not in cl_.__config__ for i in required_fields): raise ValueError(f"Missing required fields {required_fields} '__config__'.")
|
||||
if any(i not in cl_.__config__ for i in required_fields):
|
||||
raise ValueError(f"Missing required fields {required_fields} '__config__'.")
|
||||
_cl_name = cl_.__name__.replace('Config', '')
|
||||
_settings_attr = cls.default()
|
||||
has_custom_name = all(i in cl_.__config__ for i in {'model_name', 'start_name'})
|
||||
@@ -599,8 +696,11 @@ def structure_settings(cl_: type[LLMConfig], cls: type[_ModelSettingsAttr]) -> _
|
||||
_final_value_dct: DictStrAny = {}
|
||||
|
||||
if not has_custom_name:
|
||||
_final_value_dct['model_name'] = inflection.underscore(_cl_name) if _settings_attr['name_type'] == 'dasherize' else _cl_name.lower()
|
||||
_final_value_dct['start_name'] = inflection.dasherize(_final_value_dct['model_name']) if _settings_attr['name_type'] == 'dasherize' else _final_value_dct['model_name']
|
||||
_final_value_dct['model_name'] = inflection.underscore(
|
||||
_cl_name) if _settings_attr['name_type'] == 'dasherize' else _cl_name.lower()
|
||||
_final_value_dct['start_name'] = inflection.dasherize(
|
||||
_final_value_dct['model_name']
|
||||
) if _settings_attr['name_type'] == 'dasherize' else _final_value_dct['model_name']
|
||||
|
||||
model_name = _final_value_dct['model_name'] if 'model_name' in _final_value_dct else _settings_attr.model_name
|
||||
# if the default implementation dependencies doesn't exist, then always fallback to 'pt'
|
||||
@@ -610,11 +710,15 @@ def structure_settings(cl_: type[LLMConfig], cls: type[_ModelSettingsAttr]) -> _
|
||||
if not BACKENDS_MAPPING[library_stub][0](): default_implementation[rs] = 'pt'
|
||||
_final_value_dct['default_implementation'] = default_implementation
|
||||
|
||||
env = openllm_core.utils.EnvVarMixin(model_name, get_default_implementation(default_implementation), model_id=_settings_attr.default_id, bettertransformer=_settings_attr.bettertransformer)
|
||||
env = openllm_core.utils.EnvVarMixin(model_name,
|
||||
get_default_implementation(default_implementation),
|
||||
model_id=_settings_attr.default_id,
|
||||
bettertransformer=_settings_attr.bettertransformer)
|
||||
_final_value_dct['env'] = env
|
||||
|
||||
# bettertransformer support
|
||||
if _settings_attr['bettertransformer'] is None: _final_value_dct['bettertransformer'] = str(env['bettertransformer_value']).upper() in ENV_VARS_TRUE_VALUES
|
||||
if _settings_attr['bettertransformer'] is None:
|
||||
_final_value_dct['bettertransformer'] = str(env['bettertransformer_value']).upper() in ENV_VARS_TRUE_VALUES
|
||||
# if requires_gpu is True, then disable BetterTransformer for quantization.
|
||||
if _settings_attr['requires_gpu']: _final_value_dct['bettertransformer'] = False
|
||||
_final_value_dct['service_name'] = f'generated_{model_name}_service.py'
|
||||
@@ -626,7 +730,8 @@ def structure_settings(cl_: type[LLMConfig], cls: type[_ModelSettingsAttr]) -> _
|
||||
# the given value is a tuple[dict[str, t.Any] ,...]
|
||||
for _possible_ft_config in _fine_tune_strategies:
|
||||
_adapter_type: AdapterType | None = _possible_ft_config.pop('adapter_type', None)
|
||||
if _adapter_type is None: raise RuntimeError("'adapter_type' is required under config definition (currently missing)'.")
|
||||
if _adapter_type is None:
|
||||
raise RuntimeError("'adapter_type' is required under config definition (currently missing)'.")
|
||||
_llm_config_class = _possible_ft_config.pop('llm_config_class', cl_)
|
||||
_converted[_adapter_type] = FineTuneConfig(PeftType[_adapter_type], _possible_ft_config, False, _llm_config_class)
|
||||
_final_value_dct['fine_tune_strategies'] = _converted
|
||||
@@ -637,10 +742,16 @@ bentoml_cattr.register_structure_hook(_ModelSettingsAttr, structure_settings)
|
||||
def _setattr_class(attr_name: str, value_var: t.Any) -> str:
|
||||
return f"setattr(cls, '{attr_name}', {value_var})"
|
||||
|
||||
def _make_assignment_script(cls: type[LLMConfig], attributes: attr.AttrsInstance, _prefix: LiteralString = 'openllm') -> t.Callable[..., None]:
|
||||
def _make_assignment_script(cls: type[LLMConfig],
|
||||
attributes: attr.AttrsInstance,
|
||||
_prefix: LiteralString = 'openllm') -> t.Callable[..., None]:
|
||||
'''Generate the assignment script with prefix attributes __openllm_<value>__.'''
|
||||
args: ListStr = []
|
||||
globs: DictStrAny = {'cls': cls, '_cached_attribute': attributes, '_cached_getattribute_get': _object_getattribute.__get__}
|
||||
globs: DictStrAny = {
|
||||
'cls': cls,
|
||||
'_cached_attribute': attributes,
|
||||
'_cached_getattribute_get': _object_getattribute.__get__
|
||||
}
|
||||
annotations: DictStrAny = {'return': None}
|
||||
|
||||
lines: ListStr = []
|
||||
@@ -650,12 +761,18 @@ def _make_assignment_script(cls: type[LLMConfig], attributes: attr.AttrsInstance
|
||||
lines.append(_setattr_class(arg_name, attr_name))
|
||||
annotations[attr_name] = field.type
|
||||
|
||||
return codegen.generate_function(cls, '__assign_attr', lines, args=('cls', *args), globs=globs, annotations=annotations)
|
||||
return codegen.generate_function(cls,
|
||||
'__assign_attr',
|
||||
lines,
|
||||
args=('cls', *args),
|
||||
globs=globs,
|
||||
annotations=annotations)
|
||||
|
||||
_reserved_namespace = {'__config__', 'GenerationConfig', 'SamplingParams'}
|
||||
|
||||
@attr.define(slots=True)
|
||||
class _ConfigAttr:
|
||||
|
||||
@staticmethod
|
||||
def Field(default: t.Any = None, **attrs: t.Any) -> t.Any:
|
||||
return dantic.Field(default, **attrs)
|
||||
@@ -721,6 +838,7 @@ class _ConfigAttr:
|
||||
'''The result generated SamplingParams class for this LLMConfig. This will be used
|
||||
to create arguments for vLLM LLMEngine that can be used throughout the lifecycle.
|
||||
This class will also be managed internally by OpenLLM.'''
|
||||
|
||||
def __attrs_init__(self, *args: t.Any, **attrs: t.Any) -> None:
|
||||
'''Generated __attrs_init__ for LLMConfig subclass that follows the attrs contract.'''
|
||||
|
||||
@@ -805,7 +923,6 @@ class _ConfigAttr:
|
||||
'''The fine-tune strategies for this given LLM.'''
|
||||
__openllm_tokenizer_class__: t.Optional[str] = Field(None)
|
||||
'''Optional tokenizer class for this given LLM. See Llama for example.'''
|
||||
|
||||
# update-config-stubs.py: special stop
|
||||
|
||||
class _ConfigBuilder:
|
||||
@@ -823,11 +940,24 @@ class _ConfigBuilder:
|
||||
It takes `these` arguments as a fully parsed attr.Attribute[t.Any] from __init_subclass__
|
||||
"""
|
||||
|
||||
__slots__ = ('_cls', '_cls_dict', '_attr_names', '_attrs', '_model_name', '_base_attr_map', '_base_names', '_has_pre_init', '_has_post_init')
|
||||
__slots__ = ('_cls', '_cls_dict', '_attr_names', '_attrs', '_model_name', '_base_attr_map', '_base_names',
|
||||
'_has_pre_init', '_has_post_init')
|
||||
|
||||
def __init__(self, cls: type[LLMConfig], these: dict[str, _CountingAttr], auto_attribs: bool = False, kw_only: bool = False, collect_by_mro: bool = True):
|
||||
attrs, base_attrs, base_attr_map = _transform_attrs(cls, these, auto_attribs, kw_only, collect_by_mro, field_transformer=codegen.make_env_transformer(cls, cls.__openllm_model_name__))
|
||||
self._cls, self._model_name, self._cls_dict, self._attrs, self._base_names, self._base_attr_map = cls, cls.__openllm_model_name__, dict(cls.__dict__), attrs, {a.name for a in base_attrs}, base_attr_map
|
||||
def __init__(self,
|
||||
cls: type[LLMConfig],
|
||||
these: dict[str, _CountingAttr],
|
||||
auto_attribs: bool = False,
|
||||
kw_only: bool = False,
|
||||
collect_by_mro: bool = True):
|
||||
attrs, base_attrs, base_attr_map = _transform_attrs(cls,
|
||||
these,
|
||||
auto_attribs,
|
||||
kw_only,
|
||||
collect_by_mro,
|
||||
field_transformer=codegen.make_env_transformer(
|
||||
cls, cls.__openllm_model_name__))
|
||||
self._cls, self._model_name, self._cls_dict, self._attrs, self._base_names, self._base_attr_map = cls, cls.__openllm_model_name__, dict(
|
||||
cls.__dict__), attrs, {a.name for a in base_attrs}, base_attr_map
|
||||
self._attr_names = tuple(a.name for a in attrs)
|
||||
self._has_pre_init = bool(getattr(cls, '__attrs_pre_init__', False))
|
||||
self._has_post_init = bool(getattr(cls, '__attrs_post_init__', False))
|
||||
@@ -850,11 +980,14 @@ class _ConfigBuilder:
|
||||
existing_slots: DictStrAny = {}
|
||||
for base_cls in self._cls.__mro__[1:-1]:
|
||||
if base_cls.__dict__.get('__weakref__', None) is not None: weakref_inherited = True
|
||||
existing_slots.update({name: getattr(base_cls, name, codegen._sentinel) for name in getattr(base_cls, '__slots__', [])})
|
||||
existing_slots.update(
|
||||
{name: getattr(base_cls, name, codegen._sentinel) for name in getattr(base_cls, '__slots__', [])})
|
||||
|
||||
names = self._attr_names
|
||||
base_names = set(self._base_names)
|
||||
if '__weakref__' not in getattr(self._cls, '__slots__', ()) and '__weakref__' not in names and not weakref_inherited: names += ('__weakref__',)
|
||||
if '__weakref__' not in getattr(self._cls, '__slots__',
|
||||
()) and '__weakref__' not in names and not weakref_inherited:
|
||||
names += ('__weakref__',)
|
||||
# We only add the names of attributes that aren't inherited.
|
||||
# Setting __slots__ to inherited attributes wastes memory.
|
||||
slot_names = [name for name in names if name not in base_names]
|
||||
@@ -911,14 +1044,17 @@ class _ConfigBuilder:
|
||||
|
||||
def add_attrs_init(self) -> Self:
|
||||
self._cls_dict['__attrs_init__'] = codegen.add_method_dunders(
|
||||
self._cls, _make_init(self._cls, self._attrs, self._has_pre_init, self._has_post_init, False, True, False, self._base_attr_map, False, None, True)
|
||||
)
|
||||
self._cls,
|
||||
_make_init(self._cls, self._attrs, self._has_pre_init, self._has_post_init, False, True, False,
|
||||
self._base_attr_map, False, None, True))
|
||||
return self
|
||||
|
||||
def add_repr(self) -> Self:
|
||||
for key, fn in ReprMixin.__dict__.items():
|
||||
if key in ('__repr__', '__str__', '__repr_name__', '__repr_str__', '__repr_args__'): self._cls_dict[key] = codegen.add_method_dunders(self._cls, fn)
|
||||
self._cls_dict['__repr_keys__'] = property(lambda _: {i.name for i in self._attrs} | {'generation_config', 'sampling_config'})
|
||||
if key in ('__repr__', '__str__', '__repr_name__', '__repr_str__', '__repr_args__'):
|
||||
self._cls_dict[key] = codegen.add_method_dunders(self._cls, fn)
|
||||
self._cls_dict['__repr_keys__'] = property(
|
||||
lambda _: {i.name for i in self._attrs} | {'generation_config', 'sampling_config'})
|
||||
return self
|
||||
|
||||
@attr.define(slots=True, init=False)
|
||||
@@ -1011,6 +1147,7 @@ class LLMConfig(_ConfigAttr):
|
||||
Future work:
|
||||
- Support pydantic-core as validation backend.
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, **_: t.Any):
|
||||
"""The purpose of this ``__init_subclass__`` is to offer pydantic UX while adhering to attrs contract.
|
||||
|
||||
@@ -1024,31 +1161,33 @@ class LLMConfig(_ConfigAttr):
|
||||
logger.warning("LLMConfig subclass should end with 'Config'. Updating to %sConfig", cls.__name__)
|
||||
cls.__name__ = f'{cls.__name__}Config'
|
||||
|
||||
if not hasattr(cls, '__config__'): raise RuntimeError("Given LLMConfig must have '__config__' that is not None defined.")
|
||||
if not hasattr(cls, '__config__'):
|
||||
raise RuntimeError("Given LLMConfig must have '__config__' that is not None defined.")
|
||||
|
||||
# auto assignment attributes generated from __config__ after create the new slot class.
|
||||
_make_assignment_script(cls, bentoml_cattr.structure(cls, _ModelSettingsAttr))(cls)
|
||||
|
||||
def _make_subclass(class_attr: str, base: type[At], globs: dict[str, t.Any] | None = None, suffix_env: LiteralString | None = None) -> type[At]:
|
||||
def _make_subclass(class_attr: str,
|
||||
base: type[At],
|
||||
globs: dict[str, t.Any] | None = None,
|
||||
suffix_env: LiteralString | None = None) -> type[At]:
|
||||
camel_name = cls.__name__.replace('Config', '')
|
||||
klass = attr.make_class(
|
||||
f'{camel_name}{class_attr}', [],
|
||||
bases=(base,),
|
||||
slots=True,
|
||||
weakref_slot=True,
|
||||
frozen=True,
|
||||
repr=False,
|
||||
init=False,
|
||||
collect_by_mro=True,
|
||||
field_transformer=codegen.make_env_transformer(
|
||||
cls,
|
||||
cls.__openllm_model_name__,
|
||||
suffix=suffix_env,
|
||||
globs=globs,
|
||||
default_callback=lambda field_name,
|
||||
field_default: getattr(getattr(cls, class_attr), field_name, field_default) if codegen.has_own_attribute(cls, class_attr) else field_default
|
||||
)
|
||||
)
|
||||
klass = attr.make_class(f'{camel_name}{class_attr}', [],
|
||||
bases=(base,),
|
||||
slots=True,
|
||||
weakref_slot=True,
|
||||
frozen=True,
|
||||
repr=False,
|
||||
init=False,
|
||||
collect_by_mro=True,
|
||||
field_transformer=codegen.make_env_transformer(
|
||||
cls,
|
||||
cls.__openllm_model_name__,
|
||||
suffix=suffix_env,
|
||||
globs=globs,
|
||||
default_callback=lambda field_name, field_default: getattr(
|
||||
getattr(cls, class_attr), field_name, field_default)
|
||||
if codegen.has_own_attribute(cls, class_attr) else field_default))
|
||||
# For pickling to work, the __module__ variable needs to be set to the
|
||||
# frame where the class is created. This respect the module that is created from cls
|
||||
try:
|
||||
@@ -1079,10 +1218,13 @@ class LLMConfig(_ConfigAttr):
|
||||
unannotated = ca_names - annotated_names
|
||||
if len(unannotated) > 0:
|
||||
missing_annotated = sorted(unannotated, key=lambda n: t.cast('_CountingAttr', cd.get(n)).counter)
|
||||
raise openllm_core.exceptions.MissingAnnotationAttributeError(f"The following field doesn't have a type annotation: {missing_annotated}")
|
||||
raise openllm_core.exceptions.MissingAnnotationAttributeError(
|
||||
f"The following field doesn't have a type annotation: {missing_annotated}")
|
||||
# We need to set the accepted key before generation_config
|
||||
# as generation_config is a special field that users shouldn't pass.
|
||||
cls.__openllm_accepted_keys__ = set(these.keys()) | {a.name for a in attr.fields(cls.__openllm_generation_class__)} | {a.name for a in attr.fields(cls.__openllm_sampling_class__)}
|
||||
cls.__openllm_accepted_keys__ = set(these.keys()) | {
|
||||
a.name for a in attr.fields(cls.__openllm_generation_class__)
|
||||
} | {a.name for a in attr.fields(cls.__openllm_sampling_class__)}
|
||||
cls = _ConfigBuilder(cls, these).add_attrs_init().add_repr().build_class()
|
||||
|
||||
# Finally, resolve the types
|
||||
@@ -1094,7 +1236,12 @@ class LLMConfig(_ConfigAttr):
|
||||
attr.resolve_types(cls.__openllm_sampling_class__, globalns=globs)
|
||||
cls = attr.resolve_types(cls, globalns=globs)
|
||||
# the hint cache for easier access
|
||||
cls.__openllm_hints__ = {f.name: f.type for ite in [attr.fields(cls), attr.fields(cls.__openllm_generation_class__), attr.fields(cls.__openllm_sampling_class__),] for f in ite}
|
||||
cls.__openllm_hints__ = {
|
||||
f.name: f.type for ite in
|
||||
[attr.fields(cls),
|
||||
attr.fields(cls.__openllm_generation_class__),
|
||||
attr.fields(cls.__openllm_sampling_class__),] for f in ite
|
||||
}
|
||||
|
||||
# for pickling to work, need to set the module to the correct outer frame
|
||||
try:
|
||||
@@ -1109,18 +1256,27 @@ class LLMConfig(_ConfigAttr):
|
||||
)
|
||||
super().__setattr__(attr, value)
|
||||
|
||||
def __init__(self, *, generation_config: DictStrAny | None = None, __openllm_extras__: DictStrAny | None = None, **attrs: t.Any):
|
||||
def __init__(self,
|
||||
*,
|
||||
generation_config: DictStrAny | None = None,
|
||||
__openllm_extras__: DictStrAny | None = None,
|
||||
**attrs: t.Any):
|
||||
# create a copy of the keys as cache
|
||||
_cached_keys = tuple(attrs.keys())
|
||||
_generation_cl_dict = attr.fields_dict(self.__openllm_generation_class__)
|
||||
if generation_config is None: generation_config = {k: v for k, v in attrs.items() if k in _generation_cl_dict}
|
||||
else: generation_config = config_merger.merge(generation_config, {k: v for k, v in attrs.items() if k in _generation_cl_dict})
|
||||
else:
|
||||
generation_config = config_merger.merge(generation_config, {
|
||||
k: v for k, v in attrs.items() if k in _generation_cl_dict
|
||||
})
|
||||
|
||||
sampling_config = {k: v for k, v in attrs.items() if k in attr.fields_dict(self.__openllm_sampling_class__)}
|
||||
for k in _cached_keys:
|
||||
if k in generation_config or k in sampling_config or attrs[k] is None: del attrs[k]
|
||||
|
||||
self.__openllm_extras__ = config_merger.merge(first_not_none(__openllm_extras__, default={}), {k: v for k, v in attrs.items() if k not in self.__openllm_accepted_keys__})
|
||||
self.__openllm_extras__ = config_merger.merge(first_not_none(__openllm_extras__, default={}), {
|
||||
k: v for k, v in attrs.items() if k not in self.__openllm_accepted_keys__
|
||||
})
|
||||
self.generation_config = self['generation_class'](_internal=True, **generation_config)
|
||||
self.sampling_config = self['sampling_class'].from_generation_config(self.generation_config, **sampling_config)
|
||||
|
||||
@@ -1302,18 +1458,25 @@ class LLMConfig(_ConfigAttr):
|
||||
"""
|
||||
if item is None: raise TypeError(f"{self} doesn't understand how to index None.")
|
||||
item = inflection.underscore(item)
|
||||
if item in _reserved_namespace: raise ForbiddenAttributeError(f"'{item}' is a reserved namespace for {self.__class__} and should not be access nor modified.")
|
||||
if item in _reserved_namespace:
|
||||
raise ForbiddenAttributeError(
|
||||
f"'{item}' is a reserved namespace for {self.__class__} and should not be access nor modified.")
|
||||
internal_attributes = f'__openllm_{item}__'
|
||||
if hasattr(self, internal_attributes): return getattr(self, internal_attributes)
|
||||
elif hasattr(self, item): return getattr(self, item)
|
||||
elif hasattr(self.__openllm_generation_class__, item): return getattr(self.generation_config, item)
|
||||
elif hasattr(self.__openllm_sampling_class__, item): return getattr(self.sampling_config, item)
|
||||
elif item in self.__class__.__openllm_fine_tune_strategies__: return self.__class__.__openllm_fine_tune_strategies__[t.cast(AdapterType, item)]
|
||||
elif item in self.__openllm_extras__: return self.__openllm_extras__[item]
|
||||
else: raise KeyError(item)
|
||||
elif item in self.__class__.__openllm_fine_tune_strategies__:
|
||||
return self.__class__.__openllm_fine_tune_strategies__[t.cast(AdapterType, item)]
|
||||
elif item in self.__openllm_extras__:
|
||||
return self.__openllm_extras__[item]
|
||||
else:
|
||||
raise KeyError(item)
|
||||
|
||||
def __getattribute__(self, item: str) -> t.Any:
|
||||
if item in _reserved_namespace: raise ForbiddenAttributeError(f"'{item}' belongs to a private namespace for {self.__class__} and should not be access nor modified.")
|
||||
if item in _reserved_namespace:
|
||||
raise ForbiddenAttributeError(
|
||||
f"'{item}' belongs to a private namespace for {self.__class__} and should not be access nor modified.")
|
||||
return _object_getattribute.__get__(self)(item)
|
||||
|
||||
def __len__(self) -> int:
|
||||
@@ -1323,13 +1486,16 @@ class LLMConfig(_ConfigAttr):
|
||||
return list(self.__openllm_accepted_keys__) + list(self.__openllm_extras__)
|
||||
|
||||
def values(self) -> list[t.Any]:
|
||||
return ([getattr(self, k.name) for k in attr.fields(self.__class__)] + [getattr(self.generation_config, k.name) for k in attr.fields(self.__openllm_generation_class__)] + [
|
||||
getattr(self.sampling_config, k.name) for k in attr.fields(self.__openllm_sampling_class__)
|
||||
] + list(self.__openllm_extras__.values()))
|
||||
return ([getattr(self, k.name) for k in attr.fields(self.__class__)] +
|
||||
[getattr(self.generation_config, k.name) for k in attr.fields(self.__openllm_generation_class__)] +
|
||||
[getattr(self.sampling_config, k.name) for k in attr.fields(self.__openllm_sampling_class__)] +
|
||||
list(self.__openllm_extras__.values()))
|
||||
|
||||
def items(self) -> list[tuple[str, t.Any]]:
|
||||
return ([(k.name, getattr(self, k.name)) for k in attr.fields(self.__class__)] + [(k.name, getattr(self.generation_config, k.name)) for k in attr.fields(self.__openllm_generation_class__)]
|
||||
+ [(k.name, getattr(self.sampling_config, k.name)) for k in attr.fields(self.__openllm_sampling_class__)] + list(self.__openllm_extras__.items()))
|
||||
return ([(k.name, getattr(self, k.name)) for k in attr.fields(self.__class__)] + [
|
||||
(k.name, getattr(self.generation_config, k.name)) for k in attr.fields(self.__openllm_generation_class__)
|
||||
] + [(k.name, getattr(self.sampling_config, k.name)) for k in attr.fields(self.__openllm_sampling_class__)] +
|
||||
list(self.__openllm_extras__.items()))
|
||||
|
||||
def __iter__(self) -> t.Iterator[str]:
|
||||
return iter(self.keys())
|
||||
@@ -1361,11 +1527,10 @@ class LLMConfig(_ConfigAttr):
|
||||
_new_cfg = {k: v for k, v in attrs.items() if k in attr.fields_dict(_ModelSettingsAttr)}
|
||||
attrs = {k: v for k, v in attrs.items() if k not in _new_cfg}
|
||||
new_cls = types.new_class(
|
||||
name or f"{cls.__name__.replace('Config', '')}DerivateConfig", (cls,), {},
|
||||
lambda ns: ns.update({
|
||||
'__config__': config_merger.merge(copy.deepcopy(cls.__dict__['__config__']), _new_cfg), '__base_config__': cls, # keep a reference for easy access
|
||||
})
|
||||
)
|
||||
name or f"{cls.__name__.replace('Config', '')}DerivateConfig", (cls,), {}, lambda ns: ns.update({
|
||||
'__config__': config_merger.merge(copy.deepcopy(cls.__dict__['__config__']), _new_cfg),
|
||||
'__base_config__': cls, # keep a reference for easy access
|
||||
}))
|
||||
|
||||
# For pickling to work, the __module__ variable needs to be set to the
|
||||
# frame where the class is created. Bypass this step in environments where
|
||||
@@ -1413,8 +1578,10 @@ class LLMConfig(_ConfigAttr):
|
||||
|
||||
if 'generation_config' in attrs:
|
||||
generation_config = attrs.pop('generation_config')
|
||||
if not isinstance(generation_config, dict): raise RuntimeError(f'Expected a dictionary, but got {type(generation_config)}')
|
||||
else: generation_config = {k: v for k, v in attrs.items() if k in attr.fields_dict(cls.__openllm_generation_class__)}
|
||||
if not isinstance(generation_config, dict):
|
||||
raise RuntimeError(f'Expected a dictionary, but got {type(generation_config)}')
|
||||
else:
|
||||
generation_config = {k: v for k, v in attrs.items() if k in attr.fields_dict(cls.__openllm_generation_class__)}
|
||||
|
||||
for k in tuple(attrs.keys()):
|
||||
if k in generation_config: del attrs[k]
|
||||
@@ -1477,7 +1644,8 @@ class LLMConfig(_ConfigAttr):
|
||||
f = dantic.attrs_to_options(name, field, cls.__openllm_model_name__, typ=ty, suffix_sampling=True)(f)
|
||||
f = cog.optgroup.group(f'{cls.__openllm_sampling_class__.__name__} sampling options')(f)
|
||||
|
||||
total_keys = set(attr.fields_dict(cls.__openllm_generation_class__)) | set(attr.fields_dict(cls.__openllm_sampling_class__))
|
||||
total_keys = set(attr.fields_dict(cls.__openllm_generation_class__)) | set(
|
||||
attr.fields_dict(cls.__openllm_sampling_class__))
|
||||
|
||||
if len(cls.__openllm_accepted_keys__.difference(total_keys)) == 0: return t.cast('click.Command', f)
|
||||
# We pop out 'generation_config' as it is a attribute that we don't need to expose to CLI.
|
||||
@@ -1496,7 +1664,8 @@ class LLMConfig(_ConfigAttr):
|
||||
|
||||
@classmethod
|
||||
def default_implementation(cls) -> LiteralRuntime:
|
||||
return first_not_none(cls.__openllm_env__['framework_value'], default=get_default_implementation(cls.__openllm_default_implementation__))
|
||||
return first_not_none(cls.__openllm_env__['framework_value'],
|
||||
default=get_default_implementation(cls.__openllm_default_implementation__))
|
||||
|
||||
def sanitize_parameters(self, prompt: str, **attrs: t.Any) -> tuple[str, DictStrAny, DictStrAny]:
|
||||
'''This handler will sanitize all attrs and setup prompt text.
|
||||
@@ -1524,8 +1693,8 @@ class LLMConfig(_ConfigAttr):
|
||||
return generation_result
|
||||
|
||||
bentoml_cattr.register_unstructure_hook_factory(
|
||||
lambda cls: lenient_issubclass(cls, LLMConfig), lambda cls: make_dict_unstructure_fn(cls, bentoml_cattr, _cattrs_omit_if_default=False, _cattrs_use_linecache=True)
|
||||
)
|
||||
lambda cls: lenient_issubclass(cls, LLMConfig),
|
||||
lambda cls: make_dict_unstructure_fn(cls, bentoml_cattr, _cattrs_omit_if_default=False, _cattrs_use_linecache=True))
|
||||
|
||||
def structure_llm_config(data: t.Any, cls: type[LLMConfig]) -> LLMConfig:
|
||||
"""Structure a dictionary to a LLMConfig object.
|
||||
@@ -1541,7 +1710,8 @@ def structure_llm_config(data: t.Any, cls: type[LLMConfig]) -> LLMConfig:
|
||||
generation_cls_fields = attr.fields_dict(cls.__openllm_generation_class__)
|
||||
if 'generation_config' in data:
|
||||
generation_config = data.pop('generation_config')
|
||||
if not isinstance(generation_config, dict): raise RuntimeError(f'Expected a dictionary, but got {type(generation_config)}')
|
||||
if not isinstance(generation_config, dict):
|
||||
raise RuntimeError(f'Expected a dictionary, but got {type(generation_config)}')
|
||||
config_merger.merge(generation_config, {k: v for k, v in data.items() if k in generation_cls_fields})
|
||||
else:
|
||||
generation_config = {k: v for k, v in data.items() if k in generation_cls_fields}
|
||||
@@ -1550,4 +1720,7 @@ def structure_llm_config(data: t.Any, cls: type[LLMConfig]) -> LLMConfig:
|
||||
return cls(generation_config=generation_config, __openllm_extras__=data, **cls_attrs)
|
||||
|
||||
bentoml_cattr.register_structure_hook_func(lambda cls: lenient_issubclass(cls, LLMConfig), structure_llm_config)
|
||||
openllm_home = os.path.expanduser(os.environ.get('OPENLLM_HOME', os.path.join(os.environ.get('XDG_CACHE_HOME', os.path.join(os.path.expanduser('~'), '.cache')), 'openllm')))
|
||||
openllm_home = os.path.expanduser(
|
||||
os.environ.get(
|
||||
'OPENLLM_HOME',
|
||||
os.path.join(os.environ.get('XDG_CACHE_HOME', os.path.join(os.path.expanduser('~'), '.cache')), 'openllm')))
|
||||
|
||||
@@ -4,11 +4,13 @@ import typing as t
|
||||
|
||||
class PromptFormatter(string.Formatter):
|
||||
"""This PromptFormatter is largely based on langchain's implementation."""
|
||||
|
||||
def vformat(self, format_string: str, args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> t.Any:
|
||||
if len(args) > 0: raise ValueError('Positional arguments are not supported')
|
||||
return super().vformat(format_string, args, kwargs)
|
||||
|
||||
def check_unused_args(self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str, t.Any]) -> None:
|
||||
def check_unused_args(self, used_args: set[int | str], args: t.Sequence[t.Any], kwargs: t.Mapping[str,
|
||||
t.Any]) -> None:
|
||||
extras = set(kwargs).difference(used_args)
|
||||
if extras: raise KeyError(f'Extra params passed: {extras}')
|
||||
|
||||
@@ -23,7 +25,9 @@ def process_prompt(prompt: str, template: str | None = None, use_prompt_template
|
||||
elif template is None: raise ValueError("'template' can't be None while 'use_prompt_template=False'")
|
||||
template_variables = default_formatter.extract_template_variables(template)
|
||||
prompt_variables = {k: v for k, v in attrs.items() if k in template_variables}
|
||||
if 'instruction' in prompt_variables: raise RuntimeError("'instruction' should be passed as the first argument instead of kwargs when 'use_prompt_template=True'")
|
||||
if 'instruction' in prompt_variables:
|
||||
raise RuntimeError(
|
||||
"'instruction' should be passed as the first argument instead of kwargs when 'use_prompt_template=True'")
|
||||
try:
|
||||
return template.format(instruction=prompt, **prompt_variables)
|
||||
except KeyError as e:
|
||||
|
||||
@@ -21,7 +21,11 @@ class GenerationInput:
|
||||
adapter_name: str | None = attr.field(default=None)
|
||||
|
||||
def model_dump(self) -> dict[str, t.Any]:
|
||||
return {'prompt': self.prompt, 'llm_config': self.llm_config.model_dump(flatten=True), 'adapter_name': self.adapter_name}
|
||||
return {
|
||||
'prompt': self.prompt,
|
||||
'llm_config': self.llm_config.model_dump(flatten=True),
|
||||
'adapter_name': self.adapter_name
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def convert_llm_config(data: dict[str, t.Any] | LLMConfig, cls: type[LLMConfig] | None = None) -> LLMConfig:
|
||||
@@ -37,14 +41,18 @@ class GenerationInput:
|
||||
|
||||
@classmethod
|
||||
def from_llm_config(cls, llm_config: LLMConfig) -> type[GenerationInput]:
|
||||
return attr.make_class(
|
||||
inflection.camelize(llm_config['model_name']) + 'GenerationInput',
|
||||
attrs={
|
||||
'prompt': attr.field(type=str),
|
||||
'llm_config': attr.field(type=llm_config.__class__, default=llm_config, converter=functools.partial(cls.convert_llm_config, cls=llm_config.__class__)),
|
||||
'adapter_name': attr.field(default=None, type=str)
|
||||
}
|
||||
)
|
||||
return attr.make_class(inflection.camelize(llm_config['model_name']) + 'GenerationInput',
|
||||
attrs={
|
||||
'prompt':
|
||||
attr.field(type=str),
|
||||
'llm_config':
|
||||
attr.field(type=llm_config.__class__,
|
||||
default=llm_config,
|
||||
converter=functools.partial(cls.convert_llm_config,
|
||||
cls=llm_config.__class__)),
|
||||
'adapter_name':
|
||||
attr.field(default=None, type=str)
|
||||
})
|
||||
|
||||
@attr.frozen(slots=True)
|
||||
class GenerationOutput:
|
||||
@@ -80,16 +88,18 @@ class EmbeddingsOutput:
|
||||
num_tokens: int
|
||||
|
||||
def unmarshal_vllm_outputs(request_output: vllm.RequestOutput) -> dict[str, t.Any]:
|
||||
return dict(
|
||||
request_id=request_output.request_id,
|
||||
prompt=request_output.prompt,
|
||||
finished=request_output.finished,
|
||||
prompt_token_ids=request_output.prompt_token_ids,
|
||||
outputs=[
|
||||
dict(index=it.index, text=it.text, token_ids=it.token_ids, cumulative_logprob=it.cumulative_logprob, logprobs=it.logprobs, finish_reason=it.finish_reason)
|
||||
for it in request_output.outputs
|
||||
]
|
||||
)
|
||||
return dict(request_id=request_output.request_id,
|
||||
prompt=request_output.prompt,
|
||||
finished=request_output.finished,
|
||||
prompt_token_ids=request_output.prompt_token_ids,
|
||||
outputs=[
|
||||
dict(index=it.index,
|
||||
text=it.text,
|
||||
token_ids=it.token_ids,
|
||||
cumulative_logprob=it.cumulative_logprob,
|
||||
logprobs=it.logprobs,
|
||||
finish_reason=it.finish_reason) for it in request_output.outputs
|
||||
])
|
||||
|
||||
@attr.define
|
||||
class HfAgentInput:
|
||||
|
||||
@@ -151,7 +151,8 @@ def _from_spec(cls: type[DynResource], spec: t.Any) -> list[str]:
|
||||
elif isinstance(spec, list):
|
||||
return [str(x) for x in spec]
|
||||
else:
|
||||
raise TypeError(f"'{cls.__name__}.from_spec' only supports parsing spec of type int, str, or list, got '{type(spec)}' instead.")
|
||||
raise TypeError(
|
||||
f"'{cls.__name__}.from_spec' only supports parsing spec of type int, str, or list, got '{type(spec)}' instead.")
|
||||
|
||||
def _raw_device_uuid_nvml() -> list[str] | None:
|
||||
from ctypes import CDLL
|
||||
@@ -217,8 +218,7 @@ def _validate(cls: type[DynResource], val: list[t.Any]) -> None:
|
||||
|
||||
def _make_resource_class(name: str, resource_kind: str, docstring: str) -> type[DynResource]:
|
||||
return types.new_class(
|
||||
name, (bentoml.Resource[t.List[str]], ReprMixin), {'resource_id': resource_kind},
|
||||
lambda ns: ns.update({
|
||||
name, (bentoml.Resource[t.List[str]], ReprMixin), {'resource_id': resource_kind}, lambda ns: ns.update({
|
||||
'resource_id': resource_kind,
|
||||
'from_spec': classmethod(_from_spec),
|
||||
'from_system': classmethod(_from_system),
|
||||
@@ -226,8 +226,7 @@ def _make_resource_class(name: str, resource_kind: str, docstring: str) -> type[
|
||||
'__repr_keys__': property(lambda _: {'resource_id'}),
|
||||
'__doc__': inspect.cleandoc(docstring),
|
||||
'__module__': 'openllm._strategies'
|
||||
})
|
||||
)
|
||||
}))
|
||||
|
||||
# NOTE: we need to hint these t.Literal since mypy is to dumb to infer this as literal :facepalm:
|
||||
_TPU_RESOURCE: t.Literal['cloud-tpus.google.com/v2'] = 'cloud-tpus.google.com/v2'
|
||||
@@ -236,21 +235,15 @@ _NVIDIA_GPU_RESOURCE: t.Literal['nvidia.com/gpu'] = 'nvidia.com/gpu'
|
||||
_CPU_RESOURCE: t.Literal['cpu'] = 'cpu'
|
||||
|
||||
NvidiaGpuResource = _make_resource_class(
|
||||
'NvidiaGpuResource',
|
||||
_NVIDIA_GPU_RESOURCE,
|
||||
'''NVIDIA GPU resource.
|
||||
'NvidiaGpuResource', _NVIDIA_GPU_RESOURCE, '''NVIDIA GPU resource.
|
||||
|
||||
This is a modified version of internal's BentoML's NvidiaGpuResource
|
||||
where it respects and parse CUDA_VISIBLE_DEVICES correctly.'''
|
||||
)
|
||||
where it respects and parse CUDA_VISIBLE_DEVICES correctly.''')
|
||||
AmdGpuResource = _make_resource_class(
|
||||
'AmdGpuResource',
|
||||
_AMD_GPU_RESOURCE,
|
||||
'''AMD GPU resource.
|
||||
'AmdGpuResource', _AMD_GPU_RESOURCE, '''AMD GPU resource.
|
||||
|
||||
Since ROCm will respect CUDA_VISIBLE_DEVICES, the behaviour of from_spec, from_system are similar to
|
||||
``NvidiaGpuResource``. Currently ``validate`` is not yet supported.'''
|
||||
)
|
||||
``NvidiaGpuResource``. Currently ``validate`` is not yet supported.''')
|
||||
|
||||
LiteralResourceSpec = t.Literal['cloud-tpus.google.com/v2', 'amd.com/gpu', 'nvidia.com/gpu', 'cpu']
|
||||
|
||||
@@ -285,8 +278,10 @@ class CascadingResourceStrategy(bentoml.Strategy, ReprMixin):
|
||||
|
||||
TODO: Support CloudTPUResource
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_worker_count(cls, runnable_class: type[bentoml.Runnable], resource_request: dict[str, t.Any] | None, workers_per_resource: float) -> int:
|
||||
def get_worker_count(cls, runnable_class: type[bentoml.Runnable], resource_request: dict[str, t.Any] | None,
|
||||
workers_per_resource: float) -> int:
|
||||
'''Return the number of workers to be used for the given runnable class.
|
||||
|
||||
Note that for all available GPU, the number of workers will always be 1.
|
||||
@@ -303,18 +298,23 @@ class CascadingResourceStrategy(bentoml.Strategy, ReprMixin):
|
||||
# use CPU
|
||||
cpus = get_resource(resource_request, 'cpu')
|
||||
if cpus is not None and cpus > 0:
|
||||
if 'cpu' not in runnable_class.SUPPORTED_RESOURCES: logger.warning('No known supported resource available for %s, falling back to using CPU.', runnable_class)
|
||||
if 'cpu' not in runnable_class.SUPPORTED_RESOURCES:
|
||||
logger.warning('No known supported resource available for %s, falling back to using CPU.', runnable_class)
|
||||
|
||||
if runnable_class.SUPPORTS_CPU_MULTI_THREADING:
|
||||
if isinstance(workers_per_resource, float) and workers_per_resource < 1.0: raise ValueError('Fractional CPU multi threading support is not yet supported.')
|
||||
if isinstance(workers_per_resource, float) and workers_per_resource < 1.0:
|
||||
raise ValueError('Fractional CPU multi threading support is not yet supported.')
|
||||
return int(workers_per_resource)
|
||||
return math.ceil(cpus) * workers_per_resource
|
||||
|
||||
# this should not be reached by user since we always read system resource as default
|
||||
raise ValueError(f'No known supported resource available for {runnable_class}. Please check your resource request. Leaving it blank will allow BentoML to use system resources.')
|
||||
raise ValueError(
|
||||
f'No known supported resource available for {runnable_class}. Please check your resource request. Leaving it blank will allow BentoML to use system resources.'
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_worker_env(cls, runnable_class: type[bentoml.Runnable], resource_request: dict[str, t.Any] | None, workers_per_resource: int | float, worker_index: int) -> dict[str, t.Any]:
|
||||
def get_worker_env(cls, runnable_class: type[bentoml.Runnable], resource_request: dict[str, t.Any] | None,
|
||||
workers_per_resource: int | float, worker_index: int) -> dict[str, t.Any]:
|
||||
'''Get worker env for this given worker_index.
|
||||
|
||||
Args:
|
||||
@@ -372,18 +372,26 @@ class CascadingResourceStrategy(bentoml.Strategy, ReprMixin):
|
||||
# NOTE: We hit this branch when workers_per_resource is set to
|
||||
# float, for example 0.5 or 0.25
|
||||
if workers_per_resource > 1:
|
||||
raise ValueError("Currently, the default strategy doesn't support workers_per_resource > 1. It is recommended that one should implement a custom strategy in this case.")
|
||||
raise ValueError(
|
||||
"Currently, the default strategy doesn't support workers_per_resource > 1. It is recommended that one should implement a custom strategy in this case."
|
||||
)
|
||||
# We are round the assigned resource here. This means if workers_per_resource=.4
|
||||
# then it will round down to 2. If workers_per_source=0.6, then it will also round up to 2.
|
||||
assigned_resource_per_worker = round(1 / workers_per_resource)
|
||||
if len(gpus) < assigned_resource_per_worker:
|
||||
logger.warning('Failed to allocate %s GPUs for %s (number of available GPUs < assigned workers per resource [%s])', gpus, worker_index, assigned_resource_per_worker)
|
||||
raise IndexError(f"There aren't enough assigned GPU(s) for given worker id '{worker_index}' [required: {assigned_resource_per_worker}].")
|
||||
assigned_gpu = gpus[assigned_resource_per_worker * worker_index:assigned_resource_per_worker * (worker_index+1)]
|
||||
logger.warning(
|
||||
'Failed to allocate %s GPUs for %s (number of available GPUs < assigned workers per resource [%s])', gpus,
|
||||
worker_index, assigned_resource_per_worker)
|
||||
raise IndexError(
|
||||
f"There aren't enough assigned GPU(s) for given worker id '{worker_index}' [required: {assigned_resource_per_worker}]."
|
||||
)
|
||||
assigned_gpu = gpus[assigned_resource_per_worker * worker_index:assigned_resource_per_worker * (worker_index + 1)]
|
||||
dev = ','.join(assigned_gpu)
|
||||
else:
|
||||
idx = worker_index // workers_per_resource
|
||||
if idx >= len(gpus): raise ValueError(f'Number of available GPU ({gpus}) preceeds the given workers_per_resource {workers_per_resource}')
|
||||
if idx >= len(gpus):
|
||||
raise ValueError(
|
||||
f'Number of available GPU ({gpus}) preceeds the given workers_per_resource {workers_per_resource}')
|
||||
dev = str(gpus[idx])
|
||||
return dev
|
||||
|
||||
|
||||
@@ -26,9 +26,14 @@ if t.TYPE_CHECKING:
|
||||
|
||||
M = t.TypeVar(
|
||||
'M',
|
||||
bound='t.Union[transformers.PreTrainedModel, transformers.Pipeline, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel, vllm.LLMEngine, peft.PeftModel, autogptq.modeling.BaseGPTQForCausalLM]'
|
||||
bound=
|
||||
't.Union[transformers.PreTrainedModel, transformers.Pipeline, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel, vllm.LLMEngine, peft.PeftModel, autogptq.modeling.BaseGPTQForCausalLM]'
|
||||
)
|
||||
T = t.TypeVar(
|
||||
'T',
|
||||
bound=
|
||||
't.Union[transformers.PreTrainedTokenizerFast, transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerBase]'
|
||||
)
|
||||
T = t.TypeVar('T', bound='t.Union[transformers.PreTrainedTokenizerFast, transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerBase]')
|
||||
|
||||
AnyCallable = t.Callable[..., t.Any]
|
||||
DictStrAny = t.Dict[str, t.Any]
|
||||
@@ -93,7 +98,6 @@ class LLMRunnable(bentoml.Runnable, t.Generic[M, T]):
|
||||
SUPPORTED_RESOURCES = ('amd.com/gpu', 'nvidia.com/gpu', 'cpu')
|
||||
SUPPORTS_CPU_MULTI_THREADING = True
|
||||
__call__: RunnableMethod[LLMRunnable[M, T], [str], list[t.Any]]
|
||||
set_adapter: RunnableMethod[LLMRunnable[M, T], [str], dict[t.Literal['success', 'error_msg'], bool | str]]
|
||||
embeddings: RunnableMethod[LLMRunnable[M, T], [list[str]], LLMEmbeddings]
|
||||
generate: RunnableMethod[LLMRunnable[M, T], [str], list[t.Any]]
|
||||
generate_one: RunnableMethod[LLMRunnable[M, T], [str, list[str]], t.Sequence[dict[t.Literal['generated_text'], str]]]
|
||||
@@ -117,19 +121,18 @@ class LLMRunner(bentoml.Runner, t.Generic[M, T]):
|
||||
generate_one: RunnerMethod[LLMRunnable[M, T], [str, list[str]], t.Sequence[dict[t.Literal['generated_text'], str]]]
|
||||
generate_iterator: RunnerMethod[LLMRunnable[M, T], [str], t.Generator[str, None, str]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
runnable_class: type[LLMRunnable[M, T]],
|
||||
*,
|
||||
runnable_init_params: dict[str, t.Any] | None = ...,
|
||||
name: str | None = ...,
|
||||
scheduling_strategy: type[Strategy] = ...,
|
||||
models: list[bentoml.Model] | None = ...,
|
||||
max_batch_size: int | None = ...,
|
||||
max_latency_ms: int | None = ...,
|
||||
method_configs: dict[str, dict[str, int]] | None = ...,
|
||||
embedded: bool = False,
|
||||
) -> None:
|
||||
def __init__(self,
|
||||
runnable_class: type[LLMRunnable[M, T]],
|
||||
*,
|
||||
runnable_init_params: dict[str, t.Any] | None = ...,
|
||||
name: str | None = ...,
|
||||
scheduling_strategy: type[Strategy] = ...,
|
||||
models: list[bentoml.Model] | None = ...,
|
||||
max_batch_size: int | None = ...,
|
||||
max_latency_ms: int | None = ...,
|
||||
method_configs: dict[str, dict[str, int]] | None = ...,
|
||||
embedded: bool = False,
|
||||
) -> None:
|
||||
...
|
||||
|
||||
def __call__(self, prompt: str, **attrs: t.Any) -> t.Any:
|
||||
|
||||
@@ -24,11 +24,14 @@ if t.TYPE_CHECKING:
|
||||
ConfigItemsView = _odict_items[str, type[openllm_core.LLMConfig]]
|
||||
|
||||
# NOTE: This is the entrypoint when adding new model config
|
||||
CONFIG_MAPPING_NAMES = OrderedDict([('chatglm', 'ChatGLMConfig'), ('dolly_v2', 'DollyV2Config'), ('falcon', 'FalconConfig'), ('flan_t5', 'FlanT5Config'), ('gpt_neox', 'GPTNeoXConfig'), (
|
||||
'llama', 'LlamaConfig'
|
||||
), ('mpt', 'MPTConfig'), ('opt', 'OPTConfig'), ('stablelm', 'StableLMConfig'), ('starcoder', 'StarCoderConfig'), ('baichuan', 'BaichuanConfig')])
|
||||
CONFIG_MAPPING_NAMES = OrderedDict([('chatglm', 'ChatGLMConfig'), ('dolly_v2', 'DollyV2Config'),
|
||||
('falcon', 'FalconConfig'), ('flan_t5', 'FlanT5Config'),
|
||||
('gpt_neox', 'GPTNeoXConfig'), ('llama', 'LlamaConfig'), ('mpt', 'MPTConfig'),
|
||||
('opt', 'OPTConfig'), ('stablelm', 'StableLMConfig'),
|
||||
('starcoder', 'StarCoderConfig'), ('baichuan', 'BaichuanConfig')])
|
||||
|
||||
class _LazyConfigMapping(OrderedDict, ReprMixin):
|
||||
|
||||
def __init__(self, mapping: OrderedDict[LiteralString, LiteralString]):
|
||||
self._mapping = mapping
|
||||
self._extra_content: dict[str, t.Any] = {}
|
||||
@@ -76,21 +79,32 @@ class _LazyConfigMapping(OrderedDict, ReprMixin):
|
||||
|
||||
CONFIG_MAPPING: dict[str, type[openllm_core.LLMConfig]] = _LazyConfigMapping(CONFIG_MAPPING_NAMES)
|
||||
# The below handle special alias when we call underscore to the name directly without processing camelcase first.
|
||||
CONFIG_NAME_ALIASES: dict[str, str] = {'chat_glm': 'chatglm', 'stable_lm': 'stablelm', 'star_coder': 'starcoder', 'gpt_neo_x': 'gpt_neox',}
|
||||
CONFIG_NAME_ALIASES: dict[str, str] = {
|
||||
'chat_glm': 'chatglm',
|
||||
'stable_lm': 'stablelm',
|
||||
'star_coder': 'starcoder',
|
||||
'gpt_neo_x': 'gpt_neox',
|
||||
}
|
||||
|
||||
class AutoConfig:
|
||||
|
||||
def __init__(self, *_: t.Any, **__: t.Any):
|
||||
raise EnvironmentError('Cannot instantiate AutoConfig directly. Please use `AutoConfig.for_model(model_name)` instead.')
|
||||
raise EnvironmentError(
|
||||
'Cannot instantiate AutoConfig directly. Please use `AutoConfig.for_model(model_name)` instead.')
|
||||
|
||||
@classmethod
|
||||
def for_model(cls, model_name: str, **attrs: t.Any) -> openllm_core.LLMConfig:
|
||||
model_name = inflection.underscore(model_name)
|
||||
if model_name in CONFIG_MAPPING: return CONFIG_MAPPING[model_name].model_construct_env(**attrs)
|
||||
raise ValueError(f"Unrecognized configuration class for {model_name}. Model name should be one of {', '.join(CONFIG_MAPPING.keys())}.")
|
||||
raise ValueError(
|
||||
f"Unrecognized configuration class for {model_name}. Model name should be one of {', '.join(CONFIG_MAPPING.keys())}."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def infer_class_from_name(cls, name: str) -> type[openllm_core.LLMConfig]:
|
||||
model_name = inflection.underscore(name)
|
||||
if model_name in CONFIG_NAME_ALIASES: model_name = CONFIG_NAME_ALIASES[model_name]
|
||||
if model_name in CONFIG_MAPPING: return CONFIG_MAPPING[model_name]
|
||||
raise ValueError(f"Unrecognized configuration class for {model_name}. Model name should be one of {', '.join(CONFIG_MAPPING.keys())}.")
|
||||
raise ValueError(
|
||||
f"Unrecognized configuration class for {model_name}. Model name should be one of {', '.join(CONFIG_MAPPING.keys())}."
|
||||
)
|
||||
|
||||
@@ -37,21 +37,24 @@ class BaichuanConfig(openllm_core.LLMConfig):
|
||||
Refer to [Baichuan-7B's GitHub page](https://github.com/baichuan-inc/Baichuan-7B) for more information.
|
||||
"""
|
||||
__config__ = {
|
||||
'name_type': 'lowercase',
|
||||
'trust_remote_code': True,
|
||||
'timeout': 3600000,
|
||||
'requires_gpu': True,
|
||||
'url': 'https://github.com/baichuan-inc/Baichuan-7B',
|
||||
'name_type':
|
||||
'lowercase',
|
||||
'trust_remote_code':
|
||||
True,
|
||||
'timeout':
|
||||
3600000,
|
||||
'requires_gpu':
|
||||
True,
|
||||
'url':
|
||||
'https://github.com/baichuan-inc/Baichuan-7B',
|
||||
'requirements': ['cpm-kernels', 'sentencepiece'],
|
||||
'architecture': 'BaiChuanForCausalLM',
|
||||
'default_id': 'baichuan-inc/baichuan-7b',
|
||||
'model_ids': [
|
||||
'architecture':
|
||||
'BaiChuanForCausalLM',
|
||||
'default_id':
|
||||
'baichuan-inc/baichuan-7b',
|
||||
'baichuan-inc/baichuan-13b-base',
|
||||
'baichuan-inc/baichuan-13b-chat',
|
||||
'fireballoon/baichuan-vicuna-chinese-7b',
|
||||
'fireballoon/baichuan-vicuna-7b',
|
||||
'hiyouga/baichuan-7b-sft'
|
||||
'model_ids': [
|
||||
'baichuan-inc/baichuan-7b', 'baichuan-inc/baichuan-13b-base', 'baichuan-inc/baichuan-13b-chat',
|
||||
'fireballoon/baichuan-vicuna-chinese-7b', 'fireballoon/baichuan-vicuna-7b', 'hiyouga/baichuan-7b-sft'
|
||||
]
|
||||
}
|
||||
|
||||
@@ -60,10 +63,19 @@ class BaichuanConfig(openllm_core.LLMConfig):
|
||||
top_p: float = 0.7
|
||||
temperature: float = 0.95
|
||||
|
||||
def sanitize_parameters(
|
||||
self, prompt: str, max_new_tokens: int | None = None, top_p: float | None = None, temperature: float | None = None, use_default_prompt_template: bool = False, **attrs: t.Any
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {'max_new_tokens': max_new_tokens, 'top_p': top_p, 'temperature': temperature, **attrs}, {}
|
||||
def sanitize_parameters(self,
|
||||
prompt: str,
|
||||
max_new_tokens: int | None = None,
|
||||
top_p: float | None = None,
|
||||
temperature: float | None = None,
|
||||
use_default_prompt_template: bool = False,
|
||||
**attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {
|
||||
'max_new_tokens': max_new_tokens,
|
||||
'top_p': top_p,
|
||||
'temperature': temperature,
|
||||
**attrs
|
||||
}, {}
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str:
|
||||
return generation_result[0]
|
||||
|
||||
@@ -41,17 +41,30 @@ class ChatGLMConfig(openllm_core.LLMConfig):
|
||||
Refer to [ChatGLM's GitHub page](https://github.com/THUDM/ChatGLM-6B) for more information.
|
||||
"""
|
||||
__config__ = {
|
||||
'name_type': 'lowercase',
|
||||
'trust_remote_code': True,
|
||||
'timeout': 3600000,
|
||||
'requires_gpu': True,
|
||||
'url': 'https://github.com/THUDM/ChatGLM-6B',
|
||||
'name_type':
|
||||
'lowercase',
|
||||
'trust_remote_code':
|
||||
True,
|
||||
'timeout':
|
||||
3600000,
|
||||
'requires_gpu':
|
||||
True,
|
||||
'url':
|
||||
'https://github.com/THUDM/ChatGLM-6B',
|
||||
'requirements': ['cpm-kernels', 'sentencepiece'],
|
||||
'architecture': 'ChatGLMForConditionalGeneration',
|
||||
'default_id': 'thudm/chatglm-6b',
|
||||
'model_ids': ['thudm/chatglm-6b', 'thudm/chatglm-6b-int8', 'thudm/chatglm-6b-int4', 'thudm/chatglm2-6b', 'thudm/chatglm2-6b-int4']
|
||||
'architecture':
|
||||
'ChatGLMForConditionalGeneration',
|
||||
'default_id':
|
||||
'thudm/chatglm-6b',
|
||||
'model_ids': [
|
||||
'thudm/chatglm-6b', 'thudm/chatglm-6b-int8', 'thudm/chatglm-6b-int4', 'thudm/chatglm2-6b',
|
||||
'thudm/chatglm2-6b-int4'
|
||||
]
|
||||
}
|
||||
retain_history: bool = dantic.Field(False, description='Whether to retain history given to the model. If set to True, then the model will retain given history.')
|
||||
retain_history: bool = dantic.Field(
|
||||
False,
|
||||
description=
|
||||
'Whether to retain history given to the model. If set to True, then the model will retain given history.')
|
||||
use_half_precision: bool = dantic.Field(True, description='Whether to use half precision for model.')
|
||||
|
||||
class GenerationConfig:
|
||||
@@ -60,17 +73,15 @@ class ChatGLMConfig(openllm_core.LLMConfig):
|
||||
top_p: float = 0.7
|
||||
temperature: float = 0.95
|
||||
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
prompt: str,
|
||||
max_new_tokens: int | None = None,
|
||||
num_beams: int | None = None,
|
||||
top_p: float | None = None,
|
||||
temperature: float | None = None,
|
||||
chat_history: list[tuple[str, str]] | None = None,
|
||||
use_default_prompt_template: bool = False,
|
||||
**attrs: t.Any
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
def sanitize_parameters(self,
|
||||
prompt: str,
|
||||
max_new_tokens: int | None = None,
|
||||
num_beams: int | None = None,
|
||||
top_p: float | None = None,
|
||||
temperature: float | None = None,
|
||||
chat_history: list[tuple[str, str]] | None = None,
|
||||
use_default_prompt_template: bool = False,
|
||||
**attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
prompt_text = ''
|
||||
if use_default_prompt_template and chat_history is not None:
|
||||
for i, (old_query, response) in enumerate(chat_history):
|
||||
@@ -79,9 +90,20 @@ class ChatGLMConfig(openllm_core.LLMConfig):
|
||||
else:
|
||||
prompt_text = prompt
|
||||
postprocess_generate_kwargs = {'chat_history': chat_history if chat_history is not None else None}
|
||||
return prompt_text, {'max_new_tokens': max_new_tokens, 'num_beams': num_beams, 'top_p': top_p, 'temperature': temperature, **attrs}, postprocess_generate_kwargs
|
||||
return prompt_text, {
|
||||
'max_new_tokens': max_new_tokens,
|
||||
'num_beams': num_beams,
|
||||
'top_p': top_p,
|
||||
'temperature': temperature,
|
||||
**attrs
|
||||
}, postprocess_generate_kwargs
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: tuple[str, list[tuple[str, str]]], *, chat_history: list[tuple[str, str]] | None = None, **attrs: t.Any) -> str:
|
||||
def postprocess_generate(self,
|
||||
prompt: str,
|
||||
generation_result: tuple[str, list[tuple[str, str]]],
|
||||
*,
|
||||
chat_history: list[tuple[str, str]] | None = None,
|
||||
**attrs: t.Any) -> str:
|
||||
generated, history = generation_result
|
||||
if self.config.retain_history:
|
||||
if chat_history is None: raise ValueError("'retain_history' is True while there is no history provided.")
|
||||
|
||||
@@ -89,19 +89,22 @@ class DollyV2Config(openllm_core.LLMConfig):
|
||||
max_new_tokens: int = 256
|
||||
eos_token_id: int = 50277 # NOTE: from get_special_token_id(self.tokenizer, END_KEY)
|
||||
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
prompt: str,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
top_k: int | None = None,
|
||||
top_p: float | None = None,
|
||||
use_default_prompt_template: bool = True,
|
||||
**attrs: t.Any
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
def sanitize_parameters(self,
|
||||
prompt: str,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
top_k: int | None = None,
|
||||
top_p: float | None = None,
|
||||
use_default_prompt_template: bool = True,
|
||||
**attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {
|
||||
'max_new_tokens': max_new_tokens, 'top_k': top_k, 'top_p': top_p, 'temperature': temperature, **attrs
|
||||
'max_new_tokens': max_new_tokens,
|
||||
'top_k': top_k,
|
||||
'top_p': top_p,
|
||||
'temperature': temperature,
|
||||
**attrs
|
||||
}, {}
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: list[dict[t.Literal['generated_text'], str]], **_: t.Any) -> str:
|
||||
def postprocess_generate(self, prompt: str, generation_result: list[dict[t.Literal['generated_text'], str]],
|
||||
**_: t.Any) -> str:
|
||||
return generation_result[0]['generated_text']
|
||||
|
||||
@@ -39,17 +39,29 @@ class FalconConfig(openllm_core.LLMConfig):
|
||||
Refer to [Falcon's HuggingFace page](https://huggingface.co/tiiuae/falcon-7b) for more information.
|
||||
"""
|
||||
__config__ = {
|
||||
'name_type': 'lowercase',
|
||||
'trust_remote_code': True,
|
||||
'requires_gpu': True,
|
||||
'timeout': int(36e6),
|
||||
'url': 'https://falconllm.tii.ae/',
|
||||
'name_type':
|
||||
'lowercase',
|
||||
'trust_remote_code':
|
||||
True,
|
||||
'requires_gpu':
|
||||
True,
|
||||
'timeout':
|
||||
int(36e6),
|
||||
'url':
|
||||
'https://falconllm.tii.ae/',
|
||||
'requirements': ['einops', 'xformers'],
|
||||
'architecture': 'FalconForCausalLM',
|
||||
'default_id': 'tiiuae/falcon-7b',
|
||||
'architecture':
|
||||
'FalconForCausalLM',
|
||||
'default_id':
|
||||
'tiiuae/falcon-7b',
|
||||
'model_ids': ['tiiuae/falcon-7b', 'tiiuae/falcon-40b', 'tiiuae/falcon-7b-instruct', 'tiiuae/falcon-40b-instruct'],
|
||||
'fine_tune_strategies': ({
|
||||
'adapter_type': 'lora', 'r': 64, 'lora_alpha': 16, 'lora_dropout': 0.1, 'bias': 'none', 'target_modules': ['query_key_value', 'dense', 'dense_h_to_4h', 'dense_4h_to_h']
|
||||
'adapter_type': 'lora',
|
||||
'r': 64,
|
||||
'lora_alpha': 16,
|
||||
'lora_dropout': 0.1,
|
||||
'bias': 'none',
|
||||
'target_modules': ['query_key_value', 'dense', 'dense_h_to_4h', 'dense_4h_to_h']
|
||||
},)
|
||||
}
|
||||
|
||||
@@ -60,18 +72,20 @@ class FalconConfig(openllm_core.LLMConfig):
|
||||
num_beams: int = 4
|
||||
early_stopping: bool = True
|
||||
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
prompt: str,
|
||||
max_new_tokens: int | None = None,
|
||||
top_k: int | None = None,
|
||||
num_return_sequences: int | None = None,
|
||||
eos_token_id: int | None = None,
|
||||
use_default_prompt_template: bool = False,
|
||||
**attrs: t.Any
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
def sanitize_parameters(self,
|
||||
prompt: str,
|
||||
max_new_tokens: int | None = None,
|
||||
top_k: int | None = None,
|
||||
num_return_sequences: int | None = None,
|
||||
eos_token_id: int | None = None,
|
||||
use_default_prompt_template: bool = False,
|
||||
**attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {
|
||||
'max_new_tokens': max_new_tokens, 'top_k': top_k, 'num_return_sequences': num_return_sequences, 'eos_token_id': eos_token_id, **attrs
|
||||
'max_new_tokens': max_new_tokens,
|
||||
'top_k': top_k,
|
||||
'num_return_sequences': num_return_sequences,
|
||||
'eos_token_id': eos_token_id,
|
||||
**attrs
|
||||
}, {}
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str:
|
||||
|
||||
@@ -40,11 +40,18 @@ class FlanT5Config(openllm_core.LLMConfig):
|
||||
Refer to [FLAN-T5's page](https://huggingface.co/docs/transformers/model_doc/flan-t5) for more information.
|
||||
"""
|
||||
__config__ = {
|
||||
'url': 'https://huggingface.co/docs/transformers/model_doc/flan-t5',
|
||||
'architecture': 'T5ForConditionalGeneration',
|
||||
'model_type': 'seq2seq_lm',
|
||||
'default_id': 'google/flan-t5-large',
|
||||
'model_ids': ['google/flan-t5-small', 'google/flan-t5-base', 'google/flan-t5-large', 'google/flan-t5-xl', 'google/flan-t5-xxl',]
|
||||
'url':
|
||||
'https://huggingface.co/docs/transformers/model_doc/flan-t5',
|
||||
'architecture':
|
||||
'T5ForConditionalGeneration',
|
||||
'model_type':
|
||||
'seq2seq_lm',
|
||||
'default_id':
|
||||
'google/flan-t5-large',
|
||||
'model_ids': [
|
||||
'google/flan-t5-small', 'google/flan-t5-base', 'google/flan-t5-large', 'google/flan-t5-xl',
|
||||
'google/flan-t5-xxl',
|
||||
]
|
||||
}
|
||||
|
||||
class GenerationConfig:
|
||||
@@ -54,19 +61,21 @@ class FlanT5Config(openllm_core.LLMConfig):
|
||||
top_p: float = 0.4
|
||||
repetition_penalty = 1.0
|
||||
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
prompt: str,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
top_k: int | None = None,
|
||||
top_p: float | None = None,
|
||||
repetition_penalty: float | None = None,
|
||||
use_default_prompt_template: bool = True,
|
||||
**attrs: t.Any
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
def sanitize_parameters(self,
|
||||
prompt: str,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
top_k: int | None = None,
|
||||
top_p: float | None = None,
|
||||
repetition_penalty: float | None = None,
|
||||
use_default_prompt_template: bool = True,
|
||||
**attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {
|
||||
'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_k': top_k, 'top_p': top_p, 'repetition_penalty': repetition_penalty
|
||||
'max_new_tokens': max_new_tokens,
|
||||
'temperature': temperature,
|
||||
'top_k': top_k,
|
||||
'top_p': top_p,
|
||||
'repetition_penalty': repetition_penalty
|
||||
}, {}
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str:
|
||||
|
||||
@@ -57,9 +57,16 @@ class GPTNeoXConfig(openllm_core.LLMConfig):
|
||||
temperature: float = 0.9
|
||||
max_new_tokens: int = 100
|
||||
|
||||
def sanitize_parameters(self, prompt: str, temperature: float | None = None, max_new_tokens: int | None = None, use_default_prompt_template: bool = True,
|
||||
def sanitize_parameters(self,
|
||||
prompt: str,
|
||||
temperature: float | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
use_default_prompt_template: bool = True,
|
||||
**attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {'max_new_tokens': max_new_tokens, 'temperature': temperature}, {}
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {
|
||||
'max_new_tokens': max_new_tokens,
|
||||
'temperature': temperature
|
||||
}, {}
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: list[str], **_: t.Any) -> str:
|
||||
return generation_result[0]
|
||||
|
||||
@@ -42,7 +42,8 @@ If a question does not make any sense, or is not factually coherent, explain why
|
||||
'''
|
||||
SINST_KEY, EINST_KEY, SYS_KEY, EOS_TOKEN, BOS_TOKEN = '[INST]', '[/INST]', '<<SYS>>', '</s>', '<s>'
|
||||
# TODO: support history and v1 prompt implementation
|
||||
_v1_prompt, _v2_prompt = '''{instruction}''', '''{start_key} {sys_key}\n{system_message}\n{sys_key}\n\n{instruction}\n{end_key} '''.format(start_key=SINST_KEY, sys_key=SYS_KEY, system_message=SYSTEM_MESSAGE, instruction='{instruction}', end_key=EINST_KEY)
|
||||
_v1_prompt, _v2_prompt = '''{instruction}''', '''{start_key} {sys_key}\n{system_message}\n{sys_key}\n\n{instruction}\n{end_key} '''.format(
|
||||
start_key=SINST_KEY, sys_key=SYS_KEY, system_message=SYSTEM_MESSAGE, instruction='{instruction}', end_key=EINST_KEY)
|
||||
PROMPT_MAPPING = {'v1': _v1_prompt, 'v2': _v2_prompt}
|
||||
|
||||
def _get_prompt(model_type: t.Literal['v1', 'v2']) -> str:
|
||||
@@ -62,40 +63,38 @@ class LlamaConfig(openllm_core.LLMConfig):
|
||||
Refer to [Llama's model card](https://huggingface.co/docs/transformers/main/model_doc/llama)
|
||||
for more information.
|
||||
"""
|
||||
use_llama2_prompt: bool = dantic.Field(False, description='Whether to use the prompt format for Llama 2. Disable this when working with Llama 1.')
|
||||
use_llama2_prompt: bool = dantic.Field(
|
||||
False, description='Whether to use the prompt format for Llama 2. Disable this when working with Llama 1.')
|
||||
__config__ = {
|
||||
'name_type': 'lowercase',
|
||||
'url': 'https://github.com/facebookresearch/llama',
|
||||
'name_type':
|
||||
'lowercase',
|
||||
'url':
|
||||
'https://github.com/facebookresearch/llama',
|
||||
'default_implementation': {
|
||||
'cpu': 'pt', 'nvidia.com/gpu': 'pt'
|
||||
'cpu': 'pt',
|
||||
'nvidia.com/gpu': 'pt'
|
||||
},
|
||||
'architecture': 'LlamaForCausalLM',
|
||||
'architecture':
|
||||
'LlamaForCausalLM',
|
||||
'requirements': ['fairscale', 'sentencepiece'],
|
||||
'tokenizer_class': 'LlamaTokenizerFast',
|
||||
'default_id': 'NousResearch/llama-2-7b-hf',
|
||||
'model_ids': [
|
||||
'meta-llama/Llama-2-70b-chat-hf',
|
||||
'meta-llama/Llama-2-13b-chat-hf',
|
||||
'meta-llama/Llama-2-7b-chat-hf',
|
||||
'meta-llama/Llama-2-70b-hf',
|
||||
'meta-llama/Llama-2-13b-hf',
|
||||
'meta-llama/Llama-2-7b-hf',
|
||||
'NousResearch/llama-2-70b-chat-hf',
|
||||
'NousResearch/llama-2-13b-chat-hf',
|
||||
'NousResearch/llama-2-7b-chat-hf',
|
||||
'NousResearch/llama-2-70b-hf',
|
||||
'NousResearch/llama-2-13b-hf',
|
||||
'tokenizer_class':
|
||||
'LlamaTokenizerFast',
|
||||
'default_id':
|
||||
'NousResearch/llama-2-7b-hf',
|
||||
'openlm-research/open_llama_7b_v2',
|
||||
'openlm-research/open_llama_3b_v2',
|
||||
'openlm-research/open_llama_13b',
|
||||
'huggyllama/llama-65b',
|
||||
'huggyllama/llama-30b',
|
||||
'huggyllama/llama-13b',
|
||||
'huggyllama/llama-7b'
|
||||
'model_ids': [
|
||||
'meta-llama/Llama-2-70b-chat-hf', 'meta-llama/Llama-2-13b-chat-hf', 'meta-llama/Llama-2-7b-chat-hf',
|
||||
'meta-llama/Llama-2-70b-hf', 'meta-llama/Llama-2-13b-hf', 'meta-llama/Llama-2-7b-hf',
|
||||
'NousResearch/llama-2-70b-chat-hf', 'NousResearch/llama-2-13b-chat-hf', 'NousResearch/llama-2-7b-chat-hf',
|
||||
'NousResearch/llama-2-70b-hf', 'NousResearch/llama-2-13b-hf', 'NousResearch/llama-2-7b-hf',
|
||||
'openlm-research/open_llama_7b_v2', 'openlm-research/open_llama_3b_v2', 'openlm-research/open_llama_13b',
|
||||
'huggyllama/llama-65b', 'huggyllama/llama-30b', 'huggyllama/llama-13b', 'huggyllama/llama-7b'
|
||||
],
|
||||
'fine_tune_strategies': ({
|
||||
'adapter_type': 'lora', 'r': 64, 'lora_alpha': 16, 'lora_dropout': 0.1, 'bias': 'none'
|
||||
'adapter_type': 'lora',
|
||||
'r': 64,
|
||||
'lora_alpha': 16,
|
||||
'lora_dropout': 0.1,
|
||||
'bias': 'none'
|
||||
},)
|
||||
}
|
||||
|
||||
@@ -109,20 +108,24 @@ class LlamaConfig(openllm_core.LLMConfig):
|
||||
best_of: int = 1
|
||||
presence_penalty: float = 0.5
|
||||
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
prompt: str,
|
||||
top_k: int | None = None,
|
||||
top_p: float | None = None,
|
||||
temperature: float | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
use_default_prompt_template: bool = False,
|
||||
use_llama2_prompt: bool = True,
|
||||
**attrs: t.Any
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE('v2' if use_llama2_prompt else 'v1') if use_default_prompt_template else None, use_default_prompt_template, **attrs), {
|
||||
'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_p': top_p, 'top_k': top_k
|
||||
}, {}
|
||||
def sanitize_parameters(self,
|
||||
prompt: str,
|
||||
top_k: int | None = None,
|
||||
top_p: float | None = None,
|
||||
temperature: float | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
use_default_prompt_template: bool = False,
|
||||
use_llama2_prompt: bool = True,
|
||||
**attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return process_prompt(
|
||||
prompt,
|
||||
DEFAULT_PROMPT_TEMPLATE('v2' if use_llama2_prompt else 'v1') if use_default_prompt_template else None,
|
||||
use_default_prompt_template, **attrs), {
|
||||
'max_new_tokens': max_new_tokens,
|
||||
'temperature': temperature,
|
||||
'top_p': top_p,
|
||||
'top_k': top_k
|
||||
}, {}
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: list[str], **_: t.Any) -> str:
|
||||
return generation_result[0]
|
||||
|
||||
@@ -44,7 +44,12 @@ _chat_prompt, _default_prompt, _instruct_prompt = '''{instruction}''', '''{instr
|
||||
{instruction}
|
||||
{response_key}
|
||||
'''.format(intro=INTRO_BLURB, instruction_key=INSTRUCTION_KEY, instruction='{instruction}', response_key=RESPONSE_KEY)
|
||||
PROMPT_MAPPING = {'default': _default_prompt, 'instruct': _instruct_prompt, 'storywriter': _default_prompt, 'chat': _chat_prompt}
|
||||
PROMPT_MAPPING = {
|
||||
'default': _default_prompt,
|
||||
'instruct': _instruct_prompt,
|
||||
'storywriter': _default_prompt,
|
||||
'chat': _chat_prompt
|
||||
}
|
||||
|
||||
def _get_prompt(model_type: str) -> str:
|
||||
return PROMPT_MAPPING[model_type]
|
||||
@@ -61,21 +66,31 @@ class MPTConfig(openllm_core.LLMConfig):
|
||||
for more details on specific models.
|
||||
"""
|
||||
__config__ = {
|
||||
'name_type': 'lowercase',
|
||||
'trust_remote_code': True,
|
||||
'url': 'https://huggingface.co/mosaicml',
|
||||
'timeout': int(36e6),
|
||||
'name_type':
|
||||
'lowercase',
|
||||
'trust_remote_code':
|
||||
True,
|
||||
'url':
|
||||
'https://huggingface.co/mosaicml',
|
||||
'timeout':
|
||||
int(36e6),
|
||||
'requirements': ['triton', 'einops'],
|
||||
'architecture': 'MPTForCausalLM',
|
||||
'default_id': 'mosaicml/mpt-7b-instruct',
|
||||
'architecture':
|
||||
'MPTForCausalLM',
|
||||
'default_id':
|
||||
'mosaicml/mpt-7b-instruct',
|
||||
'model_ids': [
|
||||
'mosaicml/mpt-7b', 'mosaicml/mpt-7b-instruct', 'mosaicml/mpt-7b-chat', 'mosaicml/mpt-7b-storywriter', 'mosaicml/mpt-30b', 'mosaicml/mpt-30b-instruct', 'mosaicml/mpt-30b-chat'
|
||||
'mosaicml/mpt-7b', 'mosaicml/mpt-7b-instruct', 'mosaicml/mpt-7b-chat', 'mosaicml/mpt-7b-storywriter',
|
||||
'mosaicml/mpt-30b', 'mosaicml/mpt-30b-instruct', 'mosaicml/mpt-30b-chat'
|
||||
]
|
||||
}
|
||||
prompt_type: MPTPromptType = dantic.Field('"default"', description='Given prompt type for running MPT. Default will be inferred from model name if pretrained.')
|
||||
prompt_type: MPTPromptType = dantic.Field(
|
||||
'"default"',
|
||||
description='Given prompt type for running MPT. Default will be inferred from model name if pretrained.')
|
||||
max_sequence_length: int = dantic.Field(
|
||||
2048,
|
||||
description='Max sequence length to run MPT with. Note that MPT is trained ith sequence length of 2048, but with [ALiBi](https://arxiv.org/abs/2108.12409) it can set up to 4096 (for 7b models) and 16384 (for 30b models)'
|
||||
description=
|
||||
'Max sequence length to run MPT with. Note that MPT is trained ith sequence length of 2048, but with [ALiBi](https://arxiv.org/abs/2108.12409) it can set up to 4096 (for 7b models) and 16384 (for 30b models)'
|
||||
)
|
||||
|
||||
class GenerationConfig:
|
||||
@@ -83,16 +98,15 @@ class MPTConfig(openllm_core.LLMConfig):
|
||||
temperature: float = 0
|
||||
top_p: float = 0.8
|
||||
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
prompt: str,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
prompt_type: MPTPromptType | None = None,
|
||||
use_default_prompt_template: bool = True,
|
||||
**attrs: t.Any,
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
def sanitize_parameters(self,
|
||||
prompt: str,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
prompt_type: MPTPromptType | None = None,
|
||||
use_default_prompt_template: bool = True,
|
||||
**attrs: t.Any,
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
_template = None
|
||||
if use_default_prompt_template:
|
||||
if prompt_type is None:
|
||||
@@ -101,7 +115,11 @@ class MPTConfig(openllm_core.LLMConfig):
|
||||
elif 'chat' in self.model_id: prompt_type = 'chat'
|
||||
else: prompt_type = 'default'
|
||||
_template = DEFAULT_PROMPT_TEMPLATE(prompt_type)
|
||||
return process_prompt(prompt, _template, use_default_prompt_template), {'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_p': top_p}, {}
|
||||
return process_prompt(prompt, _template, use_default_prompt_template), {
|
||||
'max_new_tokens': max_new_tokens,
|
||||
'temperature': temperature,
|
||||
'top_p': top_p
|
||||
}, {}
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **attrs: t.Any) -> str:
|
||||
return generation_result[0]
|
||||
|
||||
@@ -44,17 +44,31 @@ class OPTConfig(openllm_core.LLMConfig):
|
||||
Refer to [OPT's HuggingFace page](https://huggingface.co/docs/transformers/model_doc/opt) for more information.
|
||||
"""
|
||||
__config__ = {
|
||||
'name_type': 'lowercase',
|
||||
'trust_remote_code': False,
|
||||
'url': 'https://huggingface.co/docs/transformers/model_doc/opt',
|
||||
'default_id': 'facebook/opt-1.3b',
|
||||
'architecture': 'OPTForCausalLM',
|
||||
'model_ids': ['facebook/opt-125m', 'facebook/opt-350m', 'facebook/opt-1.3b', 'facebook/opt-2.7b', 'facebook/opt-6.7b', 'facebook/opt-66b'],
|
||||
'name_type':
|
||||
'lowercase',
|
||||
'trust_remote_code':
|
||||
False,
|
||||
'url':
|
||||
'https://huggingface.co/docs/transformers/model_doc/opt',
|
||||
'default_id':
|
||||
'facebook/opt-1.3b',
|
||||
'architecture':
|
||||
'OPTForCausalLM',
|
||||
'model_ids': [
|
||||
'facebook/opt-125m', 'facebook/opt-350m', 'facebook/opt-1.3b', 'facebook/opt-2.7b', 'facebook/opt-6.7b',
|
||||
'facebook/opt-66b'
|
||||
],
|
||||
'fine_tune_strategies': ({
|
||||
'adapter_type': 'lora', 'r': 16, 'lora_alpha': 32, 'target_modules': ['q_proj', 'v_proj'], 'lora_dropout': 0.05, 'bias': 'none'
|
||||
'adapter_type': 'lora',
|
||||
'r': 16,
|
||||
'lora_alpha': 32,
|
||||
'target_modules': ['q_proj', 'v_proj'],
|
||||
'lora_dropout': 0.05,
|
||||
'bias': 'none'
|
||||
},)
|
||||
}
|
||||
format_outputs: bool = dantic.Field(False, description='''Whether to format the outputs. This can be used when num_return_sequences > 1.''')
|
||||
format_outputs: bool = dantic.Field(
|
||||
False, description='''Whether to format the outputs. This can be used when num_return_sequences > 1.''')
|
||||
|
||||
class GenerationConfig:
|
||||
top_k: int = 15
|
||||
@@ -62,18 +76,19 @@ class OPTConfig(openllm_core.LLMConfig):
|
||||
max_new_tokens: int = 1024
|
||||
num_return_sequences: int = 1
|
||||
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
prompt: str,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
top_k: int | None = None,
|
||||
num_return_sequences: int | None = None,
|
||||
use_default_prompt_template: bool = False,
|
||||
**attrs: t.Any
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
def sanitize_parameters(self,
|
||||
prompt: str,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
top_k: int | None = None,
|
||||
num_return_sequences: int | None = None,
|
||||
use_default_prompt_template: bool = False,
|
||||
**attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {
|
||||
'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_k': top_k, 'num_return_sequences': num_return_sequences
|
||||
'max_new_tokens': max_new_tokens,
|
||||
'temperature': temperature,
|
||||
'top_k': top_k,
|
||||
'num_return_sequences': num_return_sequences
|
||||
}, {}
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **attrs: t.Any) -> str:
|
||||
|
||||
@@ -47,11 +47,18 @@ class StableLMConfig(openllm_core.LLMConfig):
|
||||
for more information.
|
||||
"""
|
||||
__config__ = {
|
||||
'name_type': 'lowercase',
|
||||
'url': 'https://github.com/Stability-AI/StableLM',
|
||||
'architecture': 'GPTNeoXForCausalLM',
|
||||
'default_id': 'stabilityai/stablelm-tuned-alpha-3b',
|
||||
'model_ids': ['stabilityai/stablelm-tuned-alpha-3b', 'stabilityai/stablelm-tuned-alpha-7b', 'stabilityai/stablelm-base-alpha-3b', 'stabilityai/stablelm-base-alpha-7b']
|
||||
'name_type':
|
||||
'lowercase',
|
||||
'url':
|
||||
'https://github.com/Stability-AI/StableLM',
|
||||
'architecture':
|
||||
'GPTNeoXForCausalLM',
|
||||
'default_id':
|
||||
'stabilityai/stablelm-tuned-alpha-3b',
|
||||
'model_ids': [
|
||||
'stabilityai/stablelm-tuned-alpha-3b', 'stabilityai/stablelm-tuned-alpha-7b',
|
||||
'stabilityai/stablelm-base-alpha-3b', 'stabilityai/stablelm-base-alpha-7b'
|
||||
]
|
||||
}
|
||||
|
||||
class GenerationConfig:
|
||||
@@ -60,22 +67,29 @@ class StableLMConfig(openllm_core.LLMConfig):
|
||||
top_k: int = 0
|
||||
top_p: float = 0.9
|
||||
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
prompt: str,
|
||||
temperature: float | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
top_k: int | None = None,
|
||||
top_p: float | None = None,
|
||||
use_default_prompt_template: bool = False,
|
||||
**attrs: t.Any
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
def sanitize_parameters(self,
|
||||
prompt: str,
|
||||
temperature: float | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
top_k: int | None = None,
|
||||
top_p: float | None = None,
|
||||
use_default_prompt_template: bool = False,
|
||||
**attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
if 'tuned' in self._model_id and use_default_prompt_template:
|
||||
system_prompt = attrs.pop('system_prompt', SYSTEM_PROMPT)
|
||||
prompt_text = process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, system_prompt=system_prompt, **attrs)
|
||||
prompt_text = process_prompt(prompt,
|
||||
DEFAULT_PROMPT_TEMPLATE,
|
||||
use_default_prompt_template,
|
||||
system_prompt=system_prompt,
|
||||
**attrs)
|
||||
else:
|
||||
prompt_text = prompt
|
||||
return prompt_text, {'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_k': top_k, 'top_p': top_p}, {}
|
||||
return prompt_text, {
|
||||
'max_new_tokens': max_new_tokens,
|
||||
'temperature': temperature,
|
||||
'top_k': top_k,
|
||||
'top_p': top_p
|
||||
}, {}
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: list[str], **_: t.Any) -> str:
|
||||
return generation_result[0]
|
||||
|
||||
@@ -54,9 +54,13 @@ class StarCoderConfig(openllm_core.LLMConfig):
|
||||
pad_token_id: int = 49152
|
||||
repetition_penalty: float = 1.2
|
||||
|
||||
def sanitize_parameters(
|
||||
self, prompt: str, temperature: float | None = None, top_p: float | None = None, max_new_tokens: int | None = None, repetition_penalty: float | None = None, **attrs: t.Any
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
def sanitize_parameters(self,
|
||||
prompt: str,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
repetition_penalty: float | None = None,
|
||||
**attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
fim_mode, prefix, suffix = FIM_INDICATOR in prompt, None, None
|
||||
if fim_mode:
|
||||
try:
|
||||
@@ -67,7 +71,14 @@ class StarCoderConfig(openllm_core.LLMConfig):
|
||||
else:
|
||||
prompt_text = prompt
|
||||
# XXX: This value for pad_token_id is currently a hack, need more investigate why the default starcoder doesn't include the same value as santacoder EOD
|
||||
return prompt_text, {'temperature': temperature, 'top_p': top_p, 'max_new_tokens': max_new_tokens, 'repetition_penalty': repetition_penalty, 'pad_token_id': 49152, **attrs}, {}
|
||||
return prompt_text, {
|
||||
'temperature': temperature,
|
||||
'top_p': top_p,
|
||||
'max_new_tokens': max_new_tokens,
|
||||
'repetition_penalty': repetition_penalty,
|
||||
'pad_token_id': 49152,
|
||||
**attrs
|
||||
}, {}
|
||||
|
||||
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str:
|
||||
return generation_result[0]
|
||||
|
||||
@@ -47,9 +47,12 @@ logger = logging.getLogger(__name__)
|
||||
try:
|
||||
from typing import GenericAlias as _TypingGenericAlias # type: ignore
|
||||
except ImportError:
|
||||
_TypingGenericAlias = () # type: ignore # python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on)
|
||||
_TypingGenericAlias = (
|
||||
) # type: ignore # python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on)
|
||||
if sys.version_info < (3, 10): _WithArgsTypes = (_TypingGenericAlias,)
|
||||
else: _WithArgsTypes: t.Any = (t._GenericAlias, types.GenericAlias, types.UnionType) # type: ignore # _GenericAlias is the actual GenericAlias implementation
|
||||
else:
|
||||
_WithArgsTypes: t.Any = (t._GenericAlias, types.GenericAlias, types.UnionType
|
||||
) # type: ignore # _GenericAlias is the actual GenericAlias implementation
|
||||
|
||||
DEV_DEBUG_VAR = 'OPENLLMDEVDEBUG'
|
||||
|
||||
@@ -117,6 +120,7 @@ def get_quiet_mode() -> bool:
|
||||
return not DEBUG and _get_quiet_mode()
|
||||
|
||||
class ExceptionFilter(logging.Filter):
|
||||
|
||||
def __init__(self, exclude_exceptions: list[type[Exception]] | None = None, **kwargs: t.Any):
|
||||
'''A filter of all exception.'''
|
||||
if exclude_exceptions is None: exclude_exceptions = [ConflictError]
|
||||
@@ -133,6 +137,7 @@ class ExceptionFilter(logging.Filter):
|
||||
return True
|
||||
|
||||
class InfoFilter(logging.Filter):
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
return logging.INFO <= record.levelno < logging.WARNING
|
||||
|
||||
@@ -145,24 +150,32 @@ _LOGGING_CONFIG: dict[str, t.Any] = {
|
||||
'filters': {
|
||||
'excfilter': {
|
||||
'()': 'openllm_core.utils.ExceptionFilter'
|
||||
}, 'infofilter': {
|
||||
},
|
||||
'infofilter': {
|
||||
'()': 'openllm_core.utils.InfoFilter'
|
||||
}
|
||||
},
|
||||
'handlers': {
|
||||
'bentomlhandler': {
|
||||
'class': 'logging.StreamHandler', 'filters': ['excfilter', 'infofilter'], 'stream': 'ext://sys.stdout'
|
||||
'class': 'logging.StreamHandler',
|
||||
'filters': ['excfilter', 'infofilter'],
|
||||
'stream': 'ext://sys.stdout'
|
||||
},
|
||||
'defaulthandler': {
|
||||
'class': 'logging.StreamHandler', 'level': logging.WARNING
|
||||
'class': 'logging.StreamHandler',
|
||||
'level': logging.WARNING
|
||||
}
|
||||
},
|
||||
'loggers': {
|
||||
'bentoml': {
|
||||
'handlers': ['bentomlhandler', 'defaulthandler'], 'level': logging.INFO, 'propagate': False
|
||||
'handlers': ['bentomlhandler', 'defaulthandler'],
|
||||
'level': logging.INFO,
|
||||
'propagate': False
|
||||
},
|
||||
'openllm': {
|
||||
'handlers': ['bentomlhandler', 'defaulthandler'], 'level': logging.INFO, 'propagate': False
|
||||
'handlers': ['bentomlhandler', 'defaulthandler'],
|
||||
'level': logging.INFO,
|
||||
'propagate': False
|
||||
}
|
||||
},
|
||||
'root': {
|
||||
@@ -227,6 +240,7 @@ def compose(*funcs: AnyCallable) -> AnyCallable:
|
||||
>>> [f(3*x, x+1) for x in range(1,10)]
|
||||
[1.5, 2.0, 2.25, 2.4, 2.5, 2.571, 2.625, 2.667, 2.7]
|
||||
'''
|
||||
|
||||
def compose_two(f1: AnyCallable, f2: AnyCallable) -> AnyCallable:
|
||||
return lambda *args, **kwargs: f1(f2(*args, **kwargs))
|
||||
|
||||
@@ -282,7 +296,12 @@ def generate_context(framework_name: str) -> _ModelContext:
|
||||
if openllm_core.utils.is_tf_available():
|
||||
from bentoml._internal.frameworks.utils.tensorflow import get_tf_version
|
||||
framework_versions['tensorflow'] = get_tf_version()
|
||||
if openllm_core.utils.is_flax_available(): framework_versions.update({'flax': pkg.get_pkg_version('flax'), 'jax': pkg.get_pkg_version('jax'), 'jaxlib': pkg.get_pkg_version('jaxlib')})
|
||||
if openllm_core.utils.is_flax_available():
|
||||
framework_versions.update({
|
||||
'flax': pkg.get_pkg_version('flax'),
|
||||
'jax': pkg.get_pkg_version('jax'),
|
||||
'jaxlib': pkg.get_pkg_version('jaxlib')
|
||||
})
|
||||
return _ModelContext(framework_name=framework_name, framework_versions=framework_versions)
|
||||
|
||||
_TOKENIZER_PREFIX = '_tokenizer_'
|
||||
@@ -301,7 +320,11 @@ _whitelist_modules = {'pkg'}
|
||||
|
||||
# XXX: define all classes, functions import above this line
|
||||
# since _extras will be the locals() import from this file.
|
||||
_extras: dict[str, t.Any] = {k: v for k, v in locals().items() if k in _whitelist_modules or (not isinstance(v, types.ModuleType) and not k.startswith('_'))}
|
||||
_extras: dict[str, t.Any] = {
|
||||
k: v
|
||||
for k, v in locals().items()
|
||||
if k in _whitelist_modules or (not isinstance(v, types.ModuleType) and not k.startswith('_'))
|
||||
}
|
||||
_extras['__openllm_migration__'] = {'ModelEnv': 'EnvVarMixin'}
|
||||
_import_structure: dict[str, list[str]] = {
|
||||
'analytics': [],
|
||||
@@ -310,32 +333,12 @@ _import_structure: dict[str, list[str]] = {
|
||||
'representation': ['ReprMixin'],
|
||||
'lazy': ['LazyModule'],
|
||||
'import_utils': [
|
||||
'OPTIONAL_DEPENDENCIES',
|
||||
'DummyMetaclass',
|
||||
'EnvVarMixin',
|
||||
'require_backends',
|
||||
'is_cpm_kernels_available',
|
||||
'is_einops_available',
|
||||
'is_flax_available',
|
||||
'is_tf_available',
|
||||
'is_vllm_available',
|
||||
'is_torch_available',
|
||||
'is_bitsandbytes_available',
|
||||
'is_peft_available',
|
||||
'is_datasets_available',
|
||||
'is_transformers_supports_kbit',
|
||||
'is_transformers_supports_agent',
|
||||
'is_jupyter_available',
|
||||
'is_jupytext_available',
|
||||
'is_notebook_available',
|
||||
'is_triton_available',
|
||||
'is_autogptq_available',
|
||||
'is_sentencepiece_available',
|
||||
'is_xformers_available',
|
||||
'is_fairscale_available',
|
||||
'is_grpc_available',
|
||||
'is_grpc_health_available',
|
||||
'is_transformers_available'
|
||||
'OPTIONAL_DEPENDENCIES', 'DummyMetaclass', 'EnvVarMixin', 'require_backends', 'is_cpm_kernels_available',
|
||||
'is_einops_available', 'is_flax_available', 'is_tf_available', 'is_vllm_available', 'is_torch_available',
|
||||
'is_bitsandbytes_available', 'is_peft_available', 'is_datasets_available', 'is_transformers_supports_kbit',
|
||||
'is_transformers_supports_agent', 'is_jupyter_available', 'is_jupytext_available', 'is_notebook_available',
|
||||
'is_triton_available', 'is_autogptq_available', 'is_sentencepiece_available', 'is_xformers_available',
|
||||
'is_fairscale_available', 'is_grpc_available', 'is_grpc_health_available', 'is_transformers_available'
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ def _usage_event_debugging() -> bool:
|
||||
return os.environ.get('__BENTOML_DEBUG_USAGE', str(False)).lower() == 'true'
|
||||
|
||||
def silent(func: t.Callable[P, T]) -> t.Callable[P, T]:
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> t.Any:
|
||||
try:
|
||||
@@ -62,6 +63,7 @@ def set_bentoml_tracking() -> t.Generator[None, None, None]:
|
||||
os.environ[_internal_analytics.BENTOML_DO_NOT_TRACK] = original_value
|
||||
|
||||
class EventMeta:
|
||||
|
||||
@property
|
||||
def event_name(self) -> str:
|
||||
# camel case to snake case
|
||||
|
||||
@@ -103,19 +103,26 @@ def make_attr_tuple_class(cls_name: str, attr_names: t.Sequence[str]) -> type[t.
|
||||
def generate_unique_filename(cls: type[t.Any], func_name: str) -> str:
|
||||
return f"<{cls.__name__} generated {func_name} {cls.__module__}.{getattr(cls, '__qualname__', cls.__name__)}>"
|
||||
|
||||
def generate_function(
|
||||
typ: type[t.Any], func_name: str, lines: list[str] | None, args: tuple[str, ...] | None, globs: dict[str, t.Any], annotations: dict[str, t.Any] | None = None
|
||||
) -> AnyCallable:
|
||||
def generate_function(typ: type[t.Any],
|
||||
func_name: str,
|
||||
lines: list[str] | None,
|
||||
args: tuple[str, ...] | None,
|
||||
globs: dict[str, t.Any],
|
||||
annotations: dict[str, t.Any] | None = None) -> AnyCallable:
|
||||
from openllm_core.utils import SHOW_CODEGEN
|
||||
script = 'def %s(%s):\n %s\n' % (func_name, ', '.join(args) if args is not None else '', '\n '.join(lines) if lines else 'pass')
|
||||
script = 'def %s(%s):\n %s\n' % (func_name, ', '.join(args) if args is not None else '',
|
||||
'\n '.join(lines) if lines else 'pass')
|
||||
meth = _make_method(func_name, script, generate_unique_filename(typ, func_name), globs)
|
||||
if annotations: meth.__annotations__ = annotations
|
||||
if SHOW_CODEGEN: logger.info('Generated script for %s:\n\n%s', typ, script)
|
||||
return meth
|
||||
|
||||
def make_env_transformer(
|
||||
cls: type[openllm_core.LLMConfig], model_name: str, suffix: LiteralString | None = None, default_callback: t.Callable[[str, t.Any], t.Any] | None = None, globs: DictStrAny | None = None,
|
||||
) -> AnyCallable:
|
||||
def make_env_transformer(cls: type[openllm_core.LLMConfig],
|
||||
model_name: str,
|
||||
suffix: LiteralString | None = None,
|
||||
default_callback: t.Callable[[str, t.Any], t.Any] | None = None,
|
||||
globs: DictStrAny | None = None,
|
||||
) -> AnyCallable:
|
||||
from openllm_core.utils import dantic
|
||||
from openllm_core.utils import field_env_key
|
||||
|
||||
@@ -124,22 +131,31 @@ def make_env_transformer(
|
||||
|
||||
default_callback = identity if default_callback is None else default_callback
|
||||
globs = {} if globs is None else globs
|
||||
globs.update({'__populate_env': dantic.env_converter, '__default_callback': default_callback, '__field_env': field_env_key, '__suffix': suffix or '', '__model_name': model_name,})
|
||||
globs.update({
|
||||
'__populate_env': dantic.env_converter,
|
||||
'__default_callback': default_callback,
|
||||
'__field_env': field_env_key,
|
||||
'__suffix': suffix or '',
|
||||
'__model_name': model_name,
|
||||
})
|
||||
lines: ListStr = [
|
||||
'__env = lambda field_name: __field_env(__model_name, field_name, __suffix)',
|
||||
'return [',
|
||||
' f.evolve(',
|
||||
' default=__populate_env(__default_callback(f.name, f.default), __env(f.name)),',
|
||||
' metadata={',
|
||||
'__env = lambda field_name: __field_env(__model_name, field_name, __suffix)', 'return [', ' f.evolve(',
|
||||
' default=__populate_env(__default_callback(f.name, f.default), __env(f.name)),', ' metadata={',
|
||||
" 'env': f.metadata.get('env', __env(f.name)),",
|
||||
" 'description': f.metadata.get('description', '(not provided)'),",
|
||||
' },',
|
||||
' )',
|
||||
' for f in fields',
|
||||
']'
|
||||
" 'description': f.metadata.get('description', '(not provided)'),", ' },', ' )',
|
||||
' for f in fields', ']'
|
||||
]
|
||||
fields_ann = 'list[attr.Attribute[t.Any]]'
|
||||
return generate_function(cls, '__auto_env', lines, args=('_', 'fields'), globs=globs, annotations={'_': 'type[LLMConfig]', 'fields': fields_ann, 'return': fields_ann})
|
||||
return generate_function(cls,
|
||||
'__auto_env',
|
||||
lines,
|
||||
args=('_', 'fields'),
|
||||
globs=globs,
|
||||
annotations={
|
||||
'_': 'type[LLMConfig]',
|
||||
'fields': fields_ann,
|
||||
'return': fields_ann
|
||||
})
|
||||
|
||||
def gen_sdk(func: _T, name: str | None = None, **attrs: t.Any) -> _T:
|
||||
'''Enhance sdk with nice repr that plays well with your brain.'''
|
||||
@@ -167,9 +183,7 @@ def gen_sdk(func: _T, name: str | None = None, **attrs: t.Any) -> _T:
|
||||
'__doc__': inspect.cleandoc(doc),
|
||||
'__module__': 'openllm'
|
||||
}),
|
||||
)(func, **attrs),
|
||||
func,
|
||||
)
|
||||
)
|
||||
)(func, **attrs), func,
|
||||
))
|
||||
|
||||
__all__ = ['gen_sdk', 'make_attr_tuple_class', 'make_env_transformer', 'generate_unique_filename', 'generate_function']
|
||||
|
||||
@@ -25,29 +25,20 @@ AnyCallable = t.Callable[..., t.Any]
|
||||
FC = t.TypeVar('FC', bound=t.Union[AnyCallable, click.Command])
|
||||
|
||||
__all__ = [
|
||||
'FC',
|
||||
'attrs_to_options',
|
||||
'Field',
|
||||
'parse_type',
|
||||
'is_typing',
|
||||
'is_literal',
|
||||
'ModuleType',
|
||||
'EnumChoice',
|
||||
'LiteralChoice',
|
||||
'allows_multiple',
|
||||
'is_mapping',
|
||||
'is_container',
|
||||
'parse_container_args',
|
||||
'parse_single_arg',
|
||||
'CUDA',
|
||||
'JsonType',
|
||||
'BytesType'
|
||||
'FC', 'attrs_to_options', 'Field', 'parse_type', 'is_typing', 'is_literal', 'ModuleType', 'EnumChoice',
|
||||
'LiteralChoice', 'allows_multiple', 'is_mapping', 'is_container', 'parse_container_args', 'parse_single_arg',
|
||||
'CUDA', 'JsonType', 'BytesType'
|
||||
]
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return sorted(__all__)
|
||||
|
||||
def attrs_to_options(name: str, field: attr.Attribute[t.Any], model_name: str, typ: t.Any = None, suffix_generation: bool = False, suffix_sampling: bool = False) -> t.Callable[[FC], FC]:
|
||||
def attrs_to_options(name: str,
|
||||
field: attr.Attribute[t.Any],
|
||||
model_name: str,
|
||||
typ: t.Any = None,
|
||||
suffix_generation: bool = False,
|
||||
suffix_sampling: bool = False) -> t.Callable[[FC], FC]:
|
||||
# TODO: support parsing nested attrs class and Union
|
||||
envvar = field.metadata['env']
|
||||
dasherized = inflection.dasherize(name)
|
||||
@@ -63,18 +54,17 @@ def attrs_to_options(name: str, field: attr.Attribute[t.Any], model_name: str, t
|
||||
elif suffix_sampling: identifier = f'{model_name}_sampling_{underscored}'
|
||||
else: identifier = f'{model_name}_{underscored}'
|
||||
|
||||
return cog.optgroup.option(
|
||||
identifier,
|
||||
full_option_name,
|
||||
type=parse_type(typ),
|
||||
required=field.default is attr.NOTHING,
|
||||
default=field.default if field.default not in (attr.NOTHING, None) else None,
|
||||
show_default=True,
|
||||
multiple=allows_multiple(typ) if typ else False,
|
||||
help=field.metadata.get('description', '(No description provided)'),
|
||||
show_envvar=True,
|
||||
envvar=envvar,
|
||||
)
|
||||
return cog.optgroup.option(identifier,
|
||||
full_option_name,
|
||||
type=parse_type(typ),
|
||||
required=field.default is attr.NOTHING,
|
||||
default=field.default if field.default not in (attr.NOTHING, None) else None,
|
||||
show_default=True,
|
||||
multiple=allows_multiple(typ) if typ else False,
|
||||
help=field.metadata.get('description', '(No description provided)'),
|
||||
show_envvar=True,
|
||||
envvar=envvar,
|
||||
)
|
||||
|
||||
def env_converter(value: t.Any, env: str | None = None) -> t.Any:
|
||||
if env is not None:
|
||||
@@ -86,18 +76,16 @@ def env_converter(value: t.Any, env: str | None = None) -> t.Any:
|
||||
raise RuntimeError(f"Failed to parse ({value!r}) from '{env}': {err}") from None
|
||||
return value
|
||||
|
||||
def Field(
|
||||
default: t.Any = None,
|
||||
*,
|
||||
ge: int | float | None = None,
|
||||
le: int | float | None = None,
|
||||
validator: _ValidatorType[t.Any] | None = None,
|
||||
description: str | None = None,
|
||||
env: str | None = None,
|
||||
auto_default: bool = False,
|
||||
use_default_converter: bool = True,
|
||||
**attrs: t.Any
|
||||
) -> t.Any:
|
||||
def Field(default: t.Any = None,
|
||||
*,
|
||||
ge: int | float | None = None,
|
||||
le: int | float | None = None,
|
||||
validator: _ValidatorType[t.Any] | None = None,
|
||||
description: str | None = None,
|
||||
env: str | None = None,
|
||||
auto_default: bool = False,
|
||||
use_default_converter: bool = True,
|
||||
**attrs: t.Any) -> t.Any:
|
||||
"""A decorator that extends attr.field with additional arguments, which provides the same interface as pydantic's Field.
|
||||
|
||||
By default, if both validator and ge are provided, then then ge will be
|
||||
@@ -214,7 +202,8 @@ class ModuleType(ParamType):
|
||||
|
||||
def _import_object(self, value: str) -> t.Any:
|
||||
module_name, class_name = value.rsplit('.', maxsplit=1)
|
||||
if not all(s.isidentifier() for s in module_name.split('.')): raise ValueError(f"'{value}' is not a valid module name")
|
||||
if not all(s.isidentifier() for s in module_name.split('.')):
|
||||
raise ValueError(f"'{value}' is not a valid module name")
|
||||
if not class_name.isidentifier(): raise ValueError(f"Variable '{class_name}' is not a valid identifier")
|
||||
|
||||
module = importlib.import_module(module_name)
|
||||
@@ -262,7 +251,8 @@ class LiteralChoice(EnumChoice):
|
||||
# expect every literal value to belong to the same primitive type
|
||||
values = list(value.__args__)
|
||||
item_type = type(values[0])
|
||||
if not all(isinstance(v, item_type) for v in values): raise ValueError(f'Field {value} contains items of different types.')
|
||||
if not all(isinstance(v, item_type) for v in values):
|
||||
raise ValueError(f'Field {value} contains items of different types.')
|
||||
_mapping = {str(v): v for v in values}
|
||||
super(EnumChoice, self).__init__(list(_mapping), case_sensitive)
|
||||
self.internal_type = item_type
|
||||
|
||||
@@ -27,7 +27,9 @@ if t.TYPE_CHECKING:
|
||||
from openllm_core._typing_compat import LiteralRuntime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
OPTIONAL_DEPENDENCIES = {'opt', 'flan-t5', 'vllm', 'fine-tune', 'ggml', 'agents', 'openai', 'playground', 'gptq', 'grpc'}
|
||||
OPTIONAL_DEPENDENCIES = {
|
||||
'opt', 'flan-t5', 'vllm', 'fine-tune', 'ggml', 'agents', 'openai', 'playground', 'gptq', 'grpc'
|
||||
}
|
||||
ENV_VARS_TRUE_VALUES = {'1', 'ON', 'YES', 'TRUE'}
|
||||
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({'AUTO'})
|
||||
USE_TF = os.environ.get('USE_TF', 'AUTO').upper()
|
||||
@@ -142,19 +144,10 @@ def is_tf_available() -> bool:
|
||||
_tf_version = None
|
||||
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
|
||||
if _tf_available:
|
||||
candidates = (
|
||||
'tensorflow',
|
||||
'tensorflow-cpu',
|
||||
'tensorflow-gpu',
|
||||
'tf-nightly',
|
||||
'tf-nightly-cpu',
|
||||
'tf-nightly-gpu',
|
||||
'intel-tensorflow',
|
||||
'intel-tensorflow-avx512',
|
||||
'tensorflow-rocm',
|
||||
'tensorflow-macos',
|
||||
'tensorflow-aarch64',
|
||||
)
|
||||
candidates = ('tensorflow', 'tensorflow-cpu', 'tensorflow-gpu', 'tf-nightly', 'tf-nightly-cpu',
|
||||
'tf-nightly-gpu', 'intel-tensorflow', 'intel-tensorflow-avx512', 'tensorflow-rocm',
|
||||
'tensorflow-macos', 'tensorflow-aarch64',
|
||||
)
|
||||
_tf_version = None
|
||||
# For the metadata, we have to look for both tensorflow and tensorflow-cpu
|
||||
for _pkg in candidates:
|
||||
@@ -292,15 +285,18 @@ You can install it with pip: `pip install fairscale`. Please note that you may n
|
||||
your runtime after installation.
|
||||
'''
|
||||
|
||||
BACKENDS_MAPPING: BackendOrderedDict = OrderedDict([('flax', (is_flax_available, FLAX_IMPORT_ERROR)), ('tf', (is_tf_available, TENSORFLOW_IMPORT_ERROR)), (
|
||||
'torch', (is_torch_available, PYTORCH_IMPORT_ERROR)
|
||||
), ('vllm', (is_vllm_available, VLLM_IMPORT_ERROR)), ('cpm_kernels', (is_cpm_kernels_available, CPM_KERNELS_IMPORT_ERROR)), ('einops', (is_einops_available, EINOPS_IMPORT_ERROR)), (
|
||||
'triton', (is_triton_available, TRITON_IMPORT_ERROR)
|
||||
), ('datasets', (is_datasets_available, DATASETS_IMPORT_ERROR)), ('peft', (is_peft_available, PEFT_IMPORT_ERROR)), ('bitsandbytes', (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)), (
|
||||
'auto-gptq', (is_autogptq_available, AUTOGPTQ_IMPORT_ERROR)
|
||||
), ('sentencepiece', (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)), ('xformers', (is_xformers_available, XFORMERS_IMPORT_ERROR)), (
|
||||
'fairscale', (is_fairscale_available, FAIRSCALE_IMPORT_ERROR)
|
||||
)])
|
||||
BACKENDS_MAPPING: BackendOrderedDict = OrderedDict([
|
||||
('flax', (is_flax_available, FLAX_IMPORT_ERROR)), ('tf', (is_tf_available, TENSORFLOW_IMPORT_ERROR)),
|
||||
('torch', (is_torch_available, PYTORCH_IMPORT_ERROR)), ('vllm', (is_vllm_available, VLLM_IMPORT_ERROR)),
|
||||
('cpm_kernels', (is_cpm_kernels_available, CPM_KERNELS_IMPORT_ERROR)),
|
||||
('einops', (is_einops_available, EINOPS_IMPORT_ERROR)), ('triton', (is_triton_available, TRITON_IMPORT_ERROR)),
|
||||
('datasets', (is_datasets_available, DATASETS_IMPORT_ERROR)), ('peft', (is_peft_available, PEFT_IMPORT_ERROR)),
|
||||
('bitsandbytes', (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)),
|
||||
('auto-gptq', (is_autogptq_available, AUTOGPTQ_IMPORT_ERROR)),
|
||||
('sentencepiece', (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)),
|
||||
('xformers', (is_xformers_available, XFORMERS_IMPORT_ERROR)),
|
||||
('fairscale', (is_fairscale_available, FAIRSCALE_IMPORT_ERROR))
|
||||
])
|
||||
|
||||
class DummyMetaclass(abc.ABCMeta):
|
||||
'''Metaclass for dummy object.
|
||||
@@ -317,15 +313,22 @@ def require_backends(o: t.Any, backends: t.MutableSequence[str]) -> None:
|
||||
if not isinstance(backends, (list, tuple)): backends = list(backends)
|
||||
name = o.__name__ if hasattr(o, '__name__') else o.__class__.__name__
|
||||
# Raise an error for users who might not realize that classes without "TF" are torch-only
|
||||
if 'torch' in backends and 'tf' not in backends and not is_torch_available() and is_tf_available(): raise ImportError(PYTORCH_IMPORT_ERROR_WITH_TF.format(name))
|
||||
if 'torch' in backends and 'tf' not in backends and not is_torch_available() and is_tf_available():
|
||||
raise ImportError(PYTORCH_IMPORT_ERROR_WITH_TF.format(name))
|
||||
# Raise the inverse error for PyTorch users trying to load TF classes
|
||||
if 'tf' in backends and 'torch' not in backends and is_torch_available() and not is_tf_available(): raise ImportError(TF_IMPORT_ERROR_WITH_PYTORCH.format(name))
|
||||
if 'tf' in backends and 'torch' not in backends and is_torch_available() and not is_tf_available():
|
||||
raise ImportError(TF_IMPORT_ERROR_WITH_PYTORCH.format(name))
|
||||
# Raise an error when vLLM is not available to consider the alternative, order from PyTorch -> Tensorflow -> Flax
|
||||
if 'vllm' in backends:
|
||||
if 'torch' not in backends and is_torch_available() and not is_vllm_available(): raise ImportError(VLLM_IMPORT_ERROR_WITH_PYTORCH.format(name))
|
||||
if 'tf' not in backends and is_tf_available() and not is_vllm_available(): raise ImportError(VLLM_IMPORT_ERROR_WITH_TF.format(name))
|
||||
if 'flax' not in backends and is_flax_available() and not is_vllm_available(): raise ImportError(VLLM_IMPORT_ERROR_WITH_FLAX.format(name))
|
||||
failed = [msg.format(name) for available, msg in (BACKENDS_MAPPING[backend] for backend in backends) if not available()]
|
||||
if 'torch' not in backends and is_torch_available() and not is_vllm_available():
|
||||
raise ImportError(VLLM_IMPORT_ERROR_WITH_PYTORCH.format(name))
|
||||
if 'tf' not in backends and is_tf_available() and not is_vllm_available():
|
||||
raise ImportError(VLLM_IMPORT_ERROR_WITH_TF.format(name))
|
||||
if 'flax' not in backends and is_flax_available() and not is_vllm_available():
|
||||
raise ImportError(VLLM_IMPORT_ERROR_WITH_FLAX.format(name))
|
||||
failed = [
|
||||
msg.format(name) for available, msg in (BACKENDS_MAPPING[backend] for backend in backends) if not available()
|
||||
]
|
||||
if failed: raise ImportError(''.join(failed))
|
||||
|
||||
class EnvVarMixin(ReprMixin):
|
||||
@@ -386,15 +389,13 @@ class EnvVarMixin(ReprMixin):
|
||||
elif hasattr(self, item): return getattr(self, item)
|
||||
raise KeyError(f'Key {item} not found in {self}')
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
implementation: LiteralRuntime = 'pt',
|
||||
model_id: str | None = None,
|
||||
bettertransformer: bool | None = None,
|
||||
quantize: LiteralString | None = None,
|
||||
runtime: t.Literal['ggml', 'transformers'] = 'transformers'
|
||||
) -> None:
|
||||
def __init__(self,
|
||||
model_name: str,
|
||||
implementation: LiteralRuntime = 'pt',
|
||||
model_id: str | None = None,
|
||||
bettertransformer: bool | None = None,
|
||||
quantize: LiteralString | None = None,
|
||||
runtime: t.Literal['ggml', 'transformers'] = 'transformers') -> None:
|
||||
'''EnvVarMixin is a mixin class that returns the value extracted from environment variables.'''
|
||||
from openllm_core.utils import field_env_key
|
||||
self.model_name = inflection.underscore(model_name)
|
||||
@@ -408,7 +409,8 @@ class EnvVarMixin(ReprMixin):
|
||||
|
||||
def _quantize_value(self) -> t.Literal['int8', 'int4', 'gptq'] | None:
|
||||
from . import first_not_none
|
||||
return t.cast(t.Optional[t.Literal['int8', 'int4', 'gptq']], first_not_none(os.environ.get(self['quantize']), default=self._quantize))
|
||||
return t.cast(t.Optional[t.Literal['int8', 'int4', 'gptq']],
|
||||
first_not_none(os.environ.get(self['quantize']), default=self._quantize))
|
||||
|
||||
def _framework_value(self) -> LiteralRuntime:
|
||||
from . import first_not_none
|
||||
@@ -416,7 +418,10 @@ class EnvVarMixin(ReprMixin):
|
||||
|
||||
def _bettertransformer_value(self) -> bool:
|
||||
from . import first_not_none
|
||||
return t.cast(bool, first_not_none(os.environ.get(self['bettertransformer'], str(False)).upper() in ENV_VARS_TRUE_VALUES, default=self._bettertransformer))
|
||||
return t.cast(
|
||||
bool,
|
||||
first_not_none(os.environ.get(self['bettertransformer'], str(False)).upper() in ENV_VARS_TRUE_VALUES,
|
||||
default=self._bettertransformer))
|
||||
|
||||
def _model_id_value(self) -> str | None:
|
||||
from . import first_not_none
|
||||
@@ -424,7 +429,8 @@ class EnvVarMixin(ReprMixin):
|
||||
|
||||
def _runtime_value(self) -> t.Literal['ggml', 'transformers']:
|
||||
from . import first_not_none
|
||||
return t.cast(t.Literal['ggml', 'transformers'], first_not_none(os.environ.get(self['runtime']), default=self._runtime))
|
||||
return t.cast(t.Literal['ggml', 'transformers'],
|
||||
first_not_none(os.environ.get(self['runtime']), default=self._runtime))
|
||||
|
||||
@property
|
||||
def __repr_keys__(self) -> set[str]:
|
||||
|
||||
@@ -60,15 +60,13 @@ _sentinel, _reserved_namespace = object(), {'__openllm_migration__'}
|
||||
|
||||
class LazyModule(types.ModuleType):
|
||||
# Very heavily inspired by optuna.integration._IntegrationModule: https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
module_file: str,
|
||||
import_structure: dict[str, list[str]],
|
||||
module_spec: importlib.machinery.ModuleSpec | None = None,
|
||||
doc: str | None = None,
|
||||
extra_objects: dict[str, t.Any] | None = None
|
||||
):
|
||||
def __init__(self,
|
||||
name: str,
|
||||
module_file: str,
|
||||
import_structure: dict[str, list[str]],
|
||||
module_spec: importlib.machinery.ModuleSpec | None = None,
|
||||
doc: str | None = None,
|
||||
extra_objects: dict[str, t.Any] | None = None):
|
||||
"""Lazily load this module as an object.
|
||||
|
||||
It does instantiate a __all__ and __dir__ for IDE support
|
||||
@@ -111,7 +109,9 @@ class LazyModule(types.ModuleType):
|
||||
|
||||
It also contains a special case for all of the metadata information, such as __version__ and __version_info__.
|
||||
'''
|
||||
if name in _reserved_namespace: raise openllm_core.exceptions.ForbiddenAttributeError(f"'{name}' is a reserved namespace for {self._name} and should not be access nor modified.")
|
||||
if name in _reserved_namespace:
|
||||
raise openllm_core.exceptions.ForbiddenAttributeError(
|
||||
f"'{name}' is a reserved namespace for {self._name} and should not be access nor modified.")
|
||||
dunder_to_metadata = {
|
||||
'__title__': 'Name',
|
||||
'__copyright__': '',
|
||||
@@ -128,27 +128,36 @@ class LazyModule(types.ModuleType):
|
||||
if name in dunder_to_metadata:
|
||||
if name not in {'__version_info__', '__copyright__', '__version__'}:
|
||||
warnings.warn(
|
||||
f"Accessing '{self._name}.{name}' is deprecated. Please consider using 'importlib.metadata' directly to query for openllm packaging metadata.", DeprecationWarning, stacklevel=2
|
||||
)
|
||||
f"Accessing '{self._name}.{name}' is deprecated. Please consider using 'importlib.metadata' directly to query for openllm packaging metadata.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2)
|
||||
meta = importlib.metadata.metadata('openllm')
|
||||
project_url = dict(url.split(', ') for url in t.cast(t.List[str], meta.get_all('Project-URL')))
|
||||
if name == '__license__': return 'Apache-2.0'
|
||||
elif name == '__copyright__': return f"Copyright (c) 2023-{time.strftime('%Y')}, Aaron Pham et al."
|
||||
elif name in ('__uri__', '__url__'): return project_url['GitHub']
|
||||
elif name == '__homepage__': return project_url['Homepage']
|
||||
elif name == '__version_info__': return VersionInfo.from_version_string(meta['version']) # similar to how attrs handle __version_info__
|
||||
elif name == '__author__': return meta['Author-email'].rsplit(' ', 1)[0]
|
||||
elif name == '__email__': return meta['Author-email'].rsplit('<', 1)[1][:-1]
|
||||
elif name == '__version_info__':
|
||||
return VersionInfo.from_version_string(meta['version']) # similar to how attrs handle __version_info__
|
||||
elif name == '__author__':
|
||||
return meta['Author-email'].rsplit(' ', 1)[0]
|
||||
elif name == '__email__':
|
||||
return meta['Author-email'].rsplit('<', 1)[1][:-1]
|
||||
return meta[dunder_to_metadata[name]]
|
||||
if '__openllm_migration__' in self._objects:
|
||||
cur_value = self._objects['__openllm_migration__'].get(name, _sentinel)
|
||||
if cur_value is not _sentinel:
|
||||
warnings.warn(f"'{name}' is deprecated and will be removed in future version. Make sure to use '{cur_value}' instead", DeprecationWarning, stacklevel=3)
|
||||
warnings.warn(
|
||||
f"'{name}' is deprecated and will be removed in future version. Make sure to use '{cur_value}' instead",
|
||||
DeprecationWarning,
|
||||
stacklevel=3)
|
||||
return getattr(self, cur_value)
|
||||
if name in self._objects: return self._objects.__getitem__(name)
|
||||
if name in self._modules: value = self._get_module(name)
|
||||
elif name in self._class_to_module.keys(): value = getattr(self._get_module(self._class_to_module.__getitem__(name)), name)
|
||||
else: raise AttributeError(f'module {self.__name__} has no attribute {name}')
|
||||
elif name in self._class_to_module.keys():
|
||||
value = getattr(self._get_module(self._class_to_module.__getitem__(name)), name)
|
||||
else:
|
||||
raise AttributeError(f'module {self.__name__} has no attribute {name}')
|
||||
setattr(self, name, value)
|
||||
return value
|
||||
|
||||
@@ -156,7 +165,9 @@ class LazyModule(types.ModuleType):
|
||||
try:
|
||||
return importlib.import_module('.' + module_name, self.__name__)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f'Failed to import {self.__name__}.{module_name} because of the following error (look up to see its traceback):\n{e}') from e
|
||||
raise RuntimeError(
|
||||
f'Failed to import {self.__name__}.{module_name} because of the following error (look up to see its traceback):\n{e}'
|
||||
) from e
|
||||
|
||||
# make sure this module is picklable
|
||||
def __reduce__(self) -> tuple[type[LazyModule], tuple[str, str | None, dict[str, list[str]]]]:
|
||||
|
||||
@@ -14,6 +14,7 @@ if t.TYPE_CHECKING:
|
||||
ReprArgs: TypeAlias = t.Generator[t.Tuple[t.Optional[str], t.Any], None, None]
|
||||
|
||||
class ReprMixin:
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def __repr_keys__(self) -> set[str]:
|
||||
|
||||
@@ -26,11 +26,14 @@ else:
|
||||
# configuration for bitsandbytes before import
|
||||
_os.environ["BITSANDBYTES_NOWELCOME"] = _os.environ.get("BITSANDBYTES_NOWELCOME", "1")
|
||||
# NOTE: The following warnings from bitsandbytes, and probably not that important for users to see when DEBUG is False
|
||||
_warnings.filterwarnings("ignore", message="MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization")
|
||||
_warnings.filterwarnings("ignore", message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization")
|
||||
_warnings.filterwarnings(
|
||||
"ignore", message="MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization")
|
||||
_warnings.filterwarnings(
|
||||
"ignore", message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization")
|
||||
_warnings.filterwarnings("ignore", message="The installed version of bitsandbytes was compiled without GPU support.")
|
||||
# NOTE: ignore the following warning from ghapi as it is not important for users
|
||||
_warnings.filterwarnings("ignore", message="Neither GITHUB_TOKEN nor GITHUB_JWT_TOKEN found: running as unauthenticated")
|
||||
_warnings.filterwarnings("ignore",
|
||||
message="Neither GITHUB_TOKEN nor GITHUB_JWT_TOKEN found: running as unauthenticated")
|
||||
|
||||
_import_structure: dict[str, list[str]] = {
|
||||
"exceptions": [],
|
||||
@@ -45,8 +48,13 @@ _import_structure: dict[str, list[str]] = {
|
||||
"_quantisation": ["infer_quantisation_config"],
|
||||
"_embeddings": ["GenericEmbeddingRunnable"],
|
||||
"_llm": ["LLM", "Runner", "LLMRunner", "LLMRunnable", "LLMEmbeddings"],
|
||||
"_generation": ["StopSequenceCriteria", "StopOnTokens", "LogitsProcessorList", "StoppingCriteriaList", "prepare_logits_processor"],
|
||||
"models.auto": ["MODEL_MAPPING_NAMES", "MODEL_FLAX_MAPPING_NAMES", "MODEL_TF_MAPPING_NAMES", "MODEL_VLLM_MAPPING_NAMES"],
|
||||
"_generation": [
|
||||
"StopSequenceCriteria", "StopOnTokens", "LogitsProcessorList", "StoppingCriteriaList",
|
||||
"prepare_logits_processor"
|
||||
],
|
||||
"models.auto": [
|
||||
"MODEL_MAPPING_NAMES", "MODEL_FLAX_MAPPING_NAMES", "MODEL_TF_MAPPING_NAMES", "MODEL_VLLM_MAPPING_NAMES"
|
||||
],
|
||||
"models.chatglm": [],
|
||||
"models.baichuan": [],
|
||||
"models.dolly_v2": [],
|
||||
@@ -73,7 +81,8 @@ if _t.TYPE_CHECKING:
|
||||
from .utils import infer_auto_class as infer_auto_class
|
||||
|
||||
try:
|
||||
if not (openllm_core.utils.is_torch_available() and openllm_core.utils.is_cpm_kernels_available()): raise exceptions.MissingDependencyError
|
||||
if not (openllm_core.utils.is_torch_available() and openllm_core.utils.is_cpm_kernels_available()):
|
||||
raise exceptions.MissingDependencyError
|
||||
except exceptions.MissingDependencyError:
|
||||
_import_structure["utils.dummy_pt_objects"] = ["ChatGLM", "Baichuan"]
|
||||
else:
|
||||
@@ -83,7 +92,8 @@ else:
|
||||
from .models.baichuan import Baichuan as Baichuan
|
||||
from .models.chatglm import ChatGLM as ChatGLM
|
||||
try:
|
||||
if not (openllm_core.utils.is_torch_available() and openllm_core.utils.is_triton_available()): raise exceptions.MissingDependencyError
|
||||
if not (openllm_core.utils.is_torch_available() and openllm_core.utils.is_triton_available()):
|
||||
raise exceptions.MissingDependencyError
|
||||
except exceptions.MissingDependencyError:
|
||||
if "utils.dummy_pt_objects" in _import_structure: _import_structure["utils.dummy_pt_objects"].extend(["MPT"])
|
||||
else: _import_structure["utils.dummy_pt_objects"] = ["MPT"]
|
||||
@@ -91,7 +101,8 @@ else:
|
||||
_import_structure["models.mpt"].extend(["MPT"])
|
||||
if _t.TYPE_CHECKING: from .models.mpt import MPT as MPT
|
||||
try:
|
||||
if not (openllm_core.utils.is_torch_available() and openllm_core.utils.is_einops_available()): raise exceptions.MissingDependencyError
|
||||
if not (openllm_core.utils.is_torch_available() and openllm_core.utils.is_einops_available()):
|
||||
raise exceptions.MissingDependencyError
|
||||
except exceptions.MissingDependencyError:
|
||||
if "utils.dummy_pt_objects" in _import_structure: _import_structure["utils.dummy_pt_objects"].extend(["Falcon"])
|
||||
else: _import_structure["utils.dummy_pt_objects"] = ["Falcon"]
|
||||
@@ -103,7 +114,8 @@ try:
|
||||
if not openllm_core.utils.is_torch_available(): raise exceptions.MissingDependencyError
|
||||
except exceptions.MissingDependencyError:
|
||||
_import_structure["utils.dummy_pt_objects"] = [
|
||||
name for name in dir(utils.dummy_pt_objects) if not name.startswith("_") and name not in ("ChatGLM", "Baichuan", "MPT", "Falcon", "annotations")
|
||||
name for name in dir(utils.dummy_pt_objects)
|
||||
if not name.startswith("_") and name not in ("ChatGLM", "Baichuan", "MPT", "Falcon", "annotations")
|
||||
]
|
||||
else:
|
||||
_import_structure["models.flan_t5"].extend(["FlanT5"])
|
||||
@@ -126,7 +138,9 @@ else:
|
||||
try:
|
||||
if not openllm_core.utils.is_vllm_available(): raise exceptions.MissingDependencyError
|
||||
except exceptions.MissingDependencyError:
|
||||
_import_structure["utils.dummy_vllm_objects"] = [name for name in dir(utils.dummy_vllm_objects) if not name.startswith("_") and name not in ("annotations",)]
|
||||
_import_structure["utils.dummy_vllm_objects"] = [
|
||||
name for name in dir(utils.dummy_vllm_objects) if not name.startswith("_") and name not in ("annotations",)
|
||||
]
|
||||
else:
|
||||
_import_structure["models.baichuan"].extend(["VLLMBaichuan"])
|
||||
_import_structure["models.llama"].extend(["VLLMLlama"])
|
||||
@@ -152,7 +166,9 @@ else:
|
||||
try:
|
||||
if not openllm_core.utils.is_flax_available(): raise exceptions.MissingDependencyError
|
||||
except exceptions.MissingDependencyError:
|
||||
_import_structure["utils.dummy_flax_objects"] = [name for name in dir(utils.dummy_flax_objects) if not name.startswith("_") and name not in ("annotations",)]
|
||||
_import_structure["utils.dummy_flax_objects"] = [
|
||||
name for name in dir(utils.dummy_flax_objects) if not name.startswith("_") and name not in ("annotations",)
|
||||
]
|
||||
else:
|
||||
_import_structure["models.flan_t5"].extend(["FlaxFlanT5"])
|
||||
_import_structure["models.opt"].extend(["FlaxOPT"])
|
||||
@@ -164,7 +180,9 @@ else:
|
||||
try:
|
||||
if not openllm_core.utils.is_tf_available(): raise exceptions.MissingDependencyError
|
||||
except exceptions.MissingDependencyError:
|
||||
_import_structure["utils.dummy_tf_objects"] = [name for name in dir(utils.dummy_tf_objects) if not name.startswith("_") and name not in ("annotations",)]
|
||||
_import_structure["utils.dummy_tf_objects"] = [
|
||||
name for name in dir(utils.dummy_tf_objects) if not name.startswith("_") and name not in ("annotations",)
|
||||
]
|
||||
else:
|
||||
_import_structure["models.flan_t5"].extend(["TFFlanT5"])
|
||||
_import_structure["models.opt"].extend(["TFOPT"])
|
||||
@@ -175,7 +193,10 @@ else:
|
||||
from .models.opt import TFOPT as TFOPT
|
||||
|
||||
# NOTE: update this to sys.modules[__name__] once mypy_extensions can recognize __spec__
|
||||
__lazy = openllm_core.utils.LazyModule(__name__, globals()["__file__"], _import_structure, extra_objects={"COMPILED": COMPILED})
|
||||
__lazy = openllm_core.utils.LazyModule(__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
extra_objects={"COMPILED": COMPILED})
|
||||
__all__ = __lazy.__all__
|
||||
__dir__ = __lazy.__dir__
|
||||
__getattr__ = __lazy.__getattr__
|
||||
|
||||
@@ -26,22 +26,24 @@ def get_or_download(ids: str = _BENTOMODEL_ID) -> bentoml.Model:
|
||||
except bentoml.exceptions.NotFound:
|
||||
model_signatures = {
|
||||
k: ModelSignature(batchable=False)
|
||||
for k in ('forward', 'generate', 'contrastive_search', 'greedy_search', 'sample', 'beam_search', 'beam_sample', 'group_beam_search', 'constrained_beam_search', '__call__')
|
||||
for k in ('forward', 'generate', 'contrastive_search', 'greedy_search', 'sample', 'beam_search', 'beam_sample',
|
||||
'group_beam_search', 'constrained_beam_search', '__call__')
|
||||
}
|
||||
with bentoml.models.create(
|
||||
ids,
|
||||
module=MODULE_NAME,
|
||||
api_version=API_VERSION,
|
||||
options=ModelOptions(),
|
||||
context=openllm.utils.generate_context(framework_name='transformers'),
|
||||
labels={
|
||||
'runtime': 'pt', 'framework': 'openllm'
|
||||
},
|
||||
signatures=model_signatures
|
||||
) as bentomodel:
|
||||
with bentoml.models.create(ids,
|
||||
module=MODULE_NAME,
|
||||
api_version=API_VERSION,
|
||||
options=ModelOptions(),
|
||||
context=openllm.utils.generate_context(framework_name='transformers'),
|
||||
labels={
|
||||
'runtime': 'pt',
|
||||
'framework': 'openllm'
|
||||
},
|
||||
signatures=model_signatures) as bentomodel:
|
||||
snapshot_download(
|
||||
_GENERIC_EMBEDDING_ID, local_dir=bentomodel.path, local_dir_use_symlinks=False, ignore_patterns=['*.safetensors', '*.h5', '*.ot', '*.pdf', '*.md', '.gitattributes', 'LICENSE.txt']
|
||||
)
|
||||
_GENERIC_EMBEDDING_ID,
|
||||
local_dir=bentomodel.path,
|
||||
local_dir_use_symlinks=False,
|
||||
ignore_patterns=['*.safetensors', '*.h5', '*.ot', '*.pdf', '*.md', '.gitattributes', 'LICENSE.txt'])
|
||||
return bentomodel
|
||||
|
||||
class GenericEmbeddingRunnable(bentoml.Runnable):
|
||||
@@ -66,7 +68,10 @@ class GenericEmbeddingRunnable(bentoml.Runnable):
|
||||
model_output = self.model(**encoded_input)
|
||||
# Perform pooling and normalize
|
||||
sentence_embeddings = F.normalize(self.mean_pooling(model_output, attention_mask), p=2, dim=1)
|
||||
return [openllm.LLMEmbeddings(embeddings=sentence_embeddings.cpu().numpy(), num_tokens=int(torch.sum(attention_mask).item()))]
|
||||
return [
|
||||
openllm.LLMEmbeddings(embeddings=sentence_embeddings.cpu().numpy(),
|
||||
num_tokens=int(torch.sum(attention_mask).item()))
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def mean_pooling(model_output: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@@ -14,23 +14,30 @@ LogitsProcessorList = transformers.LogitsProcessorList
|
||||
StoppingCriteriaList = transformers.StoppingCriteriaList
|
||||
|
||||
class StopSequenceCriteria(transformers.StoppingCriteria):
|
||||
def __init__(self, stop_sequences: str | list[str], tokenizer: transformers.PreTrainedTokenizer | transformers.PreTrainedTokenizerBase | transformers.PreTrainedTokenizerFast):
|
||||
|
||||
def __init__(self, stop_sequences: str | list[str], tokenizer: transformers.PreTrainedTokenizer |
|
||||
transformers.PreTrainedTokenizerBase | transformers.PreTrainedTokenizerFast):
|
||||
if isinstance(stop_sequences, str): stop_sequences = [stop_sequences]
|
||||
self.stop_sequences, self.tokenizer = stop_sequences, tokenizer
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: t.Any, **_: t.Any) -> bool:
|
||||
return any(self.tokenizer.decode(input_ids.tolist()[0]).endswith(stop_sequence) for stop_sequence in self.stop_sequences)
|
||||
return any(
|
||||
self.tokenizer.decode(input_ids.tolist()[0]).endswith(stop_sequence) for stop_sequence in self.stop_sequences)
|
||||
|
||||
class StopOnTokens(transformers.StoppingCriteria):
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **_: t.Any) -> bool:
|
||||
return input_ids[0][-1] in {50278, 50279, 50277, 1, 0}
|
||||
|
||||
def prepare_logits_processor(config: openllm.LLMConfig) -> transformers.LogitsProcessorList:
|
||||
generation_config = config.generation_config
|
||||
logits_processor = transformers.LogitsProcessorList()
|
||||
if generation_config['temperature'] >= 1e-5 and generation_config['temperature'] != 1.0: logits_processor.append(transformers.TemperatureLogitsWarper(generation_config['temperature']))
|
||||
if generation_config['repetition_penalty'] > 1.0: logits_processor.append(transformers.RepetitionPenaltyLogitsProcessor(generation_config['repetition_penalty']))
|
||||
if 1e-8 <= generation_config['top_p']: logits_processor.append(transformers.TopPLogitsWarper(generation_config['top_p']))
|
||||
if generation_config['temperature'] >= 1e-5 and generation_config['temperature'] != 1.0:
|
||||
logits_processor.append(transformers.TemperatureLogitsWarper(generation_config['temperature']))
|
||||
if generation_config['repetition_penalty'] > 1.0:
|
||||
logits_processor.append(transformers.RepetitionPenaltyLogitsProcessor(generation_config['repetition_penalty']))
|
||||
if 1e-8 <= generation_config['top_p']:
|
||||
logits_processor.append(transformers.TopPLogitsWarper(generation_config['top_p']))
|
||||
if generation_config['top_k'] > 0: logits_processor.append(transformers.TopKLogitsWarper(generation_config['top_k']))
|
||||
return logits_processor
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -15,21 +15,27 @@ if t.TYPE_CHECKING:
|
||||
|
||||
from ._llm import LLM
|
||||
|
||||
autogptq, torch, transformers = LazyLoader('autogptq', globals(), 'auto_gptq'), LazyLoader('torch', globals(), 'torch'), LazyLoader('transformers', globals(), 'transformers')
|
||||
autogptq, torch, transformers = LazyLoader('autogptq', globals(),
|
||||
'auto_gptq'), LazyLoader('torch', globals(), 'torch'), LazyLoader(
|
||||
'transformers', globals(), 'transformers')
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
QuantiseMode = t.Literal['int8', 'int4', 'gptq']
|
||||
|
||||
@overload
|
||||
def infer_quantisation_config(cls: type[LLM[t.Any, t.Any]], quantise: t.Literal['int8', 'int4'], **attrs: t.Any) -> tuple[transformers.BitsAndBytesConfig, DictStrAny]:
|
||||
def infer_quantisation_config(cls: type[LLM[t.Any, t.Any]], quantise: t.Literal['int8', 'int4'],
|
||||
**attrs: t.Any) -> tuple[transformers.BitsAndBytesConfig, DictStrAny]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def infer_quantisation_config(cls: type[LLM[t.Any, t.Any]], quantise: t.Literal['gptq'], **attrs: t.Any) -> tuple[autogptq.BaseQuantizeConfig, DictStrAny]:
|
||||
def infer_quantisation_config(cls: type[LLM[t.Any, t.Any]], quantise: t.Literal['gptq'],
|
||||
**attrs: t.Any) -> tuple[autogptq.BaseQuantizeConfig, DictStrAny]:
|
||||
...
|
||||
|
||||
def infer_quantisation_config(cls: type[LLM[t.Any, t.Any]], quantise: QuantiseMode, **attrs: t.Any) -> tuple[transformers.BitsAndBytesConfig | autogptq.BaseQuantizeConfig, DictStrAny]:
|
||||
def infer_quantisation_config(
|
||||
cls: type[LLM[t.Any, t.Any]], quantise: QuantiseMode,
|
||||
**attrs: t.Any) -> tuple[transformers.BitsAndBytesConfig | autogptq.BaseQuantizeConfig, DictStrAny]:
|
||||
# 8 bit configuration
|
||||
int8_threshold = attrs.pop('llm_int8_threshhold', 6.0)
|
||||
int8_enable_fp32_cpu_offload = attrs.pop('llm_int8_enable_fp32_cpu_offload', False)
|
||||
@@ -50,13 +56,12 @@ def infer_quantisation_config(cls: type[LLM[t.Any, t.Any]], quantise: QuantiseMo
|
||||
if 'lm_head' not in int8_skip_modules and cls.config_class.__openllm_model_type__ == 'causal_lm':
|
||||
logger.debug("Skipping 'lm_head' for quantization for %s", cls.__name__)
|
||||
int8_skip_modules.append('lm_head')
|
||||
return transformers.BitsAndBytesConfig(
|
||||
load_in_8bit=True,
|
||||
llm_int8_enable_fp32_cpu_offload=int8_enable_fp32_cpu_offload,
|
||||
llm_int8_threshhold=int8_threshold,
|
||||
llm_int8_skip_modules=int8_skip_modules,
|
||||
llm_int8_has_fp16_weight=int8_has_fp16_weight,
|
||||
)
|
||||
return transformers.BitsAndBytesConfig(load_in_8bit=True,
|
||||
llm_int8_enable_fp32_cpu_offload=int8_enable_fp32_cpu_offload,
|
||||
llm_int8_threshhold=int8_threshold,
|
||||
llm_int8_skip_modules=int8_skip_modules,
|
||||
llm_int8_has_fp16_weight=int8_has_fp16_weight,
|
||||
)
|
||||
|
||||
# 4 bit configuration
|
||||
int4_compute_dtype = attrs.pop('bnb_4bit_compute_dtype', torch.bfloat16)
|
||||
@@ -66,18 +71,21 @@ def infer_quantisation_config(cls: type[LLM[t.Any, t.Any]], quantise: QuantiseMo
|
||||
# NOTE: Quantization setup
|
||||
# quantize is a openllm.LLM feature, where we can quantize the model
|
||||
# with bitsandbytes or quantization aware training.
|
||||
if not is_bitsandbytes_available(): raise RuntimeError("Quantization requires bitsandbytes to be installed. Make sure to install OpenLLM with 'pip install \"openllm[fine-tune]\"'")
|
||||
if not is_bitsandbytes_available():
|
||||
raise RuntimeError(
|
||||
"Quantization requires bitsandbytes to be installed. Make sure to install OpenLLM with 'pip install \"openllm[fine-tune]\"'"
|
||||
)
|
||||
if quantise == 'int8': quantisation_config = create_int8_config(int8_skip_modules)
|
||||
elif quantise == 'int4':
|
||||
if is_transformers_supports_kbit():
|
||||
quantisation_config = transformers.BitsAndBytesConfig(
|
||||
load_in_4bit=True, bnb_4bit_compute_dtype=int4_compute_dtype, bnb_4bit_quant_type=int4_quant_type, bnb_4bit_use_double_quant=int4_use_double_quant
|
||||
)
|
||||
quantisation_config = transformers.BitsAndBytesConfig(load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=int4_compute_dtype,
|
||||
bnb_4bit_quant_type=int4_quant_type,
|
||||
bnb_4bit_use_double_quant=int4_use_double_quant)
|
||||
else:
|
||||
logger.warning(
|
||||
"'quantize' is set to int4, while the current transformers version %s does not support k-bit quantization. k-bit quantization is supported since transformers 4.30, therefore make sure to install the latest version of transformers either via PyPI or from git source: 'pip install git+https://github.com/huggingface/transformers'. Fallback to int8 quantisation.",
|
||||
pkg.pkg_version_info('transformers')
|
||||
)
|
||||
pkg.pkg_version_info('transformers'))
|
||||
quantisation_config = create_int8_config(int8_skip_modules)
|
||||
elif quantise == 'gptq':
|
||||
if not is_autogptq_available():
|
||||
|
||||
@@ -21,11 +21,14 @@ if t.TYPE_CHECKING:
|
||||
from bentoml._internal.runner.runner import AbstractRunner
|
||||
from bentoml._internal.runner.runner import RunnerMethod
|
||||
from openllm_core._typing_compat import TypeAlias
|
||||
_EmbeddingMethod: TypeAlias = RunnerMethod[t.Union[bentoml.Runnable, openllm.LLMRunnable[t.Any, t.Any]], [t.List[str]], t.Sequence[openllm.EmbeddingsOutput]]
|
||||
_EmbeddingMethod: TypeAlias = RunnerMethod[t.Union[bentoml.Runnable, openllm.LLMRunnable[t.Any, t.Any]],
|
||||
[t.List[str]], t.Sequence[openllm.EmbeddingsOutput]]
|
||||
|
||||
# The following warnings from bitsandbytes, and probably not that important for users to see
|
||||
warnings.filterwarnings('ignore', message='MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization')
|
||||
warnings.filterwarnings('ignore', message='MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization')
|
||||
warnings.filterwarnings('ignore',
|
||||
message='MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization')
|
||||
warnings.filterwarnings('ignore',
|
||||
message='MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization')
|
||||
warnings.filterwarnings('ignore', message='The installed version of bitsandbytes was compiled without GPU support.')
|
||||
|
||||
model = os.environ.get('OPENLLM_MODEL', '{__model_name__}') # openllm: model name
|
||||
@@ -37,15 +40,23 @@ generic_embedding_runner = bentoml.Runner(
|
||||
name='llm-generic-embedding',
|
||||
scheduling_strategy=openllm_core.CascadingResourceStrategy,
|
||||
max_batch_size=32,
|
||||
max_latency_ms=300
|
||||
)
|
||||
max_latency_ms=300)
|
||||
runners: list[AbstractRunner] = [runner]
|
||||
if not runner.supports_embeddings: runners.append(generic_embedding_runner)
|
||||
svc = bentoml.Service(name=f"llm-{llm_config['start_name']}-service", runners=runners)
|
||||
|
||||
_JsonInput = bentoml.io.JSON.from_sample({'prompt': '', 'llm_config': llm_config.model_dump(flatten=True), 'adapter_name': None})
|
||||
_JsonInput = bentoml.io.JSON.from_sample({
|
||||
'prompt': '',
|
||||
'llm_config': llm_config.model_dump(flatten=True),
|
||||
'adapter_name': None
|
||||
})
|
||||
|
||||
@svc.api(route='/v1/generate', input=_JsonInput, output=bentoml.io.JSON.from_sample({'responses': [], 'configuration': llm_config.model_dump(flatten=True)}))
|
||||
@svc.api(route='/v1/generate',
|
||||
input=_JsonInput,
|
||||
output=bentoml.io.JSON.from_sample({
|
||||
'responses': [],
|
||||
'configuration': llm_config.model_dump(flatten=True)
|
||||
}))
|
||||
async def generate_v1(input_dict: dict[str, t.Any]) -> openllm.GenerationOutput:
|
||||
qa_inputs = openllm.GenerationInput.from_llm_config(llm_config)(**input_dict)
|
||||
config = qa_inputs.llm_config.model_dump()
|
||||
@@ -56,67 +67,45 @@ async def generate_v1(input_dict: dict[str, t.Any]) -> openllm.GenerationOutput:
|
||||
async def generate_stream_v1(input_dict: dict[str, t.Any]) -> t.AsyncGenerator[str, None]:
|
||||
echo = input_dict.pop('echo', False)
|
||||
qa_inputs = openllm.GenerationInput.from_llm_config(llm_config)(**input_dict)
|
||||
return runner.generate_iterator.async_stream(qa_inputs.prompt, adapter_name=qa_inputs.adapter_name, echo=echo, **qa_inputs.llm_config.model_dump())
|
||||
return runner.generate_iterator.async_stream(qa_inputs.prompt,
|
||||
adapter_name=qa_inputs.adapter_name,
|
||||
echo=echo,
|
||||
**qa_inputs.llm_config.model_dump())
|
||||
|
||||
@svc.api(
|
||||
route='/v1/metadata',
|
||||
input=bentoml.io.Text(),
|
||||
output=bentoml.io.JSON.from_sample({
|
||||
'model_id': runner.llm.model_id,
|
||||
'timeout': 3600,
|
||||
'model_name': llm_config['model_name'],
|
||||
'framework': runner.llm_framework,
|
||||
'configuration': '',
|
||||
'supports_embeddings': runner.supports_embeddings,
|
||||
'supports_hf_agent': runner.supports_hf_agent
|
||||
})
|
||||
)
|
||||
@svc.api(route='/v1/metadata',
|
||||
input=bentoml.io.Text(),
|
||||
output=bentoml.io.JSON.from_sample({
|
||||
'model_id': runner.llm.model_id,
|
||||
'timeout': 3600,
|
||||
'model_name': llm_config['model_name'],
|
||||
'framework': runner.llm_framework,
|
||||
'configuration': '',
|
||||
'supports_embeddings': runner.supports_embeddings,
|
||||
'supports_hf_agent': runner.supports_hf_agent
|
||||
}))
|
||||
def metadata_v1(_: str) -> openllm.MetadataOutput:
|
||||
return openllm.MetadataOutput(
|
||||
timeout=llm_config['timeout'],
|
||||
model_name=llm_config['model_name'],
|
||||
framework=llm_config['env']['framework_value'],
|
||||
model_id=runner.llm.model_id,
|
||||
configuration=llm_config.model_dump_json().decode(),
|
||||
supports_embeddings=runner.supports_embeddings,
|
||||
supports_hf_agent=runner.supports_hf_agent
|
||||
)
|
||||
return openllm.MetadataOutput(timeout=llm_config['timeout'],
|
||||
model_name=llm_config['model_name'],
|
||||
framework=llm_config['env']['framework_value'],
|
||||
model_id=runner.llm.model_id,
|
||||
configuration=llm_config.model_dump_json().decode(),
|
||||
supports_embeddings=runner.supports_embeddings,
|
||||
supports_hf_agent=runner.supports_hf_agent)
|
||||
|
||||
@svc.api(
|
||||
route='/v1/embeddings',
|
||||
input=bentoml.io.JSON.from_sample(['Hey Jude, welcome to the jungle!', 'What is the meaning of life?']),
|
||||
output=bentoml.io.JSON.from_sample({
|
||||
'embeddings': [
|
||||
0.007917795330286026,
|
||||
-0.014421648345887661,
|
||||
0.00481307040899992,
|
||||
0.007331526838243008,
|
||||
-0.0066398633643984795,
|
||||
0.00945580005645752,
|
||||
0.0087016262114048,
|
||||
-0.010709521360695362,
|
||||
0.012635177001357079,
|
||||
0.010541186667978764,
|
||||
-0.00730888033285737,
|
||||
-0.001783102168701589,
|
||||
0.02339819073677063,
|
||||
-0.010825827717781067,
|
||||
-0.015888236463069916,
|
||||
0.01876218430697918,
|
||||
0.0076906150206923485,
|
||||
0.0009032754460349679,
|
||||
-0.010024012066423893,
|
||||
0.01090280432254076,
|
||||
-0.008668390102684498,
|
||||
0.02070549875497818,
|
||||
0.0014594447566196322,
|
||||
-0.018775740638375282,
|
||||
-0.014814382418990135,
|
||||
0.01796768605709076
|
||||
],
|
||||
'num_tokens': 20
|
||||
})
|
||||
)
|
||||
@svc.api(route='/v1/embeddings',
|
||||
input=bentoml.io.JSON.from_sample(['Hey Jude, welcome to the jungle!', 'What is the meaning of life?']),
|
||||
output=bentoml.io.JSON.from_sample({
|
||||
'embeddings': [
|
||||
0.007917795330286026, -0.014421648345887661, 0.00481307040899992, 0.007331526838243008,
|
||||
-0.0066398633643984795, 0.00945580005645752, 0.0087016262114048, -0.010709521360695362,
|
||||
0.012635177001357079, 0.010541186667978764, -0.00730888033285737, -0.001783102168701589,
|
||||
0.02339819073677063, -0.010825827717781067, -0.015888236463069916, 0.01876218430697918,
|
||||
0.0076906150206923485, 0.0009032754460349679, -0.010024012066423893, 0.01090280432254076,
|
||||
-0.008668390102684498, 0.02070549875497818, 0.0014594447566196322, -0.018775740638375282,
|
||||
-0.014814382418990135, 0.01796768605709076
|
||||
],
|
||||
'num_tokens': 20
|
||||
}))
|
||||
async def embeddings_v1(phrases: list[str]) -> openllm.EmbeddingsOutput:
|
||||
embed_call: _EmbeddingMethod = runner.embeddings if runner.supports_embeddings else generic_embedding_runner.encode # type: ignore[type-arg,assignment,valid-type]
|
||||
responses = (await embed_call.async_run(phrases))[0]
|
||||
@@ -132,7 +121,8 @@ if runner.supports_hf_agent and openllm.utils.is_transformers_supports_agent():
|
||||
raise openllm.exceptions.OpenLLMException(f'Invalid JSON input received: {err}') from None
|
||||
stop = input_data.parameters.pop('stop', ['\n'])
|
||||
try:
|
||||
return JSONResponse(await runner.generate_one.async_run(input_data.inputs, stop, **input_data.parameters), status_code=200)
|
||||
return JSONResponse(await runner.generate_one.async_run(input_data.inputs, stop, **input_data.parameters),
|
||||
status_code=200)
|
||||
except NotImplementedError:
|
||||
return JSONResponse(f"'{model}' is currently not supported with HuggingFace agents.", status_code=500)
|
||||
|
||||
@@ -142,7 +132,8 @@ if runner.supports_hf_agent and openllm.utils.is_transformers_supports_agent():
|
||||
# general metadata app
|
||||
async def list_adapter_v1(_: Request) -> Response:
|
||||
res: dict[str, t.Any] = {}
|
||||
if runner.peft_adapters['success'] is True: res['result'] = {k: v.to_dict() for k, v in runner.peft_adapters['result'].items()}
|
||||
if runner.peft_adapters['success'] is True:
|
||||
res['result'] = {k: v.to_dict() for k, v in runner.peft_adapters['result'].items()}
|
||||
res.update({'success': runner.peft_adapters['success'], 'error_msg': runner.peft_adapters['error_msg']})
|
||||
return JSONResponse(res, status_code=200)
|
||||
|
||||
|
||||
@@ -10,7 +10,10 @@ from openllm_core.utils import LazyModule
|
||||
|
||||
_import_structure: dict[str, list[str]] = {
|
||||
'_package': ['create_bento', 'build_editable', 'construct_python_options', 'construct_docker_options'],
|
||||
'oci': ['CONTAINER_NAMES', 'get_base_container_tag', 'build_container', 'get_base_container_name', 'supported_registries', 'RefResolver']
|
||||
'oci': [
|
||||
'CONTAINER_NAMES', 'get_base_container_tag', 'build_container', 'get_base_container_name',
|
||||
'supported_registries', 'RefResolver'
|
||||
]
|
||||
}
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
|
||||
@@ -43,14 +43,18 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
OPENLLM_DEV_BUILD = 'OPENLLM_DEV_BUILD'
|
||||
|
||||
def build_editable(path: str, package: t.Literal['openllm', 'openllm_core', 'openllm_client'] = 'openllm') -> str | None:
|
||||
def build_editable(path: str,
|
||||
package: t.Literal['openllm', 'openllm_core', 'openllm_client'] = 'openllm') -> str | None:
|
||||
'''Build OpenLLM if the OPENLLM_DEV_BUILD environment variable is set.'''
|
||||
if str(os.environ.get(OPENLLM_DEV_BUILD, False)).lower() != 'true': return None
|
||||
# We need to build the package in editable mode, so that we can import it
|
||||
from build import ProjectBuilder
|
||||
from build.env import IsolatedEnvBuilder
|
||||
module_location = openllm_core.utils.pkg.source_locations(package)
|
||||
if not module_location: raise RuntimeError('Could not find the source location of OpenLLM. Make sure to unset OPENLLM_DEV_BUILD if you are developing OpenLLM.')
|
||||
if not module_location:
|
||||
raise RuntimeError(
|
||||
'Could not find the source location of OpenLLM. Make sure to unset OPENLLM_DEV_BUILD if you are developing OpenLLM.'
|
||||
)
|
||||
pyproject_path = Path(module_location).parent.parent / 'pyproject.toml'
|
||||
if os.path.isfile(pyproject_path.__fspath__()):
|
||||
logger.info('Generating built wheels for package %s...', package)
|
||||
@@ -60,9 +64,14 @@ def build_editable(path: str, package: t.Literal['openllm', 'openllm_core', 'ope
|
||||
builder.scripts_dir = env.scripts_dir
|
||||
env.install(builder.build_system_requires)
|
||||
return builder.build('wheel', path, config_settings={'--global-option': '--quiet'})
|
||||
raise RuntimeError('Custom OpenLLM build is currently not supported. Please install OpenLLM from PyPI or built it from Git source.')
|
||||
raise RuntimeError(
|
||||
'Custom OpenLLM build is currently not supported. Please install OpenLLM from PyPI or built it from Git source.')
|
||||
|
||||
def construct_python_options(llm: openllm.LLM[t.Any, t.Any], llm_fs: FS, extra_dependencies: tuple[str, ...] | None = None, adapter_map: dict[str, str | None] | None = None,) -> PythonOptions:
|
||||
def construct_python_options(llm: openllm.LLM[t.Any, t.Any],
|
||||
llm_fs: FS,
|
||||
extra_dependencies: tuple[str, ...] | None = None,
|
||||
adapter_map: dict[str, str | None] | None = None,
|
||||
) -> PythonOptions:
|
||||
packages = ['openllm', 'scipy'] # apparently bnb misses this one
|
||||
if adapter_map is not None: packages += ['openllm[fine-tune]']
|
||||
# NOTE: add openllm to the default dependencies
|
||||
@@ -73,27 +82,24 @@ def construct_python_options(llm: openllm.LLM[t.Any, t.Any], llm_fs: FS, extra_d
|
||||
|
||||
req = llm.config['requirements']
|
||||
if req is not None: packages.extend(req)
|
||||
if str(os.environ.get('BENTOML_BUNDLE_LOCAL_BUILD', False)).lower() == 'false': packages.append(f"bentoml>={'.'.join([str(i) for i in openllm_core.utils.pkg.pkg_version_info('bentoml')])}")
|
||||
if str(os.environ.get('BENTOML_BUNDLE_LOCAL_BUILD', False)).lower() == 'false':
|
||||
packages.append(f"bentoml>={'.'.join([str(i) for i in openllm_core.utils.pkg.pkg_version_info('bentoml')])}")
|
||||
|
||||
env = llm.config['env']
|
||||
framework_envvar = env['framework_value']
|
||||
if framework_envvar == 'flax':
|
||||
if not openllm_core.utils.is_flax_available(): raise ValueError(f"Flax is not available, while {env.framework} is set to 'flax'")
|
||||
packages.extend([importlib.metadata.version('flax'), importlib.metadata.version('jax'), importlib.metadata.version('jaxlib')])
|
||||
if not openllm_core.utils.is_flax_available():
|
||||
raise ValueError(f"Flax is not available, while {env.framework} is set to 'flax'")
|
||||
packages.extend(
|
||||
[importlib.metadata.version('flax'),
|
||||
importlib.metadata.version('jax'),
|
||||
importlib.metadata.version('jaxlib')])
|
||||
elif framework_envvar == 'tf':
|
||||
if not openllm_core.utils.is_tf_available(): raise ValueError(f"TensorFlow is not available, while {env.framework} is set to 'tf'")
|
||||
candidates = (
|
||||
'tensorflow',
|
||||
'tensorflow-cpu',
|
||||
'tensorflow-gpu',
|
||||
'tf-nightly',
|
||||
'tf-nightly-cpu',
|
||||
'tf-nightly-gpu',
|
||||
'intel-tensorflow',
|
||||
'intel-tensorflow-avx512',
|
||||
'tensorflow-rocm',
|
||||
'tensorflow-macos',
|
||||
)
|
||||
if not openllm_core.utils.is_tf_available():
|
||||
raise ValueError(f"TensorFlow is not available, while {env.framework} is set to 'tf'")
|
||||
candidates = ('tensorflow', 'tensorflow-cpu', 'tensorflow-gpu', 'tf-nightly', 'tf-nightly-cpu', 'tf-nightly-gpu',
|
||||
'intel-tensorflow', 'intel-tensorflow-avx512', 'tensorflow-rocm', 'tensorflow-macos',
|
||||
)
|
||||
# For the metadata, we have to look for both tensorflow and tensorflow-cpu
|
||||
for candidate in candidates:
|
||||
try:
|
||||
@@ -106,28 +112,28 @@ def construct_python_options(llm: openllm.LLM[t.Any, t.Any], llm_fs: FS, extra_d
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
pass # Ok to ignore here since we actually need to check for all possible tensorflow distribution.
|
||||
else:
|
||||
if not openllm_core.utils.is_torch_available(): raise ValueError('PyTorch is not available. Make sure to have it locally installed.')
|
||||
if not openllm_core.utils.is_torch_available():
|
||||
raise ValueError('PyTorch is not available. Make sure to have it locally installed.')
|
||||
packages.extend([f'torch>={importlib.metadata.version("torch")}'])
|
||||
wheels: list[str] = []
|
||||
built_wheels: list[str | None] = [
|
||||
build_editable(llm_fs.getsyspath('/'), t.cast(t.Literal['openllm', 'openllm_core', 'openllm_client'], p)) for p in ('openllm_core', 'openllm_client', 'openllm')
|
||||
build_editable(llm_fs.getsyspath('/'), t.cast(t.Literal['openllm', 'openllm_core', 'openllm_client'], p))
|
||||
for p in ('openllm_core', 'openllm_client', 'openllm')
|
||||
]
|
||||
if all(i for i in built_wheels): wheels.extend([llm_fs.getsyspath(f"/{i.split('/')[-1]}") for i in t.cast(t.List[str], built_wheels)])
|
||||
return PythonOptions(packages=packages, wheels=wheels, lock_packages=False, extra_index_url=['https://download.pytorch.org/whl/cu118'])
|
||||
if all(i for i in built_wheels):
|
||||
wheels.extend([llm_fs.getsyspath(f"/{i.split('/')[-1]}") for i in t.cast(t.List[str], built_wheels)])
|
||||
return PythonOptions(packages=packages,
|
||||
wheels=wheels,
|
||||
lock_packages=False,
|
||||
extra_index_url=['https://download.pytorch.org/whl/cu118'])
|
||||
|
||||
def construct_docker_options(
|
||||
llm: openllm.LLM[t.Any, t.Any],
|
||||
_: FS,
|
||||
workers_per_resource: float,
|
||||
quantize: LiteralString | None,
|
||||
bettertransformer: bool | None,
|
||||
adapter_map: dict[str, str | None] | None,
|
||||
dockerfile_template: str | None,
|
||||
runtime: t.Literal['ggml', 'transformers'],
|
||||
serialisation_format: t.Literal['safetensors', 'legacy'],
|
||||
container_registry: LiteralContainerRegistry,
|
||||
container_version_strategy: LiteralContainerVersionStrategy
|
||||
) -> DockerOptions:
|
||||
def construct_docker_options(llm: openllm.LLM[t.Any, t.Any], _: FS, workers_per_resource: float,
|
||||
quantize: LiteralString | None, bettertransformer: bool | None,
|
||||
adapter_map: dict[str, str | None] | None, dockerfile_template: str | None,
|
||||
runtime: t.Literal['ggml', 'transformers'], serialisation_format: t.Literal['safetensors',
|
||||
'legacy'],
|
||||
container_registry: LiteralContainerRegistry,
|
||||
container_version_strategy: LiteralContainerVersionStrategy) -> DockerOptions:
|
||||
from openllm.cli._factory import parse_config_options
|
||||
environ = parse_config_options(llm.config, llm.config['timeout'], workers_per_resource, None, True, os.environ.copy())
|
||||
env: openllm_core.utils.EnvVarMixin = llm.config['env']
|
||||
@@ -146,12 +152,18 @@ def construct_docker_options(
|
||||
if adapter_map: env_dict['BITSANDBYTES_NOWELCOME'] = os.environ.get('BITSANDBYTES_NOWELCOME', '1')
|
||||
|
||||
# We need to handle None separately here, as env from subprocess doesn't accept None value.
|
||||
_env = openllm_core.utils.EnvVarMixin(llm.config['model_name'], bettertransformer=bettertransformer, quantize=quantize, runtime=runtime)
|
||||
_env = openllm_core.utils.EnvVarMixin(llm.config['model_name'],
|
||||
bettertransformer=bettertransformer,
|
||||
quantize=quantize,
|
||||
runtime=runtime)
|
||||
|
||||
env_dict[_env.bettertransformer] = str(_env['bettertransformer_value'])
|
||||
if _env['quantize_value'] is not None: env_dict[_env.quantize] = t.cast(str, _env['quantize_value'])
|
||||
env_dict[_env.runtime] = _env['runtime_value']
|
||||
return DockerOptions(base_image=f'{oci.CONTAINER_NAMES[container_registry]}:{oci.get_base_container_tag(container_version_strategy)}', env=env_dict, dockerfile_template=dockerfile_template)
|
||||
return DockerOptions(
|
||||
base_image=f'{oci.CONTAINER_NAMES[container_registry]}:{oci.get_base_container_tag(container_version_strategy)}',
|
||||
env=env_dict,
|
||||
dockerfile_template=dockerfile_template)
|
||||
|
||||
OPENLLM_MODEL_NAME = '# openllm: model name'
|
||||
OPENLLM_MODEL_ADAPTER_MAP = '# openllm: model adapter map'
|
||||
@@ -185,47 +197,58 @@ _service_file = Path(os.path.abspath(__file__)).parent.parent / '_service.py'
|
||||
def write_service(llm: openllm.LLM[t.Any, t.Any], adapter_map: dict[str, str | None] | None, llm_fs: FS) -> None:
|
||||
from openllm_core.utils import DEBUG
|
||||
model_name = llm.config['model_name']
|
||||
logger.debug('Generating service file for %s at %s (dir=%s)', model_name, llm.config['service_name'], llm_fs.getsyspath('/'))
|
||||
logger.debug('Generating service file for %s at %s (dir=%s)', model_name, llm.config['service_name'],
|
||||
llm_fs.getsyspath('/'))
|
||||
with open(_service_file.__fspath__(), 'r') as f:
|
||||
src_contents = f.readlines()
|
||||
for it in src_contents:
|
||||
if OPENLLM_MODEL_NAME in it: src_contents[src_contents.index(it)] = (ModelNameFormatter(model_name).vformat(it)[:-(len(OPENLLM_MODEL_NAME) + 3)] + '\n')
|
||||
if OPENLLM_MODEL_NAME in it:
|
||||
src_contents[src_contents.index(it)] = (
|
||||
ModelNameFormatter(model_name).vformat(it)[:-(len(OPENLLM_MODEL_NAME) + 3)] + '\n')
|
||||
elif OPENLLM_MODEL_ADAPTER_MAP in it:
|
||||
src_contents[src_contents.index(it)] = (ModelAdapterMapFormatter(orjson.dumps(adapter_map).decode()).vformat(it)[:-(len(OPENLLM_MODEL_ADAPTER_MAP) + 3)] + '\n')
|
||||
src_contents[src_contents.index(it)] = (ModelAdapterMapFormatter(
|
||||
orjson.dumps(adapter_map).decode()).vformat(it)[:-(len(OPENLLM_MODEL_ADAPTER_MAP) + 3)] + '\n')
|
||||
script = f"# GENERATED BY 'openllm build {model_name}'. DO NOT EDIT\n\n" + ''.join(src_contents)
|
||||
if DEBUG: logger.info('Generated script:\n%s', script)
|
||||
llm_fs.writetext(llm.config['service_name'], script)
|
||||
|
||||
@inject
|
||||
def create_bento(
|
||||
bento_tag: bentoml.Tag,
|
||||
llm_fs: FS,
|
||||
llm: openllm.LLM[t.Any, t.Any],
|
||||
workers_per_resource: str | float,
|
||||
quantize: LiteralString | None,
|
||||
bettertransformer: bool | None,
|
||||
dockerfile_template: str | None,
|
||||
adapter_map: dict[str, str | None] | None = None,
|
||||
extra_dependencies: tuple[str, ...] | None = None,
|
||||
runtime: t.Literal['ggml', 'transformers'] = 'transformers',
|
||||
serialisation_format: t.Literal['safetensors', 'legacy'] = 'safetensors',
|
||||
container_registry: LiteralContainerRegistry = 'ecr',
|
||||
container_version_strategy: LiteralContainerVersionStrategy = 'release',
|
||||
_bento_store: BentoStore = Provide[BentoMLContainer.bento_store],
|
||||
_model_store: ModelStore = Provide[BentoMLContainer.model_store]
|
||||
) -> bentoml.Bento:
|
||||
def create_bento(bento_tag: bentoml.Tag,
|
||||
llm_fs: FS,
|
||||
llm: openllm.LLM[t.Any, t.Any],
|
||||
workers_per_resource: str | float,
|
||||
quantize: LiteralString | None,
|
||||
bettertransformer: bool | None,
|
||||
dockerfile_template: str | None,
|
||||
adapter_map: dict[str, str | None] | None = None,
|
||||
extra_dependencies: tuple[str, ...] | None = None,
|
||||
runtime: t.Literal['ggml', 'transformers'] = 'transformers',
|
||||
serialisation_format: t.Literal['safetensors', 'legacy'] = 'safetensors',
|
||||
container_registry: LiteralContainerRegistry = 'ecr',
|
||||
container_version_strategy: LiteralContainerVersionStrategy = 'release',
|
||||
_bento_store: BentoStore = Provide[BentoMLContainer.bento_store],
|
||||
_model_store: ModelStore = Provide[BentoMLContainer.model_store]) -> bentoml.Bento:
|
||||
framework_envvar = llm.config['env']['framework_value']
|
||||
labels = dict(llm.identifying_params)
|
||||
labels.update({'_type': llm.llm_type, '_framework': framework_envvar, 'start_name': llm.config['start_name'], 'base_name_or_path': llm.model_id, 'bundler': 'openllm.bundle'})
|
||||
labels.update({
|
||||
'_type': llm.llm_type,
|
||||
'_framework': framework_envvar,
|
||||
'start_name': llm.config['start_name'],
|
||||
'base_name_or_path': llm.model_id,
|
||||
'bundler': 'openllm.bundle'
|
||||
})
|
||||
if adapter_map: labels.update(adapter_map)
|
||||
if isinstance(workers_per_resource, str):
|
||||
if workers_per_resource == 'round_robin': workers_per_resource = 1.0
|
||||
elif workers_per_resource == 'conserved': workers_per_resource = 1.0 if openllm_core.utils.device_count() == 0 else float(1 / openllm_core.utils.device_count())
|
||||
elif workers_per_resource == 'conserved':
|
||||
workers_per_resource = 1.0 if openllm_core.utils.device_count() == 0 else float(1 /
|
||||
openllm_core.utils.device_count())
|
||||
else:
|
||||
try:
|
||||
workers_per_resource = float(workers_per_resource)
|
||||
except ValueError:
|
||||
raise ValueError("'workers_per_resource' only accept ['round_robin', 'conserved'] as possible strategies.") from None
|
||||
raise ValueError(
|
||||
"'workers_per_resource' only accept ['round_robin', 'conserved'] as possible strategies.") from None
|
||||
elif isinstance(workers_per_resource, int):
|
||||
workers_per_resource = float(workers_per_resource)
|
||||
logger.info("Building Bento for '%s'", llm.config['start_name'])
|
||||
@@ -233,19 +256,18 @@ def create_bento(
|
||||
write_service(llm, adapter_map, llm_fs)
|
||||
|
||||
llm_spec = ModelSpec.from_item({'tag': str(llm.tag), 'alias': llm.tag.name})
|
||||
build_config = BentoBuildConfig(
|
||||
service=f"{llm.config['service_name']}:svc",
|
||||
name=bento_tag.name,
|
||||
labels=labels,
|
||||
description=f"OpenLLM service for {llm.config['start_name']}",
|
||||
include=list(llm_fs.walk.files()),
|
||||
exclude=['/venv', '/.venv', '__pycache__/', '*.py[cod]', '*$py.class'],
|
||||
python=construct_python_options(llm, llm_fs, extra_dependencies, adapter_map),
|
||||
models=[llm_spec],
|
||||
docker=construct_docker_options(
|
||||
llm, llm_fs, workers_per_resource, quantize, bettertransformer, adapter_map, dockerfile_template, runtime, serialisation_format, container_registry, container_version_strategy
|
||||
)
|
||||
)
|
||||
build_config = BentoBuildConfig(service=f"{llm.config['service_name']}:svc",
|
||||
name=bento_tag.name,
|
||||
labels=labels,
|
||||
description=f"OpenLLM service for {llm.config['start_name']}",
|
||||
include=list(llm_fs.walk.files()),
|
||||
exclude=['/venv', '/.venv', '__pycache__/', '*.py[cod]', '*$py.class'],
|
||||
python=construct_python_options(llm, llm_fs, extra_dependencies, adapter_map),
|
||||
models=[llm_spec],
|
||||
docker=construct_docker_options(llm, llm_fs, workers_per_resource, quantize,
|
||||
bettertransformer, adapter_map, dockerfile_template,
|
||||
runtime, serialisation_format, container_registry,
|
||||
container_version_strategy))
|
||||
|
||||
bento = bentoml.Bento.create(build_config=build_config, version=bento_tag.version, build_ctx=llm_fs.getsyspath('/'))
|
||||
# NOTE: the model_id_path here are only used for setting this environment variable within the container built with for BentoLLM.
|
||||
@@ -261,6 +283,7 @@ def create_bento(
|
||||
if openllm_core.utils.DEBUG: logger.info('Generated script:\n%s', script)
|
||||
|
||||
bento._fs.writetext(service_fs_path, script)
|
||||
if 'model_store' in inspect.signature(bento.save).parameters: return bento.save(bento_store=_bento_store, model_store=_model_store)
|
||||
if 'model_store' in inspect.signature(bento.save).parameters:
|
||||
return bento.save(bento_store=_bento_store, model_store=_model_store)
|
||||
# backward arguments. `model_store` is added recently
|
||||
return bento.save(bento_store=_bento_store)
|
||||
|
||||
@@ -42,7 +42,11 @@ ROOT_DIR = pathlib.Path(os.path.abspath('__file__')).parent.parent.parent
|
||||
# but in the future, we can infer based on git repo and everything to make it more options for users
|
||||
# to build the base image. For now, all of the base image will be <registry>/bentoml/openllm:...
|
||||
# NOTE: The ECR registry is the public one and currently only @bentoml team has access to push it.
|
||||
_CONTAINER_REGISTRY: dict[LiteralContainerRegistry, str] = {'docker': 'docker.io/bentoml/openllm', 'gh': 'ghcr.io/bentoml/openllm', 'ecr': 'public.ecr.aws/y5w8i4y6/bentoml/openllm'}
|
||||
_CONTAINER_REGISTRY: dict[LiteralContainerRegistry, str] = {
|
||||
'docker': 'docker.io/bentoml/openllm',
|
||||
'gh': 'ghcr.io/bentoml/openllm',
|
||||
'ecr': 'public.ecr.aws/y5w8i4y6/bentoml/openllm'
|
||||
}
|
||||
|
||||
# TODO: support custom fork. Currently it only support openllm main.
|
||||
_OWNER = 'bentoml'
|
||||
@@ -64,7 +68,8 @@ def _commit_time_range(r: int = 5) -> str:
|
||||
class VersionNotSupported(openllm.exceptions.OpenLLMException):
|
||||
"""Raised when the stable release is too low that it doesn't include OpenLLM base container."""
|
||||
|
||||
_RefTuple: type[RefTuple] = openllm_core.utils.codegen.make_attr_tuple_class('_RefTuple', ['git_hash', 'version', 'strategy'])
|
||||
_RefTuple: type[RefTuple] = openllm_core.utils.codegen.make_attr_tuple_class('_RefTuple',
|
||||
['git_hash', 'version', 'strategy'])
|
||||
|
||||
def nightly_resolver(cls: type[RefResolver]) -> str:
|
||||
# NOTE: all openllm container will have sha-<git_hash[:7]>
|
||||
@@ -78,7 +83,11 @@ def nightly_resolver(cls: type[RefResolver]) -> str:
|
||||
commits = t.cast('list[dict[str, t.Any]]', cls._ghapi.repos.list_commits(since=_commit_time_range()))
|
||||
return next(f'sha-{it["sha"][:7]}' for it in commits if '[skip ci]' not in it['commit']['message'])
|
||||
# now is the correct behaviour
|
||||
return orjson.loads(subprocess.check_output([docker_bin, 'run', '--rm', '-it', 'quay.io/skopeo/stable:latest', 'list-tags', 'docker://ghcr.io/bentoml/openllm']).decode().strip())['Tags'][-2]
|
||||
return orjson.loads(
|
||||
subprocess.check_output([
|
||||
docker_bin, 'run', '--rm', '-it', 'quay.io/skopeo/stable:latest', 'list-tags',
|
||||
'docker://ghcr.io/bentoml/openllm'
|
||||
]).decode().strip())['Tags'][-2]
|
||||
|
||||
@attr.attrs(eq=False, order=False, slots=True, frozen=True)
|
||||
class RefResolver:
|
||||
@@ -98,16 +107,20 @@ class RefResolver:
|
||||
# NOTE: This strategy will only support openllm>0.2.12
|
||||
meta: dict[str, t.Any] = cls._ghapi.repos.get_latest_release()
|
||||
version_str = meta['name'].lstrip('v')
|
||||
version: tuple[str, str | None] = (cls._ghapi.git.get_ref(ref=f"tags/{meta['name']}")['object']['sha'], version_str)
|
||||
version: tuple[str,
|
||||
str | None] = (cls._ghapi.git.get_ref(ref=f"tags/{meta['name']}")['object']['sha'], version_str)
|
||||
else:
|
||||
version = ('', version_str)
|
||||
if openllm_core.utils.VersionInfo.from_version_string(t.cast(str, version_str)) < (0, 2, 12):
|
||||
raise VersionNotSupported(f"Version {version_str} doesn't support OpenLLM base container. Consider using 'nightly' or upgrade 'openllm>=0.2.12'")
|
||||
raise VersionNotSupported(
|
||||
f"Version {version_str} doesn't support OpenLLM base container. Consider using 'nightly' or upgrade 'openllm>=0.2.12'"
|
||||
)
|
||||
return _RefTuple((*version, 'release' if _use_base_strategy else 'custom'))
|
||||
|
||||
@classmethod
|
||||
@functools.lru_cache(maxsize=64)
|
||||
def from_strategy(cls, strategy_or_version: t.Literal['release', 'nightly'] | LiteralString | None = None) -> RefResolver:
|
||||
def from_strategy(cls,
|
||||
strategy_or_version: t.Literal['release', 'nightly'] | LiteralString | None = None) -> RefResolver:
|
||||
# using default strategy
|
||||
if strategy_or_version is None or strategy_or_version == 'release': return cls(*cls._release_ref())
|
||||
elif strategy_or_version == 'latest': return cls('latest', '0.0.0', 'latest')
|
||||
@@ -115,7 +128,8 @@ class RefResolver:
|
||||
_ref = cls._nightly_ref()
|
||||
return cls(_ref[0], '0.0.0', _ref[-1])
|
||||
else:
|
||||
logger.warning('Using custom %s. Make sure that it is at lease 0.2.12 for base container support.', strategy_or_version)
|
||||
logger.warning('Using custom %s. Make sure that it is at lease 0.2.12 for base container support.',
|
||||
strategy_or_version)
|
||||
return cls(*cls._release_ref(version_str=strategy_or_version))
|
||||
|
||||
@property
|
||||
@@ -129,21 +143,27 @@ class RefResolver:
|
||||
def get_base_container_tag(strategy: LiteralContainerVersionStrategy | None = None) -> str:
|
||||
return RefResolver.from_strategy(strategy).tag
|
||||
|
||||
def build_container(
|
||||
registries: LiteralContainerRegistry | t.Sequence[LiteralContainerRegistry] | None = None,
|
||||
version_strategy: LiteralContainerVersionStrategy = 'release',
|
||||
push: bool = False,
|
||||
machine: bool = False
|
||||
) -> dict[str | LiteralContainerRegistry, str]:
|
||||
def build_container(registries: LiteralContainerRegistry | t.Sequence[LiteralContainerRegistry] | None = None,
|
||||
version_strategy: LiteralContainerVersionStrategy = 'release',
|
||||
push: bool = False,
|
||||
machine: bool = False) -> dict[str | LiteralContainerRegistry, str]:
|
||||
try:
|
||||
if not _BUILDER.health(): raise openllm.exceptions.Error
|
||||
except (openllm.exceptions.Error, subprocess.CalledProcessError):
|
||||
raise RuntimeError('Building base container requires BuildKit (via Buildx) to be installed. See https://docs.docker.com/build/buildx/install/ for instalation instruction.') from None
|
||||
if openllm_core.utils.device_count() == 0: raise RuntimeError('Building base container requires GPUs (None available)')
|
||||
if not shutil.which('nvidia-container-runtime'): raise RuntimeError('NVIDIA Container Toolkit is required to compile CUDA kernel in container.')
|
||||
if not _module_location: raise RuntimeError("Failed to determine source location of 'openllm'. (Possible broken installation)")
|
||||
raise RuntimeError(
|
||||
'Building base container requires BuildKit (via Buildx) to be installed. See https://docs.docker.com/build/buildx/install/ for instalation instruction.'
|
||||
) from None
|
||||
if openllm_core.utils.device_count() == 0:
|
||||
raise RuntimeError('Building base container requires GPUs (None available)')
|
||||
if not shutil.which('nvidia-container-runtime'):
|
||||
raise RuntimeError('NVIDIA Container Toolkit is required to compile CUDA kernel in container.')
|
||||
if not _module_location:
|
||||
raise RuntimeError("Failed to determine source location of 'openllm'. (Possible broken installation)")
|
||||
pyproject_path = pathlib.Path(_module_location).parent.parent / 'pyproject.toml'
|
||||
if not pyproject_path.exists(): raise ValueError("This utility can only be run within OpenLLM git repository. Clone it first with 'git clone https://github.com/bentoml/OpenLLM.git'")
|
||||
if not pyproject_path.exists():
|
||||
raise ValueError(
|
||||
"This utility can only be run within OpenLLM git repository. Clone it first with 'git clone https://github.com/bentoml/OpenLLM.git'"
|
||||
)
|
||||
if not registries:
|
||||
tags: dict[str | LiteralContainerRegistry, str] = {
|
||||
alias: f'{value}:{get_base_container_tag(version_strategy)}' for alias, value in _CONTAINER_REGISTRY.items()
|
||||
@@ -152,24 +172,27 @@ def build_container(
|
||||
registries = [registries] if isinstance(registries, str) else list(registries)
|
||||
tags = {name: f'{_CONTAINER_REGISTRY[name]}:{get_base_container_tag(version_strategy)}' for name in registries}
|
||||
try:
|
||||
outputs = _BUILDER.build(
|
||||
file=pathlib.Path(__file__).parent.joinpath('Dockerfile').resolve().__fspath__(),
|
||||
context_path=pyproject_path.parent.__fspath__(),
|
||||
tag=tuple(tags.values()),
|
||||
push=push,
|
||||
progress='plain' if openllm_core.utils.get_debug_mode() else 'auto',
|
||||
quiet=machine
|
||||
)
|
||||
outputs = _BUILDER.build(file=pathlib.Path(__file__).parent.joinpath('Dockerfile').resolve().__fspath__(),
|
||||
context_path=pyproject_path.parent.__fspath__(),
|
||||
tag=tuple(tags.values()),
|
||||
push=push,
|
||||
progress='plain' if openllm_core.utils.get_debug_mode() else 'auto',
|
||||
quiet=machine)
|
||||
if machine and outputs is not None: tags['image_sha'] = outputs.decode('utf-8').strip()
|
||||
except Exception as err:
|
||||
raise openllm.exceptions.OpenLLMException(f'Failed to containerize base container images (Scroll up to see error above, or set OPENLLMDEVDEBUG=True for more traceback):\n{err}') from err
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
f'Failed to containerize base container images (Scroll up to see error above, or set OPENLLMDEVDEBUG=True for more traceback):\n{err}'
|
||||
) from err
|
||||
return tags
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
CONTAINER_NAMES: dict[LiteralContainerRegistry, str]
|
||||
supported_registries: list[str]
|
||||
|
||||
__all__ = ['CONTAINER_NAMES', 'get_base_container_tag', 'build_container', 'get_base_container_name', 'supported_registries', 'RefResolver']
|
||||
__all__ = [
|
||||
'CONTAINER_NAMES', 'get_base_container_tag', 'build_container', 'get_base_container_name', 'supported_registries',
|
||||
'RefResolver'
|
||||
]
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return sorted(__all__)
|
||||
|
||||
@@ -40,27 +40,46 @@ _AnyCallable = t.Callable[..., t.Any]
|
||||
FC = t.TypeVar('FC', bound=t.Union[_AnyCallable, click.Command])
|
||||
|
||||
def bento_complete_envvar(ctx: click.Context, param: click.Parameter, incomplete: str) -> list[sc.CompletionItem]:
|
||||
return [sc.CompletionItem(str(it.tag), help='Bento') for it in bentoml.list() if str(it.tag).startswith(incomplete) and all(k in it.info.labels for k in {'start_name', 'bundler'})]
|
||||
return [
|
||||
sc.CompletionItem(str(it.tag), help='Bento')
|
||||
for it in bentoml.list()
|
||||
if str(it.tag).startswith(incomplete) and all(k in it.info.labels for k in {'start_name', 'bundler'})
|
||||
]
|
||||
|
||||
def model_complete_envvar(ctx: click.Context, param: click.Parameter, incomplete: str) -> list[sc.CompletionItem]:
|
||||
return [sc.CompletionItem(inflection.dasherize(it), help='Model') for it in openllm.CONFIG_MAPPING if it.startswith(incomplete)]
|
||||
return [
|
||||
sc.CompletionItem(inflection.dasherize(it), help='Model')
|
||||
for it in openllm.CONFIG_MAPPING
|
||||
if it.startswith(incomplete)
|
||||
]
|
||||
|
||||
def parse_config_options(config: LLMConfig, server_timeout: int, workers_per_resource: float, device: t.Tuple[str, ...] | None, cors: bool, environ: DictStrAny) -> DictStrAny:
|
||||
def parse_config_options(config: LLMConfig, server_timeout: int, workers_per_resource: float,
|
||||
device: t.Tuple[str, ...] | None, cors: bool, environ: DictStrAny) -> DictStrAny:
|
||||
# TODO: Support amd.com/gpu on k8s
|
||||
_bentoml_config_options_env = environ.pop('BENTOML_CONFIG_OPTIONS', '')
|
||||
_bentoml_config_options_opts = [
|
||||
'tracing.sample_rate=1.0',
|
||||
f'api_server.traffic.timeout={server_timeout}',
|
||||
'tracing.sample_rate=1.0', f'api_server.traffic.timeout={server_timeout}',
|
||||
f'runners."llm-{config["start_name"]}-runner".traffic.timeout={config["timeout"]}',
|
||||
f'runners."llm-{config["start_name"]}-runner".workers_per_resource={workers_per_resource}'
|
||||
]
|
||||
if device:
|
||||
if len(device) > 1: _bentoml_config_options_opts.extend([f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"[{idx}]={dev}' for idx, dev in enumerate(device)])
|
||||
else: _bentoml_config_options_opts.append(f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"=[{device[0]}]')
|
||||
_bentoml_config_options_opts.append(f'runners."llm-generic-embedding".resources.cpu={openllm.get_resource({"cpu":"system"},"cpu")}')
|
||||
if len(device) > 1:
|
||||
_bentoml_config_options_opts.extend([
|
||||
f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"[{idx}]={dev}'
|
||||
for idx, dev in enumerate(device)
|
||||
])
|
||||
else:
|
||||
_bentoml_config_options_opts.append(
|
||||
f'runners."llm-{config["start_name"]}-runner".resources."nvidia.com/gpu"=[{device[0]}]')
|
||||
_bentoml_config_options_opts.append(
|
||||
f'runners."llm-generic-embedding".resources.cpu={openllm.get_resource({"cpu":"system"},"cpu")}')
|
||||
if cors:
|
||||
_bentoml_config_options_opts.extend(['api_server.http.cors.enabled=true', 'api_server.http.cors.access_control_allow_origins="*"'])
|
||||
_bentoml_config_options_opts.extend([f'api_server.http.cors.access_control_allow_methods[{idx}]="{it}"' for idx, it in enumerate(['GET', 'OPTIONS', 'POST', 'HEAD', 'PUT'])])
|
||||
_bentoml_config_options_opts.extend(
|
||||
['api_server.http.cors.enabled=true', 'api_server.http.cors.access_control_allow_origins="*"'])
|
||||
_bentoml_config_options_opts.extend([
|
||||
f'api_server.http.cors.access_control_allow_methods[{idx}]="{it}"'
|
||||
for idx, it in enumerate(['GET', 'OPTIONS', 'POST', 'HEAD', 'PUT'])
|
||||
])
|
||||
_bentoml_config_options_env += ' ' if _bentoml_config_options_env else '' + ' '.join(_bentoml_config_options_opts)
|
||||
environ['BENTOML_CONFIG_OPTIONS'] = _bentoml_config_options_env
|
||||
if DEBUG: logger.debug('Setting BENTOML_CONFIG_OPTIONS=%s', _bentoml_config_options_env)
|
||||
@@ -82,7 +101,10 @@ def _id_callback(ctx: click.Context, _: click.Parameter, value: t.Tuple[str, ...
|
||||
ctx.params[_adapter_mapping_key][adapter_id] = adapter_name[0] if len(adapter_name) > 0 else None
|
||||
return None
|
||||
|
||||
def start_command_factory(group: click.Group, model: str, _context_settings: DictStrAny | None = None, _serve_grpc: bool = False) -> click.Command:
|
||||
def start_command_factory(group: click.Group,
|
||||
model: str,
|
||||
_context_settings: DictStrAny | None = None,
|
||||
_serve_grpc: bool = False) -> click.Command:
|
||||
llm_config = openllm.AutoConfig.for_model(model)
|
||||
command_attrs: DictStrAny = dict(
|
||||
name=llm_config['model_name'],
|
||||
@@ -113,37 +135,29 @@ Available official model_id(s): [default: {llm_config['default_id']}]
|
||||
if llm_config['requires_gpu'] and openllm.utils.device_count() < 1:
|
||||
# NOTE: The model requires GPU, therefore we will return a dummy command
|
||||
command_attrs.update({
|
||||
'short_help': '(Disabled because there is no GPU available)', 'help': f'{model} is currently not available to run on your local machine because it requires GPU for inference.'
|
||||
'short_help':
|
||||
'(Disabled because there is no GPU available)',
|
||||
'help':
|
||||
f'{model} is currently not available to run on your local machine because it requires GPU for inference.'
|
||||
})
|
||||
return noop_command(group, llm_config, _serve_grpc, **command_attrs)
|
||||
|
||||
@group.command(**command_attrs)
|
||||
@start_decorator(llm_config, serve_grpc=_serve_grpc)
|
||||
@click.pass_context
|
||||
def start_cmd(
|
||||
ctx: click.Context,
|
||||
/,
|
||||
server_timeout: int,
|
||||
model_id: str | None,
|
||||
model_version: str | None,
|
||||
workers_per_resource: t.Literal['conserved', 'round_robin'] | LiteralString,
|
||||
device: t.Tuple[str, ...],
|
||||
quantize: t.Literal['int8', 'int4', 'gptq'] | None,
|
||||
bettertransformer: bool | None,
|
||||
runtime: t.Literal['ggml', 'transformers'],
|
||||
fast: bool,
|
||||
serialisation_format: t.Literal['safetensors', 'legacy'],
|
||||
cors: bool,
|
||||
adapter_id: str | None,
|
||||
return_process: bool,
|
||||
**attrs: t.Any,
|
||||
) -> LLMConfig | subprocess.Popen[bytes]:
|
||||
def start_cmd(ctx: click.Context, /, server_timeout: int, model_id: str | None, model_version: str | None,
|
||||
workers_per_resource: t.Literal['conserved', 'round_robin'] | LiteralString, device: t.Tuple[str, ...],
|
||||
quantize: t.Literal['int8', 'int4', 'gptq'] | None, bettertransformer: bool | None,
|
||||
runtime: t.Literal['ggml', 'transformers'], fast: bool, serialisation_format: t.Literal['safetensors',
|
||||
'legacy'],
|
||||
cors: bool, adapter_id: str | None, return_process: bool, **attrs: t.Any,
|
||||
) -> LLMConfig | subprocess.Popen[bytes]:
|
||||
fast = str(fast).upper() in openllm.utils.ENV_VARS_TRUE_VALUES
|
||||
if serialisation_format == 'safetensors' and quantize is not None and os.environ.get('OPENLLM_SERIALIZATION_WARNING', str(True)).upper() in openllm.utils.ENV_VARS_TRUE_VALUES:
|
||||
if serialisation_format == 'safetensors' and quantize is not None and os.environ.get(
|
||||
'OPENLLM_SERIALIZATION_WARNING', str(True)).upper() in openllm.utils.ENV_VARS_TRUE_VALUES:
|
||||
termui.echo(
|
||||
f"'--quantize={quantize}' might not work with 'safetensors' serialisation format. Use with caution!. To silence this warning, set \"OPENLLM_SERIALIZATION_WARNING=False\"\nNote: You can always fallback to '--serialisation legacy' when running quantisation.",
|
||||
fg='yellow'
|
||||
)
|
||||
fg='yellow')
|
||||
adapter_map: dict[str, str | None] | None = attrs.pop(_adapter_mapping_key, None)
|
||||
config, server_attrs = llm_config.model_validate_click(**attrs)
|
||||
server_timeout = openllm.utils.first_not_none(server_timeout, default=config['timeout'])
|
||||
@@ -169,16 +183,21 @@ Available official model_id(s): [default: {llm_config['default_id']}]
|
||||
wpr = float(wpr)
|
||||
|
||||
# Create a new model env to work with the envvar during CLI invocation
|
||||
env = openllm.utils.EnvVarMixin(
|
||||
config['model_name'], config.default_implementation(), model_id=model_id or config['default_id'], bettertransformer=bettertransformer, quantize=quantize, runtime=runtime
|
||||
)
|
||||
env = openllm.utils.EnvVarMixin(config['model_name'],
|
||||
config.default_implementation(),
|
||||
model_id=model_id or config['default_id'],
|
||||
bettertransformer=bettertransformer,
|
||||
quantize=quantize,
|
||||
runtime=runtime)
|
||||
prerequisite_check(ctx, config, quantize, adapter_map, int(1 / wpr))
|
||||
|
||||
# NOTE: This is to set current configuration
|
||||
start_env = os.environ.copy()
|
||||
start_env = parse_config_options(config, server_timeout, wpr, device, cors, start_env)
|
||||
if fast:
|
||||
termui.echo(f"Fast mode is enabled. Make sure the model is available in local store before 'start': 'openllm import {model}{' --model-id ' + model_id if model_id else ''}'", fg='yellow')
|
||||
termui.echo(
|
||||
f"Fast mode is enabled. Make sure the model is available in local store before 'start': 'openllm import {model}{' --model-id ' + model_id if model_id else ''}'",
|
||||
fg='yellow')
|
||||
|
||||
start_env.update({
|
||||
'OPENLLM_MODEL': model,
|
||||
@@ -194,18 +213,28 @@ Available official model_id(s): [default: {llm_config['default_id']}]
|
||||
if bettertransformer is not None: start_env[env.bettertransformer] = str(env['bettertransformer_value'])
|
||||
if quantize is not None: start_env[env.quantize] = str(t.cast(str, env['quantize_value']))
|
||||
|
||||
llm = openllm.utils.infer_auto_class(env['framework_value']).for_model(
|
||||
model, model_id=start_env[env.model_id], model_version=model_version, llm_config=config, ensure_available=not fast, adapter_map=adapter_map, serialisation=serialisation_format
|
||||
)
|
||||
llm = openllm.utils.infer_auto_class(env['framework_value']).for_model(model,
|
||||
model_id=start_env[env.model_id],
|
||||
model_version=model_version,
|
||||
llm_config=config,
|
||||
ensure_available=not fast,
|
||||
adapter_map=adapter_map,
|
||||
serialisation=serialisation_format)
|
||||
start_env.update({env.config: llm.config.model_dump_json().decode()})
|
||||
|
||||
server = bentoml.GrpcServer('_service:svc', **server_attrs) if _serve_grpc else bentoml.HTTPServer('_service:svc', **server_attrs)
|
||||
server = bentoml.GrpcServer('_service:svc', **server_attrs) if _serve_grpc else bentoml.HTTPServer(
|
||||
'_service:svc', **server_attrs)
|
||||
openllm.utils.analytics.track_start_init(llm.config)
|
||||
|
||||
def next_step(model_name: str, adapter_map: DictStrAny | None) -> None:
|
||||
cmd_name = f'openllm build {model_name}'
|
||||
if adapter_map is not None: cmd_name += ' ' + ' '.join([f'--adapter-id {s}' for s in [f'{p}:{name}' if name not in (None, 'default') else p for p, name in adapter_map.items()]])
|
||||
if not openllm.utils.get_quiet_mode(): termui.echo(f"\n🚀 Next step: run '{cmd_name}' to create a Bento for {model_name}", fg='blue')
|
||||
if adapter_map is not None:
|
||||
cmd_name += ' ' + ' '.join([
|
||||
f'--adapter-id {s}'
|
||||
for s in [f'{p}:{name}' if name not in (None, 'default') else p for p, name in adapter_map.items()]
|
||||
])
|
||||
if not openllm.utils.get_quiet_mode():
|
||||
termui.echo(f"\n🚀 Next step: run '{cmd_name}' to create a Bento for {model_name}", fg='blue')
|
||||
|
||||
if return_process:
|
||||
server.start(env=start_env, text=True)
|
||||
@@ -239,30 +268,35 @@ def noop_command(group: click.Group, llm_config: LLMConfig, _serve_grpc: bool, *
|
||||
|
||||
return noop
|
||||
|
||||
def prerequisite_check(ctx: click.Context, llm_config: LLMConfig, quantize: LiteralString | None, adapter_map: dict[str, str | None] | None, num_workers: int) -> None:
|
||||
if adapter_map and not openllm.utils.is_peft_available(): ctx.fail("Using adapter requires 'peft' to be available. Make sure to install with 'pip install \"openllm[fine-tune]\"'")
|
||||
def prerequisite_check(ctx: click.Context, llm_config: LLMConfig, quantize: LiteralString | None,
|
||||
adapter_map: dict[str, str | None] | None, num_workers: int) -> None:
|
||||
if adapter_map and not openllm.utils.is_peft_available():
|
||||
ctx.fail(
|
||||
"Using adapter requires 'peft' to be available. Make sure to install with 'pip install \"openllm[fine-tune]\"'")
|
||||
if quantize and llm_config.default_implementation() == 'vllm':
|
||||
ctx.fail(f"Quantization is not yet supported with vLLM. Set '{llm_config['env']['framework']}=\"pt\"' to run with quantization.")
|
||||
ctx.fail(
|
||||
f"Quantization is not yet supported with vLLM. Set '{llm_config['env']['framework']}=\"pt\"' to run with quantization."
|
||||
)
|
||||
requirements = llm_config['requirements']
|
||||
if requirements is not None and len(requirements) > 0:
|
||||
missing_requirements = [i for i in requirements if importlib.util.find_spec(inflection.underscore(i)) is None]
|
||||
if len(missing_requirements) > 0: termui.echo(f'Make sure to have the following dependencies available: {missing_requirements}', fg='yellow')
|
||||
if len(missing_requirements) > 0:
|
||||
termui.echo(f'Make sure to have the following dependencies available: {missing_requirements}', fg='yellow')
|
||||
|
||||
def start_decorator(llm_config: LLMConfig, serve_grpc: bool = False) -> t.Callable[[FC], t.Callable[[FC], FC]]:
|
||||
|
||||
def wrapper(fn: FC) -> t.Callable[[FC], FC]:
|
||||
composed = openllm.utils.compose(
|
||||
llm_config.to_click_options,
|
||||
_http_server_args if not serve_grpc else _grpc_server_args,
|
||||
cog.optgroup.group('General LLM Options', help=f"The following options are related to running '{llm_config['start_name']}' LLM Server."),
|
||||
model_id_option(factory=cog.optgroup, model_env=llm_config['env']),
|
||||
model_version_option(factory=cog.optgroup),
|
||||
cog.optgroup.option('--server-timeout', type=int, default=None, help='Server timeout in seconds'),
|
||||
workers_per_resource_option(factory=cog.optgroup),
|
||||
cors_option(factory=cog.optgroup),
|
||||
fast_option(factory=cog.optgroup),
|
||||
llm_config.to_click_options, _http_server_args if not serve_grpc else _grpc_server_args,
|
||||
cog.optgroup.group(
|
||||
'LLM Optimization Options',
|
||||
help='''Optimization related options.
|
||||
'General LLM Options',
|
||||
help=f"The following options are related to running '{llm_config['start_name']}' LLM Server."),
|
||||
model_id_option(factory=cog.optgroup, model_env=llm_config['env']), model_version_option(factory=cog.optgroup),
|
||||
cog.optgroup.option('--server-timeout', type=int, default=None, help='Server timeout in seconds'),
|
||||
workers_per_resource_option(factory=cog.optgroup), cors_option(factory=cog.optgroup),
|
||||
fast_option(factory=cog.optgroup),
|
||||
cog.optgroup.group('LLM Optimization Options',
|
||||
help='''Optimization related options.
|
||||
|
||||
OpenLLM supports running model with [BetterTransformer](https://pytorch.org/blog/a-better-transformer-for-fast-transformer-encoder-inference/),
|
||||
k-bit quantization (8-bit, 4-bit), GPTQ quantization, PagedAttention via vLLM.
|
||||
@@ -272,23 +306,23 @@ def start_decorator(llm_config: LLMConfig, serve_grpc: bool = False) -> t.Callab
|
||||
- DeepSpeed Inference: [link](https://www.deepspeed.ai/inference/)
|
||||
- GGML: Fast inference on [bare metal](https://github.com/ggerganov/ggml)
|
||||
''',
|
||||
),
|
||||
cog.optgroup.option(
|
||||
'--device',
|
||||
type=openllm.utils.dantic.CUDA,
|
||||
multiple=True,
|
||||
envvar='CUDA_VISIBLE_DEVICES',
|
||||
callback=parse_device_callback,
|
||||
help=f"Assign GPU devices (if available) for {llm_config['model_name']}.",
|
||||
show_envvar=True
|
||||
),
|
||||
cog.optgroup.option('--runtime', type=click.Choice(['ggml', 'transformers']), default='transformers', help='The runtime to use for the given model. Default is transformers.'),
|
||||
),
|
||||
cog.optgroup.option('--device',
|
||||
type=openllm.utils.dantic.CUDA,
|
||||
multiple=True,
|
||||
envvar='CUDA_VISIBLE_DEVICES',
|
||||
callback=parse_device_callback,
|
||||
help=f"Assign GPU devices (if available) for {llm_config['model_name']}.",
|
||||
show_envvar=True),
|
||||
cog.optgroup.option('--runtime',
|
||||
type=click.Choice(['ggml', 'transformers']),
|
||||
default='transformers',
|
||||
help='The runtime to use for the given model. Default is transformers.'),
|
||||
quantize_option(factory=cog.optgroup, model_env=llm_config['env']),
|
||||
bettertransformer_option(factory=cog.optgroup, model_env=llm_config['env']),
|
||||
serialisation_option(factory=cog.optgroup),
|
||||
cog.optgroup.group(
|
||||
'Fine-tuning related options',
|
||||
help='''\
|
||||
cog.optgroup.group('Fine-tuning related options',
|
||||
help='''\
|
||||
Note that the argument `--adapter-id` can accept the following format:
|
||||
|
||||
- `--adapter-id /path/to/adapter` (local adapter)
|
||||
@@ -302,23 +336,22 @@ def start_decorator(llm_config: LLMConfig, serve_grpc: bool = False) -> t.Callab
|
||||
$ openllm start opt --adapter-id /path/to/adapter_dir --adapter-id remote/adapter:eng_lora
|
||||
|
||||
```
|
||||
'''
|
||||
),
|
||||
cog.optgroup.option(
|
||||
'--adapter-id',
|
||||
default=None,
|
||||
help='Optional name or path for given LoRA adapter' + f" to wrap '{llm_config['model_name']}'",
|
||||
multiple=True,
|
||||
callback=_id_callback,
|
||||
metavar='[PATH | [remote/][adapter_name:]adapter_id][, ...]'
|
||||
),
|
||||
'''),
|
||||
cog.optgroup.option('--adapter-id',
|
||||
default=None,
|
||||
help='Optional name or path for given LoRA adapter' +
|
||||
f" to wrap '{llm_config['model_name']}'",
|
||||
multiple=True,
|
||||
callback=_id_callback,
|
||||
metavar='[PATH | [remote/][adapter_name:]adapter_id][, ...]'),
|
||||
click.option('--return-process', is_flag=True, default=False, help='Internal use only.', hidden=True),
|
||||
)
|
||||
return composed(fn)
|
||||
|
||||
return wrapper
|
||||
|
||||
def parse_device_callback(ctx: click.Context, param: click.Parameter, value: tuple[tuple[str], ...] | None) -> t.Tuple[str, ...] | None:
|
||||
def parse_device_callback(ctx: click.Context, param: click.Parameter,
|
||||
value: tuple[tuple[str], ...] | None) -> t.Tuple[str, ...] | None:
|
||||
if value is None: return value
|
||||
if not isinstance(value, tuple): ctx.fail(f'{param} only accept multiple values, not {type(value)} (value: {value})')
|
||||
el: t.Tuple[str, ...] = tuple(i for k in value for i in k)
|
||||
@@ -337,14 +370,18 @@ def parse_serve_args(serve_grpc: bool) -> t.Callable[[t.Callable[..., LLMConfig]
|
||||
|
||||
command = 'serve' if not serve_grpc else 'serve-grpc'
|
||||
group = cog.optgroup.group(
|
||||
f"Start a {'HTTP' if not serve_grpc else 'gRPC'} server options", help=f"Related to serving the model [synonymous to `bentoml {'serve-http' if not serve_grpc else command }`]",
|
||||
f"Start a {'HTTP' if not serve_grpc else 'gRPC'} server options",
|
||||
help=f"Related to serving the model [synonymous to `bentoml {'serve-http' if not serve_grpc else command }`]",
|
||||
)
|
||||
|
||||
def decorator(f: t.Callable[Concatenate[int, t.Optional[str], P], LLMConfig]) -> t.Callable[[FC], FC]:
|
||||
serve_command = cli.commands[command]
|
||||
# The first variable is the argument bento
|
||||
# The last five is from BentoMLCommandGroup.NUMBER_OF_COMMON_PARAMS
|
||||
serve_options = [p for p in serve_command.params[1:-BentoMLCommandGroup.NUMBER_OF_COMMON_PARAMS] if p.name not in _IGNORED_OPTIONS]
|
||||
serve_options = [
|
||||
p for p in serve_command.params[1:-BentoMLCommandGroup.NUMBER_OF_COMMON_PARAMS]
|
||||
if p.name not in _IGNORED_OPTIONS
|
||||
]
|
||||
for options in reversed(serve_options):
|
||||
attrs = options.to_info_dict()
|
||||
# we don't need param_type_name, since it should all be options
|
||||
@@ -381,73 +418,90 @@ def _click_factory_type(*param_decls: t.Any, **attrs: t.Any) -> t.Callable[[FC |
|
||||
cli_option = functools.partial(_click_factory_type, attr='option')
|
||||
cli_argument = functools.partial(_click_factory_type, attr='argument')
|
||||
|
||||
def output_option(f: _AnyCallable | None = None, *, default_value: LiteralOutput = 'pretty', **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
def output_option(f: _AnyCallable | None = None,
|
||||
*,
|
||||
default_value: LiteralOutput = 'pretty',
|
||||
**attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
output = ['json', 'pretty', 'porcelain']
|
||||
|
||||
def complete_output_var(ctx: click.Context, param: click.Parameter, incomplete: str) -> list[CompletionItem]:
|
||||
return [CompletionItem(it) for it in output]
|
||||
|
||||
return cli_option(
|
||||
'-o',
|
||||
'--output',
|
||||
'output',
|
||||
type=click.Choice(output),
|
||||
default=default_value,
|
||||
help='Showing output type.',
|
||||
show_default=True,
|
||||
envvar='OPENLLM_OUTPUT',
|
||||
show_envvar=True,
|
||||
shell_complete=complete_output_var,
|
||||
**attrs
|
||||
)(f)
|
||||
return cli_option('-o',
|
||||
'--output',
|
||||
'output',
|
||||
type=click.Choice(output),
|
||||
default=default_value,
|
||||
help='Showing output type.',
|
||||
show_default=True,
|
||||
envvar='OPENLLM_OUTPUT',
|
||||
show_envvar=True,
|
||||
shell_complete=complete_output_var,
|
||||
**attrs)(f)
|
||||
|
||||
def fast_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--fast/--no-fast',
|
||||
show_default=True,
|
||||
default=False,
|
||||
envvar='OPENLLM_USE_LOCAL_LATEST',
|
||||
show_envvar=True,
|
||||
help='''Whether to skip checking if models is already in store.
|
||||
return cli_option('--fast/--no-fast',
|
||||
show_default=True,
|
||||
default=False,
|
||||
envvar='OPENLLM_USE_LOCAL_LATEST',
|
||||
show_envvar=True,
|
||||
help='''Whether to skip checking if models is already in store.
|
||||
|
||||
This is useful if you already downloaded or setup the model beforehand.
|
||||
''',
|
||||
**attrs
|
||||
)(f)
|
||||
**attrs)(f)
|
||||
|
||||
def cors_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option('--cors/--no-cors', show_default=True, default=False, envvar='OPENLLM_CORS', show_envvar=True, help='Enable CORS for the server.', **attrs)(f)
|
||||
return cli_option('--cors/--no-cors',
|
||||
show_default=True,
|
||||
default=False,
|
||||
envvar='OPENLLM_CORS',
|
||||
show_envvar=True,
|
||||
help='Enable CORS for the server.',
|
||||
**attrs)(f)
|
||||
|
||||
def machine_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option('--machine', is_flag=True, default=False, hidden=True, **attrs)(f)
|
||||
|
||||
def model_id_option(f: _AnyCallable | None = None, *, model_env: openllm.utils.EnvVarMixin | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--model-id',
|
||||
type=click.STRING,
|
||||
default=None,
|
||||
envvar=model_env.model_id if model_env is not None else None,
|
||||
show_envvar=model_env is not None,
|
||||
help='Optional model_id name or path for (fine-tune) weight.',
|
||||
**attrs
|
||||
)(f)
|
||||
def model_id_option(f: _AnyCallable | None = None,
|
||||
*,
|
||||
model_env: openllm.utils.EnvVarMixin | None = None,
|
||||
**attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option('--model-id',
|
||||
type=click.STRING,
|
||||
default=None,
|
||||
envvar=model_env.model_id if model_env is not None else None,
|
||||
show_envvar=model_env is not None,
|
||||
help='Optional model_id name or path for (fine-tune) weight.',
|
||||
**attrs)(f)
|
||||
|
||||
def model_version_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option('--model-version', type=click.STRING, default=None, help='Optional model version to save for this model. It will be inferred automatically from model-id.', **attrs)(f)
|
||||
return cli_option(
|
||||
'--model-version',
|
||||
type=click.STRING,
|
||||
default=None,
|
||||
help='Optional model version to save for this model. It will be inferred automatically from model-id.',
|
||||
**attrs)(f)
|
||||
|
||||
def model_name_argument(f: _AnyCallable | None = None, required: bool = True, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_argument('model_name', type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING]), required=required, **attrs)(f)
|
||||
return cli_argument('model_name',
|
||||
type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING]),
|
||||
required=required,
|
||||
**attrs)(f)
|
||||
|
||||
def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, model_env: openllm.utils.EnvVarMixin | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--quantise',
|
||||
'--quantize',
|
||||
'quantize',
|
||||
type=click.Choice(['int8', 'int4', 'gptq']),
|
||||
default=None,
|
||||
envvar=model_env.quantize if model_env is not None else None,
|
||||
show_envvar=model_env is not None,
|
||||
help='''Dynamic quantization for running this LLM.
|
||||
def quantize_option(f: _AnyCallable | None = None,
|
||||
*,
|
||||
build: bool = False,
|
||||
model_env: openllm.utils.EnvVarMixin | None = None,
|
||||
**attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option('--quantise',
|
||||
'--quantize',
|
||||
'quantize',
|
||||
type=click.Choice(['int8', 'int4', 'gptq']),
|
||||
default=None,
|
||||
envvar=model_env.quantize if model_env is not None else None,
|
||||
show_envvar=model_env is not None,
|
||||
help='''Dynamic quantization for running this LLM.
|
||||
|
||||
The following quantization strategies are supported:
|
||||
|
||||
@@ -461,17 +515,18 @@ def quantize_option(f: _AnyCallable | None = None, *, build: bool = False, model
|
||||
''' + ('''
|
||||
> [!NOTE] that this will set the mode for serving within deployment.''' if build else '') + '''
|
||||
> [!NOTE] that quantization are currently only available in *PyTorch* models.''',
|
||||
**attrs
|
||||
)(f)
|
||||
**attrs)(f)
|
||||
|
||||
def workers_per_resource_option(f: _AnyCallable | None = None, *, build: bool = False, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--workers-per-resource',
|
||||
default=None,
|
||||
callback=workers_per_resource_callback,
|
||||
type=str,
|
||||
required=False,
|
||||
help='''Number of workers per resource assigned.
|
||||
def workers_per_resource_option(f: _AnyCallable | None = None,
|
||||
*,
|
||||
build: bool = False,
|
||||
**attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option('--workers-per-resource',
|
||||
default=None,
|
||||
callback=workers_per_resource_callback,
|
||||
type=str,
|
||||
required=False,
|
||||
help='''Number of workers per resource assigned.
|
||||
|
||||
See https://docs.bentoml.org/en/latest/guides/scheduling.html#resource-scheduling-strategy
|
||||
for more information. By default, this is set to 1.
|
||||
@@ -481,38 +536,37 @@ def workers_per_resource_option(f: _AnyCallable | None = None, *, build: bool =
|
||||
- ``round_robin``: Similar behaviour when setting ``--workers-per-resource 1``. This is useful for smaller models.
|
||||
|
||||
- ``conserved``: This will determine the number of available GPU resources, and only assign one worker for the LLMRunner. For example, if ther are 4 GPUs available, then ``conserved`` is equivalent to ``--workers-per-resource 0.25``.
|
||||
''' + (
|
||||
"""\n
|
||||
''' + ("""\n
|
||||
> [!NOTE] The workers value passed into 'build' will determine how the LLM can
|
||||
> be provisioned in Kubernetes as well as in standalone container. This will
|
||||
> ensure it has the same effect with 'openllm start --api-workers ...'""" if build else ''
|
||||
),
|
||||
**attrs
|
||||
)(f)
|
||||
> ensure it has the same effect with 'openllm start --api-workers ...'""" if build else ''),
|
||||
**attrs)(f)
|
||||
|
||||
def bettertransformer_option(f: _AnyCallable | None = None, *, build: bool = False, model_env: openllm.utils.EnvVarMixin | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
def bettertransformer_option(f: _AnyCallable | None = None,
|
||||
*,
|
||||
build: bool = False,
|
||||
model_env: openllm.utils.EnvVarMixin | None = None,
|
||||
**attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--bettertransformer',
|
||||
is_flag=True,
|
||||
default=None,
|
||||
envvar=model_env.bettertransformer if model_env is not None else None,
|
||||
show_envvar=model_env is not None,
|
||||
help='Apply FasterTransformer wrapper to serve model. This will applies during serving time.'
|
||||
if not build else 'Set default environment variable whether to serve this model with FasterTransformer in build time.',
|
||||
**attrs
|
||||
)(f)
|
||||
help='Apply FasterTransformer wrapper to serve model. This will applies during serving time.' if not build else
|
||||
'Set default environment variable whether to serve this model with FasterTransformer in build time.',
|
||||
**attrs)(f)
|
||||
|
||||
def serialisation_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--serialisation',
|
||||
'--serialization',
|
||||
'serialisation_format',
|
||||
type=click.Choice(['safetensors', 'legacy']),
|
||||
default='safetensors',
|
||||
show_default=True,
|
||||
show_envvar=True,
|
||||
envvar='OPENLLM_SERIALIZATION',
|
||||
help='''Serialisation format for save/load LLM.
|
||||
return cli_option('--serialisation',
|
||||
'--serialization',
|
||||
'serialisation_format',
|
||||
type=click.Choice(['safetensors', 'legacy']),
|
||||
default='safetensors',
|
||||
show_default=True,
|
||||
show_envvar=True,
|
||||
envvar='OPENLLM_SERIALIZATION',
|
||||
help='''Serialisation format for save/load LLM.
|
||||
|
||||
Currently the following strategies are supported:
|
||||
|
||||
@@ -529,28 +583,25 @@ def serialisation_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Cal
|
||||
|
||||
> [!NOTE] that GGML format is working in progress.
|
||||
''',
|
||||
**attrs
|
||||
)(f)
|
||||
**attrs)(f)
|
||||
|
||||
def container_registry_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]:
|
||||
return cli_option(
|
||||
'--container-registry',
|
||||
'container_registry',
|
||||
type=click.Choice(list(openllm.bundle.CONTAINER_NAMES)),
|
||||
default='ecr',
|
||||
show_default=True,
|
||||
show_envvar=True,
|
||||
envvar='OPENLLM_CONTAINER_REGISTRY',
|
||||
callback=container_registry_callback,
|
||||
help='''The default container registry to get the base image for building BentoLLM.
|
||||
return cli_option('--container-registry',
|
||||
'container_registry',
|
||||
type=click.Choice(list(openllm.bundle.CONTAINER_NAMES)),
|
||||
default='ecr',
|
||||
show_default=True,
|
||||
show_envvar=True,
|
||||
envvar='OPENLLM_CONTAINER_REGISTRY',
|
||||
callback=container_registry_callback,
|
||||
help='''The default container registry to get the base image for building BentoLLM.
|
||||
|
||||
Currently, it supports 'ecr', 'ghcr.io', 'docker.io'
|
||||
|
||||
\b
|
||||
> [!NOTE] that in order to build the base image, you will need a GPUs to compile custom kernel. See ``openllm ext build-base-container`` for more information.
|
||||
''',
|
||||
**attrs
|
||||
)(f)
|
||||
**attrs)(f)
|
||||
|
||||
_wpr_strategies = {'round_robin', 'conserved'}
|
||||
|
||||
@@ -562,11 +613,14 @@ def workers_per_resource_callback(ctx: click.Context, param: click.Parameter, va
|
||||
try:
|
||||
float(value) # type: ignore[arg-type]
|
||||
except ValueError:
|
||||
raise click.BadParameter(f"'workers_per_resource' only accept '{_wpr_strategies}' as possible strategies, otherwise pass in float.", ctx, param) from None
|
||||
raise click.BadParameter(
|
||||
f"'workers_per_resource' only accept '{_wpr_strategies}' as possible strategies, otherwise pass in float.",
|
||||
ctx, param) from None
|
||||
else:
|
||||
return value
|
||||
|
||||
def container_registry_callback(ctx: click.Context, param: click.Parameter, value: str | None) -> str | None:
|
||||
if value is None: return value
|
||||
if value not in openllm.bundle.supported_registries: raise click.BadParameter(f'Value must be one of {openllm.bundle.supported_registries}', ctx, param)
|
||||
if value not in openllm.bundle.supported_registries:
|
||||
raise click.BadParameter(f'Value must be one of {openllm.bundle.supported_registries}', ctx, param)
|
||||
return value
|
||||
|
||||
@@ -30,25 +30,23 @@ if t.TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def _start(
|
||||
model_name: str,
|
||||
/,
|
||||
*,
|
||||
model_id: str | None = None,
|
||||
timeout: int = 30,
|
||||
workers_per_resource: t.Literal['conserved', 'round_robin'] | float | None = None,
|
||||
device: tuple[str, ...] | t.Literal['all'] | None = None,
|
||||
quantize: t.Literal['int8', 'int4', 'gptq'] | None = None,
|
||||
bettertransformer: bool | None = None,
|
||||
runtime: t.Literal['ggml', 'transformers'] = 'transformers',
|
||||
adapter_map: dict[LiteralString, str | None] | None = None,
|
||||
framework: LiteralRuntime | None = None,
|
||||
additional_args: list[str] | None = None,
|
||||
cors: bool = False,
|
||||
_serve_grpc: bool = False,
|
||||
__test__: bool = False,
|
||||
**_: t.Any
|
||||
) -> LLMConfig | subprocess.Popen[bytes]:
|
||||
def _start(model_name: str,
|
||||
/,
|
||||
*,
|
||||
model_id: str | None = None,
|
||||
timeout: int = 30,
|
||||
workers_per_resource: t.Literal['conserved', 'round_robin'] | float | None = None,
|
||||
device: tuple[str, ...] | t.Literal['all'] | None = None,
|
||||
quantize: t.Literal['int8', 'int4', 'gptq'] | None = None,
|
||||
bettertransformer: bool | None = None,
|
||||
runtime: t.Literal['ggml', 'transformers'] = 'transformers',
|
||||
adapter_map: dict[LiteralString, str | None] | None = None,
|
||||
framework: LiteralRuntime | None = None,
|
||||
additional_args: list[str] | None = None,
|
||||
cors: bool = False,
|
||||
_serve_grpc: bool = False,
|
||||
__test__: bool = False,
|
||||
**_: t.Any) -> LLMConfig | subprocess.Popen[bytes]:
|
||||
"""Python API to start a LLM server. These provides one-to-one mapping to CLI arguments.
|
||||
|
||||
For all additional arguments, pass it as string to ``additional_args``. For example, if you want to
|
||||
@@ -91,58 +89,66 @@ def _start(
|
||||
from .entrypoint import start_command
|
||||
from .entrypoint import start_grpc_command
|
||||
llm_config = openllm.AutoConfig.for_model(model_name)
|
||||
_ModelEnv = openllm_core.utils.EnvVarMixin(
|
||||
model_name,
|
||||
openllm_core.utils.first_not_none(framework, default=llm_config.default_implementation()),
|
||||
model_id=model_id,
|
||||
bettertransformer=bettertransformer,
|
||||
quantize=quantize,
|
||||
runtime=runtime
|
||||
)
|
||||
_ModelEnv = openllm_core.utils.EnvVarMixin(model_name,
|
||||
openllm_core.utils.first_not_none(
|
||||
framework, default=llm_config.default_implementation()),
|
||||
model_id=model_id,
|
||||
bettertransformer=bettertransformer,
|
||||
quantize=quantize,
|
||||
runtime=runtime)
|
||||
os.environ[_ModelEnv.framework] = _ModelEnv['framework_value']
|
||||
|
||||
args: list[str] = ['--runtime', runtime]
|
||||
if model_id: args.extend(['--model-id', model_id])
|
||||
if timeout: args.extend(['--server-timeout', str(timeout)])
|
||||
if workers_per_resource: args.extend(['--workers-per-resource', str(workers_per_resource) if not isinstance(workers_per_resource, str) else workers_per_resource])
|
||||
if workers_per_resource:
|
||||
args.extend([
|
||||
'--workers-per-resource',
|
||||
str(workers_per_resource) if not isinstance(workers_per_resource, str) else workers_per_resource
|
||||
])
|
||||
if device and not os.environ.get('CUDA_VISIBLE_DEVICES'): args.extend(['--device', ','.join(device)])
|
||||
if quantize and bettertransformer: raise OpenLLMException("'quantize' and 'bettertransformer' are currently mutually exclusive.")
|
||||
if quantize and bettertransformer:
|
||||
raise OpenLLMException("'quantize' and 'bettertransformer' are currently mutually exclusive.")
|
||||
if quantize: args.extend(['--quantize', str(quantize)])
|
||||
elif bettertransformer: args.append('--bettertransformer')
|
||||
if cors: args.append('--cors')
|
||||
if adapter_map: args.extend(list(itertools.chain.from_iterable([['--adapter-id', f"{k}{':'+v if v else ''}"] for k, v in adapter_map.items()])))
|
||||
if adapter_map:
|
||||
args.extend(
|
||||
list(
|
||||
itertools.chain.from_iterable([['--adapter-id', f"{k}{':'+v if v else ''}"] for k, v in adapter_map.items()
|
||||
])))
|
||||
if additional_args: args.extend(additional_args)
|
||||
if __test__: args.append('--return-process')
|
||||
|
||||
return start_command_factory(start_command if not _serve_grpc else start_grpc_command, model_name, _context_settings=termui.CONTEXT_SETTINGS, _serve_grpc=_serve_grpc).main(
|
||||
args=args if len(args) > 0 else None, standalone_mode=False
|
||||
)
|
||||
return start_command_factory(start_command if not _serve_grpc else start_grpc_command,
|
||||
model_name,
|
||||
_context_settings=termui.CONTEXT_SETTINGS,
|
||||
_serve_grpc=_serve_grpc).main(args=args if len(args) > 0 else None,
|
||||
standalone_mode=False)
|
||||
|
||||
@inject
|
||||
def _build(
|
||||
model_name: str,
|
||||
/,
|
||||
*,
|
||||
model_id: str | None = None,
|
||||
model_version: str | None = None,
|
||||
bento_version: str | None = None,
|
||||
quantize: t.Literal['int8', 'int4', 'gptq'] | None = None,
|
||||
bettertransformer: bool | None = None,
|
||||
adapter_map: dict[str, str | None] | None = None,
|
||||
build_ctx: str | None = None,
|
||||
enable_features: tuple[str, ...] | None = None,
|
||||
workers_per_resource: float | None = None,
|
||||
runtime: t.Literal['ggml', 'transformers'] = 'transformers',
|
||||
dockerfile_template: str | None = None,
|
||||
overwrite: bool = False,
|
||||
container_registry: LiteralContainerRegistry | None = None,
|
||||
container_version_strategy: LiteralContainerVersionStrategy | None = None,
|
||||
push: bool = False,
|
||||
containerize: bool = False,
|
||||
serialisation_format: t.Literal['safetensors', 'legacy'] = 'safetensors',
|
||||
additional_args: list[str] | None = None,
|
||||
bento_store: BentoStore = Provide[BentoMLContainer.bento_store]
|
||||
) -> bentoml.Bento:
|
||||
def _build(model_name: str,
|
||||
/,
|
||||
*,
|
||||
model_id: str | None = None,
|
||||
model_version: str | None = None,
|
||||
bento_version: str | None = None,
|
||||
quantize: t.Literal['int8', 'int4', 'gptq'] | None = None,
|
||||
bettertransformer: bool | None = None,
|
||||
adapter_map: dict[str, str | None] | None = None,
|
||||
build_ctx: str | None = None,
|
||||
enable_features: tuple[str, ...] | None = None,
|
||||
workers_per_resource: float | None = None,
|
||||
runtime: t.Literal['ggml', 'transformers'] = 'transformers',
|
||||
dockerfile_template: str | None = None,
|
||||
overwrite: bool = False,
|
||||
container_registry: LiteralContainerRegistry | None = None,
|
||||
container_version_strategy: LiteralContainerVersionStrategy | None = None,
|
||||
push: bool = False,
|
||||
containerize: bool = False,
|
||||
serialisation_format: t.Literal['safetensors', 'legacy'] = 'safetensors',
|
||||
additional_args: list[str] | None = None,
|
||||
bento_store: BentoStore = Provide[BentoMLContainer.bento_store]) -> bentoml.Bento:
|
||||
"""Package a LLM into a Bento.
|
||||
|
||||
The LLM will be built into a BentoService with the following structure:
|
||||
@@ -192,8 +198,12 @@ def _build(
|
||||
Returns:
|
||||
``bentoml.Bento | str``: BentoLLM instance. This can be used to serve the LLM or can be pushed to BentoCloud.
|
||||
"""
|
||||
args: list[str] = [sys.executable, '-m', 'openllm', 'build', model_name, '--machine', '--runtime', runtime, '--serialisation', serialisation_format]
|
||||
if quantize and bettertransformer: raise OpenLLMException("'quantize' and 'bettertransformer' are currently mutually exclusive.")
|
||||
args: list[str] = [
|
||||
sys.executable, '-m', 'openllm', 'build', model_name, '--machine', '--runtime', runtime, '--serialisation',
|
||||
serialisation_format
|
||||
]
|
||||
if quantize and bettertransformer:
|
||||
raise OpenLLMException("'quantize' and 'bettertransformer' are currently mutually exclusive.")
|
||||
if quantize: args.extend(['--quantize', quantize])
|
||||
if bettertransformer: args.append('--bettertransformer')
|
||||
if containerize and push: raise OpenLLMException("'containerize' and 'push' are currently mutually exclusive.")
|
||||
@@ -221,21 +231,21 @@ def _build(
|
||||
raise OpenLLMException(str(e)) from None
|
||||
matched = re.match(r'__tag__:([^:\n]+:[^:\n]+)$', output.decode('utf-8').strip())
|
||||
if matched is None:
|
||||
raise ValueError(f"Failed to find tag from output: {output.decode('utf-8').strip()}\nNote: Output from 'openllm build' might not be correct. Please open an issue on GitHub.")
|
||||
raise ValueError(
|
||||
f"Failed to find tag from output: {output.decode('utf-8').strip()}\nNote: Output from 'openllm build' might not be correct. Please open an issue on GitHub."
|
||||
)
|
||||
return bentoml.get(matched.group(1), _bento_store=bento_store)
|
||||
|
||||
def _import_model(
|
||||
model_name: str,
|
||||
/,
|
||||
*,
|
||||
model_id: str | None = None,
|
||||
model_version: str | None = None,
|
||||
runtime: t.Literal['ggml', 'transformers'] = 'transformers',
|
||||
implementation: LiteralRuntime = 'pt',
|
||||
quantize: t.Literal['int8', 'int4', 'gptq'] | None = None,
|
||||
serialisation_format: t.Literal['legacy', 'safetensors'] = 'safetensors',
|
||||
additional_args: t.Sequence[str] | None = None
|
||||
) -> bentoml.Model:
|
||||
def _import_model(model_name: str,
|
||||
/,
|
||||
*,
|
||||
model_id: str | None = None,
|
||||
model_version: str | None = None,
|
||||
runtime: t.Literal['ggml', 'transformers'] = 'transformers',
|
||||
implementation: LiteralRuntime = 'pt',
|
||||
quantize: t.Literal['int8', 'int4', 'gptq'] | None = None,
|
||||
serialisation_format: t.Literal['legacy', 'safetensors'] = 'safetensors',
|
||||
additional_args: t.Sequence[str] | None = None) -> bentoml.Model:
|
||||
"""Import a LLM into local store.
|
||||
|
||||
> [!NOTE]
|
||||
@@ -267,7 +277,10 @@ def _import_model(
|
||||
``bentoml.Model``:BentoModel of the given LLM. This can be used to serve the LLM or can be pushed to BentoCloud.
|
||||
"""
|
||||
from .entrypoint import import_command
|
||||
args = [model_name, '--runtime', runtime, '--implementation', implementation, '--machine', '--serialisation', serialisation_format,]
|
||||
args = [
|
||||
model_name, '--runtime', runtime, '--implementation', implementation, '--machine', '--serialisation',
|
||||
serialisation_format,
|
||||
]
|
||||
if model_id is not None: args.append(model_id)
|
||||
if model_version is not None: args.extend(['--model-version', str(model_version)])
|
||||
if additional_args is not None: args.extend(additional_args)
|
||||
@@ -278,5 +291,9 @@ def _list_models() -> dict[str, t.Any]:
|
||||
'''List all available models within the local store.'''
|
||||
from .entrypoint import models_command
|
||||
return models_command.main(args=['-o', 'json', '--show-available', '--machine'], standalone_mode=False)
|
||||
start, start_grpc, build, import_model, list_models = openllm_core.utils.codegen.gen_sdk(_start, _serve_grpc=False), openllm_core.utils.codegen.gen_sdk(_start, _serve_grpc=True), openllm_core.utils.codegen.gen_sdk(_build), openllm_core.utils.codegen.gen_sdk(_import_model), openllm_core.utils.codegen.gen_sdk(_list_models)
|
||||
|
||||
start, start_grpc, build, import_model, list_models = openllm_core.utils.codegen.gen_sdk(
|
||||
_start, _serve_grpc=False), openllm_core.utils.codegen.gen_sdk(
|
||||
_start, _serve_grpc=True), openllm_core.utils.codegen.gen_sdk(_build), openllm_core.utils.codegen.gen_sdk(
|
||||
_import_model), openllm_core.utils.codegen.gen_sdk(_list_models)
|
||||
__all__ = ['start', 'start_grpc', 'build', 'import_model', 'list_models']
|
||||
|
||||
@@ -14,10 +14,9 @@ if t.TYPE_CHECKING:
|
||||
from openllm_core._typing_compat import LiteralContainerRegistry
|
||||
from openllm_core._typing_compat import LiteralContainerVersionStrategy
|
||||
|
||||
@click.command(
|
||||
'build_base_container',
|
||||
context_settings=termui.CONTEXT_SETTINGS,
|
||||
help='''Base image builder for BentoLLM.
|
||||
@click.command('build_base_container',
|
||||
context_settings=termui.CONTEXT_SETTINGS,
|
||||
help='''Base image builder for BentoLLM.
|
||||
|
||||
By default, the base image will include custom kernels (PagedAttention via vllm, FlashAttention-v2, etc.) built with CUDA 11.8, Python 3.9 on Ubuntu22.04.
|
||||
Optionally, this can also be pushed directly to remote registry. Currently support ``docker.io``, ``ghcr.io`` and ``quay.io``.
|
||||
@@ -27,13 +26,16 @@ if t.TYPE_CHECKING:
|
||||
This command is only useful for debugging and for building custom base image for extending BentoML with custom base images and custom kernels.
|
||||
|
||||
Note that we already release images on our CI to ECR and GHCR, so you don't need to build it yourself.
|
||||
'''
|
||||
)
|
||||
''')
|
||||
@container_registry_option
|
||||
@click.option('--version-strategy', type=click.Choice(['release', 'latest', 'nightly']), default='nightly', help='Version strategy to use for tagging the image.')
|
||||
@click.option('--version-strategy',
|
||||
type=click.Choice(['release', 'latest', 'nightly']),
|
||||
default='nightly',
|
||||
help='Version strategy to use for tagging the image.')
|
||||
@click.option('--push/--no-push', help='Whether to push to remote repository', is_flag=True, default=False)
|
||||
@machine_option
|
||||
def cli(container_registry: tuple[LiteralContainerRegistry, ...] | None, version_strategy: LiteralContainerVersionStrategy, push: bool, machine: bool) -> dict[str, str]:
|
||||
def cli(container_registry: tuple[LiteralContainerRegistry, ...] | None,
|
||||
version_strategy: LiteralContainerVersionStrategy, push: bool, machine: bool) -> dict[str, str]:
|
||||
mapping = openllm.bundle.build_container(container_registry, version_strategy, push, machine)
|
||||
if machine: termui.echo(orjson.dumps(mapping, option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
return mapping
|
||||
|
||||
@@ -24,14 +24,19 @@ if t.TYPE_CHECKING:
|
||||
@machine_option
|
||||
@click.pass_context
|
||||
@inject
|
||||
def cli(ctx: click.Context, bento: str, machine: bool, _bento_store: BentoStore = Provide[BentoMLContainer.bento_store]) -> str | None:
|
||||
def cli(ctx: click.Context,
|
||||
bento: str,
|
||||
machine: bool,
|
||||
_bento_store: BentoStore = Provide[BentoMLContainer.bento_store]) -> str | None:
|
||||
'''Dive into a BentoLLM. This is synonymous to cd $(b get <bento>:<tag> -o path).'''
|
||||
try:
|
||||
bentomodel = _bento_store.get(bento)
|
||||
except bentoml.exceptions.NotFound:
|
||||
ctx.fail(f'Bento {bento} not found. Make sure to call `openllm build` first.')
|
||||
if 'bundler' not in bentomodel.info.labels or bentomodel.info.labels['bundler'] != 'openllm.bundle':
|
||||
ctx.fail(f"Bento is either too old or not built with OpenLLM. Make sure to use ``openllm build {bentomodel.info.labels['start_name']}`` for correctness.")
|
||||
ctx.fail(
|
||||
f"Bento is either too old or not built with OpenLLM. Make sure to use ``openllm build {bentomodel.info.labels['start_name']}`` for correctness."
|
||||
)
|
||||
if machine: return bentomodel.path
|
||||
# copy and paste this into a new shell
|
||||
if psutil.WINDOWS: subprocess.check_call([shutil.which('dir') or 'dir'], cwd=bentomodel.path)
|
||||
|
||||
@@ -19,7 +19,9 @@ from openllm_core.utils import bentoml_cattr
|
||||
if t.TYPE_CHECKING:
|
||||
from bentoml._internal.bento import BentoStore
|
||||
|
||||
@click.command('get_containerfile', context_settings=termui.CONTEXT_SETTINGS, help='Return Containerfile of any given Bento.')
|
||||
@click.command('get_containerfile',
|
||||
context_settings=termui.CONTEXT_SETTINGS,
|
||||
help='Return Containerfile of any given Bento.')
|
||||
@click.argument('bento', type=str, shell_complete=bento_complete_envvar)
|
||||
@click.pass_context
|
||||
@inject
|
||||
@@ -39,7 +41,13 @@ def cli(ctx: click.Context, bento: str, _bento_store: BentoStore = Provide[Bento
|
||||
# NOTE: if users specify a dockerfile_template, we will
|
||||
# save it to /env/docker/Dockerfile.template. This is necessary
|
||||
# for the reconstruction of the Dockerfile.
|
||||
if 'dockerfile_template' in docker_attrs and docker_attrs['dockerfile_template'] is not None: docker_attrs['dockerfile_template'] = 'env/docker/Dockerfile.template'
|
||||
doc = generate_containerfile(docker=DockerOptions(**docker_attrs), build_ctx=bentomodel.path, conda=options.conda, bento_fs=bentomodel._fs, enable_buildkit=True, add_header=True)
|
||||
if 'dockerfile_template' in docker_attrs and docker_attrs['dockerfile_template'] is not None:
|
||||
docker_attrs['dockerfile_template'] = 'env/docker/Dockerfile.template'
|
||||
doc = generate_containerfile(docker=DockerOptions(**docker_attrs),
|
||||
build_ctx=bentomodel.path,
|
||||
conda=options.conda,
|
||||
bento_fs=bentomodel._fs,
|
||||
enable_buildkit=True,
|
||||
add_header=True)
|
||||
termui.echo(doc, fg='white')
|
||||
return bentomodel.path
|
||||
|
||||
@@ -18,41 +18,51 @@ from openllm_core._prompt import process_prompt
|
||||
LiteralOutput = t.Literal['json', 'pretty', 'porcelain']
|
||||
|
||||
@click.command('get_prompt', context_settings=termui.CONTEXT_SETTINGS)
|
||||
@click.argument('model_name', type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()]), shell_complete=model_complete_envvar)
|
||||
@click.argument('model_name',
|
||||
type=click.Choice([inflection.dasherize(name) for name in openllm.CONFIG_MAPPING.keys()]),
|
||||
shell_complete=model_complete_envvar)
|
||||
@click.argument('prompt', type=click.STRING)
|
||||
@output_option
|
||||
@click.option('--format', type=click.STRING, default=None)
|
||||
@machine_option
|
||||
@click.option(
|
||||
'--opt',
|
||||
help="Define additional prompt variables. (format: ``--opt system_prompt='You are a useful assistant'``)",
|
||||
required=False,
|
||||
multiple=True,
|
||||
callback=opt_callback,
|
||||
metavar='ARG=VALUE[,ARG=VALUE]'
|
||||
)
|
||||
@click.option('--opt',
|
||||
help="Define additional prompt variables. (format: ``--opt system_prompt='You are a useful assistant'``)",
|
||||
required=False,
|
||||
multiple=True,
|
||||
callback=opt_callback,
|
||||
metavar='ARG=VALUE[,ARG=VALUE]')
|
||||
@click.pass_context
|
||||
def cli(ctx: click.Context, /, model_name: str, prompt: str, format: str | None, output: LiteralOutput, machine: bool, _memoized: dict[str, t.Any], **_: t.Any) -> str | None:
|
||||
def cli(ctx: click.Context, /, model_name: str, prompt: str, format: str | None, output: LiteralOutput, machine: bool,
|
||||
_memoized: dict[str, t.Any], **_: t.Any) -> str | None:
|
||||
'''Get the default prompt used by OpenLLM.'''
|
||||
module = openllm.utils.EnvVarMixin(model_name).module
|
||||
_memoized = {k: v[0] for k, v in _memoized.items() if v}
|
||||
try:
|
||||
template = getattr(module, 'DEFAULT_PROMPT_TEMPLATE', None)
|
||||
prompt_mapping = getattr(module, 'PROMPT_MAPPING', None)
|
||||
if template is None: raise click.BadArgumentUsage(f'model {model_name} does not have a default prompt template') from None
|
||||
if template is None:
|
||||
raise click.BadArgumentUsage(f'model {model_name} does not have a default prompt template') from None
|
||||
if callable(template):
|
||||
if format is None:
|
||||
if not hasattr(module, 'PROMPT_MAPPING') or module.PROMPT_MAPPING is None: raise RuntimeError('Failed to find prompt mapping while DEFAULT_PROMPT_TEMPLATE is a function.')
|
||||
raise click.BadOptionUsage('format', f"{model_name} prompt requires passing '--format' (available format: {list(module.PROMPT_MAPPING)})")
|
||||
if prompt_mapping is None: raise click.BadArgumentUsage(f'Failed to fine prompt mapping while the default prompt for {model_name} is a callable.') from None
|
||||
if format not in prompt_mapping: raise click.BadOptionUsage('format', f'Given format {format} is not valid for {model_name} (available format: {list(prompt_mapping)})')
|
||||
if not hasattr(module, 'PROMPT_MAPPING') or module.PROMPT_MAPPING is None:
|
||||
raise RuntimeError('Failed to find prompt mapping while DEFAULT_PROMPT_TEMPLATE is a function.')
|
||||
raise click.BadOptionUsage(
|
||||
'format',
|
||||
f"{model_name} prompt requires passing '--format' (available format: {list(module.PROMPT_MAPPING)})")
|
||||
if prompt_mapping is None:
|
||||
raise click.BadArgumentUsage(
|
||||
f'Failed to fine prompt mapping while the default prompt for {model_name} is a callable.') from None
|
||||
if format not in prompt_mapping:
|
||||
raise click.BadOptionUsage(
|
||||
'format', f'Given format {format} is not valid for {model_name} (available format: {list(prompt_mapping)})')
|
||||
_prompt_template = template(format)
|
||||
else:
|
||||
_prompt_template = template
|
||||
fully_formatted = process_prompt(prompt, _prompt_template, True, **_memoized)
|
||||
if machine: return repr(fully_formatted)
|
||||
elif output == 'porcelain': termui.echo(repr(fully_formatted), fg='white')
|
||||
elif output == 'json': termui.echo(orjson.dumps({'prompt': fully_formatted}, option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
elif output == 'json':
|
||||
termui.echo(orjson.dumps({'prompt': fully_formatted}, option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
else:
|
||||
termui.echo(f'== Prompt for {model_name} ==\n', fg='magenta')
|
||||
termui.echo(fully_formatted, fg='white')
|
||||
|
||||
@@ -19,23 +19,27 @@ def cli(ctx: click.Context, output: LiteralOutput) -> None:
|
||||
'''List available bentos built by OpenLLM.'''
|
||||
mapping = {
|
||||
k: [{
|
||||
'tag': str(b.tag),
|
||||
'size': human_readable_size(openllm.utils.calc_dir_size(b.path)),
|
||||
'tag':
|
||||
str(b.tag),
|
||||
'size':
|
||||
human_readable_size(openllm.utils.calc_dir_size(b.path)),
|
||||
'models': [{
|
||||
'tag': str(m.tag), 'size': human_readable_size(openllm.utils.calc_dir_size(m.path))
|
||||
'tag': str(m.tag),
|
||||
'size': human_readable_size(openllm.utils.calc_dir_size(m.path))
|
||||
} for m in (bentoml.models.get(_.tag) for _ in b.info.models)]
|
||||
} for b in tuple(i for i in bentoml.list() if all(k in i.info.labels for k in {'start_name', 'bundler'})) if b.info.labels['start_name'] == k] for k in tuple(
|
||||
inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys()
|
||||
)
|
||||
} for b in tuple(i for i in bentoml.list() if all(
|
||||
k in i.info.labels for k in {'start_name', 'bundler'})) if b.info.labels['start_name'] == k
|
||||
] for k in tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys())
|
||||
}
|
||||
mapping = {k: v for k, v in mapping.items() if v}
|
||||
if output == 'pretty':
|
||||
import tabulate
|
||||
tabulate.PRESERVE_WHITESPACE = True
|
||||
termui.echo(
|
||||
tabulate.tabulate([(k, i['tag'], i['size'], [_['tag'] for _ in i['models']]) for k, v in mapping.items() for i in v], tablefmt='fancy_grid', headers=['LLM', 'Tag', 'Size', 'Models']),
|
||||
fg='white'
|
||||
)
|
||||
termui.echo(tabulate.tabulate(
|
||||
[(k, i['tag'], i['size'], [_['tag'] for _ in i['models']]) for k, v in mapping.items() for i in v],
|
||||
tablefmt='fancy_grid',
|
||||
headers=['LLM', 'Tag', 'Size', 'Models']),
|
||||
fg='white')
|
||||
else:
|
||||
termui.echo(orjson.dumps(mapping, option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
ctx.exit(0)
|
||||
|
||||
@@ -25,17 +25,33 @@ def cli(model_name: str | None, output: LiteralOutput) -> DictStrAny:
|
||||
'''This is equivalent to openllm models --show-available less the nice table.'''
|
||||
models = tuple(inflection.dasherize(key) for key in openllm.CONFIG_MAPPING.keys())
|
||||
ids_in_local_store = {
|
||||
k: [i for i in bentoml.models.list() if 'framework' in i.info.labels and i.info.labels['framework'] == 'openllm' and 'model_name' in i.info.labels and i.info.labels['model_name'] == k
|
||||
] for k in models
|
||||
k: [
|
||||
i for i in bentoml.models.list() if 'framework' in i.info.labels and
|
||||
i.info.labels['framework'] == 'openllm' and 'model_name' in i.info.labels and i.info.labels['model_name'] == k
|
||||
] for k in models
|
||||
}
|
||||
if model_name is not None:
|
||||
ids_in_local_store = {k: [i for i in v if 'model_name' in i.info.labels and i.info.labels['model_name'] == inflection.dasherize(model_name)] for k, v in ids_in_local_store.items()}
|
||||
ids_in_local_store = {
|
||||
k: [
|
||||
i
|
||||
for i in v
|
||||
if 'model_name' in i.info.labels and i.info.labels['model_name'] == inflection.dasherize(model_name)
|
||||
] for k, v in ids_in_local_store.items()
|
||||
}
|
||||
ids_in_local_store = {k: v for k, v in ids_in_local_store.items() if v}
|
||||
local_models = {k: [{'tag': str(i.tag), 'size': human_readable_size(openllm.utils.calc_dir_size(i.path))} for i in val] for k, val in ids_in_local_store.items()}
|
||||
local_models = {
|
||||
k: [{
|
||||
'tag': str(i.tag),
|
||||
'size': human_readable_size(openllm.utils.calc_dir_size(i.path))
|
||||
} for i in val] for k, val in ids_in_local_store.items()
|
||||
}
|
||||
if output == 'pretty':
|
||||
import tabulate
|
||||
tabulate.PRESERVE_WHITESPACE = True
|
||||
termui.echo(tabulate.tabulate([(k, i['tag'], i['size']) for k, v in local_models.items() for i in v], tablefmt='fancy_grid', headers=['LLM', 'Tag', 'Size']), fg='white')
|
||||
termui.echo(tabulate.tabulate([(k, i['tag'], i['size']) for k, v in local_models.items() for i in v],
|
||||
tablefmt='fancy_grid',
|
||||
headers=['LLM', 'Tag', 'Size']),
|
||||
fg='white')
|
||||
else:
|
||||
termui.echo(orjson.dumps(local_models, option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
return local_models
|
||||
|
||||
@@ -28,12 +28,18 @@ logger = logging.getLogger(__name__)
|
||||
def load_notebook_metadata() -> DictStrAny:
|
||||
with open(os.path.join(os.path.dirname(playground.__file__), '_meta.yml'), 'r') as f:
|
||||
content = yaml.safe_load(f)
|
||||
if not all('description' in k for k in content.values()): raise ValueError("Invalid metadata file. All entries must have a 'description' key.")
|
||||
if not all('description' in k for k in content.values()):
|
||||
raise ValueError("Invalid metadata file. All entries must have a 'description' key.")
|
||||
return content
|
||||
|
||||
@click.command('playground', context_settings=termui.CONTEXT_SETTINGS)
|
||||
@click.argument('output-dir', default=None, required=False)
|
||||
@click.option('--port', envvar='JUPYTER_PORT', show_envvar=True, show_default=True, default=8888, help='Default port for Jupyter server')
|
||||
@click.option('--port',
|
||||
envvar='JUPYTER_PORT',
|
||||
show_envvar=True,
|
||||
show_default=True,
|
||||
default=8888,
|
||||
help='Default port for Jupyter server')
|
||||
@click.pass_context
|
||||
def cli(ctx: click.Context, output_dir: str | None, port: int) -> None:
|
||||
"""OpenLLM Playground.
|
||||
@@ -54,7 +60,9 @@ def cli(ctx: click.Context, output_dir: str | None, port: int) -> None:
|
||||
> This command requires Jupyter to be installed. Install it with 'pip install "openllm[playground]"'
|
||||
"""
|
||||
if not is_jupyter_available() or not is_jupytext_available() or not is_notebook_available():
|
||||
raise RuntimeError("Playground requires 'jupyter', 'jupytext', and 'notebook'. Install it with 'pip install \"openllm[playground]\"'")
|
||||
raise RuntimeError(
|
||||
"Playground requires 'jupyter', 'jupytext', and 'notebook'. Install it with 'pip install \"openllm[playground]\"'"
|
||||
)
|
||||
metadata = load_notebook_metadata()
|
||||
_temp_dir = False
|
||||
if output_dir is None:
|
||||
@@ -66,7 +74,8 @@ def cli(ctx: click.Context, output_dir: str | None, port: int) -> None:
|
||||
termui.echo('The playground notebooks will be saved to: ' + os.path.abspath(output_dir), fg='blue')
|
||||
for module in pkgutil.iter_modules(playground.__path__):
|
||||
if module.ispkg or os.path.exists(os.path.join(output_dir, module.name + '.ipynb')):
|
||||
logger.debug('Skipping: %s (%s)', module.name, 'File already exists' if not module.ispkg else f'{module.name} is a module')
|
||||
logger.debug('Skipping: %s (%s)', module.name,
|
||||
'File already exists' if not module.ispkg else f'{module.name} is a module')
|
||||
continue
|
||||
if not isinstance(module.module_finder, importlib.machinery.FileFinder): continue
|
||||
termui.echo('Generating notebook for: ' + module.name, fg='magenta')
|
||||
@@ -75,7 +84,10 @@ def cli(ctx: click.Context, output_dir: str | None, port: int) -> None:
|
||||
f.cells.insert(0, markdown_cell)
|
||||
jupytext.write(f, os.path.join(output_dir, module.name + '.ipynb'), fmt='notebook')
|
||||
try:
|
||||
subprocess.check_output([sys.executable, '-m', 'jupyter', 'notebook', '--notebook-dir', output_dir, '--port', str(port), '--no-browser', '--debug'])
|
||||
subprocess.check_output([
|
||||
sys.executable, '-m', 'jupyter', 'notebook', '--notebook-dir', output_dir, '--port',
|
||||
str(port), '--no-browser', '--debug'
|
||||
])
|
||||
except subprocess.CalledProcessError as e:
|
||||
termui.echo(e.output, fg='red')
|
||||
raise click.ClickException(f'Failed to start a jupyter server:\n{e}') from None
|
||||
|
||||
@@ -12,8 +12,13 @@ if t.TYPE_CHECKING:
|
||||
|
||||
def echo(text: t.Any, fg: str = 'green', _with_style: bool = True, **attrs: t.Any) -> None:
|
||||
attrs['fg'] = fg if not openllm.utils.get_debug_mode() else None
|
||||
if not openllm.utils.get_quiet_mode(): t.cast(t.Callable[..., None], click.echo if not _with_style else click.secho)(text, **attrs)
|
||||
if not openllm.utils.get_quiet_mode():
|
||||
t.cast(t.Callable[..., None], click.echo if not _with_style else click.secho)(text, **attrs)
|
||||
|
||||
COLUMNS: int = int(os.environ.get('COLUMNS', str(120)))
|
||||
CONTEXT_SETTINGS: DictStrAny = {'help_option_names': ['-h', '--help'], 'max_content_width': COLUMNS, 'token_normalize_func': inflection.underscore}
|
||||
CONTEXT_SETTINGS: DictStrAny = {
|
||||
'help_option_names': ['-h', '--help'],
|
||||
'max_content_width': COLUMNS,
|
||||
'token_normalize_func': inflection.underscore
|
||||
}
|
||||
__all__ = ['echo', 'COLUMNS', 'CONTEXT_SETTINGS']
|
||||
|
||||
@@ -43,7 +43,8 @@ except openllm.exceptions.MissingDependencyError:
|
||||
pass
|
||||
else:
|
||||
_import_structure['modeling_flax_auto'].extend(['AutoFlaxLLM', 'MODEL_FLAX_MAPPING'])
|
||||
if t.TYPE_CHECKING: from .modeling_flax_auto import MODEL_FLAX_MAPPING as MODEL_FLAX_MAPPING, AutoFlaxLLM as AutoFlaxLLM
|
||||
if t.TYPE_CHECKING:
|
||||
from .modeling_flax_auto import MODEL_FLAX_MAPPING as MODEL_FLAX_MAPPING, AutoFlaxLLM as AutoFlaxLLM
|
||||
try:
|
||||
if not is_tf_available(): raise openllm.exceptions.MissingDependencyError
|
||||
except openllm.exceptions.MissingDependencyError:
|
||||
|
||||
@@ -30,10 +30,18 @@ class BaseAutoLLMClass:
|
||||
_model_mapping: t.ClassVar[_LazyAutoMapping]
|
||||
|
||||
def __init__(self, *args: t.Any, **attrs: t.Any):
|
||||
raise EnvironmentError(f"Cannot instantiate {self.__class__.__name__} directly. Please use '{self.__class__.__name__}.Runner(model_name)' instead.")
|
||||
raise EnvironmentError(
|
||||
f"Cannot instantiate {self.__class__.__name__} directly. Please use '{self.__class__.__name__}.Runner(model_name)' instead."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def for_model(cls, model: str, /, model_id: str | None = None, model_version: str | None = None, llm_config: openllm.LLMConfig | None = None, ensure_available: bool = False,
|
||||
def for_model(cls,
|
||||
model: str,
|
||||
/,
|
||||
model_id: str | None = None,
|
||||
model_version: str | None = None,
|
||||
llm_config: openllm.LLMConfig | None = None,
|
||||
ensure_available: bool = False,
|
||||
**attrs: t.Any) -> openllm.LLM[t.Any, t.Any]:
|
||||
'''The lower level API for creating a LLM instance.
|
||||
|
||||
@@ -42,7 +50,10 @@ class BaseAutoLLMClass:
|
||||
>>> llm = openllm.AutoLLM.for_model("flan-t5")
|
||||
```
|
||||
'''
|
||||
llm = cls.infer_class_from_name(model).from_pretrained(model_id=model_id, model_version=model_version, llm_config=llm_config, **attrs)
|
||||
llm = cls.infer_class_from_name(model).from_pretrained(model_id=model_id,
|
||||
model_version=model_version,
|
||||
llm_config=llm_config,
|
||||
**attrs)
|
||||
if ensure_available: llm.ensure_model_id_exists()
|
||||
return llm
|
||||
|
||||
@@ -105,7 +116,9 @@ class _LazyAutoMapping(OrderedDict, ReprMixin):
|
||||
This OrderedDict values() and keys() returns the list instead, so you don't
|
||||
have to do list(mapping.values()) to get the list of values.
|
||||
"""
|
||||
def __init__(self, config_mapping: OrderedDict[LiteralString, LiteralString], model_mapping: OrderedDict[LiteralString, LiteralString]):
|
||||
|
||||
def __init__(self, config_mapping: OrderedDict[LiteralString, LiteralString],
|
||||
model_mapping: OrderedDict[LiteralString, LiteralString]):
|
||||
self._config_mapping = config_mapping
|
||||
self._reverse_config_mapping = {v: k for k, v in config_mapping.items()}
|
||||
self._model_mapping = model_mapping
|
||||
@@ -115,7 +128,8 @@ class _LazyAutoMapping(OrderedDict, ReprMixin):
|
||||
def __getitem__(self, key: type[openllm.LLMConfig]) -> type[openllm.LLM[t.Any, t.Any]]:
|
||||
if key in self._extra_content: return self._extra_content[key]
|
||||
model_type = self._reverse_config_mapping[key.__name__]
|
||||
if model_type in self._model_mapping: return self._load_attr_from_module(model_type, self._model_mapping[model_type])
|
||||
if model_type in self._model_mapping:
|
||||
return self._load_attr_from_module(model_type, self._model_mapping[model_type])
|
||||
# Maybe there was several model types associated with this config.
|
||||
model_types = [k for k, v in self._config_mapping.items() if v == key.__name__]
|
||||
for mtype in model_types:
|
||||
@@ -124,7 +138,8 @@ class _LazyAutoMapping(OrderedDict, ReprMixin):
|
||||
|
||||
def _load_attr_from_module(self, model_type: str, attr: str) -> t.Any:
|
||||
module_name = inflection.underscore(model_type)
|
||||
if module_name not in self._modules: self._modules[module_name] = importlib.import_module(f'.{module_name}', 'openllm.models')
|
||||
if module_name not in self._modules:
|
||||
self._modules[module_name] = importlib.import_module(f'.{module_name}', 'openllm.models')
|
||||
return getattribute_from_module(self._modules[module_name], attr)
|
||||
|
||||
def __len__(self) -> int:
|
||||
@@ -138,29 +153,32 @@ class _LazyAutoMapping(OrderedDict, ReprMixin):
|
||||
return ReprMixin.__repr__(self)
|
||||
|
||||
def __repr_args__(self) -> t.Generator[tuple[str, tuple[str, str]], t.Any, t.Any]:
|
||||
yield from ((key, (value, self._model_mapping[key])) for key, value in self._config_mapping.items() if key in self._model_mapping)
|
||||
yield from ((key, (value, self._model_mapping[key]))
|
||||
for key, value in self._config_mapping.items()
|
||||
if key in self._model_mapping)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.keys())
|
||||
|
||||
def keys(self) -> ConfigModelKeysView:
|
||||
return t.cast(
|
||||
'ConfigModelKeysView', [self._load_attr_from_module(key, name) for key, name in self._config_mapping.items() if key in self._model_mapping.keys()] + list(self._extra_content.keys())
|
||||
)
|
||||
return t.cast('ConfigModelKeysView', [
|
||||
self._load_attr_from_module(key, name)
|
||||
for key, name in self._config_mapping.items()
|
||||
if key in self._model_mapping.keys()
|
||||
] + list(self._extra_content.keys()))
|
||||
|
||||
def values(self) -> ConfigModelValuesView:
|
||||
return t.cast(
|
||||
'ConfigModelValuesView', [self._load_attr_from_module(key, name) for key, name in self._model_mapping.items() if key in self._config_mapping.keys()] + list(
|
||||
self._extra_content.values()
|
||||
)
|
||||
)
|
||||
return t.cast('ConfigModelValuesView', [
|
||||
self._load_attr_from_module(key, name)
|
||||
for key, name in self._model_mapping.items()
|
||||
if key in self._config_mapping.keys()
|
||||
] + list(self._extra_content.values()))
|
||||
|
||||
def items(self) -> ConfigModelItemsView:
|
||||
return t.cast(
|
||||
'ConfigModelItemsView',
|
||||
[(self._load_attr_from_module(key, self._config_mapping[key]),
|
||||
self._load_attr_from_module(key, self._model_mapping[key])) for key in self._model_mapping.keys() if key in self._config_mapping.keys()] + list(self._extra_content.items())
|
||||
)
|
||||
return t.cast('ConfigModelItemsView', [(self._load_attr_from_module(
|
||||
key, self._config_mapping[key]), self._load_attr_from_module(key, self._model_mapping[key]))
|
||||
for key in self._model_mapping.keys()
|
||||
if key in self._config_mapping.keys()] + list(self._extra_content.items()))
|
||||
|
||||
def __iter__(self) -> t.Iterator[type[openllm.LLMConfig]]:
|
||||
return iter(t.cast('SupportsIter[t.Iterator[type[openllm.LLMConfig]]]', self.keys()))
|
||||
@@ -172,7 +190,8 @@ class _LazyAutoMapping(OrderedDict, ReprMixin):
|
||||
|
||||
def register(self, key: t.Any, value: t.Any) -> None:
|
||||
if hasattr(key, '__name__') and key.__name__ in self._reverse_config_mapping:
|
||||
if self._reverse_config_mapping[key.__name__] in self._model_mapping.keys(): raise ValueError(f"'{key}' is already used by a OpenLLM model.")
|
||||
if self._reverse_config_mapping[key.__name__] in self._model_mapping.keys():
|
||||
raise ValueError(f"'{key}' is already used by a OpenLLM model.")
|
||||
self._extra_content[key] = value
|
||||
|
||||
__all__ = ['BaseAutoLLMClass', '_LazyAutoMapping']
|
||||
|
||||
@@ -7,9 +7,10 @@ from openllm_core.config import CONFIG_MAPPING_NAMES
|
||||
from .factory import BaseAutoLLMClass
|
||||
from .factory import _LazyAutoMapping
|
||||
|
||||
MODEL_MAPPING_NAMES = OrderedDict([('chatglm', 'ChatGLM'), ('dolly_v2', 'DollyV2'), ('falcon', 'Falcon'), ('flan_t5', 'FlanT5'), ('gpt_neox', 'GPTNeoX'), ('llama', 'Llama'), ('mpt', 'MPT'), (
|
||||
'opt', 'OPT'
|
||||
), ('stablelm', 'StableLM'), ('starcoder', 'StarCoder'), ('baichuan', 'Baichuan')])
|
||||
MODEL_MAPPING_NAMES = OrderedDict([('chatglm', 'ChatGLM'), ('dolly_v2', 'DollyV2'), ('falcon', 'Falcon'),
|
||||
('flan_t5', 'FlanT5'), ('gpt_neox', 'GPTNeoX'), ('llama', 'Llama'), ('mpt', 'MPT'),
|
||||
('opt', 'OPT'), ('stablelm', 'StableLM'), ('starcoder', 'StarCoder'),
|
||||
('baichuan', 'Baichuan')])
|
||||
MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
|
||||
|
||||
class AutoLLM(BaseAutoLLMClass):
|
||||
|
||||
@@ -7,9 +7,10 @@ from openllm_core.config import CONFIG_MAPPING_NAMES
|
||||
from .factory import BaseAutoLLMClass
|
||||
from .factory import _LazyAutoMapping
|
||||
|
||||
MODEL_VLLM_MAPPING_NAMES = OrderedDict([('baichuan', 'VLLMBaichuan'), ('dolly_v2', 'VLLMDollyV2'), ('falcon', 'VLLMFalcon'), ('gpt_neox', 'VLLMGPTNeoX'), ('mpt', 'VLLMMPT'), (
|
||||
'opt', 'VLLMOPT'
|
||||
), ('stablelm', 'VLLMStableLM'), ('starcoder', 'VLLMStarCoder'), ('llama', 'VLLMLlama')])
|
||||
MODEL_VLLM_MAPPING_NAMES = OrderedDict([('baichuan', 'VLLMBaichuan'), ('dolly_v2', 'VLLMDollyV2'),
|
||||
('falcon', 'VLLMFalcon'), ('gpt_neox', 'VLLMGPTNeoX'), ('mpt', 'VLLMMPT'),
|
||||
('opt', 'VLLMOPT'), ('stablelm', 'VLLMStableLM'),
|
||||
('starcoder', 'VLLMStarCoder'), ('llama', 'VLLMLlama')])
|
||||
MODEL_VLLM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_VLLM_MAPPING_NAMES)
|
||||
|
||||
class AutoVLLM(BaseAutoLLMClass):
|
||||
|
||||
@@ -11,5 +11,6 @@ class Baichuan(openllm.LLM['transformers.PreTrainedModel', 'transformers.PreTrai
|
||||
import torch
|
||||
inputs = self.tokenizer(prompt, return_tensors='pt').to(self.device)
|
||||
with torch.inference_mode(), torch.autocast('cuda', dtype=torch.float16): # type: ignore[attr-defined]
|
||||
outputs = self.model.generate(**inputs, generation_config=self.config.model_construct_env(**attrs).to_generation_config())
|
||||
outputs = self.model.generate(**inputs,
|
||||
generation_config=self.config.model_construct_env(**attrs).to_generation_config())
|
||||
return self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
||||
@@ -14,7 +14,9 @@ class ChatGLM(openllm.LLM['transformers.PreTrainedModel', 'transformers.PreTrain
|
||||
self.model.eval()
|
||||
# Only use half precision if the model is not yet quantized
|
||||
if self.config.use_half_precision: self.model.half()
|
||||
return self.model.chat(self.tokenizer, prompt, generation_config=self.config.model_construct_env(**attrs).to_generation_config())
|
||||
return self.model.chat(self.tokenizer,
|
||||
prompt,
|
||||
generation_config=self.config.model_construct_env(**attrs).to_generation_config())
|
||||
|
||||
def embeddings(self, prompts: list[str]) -> openllm.LLMEmbeddings:
|
||||
import torch
|
||||
|
||||
@@ -10,29 +10,57 @@ from openllm_core.config.configuration_dolly_v2 import END_KEY
|
||||
from openllm_core.config.configuration_dolly_v2 import RESPONSE_KEY
|
||||
from openllm_core.config.configuration_dolly_v2 import get_special_token_id
|
||||
if t.TYPE_CHECKING: import torch, transformers, tensorflow as tf
|
||||
else: torch, transformers, tf = openllm.utils.LazyLoader('torch', globals(), 'torch'), openllm.utils.LazyLoader('transformers', globals(), 'transformers'), openllm.utils.LazyLoader('tf', globals(), 'tensorflow')
|
||||
else:
|
||||
torch, transformers, tf = openllm.utils.LazyLoader('torch', globals(), 'torch'), openllm.utils.LazyLoader(
|
||||
'transformers', globals(), 'transformers'), openllm.utils.LazyLoader('tf', globals(), 'tensorflow')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@overload
|
||||
def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, _init: t.Literal[True] = True, **attrs: t.Any) -> transformers.Pipeline:
|
||||
def get_pipeline(model: transformers.PreTrainedModel,
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
_init: t.Literal[True] = True,
|
||||
**attrs: t.Any) -> transformers.Pipeline:
|
||||
...
|
||||
|
||||
@overload
|
||||
def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, _init: t.Literal[False] = ..., **attrs: t.Any) -> type[transformers.Pipeline]:
|
||||
def get_pipeline(model: transformers.PreTrainedModel,
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
_init: t.Literal[False] = ...,
|
||||
**attrs: t.Any) -> type[transformers.Pipeline]:
|
||||
...
|
||||
|
||||
def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.PreTrainedTokenizer, _init: bool = False, **attrs: t.Any) -> type[transformers.Pipeline] | transformers.Pipeline:
|
||||
def get_pipeline(model: transformers.PreTrainedModel,
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
_init: bool = False,
|
||||
**attrs: t.Any) -> type[transformers.Pipeline] | transformers.Pipeline:
|
||||
# Lazy loading the pipeline. See databricks' implementation on HuggingFace for more information.
|
||||
class InstructionTextGenerationPipeline(transformers.Pipeline):
|
||||
def __init__(self, *args: t.Any, do_sample: bool = True, max_new_tokens: int = 256, top_p: float = 0.92, top_k: int = 0, **kwargs: t.Any):
|
||||
super().__init__(*args, model=model, tokenizer=tokenizer, do_sample=do_sample, max_new_tokens=max_new_tokens, top_p=top_p, top_k=top_k, **kwargs)
|
||||
|
||||
def _sanitize_parameters(self, return_full_text: bool | None = None, **generate_kwargs: t.Any) -> tuple[dict[str, t.Any], dict[str, t.Any], dict[str, t.Any]]:
|
||||
def __init__(self,
|
||||
*args: t.Any,
|
||||
do_sample: bool = True,
|
||||
max_new_tokens: int = 256,
|
||||
top_p: float = 0.92,
|
||||
top_k: int = 0,
|
||||
**kwargs: t.Any):
|
||||
super().__init__(*args,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
do_sample=do_sample,
|
||||
max_new_tokens=max_new_tokens,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
**kwargs)
|
||||
|
||||
def _sanitize_parameters(self,
|
||||
return_full_text: bool | None = None,
|
||||
**generate_kwargs: t.Any) -> tuple[dict[str, t.Any], dict[str, t.Any], dict[str, t.Any]]:
|
||||
if t.TYPE_CHECKING: assert self.tokenizer is not None
|
||||
preprocess_params: dict[str, t.Any] = {}
|
||||
# newer versions of the tokenizer configure the response key as a special token. newer versions still may
|
||||
# append a newline to yield a single token. find whatever token is configured for the response key.
|
||||
tokenizer_response_key = next((token for token in self.tokenizer.additional_special_tokens if token.startswith(RESPONSE_KEY)), None)
|
||||
tokenizer_response_key = next(
|
||||
(token for token in self.tokenizer.additional_special_tokens if token.startswith(RESPONSE_KEY)), None)
|
||||
response_key_token_id = None
|
||||
end_key_token_id = None
|
||||
if tokenizer_response_key:
|
||||
@@ -56,7 +84,8 @@ def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.Pr
|
||||
inputs['instruction_text'] = input_
|
||||
return t.cast(t.Dict[str, t.Any], inputs)
|
||||
|
||||
def _forward(self, input_tensors: dict[str, t.Any], **generate_kwargs: t.Any) -> transformers.utils.generic.ModelOutput:
|
||||
def _forward(self, input_tensors: dict[str, t.Any],
|
||||
**generate_kwargs: t.Any) -> transformers.utils.generic.ModelOutput:
|
||||
if t.TYPE_CHECKING: assert self.tokenizer is not None
|
||||
input_ids, attention_mask = input_tensors['input_ids'], input_tensors.get('attention_mask', None)
|
||||
if input_ids.shape[1] == 0: input_ids, attention_mask, in_b = None, None, 1
|
||||
@@ -65,15 +94,20 @@ def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.Pr
|
||||
input_ids=input_ids.to(self.model.device) if input_ids is not None else None,
|
||||
attention_mask=attention_mask.to(self.model.device) if attention_mask is not None else None,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
**generate_kwargs
|
||||
)
|
||||
**generate_kwargs)
|
||||
out_b = generated_sequence.shape[0]
|
||||
if self.framework == 'pt': generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
|
||||
elif self.framework == 'tf': generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:]))
|
||||
if self.framework == 'pt':
|
||||
generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
|
||||
elif self.framework == 'tf':
|
||||
generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:]))
|
||||
instruction_text = input_tensors.pop('instruction_text')
|
||||
return {'generated_sequence': generated_sequence, 'input_ids': input_ids, 'instruction_text': instruction_text}
|
||||
|
||||
def postprocess(self, model_outputs: dict[str, t.Any], response_key_token_id: int, end_key_token_id: int, return_full_text: bool = False) -> list[dict[t.Literal['generated_text'], str]]:
|
||||
def postprocess(self,
|
||||
model_outputs: dict[str, t.Any],
|
||||
response_key_token_id: int,
|
||||
end_key_token_id: int,
|
||||
return_full_text: bool = False) -> list[dict[t.Literal['generated_text'], str]]:
|
||||
if t.TYPE_CHECKING: assert self.tokenizer is not None
|
||||
_generated_sequence, instruction_text = model_outputs['generated_sequence'][0], model_outputs['instruction_text']
|
||||
generated_sequence: list[list[int]] = _generated_sequence.numpy().tolist()
|
||||
@@ -89,7 +123,8 @@ def get_pipeline(model: transformers.PreTrainedModel, tokenizer: transformers.Pr
|
||||
response_pos = sequence.index(response_key_token_id)
|
||||
except ValueError:
|
||||
response_pos = None
|
||||
if response_pos is None: logger.warning('Could not find response key %s in: %s', response_key_token_id, sequence)
|
||||
if response_pos is None:
|
||||
logger.warning('Could not find response key %s in: %s', response_key_token_id, sequence)
|
||||
if response_pos:
|
||||
# Next find where "### End" is located. The model has been trained to end its responses with this
|
||||
# sequence (or actually, the token ID it maps to, since it is a special token). We may not find
|
||||
@@ -127,12 +162,20 @@ class DollyV2(openllm.LLM['transformers.Pipeline', 'transformers.PreTrainedToken
|
||||
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
|
||||
return {'device_map': 'auto' if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, 'torch_dtype': torch.bfloat16}, {}
|
||||
return {
|
||||
'device_map': 'auto' if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
|
||||
'torch_dtype': torch.bfloat16
|
||||
}, {}
|
||||
|
||||
def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.Pipeline:
|
||||
return get_pipeline(transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path, *args, **attrs), self.tokenizer, _init=True, return_full_text=self.config.return_full_text)
|
||||
return get_pipeline(transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path, *args, **attrs),
|
||||
self.tokenizer,
|
||||
_init=True,
|
||||
return_full_text=self.config.return_full_text)
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[dict[t.Literal['generated_text'], str]]:
|
||||
llm_config = self.config.model_construct_env(**attrs)
|
||||
with torch.inference_mode():
|
||||
return self.model(prompt, return_full_text=llm_config.return_full_text, generation_config=llm_config.to_generation_config())
|
||||
return self.model(prompt,
|
||||
return_full_text=llm_config.return_full_text,
|
||||
generation_config=llm_config.to_generation_config())
|
||||
|
||||
@@ -3,32 +3,43 @@ import typing as t
|
||||
|
||||
import openllm
|
||||
if t.TYPE_CHECKING: import torch, transformers
|
||||
else: torch, transformers = openllm.utils.LazyLoader('torch', globals(), 'torch'), openllm.utils.LazyLoader('transformers', globals(), 'transformers')
|
||||
else:
|
||||
torch, transformers = openllm.utils.LazyLoader('torch', globals(),
|
||||
'torch'), openllm.utils.LazyLoader('transformers', globals(),
|
||||
'transformers')
|
||||
|
||||
class Falcon(openllm.LLM['transformers.PreTrainedModel', 'transformers.PreTrainedTokenizerBase']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
|
||||
return {'torch_dtype': torch.bfloat16, 'device_map': 'auto' if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None}, {}
|
||||
return {
|
||||
'torch_dtype': torch.bfloat16,
|
||||
'device_map': 'auto' if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None
|
||||
}, {}
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
eos_token_id, inputs = attrs.pop('eos_token_id', self.tokenizer.eos_token_id), self.tokenizer(prompt, return_tensors='pt').to(self.device)
|
||||
eos_token_id, inputs = attrs.pop('eos_token_id',
|
||||
self.tokenizer.eos_token_id), self.tokenizer(prompt,
|
||||
return_tensors='pt').to(self.device)
|
||||
with torch.inference_mode(), torch.autocast('cuda', dtype=torch.float16): # type: ignore[attr-defined]
|
||||
return self.tokenizer.batch_decode(
|
||||
self.model.generate(
|
||||
input_ids=inputs['input_ids'],
|
||||
attention_mask=inputs['attention_mask'],
|
||||
generation_config=self.config.model_construct_env(eos_token_id=eos_token_id, **attrs).to_generation_config()
|
||||
),
|
||||
skip_special_tokens=True
|
||||
)
|
||||
return self.tokenizer.batch_decode(self.model.generate(
|
||||
input_ids=inputs['input_ids'],
|
||||
attention_mask=inputs['attention_mask'],
|
||||
generation_config=self.config.model_construct_env(eos_token_id=eos_token_id, **attrs).to_generation_config()),
|
||||
skip_special_tokens=True)
|
||||
|
||||
def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> list[dict[t.Literal['generated_text'], str]]:
|
||||
max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop('max_new_tokens', 200), self.tokenizer(prompt, return_tensors='pt').to(self.device)
|
||||
src_len, stopping_criteria = encoded_inputs['input_ids'].shape[1], preprocess_generate_kwds.pop('stopping_criteria', openllm.StoppingCriteriaList([]))
|
||||
def generate_one(self, prompt: str, stop: list[str],
|
||||
**preprocess_generate_kwds: t.Any) -> list[dict[t.Literal['generated_text'], str]]:
|
||||
max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop('max_new_tokens', 200), self.tokenizer(
|
||||
prompt, return_tensors='pt').to(self.device)
|
||||
src_len, stopping_criteria = encoded_inputs['input_ids'].shape[1], preprocess_generate_kwds.pop(
|
||||
'stopping_criteria', openllm.StoppingCriteriaList([]))
|
||||
stopping_criteria.append(openllm.StopSequenceCriteria(stop, self.tokenizer))
|
||||
result = self.tokenizer.decode(self.model.generate(encoded_inputs['input_ids'], max_new_tokens=max_new_tokens, stopping_criteria=stopping_criteria)[0].tolist()[src_len:])
|
||||
result = self.tokenizer.decode(
|
||||
self.model.generate(encoded_inputs['input_ids'],
|
||||
max_new_tokens=max_new_tokens,
|
||||
stopping_criteria=stopping_criteria)[0].tolist()[src_len:])
|
||||
# Inference API returns the stop sequence
|
||||
for stop_seq in stop:
|
||||
if result.endswith(stop_seq): result = result[:-len(stop_seq)]
|
||||
|
||||
@@ -12,9 +12,10 @@ class FlanT5(openllm.LLM['transformers.T5ForConditionalGeneration', 'transformer
|
||||
import torch
|
||||
with torch.inference_mode():
|
||||
return self.tokenizer.batch_decode(
|
||||
self.model.generate(**self.tokenizer(prompt, return_tensors='pt').to(self.device), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()),
|
||||
skip_special_tokens=True
|
||||
)
|
||||
self.model.generate(**self.tokenizer(prompt, return_tensors='pt').to(self.device),
|
||||
do_sample=True,
|
||||
generation_config=self.config.model_construct_env(**attrs).to_generation_config()),
|
||||
skip_special_tokens=True)
|
||||
|
||||
def embeddings(self, prompts: list[str]) -> openllm.LLMEmbeddings:
|
||||
import torch
|
||||
|
||||
@@ -9,18 +9,16 @@ if t.TYPE_CHECKING: import transformers
|
||||
class FlaxFlanT5(openllm.LLM['transformers.FlaxT5ForConditionalGeneration', 'transformers.T5TokenizerFast']):
|
||||
__openllm_internal__ = True
|
||||
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
prompt: str,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
top_k: int | None = None,
|
||||
top_p: float | None = None,
|
||||
repetition_penalty: float | None = None,
|
||||
decoder_start_token_id: int | None = None,
|
||||
use_default_prompt_template: bool = True,
|
||||
**attrs: t.Any
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
def sanitize_parameters(self,
|
||||
prompt: str,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
top_k: int | None = None,
|
||||
top_p: float | None = None,
|
||||
repetition_penalty: float | None = None,
|
||||
decoder_start_token_id: int | None = None,
|
||||
use_default_prompt_template: bool = True,
|
||||
**attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
if decoder_start_token_id is None: decoder_start_token_id = 0
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {
|
||||
'max_new_tokens': max_new_tokens,
|
||||
@@ -34,13 +32,10 @@ class FlaxFlanT5(openllm.LLM['transformers.FlaxT5ForConditionalGeneration', 'tra
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
# NOTE: decoder_start_token_id is extracted from https://huggingface.co/google/flan-t5-small/tree/main as it is required for encoder-decoder generation.
|
||||
decoder_start_token_id = attrs.pop('decoder_start_token_id', 0)
|
||||
return self.tokenizer.batch_decode(
|
||||
self.model.generate(
|
||||
self.tokenizer(prompt, return_tensors='np')['input_ids'],
|
||||
do_sample=True,
|
||||
generation_config=self.config.model_construct_env(**attrs).to_generation_config(),
|
||||
decoder_start_token_id=decoder_start_token_id
|
||||
).sequences,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=True
|
||||
)
|
||||
return self.tokenizer.batch_decode(self.model.generate(
|
||||
self.tokenizer(prompt, return_tensors='np')['input_ids'],
|
||||
do_sample=True,
|
||||
generation_config=self.config.model_construct_env(**attrs).to_generation_config(),
|
||||
decoder_start_token_id=decoder_start_token_id).sequences,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=True)
|
||||
|
||||
@@ -8,7 +8,8 @@ class TFFlanT5(openllm.LLM['transformers.TFT5ForConditionalGeneration', 'transfo
|
||||
__openllm_internal__ = True
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
return self.tokenizer.batch_decode(
|
||||
self.model.generate(self.tokenizer(prompt, return_tensors='tf').input_ids, do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()),
|
||||
skip_special_tokens=True
|
||||
)
|
||||
return self.tokenizer.batch_decode(self.model.generate(
|
||||
self.tokenizer(prompt, return_tensors='tf').input_ids,
|
||||
do_sample=True,
|
||||
generation_config=self.config.model_construct_env(**attrs).to_generation_config()),
|
||||
skip_special_tokens=True)
|
||||
|
||||
@@ -25,11 +25,8 @@ class GPTNeoX(openllm.LLM['transformers.GPTNeoXForCausalLM', 'transformers.GPTNe
|
||||
import torch
|
||||
with torch.inference_mode():
|
||||
return self.tokenizer.batch_decode(
|
||||
self.model.generate(
|
||||
self.tokenizer(prompt, return_tensors='pt').to(self.device).input_ids,
|
||||
do_sample=True,
|
||||
generation_config=self.config.model_construct_env(**attrs).to_generation_config(),
|
||||
pad_token_id=self.tokenizer.eos_token_id,
|
||||
stopping_criteria=openllm.StoppingCriteriaList([openllm.StopOnTokens()])
|
||||
)
|
||||
)
|
||||
self.model.generate(self.tokenizer(prompt, return_tensors='pt').to(self.device).input_ids,
|
||||
do_sample=True,
|
||||
generation_config=self.config.model_construct_env(**attrs).to_generation_config(),
|
||||
pad_token_id=self.tokenizer.eos_token_id,
|
||||
stopping_criteria=openllm.StoppingCriteriaList([openllm.StopOnTokens()])))
|
||||
|
||||
@@ -23,13 +23,20 @@ class Llama(openllm.LLM['transformers.LlamaForCausalLM', 'transformers.LlamaToke
|
||||
mask = attention_mask.unsqueeze(-1).expand(data.size()).float()
|
||||
masked_embeddings = data * mask
|
||||
sum_embeddings, seq_length = torch.sum(masked_embeddings, dim=1), torch.sum(mask, dim=1)
|
||||
return openllm.LLMEmbeddings(embeddings=F.normalize(sum_embeddings / seq_length, p=2, dim=1).tolist(), num_tokens=int(torch.sum(attention_mask).item()))
|
||||
return openllm.LLMEmbeddings(embeddings=F.normalize(sum_embeddings / seq_length, p=2, dim=1).tolist(),
|
||||
num_tokens=int(torch.sum(attention_mask).item()))
|
||||
|
||||
def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> list[dict[t.Literal['generated_text'], str]]:
|
||||
max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop('max_new_tokens', 200), self.tokenizer(prompt, return_tensors='pt').to(self.device)
|
||||
src_len, stopping_criteria = encoded_inputs['input_ids'].shape[1], preprocess_generate_kwds.pop('stopping_criteria', openllm.StoppingCriteriaList([]))
|
||||
def generate_one(self, prompt: str, stop: list[str],
|
||||
**preprocess_generate_kwds: t.Any) -> list[dict[t.Literal['generated_text'], str]]:
|
||||
max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop('max_new_tokens', 200), self.tokenizer(
|
||||
prompt, return_tensors='pt').to(self.device)
|
||||
src_len, stopping_criteria = encoded_inputs['input_ids'].shape[1], preprocess_generate_kwds.pop(
|
||||
'stopping_criteria', openllm.StoppingCriteriaList([]))
|
||||
stopping_criteria.append(openllm.StopSequenceCriteria(stop, self.tokenizer))
|
||||
result = self.tokenizer.decode(self.model.generate(encoded_inputs['input_ids'], max_new_tokens=max_new_tokens, stopping_criteria=stopping_criteria)[0].tolist()[src_len:])
|
||||
result = self.tokenizer.decode(
|
||||
self.model.generate(encoded_inputs['input_ids'],
|
||||
max_new_tokens=max_new_tokens,
|
||||
stopping_criteria=stopping_criteria)[0].tolist()[src_len:])
|
||||
# Inference API returns the stop sequence
|
||||
for stop_seq in stop:
|
||||
if result.endswith(stop_seq): result = result[:-len(stop_seq)]
|
||||
|
||||
@@ -12,12 +12,15 @@ if t.TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def get_mpt_config(
|
||||
model_id_or_path: str, max_sequence_length: int, device: torch.device | str | int | None, device_map: str | None = None, trust_remote_code: bool = True
|
||||
) -> transformers.PretrainedConfig:
|
||||
def get_mpt_config(model_id_or_path: str,
|
||||
max_sequence_length: int,
|
||||
device: torch.device | str | int | None,
|
||||
device_map: str | None = None,
|
||||
trust_remote_code: bool = True) -> transformers.PretrainedConfig:
|
||||
import torch
|
||||
config = transformers.AutoConfig.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code)
|
||||
if hasattr(config, 'init_device') and device_map is None and isinstance(device, (str, torch.device)): config.init_device = str(device)
|
||||
if hasattr(config, 'init_device') and device_map is None and isinstance(device, (str, torch.device)):
|
||||
config.init_device = str(device)
|
||||
if hasattr(config, 'attn_config') and is_triton_available(): config.attn_config['attn_impl'] = 'triton'
|
||||
else:
|
||||
logger.debug(
|
||||
@@ -37,7 +40,10 @@ class MPT(openllm.LLM['transformers.PreTrainedModel', 'transformers.GPTNeoXToken
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
|
||||
import torch
|
||||
return {'device_map': 'auto' if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, 'torch_dtype': torch.bfloat16 if torch.cuda.is_available() else torch.float32}, {}
|
||||
return {
|
||||
'device_map': 'auto' if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
|
||||
'torch_dtype': torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
||||
}, {}
|
||||
|
||||
def import_model(self, *args: t.Any, trust_remote_code: bool = True, **attrs: t.Any) -> bentoml.Model:
|
||||
import torch
|
||||
@@ -46,12 +52,24 @@ class MPT(openllm.LLM['transformers.PreTrainedModel', 'transformers.GPTNeoXToken
|
||||
torch_dtype = attrs.pop('torch_dtype', self.dtype)
|
||||
device_map = attrs.pop('device_map', None)
|
||||
attrs.pop('low_cpu_mem_usage', None)
|
||||
config = get_mpt_config(self.model_id, self.config.max_sequence_length, self.device, device_map=device_map, trust_remote_code=trust_remote_code)
|
||||
config = get_mpt_config(self.model_id,
|
||||
self.config.max_sequence_length,
|
||||
self.device,
|
||||
device_map=device_map,
|
||||
trust_remote_code=trust_remote_code)
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **tokenizer_attrs)
|
||||
if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(self.model_id, config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code, device_map=device_map, **attrs)
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(self.model_id,
|
||||
config=config,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
device_map=device_map,
|
||||
**attrs)
|
||||
try:
|
||||
return bentoml.transformers.save_model(self.tag, model, custom_objects={'tokenizer': tokenizer}, labels=generate_labels(self))
|
||||
return bentoml.transformers.save_model(self.tag,
|
||||
model,
|
||||
custom_objects={'tokenizer': tokenizer},
|
||||
labels=generate_labels(self))
|
||||
finally:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -60,10 +78,18 @@ class MPT(openllm.LLM['transformers.PreTrainedModel', 'transformers.GPTNeoXToken
|
||||
torch_dtype = attrs.pop('torch_dtype', self.dtype)
|
||||
device_map = attrs.pop('device_map', None)
|
||||
trust_remote_code = attrs.pop('trust_remote_code', True)
|
||||
config = get_mpt_config(self._bentomodel.path, self.config.max_sequence_length, self.device, device_map=device_map, trust_remote_code=trust_remote_code,)
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(
|
||||
self._bentomodel.path, config=config, trust_remote_code=trust_remote_code, torch_dtype=torch_dtype, device_map=device_map, **attrs
|
||||
)
|
||||
config = get_mpt_config(self._bentomodel.path,
|
||||
self.config.max_sequence_length,
|
||||
self.device,
|
||||
device_map=device_map,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path,
|
||||
config=config,
|
||||
trust_remote_code=trust_remote_code,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=device_map,
|
||||
**attrs)
|
||||
model.tie_weights()
|
||||
return model
|
||||
|
||||
|
||||
@@ -16,29 +16,35 @@ class FlaxOPT(openllm.LLM['transformers.TFOPTForCausalLM', 'transformers.GPT2Tok
|
||||
__openllm_internal__ = True
|
||||
|
||||
def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model:
|
||||
config, tokenizer = transformers.AutoConfig.from_pretrained(self.model_id), transformers.AutoTokenizer.from_pretrained(self.model_id, **self.llm_parameters[-1])
|
||||
config, tokenizer = transformers.AutoConfig.from_pretrained(
|
||||
self.model_id), transformers.AutoTokenizer.from_pretrained(self.model_id, **self.llm_parameters[-1])
|
||||
tokenizer.pad_token_id = config.pad_token_id
|
||||
return bentoml.transformers.save_model(
|
||||
self.tag, transformers.FlaxAutoModelForCausalLM.from_pretrained(self.model_id, **attrs), custom_objects={'tokenizer': tokenizer}, labels=generate_labels(self)
|
||||
)
|
||||
return bentoml.transformers.save_model(self.tag,
|
||||
transformers.FlaxAutoModelForCausalLM.from_pretrained(
|
||||
self.model_id, **attrs),
|
||||
custom_objects={'tokenizer': tokenizer},
|
||||
labels=generate_labels(self))
|
||||
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
prompt: str,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
top_k: int | None = None,
|
||||
num_return_sequences: int | None = None,
|
||||
repetition_penalty: float | None = None,
|
||||
use_default_prompt_template: bool = False,
|
||||
**attrs: t.Any
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
def sanitize_parameters(self,
|
||||
prompt: str,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
top_k: int | None = None,
|
||||
num_return_sequences: int | None = None,
|
||||
repetition_penalty: float | None = None,
|
||||
use_default_prompt_template: bool = False,
|
||||
**attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {
|
||||
'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_k': top_k, 'num_return_sequences': num_return_sequences, 'repetition_penalty': repetition_penalty
|
||||
'max_new_tokens': max_new_tokens,
|
||||
'temperature': temperature,
|
||||
'top_k': top_k,
|
||||
'num_return_sequences': num_return_sequences,
|
||||
'repetition_penalty': repetition_penalty
|
||||
}, {}
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
return self.tokenizer.batch_decode(
|
||||
self.model.generate(**self.tokenizer(prompt, return_tensors='np'), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()).sequences,
|
||||
skip_special_tokens=True
|
||||
)
|
||||
return self.tokenizer.batch_decode(self.model.generate(**self.tokenizer(prompt, return_tensors='np'),
|
||||
do_sample=True,
|
||||
generation_config=self.config.model_construct_env(
|
||||
**attrs).to_generation_config()).sequences,
|
||||
skip_special_tokens=True)
|
||||
|
||||
@@ -19,6 +19,7 @@ class OPT(openllm.LLM['transformers.OPTForCausalLM', 'transformers.GPT2Tokenizer
|
||||
import torch
|
||||
with torch.inference_mode():
|
||||
return self.tokenizer.batch_decode(
|
||||
self.model.generate(**self.tokenizer(prompt, return_tensors='pt').to(self.device), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()),
|
||||
skip_special_tokens=True
|
||||
)
|
||||
self.model.generate(**self.tokenizer(prompt, return_tensors='pt').to(self.device),
|
||||
do_sample=True,
|
||||
generation_config=self.config.model_construct_env(**attrs).to_generation_config()),
|
||||
skip_special_tokens=True)
|
||||
|
||||
@@ -11,17 +11,18 @@ class TFOPT(openllm.LLM['transformers.TFOPTForCausalLM', 'transformers.GPT2Token
|
||||
|
||||
def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model:
|
||||
import transformers
|
||||
config, tokenizer = transformers.AutoConfig.from_pretrained(self.model_id), transformers.AutoTokenizer.from_pretrained(self.model_id, **self.llm_parameters[-1])
|
||||
config, tokenizer = transformers.AutoConfig.from_pretrained(
|
||||
self.model_id), transformers.AutoTokenizer.from_pretrained(self.model_id, **self.llm_parameters[-1])
|
||||
tokenizer.pad_token_id = config.pad_token_id
|
||||
return bentoml.transformers.save_model(
|
||||
self.tag,
|
||||
transformers.TFOPTForCausalLM.from_pretrained(self.model_id, trust_remote_code=trust_remote_code, **attrs),
|
||||
custom_objects={'tokenizer': tokenizer},
|
||||
labels=generate_labels(self)
|
||||
)
|
||||
return bentoml.transformers.save_model(self.tag,
|
||||
transformers.TFOPTForCausalLM.from_pretrained(
|
||||
self.model_id, trust_remote_code=trust_remote_code, **attrs),
|
||||
custom_objects={'tokenizer': tokenizer},
|
||||
labels=generate_labels(self))
|
||||
|
||||
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
|
||||
return self.tokenizer.batch_decode(
|
||||
self.model.generate(**self.tokenizer(prompt, return_tensors='tf'), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()),
|
||||
skip_special_tokens=True
|
||||
)
|
||||
self.model.generate(**self.tokenizer(prompt, return_tensors='tf'),
|
||||
do_sample=True,
|
||||
generation_config=self.config.model_construct_env(**attrs).to_generation_config()),
|
||||
skip_special_tokens=True)
|
||||
|
||||
@@ -10,16 +10,17 @@ class VLLMOPT(openllm.LLM['vllm.LLMEngine', 'transformers.GPT2Tokenizer']):
|
||||
__openllm_internal__ = True
|
||||
tokenizer_id = 'local'
|
||||
|
||||
def sanitize_parameters(
|
||||
self,
|
||||
prompt: str,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
top_k: int | None = None,
|
||||
num_return_sequences: int | None = None,
|
||||
use_default_prompt_template: bool = True,
|
||||
**attrs: t.Any
|
||||
) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
def sanitize_parameters(self,
|
||||
prompt: str,
|
||||
max_new_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
top_k: int | None = None,
|
||||
num_return_sequences: int | None = None,
|
||||
use_default_prompt_template: bool = True,
|
||||
**attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
|
||||
return process_prompt(prompt, DEFAULT_PROMPT_TEMPLATE, use_default_prompt_template, **attrs), {
|
||||
'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_k': top_k, 'num_return_sequences': num_return_sequences
|
||||
'max_new_tokens': max_new_tokens,
|
||||
'temperature': temperature,
|
||||
'top_k': top_k,
|
||||
'num_return_sequences': num_return_sequences
|
||||
}, {}
|
||||
|
||||
@@ -22,13 +22,10 @@ class StableLM(openllm.LLM['transformers.GPTNeoXForCausalLM', 'transformers.GPTN
|
||||
with torch.inference_mode():
|
||||
return [
|
||||
self.tokenizer.decode(
|
||||
self.model.generate(
|
||||
**self.tokenizer(prompt, return_tensors='pt').to(self.device),
|
||||
do_sample=True,
|
||||
generation_config=self.config.model_construct_env(**attrs).to_generation_config(),
|
||||
pad_token_id=self.tokenizer.eos_token_id,
|
||||
stopping_criteria=openllm.StoppingCriteriaList([openllm.StopOnTokens()])
|
||||
)[0],
|
||||
skip_special_tokens=True
|
||||
)
|
||||
self.model.generate(**self.tokenizer(prompt, return_tensors='pt').to(self.device),
|
||||
do_sample=True,
|
||||
generation_config=self.config.model_construct_env(**attrs).to_generation_config(),
|
||||
pad_token_id=self.tokenizer.eos_token_id,
|
||||
stopping_criteria=openllm.StoppingCriteriaList([openllm.StopOnTokens()]))[0],
|
||||
skip_special_tokens=True)
|
||||
]
|
||||
|
||||
@@ -18,17 +18,29 @@ class StarCoder(openllm.LLM['transformers.GPTBigCodeForCausalLM', 'transformers.
|
||||
@property
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
|
||||
import torch
|
||||
return {'device_map': 'auto' if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, 'torch_dtype': torch.float16 if torch.cuda.is_available() else torch.float32}, {}
|
||||
return {
|
||||
'device_map': 'auto' if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
|
||||
'torch_dtype': torch.float16 if torch.cuda.is_available() else torch.float32
|
||||
}, {}
|
||||
|
||||
def import_model(self, *args: t.Any, trust_remote_code: bool = False, **attrs: t.Any) -> bentoml.Model:
|
||||
import torch
|
||||
import transformers
|
||||
torch_dtype, device_map = attrs.pop('torch_dtype', torch.float16), attrs.pop('device_map', 'auto')
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id, **self.llm_parameters[-1])
|
||||
tokenizer.add_special_tokens({'additional_special_tokens': [EOD, FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD], 'pad_token': EOD})
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(self.model_id, torch_dtype=torch_dtype, device_map=device_map, **attrs)
|
||||
tokenizer.add_special_tokens({
|
||||
'additional_special_tokens': [EOD, FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD],
|
||||
'pad_token': EOD
|
||||
})
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(self.model_id,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=device_map,
|
||||
**attrs)
|
||||
try:
|
||||
return bentoml.transformers.save_model(self.tag, model, custom_objects={'tokenizer': tokenizer}, labels=generate_labels(self))
|
||||
return bentoml.transformers.save_model(self.tag,
|
||||
model,
|
||||
custom_objects={'tokenizer': tokenizer},
|
||||
labels=generate_labels(self))
|
||||
finally:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -41,17 +53,22 @@ class StarCoder(openllm.LLM['transformers.GPTBigCodeForCausalLM', 'transformers.
|
||||
self.tokenizer.encode(prompt, return_tensors='pt').to(self.device),
|
||||
do_sample=True,
|
||||
pad_token_id=self.tokenizer.eos_token_id,
|
||||
generation_config=self.config.model_construct_env(**attrs).to_generation_config()
|
||||
)
|
||||
generation_config=self.config.model_construct_env(**attrs).to_generation_config())
|
||||
# TODO: We will probably want to return the tokenizer here so that we can manually process this
|
||||
# return (skip_special_tokens=False, clean_up_tokenization_spaces=False))
|
||||
return self.tokenizer.batch_decode(result_tensor[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
|
||||
def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> list[dict[t.Literal['generated_text'], str]]:
|
||||
max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop('max_new_tokens', 200), self.tokenizer(prompt, return_tensors='pt').to(self.device)
|
||||
src_len, stopping_criteria = encoded_inputs['input_ids'].shape[1], preprocess_generate_kwds.pop('stopping_criteria', openllm.StoppingCriteriaList([]))
|
||||
def generate_one(self, prompt: str, stop: list[str],
|
||||
**preprocess_generate_kwds: t.Any) -> list[dict[t.Literal['generated_text'], str]]:
|
||||
max_new_tokens, encoded_inputs = preprocess_generate_kwds.pop('max_new_tokens', 200), self.tokenizer(
|
||||
prompt, return_tensors='pt').to(self.device)
|
||||
src_len, stopping_criteria = encoded_inputs['input_ids'].shape[1], preprocess_generate_kwds.pop(
|
||||
'stopping_criteria', openllm.StoppingCriteriaList([]))
|
||||
stopping_criteria.append(openllm.StopSequenceCriteria(stop, self.tokenizer))
|
||||
result = self.tokenizer.decode(self.model.generate(encoded_inputs['input_ids'], max_new_tokens=max_new_tokens, stopping_criteria=stopping_criteria)[0].tolist()[src_len:])
|
||||
result = self.tokenizer.decode(
|
||||
self.model.generate(encoded_inputs['input_ids'],
|
||||
max_new_tokens=max_new_tokens,
|
||||
stopping_criteria=stopping_criteria)[0].tolist()[src_len:])
|
||||
# Inference API returns the stop sequence
|
||||
for stop_seq in stop:
|
||||
if result.endswith(stop_seq): result = result[:-len(stop_seq)]
|
||||
|
||||
@@ -56,20 +56,34 @@ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
else:
|
||||
model_args, training_args = t.cast(t.Tuple[ModelArguments, TrainingArguments], parser.parse_args_into_dataclasses())
|
||||
|
||||
model, tokenizer = openllm.AutoLLM.for_model("falcon", model_id=model_args.model_id, quantize="int4", bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, ensure_available=True).prepare_for_training(adapter_type="lora", lora_alpha=16, lora_dropout=0.1, r=16, bias="none", target_modules=["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"])
|
||||
model, tokenizer = openllm.AutoLLM.for_model("falcon",
|
||||
model_id=model_args.model_id,
|
||||
quantize="int4",
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.float16,
|
||||
ensure_available=True).prepare_for_training(adapter_type="lora",
|
||||
lora_alpha=16,
|
||||
lora_dropout=0.1,
|
||||
r=16,
|
||||
bias="none",
|
||||
target_modules=[
|
||||
"query_key_value", "dense",
|
||||
"dense_h_to_4h",
|
||||
"dense_4h_to_h"
|
||||
])
|
||||
model.config.use_cache = False
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
dataset = load_dataset(DATASET_NAME, split="train")
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
train_dataset=dataset,
|
||||
dataset_text_field="text",
|
||||
max_seq_length=model_args.max_sequence_length,
|
||||
tokenizer=tokenizer,
|
||||
args=dataclasses.replace(transformers.TrainingArguments(training_args.output_dir), **dataclasses.asdict(training_args)),
|
||||
)
|
||||
trainer = SFTTrainer(model=model,
|
||||
train_dataset=dataset,
|
||||
dataset_text_field="text",
|
||||
max_seq_length=model_args.max_sequence_length,
|
||||
tokenizer=tokenizer,
|
||||
args=dataclasses.replace(transformers.TrainingArguments(training_args.output_dir),
|
||||
**dataclasses.asdict(training_args)),
|
||||
)
|
||||
|
||||
# upcast layernorm in float32 for more stable training
|
||||
for name, module in trainer.model.named_modules():
|
||||
|
||||
@@ -75,10 +75,13 @@ def chunk(sample, chunk_length=2048):
|
||||
|
||||
# get max number of chunks for batch
|
||||
if batch_total_length >= chunk_length:
|
||||
batch_chunk_length = (batch_total_length//chunk_length) * chunk_length
|
||||
batch_chunk_length = (batch_total_length // chunk_length) * chunk_length
|
||||
|
||||
# Split by chunks of max_len.
|
||||
result = {k: [t[i:i + chunk_length] for i in range(0, batch_chunk_length, chunk_length)] for k, t in concatenated_examples.items()}
|
||||
result = {
|
||||
k: [t[i:i + chunk_length] for i in range(0, batch_chunk_length, chunk_length)]
|
||||
for k, t in concatenated_examples.items()
|
||||
}
|
||||
# add remainder to global variable for next batch
|
||||
remainder = {k: concatenated_examples[k][batch_chunk_length:] for k in concatenated_examples.keys()}
|
||||
# prepare labels
|
||||
@@ -98,33 +101,39 @@ def prepare_datasets(tokenizer, dataset_name=DATASET_NAME):
|
||||
print("Sample from dolly-v2 ds:", dataset[randint(0, len(dataset))]["text"])
|
||||
|
||||
# tokenize and chunk dataset
|
||||
lm_dataset = dataset.map(lambda sample: tokenizer(sample["text"]), batched=True, remove_columns=list(dataset.features)).map(partial(chunk, chunk_length=2048), batched=True,)
|
||||
lm_dataset = dataset.map(lambda sample: tokenizer(sample["text"]),
|
||||
batched=True,
|
||||
remove_columns=list(dataset.features)).map(partial(chunk, chunk_length=2048), batched=True)
|
||||
|
||||
# Print total number of samples
|
||||
print(f"Total number of samples: {len(lm_dataset)}")
|
||||
return lm_dataset
|
||||
|
||||
def prepare_for_int4_training(model_id: str, model_version: str | None = None, gradient_checkpointing: bool = True, bf16: bool = True,
|
||||
) -> tuple[peft.PeftModel, transformers.LlamaTokenizerFast]:
|
||||
def prepare_for_int4_training(model_id: str,
|
||||
model_version: str | None = None,
|
||||
gradient_checkpointing: bool = True,
|
||||
bf16: bool = True,
|
||||
) -> tuple[peft.PeftModel, transformers.LlamaTokenizerFast]:
|
||||
from peft.tuners.lora import LoraLayer
|
||||
|
||||
llm = openllm.AutoLLM.for_model(
|
||||
"llama",
|
||||
model_id=model_id,
|
||||
model_version=model_version,
|
||||
ensure_available=True,
|
||||
quantize="int4",
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
use_cache=not gradient_checkpointing,
|
||||
device_map="auto",
|
||||
)
|
||||
llm = openllm.AutoLLM.for_model("llama",
|
||||
model_id=model_id,
|
||||
model_version=model_version,
|
||||
ensure_available=True,
|
||||
quantize="int4",
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
use_cache=not gradient_checkpointing,
|
||||
device_map="auto",
|
||||
)
|
||||
print("Model summary:", llm.model)
|
||||
|
||||
# get lora target modules
|
||||
modules = find_all_linear_names(llm.model)
|
||||
print(f"Found {len(modules)} modules to quantize: {modules}")
|
||||
|
||||
model, tokenizer = llm.prepare_for_training(adapter_type="lora", use_gradient_checkpointing=gradient_checkpointing, target_modules=modules)
|
||||
model, tokenizer = llm.prepare_for_training(adapter_type="lora",
|
||||
use_gradient_checkpointing=gradient_checkpointing,
|
||||
target_modules=modules)
|
||||
|
||||
# pre-process the model by upcasting the layer norms in float 32 for
|
||||
for name, module in model.named_modules():
|
||||
@@ -177,15 +186,18 @@ def train_loop(model_args: ModelArguments, training_args: TrainingArguments):
|
||||
|
||||
transformers.set_seed(model_args.seed)
|
||||
|
||||
model, tokenizer = prepare_for_int4_training(model_args.model_id, gradient_checkpointing=training_args.gradient_checkpointing, bf16=training_args.bf16,)
|
||||
model, tokenizer = prepare_for_int4_training(model_args.model_id,
|
||||
gradient_checkpointing=training_args.gradient_checkpointing,
|
||||
bf16=training_args.bf16,
|
||||
)
|
||||
datasets = prepare_datasets(tokenizer)
|
||||
|
||||
trainer = transformers.Trainer(
|
||||
model=model,
|
||||
args=dataclasses.replace(transformers.TrainingArguments(training_args.output_dir), **dataclasses.asdict(training_args)),
|
||||
train_dataset=datasets,
|
||||
data_collator=transformers.default_data_collator,
|
||||
)
|
||||
trainer = transformers.Trainer(model=model,
|
||||
args=dataclasses.replace(transformers.TrainingArguments(training_args.output_dir),
|
||||
**dataclasses.asdict(training_args)),
|
||||
train_dataset=datasets,
|
||||
data_collator=transformers.default_data_collator,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
@@ -200,10 +212,14 @@ def train_loop(model_args: ModelArguments, training_args: TrainingArguments):
|
||||
del model, trainer
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
model = peft.AutoPeftModelForCausalLM.from_pretrained(training_args.output_dir, low_cpu_mem_usage=True, torch_dtype=torch.float16)
|
||||
model = peft.AutoPeftModelForCausalLM.from_pretrained(training_args.output_dir,
|
||||
low_cpu_mem_usage=True,
|
||||
torch_dtype=torch.float16)
|
||||
# merge lora with base weights and save
|
||||
model = model.merge_and_unload()
|
||||
model.save_pretrained(os.path.join(os.getcwd(), "outputs", "merged_llama_lora"), safe_serialization=True, max_shard_size="2GB")
|
||||
model.save_pretrained(os.path.join(os.getcwd(), "outputs", "merged_llama_lora"),
|
||||
safe_serialization=True,
|
||||
max_shard_size="2GB")
|
||||
else:
|
||||
trainer.model.save_pretrained(os.path.join(training_args.output_dir, "lora"))
|
||||
|
||||
|
||||
@@ -26,13 +26,14 @@ if t.TYPE_CHECKING:
|
||||
|
||||
DEFAULT_MODEL_ID = "facebook/opt-6.7b"
|
||||
|
||||
def load_trainer(model: PeftModel, tokenizer: transformers.GPT2TokenizerFast, dataset_dict: t.Any, training_args: TrainingArguments):
|
||||
return transformers.Trainer(
|
||||
model=model,
|
||||
train_dataset=dataset_dict["train"],
|
||||
args=dataclasses.replace(transformers.TrainingArguments(training_args.output_dir), **dataclasses.asdict(training_args)),
|
||||
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
|
||||
)
|
||||
def load_trainer(model: PeftModel, tokenizer: transformers.GPT2TokenizerFast, dataset_dict: t.Any,
|
||||
training_args: TrainingArguments):
|
||||
return transformers.Trainer(model=model,
|
||||
train_dataset=dataset_dict["train"],
|
||||
args=dataclasses.replace(transformers.TrainingArguments(training_args.output_dir),
|
||||
**dataclasses.asdict(training_args)),
|
||||
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
|
||||
)
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TrainingArguments:
|
||||
@@ -57,7 +58,16 @@ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
else:
|
||||
model_args, training_args = t.cast(t.Tuple[ModelArguments, TrainingArguments], parser.parse_args_into_dataclasses())
|
||||
|
||||
model, tokenizer = openllm.AutoLLM.for_model("opt", model_id=model_args.model_id, quantize="int8", ensure_available=True).prepare_for_training(adapter_type="lora", r=16, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none")
|
||||
model, tokenizer = openllm.AutoLLM.for_model("opt",
|
||||
model_id=model_args.model_id,
|
||||
quantize="int8",
|
||||
ensure_available=True).prepare_for_training(
|
||||
adapter_type="lora",
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules=["q_proj", "v_proj"],
|
||||
lora_dropout=0.05,
|
||||
bias="none")
|
||||
|
||||
# ft on english_quotes
|
||||
data = load_dataset("Abirate/english_quotes")
|
||||
|
||||
@@ -64,10 +64,11 @@ def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T:
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
"Bento model does not have tokenizer. Make sure to save"
|
||||
" the tokenizer within the model via 'custom_objects'."
|
||||
" For example: \"bentoml.transformers.save_model(..., custom_objects={'tokenizer': tokenizer})\""
|
||||
) from None
|
||||
" For example: \"bentoml.transformers.save_model(..., custom_objects={'tokenizer': tokenizer})\"") from None
|
||||
else:
|
||||
tokenizer = infer_tokenizers_from_llm(llm).from_pretrained(bentomodel_fs.getsyspath('/'), trust_remote_code=llm.__llm_trust_remote_code__, **tokenizer_attrs)
|
||||
tokenizer = infer_tokenizers_from_llm(llm).from_pretrained(bentomodel_fs.getsyspath('/'),
|
||||
trust_remote_code=llm.__llm_trust_remote_code__,
|
||||
**tokenizer_attrs)
|
||||
|
||||
if tokenizer.pad_token_id is None:
|
||||
if config.pad_token_id is not None: tokenizer.pad_token_id = config.pad_token_id
|
||||
@@ -77,12 +78,14 @@ def load_tokenizer(llm: openllm.LLM[t.Any, T], **tokenizer_attrs: t.Any) -> T:
|
||||
return tokenizer
|
||||
|
||||
class _Caller(t.Protocol[P]):
|
||||
|
||||
def __call__(self, llm: openllm.LLM[M, T], *args: P.args, **kwargs: P.kwargs) -> t.Any:
|
||||
...
|
||||
|
||||
_extras = ['get', 'import_model', 'save_pretrained', 'load_model']
|
||||
|
||||
def _make_dispatch_function(fn: str) -> _Caller[P]:
|
||||
|
||||
def caller(llm: openllm.LLM[M, T], *args: P.args, **kwargs: P.kwargs) -> t.Any:
|
||||
"""Generic function dispatch to correct serialisation submodules based on LLM runtime.
|
||||
|
||||
|
||||
@@ -6,4 +6,7 @@ FRAMEWORK_TO_AUTOCLASS_MAPPING = {
|
||||
'flax': ('FlaxAutoModelForCausalLM', 'FlaxAutoModelForSeq2SeqLM'),
|
||||
'vllm': ('AutoModelForCausalLM', 'AutoModelForSeq2SeqLM')
|
||||
}
|
||||
HUB_ATTRS = ['cache_dir', 'code_revision', 'force_download', 'local_files_only', 'proxies', 'resume_download', 'revision', 'subfolder', 'use_auth_token']
|
||||
HUB_ATTRS = [
|
||||
'cache_dir', 'code_revision', 'force_download', 'local_files_only', 'proxies', 'resume_download', 'revision',
|
||||
'subfolder', 'use_auth_token'
|
||||
]
|
||||
|
||||
@@ -13,7 +13,11 @@ if t.TYPE_CHECKING:
|
||||
|
||||
_conversion_strategy = {'pt': 'ggml'}
|
||||
|
||||
def import_model(llm: openllm.LLM[t.Any, t.Any], *decls: t.Any, trust_remote_code: bool = True, **attrs: t.Any,) -> bentoml.Model:
|
||||
def import_model(llm: openllm.LLM[t.Any, t.Any],
|
||||
*decls: t.Any,
|
||||
trust_remote_code: bool = True,
|
||||
**attrs: t.Any,
|
||||
) -> bentoml.Model:
|
||||
raise NotImplementedError('Currently work in progress.')
|
||||
|
||||
def get(llm: openllm.LLM[t.Any, t.Any], auto_import: bool = False) -> bentoml.Model:
|
||||
@@ -27,9 +31,12 @@ def get(llm: openllm.LLM[t.Any, t.Any], auto_import: bool = False) -> bentoml.Mo
|
||||
try:
|
||||
model = bentoml.models.get(llm.tag)
|
||||
if model.info.module not in ('openllm.serialisation.ggml', __name__):
|
||||
raise bentoml.exceptions.NotFound(f"Model {model.tag} was saved with module {model.info.module}, not loading with 'openllm.serialisation.transformers'.")
|
||||
raise bentoml.exceptions.NotFound(
|
||||
f"Model {model.tag} was saved with module {model.info.module}, not loading with 'openllm.serialisation.transformers'."
|
||||
)
|
||||
if 'runtime' in model.info.labels and model.info.labels['runtime'] != llm.runtime:
|
||||
raise openllm.exceptions.OpenLLMException(f"Model {model.tag} was saved with runtime {model.info.labels['runtime']}, not loading with {llm.runtime}.")
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
f"Model {model.tag} was saved with runtime {model.info.labels['runtime']}, not loading with {llm.runtime}.")
|
||||
return model
|
||||
except bentoml.exceptions.NotFound:
|
||||
if auto_import:
|
||||
|
||||
@@ -46,7 +46,11 @@ logger = logging.getLogger(__name__)
|
||||
__all__ = ['import_model', 'get', 'load_model', 'save_pretrained']
|
||||
|
||||
@inject
|
||||
def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool, _model_store: ModelStore = Provide[BentoMLContainer.model_store], **attrs: t.Any) -> bentoml.Model:
|
||||
def import_model(llm: openllm.LLM[M, T],
|
||||
*decls: t.Any,
|
||||
trust_remote_code: bool,
|
||||
_model_store: ModelStore = Provide[BentoMLContainer.model_store],
|
||||
**attrs: t.Any) -> bentoml.Model:
|
||||
"""Auto detect model type from given model_id and import it to bentoml's model store.
|
||||
|
||||
For all kwargs, it will be parsed into `transformers.AutoConfig.from_pretrained` first,
|
||||
@@ -67,80 +71,110 @@ def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool,
|
||||
config, hub_attrs, attrs = process_config(llm.model_id, trust_remote_code, **attrs)
|
||||
_, tokenizer_attrs = llm.llm_parameters
|
||||
quantize_method = llm._quantize_method
|
||||
safe_serialisation = openllm.utils.first_not_none(attrs.get('safe_serialization'), default=llm._serialisation_format == 'safetensors')
|
||||
safe_serialisation = openllm.utils.first_not_none(attrs.get('safe_serialization'),
|
||||
default=llm._serialisation_format == 'safetensors')
|
||||
# Disable safe serialization with vLLM
|
||||
if llm.__llm_implementation__ == 'vllm': safe_serialisation = False
|
||||
metadata: DictStrAny = {'safe_serialisation': safe_serialisation, '_quantize': quantize_method is not None and quantize_method}
|
||||
metadata: DictStrAny = {
|
||||
'safe_serialisation': safe_serialisation,
|
||||
'_quantize': quantize_method is not None and quantize_method
|
||||
}
|
||||
signatures: DictStrAny = {}
|
||||
|
||||
if quantize_method == 'gptq':
|
||||
if not openllm.utils.is_autogptq_available():
|
||||
raise openllm.exceptions.OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
|
||||
if llm.config['model_type'] != 'causal_lm': raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
"GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'"
|
||||
)
|
||||
if llm.config['model_type'] != 'causal_lm':
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
signatures['generate'] = {'batchable': False}
|
||||
else:
|
||||
# this model might be called with --quantize int4, therefore we need to pop this out
|
||||
# since saving int4 is not yet supported
|
||||
if 'quantization_config' in attrs and getattr(attrs['quantization_config'], 'load_in_4bit', False): attrs.pop('quantization_config')
|
||||
if 'quantization_config' in attrs and getattr(attrs['quantization_config'], 'load_in_4bit', False):
|
||||
attrs.pop('quantization_config')
|
||||
if llm.__llm_implementation__ != 'flax': attrs['use_safetensors'] = safe_serialisation
|
||||
metadata['_framework'] = 'pt' if llm.__llm_implementation__ == 'vllm' else llm.__llm_implementation__
|
||||
|
||||
tokenizer = infer_tokenizers_from_llm(llm).from_pretrained(llm.model_id, trust_remote_code=trust_remote_code, **hub_attrs, **tokenizer_attrs)
|
||||
tokenizer = infer_tokenizers_from_llm(llm).from_pretrained(llm.model_id,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**hub_attrs,
|
||||
**tokenizer_attrs)
|
||||
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
external_modules: list[types.ModuleType] = [importlib.import_module(tokenizer.__module__)]
|
||||
imported_modules: list[types.ModuleType] = []
|
||||
bentomodel = bentoml.Model.create(
|
||||
llm.tag,
|
||||
module='openllm.serialisation.transformers',
|
||||
api_version='v1',
|
||||
options=ModelOptions(),
|
||||
context=openllm.utils.generate_context(framework_name='openllm'),
|
||||
labels=openllm.utils.generate_labels(llm),
|
||||
signatures=signatures if signatures else make_model_signatures(llm)
|
||||
)
|
||||
bentomodel = bentoml.Model.create(llm.tag,
|
||||
module='openllm.serialisation.transformers',
|
||||
api_version='v1',
|
||||
options=ModelOptions(),
|
||||
context=openllm.utils.generate_context(framework_name='openllm'),
|
||||
labels=openllm.utils.generate_labels(llm),
|
||||
signatures=signatures if signatures else make_model_signatures(llm))
|
||||
with openllm.utils.analytics.set_bentoml_tracking():
|
||||
try:
|
||||
bentomodel.enter_cloudpickle_context(external_modules, imported_modules)
|
||||
tokenizer.save_pretrained(bentomodel.path)
|
||||
if quantize_method == 'gptq':
|
||||
if not openllm.utils.is_autogptq_available():
|
||||
raise openllm.exceptions.OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
|
||||
if llm.config['model_type'] != 'causal_lm': raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
"GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'"
|
||||
)
|
||||
if llm.config['model_type'] != 'causal_lm':
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
logger.debug('Saving model with GPTQ quantisation will require loading model into memory.')
|
||||
model = autogptq.AutoGPTQForCausalLM.from_quantized(
|
||||
llm.model_id,
|
||||
*decls,
|
||||
quantize_config=t.cast('autogptq.BaseQuantizeConfig', llm.quantization_config),
|
||||
trust_remote_code=trust_remote_code,
|
||||
use_safetensors=safe_serialisation,
|
||||
**hub_attrs,
|
||||
**attrs,
|
||||
)
|
||||
update_model(bentomodel, metadata={'_pretrained_class': model.__class__.__name__, '_framework': model.model.framework})
|
||||
model = autogptq.AutoGPTQForCausalLM.from_quantized(llm.model_id,
|
||||
*decls,
|
||||
quantize_config=t.cast('autogptq.BaseQuantizeConfig',
|
||||
llm.quantization_config),
|
||||
trust_remote_code=trust_remote_code,
|
||||
use_safetensors=safe_serialisation,
|
||||
**hub_attrs,
|
||||
**attrs,
|
||||
)
|
||||
update_model(bentomodel,
|
||||
metadata={
|
||||
'_pretrained_class': model.__class__.__name__,
|
||||
'_framework': model.model.framework
|
||||
})
|
||||
model.save_quantized(bentomodel.path, use_safetensors=safe_serialisation)
|
||||
else:
|
||||
architectures = getattr(config, 'architectures', [])
|
||||
if not architectures:
|
||||
raise RuntimeError('Failed to determine the architecture for this model. Make sure the `config.json` is valid and can be loaded with `transformers.AutoConfig`')
|
||||
raise RuntimeError(
|
||||
'Failed to determine the architecture for this model. Make sure the `config.json` is valid and can be loaded with `transformers.AutoConfig`'
|
||||
)
|
||||
architecture = architectures[0]
|
||||
update_model(bentomodel, metadata={'_pretrained_class': architecture})
|
||||
if llm._local:
|
||||
# possible local path
|
||||
logger.debug('Model will be loaded into memory to save to target store as it is from local path.')
|
||||
model = infer_autoclass_from_llm(llm, config).from_pretrained(llm.model_id, *decls, config=config, trust_remote_code=trust_remote_code, **hub_attrs, **attrs)
|
||||
model = infer_autoclass_from_llm(llm, config).from_pretrained(llm.model_id,
|
||||
*decls,
|
||||
config=config,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**hub_attrs,
|
||||
**attrs)
|
||||
# for trust_remote_code to work
|
||||
bentomodel.enter_cloudpickle_context([importlib.import_module(model.__module__)], imported_modules)
|
||||
model.save_pretrained(bentomodel.path, max_shard_size='5GB', safe_serialization=safe_serialisation)
|
||||
else:
|
||||
# we will clone the all tings into the bentomodel path without loading model into memory
|
||||
snapshot_download(llm.model_id, local_dir=bentomodel.path, local_dir_use_symlinks=False, ignore_patterns=HfIgnore.ignore_patterns(llm))
|
||||
snapshot_download(llm.model_id,
|
||||
local_dir=bentomodel.path,
|
||||
local_dir_use_symlinks=False,
|
||||
ignore_patterns=HfIgnore.ignore_patterns(llm))
|
||||
except Exception:
|
||||
raise
|
||||
else:
|
||||
bentomodel.flush() # type: ignore[no-untyped-call]
|
||||
bentomodel.save(_model_store)
|
||||
openllm.utils.analytics.track(openllm.utils.analytics.ModelSaveEvent(module=bentomodel.info.module, model_size_in_kb=openllm.utils.calc_dir_size(bentomodel.path) / 1024))
|
||||
openllm.utils.analytics.track(
|
||||
openllm.utils.analytics.ModelSaveEvent(module=bentomodel.info.module,
|
||||
model_size_in_kb=openllm.utils.calc_dir_size(bentomodel.path) / 1024))
|
||||
finally:
|
||||
bentomodel.exit_cloudpickle_context(imported_modules)
|
||||
# NOTE: We need to free up the cache after importing the model
|
||||
@@ -158,13 +192,15 @@ def get(llm: openllm.LLM[M, T], auto_import: bool = False) -> bentoml.Model:
|
||||
'''
|
||||
try:
|
||||
model = bentoml.models.get(llm.tag)
|
||||
if model.info.module not in (
|
||||
'openllm.serialisation.transformers'
|
||||
'bentoml.transformers', 'bentoml._internal.frameworks.transformers', __name__
|
||||
): # NOTE: backward compatible with previous version of OpenLLM.
|
||||
raise bentoml.exceptions.NotFound(f"Model {model.tag} was saved with module {model.info.module}, not loading with 'openllm.serialisation.transformers'.")
|
||||
if model.info.module not in ('openllm.serialisation.transformers'
|
||||
'bentoml.transformers', 'bentoml._internal.frameworks.transformers',
|
||||
__name__): # NOTE: backward compatible with previous version of OpenLLM.
|
||||
raise bentoml.exceptions.NotFound(
|
||||
f"Model {model.tag} was saved with module {model.info.module}, not loading with 'openllm.serialisation.transformers'."
|
||||
)
|
||||
if 'runtime' in model.info.labels and model.info.labels['runtime'] != llm.runtime:
|
||||
raise openllm.exceptions.OpenLLMException(f"Model {model.tag} was saved with runtime {model.info.labels['runtime']}, not loading with {llm.runtime}.")
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
f"Model {model.tag} was saved with runtime {model.info.labels['runtime']}, not loading with {llm.runtime}.")
|
||||
return model
|
||||
except bentoml.exceptions.NotFound as err:
|
||||
if auto_import: return import_model(llm, trust_remote_code=llm.__llm_trust_remote_code__)
|
||||
@@ -177,44 +213,50 @@ def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M:
|
||||
If model is not found, it will raises a ``bentoml.exceptions.NotFound``.
|
||||
'''
|
||||
config, hub_attrs, attrs = process_config(llm.model_id, llm.__llm_trust_remote_code__, **attrs)
|
||||
safe_serialization = openllm.utils.first_not_none(
|
||||
t.cast(t.Optional[bool], llm._bentomodel.info.metadata.get('safe_serialisation', None)), attrs.pop('safe_serialization', None), default=llm._serialisation_format == 'safetensors'
|
||||
)
|
||||
safe_serialization = openllm.utils.first_not_none(t.cast(
|
||||
t.Optional[bool], llm._bentomodel.info.metadata.get('safe_serialisation', None)),
|
||||
attrs.pop('safe_serialization', None),
|
||||
default=llm._serialisation_format == 'safetensors')
|
||||
if '_quantize' in llm._bentomodel.info.metadata and llm._bentomodel.info.metadata['_quantize'] == 'gptq':
|
||||
if not openllm.utils.is_autogptq_available():
|
||||
raise openllm.exceptions.OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
|
||||
if llm.config['model_type'] != 'causal_lm': raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
return autogptq.AutoGPTQForCausalLM.from_quantized(
|
||||
llm._bentomodel.path,
|
||||
*decls,
|
||||
quantize_config=t.cast('autogptq.BaseQuantizeConfig', llm.quantization_config),
|
||||
trust_remote_code=llm.__llm_trust_remote_code__,
|
||||
use_safetensors=safe_serialization,
|
||||
**hub_attrs,
|
||||
**attrs
|
||||
)
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
"GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'"
|
||||
)
|
||||
if llm.config['model_type'] != 'causal_lm':
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
return autogptq.AutoGPTQForCausalLM.from_quantized(llm._bentomodel.path,
|
||||
*decls,
|
||||
quantize_config=t.cast('autogptq.BaseQuantizeConfig',
|
||||
llm.quantization_config),
|
||||
trust_remote_code=llm.__llm_trust_remote_code__,
|
||||
use_safetensors=safe_serialization,
|
||||
**hub_attrs,
|
||||
**attrs)
|
||||
|
||||
device_map = attrs.pop('device_map', 'auto' if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None)
|
||||
model = infer_autoclass_from_llm(llm, config).from_pretrained(
|
||||
llm._bentomodel.path, *decls, config=config, trust_remote_code=llm.__llm_trust_remote_code__, device_map=device_map, **hub_attrs, **attrs
|
||||
).eval()
|
||||
model = infer_autoclass_from_llm(llm, config).from_pretrained(llm._bentomodel.path,
|
||||
*decls,
|
||||
config=config,
|
||||
trust_remote_code=llm.__llm_trust_remote_code__,
|
||||
device_map=device_map,
|
||||
**hub_attrs,
|
||||
**attrs).eval()
|
||||
# BetterTransformer is currently only supported on PyTorch.
|
||||
if llm.bettertransformer and isinstance(model, transformers.PreTrainedModel): model = model.to_bettertransformer()
|
||||
if llm.__llm_implementation__ in {'pt', 'vllm'}: check_unintialised_params(model)
|
||||
return t.cast('M', model)
|
||||
|
||||
def save_pretrained(
|
||||
llm: openllm.LLM[M, T],
|
||||
save_directory: str,
|
||||
is_main_process: bool = True,
|
||||
state_dict: DictStrAny | None = None,
|
||||
save_function: t.Any | None = None,
|
||||
push_to_hub: bool = False,
|
||||
max_shard_size: int | str = '10GB',
|
||||
safe_serialization: bool = False,
|
||||
variant: str | None = None,
|
||||
**attrs: t.Any
|
||||
) -> None:
|
||||
def save_pretrained(llm: openllm.LLM[M, T],
|
||||
save_directory: str,
|
||||
is_main_process: bool = True,
|
||||
state_dict: DictStrAny | None = None,
|
||||
save_function: t.Any | None = None,
|
||||
push_to_hub: bool = False,
|
||||
max_shard_size: int | str = '10GB',
|
||||
safe_serialization: bool = False,
|
||||
variant: str | None = None,
|
||||
**attrs: t.Any) -> None:
|
||||
save_function = t.cast(t.Callable[..., None], openllm.utils.first_not_none(save_function, default=torch.save))
|
||||
model_save_attrs, tokenizer_save_attrs = openllm.utils.normalize_attrs_to_model_tokenizer_pair(**attrs)
|
||||
safe_serialization = safe_serialization or llm._serialisation_format == 'safetensors'
|
||||
@@ -222,25 +264,31 @@ def save_pretrained(
|
||||
if llm.__llm_implementation__ == 'vllm': safe_serialization = False
|
||||
if llm._quantize_method == 'gptq':
|
||||
if not openllm.utils.is_autogptq_available():
|
||||
raise openllm.exceptions.OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'")
|
||||
if llm.config['model_type'] != 'causal_lm': raise openllm.exceptions.OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
if not openllm.utils.lenient_issubclass(llm.model, autogptq.modeling.BaseGPTQForCausalLM): raise ValueError(f'Model is not a BaseGPTQForCausalLM (type: {type(llm.model)})')
|
||||
t.cast('autogptq.modeling.BaseGPTQForCausalLM', llm.model).save_quantized(save_directory, use_safetensors=safe_serialization)
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
"GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'"
|
||||
)
|
||||
if llm.config['model_type'] != 'causal_lm':
|
||||
raise openllm.exceptions.OpenLLMException(
|
||||
f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})")
|
||||
if not openllm.utils.lenient_issubclass(llm.model, autogptq.modeling.BaseGPTQForCausalLM):
|
||||
raise ValueError(f'Model is not a BaseGPTQForCausalLM (type: {type(llm.model)})')
|
||||
t.cast('autogptq.modeling.BaseGPTQForCausalLM', llm.model).save_quantized(save_directory,
|
||||
use_safetensors=safe_serialization)
|
||||
elif openllm.utils.LazyType['vllm.LLMEngine']('vllm.LLMEngine').isinstance(llm.model):
|
||||
raise RuntimeError("vllm.LLMEngine cannot be serialisation directly. This happens when 'save_pretrained' is called directly after `openllm.AutoVLLM` is initialized.")
|
||||
raise RuntimeError(
|
||||
"vllm.LLMEngine cannot be serialisation directly. This happens when 'save_pretrained' is called directly after `openllm.AutoVLLM` is initialized."
|
||||
)
|
||||
elif isinstance(llm.model, transformers.Pipeline):
|
||||
llm.model.save_pretrained(save_directory, safe_serialization=safe_serialization)
|
||||
else:
|
||||
# We can safely cast here since it will be the PreTrainedModel protocol.
|
||||
t.cast('transformers.PreTrainedModel', llm.model).save_pretrained(
|
||||
save_directory,
|
||||
is_main_process=is_main_process,
|
||||
state_dict=state_dict,
|
||||
save_function=save_function,
|
||||
push_to_hub=push_to_hub,
|
||||
max_shard_size=max_shard_size,
|
||||
safe_serialization=safe_serialization,
|
||||
variant=variant,
|
||||
**model_save_attrs
|
||||
)
|
||||
t.cast('transformers.PreTrainedModel', llm.model).save_pretrained(save_directory,
|
||||
is_main_process=is_main_process,
|
||||
state_dict=state_dict,
|
||||
save_function=save_function,
|
||||
push_to_hub=push_to_hub,
|
||||
max_shard_size=max_shard_size,
|
||||
safe_serialization=safe_serialization,
|
||||
variant=variant,
|
||||
**model_save_attrs)
|
||||
llm.tokenizer.save_pretrained(save_directory, push_to_hub=push_to_hub, **tokenizer_save_attrs)
|
||||
|
||||
@@ -23,11 +23,14 @@ if t.TYPE_CHECKING:
|
||||
from openllm_core._typing_compat import M
|
||||
from openllm_core._typing_compat import T
|
||||
else:
|
||||
transformers, torch = openllm_core.utils.LazyLoader('transformers', globals(), 'transformers'), openllm_core.utils.LazyLoader('torch', globals(), 'torch')
|
||||
transformers, torch = openllm_core.utils.LazyLoader('transformers', globals(),
|
||||
'transformers'), openllm_core.utils.LazyLoader(
|
||||
'torch', globals(), 'torch')
|
||||
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
def process_config(model_id: str, trust_remote_code: bool, **attrs: t.Any) -> tuple[transformers.PretrainedConfig, DictStrAny, DictStrAny]:
|
||||
def process_config(model_id: str, trust_remote_code: bool,
|
||||
**attrs: t.Any) -> tuple[transformers.PretrainedConfig, DictStrAny, DictStrAny]:
|
||||
'''A helper function that correctly parse config and attributes for transformers.PretrainedConfig.
|
||||
|
||||
Args:
|
||||
@@ -44,19 +47,27 @@ def process_config(model_id: str, trust_remote_code: bool, **attrs: t.Any) -> tu
|
||||
if not isinstance(config, transformers.PretrainedConfig):
|
||||
copied_attrs = copy.deepcopy(attrs)
|
||||
if copied_attrs.get('torch_dtype', None) == 'auto': copied_attrs.pop('torch_dtype')
|
||||
config, attrs = transformers.AutoConfig.from_pretrained(model_id, return_unused_kwargs=True, trust_remote_code=trust_remote_code, **hub_attrs, **copied_attrs)
|
||||
config, attrs = transformers.AutoConfig.from_pretrained(model_id,
|
||||
return_unused_kwargs=True,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**hub_attrs,
|
||||
**copied_attrs)
|
||||
return config, hub_attrs, attrs
|
||||
|
||||
def infer_tokenizers_from_llm(__llm: openllm.LLM[t.Any, T], /) -> T:
|
||||
__cls = getattr(transformers, openllm_core.utils.first_not_none(__llm.config['tokenizer_class'], default='AutoTokenizer'), None)
|
||||
if __cls is None: raise ValueError(f'Cannot infer correct tokenizer class for {__llm}. Make sure to unset `tokenizer_class`')
|
||||
__cls = getattr(transformers,
|
||||
openllm_core.utils.first_not_none(__llm.config['tokenizer_class'], default='AutoTokenizer'), None)
|
||||
if __cls is None:
|
||||
raise ValueError(f'Cannot infer correct tokenizer class for {__llm}. Make sure to unset `tokenizer_class`')
|
||||
return __cls
|
||||
|
||||
def infer_autoclass_from_llm(llm: openllm.LLM[M, T], config: transformers.PretrainedConfig, /) -> _BaseAutoModelClass:
|
||||
if llm.config['trust_remote_code']:
|
||||
autoclass = 'AutoModelForSeq2SeqLM' if llm.config['model_type'] == 'seq2seq_lm' else 'AutoModelForCausalLM'
|
||||
if not hasattr(config, 'auto_map'):
|
||||
raise ValueError(f'Invalid configuraiton for {llm.model_id}. ``trust_remote_code=True`` requires `transformers.PretrainedConfig` to contain a `auto_map` mapping')
|
||||
raise ValueError(
|
||||
f'Invalid configuraiton for {llm.model_id}. ``trust_remote_code=True`` requires `transformers.PretrainedConfig` to contain a `auto_map` mapping'
|
||||
)
|
||||
# in case this model doesn't use the correct auto class for model type, for example like chatglm
|
||||
# where it uses AutoModel instead of AutoModelForCausalLM. Then we fallback to AutoModel
|
||||
if autoclass not in config.auto_map: autoclass = 'AutoModel'
|
||||
@@ -69,14 +80,24 @@ def infer_autoclass_from_llm(llm: openllm.LLM[M, T], config: transformers.Pretra
|
||||
|
||||
def check_unintialised_params(model: torch.nn.Module) -> None:
|
||||
unintialized = [n for n, param in model.named_parameters() if param.data.device == torch.device('meta')]
|
||||
if len(unintialized) > 0: raise RuntimeError(f'Found the following unintialized parameters in {model}: {unintialized}')
|
||||
if len(unintialized) > 0:
|
||||
raise RuntimeError(f'Found the following unintialized parameters in {model}: {unintialized}')
|
||||
|
||||
def update_model(bentomodel: bentoml.Model, metadata: DictStrAny) -> bentoml.Model:
|
||||
based: DictStrAny = copy.deepcopy(bentomodel.info.metadata)
|
||||
based.update(metadata)
|
||||
_object_setattr(bentomodel, '_info', ModelInfo( # type: ignore[call-arg] # XXX: remove me once upstream is merged
|
||||
tag=bentomodel.info.tag, module=bentomodel.info.module, labels=bentomodel.info.labels, options=bentomodel.info.options.to_dict(), signatures=bentomodel.info.signatures, context=bentomodel.info.context, api_version=bentomodel.info.api_version, creation_time=bentomodel.info.creation_time, metadata=based
|
||||
))
|
||||
_object_setattr(
|
||||
bentomodel, '_info',
|
||||
ModelInfo( # type: ignore[call-arg] # XXX: remove me once upstream is merged
|
||||
tag=bentomodel.info.tag,
|
||||
module=bentomodel.info.module,
|
||||
labels=bentomodel.info.labels,
|
||||
options=bentomodel.info.options.to_dict(),
|
||||
signatures=bentomodel.info.signatures,
|
||||
context=bentomodel.info.context,
|
||||
api_version=bentomodel.info.api_version,
|
||||
creation_time=bentomodel.info.creation_time,
|
||||
metadata=based))
|
||||
return bentomodel
|
||||
|
||||
# NOTE: sync with bentoml/_internal/frameworks/transformers.py#make_default_signatures
|
||||
@@ -84,9 +105,13 @@ def make_model_signatures(llm: openllm.LLM[M, T]) -> ModelSignaturesType:
|
||||
infer_fn: tuple[str, ...] = ('__call__',)
|
||||
default_config = ModelSignature(batchable=False)
|
||||
if llm.__llm_implementation__ in {'pt', 'vllm'}:
|
||||
infer_fn += ('forward', 'generate', 'contrastive_search', 'greedy_search', 'sample', 'beam_search', 'beam_sample', 'group_beam_search', 'constrained_beam_search',)
|
||||
infer_fn += ('forward', 'generate', 'contrastive_search', 'greedy_search', 'sample', 'beam_search', 'beam_sample',
|
||||
'group_beam_search', 'constrained_beam_search',
|
||||
)
|
||||
elif llm.__llm_implementation__ == 'tf':
|
||||
infer_fn += ('predict', 'call', 'generate', 'compute_transition_scores', 'greedy_search', 'sample', 'beam_search', 'contrastive_search',)
|
||||
infer_fn += ('predict', 'call', 'generate', 'compute_transition_scores', 'greedy_search', 'sample', 'beam_search',
|
||||
'contrastive_search',
|
||||
)
|
||||
else:
|
||||
infer_fn += ('generate',)
|
||||
return {k: default_config for k in infer_fn}
|
||||
|
||||
@@ -25,7 +25,8 @@ class HfIgnore:
|
||||
def ignore_patterns(cls, llm: openllm.LLM[M, T]) -> list[str]:
|
||||
if llm.__llm_implementation__ == 'vllm': base = [cls.tf, cls.flax, cls.safetensors]
|
||||
elif llm.__llm_implementation__ == 'tf': base = [cls.flax, cls.pt]
|
||||
elif llm.__llm_implementation__ == 'flax': base = [cls.tf, cls.pt, cls.safetensors] # as of current, safetensors is not supported with flax
|
||||
elif llm.__llm_implementation__ == 'flax':
|
||||
base = [cls.tf, cls.pt, cls.safetensors] # as of current, safetensors is not supported with flax
|
||||
else:
|
||||
base = [cls.tf, cls.flax]
|
||||
if has_safetensors_weights(llm.model_id): base.append(cls.pt)
|
||||
|
||||
@@ -15,9 +15,11 @@ if t.TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def build_bento(
|
||||
model: str, model_id: str | None = None, quantize: t.Literal['int4', 'int8', 'gptq'] | None = None, runtime: t.Literal['ggml', 'transformers'] = 'transformers', cleanup: bool = False
|
||||
) -> t.Iterator[bentoml.Bento]:
|
||||
def build_bento(model: str,
|
||||
model_id: str | None = None,
|
||||
quantize: t.Literal['int4', 'int8', 'gptq'] | None = None,
|
||||
runtime: t.Literal['ggml', 'transformers'] = 'transformers',
|
||||
cleanup: bool = False) -> t.Iterator[bentoml.Bento]:
|
||||
logger.info('Building BentoML for %s', model)
|
||||
bento = openllm.build(model, model_id=model_id, quantize=quantize, runtime=runtime)
|
||||
yield bento
|
||||
@@ -26,7 +28,10 @@ def build_bento(
|
||||
bentoml.bentos.delete(bento.tag)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def build_container(bento: bentoml.Bento | str | bentoml.Tag, image_tag: str | None = None, cleanup: bool = False, **attrs: t.Any) -> t.Iterator[str]:
|
||||
def build_container(bento: bentoml.Bento | str | bentoml.Tag,
|
||||
image_tag: str | None = None,
|
||||
cleanup: bool = False,
|
||||
**attrs: t.Any) -> t.Iterator[str]:
|
||||
if isinstance(bento, bentoml.Bento): bento_tag = bento.tag
|
||||
else: bento_tag = bentoml.Tag.from_taglike(bento)
|
||||
if image_tag is None: image_tag = str(bento_tag)
|
||||
@@ -42,22 +47,23 @@ def build_container(bento: bentoml.Bento | str | bentoml.Tag, image_tag: str | N
|
||||
subprocess.check_output([executable, 'rmi', '-f', image_tag])
|
||||
|
||||
@contextlib.contextmanager
|
||||
def prepare(
|
||||
model: str,
|
||||
model_id: str | None = None,
|
||||
implementation: LiteralRuntime = 'pt',
|
||||
deployment_mode: t.Literal['container', 'local'] = 'local',
|
||||
clean_context: contextlib.ExitStack | None = None,
|
||||
cleanup: bool = True
|
||||
) -> t.Iterator[str]:
|
||||
def prepare(model: str,
|
||||
model_id: str | None = None,
|
||||
implementation: LiteralRuntime = 'pt',
|
||||
deployment_mode: t.Literal['container', 'local'] = 'local',
|
||||
clean_context: contextlib.ExitStack | None = None,
|
||||
cleanup: bool = True) -> t.Iterator[str]:
|
||||
if clean_context is None:
|
||||
clean_context = contextlib.ExitStack()
|
||||
cleanup = True
|
||||
llm = openllm.infer_auto_class(implementation).for_model(model, model_id=model_id, ensure_available=True)
|
||||
bento_tag = bentoml.Tag.from_taglike(f'{llm.llm_type}-service:{llm.tag.version}')
|
||||
if not bentoml.list(bento_tag): bento = clean_context.enter_context(build_bento(model, model_id=model_id, cleanup=cleanup))
|
||||
else: bento = bentoml.get(bento_tag)
|
||||
if not bentoml.list(bento_tag):
|
||||
bento = clean_context.enter_context(build_bento(model, model_id=model_id, cleanup=cleanup))
|
||||
else:
|
||||
bento = bentoml.get(bento_tag)
|
||||
container_name = f'openllm-{model}-{llm.llm_type}'.replace('-', '_')
|
||||
if deployment_mode == 'container': container_name = clean_context.enter_context(build_container(bento, image_tag=container_name, cleanup=cleanup))
|
||||
if deployment_mode == 'container':
|
||||
container_name = clean_context.enter_context(build_container(bento, image_tag=container_name, cleanup=cleanup))
|
||||
yield container_name
|
||||
if cleanup: clean_context.close()
|
||||
|
||||
@@ -19,9 +19,17 @@ if t.TYPE_CHECKING:
|
||||
from openllm_core._typing_compat import LiteralRuntime
|
||||
|
||||
def generate_labels(llm: openllm.LLM[t.Any, t.Any]) -> dict[str, t.Any]:
|
||||
return {'runtime': llm.runtime, 'framework': 'openllm', 'model_name': llm.config['model_name'], 'architecture': llm.config['architecture'], 'serialisation_format': llm._serialisation_format}
|
||||
return {
|
||||
'runtime': llm.runtime,
|
||||
'framework': 'openllm',
|
||||
'model_name': llm.config['model_name'],
|
||||
'architecture': llm.config['architecture'],
|
||||
'serialisation_format': llm._serialisation_format
|
||||
}
|
||||
|
||||
def infer_auto_class(implementation: LiteralRuntime) -> type[openllm.AutoLLM | openllm.AutoTFLLM | openllm.AutoFlaxLLM | openllm.AutoVLLM]:
|
||||
def infer_auto_class(
|
||||
implementation: LiteralRuntime
|
||||
) -> type[openllm.AutoLLM | openllm.AutoTFLLM | openllm.AutoFlaxLLM | openllm.AutoVLLM]:
|
||||
import openllm
|
||||
if implementation == 'tf': return openllm.AutoTFLLM
|
||||
elif implementation == 'flax': return openllm.AutoFlaxLLM
|
||||
@@ -29,7 +37,10 @@ def infer_auto_class(implementation: LiteralRuntime) -> type[openllm.AutoLLM | o
|
||||
elif implementation == 'vllm': return openllm.AutoVLLM
|
||||
else: raise RuntimeError(f"Unknown implementation: {implementation} (supported: 'pt', 'flax', 'tf', 'vllm')")
|
||||
|
||||
__all__ = ['generate_labels', 'infer_auto_class', 'dummy_flax_objects', 'dummy_pt_objects', 'dummy_tf_objects', 'dummy_vllm_objects']
|
||||
__all__ = [
|
||||
'generate_labels', 'infer_auto_class', 'dummy_flax_objects', 'dummy_pt_objects', 'dummy_tf_objects',
|
||||
'dummy_vllm_objects'
|
||||
]
|
||||
|
||||
def __dir__() -> t.Sequence[str]:
|
||||
return sorted(__all__)
|
||||
|
||||
@@ -16,19 +16,32 @@ env_strats = st.sampled_from([openllm.utils.EnvVarMixin(model_name) for model_na
|
||||
def model_settings(draw: st.DrawFn):
|
||||
'''Strategy for generating ModelSettings objects.'''
|
||||
kwargs: dict[str, t.Any] = {
|
||||
'default_id': st.text(min_size=1),
|
||||
'model_ids': st.lists(st.text(), min_size=1),
|
||||
'architecture': st.text(min_size=1),
|
||||
'url': st.text(),
|
||||
'requires_gpu': st.booleans(),
|
||||
'trust_remote_code': st.booleans(),
|
||||
'requirements': st.none() | st.lists(st.text(), min_size=1),
|
||||
'default_implementation': st.dictionaries(st.sampled_from(['cpu', 'nvidia.com/gpu']), st.sampled_from(['vllm', 'pt', 'tf', 'flax'])),
|
||||
'model_type': st.sampled_from(['causal_lm', 'seq2seq_lm']),
|
||||
'runtime': st.sampled_from(['transformers', 'ggml']),
|
||||
'name_type': st.sampled_from(['dasherize', 'lowercase']),
|
||||
'timeout': st.integers(min_value=3600),
|
||||
'workers_per_resource': st.one_of(st.integers(min_value=1), st.floats(min_value=0.1, max_value=1.0)),
|
||||
'default_id':
|
||||
st.text(min_size=1),
|
||||
'model_ids':
|
||||
st.lists(st.text(), min_size=1),
|
||||
'architecture':
|
||||
st.text(min_size=1),
|
||||
'url':
|
||||
st.text(),
|
||||
'requires_gpu':
|
||||
st.booleans(),
|
||||
'trust_remote_code':
|
||||
st.booleans(),
|
||||
'requirements':
|
||||
st.none() | st.lists(st.text(), min_size=1),
|
||||
'default_implementation':
|
||||
st.dictionaries(st.sampled_from(['cpu', 'nvidia.com/gpu']), st.sampled_from(['vllm', 'pt', 'tf', 'flax'])),
|
||||
'model_type':
|
||||
st.sampled_from(['causal_lm', 'seq2seq_lm']),
|
||||
'runtime':
|
||||
st.sampled_from(['transformers', 'ggml']),
|
||||
'name_type':
|
||||
st.sampled_from(['dasherize', 'lowercase']),
|
||||
'timeout':
|
||||
st.integers(min_value=3600),
|
||||
'workers_per_resource':
|
||||
st.one_of(st.integers(min_value=1), st.floats(min_value=0.1, max_value=1.0)),
|
||||
}
|
||||
return draw(st.builds(ModelSettings, **kwargs))
|
||||
|
||||
|
||||
@@ -24,19 +24,29 @@ from ._strategies._configuration import make_llm_config
|
||||
from ._strategies._configuration import model_settings
|
||||
|
||||
# XXX: @aarnphm fixes TypedDict behaviour in 3.11
|
||||
@pytest.mark.skipif(sys.version_info[:2] == (3, 11), reason='TypedDict in 3.11 behaves differently, so we need to fix this')
|
||||
@pytest.mark.skipif(sys.version_info[:2] == (3, 11),
|
||||
reason='TypedDict in 3.11 behaves differently, so we need to fix this')
|
||||
def test_missing_default():
|
||||
with pytest.raises(ValueError, match='Missing required fields *'):
|
||||
make_llm_config('MissingDefaultId', {'name_type': 'lowercase', 'requirements': ['bentoml']})
|
||||
with pytest.raises(ValueError, match='Missing required fields *'):
|
||||
make_llm_config('MissingModelId', {'default_id': 'huggingface/t5-tiny-testing', 'requirements': ['bentoml']})
|
||||
with pytest.raises(ValueError, match='Missing required fields *'):
|
||||
make_llm_config('MissingArchitecture', {'default_id': 'huggingface/t5-tiny-testing', 'model_ids': ['huggingface/t5-tiny-testing'], 'requirements': ['bentoml'],},)
|
||||
make_llm_config(
|
||||
'MissingArchitecture', {
|
||||
'default_id': 'huggingface/t5-tiny-testing',
|
||||
'model_ids': ['huggingface/t5-tiny-testing'],
|
||||
'requirements': ['bentoml'],
|
||||
},
|
||||
)
|
||||
|
||||
def test_forbidden_access():
|
||||
cl_ = make_llm_config(
|
||||
'ForbiddenAccess', {
|
||||
'default_id': 'huggingface/t5-tiny-testing', 'model_ids': ['huggingface/t5-tiny-testing', 'bentoml/t5-tiny-testing'], 'architecture': 'PreTrainedModel', 'requirements': ['bentoml'],
|
||||
'default_id': 'huggingface/t5-tiny-testing',
|
||||
'model_ids': ['huggingface/t5-tiny-testing', 'bentoml/t5-tiny-testing'],
|
||||
'architecture': 'PreTrainedModel',
|
||||
'requirements': ['bentoml'],
|
||||
},
|
||||
)
|
||||
|
||||
@@ -69,9 +79,16 @@ def test_config_derived_follow_attrs_protocol(gen_settings: ModelSettings):
|
||||
cl_ = make_llm_config('AttrsProtocolLLM', gen_settings)
|
||||
assert attr.has(cl_)
|
||||
|
||||
@given(model_settings(), st.integers(max_value=283473), st.floats(min_value=0.0, max_value=1.0), st.integers(max_value=283473), st.floats(min_value=0.0, max_value=1.0),)
|
||||
def test_complex_struct_dump(gen_settings: ModelSettings, field1: int, temperature: float, input_field1: int, input_temperature: float):
|
||||
cl_ = make_llm_config('ComplexLLM', gen_settings, fields=(('field1', 'float', field1),), generation_fields=(('temperature', temperature),),)
|
||||
@given(model_settings(), st.integers(max_value=283473), st.floats(min_value=0.0, max_value=1.0),
|
||||
st.integers(max_value=283473), st.floats(min_value=0.0, max_value=1.0),
|
||||
)
|
||||
def test_complex_struct_dump(gen_settings: ModelSettings, field1: int, temperature: float, input_field1: int,
|
||||
input_temperature: float):
|
||||
cl_ = make_llm_config('ComplexLLM',
|
||||
gen_settings,
|
||||
fields=(('field1', 'float', field1),),
|
||||
generation_fields=(('temperature', temperature),),
|
||||
)
|
||||
sent = cl_()
|
||||
assert sent.model_dump()['field1'] == field1
|
||||
assert sent.model_dump()['generation_config']['temperature'] == temperature
|
||||
@@ -94,7 +111,10 @@ def patch_env(**attrs: t.Any):
|
||||
yield
|
||||
|
||||
def test_struct_envvar():
|
||||
with patch_env(**{field_env_key('env_llm', 'field1'): '4', field_env_key('env_llm', 'temperature', suffix='generation'): '0.2',}):
|
||||
with patch_env(**{
|
||||
field_env_key('env_llm', 'field1'): '4',
|
||||
field_env_key('env_llm', 'temperature', suffix='generation'): '0.2',
|
||||
}):
|
||||
|
||||
class EnvLLM(openllm.LLMConfig):
|
||||
__config__ = {'default_id': 'asdfasdf', 'model_ids': ['asdf', 'asdfasdfads'], 'architecture': 'PreTrainedModel',}
|
||||
@@ -112,6 +132,7 @@ def test_struct_envvar():
|
||||
assert overwrite_default['temperature'] == 0.2
|
||||
|
||||
def test_struct_provided_fields():
|
||||
|
||||
class EnvLLM(openllm.LLMConfig):
|
||||
__config__ = {'default_id': 'asdfasdf', 'model_ids': ['asdf', 'asdfasdfads'], 'architecture': 'PreTrainedModel',}
|
||||
field1: int = 2
|
||||
@@ -127,11 +148,13 @@ def test_struct_envvar_with_overwrite_provided_env(monkeypatch: pytest.MonkeyPat
|
||||
with monkeypatch.context() as mk:
|
||||
mk.setenv(field_env_key('overwrite_with_env_available', 'field1'), str(4.0))
|
||||
mk.setenv(field_env_key('overwrite_with_env_available', 'temperature', suffix='generation'), str(0.2))
|
||||
sent = make_llm_config(
|
||||
'OverwriteWithEnvAvailable', {
|
||||
'default_id': 'asdfasdf', 'model_ids': ['asdf', 'asdfasdfads'], 'architecture': 'PreTrainedModel'
|
||||
}, fields=(('field1', 'float', 3.0),),
|
||||
).model_construct_env(field1=20.0, temperature=0.4)
|
||||
sent = make_llm_config('OverwriteWithEnvAvailable', {
|
||||
'default_id': 'asdfasdf',
|
||||
'model_ids': ['asdf', 'asdfasdfads'],
|
||||
'architecture': 'PreTrainedModel'
|
||||
},
|
||||
fields=(('field1', 'float', 3.0),),
|
||||
).model_construct_env(field1=20.0, temperature=0.4)
|
||||
assert sent.generation_config.temperature == 0.4
|
||||
assert sent.field1 == 20.0
|
||||
|
||||
|
||||
@@ -10,23 +10,37 @@ import openllm
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm_core._typing_compat import LiteralRuntime
|
||||
|
||||
_FRAMEWORK_MAPPING = {'flan_t5': 'google/flan-t5-small', 'opt': 'facebook/opt-125m', 'baichuan': 'baichuan-inc/Baichuan-7B',}
|
||||
_PROMPT_MAPPING = {'qa': 'Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?',}
|
||||
_FRAMEWORK_MAPPING = {
|
||||
'flan_t5': 'google/flan-t5-small',
|
||||
'opt': 'facebook/opt-125m',
|
||||
'baichuan': 'baichuan-inc/Baichuan-7B',
|
||||
}
|
||||
_PROMPT_MAPPING = {
|
||||
'qa':
|
||||
'Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?',
|
||||
}
|
||||
|
||||
def parametrise_local_llm(model: str,) -> t.Generator[tuple[str, openllm.LLMRunner[t.Any, t.Any] | openllm.LLM[t.Any, t.Any]], None, None]:
|
||||
def parametrise_local_llm(
|
||||
model: str,) -> t.Generator[tuple[str, openllm.LLMRunner[t.Any, t.Any] | openllm.LLM[t.Any, t.Any]], None, None]:
|
||||
if model not in _FRAMEWORK_MAPPING: pytest.skip(f"'{model}' is not yet supported in framework testing.")
|
||||
runtime_impl: tuple[LiteralRuntime, ...] = tuple()
|
||||
if model in openllm.MODEL_MAPPING_NAMES: runtime_impl += ('pt',)
|
||||
if model in openllm.MODEL_FLAX_MAPPING_NAMES: runtime_impl += ('flax',)
|
||||
if model in openllm.MODEL_TF_MAPPING_NAMES: runtime_impl += ('tf',)
|
||||
for framework, prompt in itertools.product(runtime_impl, _PROMPT_MAPPING.keys()):
|
||||
llm = openllm.Runner(model, model_id=_FRAMEWORK_MAPPING[model], ensure_available=True, implementation=framework, init_local=True,)
|
||||
llm = openllm.Runner(model,
|
||||
model_id=_FRAMEWORK_MAPPING[model],
|
||||
ensure_available=True,
|
||||
implementation=framework,
|
||||
init_local=True,
|
||||
)
|
||||
yield prompt, llm
|
||||
|
||||
def pytest_generate_tests(metafunc: pytest.Metafunc) -> None:
|
||||
if os.getenv('GITHUB_ACTIONS') is None:
|
||||
if 'prompt' in metafunc.fixturenames and 'llm' in metafunc.fixturenames:
|
||||
metafunc.parametrize('prompt,llm', [(p, llm) for p, llm in parametrise_local_llm(metafunc.function.__name__[5:-15])])
|
||||
metafunc.parametrize('prompt,llm',
|
||||
[(p, llm) for p, llm in parametrise_local_llm(metafunc.function.__name__[5:-15])])
|
||||
|
||||
def pytest_sessionfinish(session: pytest.Session, exitstatus: int):
|
||||
# If no tests are collected, pytest exists with code 5, which makes the CI fail.
|
||||
|
||||
@@ -40,7 +40,13 @@ if t.TYPE_CHECKING:
|
||||
from openllm.client import BaseAsyncClient
|
||||
|
||||
class ResponseComparator(JSONSnapshotExtension):
|
||||
def serialize(self, data: SerializableData, *, exclude: PropertyFilter | None = None, matcher: PropertyMatcher | None = None,) -> SerializedData:
|
||||
|
||||
def serialize(self,
|
||||
data: SerializableData,
|
||||
*,
|
||||
exclude: PropertyFilter | None = None,
|
||||
matcher: PropertyMatcher | None = None,
|
||||
) -> SerializedData:
|
||||
if openllm.utils.LazyType(ListAny).isinstance(data):
|
||||
data = [d.unmarshaled for d in data]
|
||||
else:
|
||||
@@ -49,6 +55,7 @@ class ResponseComparator(JSONSnapshotExtension):
|
||||
return orjson.dumps(data, option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS).decode()
|
||||
|
||||
def matches(self, *, serialized_data: SerializableData, snapshot_data: SerializableData) -> bool:
|
||||
|
||||
def convert_data(data: SerializableData) -> openllm.GenerationOutput | t.Sequence[openllm.GenerationOutput]:
|
||||
try:
|
||||
data = orjson.loads(data)
|
||||
@@ -73,9 +80,11 @@ class ResponseComparator(JSONSnapshotExtension):
|
||||
return s == t
|
||||
|
||||
def eq_output(s: openllm.GenerationOutput, t: openllm.GenerationOutput) -> bool:
|
||||
return (len(s.responses) == len(t.responses) and all([_s == _t for _s, _t in zip(s.responses, t.responses)]) and eq_config(s.marshaled_config, t.marshaled_config))
|
||||
return (len(s.responses) == len(t.responses) and all([_s == _t for _s, _t in zip(s.responses, t.responses)]) and
|
||||
eq_config(s.marshaled_config, t.marshaled_config))
|
||||
|
||||
return len(serialized_data) == len(snapshot_data) and all([eq_output(s, t) for s, t in zip(serialized_data, snapshot_data)])
|
||||
return len(serialized_data) == len(snapshot_data) and all(
|
||||
[eq_output(s, t) for s, t in zip(serialized_data, snapshot_data)])
|
||||
|
||||
@pytest.fixture()
|
||||
def response_snapshot(snapshot: SnapshotAssertion):
|
||||
@@ -124,8 +133,14 @@ class LocalHandle(_Handle):
|
||||
return self.process.poll() is None
|
||||
|
||||
class HandleProtocol(t.Protocol):
|
||||
|
||||
@contextlib.contextmanager
|
||||
def __call__(*, model: str, model_id: str, image_tag: str, quantize: t.AnyStr | None = None,) -> t.Generator[_Handle, None, None]:
|
||||
def __call__(*,
|
||||
model: str,
|
||||
model_id: str,
|
||||
image_tag: str,
|
||||
quantize: t.AnyStr | None = None,
|
||||
) -> t.Generator[_Handle, None, None]:
|
||||
...
|
||||
|
||||
@attr.define(init=False)
|
||||
@@ -133,7 +148,9 @@ class DockerHandle(_Handle):
|
||||
container_name: str
|
||||
docker_client: docker.DockerClient
|
||||
|
||||
def __init__(self, docker_client: docker.DockerClient, container_name: str, port: int, deployment_mode: t.Literal['container', 'local'],):
|
||||
def __init__(self, docker_client: docker.DockerClient, container_name: str, port: int,
|
||||
deployment_mode: t.Literal['container', 'local'],
|
||||
):
|
||||
self.__attrs_init__(port, deployment_mode, container_name, docker_client)
|
||||
|
||||
def status(self) -> bool:
|
||||
@@ -141,16 +158,29 @@ class DockerHandle(_Handle):
|
||||
return container.status in ['running', 'created']
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _local_handle(
|
||||
model: str, model_id: str, image_tag: str, deployment_mode: t.Literal['container', 'local'], quantize: t.Literal['int8', 'int4', 'gptq'] | None = None, *, _serve_grpc: bool = False,
|
||||
):
|
||||
def _local_handle(model: str,
|
||||
model_id: str,
|
||||
image_tag: str,
|
||||
deployment_mode: t.Literal['container', 'local'],
|
||||
quantize: t.Literal['int8', 'int4', 'gptq'] | None = None,
|
||||
*,
|
||||
_serve_grpc: bool = False,
|
||||
):
|
||||
with openllm.utils.reserve_free_port() as port:
|
||||
pass
|
||||
|
||||
if not _serve_grpc:
|
||||
proc = openllm.start(model, model_id=model_id, quantize=quantize, additional_args=['--port', str(port)], __test__=True)
|
||||
proc = openllm.start(model,
|
||||
model_id=model_id,
|
||||
quantize=quantize,
|
||||
additional_args=['--port', str(port)],
|
||||
__test__=True)
|
||||
else:
|
||||
proc = openllm.start_grpc(model, model_id=model_id, quantize=quantize, additional_args=['--port', str(port)], __test__=True)
|
||||
proc = openllm.start_grpc(model,
|
||||
model_id=model_id,
|
||||
quantize=quantize,
|
||||
additional_args=['--port', str(port)],
|
||||
__test__=True)
|
||||
|
||||
yield LocalHandle(proc, port, deployment_mode)
|
||||
proc.terminate()
|
||||
@@ -164,9 +194,14 @@ def _local_handle(
|
||||
proc.stderr.close()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _container_handle(
|
||||
model: str, model_id: str, image_tag: str, deployment_mode: t.Literal['container', 'local'], quantize: t.Literal['int8', 'int4', 'gptq'] | None = None, *, _serve_grpc: bool = False,
|
||||
):
|
||||
def _container_handle(model: str,
|
||||
model_id: str,
|
||||
image_tag: str,
|
||||
deployment_mode: t.Literal['container', 'local'],
|
||||
quantize: t.Literal['int8', 'int4', 'gptq'] | None = None,
|
||||
*,
|
||||
_serve_grpc: bool = False,
|
||||
):
|
||||
envvar = openllm.utils.EnvVarMixin(model)
|
||||
|
||||
with openllm.utils.reserve_free_port() as port, openllm.utils.reserve_free_port() as prom_port:
|
||||
@@ -191,11 +226,18 @@ def _container_handle(
|
||||
gpus = openllm.utils.device_count() or -1
|
||||
devs = [docker.types.DeviceRequest(count=gpus, capabilities=[['gpu']])] if gpus > 0 else None
|
||||
|
||||
container = client.containers.run(
|
||||
image_tag, command=args, name=container_name, environment=env, auto_remove=False, detach=True, device_requests=devs, ports={
|
||||
'3000/tcp': port, '3001/tcp': prom_port
|
||||
},
|
||||
)
|
||||
container = client.containers.run(image_tag,
|
||||
command=args,
|
||||
name=container_name,
|
||||
environment=env,
|
||||
auto_remove=False,
|
||||
detach=True,
|
||||
device_requests=devs,
|
||||
ports={
|
||||
'3000/tcp': port,
|
||||
'3001/tcp': prom_port
|
||||
},
|
||||
)
|
||||
|
||||
yield DockerHandle(client, container.name, port, deployment_mode)
|
||||
|
||||
|
||||
@@ -16,8 +16,11 @@ model = 'flan_t5'
|
||||
model_id = 'google/flan-t5-small'
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def flan_t5_handle(handler: HandleProtocol, deployment_mode: t.Literal['container', 'local'], clean_context: contextlib.ExitStack,):
|
||||
with openllm.testing.prepare(model, model_id=model_id, deployment_mode=deployment_mode, clean_context=clean_context) as image_tag:
|
||||
def flan_t5_handle(handler: HandleProtocol, deployment_mode: t.Literal['container', 'local'],
|
||||
clean_context: contextlib.ExitStack,
|
||||
):
|
||||
with openllm.testing.prepare(model, model_id=model_id, deployment_mode=deployment_mode,
|
||||
clean_context=clean_context) as image_tag:
|
||||
with handler(model=model, model_id=model_id, image_tag=image_tag) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@@ -16,8 +16,11 @@ model = 'opt'
|
||||
model_id = 'facebook/opt-125m'
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def opt_125m_handle(handler: HandleProtocol, deployment_mode: t.Literal['container', 'local'], clean_context: contextlib.ExitStack,):
|
||||
with openllm.testing.prepare(model, model_id=model_id, deployment_mode=deployment_mode, clean_context=clean_context) as image_tag:
|
||||
def opt_125m_handle(handler: HandleProtocol, deployment_mode: t.Literal['container', 'local'],
|
||||
clean_context: contextlib.ExitStack,
|
||||
):
|
||||
with openllm.testing.prepare(model, model_id=model_id, deployment_mode=deployment_mode,
|
||||
clean_context=clean_context) as image_tag:
|
||||
with handler(model=model, model_id=model_id, image_tag=image_tag) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@@ -15,7 +15,9 @@ if t.TYPE_CHECKING:
|
||||
HF_INTERNAL_T5_TESTING = 'hf-internal-testing/tiny-random-t5'
|
||||
|
||||
actions_xfail = functools.partial(
|
||||
pytest.mark.xfail, condition=os.getenv('GITHUB_ACTIONS') is not None, reason='Marking GitHub Actions to xfail due to flakiness and building environment not isolated.',
|
||||
pytest.mark.xfail,
|
||||
condition=os.getenv('GITHUB_ACTIONS') is not None,
|
||||
reason='Marking GitHub Actions to xfail due to flakiness and building environment not isolated.',
|
||||
)
|
||||
|
||||
@actions_xfail
|
||||
@@ -46,7 +48,9 @@ def test_general_build_from_local(tmp_path_factory: pytest.TempPathFactory):
|
||||
@pytest.fixture()
|
||||
def dockerfile_template(tmp_path_factory: pytest.TempPathFactory):
|
||||
file = tmp_path_factory.mktemp('dockerfiles') / 'Dockerfile.template'
|
||||
file.write_text("{% extends bento_base_template %}\n{% block SETUP_BENTO_ENTRYPOINT %}\n{{ super() }}\nRUN echo 'sanity from custom dockerfile'\n{% endblock %}")
|
||||
file.write_text(
|
||||
"{% extends bento_base_template %}\n{% block SETUP_BENTO_ENTRYPOINT %}\n{{ super() }}\nRUN echo 'sanity from custom dockerfile'\n{% endblock %}"
|
||||
)
|
||||
return file
|
||||
|
||||
@pytest.mark.usefixtures('dockerfile_template')
|
||||
|
||||
@@ -71,9 +71,11 @@ def test_nvidia_gpu_validate(monkeypatch: pytest.MonkeyPatch):
|
||||
mcls.setenv('CUDA_VISIBLE_DEVICES', '')
|
||||
assert len(NvidiaGpuResource.from_system()) >= 0 # TODO: real from_system tests
|
||||
|
||||
assert pytest.raises(ValueError, NvidiaGpuResource.validate, [*NvidiaGpuResource.from_system(), 1],).match('Input list should be all string type.')
|
||||
assert pytest.raises(ValueError, NvidiaGpuResource.validate, [*NvidiaGpuResource.from_system(), 1],
|
||||
).match('Input list should be all string type.')
|
||||
assert pytest.raises(ValueError, NvidiaGpuResource.validate, [-2]).match('Input list should be all string type.')
|
||||
assert pytest.raises(ValueError, NvidiaGpuResource.validate, ['GPU-5ebe9f43', 'GPU-ac33420d4628']).match('Failed to parse available GPUs UUID')
|
||||
assert pytest.raises(ValueError, NvidiaGpuResource.validate,
|
||||
['GPU-5ebe9f43', 'GPU-ac33420d4628']).match('Failed to parse available GPUs UUID')
|
||||
|
||||
def test_nvidia_gpu_from_spec(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as mcls:
|
||||
|
||||
@@ -258,54 +258,54 @@ ignore_patterns = [
|
||||
]
|
||||
|
||||
[tool.yapf]
|
||||
ALIGN_CLOSING_BRACKET_WITH_VISUAL_INDENT = true
|
||||
ALLOW_MULTILINE_DICTIONARY_KEYS = false
|
||||
ALLOW_MULTILINE_LAMBDAS = false
|
||||
ALLOW_SPLIT_BEFORE_DEFAULT_OR_NAMED_ASSIGNS = false
|
||||
ALLOW_SPLIT_BEFORE_DICT_VALUE = false
|
||||
ARITHMETIC_PRECEDENCE_INDICATION = true
|
||||
BLANK_LINES_AROUND_TOP_LEVEL_DEFINITION = 1
|
||||
BLANK_LINES_BETWEEN_TOP_LEVEL_IMPORTS_AND_VARIABLES = 1
|
||||
BLANK_LINE_BEFORE_CLASS_DOCSTRING = false
|
||||
BLANK_LINE_BEFORE_MODULE_DOCSTRING = false
|
||||
BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = false
|
||||
COALESCE_BRACKETS = true
|
||||
COLUMN_LIMIT = 192
|
||||
CONTINUATION_ALIGN_STYLE = "SPACE"
|
||||
DEDENT_CLOSING_BRACKETS = true
|
||||
DISABLE_ENDING_COMMA_HEURISTIC = true
|
||||
EACH_DICT_ENTRY_ON_SEPARATE_LINE = true
|
||||
INDENT_BLANK_LINES = false
|
||||
INDENT_CLOSING_BRACKETS = false
|
||||
based_on_style = "google"
|
||||
INDENT_WIDTH = 2
|
||||
JOIN_MULTIPLE_LINES = true
|
||||
NO_SPACES_AROUND_SELECTED_BINARY_OPERATORS = true
|
||||
SPACES_AROUND_SUBSCRIPT_COLON = false
|
||||
SPACES_AROUND_DICT_DELIMITERS = false
|
||||
SPACES_AROUND_LIST_DELIMITERS = false
|
||||
SPACES_AROUND_POWER_OPERATOR = false
|
||||
SPACES_AROUND_TUPLE_DELIMITERS = false
|
||||
SPACE_BETWEEN_ENDING_COMMA_AND_CLOSING_BRACKET = false
|
||||
SPACE_INSIDE_BRACKETS = false
|
||||
SPLIT_ALL_COMMA_SEPARATED_VALUES = false
|
||||
SPLIT_ALL_TOP_LEVEL_COMMA_SEPARATED_VALUES = true
|
||||
SPLIT_ARGUMENTS_WHEN_COMMA_TERMINATED = false
|
||||
SPLIT_BEFORE_BITWISE_OPERATOR = false
|
||||
SPLIT_BEFORE_CLOSING_BRACKET = false
|
||||
SPLIT_BEFORE_DICT_SET_GENERATOR = false
|
||||
# similar to how rust format its expression
|
||||
SPLIT_BEFORE_DOT = true
|
||||
SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = false
|
||||
SPLIT_BEFORE_FIRST_ARGUMENT = false
|
||||
SPLIT_BEFORE_LOGICAL_OPERATOR = false
|
||||
SPLIT_BEFORE_NAMED_ASSIGNS = false
|
||||
SPLIT_COMPLEX_COMPREHENSION = true
|
||||
SPLIT_PENALTY_IMPORT_NAMES = 10000
|
||||
SPLIT_PENALTY_AFTER_OPENING_BRACKET = 350
|
||||
SPLIT_PENALTY_BEFORE_IF_EXPR = 10000
|
||||
SPLIT_PENALTY_COMPREHENSION = 2500
|
||||
SPLIT_PENALTY_FOR_ADDED_LINE_SPLIT = 5000
|
||||
COLUMN_LIMIT = 120
|
||||
USE_TABS = false
|
||||
BLANK_LINES_AROUND_TOP_LEVEL_DEFINITION = 1
|
||||
BLANK_LINES_BETWEEN_TOP_LEVEL_IMPORTS_AND_VARIABLES = 1
|
||||
DISABLE_ENDING_COMMA_HEURISTIC = true
|
||||
# DEDENT_CLOSING_BRACKETS = true
|
||||
# INDENT_CLOSING_BRACKETS = false
|
||||
# COALESCE_BRACKETS = true
|
||||
# EACH_DICT_ENTRY_ON_SEPARATE_LINE = true
|
||||
# ALIGN_CLOSING_BRACKET_WITH_VISUAL_INDENT = true
|
||||
# ALLOW_MULTILINE_DICTIONARY_KEYS = false
|
||||
# ALLOW_MULTILINE_LAMBDAS = false
|
||||
# ALLOW_SPLIT_BEFORE_DEFAULT_OR_NAMED_ASSIGNS = false
|
||||
# ALLOW_SPLIT_BEFORE_DICT_VALUE = false
|
||||
# ARITHMETIC_PRECEDENCE_INDICATION = true
|
||||
# BLANK_LINE_BEFORE_CLASS_DOCSTRING = false
|
||||
# BLANK_LINE_BEFORE_MODULE_DOCSTRING = false
|
||||
# BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = false
|
||||
# CONTINUATION_ALIGN_STYLE = "SPACE"
|
||||
# INDENT_BLANK_LINES = false
|
||||
# NO_SPACES_AROUND_SELECTED_BINARY_OPERATORS = true
|
||||
# SPACES_AROUND_SUBSCRIPT_COLON = false
|
||||
# SPACES_AROUND_DICT_DELIMITERS = false
|
||||
# SPACES_AROUND_LIST_DELIMITERS = false
|
||||
# SPACES_AROUND_POWER_OPERATOR = false
|
||||
# SPACES_AROUND_TUPLE_DELIMITERS = false
|
||||
# SPACE_BETWEEN_ENDING_COMMA_AND_CLOSING_BRACKET = false
|
||||
# SPACE_INSIDE_BRACKETS = false
|
||||
# SPLIT_ALL_COMMA_SEPARATED_VALUES = false
|
||||
# SPLIT_ALL_TOP_LEVEL_COMMA_SEPARATED_VALUES = true
|
||||
# SPLIT_ARGUMENTS_WHEN_COMMA_TERMINATED = false
|
||||
# SPLIT_BEFORE_BITWISE_OPERATOR = false
|
||||
# SPLIT_BEFORE_CLOSING_BRACKET = false
|
||||
# SPLIT_BEFORE_DICT_SET_GENERATOR = false
|
||||
# SPLIT_BEFORE_DOT = true
|
||||
# SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = false
|
||||
# SPLIT_BEFORE_FIRST_ARGUMENT = false
|
||||
# SPLIT_BEFORE_LOGICAL_OPERATOR = false
|
||||
# SPLIT_BEFORE_NAMED_ASSIGNS = false
|
||||
# SPLIT_COMPLEX_COMPREHENSION = true
|
||||
# SPLIT_PENALTY_IMPORT_NAMES = 10000
|
||||
# SPLIT_PENALTY_AFTER_OPENING_BRACKET = 350
|
||||
# SPLIT_PENALTY_BEFORE_IF_EXPR = 10000
|
||||
# SPLIT_PENALTY_COMPREHENSION = 2500
|
||||
# SPLIT_PENALTY_FOR_ADDED_LINE_SPLIT = 5000
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = ["-rfEX", "-pno:warnings", "--snapshot-warn-unused"]
|
||||
|
||||
@@ -23,13 +23,17 @@ class Classifier:
|
||||
'audience': 'Intended Audience',
|
||||
'typing': 'Typing',
|
||||
'language': 'Programming Language',
|
||||
}
|
||||
)
|
||||
})
|
||||
joiner: str = ' :: '
|
||||
|
||||
@staticmethod
|
||||
def status() -> dict[int, str]:
|
||||
return {v: status for v, status in zip(range(1, 8), ['1 - Planning', '2 - Pre-Alpha', '3 - Alpha', '4 - Beta', '5 - Production/Stable', '6 - Mature', '7 - Inactive'])}
|
||||
return {
|
||||
v: status for v, status in zip(range(1, 8), [
|
||||
'1 - Planning', '2 - Pre-Alpha', '3 - Alpha', '4 - Beta', '5 - Production/Stable', '6 - Mature',
|
||||
'7 - Inactive'
|
||||
])
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def apache() -> str:
|
||||
@@ -43,10 +47,14 @@ class Classifier:
|
||||
return cls_.joiner.join([cls_.identifier[identifier], *decls])
|
||||
|
||||
@staticmethod
|
||||
def create_python_classifier(implementation: list[str] | None = None, supported_version: list[str] | None = None) -> list[str]:
|
||||
def create_python_classifier(implementation: list[str] | None = None,
|
||||
supported_version: list[str] | None = None) -> list[str]:
|
||||
if supported_version is None: supported_version = ['3.8', '3.9', '3.10', '3.11', '3.12']
|
||||
if implementation is None: implementation = ['CPython', 'PyPy']
|
||||
base = [Classifier.create_classifier('language', 'Python'), Classifier.create_classifier('language', 'Python', '3'),]
|
||||
base = [
|
||||
Classifier.create_classifier('language', 'Python'),
|
||||
Classifier.create_classifier('language', 'Python', '3'),
|
||||
]
|
||||
base.append(Classifier.create_classifier('language', 'Python', '3', 'Only'))
|
||||
base.extend([Classifier.create_classifier('language', 'Python', version) for version in supported_version])
|
||||
base.extend([Classifier.create_classifier('language', 'Python', 'Implementation', impl) for impl in implementation])
|
||||
@@ -85,12 +93,18 @@ class Dependencies:
|
||||
|
||||
def to_str(self) -> str:
|
||||
deps: list[str] = []
|
||||
if self.lower_constraint is not None and self.upper_constraint is not None: dep = f'{self.name}{self.pypi_extensions}>={self.lower_constraint},<{self.upper_constraint}'
|
||||
elif self.lower_constraint is not None: dep = f'{self.name}{self.pypi_extensions}>={self.lower_constraint}'
|
||||
elif self.upper_constraint is not None: dep = f'{self.name}{self.pypi_extensions}<{self.upper_constraint}'
|
||||
elif self.subdirectory is not None: dep = f'{self.name}{self.pypi_extensions} @ git+https://github.com/{self.git_repo_url}.git#subdirectory={self.subdirectory}'
|
||||
elif self.branch is not None: dep = f'{self.name}{self.pypi_extensions} @ git+https://github.com/{self.git_repo_url}.git@{self.branch}'
|
||||
else: dep = f'{self.name}{self.pypi_extensions}'
|
||||
if self.lower_constraint is not None and self.upper_constraint is not None:
|
||||
dep = f'{self.name}{self.pypi_extensions}>={self.lower_constraint},<{self.upper_constraint}'
|
||||
elif self.lower_constraint is not None:
|
||||
dep = f'{self.name}{self.pypi_extensions}>={self.lower_constraint}'
|
||||
elif self.upper_constraint is not None:
|
||||
dep = f'{self.name}{self.pypi_extensions}<{self.upper_constraint}'
|
||||
elif self.subdirectory is not None:
|
||||
dep = f'{self.name}{self.pypi_extensions} @ git+https://github.com/{self.git_repo_url}.git#subdirectory={self.subdirectory}'
|
||||
elif self.branch is not None:
|
||||
dep = f'{self.name}{self.pypi_extensions} @ git+https://github.com/{self.git_repo_url}.git@{self.branch}'
|
||||
else:
|
||||
dep = f'{self.name}{self.pypi_extensions}'
|
||||
deps.append(dep)
|
||||
if self.platform: deps.append(self.platform_restriction(*self.platform))
|
||||
return ';'.join(deps)
|
||||
@@ -129,7 +143,9 @@ GPTQ_DEPS = ['auto-gptq[triton]']
|
||||
VLLM_DEPS = ['vllm>=0.1.4', 'ray']
|
||||
|
||||
_base_requirements: dict[str, t.Any] = {
|
||||
inflection.dasherize(name): config_cls.__openllm_requirements__ for name, config_cls in openllm.CONFIG_MAPPING.items() if config_cls.__openllm_requirements__
|
||||
inflection.dasherize(name): config_cls.__openllm_requirements__
|
||||
for name, config_cls in openllm.CONFIG_MAPPING.items()
|
||||
if config_cls.__openllm_requirements__
|
||||
}
|
||||
|
||||
# shallow copy from locals()
|
||||
@@ -137,7 +153,8 @@ _locals = locals().copy()
|
||||
|
||||
# NOTE: update this table when adding new external dependencies
|
||||
# sync with openllm.utils.OPTIONAL_DEPENDENCIES
|
||||
_base_requirements.update({v: _locals.get(f'{inflection.underscore(v).upper()}_DEPS') for v in openllm.utils.OPTIONAL_DEPENDENCIES})
|
||||
_base_requirements.update(
|
||||
{v: _locals.get(f'{inflection.underscore(v).upper()}_DEPS') for v in openllm.utils.OPTIONAL_DEPENDENCIES})
|
||||
|
||||
_base_requirements = {k: v for k, v in sorted(_base_requirements.items())}
|
||||
|
||||
@@ -161,8 +178,7 @@ def create_classifiers() -> Array:
|
||||
Classifier.create_classifier('audience', 'Developers'),
|
||||
Classifier.create_classifier('audience', 'Science/Research'),
|
||||
Classifier.create_classifier('audience', 'System Administrators'),
|
||||
Classifier.create_classifier('typing', 'Typed'),
|
||||
*Classifier.create_python_classifier(),
|
||||
Classifier.create_classifier('typing', 'Typed'), *Classifier.create_python_classifier(),
|
||||
])
|
||||
return arr.multiline(True)
|
||||
|
||||
@@ -171,7 +187,10 @@ def create_optional_table() -> Table:
|
||||
all_array.append(f"openllm[{','.join(_base_requirements)}]")
|
||||
|
||||
table = tomlkit.table(is_super_table=True)
|
||||
_base_requirements.update({'full': correct_style(all_array.multiline(True)), 'all': tomlkit.array('["openllm[full]"]')})
|
||||
_base_requirements.update({
|
||||
'full': correct_style(all_array.multiline(True)),
|
||||
'all': tomlkit.array('["openllm[full]"]')
|
||||
})
|
||||
table.update({k: v for k, v in sorted(_base_requirements.items())})
|
||||
table.add(tomlkit.nl())
|
||||
|
||||
@@ -209,22 +228,8 @@ def authors() -> Array:
|
||||
def keywords() -> Array:
|
||||
arr = correct_style(tomlkit.array())
|
||||
arr.extend([
|
||||
'MLOps',
|
||||
'AI',
|
||||
'BentoML',
|
||||
'Model Serving',
|
||||
'Model Deployment',
|
||||
'LLMOps',
|
||||
'Falcon',
|
||||
'Vicuna',
|
||||
'Llama 2',
|
||||
'Fine tuning',
|
||||
'Serverless',
|
||||
'Large Language Model',
|
||||
'Generative AI',
|
||||
'StableLM',
|
||||
'Alpaca',
|
||||
'PyTorch',
|
||||
'MLOps', 'AI', 'BentoML', 'Model Serving', 'Model Deployment', 'LLMOps', 'Falcon', 'Vicuna', 'Llama 2',
|
||||
'Fine tuning', 'Serverless', 'Large Language Model', 'Generative AI', 'StableLM', 'Alpaca', 'PyTorch',
|
||||
'Transformers'
|
||||
])
|
||||
return arr.multiline(True)
|
||||
@@ -234,8 +239,10 @@ def build_cli_extensions() -> Table:
|
||||
ext: dict[str, str] = {'openllm': 'openllm.cli.entrypoint:cli'}
|
||||
ext.update({
|
||||
f'openllm-{inflection.dasherize(ke)}': f'openllm.cli.extension.{ke}:cli' for ke in sorted([
|
||||
fname[:-3] for fname in os.listdir(os.path.abspath(os.path.join(ROOT, 'openllm-python', 'src', 'openllm', 'cli', 'extension'))) if fname.endswith('.py') and
|
||||
not fname.startswith('__')
|
||||
fname[:-3]
|
||||
for fname in os.listdir(
|
||||
os.path.abspath(os.path.join(ROOT, 'openllm-python', 'src', 'openllm', 'cli', 'extension')))
|
||||
if fname.endswith('.py') and not fname.startswith('__')
|
||||
])
|
||||
})
|
||||
table.update(ext)
|
||||
|
||||
@@ -16,10 +16,13 @@ _OWNER = 'bentoml'
|
||||
_REPO = 'openllm'
|
||||
|
||||
_gz_strategies: dict[t.Literal['macos_arm', 'macos_intel', 'linux_intel'], str] = {
|
||||
'macos_arm': 'aarch64-apple-darwin', 'macos_intel': 'x86_64-apple-darwin', 'linux_intel': 'x86_64-unknown-linux-musl'
|
||||
'macos_arm': 'aarch64-apple-darwin',
|
||||
'macos_intel': 'x86_64-apple-darwin',
|
||||
'linux_intel': 'x86_64-unknown-linux-musl'
|
||||
}
|
||||
|
||||
def determine_release_url(svn_url: str, tag: str, target: t.Literal['macos_arm', 'macos_intel', 'linux_intel', 'archive']) -> str:
|
||||
def determine_release_url(svn_url: str, tag: str, target: t.Literal['macos_arm', 'macos_intel', 'linux_intel',
|
||||
'archive']) -> str:
|
||||
if target == 'archive': return f'{svn_url}/archive/{tag}.tar.gz'
|
||||
return f"{svn_url}/releases/download/{tag}/openllm-{tag.replace('v', '')}-{_gz_strategies[target]}.tar.gz"
|
||||
|
||||
@@ -32,29 +35,28 @@ def main() -> int:
|
||||
_info = api.repos.get()
|
||||
release_tag = api.repos.get_latest_release().name
|
||||
|
||||
shadict: dict[str, t.Any] = {k: get_release_hash_command(determine_release_url(_info.svn_url, release_tag, k), release_tag)().strip() for k in _gz_strategies}
|
||||
shadict['archive'] = get_release_hash_command(determine_release_url(_info.svn_url, release_tag, 'archive'), release_tag)().strip()
|
||||
shadict: dict[str, t.Any] = {
|
||||
k: get_release_hash_command(determine_release_url(_info.svn_url, release_tag, k), release_tag)().strip()
|
||||
for k in _gz_strategies
|
||||
}
|
||||
shadict['archive'] = get_release_hash_command(determine_release_url(_info.svn_url, release_tag, 'archive'),
|
||||
release_tag)().strip()
|
||||
|
||||
ENVIRONMENT = Environment(
|
||||
extensions=['jinja2.ext.do', 'jinja2.ext.loopcontrols', 'jinja2.ext.debug'],
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True,
|
||||
loader=FileSystemLoader((ROOT / 'Formula').__fspath__(), followlinks=True)
|
||||
)
|
||||
ENVIRONMENT = Environment(extensions=['jinja2.ext.do', 'jinja2.ext.loopcontrols', 'jinja2.ext.debug'],
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True,
|
||||
loader=FileSystemLoader((ROOT / 'Formula').__fspath__(), followlinks=True))
|
||||
template_file = 'openllm.rb.j2'
|
||||
with (ROOT / 'Formula' / 'openllm.rb').open('w') as f:
|
||||
f.write(
|
||||
ENVIRONMENT.get_template(template_file, globals={
|
||||
'determine_release_url': determine_release_url
|
||||
}).render(
|
||||
shadict=shadict,
|
||||
__tag__=release_tag,
|
||||
__cmd__=fs.path.join(os.path.basename(os.path.dirname(__file__)), os.path.basename(__file__)),
|
||||
__template_file__=fs.path.join('Formula', template_file),
|
||||
__gz_extension__=_gz_strategies,
|
||||
**_info
|
||||
)
|
||||
)
|
||||
}).render(shadict=shadict,
|
||||
__tag__=release_tag,
|
||||
__cmd__=fs.path.join(os.path.basename(os.path.dirname(__file__)), os.path.basename(__file__)),
|
||||
__template_file__=fs.path.join('Formula', template_file),
|
||||
__gz_extension__=_gz_strategies,
|
||||
**_info))
|
||||
f.write('\n')
|
||||
return 0
|
||||
|
||||
|
||||
@@ -24,12 +24,14 @@ def process_annotations(annotations: str) -> str:
|
||||
else: return annotations
|
||||
|
||||
_value_docstring = {
|
||||
'default_id': '''Return the default model to use when using 'openllm start <model_id>'.
|
||||
'default_id':
|
||||
'''Return the default model to use when using 'openllm start <model_id>'.
|
||||
This could be one of the keys in 'self.model_ids' or custom users model.
|
||||
|
||||
This field is required when defining under '__config__'.
|
||||
''',
|
||||
'model_ids': '''A list of supported pretrained models tag for this given runnable.
|
||||
'model_ids':
|
||||
'''A list of supported pretrained models tag for this given runnable.
|
||||
|
||||
For example:
|
||||
For FLAN-T5 impl, this would be ["google/flan-t5-small", "google/flan-t5-base",
|
||||
@@ -37,7 +39,8 @@ _value_docstring = {
|
||||
|
||||
This field is required when defining under '__config__'.
|
||||
''',
|
||||
'architecture': '''The model architecture that is supported by this LLM.
|
||||
'architecture':
|
||||
'''The model architecture that is supported by this LLM.
|
||||
|
||||
Note that any model weights within this architecture generation can always be run and supported by this LLM.
|
||||
|
||||
@@ -47,29 +50,44 @@ _value_docstring = {
|
||||
```bash
|
||||
openllm start gpt-neox --model-id stabilityai/stablelm-tuned-alpha-3b
|
||||
```''',
|
||||
'default_implementation': '''The default runtime to run this LLM. By default, it will be PyTorch (pt) for most models. For some models, such as Llama, it will use `vllm` or `flax`.
|
||||
'default_implementation':
|
||||
'''The default runtime to run this LLM. By default, it will be PyTorch (pt) for most models. For some models, such as Llama, it will use `vllm` or `flax`.
|
||||
|
||||
It is a dictionary of key as the accelerator spec in k4s ('cpu', 'nvidia.com/gpu', 'amd.com/gpu', 'cloud-tpus.google.com/v2', ...) and the values as supported OpenLLM Runtime ('flax', 'tf', 'pt', 'vllm')
|
||||
''',
|
||||
'url': '''The resolved url for this LLMConfig.''',
|
||||
'requires_gpu': '''Determines if this model is only available on GPU. By default it supports GPU and fallback to CPU.''',
|
||||
'trust_remote_code': '''Whether to always trust remote code''',
|
||||
'service_name': """Generated service name for this LLMConfig. By default, it is 'generated_{model_name}_service.py'""",
|
||||
'requirements': '''The default PyPI requirements needed to run this given LLM. By default, we will depend on
|
||||
'url':
|
||||
'''The resolved url for this LLMConfig.''',
|
||||
'requires_gpu':
|
||||
'''Determines if this model is only available on GPU. By default it supports GPU and fallback to CPU.''',
|
||||
'trust_remote_code':
|
||||
'''Whether to always trust remote code''',
|
||||
'service_name':
|
||||
"""Generated service name for this LLMConfig. By default, it is 'generated_{model_name}_service.py'""",
|
||||
'requirements':
|
||||
'''The default PyPI requirements needed to run this given LLM. By default, we will depend on
|
||||
bentoml, torch, transformers.''',
|
||||
'bettertransformer': '''Whether to use BetterTransformer for this given LLM. This depends per model architecture. By default, we will use BetterTransformer for T5 and StableLM models, and set to False for every other models.''',
|
||||
'model_type': '''The model type for this given LLM. By default, it should be causal language modeling.
|
||||
'bettertransformer':
|
||||
'''Whether to use BetterTransformer for this given LLM. This depends per model architecture. By default, we will use BetterTransformer for T5 and StableLM models, and set to False for every other models.''',
|
||||
'model_type':
|
||||
'''The model type for this given LLM. By default, it should be causal language modeling.
|
||||
Currently supported 'causal_lm' or 'seq2seq_lm'
|
||||
''',
|
||||
'runtime': '''The runtime to use for this model. Possible values are `transformers` or `ggml`. See Llama for more information.''',
|
||||
'name_type': '''The default name typed for this model. "dasherize" will convert the name to lowercase and
|
||||
'runtime':
|
||||
'''The runtime to use for this model. Possible values are `transformers` or `ggml`. See Llama for more information.''',
|
||||
'name_type':
|
||||
'''The default name typed for this model. "dasherize" will convert the name to lowercase and
|
||||
replace spaces with dashes. "lowercase" will convert the name to lowercase. If this is not set, then both
|
||||
`model_name` and `start_name` must be specified.''',
|
||||
'model_name': '''The normalized version of __openllm_start_name__, determined by __openllm_name_type__''',
|
||||
'start_name': '''Default name to be used with `openllm start`''',
|
||||
'env': '''A EnvVarMixin instance for this LLMConfig.''',
|
||||
'timeout': '''The default timeout to be set for this given LLM.''',
|
||||
'workers_per_resource': '''The number of workers per resource. This is used to determine the number of workers to use for this model.
|
||||
'model_name':
|
||||
'''The normalized version of __openllm_start_name__, determined by __openllm_name_type__''',
|
||||
'start_name':
|
||||
'''Default name to be used with `openllm start`''',
|
||||
'env':
|
||||
'''A EnvVarMixin instance for this LLMConfig.''',
|
||||
'timeout':
|
||||
'''The default timeout to be set for this given LLM.''',
|
||||
'workers_per_resource':
|
||||
'''The number of workers per resource. This is used to determine the number of workers to use for this model.
|
||||
For example, if this is set to 0.5, then OpenLLM will use 1 worker per 2 resources. If this is set to 1, then
|
||||
OpenLLM will use 1 worker per resource. If this is set to 2, then OpenLLM will use 2 workers per resource.
|
||||
|
||||
@@ -78,8 +96,10 @@ _value_docstring = {
|
||||
|
||||
By default, it is set to 1.
|
||||
''',
|
||||
'fine_tune_strategies': '''The fine-tune strategies for this given LLM.''',
|
||||
'tokenizer_class': '''Optional tokenizer class for this given LLM. See Llama for example.''',
|
||||
'fine_tune_strategies':
|
||||
'''The fine-tune strategies for this given LLM.''',
|
||||
'tokenizer_class':
|
||||
'''Optional tokenizer class for this given LLM. See Llama for example.''',
|
||||
}
|
||||
|
||||
_transformed = {'fine_tune_strategies': 't.Dict[AdapterType, FineTuneConfig]'}
|
||||
@@ -88,53 +108,74 @@ def main() -> int:
|
||||
with _TARGET_FILE.open('r') as f:
|
||||
processed = f.readlines()
|
||||
|
||||
start_idx, end_idx = processed.index(' '*2 + START_COMMENT), processed.index(' '*2 + END_COMMENT)
|
||||
start_stub_idx, end_stub_idx = processed.index(' '*4 + START_SPECIAL_COMMENT), processed.index(' '*4 + END_SPECIAL_COMMENT)
|
||||
start_attrs_idx, end_attrs_idx = processed.index(' '*4 + START_ATTRS_COMMENT), processed.index(' '*4 + END_ATTRS_COMMENT)
|
||||
start_idx, end_idx = processed.index(' ' * 2 + START_COMMENT), processed.index(' ' * 2 + END_COMMENT)
|
||||
start_stub_idx, end_stub_idx = processed.index(' ' * 4 + START_SPECIAL_COMMENT), processed.index(' ' * 4 +
|
||||
END_SPECIAL_COMMENT)
|
||||
start_attrs_idx, end_attrs_idx = processed.index(' ' * 4 + START_ATTRS_COMMENT), processed.index(' ' * 4 +
|
||||
END_ATTRS_COMMENT)
|
||||
|
||||
# NOTE: inline stubs __config__ attrs representation
|
||||
special_attrs_lines: list[str] = []
|
||||
for keys, ForwardRef in codegen.get_annotations(ModelSettings).items():
|
||||
special_attrs_lines.append(f"{' ' * 4}{keys}: {_transformed.get(keys, process_annotations(ForwardRef.__forward_arg__))}\n")
|
||||
special_attrs_lines.append(
|
||||
f"{' ' * 4}{keys}: {_transformed.get(keys, process_annotations(ForwardRef.__forward_arg__))}\n")
|
||||
# NOTE: inline stubs for _ConfigAttr type stubs
|
||||
config_attr_lines: list[str] = []
|
||||
for keys, ForwardRef in codegen.get_annotations(ModelSettings).items():
|
||||
config_attr_lines.extend([
|
||||
' '*4 + line for line in [f'__openllm_{keys}__: {_transformed.get(keys, process_annotations(ForwardRef.__forward_arg__))} = Field(None)\n', f'"""{_value_docstring[keys]}"""\n',]
|
||||
' ' * 4 + line for line in [
|
||||
f'__openllm_{keys}__: {_transformed.get(keys, process_annotations(ForwardRef.__forward_arg__))} = Field(None)\n',
|
||||
f'"""{_value_docstring[keys]}"""\n',
|
||||
]
|
||||
])
|
||||
# NOTE: inline runtime __getitem__ overload process
|
||||
lines: list[str] = []
|
||||
lines.append(' '*2 + '# NOTE: ModelSettings arguments\n')
|
||||
lines.append(' ' * 2 + '# NOTE: ModelSettings arguments\n')
|
||||
for keys, ForwardRef in codegen.get_annotations(ModelSettings).items():
|
||||
lines.extend([
|
||||
' '*2 + line for line in ['@overload\n', f'def __getitem__(self, item: t.Literal["{keys}"]) -> {_transformed.get(keys, process_annotations(ForwardRef.__forward_arg__))}: ...\n',]
|
||||
' ' * 2 + line for line in [
|
||||
'@overload\n',
|
||||
f'def __getitem__(self, item: t.Literal["{keys}"]) -> {_transformed.get(keys, process_annotations(ForwardRef.__forward_arg__))}: ...\n',
|
||||
]
|
||||
])
|
||||
# special case variables: generation_class, extras, sampling_class
|
||||
lines.append(' '*2 + '# NOTE: generation_class, sampling_class and extras arguments\n')
|
||||
lines.append(' ' * 2 + '# NOTE: generation_class, sampling_class and extras arguments\n')
|
||||
lines.extend([
|
||||
' '*2 + line for line in [
|
||||
' ' * 2 + line for line in [
|
||||
'@overload\n',
|
||||
'def __getitem__(self, item: t.Literal["generation_class"]) -> t.Type[openllm_core.GenerationConfig]: ...\n',
|
||||
'@overload\n',
|
||||
'def __getitem__(self, item: t.Literal["sampling_class"]) -> t.Type[openllm_core.SamplingParams]: ...\n',
|
||||
'@overload\n',
|
||||
'def __getitem__(self, item: t.Literal["extras"]) -> t.Dict[str, t.Any]: ...\n',
|
||||
'@overload\n', 'def __getitem__(self, item: t.Literal["extras"]) -> t.Dict[str, t.Any]: ...\n',
|
||||
]
|
||||
])
|
||||
lines.append(' '*2 + '# NOTE: GenerationConfig arguments\n')
|
||||
lines.append(' ' * 2 + '# NOTE: GenerationConfig arguments\n')
|
||||
generation_config_anns = codegen.get_annotations(GenerationConfig)
|
||||
for keys, type_pep563 in generation_config_anns.items():
|
||||
lines.extend([' '*2 + line for line in ['@overload\n', f'def __getitem__(self, item: t.Literal["{keys}"]) -> {type_pep563}: ...\n']])
|
||||
lines.append(' '*2 + '# NOTE: SamplingParams arguments\n')
|
||||
lines.extend([
|
||||
' ' * 2 + line
|
||||
for line in ['@overload\n', f'def __getitem__(self, item: t.Literal["{keys}"]) -> {type_pep563}: ...\n']
|
||||
])
|
||||
lines.append(' ' * 2 + '# NOTE: SamplingParams arguments\n')
|
||||
for keys, type_pep563 in codegen.get_annotations(SamplingParams).items():
|
||||
if keys not in generation_config_anns: lines.extend([' '*2 + line for line in ['@overload\n', f'def __getitem__(self, item: t.Literal["{keys}"]) -> {type_pep563}: ...\n',]])
|
||||
lines.append(' '*2 + '# NOTE: PeftType arguments\n')
|
||||
if keys not in generation_config_anns:
|
||||
lines.extend([
|
||||
' ' * 2 + line
|
||||
for line in ['@overload\n', f'def __getitem__(self, item: t.Literal["{keys}"]) -> {type_pep563}: ...\n',]
|
||||
])
|
||||
lines.append(' ' * 2 + '# NOTE: PeftType arguments\n')
|
||||
for keys in PeftType._member_names_:
|
||||
lines.extend([' '*2 + line for line in ['@overload\n', f'def __getitem__(self, item: t.Literal["{keys.lower()}"]) -> dict[str, t.Any]: ...\n',]])
|
||||
lines.extend([
|
||||
' ' * 2 + line for line in
|
||||
['@overload\n', f'def __getitem__(self, item: t.Literal["{keys.lower()}"]) -> dict[str, t.Any]: ...\n',]
|
||||
])
|
||||
|
||||
processed = processed[:start_attrs_idx] + [' '*4 + START_ATTRS_COMMENT, *special_attrs_lines, ' '*4 + END_ATTRS_COMMENT] + processed[end_attrs_idx + 1:start_stub_idx] + [
|
||||
' '*4 + START_SPECIAL_COMMENT, *config_attr_lines, ' '*4 + END_SPECIAL_COMMENT
|
||||
] + processed[end_stub_idx + 1:start_idx] + [' '*2 + START_COMMENT, *lines, ' '*2 + END_COMMENT] + processed[end_idx + 1:]
|
||||
processed = processed[:start_attrs_idx] + [
|
||||
' ' * 4 + START_ATTRS_COMMENT, *special_attrs_lines, ' ' * 4 + END_ATTRS_COMMENT
|
||||
] + processed[end_attrs_idx + 1:start_stub_idx] + [
|
||||
' ' * 4 + START_SPECIAL_COMMENT, *config_attr_lines, ' ' * 4 + END_SPECIAL_COMMENT
|
||||
] + processed[end_stub_idx + 1:start_idx] + [' ' * 2 + START_COMMENT, *lines, ' ' * 2 + END_COMMENT
|
||||
] + processed[end_idx + 1:]
|
||||
with _TARGET_FILE.open('w') as f:
|
||||
f.writelines(processed)
|
||||
return 0
|
||||
|
||||
@@ -13,9 +13,16 @@ from openllm import CONFIG_MAPPING
|
||||
|
||||
if t.TYPE_CHECKING: from collections import OrderedDict
|
||||
|
||||
config_requirements = {k: [_.replace('-', '_') for _ in v.__openllm_requirements__] if v.__openllm_requirements__ else None for k, v in CONFIG_MAPPING.items()}
|
||||
_dependencies: dict[LiteralRuntime, str] = {k: v for k, v in zip(LiteralRuntime.__args__, ('torch', 'tensorflow', 'flax', 'vllm'))}
|
||||
_auto: dict[str, str] = {k: v for k, v in zip(LiteralRuntime.__args__, ('AutoLLM', 'AutoTFLLM', 'AutoFlaxLLM', 'AutoVLLM'))}
|
||||
config_requirements = {
|
||||
k: [_.replace('-', '_') for _ in v.__openllm_requirements__] if v.__openllm_requirements__ else None
|
||||
for k, v in CONFIG_MAPPING.items()
|
||||
}
|
||||
_dependencies: dict[LiteralRuntime, str] = {
|
||||
k: v for k, v in zip(LiteralRuntime.__args__, ('torch', 'tensorflow', 'flax', 'vllm'))
|
||||
}
|
||||
_auto: dict[str, str] = {
|
||||
k: v for k, v in zip(LiteralRuntime.__args__, ('AutoLLM', 'AutoTFLLM', 'AutoFlaxLLM', 'AutoVLLM'))
|
||||
}
|
||||
|
||||
def get_target_dummy_file(framework: LiteralRuntime) -> Path:
|
||||
return _ROOT / 'openllm-python' / 'src' / 'openllm' / 'utils' / f'dummy_{framework}_objects.py'
|
||||
@@ -28,23 +35,24 @@ def get_mapping(framework: LiteralRuntime) -> OrderedDict[t.Any, t.Any]:
|
||||
|
||||
def make_class_stub(model_name: str, framework: LiteralRuntime, indentation: int = 2, auto: bool = False) -> list[str]:
|
||||
_dep_list: list[str] = [
|
||||
f'"{v}"' for v in [_dependencies[framework], *(t.cast(t.List[str], config_requirements[model_name]) if model_name != '__default__' and config_requirements[model_name] else [])]
|
||||
f'"{v}"' for v in [
|
||||
_dependencies[framework], *(t.cast(t.List[str], config_requirements[model_name])
|
||||
if model_name != '__default__' and config_requirements[model_name] else [])
|
||||
]
|
||||
]
|
||||
if auto: cl_ = _auto[framework]
|
||||
else: cl_ = get_mapping(framework)[model_name]
|
||||
lines = [
|
||||
f'class {cl_}(metaclass=_DummyMetaclass):',
|
||||
' '*indentation + f"_backends=[{','.join(_dep_list)}]",
|
||||
' '*indentation + f"def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,[{','.join(_dep_list)}])"
|
||||
f'class {cl_}(metaclass=_DummyMetaclass):', ' ' * indentation + f"_backends=[{','.join(_dep_list)}]",
|
||||
' ' * indentation +
|
||||
f"def __init__(self,*param_decls:_t.Any,**attrs: _t.Any):_require_backends(self,[{','.join(_dep_list)}])"
|
||||
]
|
||||
return lines
|
||||
|
||||
def write_stub(framework: LiteralRuntime, _path: str) -> list[str]:
|
||||
base = [
|
||||
f'# This file is generated by {_path}. DO NOT EDIT MANUALLY!',
|
||||
f'# To update this, run ./{_path}',
|
||||
'from __future__ import annotations',
|
||||
'import typing as _t',
|
||||
f'# This file is generated by {_path}. DO NOT EDIT MANUALLY!', f'# To update this, run ./{_path}',
|
||||
'from __future__ import annotations', 'import typing as _t',
|
||||
'from openllm_core.utils import DummyMetaclass as _DummyMetaclass, require_backends as _require_backends',
|
||||
]
|
||||
base.extend([v for it in [make_class_stub(k, framework) for k in get_mapping(framework)] for v in it])
|
||||
@@ -52,7 +60,10 @@ def write_stub(framework: LiteralRuntime, _path: str) -> list[str]:
|
||||
base.extend(make_class_stub('__default__', framework, auto=True))
|
||||
# mapping and export
|
||||
_imports = [f'"{v}"' for v in get_mapping(framework).values()]
|
||||
base += [f'{mapping_names(framework)}:_t.Any=None', f"__all__:list[str]=[\"{mapping_names(framework)}\",\"{_auto[framework]}\",{','.join(_imports)}]\n"]
|
||||
base += [
|
||||
f'{mapping_names(framework)}:_t.Any=None',
|
||||
f"__all__:list[str]=[\"{mapping_names(framework)}\",\"{_auto[framework]}\",{','.join(_imports)}]\n"
|
||||
]
|
||||
return base
|
||||
|
||||
def main() -> int:
|
||||
|
||||
@@ -6,32 +6,29 @@ from pathlib import Path
|
||||
_TARGET_FILE = Path(__file__).parent.parent / 'openllm-python' / 'src' / 'openllm' / 'models' / '__init__.py'
|
||||
|
||||
def create_module_import() -> str:
|
||||
r = [f'"{p.name}"' for p in _TARGET_FILE.parent.glob('*/') if p.name not in ['__pycache__', '__init__.py', '.DS_Store']]
|
||||
r = [
|
||||
f'"{p.name}"' for p in _TARGET_FILE.parent.glob('*/')
|
||||
if p.name not in ['__pycache__', '__init__.py', '.DS_Store']
|
||||
]
|
||||
return f"_MODELS:set[str]={{{', '.join(sorted(r))}}}"
|
||||
|
||||
def create_stubs_import() -> list[str]:
|
||||
return [
|
||||
'if t.TYPE_CHECKING:from . import ' + ','.join([f'{p.name} as {p.name}' for p in sorted(_TARGET_FILE.parent.glob('*/')) if p.name not in {'__pycache__', '__init__.py', '.DS_Store'}]),
|
||||
'__lazy=LazyModule(__name__, globals()["__file__"], {k: [] for k in _MODELS})',
|
||||
'__all__=__lazy.__all__',
|
||||
'__dir__=__lazy.__dir__',
|
||||
'__getattr__=__lazy.__getattr__\n'
|
||||
'if t.TYPE_CHECKING:from . import ' + ','.join([
|
||||
f'{p.name} as {p.name}' for p in sorted(_TARGET_FILE.parent.glob('*/'))
|
||||
if p.name not in {'__pycache__', '__init__.py', '.DS_Store'}
|
||||
]), '__lazy=LazyModule(__name__, globals()["__file__"], {k: [] for k in _MODELS})', '__all__=__lazy.__all__',
|
||||
'__dir__=__lazy.__dir__', '__getattr__=__lazy.__getattr__\n'
|
||||
]
|
||||
|
||||
def main() -> int:
|
||||
_path = os.path.join(os.path.basename(os.path.dirname(__file__)), os.path.basename(__file__))
|
||||
with _TARGET_FILE.open('w') as f:
|
||||
f.writelines(
|
||||
'\n'.join([
|
||||
f'# This file is generated by {_path}. DO NOT EDIT MANUALLY!',
|
||||
f'# To update this, run ./{_path}',
|
||||
'from __future__ import annotations',
|
||||
'import typing as t',
|
||||
'from openllm_core.utils import LazyModule',
|
||||
create_module_import(),
|
||||
*create_stubs_import(),
|
||||
])
|
||||
)
|
||||
f.writelines('\n'.join([
|
||||
f'# This file is generated by {_path}. DO NOT EDIT MANUALLY!', f'# To update this, run ./{_path}',
|
||||
'from __future__ import annotations', 'import typing as t', 'from openllm_core.utils import LazyModule',
|
||||
create_module_import(), *create_stubs_import(),
|
||||
]))
|
||||
return 0
|
||||
|
||||
if __name__ == '__main__': raise SystemExit(main())
|
||||
|
||||
@@ -18,7 +18,11 @@ def main() -> int:
|
||||
|
||||
start_index, stop_index = readme.index(START_COMMENT), readme.index(END_COMMENT)
|
||||
formatted: dict[t.Literal['Model', 'Architecture', 'URL', 'Installation', 'Model Ids'], list[str | list[str]]] = {
|
||||
'Model': [], 'Architecture': [], 'URL': [], 'Model Ids': [], 'Installation': [],
|
||||
'Model': [],
|
||||
'Architecture': [],
|
||||
'URL': [],
|
||||
'Model Ids': [],
|
||||
'Installation': [],
|
||||
}
|
||||
max_install_len_div = 0
|
||||
for name, config_cls in openllm.CONFIG_MAPPING.items():
|
||||
@@ -38,7 +42,8 @@ def main() -> int:
|
||||
meta.extend([f'<th>{header}</th>\n' for header in formatted.keys() if header not in ('URL',)])
|
||||
meta += ['</tr>\n']
|
||||
# NOTE: rows
|
||||
for name, architecture, url, model_ids, installation in t.cast(t.Iterable[t.Tuple[str, str, str, t.List[str], str]], zip(*formatted.values())):
|
||||
for name, architecture, url, model_ids, installation in t.cast(t.Iterable[t.Tuple[str, str, str, t.List[str], str]],
|
||||
zip(*formatted.values())):
|
||||
meta += '<tr>\n'
|
||||
# configure architecture URL
|
||||
cfg_cls = openllm.CONFIG_MAPPING[name]
|
||||
|
||||
@@ -31,7 +31,8 @@ def main() -> int:
|
||||
color = 'ok' if float(total_rate) >= 95 else 'critical'
|
||||
lines.insert(0, f'\n')
|
||||
|
||||
lines.append(f'**Summary** | {100 if total_rate == 100 else total_rate}% ({total_statements_covered} / {total_statements})\n')
|
||||
lines.append(
|
||||
f'**Summary** | {100 if total_rate == 100 else total_rate}% ({total_statements_covered} / {total_statements})\n')
|
||||
|
||||
coverage_report = ROOT / 'coverage-report.md'
|
||||
with coverage_report.open('w', encoding='utf-8') as f:
|
||||
|
||||
Reference in New Issue
Block a user