mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-02-23 10:16:06 -05:00
fix(breaking): remove embeddings and update client implementation (#500)
This commit is contained in:
@@ -15,7 +15,7 @@ from . import exceptions as exceptions, utils as utils
|
||||
|
||||
from openllm_core._configuration import GenerationConfig as GenerationConfig, LLMConfig as LLMConfig, SamplingParams as SamplingParams
|
||||
from openllm_core._strategies import CascadingResourceStrategy as CascadingResourceStrategy, get_resource as get_resource
|
||||
from openllm_core._schema import EmbeddingsOutput as EmbeddingsOutput, GenerationInput as GenerationInput, GenerationOutput as GenerationOutput, HfAgentInput as HfAgentInput, MetadataOutput as MetadataOutput, unmarshal_vllm_outputs as unmarshal_vllm_outputs
|
||||
from openllm_core._schema import GenerationInput as GenerationInput, GenerationOutput as GenerationOutput, HfAgentInput as HfAgentInput, MetadataOutput as MetadataOutput, unmarshal_vllm_outputs as unmarshal_vllm_outputs
|
||||
from openllm_core.config import AutoConfig as AutoConfig, CONFIG_MAPPING as CONFIG_MAPPING, CONFIG_MAPPING_NAMES as CONFIG_MAPPING_NAMES, BaichuanConfig as BaichuanConfig, ChatGLMConfig as ChatGLMConfig, DollyV2Config as DollyV2Config, FalconConfig as FalconConfig, FlanT5Config as FlanT5Config, GPTNeoXConfig as GPTNeoXConfig, LlamaConfig as LlamaConfig, MPTConfig as MPTConfig, OPTConfig as OPTConfig, StableLMConfig as StableLMConfig, StarCoderConfig as StarCoderConfig
|
||||
|
||||
if openllm_core.utils.DEBUG:
|
||||
@@ -45,8 +45,7 @@ _import_structure: dict[str, list[str]] = {
|
||||
"serialisation": ["ggml", "transformers"],
|
||||
"cli._sdk": ["start", "start_grpc", "build", "import_model", "list_models"],
|
||||
"_quantisation": ["infer_quantisation_config"],
|
||||
"_embeddings": ["GenericEmbeddingRunnable"],
|
||||
"_llm": ["LLM", "Runner", "LLMRunner", "LLMRunnable", "EmbeddingsOutput"],
|
||||
"_llm": ["LLM", "Runner", "LLMRunner", "LLMRunnable"],
|
||||
"_generation": ["StopSequenceCriteria", "StopOnTokens", "LogitsProcessorList", "StoppingCriteriaList", "prepare_logits_processor"],
|
||||
"models.auto": ["MODEL_MAPPING_NAMES", "MODEL_FLAX_MAPPING_NAMES", "MODEL_TF_MAPPING_NAMES", "MODEL_VLLM_MAPPING_NAMES"],
|
||||
"models.chatglm": [],
|
||||
@@ -66,9 +65,8 @@ COMPILED = _Path(__file__).suffix in (".pyd", ".so")
|
||||
if _t.TYPE_CHECKING:
|
||||
from . import bundle as bundle, cli as cli, client as client, models as models, playground as playground, serialisation as serialisation, testing as testing
|
||||
from ._generation import LogitsProcessorList as LogitsProcessorList, StopOnTokens as StopOnTokens, StoppingCriteriaList as StoppingCriteriaList, StopSequenceCriteria as StopSequenceCriteria, prepare_logits_processor as prepare_logits_processor
|
||||
from ._llm import LLM as LLM, EmbeddingsOutput as EmbeddingsOutput, LLMRunnable as LLMRunnable, LLMRunner as LLMRunner, Runner as Runner
|
||||
from ._llm import LLM as LLM, LLMRunnable as LLMRunnable, LLMRunner as LLMRunner, Runner as Runner
|
||||
from ._quantisation import infer_quantisation_config as infer_quantisation_config
|
||||
from ._embeddings import GenericEmbeddingRunnable as GenericEmbeddingRunnable
|
||||
from .cli._sdk import build as build, import_model as import_model, list_models as list_models, start as start, start_grpc as start_grpc
|
||||
from .models.auto import MODEL_FLAX_MAPPING_NAMES as MODEL_FLAX_MAPPING_NAMES, MODEL_MAPPING_NAMES as MODEL_MAPPING_NAMES, MODEL_TF_MAPPING_NAMES as MODEL_TF_MAPPING_NAMES, MODEL_VLLM_MAPPING_NAMES as MODEL_VLLM_MAPPING_NAMES
|
||||
from .serialisation import ggml as ggml, transformers as transformers
|
||||
@@ -182,7 +180,7 @@ else:
|
||||
from .models.opt import TFOPT as TFOPT
|
||||
|
||||
# NOTE: update this to sys.modules[__name__] once mypy_extensions can recognize __spec__
|
||||
__lazy = openllm_core.utils.LazyModule(__name__, globals()["__file__"], _import_structure, extra_objects={"COMPILED": COMPILED, "__openllm_migration__": {"LLMEmbeddings": "EmbeddingsOutput"}})
|
||||
__lazy = openllm_core.utils.LazyModule(__name__, globals()["__file__"], _import_structure, extra_objects={"COMPILED": COMPILED})
|
||||
__all__ = __lazy.__all__
|
||||
__dir__ = __lazy.__dir__
|
||||
__getattr__ = __lazy.__getattr__
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
# See https://github.com/bentoml/sentence-embedding-bento for more information.
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
import transformers
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
import bentoml
|
||||
import openllm
|
||||
|
||||
from bentoml._internal.frameworks.transformers import API_VERSION
|
||||
from bentoml._internal.frameworks.transformers import MODULE_NAME
|
||||
from bentoml._internal.models.model import ModelOptions
|
||||
from bentoml._internal.models.model import ModelSignature
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
_GENERIC_EMBEDDING_ID = 'sentence-transformers/all-MiniLM-L6-v2'
|
||||
_BENTOMODEL_ID = 'sentence-transformers--all-MiniLM-L6-v2'
|
||||
|
||||
def get_or_download(ids: str = _BENTOMODEL_ID) -> bentoml.Model:
|
||||
try:
|
||||
return bentoml.transformers.get(ids)
|
||||
except bentoml.exceptions.NotFound:
|
||||
model_signatures = {
|
||||
k: ModelSignature(batchable=False)
|
||||
for k in ('forward', 'generate', 'contrastive_search', 'greedy_search', 'sample', 'beam_search', 'beam_sample', 'group_beam_search', 'constrained_beam_search', '__call__')
|
||||
}
|
||||
with bentoml.models.create(ids,
|
||||
module=MODULE_NAME,
|
||||
api_version=API_VERSION,
|
||||
options=ModelOptions(),
|
||||
context=openllm.utils.generate_context(framework_name='transformers'),
|
||||
labels={
|
||||
'runtime': 'pt',
|
||||
'framework': 'openllm'
|
||||
},
|
||||
signatures=model_signatures) as bentomodel:
|
||||
snapshot_download(_GENERIC_EMBEDDING_ID,
|
||||
local_dir=bentomodel.path,
|
||||
local_dir_use_symlinks=False,
|
||||
ignore_patterns=['*.safetensors', '*.h5', '*.ot', '*.pdf', '*.md', '.gitattributes', 'LICENSE.txt'])
|
||||
return bentomodel
|
||||
|
||||
class GenericEmbeddingRunnable(bentoml.Runnable):
|
||||
SUPPORTED_RESOURCES = ('nvidia.com/gpu', 'cpu')
|
||||
SUPPORTS_CPU_MULTI_THREADING = True
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.device = 'cuda' if openllm.utils.device_count() > 0 else 'cpu'
|
||||
self._bentomodel = get_or_download()
|
||||
self.tokenizer = transformers.AutoTokenizer.from_pretrained(self._bentomodel.path)
|
||||
self.model = transformers.AutoModel.from_pretrained(self._bentomodel.path)
|
||||
self.model.to(self.device)
|
||||
|
||||
@bentoml.Runnable.method(batchable=True, batch_dim=0)
|
||||
def encode(self, sentences: list[str]) -> t.Sequence[openllm.EmbeddingsOutput]:
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt').to(self.device)
|
||||
attention_mask = encoded_input['attention_mask']
|
||||
# Compute token embeddings
|
||||
with torch.no_grad():
|
||||
model_output = self.model(**encoded_input)
|
||||
# Perform pooling and normalize
|
||||
sentence_embeddings = F.normalize(self.mean_pooling(model_output, attention_mask), p=2, dim=1)
|
||||
return [openllm.EmbeddingsOutput(embeddings=sentence_embeddings.cpu().numpy(), num_tokens=int(torch.sum(attention_mask).item()))]
|
||||
|
||||
@staticmethod
|
||||
def mean_pooling(model_output: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
||||
import torch
|
||||
# Mean Pooling - Take attention mask into account for correct averaging
|
||||
token_embeddings = model_output[0] # First element of model_output contains all token embeddings
|
||||
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
||||
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
||||
|
||||
__all__ = ['GenericEmbeddingRunnable']
|
||||
@@ -21,7 +21,6 @@ import openllm_core
|
||||
from bentoml._internal.models.model import ModelSignature
|
||||
from openllm_core._configuration import FineTuneConfig
|
||||
from openllm_core._configuration import LLMConfig
|
||||
from openllm_core._schema import EmbeddingsOutput
|
||||
from openllm_core._typing_compat import AdaptersMapping
|
||||
from openllm_core._typing_compat import AdaptersTuple
|
||||
from openllm_core._typing_compat import AdapterType
|
||||
@@ -165,16 +164,6 @@ class LLMFunction(abc.ABC):
|
||||
'''
|
||||
raise NotImplementedError
|
||||
|
||||
def embeddings(self, prompts: list[str]) -> EmbeddingsOutput:
|
||||
'''The implementation for generating text embeddings from given prompt.
|
||||
|
||||
It takes the prompt and output the embeddings for this given LLM.
|
||||
|
||||
Returns:
|
||||
The embeddings for the given prompt.
|
||||
'''
|
||||
raise NotImplementedError
|
||||
|
||||
class LLMSerialisation(abc.ABC, t.Generic[M, T]):
|
||||
def import_model(self, *args: t.Any, trust_remote_code: bool, **attrs: t.Any) -> bentoml.Model:
|
||||
'''Import both model and tokenizer weights into as a BentoML models.
|
||||
@@ -261,8 +250,6 @@ class LLMInterface(LLMFunction, LLMSerialisation[M, T], abc.ABC):
|
||||
__llm_adapter_map__: t.Optional[ResolvedAdaptersMapping]
|
||||
'''A reference to the the cached LoRA adapter mapping.'''
|
||||
|
||||
__llm_supports_embeddings__: bool
|
||||
'''A boolean to determine whether models does implement ``LLM.embeddings``.'''
|
||||
__llm_supports_generate__: bool
|
||||
'''A boolean to determine whether models does implement ``LLM.generate``.'''
|
||||
__llm_supports_generate_one__: bool
|
||||
@@ -338,10 +325,6 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
def __getitem__(self, item: t.Literal['adapter_map']) -> ResolvedAdaptersMapping | None:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal['supports_embeddings']) -> bool:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, item: t.Literal['supports_generate']) -> bool:
|
||||
...
|
||||
@@ -876,18 +859,16 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
raise RuntimeError(f'Failed to locate {self._bentomodel}:{err}') from None
|
||||
|
||||
generate_sig = ModelSignature.from_dict(ModelSignatureDict(batchable=False))
|
||||
embeddings_sig = ModelSignature.from_dict(ModelSignatureDict(batchable=True, batch_dim=0))
|
||||
generate_iterator_sig = ModelSignature.from_dict(ModelSignatureDict(batchable=False))
|
||||
|
||||
# NOTE: returning the two langchain API's to the runner
|
||||
return llm_runner_class(self)(llm_runnable_class(self, embeddings_sig, generate_sig, generate_iterator_sig),
|
||||
return llm_runner_class(self)(llm_runnable_class(self, generate_sig, generate_iterator_sig),
|
||||
name=self.runner_name,
|
||||
embedded=False,
|
||||
models=models,
|
||||
max_batch_size=max_batch_size,
|
||||
max_latency_ms=max_latency_ms,
|
||||
method_configs=bentoml_cattr.unstructure({
|
||||
'embeddings': embeddings_sig,
|
||||
'__call__': generate_sig,
|
||||
'generate': generate_sig,
|
||||
'generate_one': generate_sig,
|
||||
@@ -970,14 +951,14 @@ class LLM(LLMInterface[M, T], ReprMixin):
|
||||
past_key_values = out = token = None
|
||||
finish_reason = None
|
||||
for i in range(config['max_new_tokens']):
|
||||
torch.cuda.synchronize()
|
||||
if torch.cuda.is_available(): torch.cuda.synchronize()
|
||||
if i == 0: # prefill
|
||||
out = self.model(torch.as_tensor([input_ids], device=self.device), use_cache=True)
|
||||
else: # decoding
|
||||
out = self.model(torch.as_tensor([[token]], device=self.device), use_cache=True, past_key_values=past_key_values)
|
||||
logits = out.logits
|
||||
past_key_values = out.past_key_values
|
||||
torch.cuda.synchronize()
|
||||
if torch.cuda.is_available(): torch.cuda.synchronize()
|
||||
|
||||
if logits_processor:
|
||||
if config['repetition_penalty'] > 1.0:
|
||||
@@ -1139,7 +1120,7 @@ class SetAdapterOutput(t.TypedDict):
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
def llm_runnable_class(self: LLM[M, T], embeddings_sig: ModelSignature, generate_sig: ModelSignature, generate_iterator_sig: ModelSignature) -> type[LLMRunnable[M, T]]:
|
||||
def llm_runnable_class(self: LLM[M, T], generate_sig: ModelSignature, generate_iterator_sig: ModelSignature) -> type[LLMRunnable[M, T]]:
|
||||
class _Runnable(bentoml.Runnable):
|
||||
SUPPORTED_RESOURCES = ('nvidia.com/gpu', 'amd.com/gpu', 'cpu')
|
||||
SUPPORTS_CPU_MULTI_THREADING = True
|
||||
@@ -1159,10 +1140,6 @@ def llm_runnable_class(self: LLM[M, T], embeddings_sig: ModelSignature, generate
|
||||
if adapter_name != 'default': self.model.set_adapter(adapter_name)
|
||||
logger.info('Successfully apply LoRA layer %s', adapter_name)
|
||||
|
||||
@bentoml.Runnable.method(**method_signature(embeddings_sig)) # type: ignore
|
||||
def embeddings(__self: _Runnable, prompt: str | list[str]) -> t.Sequence[EmbeddingsOutput]:
|
||||
return [self.embeddings([prompt] if isinstance(prompt, str) else prompt)]
|
||||
|
||||
@bentoml.Runnable.method(**method_signature(generate_sig)) # type: ignore
|
||||
def __call__(__self: _Runnable, prompt: str, **attrs: t.Any) -> list[t.Any]:
|
||||
prompt, attrs, _ = self.sanitize_parameters(prompt, **attrs)
|
||||
@@ -1303,18 +1280,6 @@ def llm_runner_class(self: LLM[M, T]) -> type[LLMRunner[M, T]]:
|
||||
prompt, generate_kwargs, postprocess_kwargs = self.sanitize_parameters(prompt, **kwargs)
|
||||
return self.postprocess_generate(prompt, __self.generate.run(prompt, **generate_kwargs), **postprocess_kwargs)
|
||||
|
||||
def _wrapped_embeddings_run(__self: LLMRunner[M, T], prompt: str | list[str]) -> EmbeddingsOutput:
|
||||
'''``llm.embed`` is a light wrapper around runner.embeedings.run().
|
||||
|
||||
Usage:
|
||||
|
||||
```python
|
||||
runner = openllm.Runner('llama', backend='pt')
|
||||
runner.embed("What is the meaning of life?")
|
||||
```
|
||||
'''
|
||||
return __self.embeddings.run([prompt] if isinstance(prompt, str) else prompt)
|
||||
|
||||
def _wrapped_repr_keys(_: LLMRunner[M, T]) -> set[str]:
|
||||
return {'config', 'llm_type', 'runner_methods', 'backend', 'llm_tag'}
|
||||
|
||||
@@ -1325,6 +1290,14 @@ def llm_runner_class(self: LLM[M, T]) -> type[LLMRunner[M, T]]:
|
||||
yield 'backend', self.__llm_backend__
|
||||
yield 'llm_tag', self.tag
|
||||
|
||||
if self._prompt_template: prompt_template = self._prompt_template.to_string()
|
||||
elif hasattr(self.config, 'default_prompt_template'): prompt_template = self.config.default_prompt_template
|
||||
else: prompt_template = None
|
||||
|
||||
if self._system_message: system_message = self._system_message
|
||||
elif hasattr(self.config, 'default_system_message'): system_message = self.config.default_system_message
|
||||
else: system_message = None
|
||||
|
||||
return types.new_class(self.__class__.__name__ + 'Runner', (bentoml.Runner,),
|
||||
exec_body=lambda ns: ns.update({
|
||||
'llm_type': self.llm_type,
|
||||
@@ -1336,17 +1309,15 @@ def llm_runner_class(self: LLM[M, T]) -> type[LLMRunner[M, T]]:
|
||||
'peft_adapters': property(fget=available_adapters),
|
||||
'download_model': self.save_pretrained,
|
||||
'__call__': _wrapped_generate_run,
|
||||
'embed': _wrapped_embeddings_run,
|
||||
'__module__': self.__module__,
|
||||
'__doc__': self.config['env'].start_docstring,
|
||||
'__repr__': ReprMixin.__repr__,
|
||||
'__repr_keys__': property(_wrapped_repr_keys),
|
||||
'__repr_args__': _wrapped_repr_args,
|
||||
'supports_embeddings': self['supports_embeddings'],
|
||||
'supports_hf_agent': self['supports_generate_one'],
|
||||
'has_adapters': self._adapters_mapping is not None,
|
||||
'prompt_template': self._prompt_template.to_string() if self._prompt_template else self.config.default_prompt_template,
|
||||
'system_message': self._system_message if self._system_message else self.config.default_system_message,
|
||||
'prompt_template': prompt_template,
|
||||
'system_message': system_message,
|
||||
}))
|
||||
|
||||
__all__ = ['LLMRunner', 'LLMRunnable', 'Runner', 'LLM', 'llm_runner_class', 'llm_runnable_class', 'EmbeddingsOutput']
|
||||
__all__ = ['LLMRunner', 'LLMRunnable', 'Runner', 'LLM', 'llm_runner_class', 'llm_runnable_class']
|
||||
|
||||
@@ -18,11 +18,6 @@ if t.TYPE_CHECKING:
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
from bentoml._internal.runner.runner import AbstractRunner
|
||||
from bentoml._internal.runner.runner import RunnerMethod
|
||||
from openllm_core._typing_compat import TypeAlias
|
||||
_EmbeddingMethod: TypeAlias = RunnerMethod[t.Union[bentoml.Runnable, openllm.LLMRunnable[t.Any, t.Any]], [t.List[str]], t.Sequence[openllm.EmbeddingsOutput]]
|
||||
|
||||
# The following warnings from bitsandbytes, and probably not that important for users to see
|
||||
warnings.filterwarnings('ignore', message='MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization')
|
||||
warnings.filterwarnings('ignore', message='MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization')
|
||||
@@ -33,14 +28,7 @@ model_id = svars.model_id
|
||||
adapter_map = svars.adapter_map
|
||||
llm_config = openllm.AutoConfig.for_model(model)
|
||||
runner = openllm.Runner(model, llm_config=llm_config, model_id=model_id, ensure_available=False, adapter_map=orjson.loads(adapter_map))
|
||||
generic_embedding_runner = bentoml.Runner(openllm.GenericEmbeddingRunnable, # XXX: remove arg-type once bentoml.Runner is correct set with type
|
||||
name='llm-generic-embedding',
|
||||
scheduling_strategy=openllm_core.CascadingResourceStrategy,
|
||||
max_batch_size=32,
|
||||
max_latency_ms=300)
|
||||
runners: list[AbstractRunner] = [runner]
|
||||
if not runner.supports_embeddings: runners.append(generic_embedding_runner)
|
||||
svc = bentoml.Service(name=f"llm-{llm_config['start_name']}-service", runners=runners)
|
||||
svc = bentoml.Service(name=f"llm-{llm_config['start_name']}-service", runners=[runner])
|
||||
|
||||
_JsonInput = bentoml.io.JSON.from_sample({'prompt': '', 'llm_config': llm_config.model_dump(flatten=True), 'adapter_name': None})
|
||||
|
||||
@@ -184,7 +172,6 @@ async def chat_completion_v1(input_dict: dict[str, t.Any], ctx: bentoml.Context)
|
||||
'model_name': llm_config['model_name'],
|
||||
'backend': runner.backend,
|
||||
'configuration': llm_config.model_dump(flatten=True),
|
||||
'supports_embeddings': runner.supports_embeddings,
|
||||
'supports_hf_agent': runner.supports_hf_agent,
|
||||
'prompt_template': runner.prompt_template,
|
||||
'system_message': runner.system_message,
|
||||
@@ -195,27 +182,11 @@ def metadata_v1(_: str) -> openllm.MetadataOutput:
|
||||
backend=llm_config['env']['backend_value'],
|
||||
model_id=runner.llm.model_id,
|
||||
configuration=llm_config.model_dump_json().decode(),
|
||||
supports_embeddings=runner.supports_embeddings,
|
||||
supports_hf_agent=runner.supports_hf_agent,
|
||||
prompt_template=runner.prompt_template,
|
||||
system_message=runner.system_message,
|
||||
)
|
||||
|
||||
@svc.api(route='/v1/embeddings',
|
||||
input=bentoml.io.JSON.from_sample(['Hey Jude, welcome to the jungle!', 'What is the meaning of life?']),
|
||||
output=bentoml.io.JSON.from_sample({
|
||||
'embeddings': [
|
||||
0.007917795330286026, -0.014421648345887661, 0.00481307040899992, 0.007331526838243008, -0.0066398633643984795, 0.00945580005645752, 0.0087016262114048, -0.010709521360695362,
|
||||
0.012635177001357079, 0.010541186667978764, -0.00730888033285737, -0.001783102168701589, 0.02339819073677063, -0.010825827717781067, -0.015888236463069916,
|
||||
0.01876218430697918, 0.0076906150206923485, 0.0009032754460349679, -0.010024012066423893, 0.01090280432254076, -0.008668390102684498, 0.02070549875497818,
|
||||
0.0014594447566196322, -0.018775740638375282, -0.014814382418990135, 0.01796768605709076
|
||||
],
|
||||
'num_tokens': 20
|
||||
}))
|
||||
async def embeddings_v1(phrases: list[str]) -> list[openllm.EmbeddingsOutput]:
|
||||
embed_call: _EmbeddingMethod = runner.embeddings if runner.supports_embeddings else generic_embedding_runner.encode # type: ignore[type-arg,assignment,valid-type]
|
||||
return await embed_call.async_run(phrases)
|
||||
|
||||
if runner.supports_hf_agent:
|
||||
|
||||
async def hf_agent(request: Request) -> Response:
|
||||
|
||||
@@ -21,7 +21,6 @@ bentomodel = openllm.import_model("falcon", model_id='tiiuae/falcon-7b-instruct'
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
import http.client
|
||||
import inspect
|
||||
import itertools
|
||||
import logging
|
||||
@@ -112,7 +111,8 @@ if t.TYPE_CHECKING:
|
||||
|
||||
from bentoml._internal.bento import BentoStore
|
||||
from bentoml._internal.container import DefaultBuilder
|
||||
from openllm_core._schema import EmbeddingsOutput
|
||||
from openllm_client._schemas import Response
|
||||
from openllm_client._schemas import StreamResponse
|
||||
from openllm_core._typing_compat import LiteralContainerRegistry
|
||||
from openllm_core._typing_compat import LiteralContainerVersionStrategy
|
||||
else:
|
||||
@@ -130,17 +130,20 @@ OPENLLM_FIGLET = '''\
|
||||
'''
|
||||
|
||||
ServeCommand = t.Literal['serve', 'serve-grpc']
|
||||
|
||||
@attr.define
|
||||
class GlobalOptions:
|
||||
cloud_context: str | None = attr.field(default=None)
|
||||
|
||||
def with_options(self, **attrs: t.Any) -> Self:
|
||||
return attr.evolve(self, **attrs)
|
||||
|
||||
GrpType = t.TypeVar('GrpType', bound=click.Group)
|
||||
|
||||
_object_setattr = object.__setattr__
|
||||
|
||||
_EXT_FOLDER = os.path.abspath(os.path.join(os.path.dirname(__file__), 'extension'))
|
||||
|
||||
class Extensions(click.MultiCommand):
|
||||
def list_commands(self, ctx: click.Context) -> list[str]:
|
||||
return sorted([filename[:-3] for filename in os.listdir(_EXT_FOLDER) if filename.endswith('.py') and not filename.startswith('__')])
|
||||
@@ -151,6 +154,7 @@ class Extensions(click.MultiCommand):
|
||||
except ImportError:
|
||||
return None
|
||||
return mod.cli
|
||||
|
||||
class OpenLLMCommandGroup(BentoMLCommandGroup):
|
||||
NUMBER_OF_COMMON_PARAMS = 5 # parameters in common_params + 1 faked group option header
|
||||
|
||||
@@ -284,10 +288,12 @@ class OpenLLMCommandGroup(BentoMLCommandGroup):
|
||||
if rows:
|
||||
with formatter.section(_('Extensions')):
|
||||
formatter.write_dl(rows)
|
||||
|
||||
@click.group(cls=OpenLLMCommandGroup, context_settings=termui.CONTEXT_SETTINGS, name='openllm')
|
||||
@click.version_option(
|
||||
None, '--version', '-v', message=f"%(prog)s, %(version)s (compiled: {'yes' if openllm.COMPILED else 'no'})\nPython ({platform.python_implementation()}) {platform.python_version()}"
|
||||
)
|
||||
@click.version_option(None,
|
||||
'--version',
|
||||
'-v',
|
||||
message=f"%(prog)s, %(version)s (compiled: {'yes' if openllm.COMPILED else 'no'})\nPython ({platform.python_implementation()}) {platform.python_version()}")
|
||||
def cli() -> None:
|
||||
'''\b
|
||||
██████╗ ██████╗ ███████╗███╗ ██╗██╗ ██╗ ███╗ ███╗
|
||||
@@ -301,6 +307,7 @@ def cli() -> None:
|
||||
An open platform for operating large language models in production.
|
||||
Fine-tune, serve, deploy, and monitor any LLMs with ease.
|
||||
'''
|
||||
|
||||
@cli.group(cls=OpenLLMCommandGroup, context_settings=termui.CONTEXT_SETTINGS, name='start', aliases=['start-http'])
|
||||
def start_command() -> None:
|
||||
'''Start any LLM as a REST server.
|
||||
@@ -310,6 +317,7 @@ def start_command() -> None:
|
||||
$ openllm <start|start-http> <model_name> --<options> ...
|
||||
```
|
||||
'''
|
||||
|
||||
@cli.group(cls=OpenLLMCommandGroup, context_settings=termui.CONTEXT_SETTINGS, name='start-grpc')
|
||||
def start_grpc_command() -> None:
|
||||
'''Start any LLM as a gRPC server.
|
||||
@@ -319,6 +327,7 @@ def start_grpc_command() -> None:
|
||||
$ openllm start-grpc <model_name> --<options> ...
|
||||
```
|
||||
'''
|
||||
|
||||
_start_mapping = {
|
||||
'start': {
|
||||
key: start_command_factory(start_command, key, _context_settings=termui.CONTEXT_SETTINGS) for key in CONFIG_MAPPING
|
||||
@@ -327,6 +336,7 @@ _start_mapping = {
|
||||
key: start_command_factory(start_grpc_command, key, _context_settings=termui.CONTEXT_SETTINGS, _serve_grpc=True) for key in CONFIG_MAPPING
|
||||
}
|
||||
}
|
||||
|
||||
@cli.command(name='import', aliases=['download'])
|
||||
@model_name_argument
|
||||
@click.argument('model_id', type=click.STRING, default=None, metavar='Optional[REMOTE_REPO/MODEL_ID | /path/to/local/model]', required=False)
|
||||
@@ -337,17 +347,9 @@ _start_mapping = {
|
||||
@machine_option
|
||||
@backend_option
|
||||
@serialisation_option
|
||||
def import_command(
|
||||
model_name: str,
|
||||
model_id: str | None,
|
||||
converter: str | None,
|
||||
model_version: str | None,
|
||||
output: LiteralOutput,
|
||||
machine: bool,
|
||||
backend: LiteralBackend,
|
||||
quantize: LiteralQuantise | None,
|
||||
serialisation: LiteralSerialisation | None,
|
||||
) -> bentoml.Model:
|
||||
def import_command(model_name: str, model_id: str | None, converter: str | None, model_version: str | None, output: LiteralOutput, machine: bool, backend: LiteralBackend,
|
||||
quantize: LiteralQuantise | None, serialisation: LiteralSerialisation | None,
|
||||
) -> bentoml.Model:
|
||||
"""Setup LLM interactively.
|
||||
|
||||
It accepts two positional arguments: `model_name` and `model_id`. The first name determine
|
||||
@@ -402,7 +404,13 @@ def import_command(
|
||||
_serialisation = openllm_core.utils.first_not_none(serialisation, default=llm_config['serialisation'])
|
||||
env = EnvVarMixin(model_name, backend=llm_config.default_backend(), model_id=model_id, quantize=quantize)
|
||||
backend = first_not_none(backend, default=env['backend_value'])
|
||||
llm = infer_auto_class(backend).for_model(model_name, model_id=env['model_id_value'], llm_config=llm_config, model_version=model_version, ensure_available=False, quantize=env['quantize_value'], serialisation=_serialisation)
|
||||
llm = infer_auto_class(backend).for_model(model_name,
|
||||
model_id=env['model_id_value'],
|
||||
llm_config=llm_config,
|
||||
model_version=model_version,
|
||||
ensure_available=False,
|
||||
quantize=env['quantize_value'],
|
||||
serialisation=_serialisation)
|
||||
_previously_saved = False
|
||||
try:
|
||||
_ref = openllm.serialisation.get(llm)
|
||||
@@ -434,66 +442,40 @@ def import_command(
|
||||
@workers_per_resource_option(factory=click, build=True)
|
||||
@cog.optgroup.group(cls=cog.MutuallyExclusiveOptionGroup, name='Optimisation options')
|
||||
@quantize_option(factory=cog.optgroup, build=True)
|
||||
@click.option(
|
||||
'--enable-features',
|
||||
multiple=True,
|
||||
nargs=1,
|
||||
metavar='FEATURE[,FEATURE]',
|
||||
help='Enable additional features for building this LLM Bento. Available: {}'.format(', '.join(OPTIONAL_DEPENDENCIES))
|
||||
)
|
||||
@click.option(
|
||||
'--adapter-id',
|
||||
default=None,
|
||||
multiple=True,
|
||||
metavar='[PATH | [remote/][adapter_name:]adapter_id][, ...]',
|
||||
help="Optional adapters id to be included within the Bento. Note that if you are using relative path, '--build-ctx' must be passed."
|
||||
)
|
||||
@click.option('--enable-features',
|
||||
multiple=True,
|
||||
nargs=1,
|
||||
metavar='FEATURE[,FEATURE]',
|
||||
help='Enable additional features for building this LLM Bento. Available: {}'.format(', '.join(OPTIONAL_DEPENDENCIES)))
|
||||
@click.option('--adapter-id',
|
||||
default=None,
|
||||
multiple=True,
|
||||
metavar='[PATH | [remote/][adapter_name:]adapter_id][, ...]',
|
||||
help="Optional adapters id to be included within the Bento. Note that if you are using relative path, '--build-ctx' must be passed.")
|
||||
@click.option('--build-ctx', help='Build context. This is required if --adapter-id uses relative path', default=None)
|
||||
@model_version_option
|
||||
@click.option('--dockerfile-template', default=None, type=click.File(), help='Optional custom dockerfile template to be used with this BentoLLM.')
|
||||
@serialisation_option
|
||||
@container_registry_option
|
||||
@click.option(
|
||||
'--container-version-strategy', type=click.Choice(['release', 'latest', 'nightly']), default='release', help="Default container version strategy for the image from '--container-registry'"
|
||||
)
|
||||
@click.option('--container-version-strategy',
|
||||
type=click.Choice(['release', 'latest', 'nightly']),
|
||||
default='release',
|
||||
help="Default container version strategy for the image from '--container-registry'")
|
||||
@cog.optgroup.group(cls=cog.MutuallyExclusiveOptionGroup, name='Utilities options')
|
||||
@cog.optgroup.option(
|
||||
'--containerize',
|
||||
default=False,
|
||||
is_flag=True,
|
||||
type=click.BOOL,
|
||||
help="Whether to containerize the Bento after building. '--containerize' is the shortcut of 'openllm build && bentoml containerize'."
|
||||
)
|
||||
@cog.optgroup.option('--containerize',
|
||||
default=False,
|
||||
is_flag=True,
|
||||
type=click.BOOL,
|
||||
help="Whether to containerize the Bento after building. '--containerize' is the shortcut of 'openllm build && bentoml containerize'.")
|
||||
@cog.optgroup.option('--push', default=False, is_flag=True, type=click.BOOL, help="Whether to push the result bento to BentoCloud. Make sure to login with 'bentoml cloud login' first.")
|
||||
@click.option('--force-push', default=False, is_flag=True, type=click.BOOL, help='Whether to force push.')
|
||||
@click.pass_context
|
||||
def build_command(
|
||||
ctx: click.Context,
|
||||
/,
|
||||
model_name: str,
|
||||
model_id: str | None,
|
||||
bento_version: str | None,
|
||||
overwrite: bool,
|
||||
output: LiteralOutput,
|
||||
quantize: LiteralQuantise | None,
|
||||
enable_features: tuple[str, ...] | None,
|
||||
workers_per_resource: float | None,
|
||||
adapter_id: tuple[str, ...],
|
||||
build_ctx: str | None,
|
||||
backend: LiteralBackend,
|
||||
system_message: str | None,
|
||||
prompt_template_file: t.IO[t.Any] | None,
|
||||
machine: bool,
|
||||
model_version: str | None,
|
||||
dockerfile_template: t.TextIO | None,
|
||||
containerize: bool,
|
||||
push: bool,
|
||||
serialisation: LiteralSerialisation | None,
|
||||
container_registry: LiteralContainerRegistry,
|
||||
container_version_strategy: LiteralContainerVersionStrategy,
|
||||
force_push: bool,
|
||||
**attrs: t.Any,
|
||||
) -> bentoml.Bento:
|
||||
def build_command(ctx: click.Context, /, model_name: str, model_id: str | None, bento_version: str | None, overwrite: bool, output: LiteralOutput, quantize: LiteralQuantise | None,
|
||||
enable_features: tuple[str, ...] | None, workers_per_resource: float | None, adapter_id: tuple[str, ...], build_ctx: str | None, backend: LiteralBackend,
|
||||
system_message: str | None, prompt_template_file: t.IO[t.Any] | None, machine: bool, model_version: str | None, dockerfile_template: t.TextIO | None, containerize: bool,
|
||||
push: bool, serialisation: LiteralSerialisation | None, container_registry: LiteralContainerRegistry, container_version_strategy: LiteralContainerVersionStrategy,
|
||||
force_push: bool, **attrs: t.Any,
|
||||
) -> bentoml.Bento:
|
||||
'''Package a given models into a Bento.
|
||||
|
||||
\b
|
||||
@@ -530,7 +512,16 @@ def build_command(
|
||||
if system_message: os.environ['OPENLLM_SYSTEM_MESSAGE'] = system_message
|
||||
if prompt_template: os.environ['OPENLLM_PROMPT_TEMPLATE'] = prompt_template
|
||||
|
||||
llm = infer_auto_class(env['backend_value']).for_model(model_name, model_id=env['model_id_value'], prompt_template=prompt_template, system_message=system_message, llm_config=llm_config, ensure_available=True, model_version=model_version, quantize=env['quantize_value'], serialisation=_serialisation, **attrs)
|
||||
llm = infer_auto_class(env['backend_value']).for_model(model_name,
|
||||
model_id=env['model_id_value'],
|
||||
prompt_template=prompt_template,
|
||||
system_message=system_message,
|
||||
llm_config=llm_config,
|
||||
ensure_available=True,
|
||||
model_version=model_version,
|
||||
quantize=env['quantize_value'],
|
||||
serialisation=_serialisation,
|
||||
**attrs)
|
||||
|
||||
labels = dict(llm.identifying_params)
|
||||
labels.update({'_type': llm.llm_type, '_framework': env['backend_value']})
|
||||
@@ -575,18 +566,16 @@ def build_command(
|
||||
raise bentoml.exceptions.NotFound(f'Rebuilding existing Bento {bento_tag}') from None
|
||||
_previously_built = True
|
||||
except bentoml.exceptions.NotFound:
|
||||
bento = bundle.create_bento(
|
||||
bento_tag,
|
||||
llm_fs,
|
||||
llm,
|
||||
workers_per_resource=workers_per_resource,
|
||||
adapter_map=adapter_map,
|
||||
quantize=quantize,
|
||||
extra_dependencies=enable_features,
|
||||
dockerfile_template=dockerfile_template_path,
|
||||
container_registry=container_registry,
|
||||
container_version_strategy=container_version_strategy
|
||||
)
|
||||
bento = bundle.create_bento(bento_tag,
|
||||
llm_fs,
|
||||
llm,
|
||||
workers_per_resource=workers_per_resource,
|
||||
adapter_map=adapter_map,
|
||||
quantize=quantize,
|
||||
extra_dependencies=enable_features,
|
||||
dockerfile_template=dockerfile_template_path,
|
||||
container_registry=container_registry,
|
||||
container_version_strategy=container_version_strategy)
|
||||
except Exception as err:
|
||||
raise err from None
|
||||
|
||||
@@ -596,12 +585,11 @@ def build_command(
|
||||
termui.echo('\n' + OPENLLM_FIGLET, fg='white')
|
||||
if not _previously_built: termui.echo(f'Successfully built {bento}.', fg='green')
|
||||
elif not overwrite: termui.echo(f"'{model_name}' already has a Bento built [{bento}]. To overwrite it pass '--overwrite'.", fg='yellow')
|
||||
termui.echo(
|
||||
'📖 Next steps:\n\n' + f"* Push to BentoCloud with 'bentoml push':\n\t$ bentoml push {bento.tag}\n\n" +
|
||||
f"* Containerize your Bento with 'bentoml containerize':\n\t$ bentoml containerize {bento.tag} --opt progress=plain\n\n" +
|
||||
"\tTip: To enable additional BentoML features for 'containerize', use '--enable-features=FEATURE[,FEATURE]' [see 'bentoml containerize -h' for more advanced usage]\n",
|
||||
fg='blue',
|
||||
)
|
||||
termui.echo('📖 Next steps:\n\n' + f"* Push to BentoCloud with 'bentoml push':\n\t$ bentoml push {bento.tag}\n\n" +
|
||||
f"* Containerize your Bento with 'bentoml containerize':\n\t$ bentoml containerize {bento.tag} --opt progress=plain\n\n" +
|
||||
"\tTip: To enable additional BentoML features for 'containerize', use '--enable-features=FEATURE[,FEATURE]' [see 'bentoml containerize -h' for more advanced usage]\n",
|
||||
fg='blue',
|
||||
)
|
||||
elif output == 'json':
|
||||
termui.echo(orjson.dumps(bento.info.to_dict(), option=orjson.OPT_INDENT_2).decode())
|
||||
else:
|
||||
@@ -688,7 +676,7 @@ def models_command(ctx: click.Context, output: LiteralOutput, show_available: bo
|
||||
data: list[str | tuple[str, str, list[str], str, tuple[LiteralBackend, ...]]] = []
|
||||
for m, v in json_data.items():
|
||||
data.extend([(m, v['architecture'], v['model_id'], v['installation'], v['backend'])])
|
||||
column_widths = [int(termui.COLUMNS / 12), int(termui.COLUMNS / 6), int(termui.COLUMNS / 4), int(termui.COLUMNS / 6), int(termui.COLUMNS / 4)]
|
||||
column_widths = [int(termui.COLUMNS / 12), int(termui.COLUMNS / 6), int(termui.COLUMNS / 4), int(termui.COLUMNS / 6), int(termui.COLUMNS / 4)]
|
||||
|
||||
if len(data) == 0 and len(failed_initialized) > 0:
|
||||
termui.echo('Exception found while parsing models:\n', fg='yellow')
|
||||
@@ -716,14 +704,17 @@ def models_command(ctx: click.Context, output: LiteralOutput, show_available: bo
|
||||
if show_available: json_data['local'] = local_models
|
||||
termui.echo(orjson.dumps(json_data, option=orjson.OPT_INDENT_2,).decode(), fg='white')
|
||||
ctx.exit(0)
|
||||
|
||||
@cli.command()
|
||||
@model_name_argument(required=False)
|
||||
@click.option('-y', '--yes', '--assume-yes', is_flag=True, help='Skip confirmation when deleting a specific model')
|
||||
@click.option('--include-bentos/--no-include-bentos', is_flag=True, default=False, help='Whether to also include pruning bentos.')
|
||||
@inject
|
||||
def prune_command(
|
||||
model_name: str | None, yes: bool, include_bentos: bool, model_store: ModelStore = Provide[BentoMLContainer.model_store], bento_store: BentoStore = Provide[BentoMLContainer.bento_store]
|
||||
) -> None:
|
||||
def prune_command(model_name: str | None,
|
||||
yes: bool,
|
||||
include_bentos: bool,
|
||||
model_store: ModelStore = Provide[BentoMLContainer.model_store],
|
||||
bento_store: BentoStore = Provide[BentoMLContainer.bento_store]) -> None:
|
||||
'''Remove all saved models, (and optionally bentos) built with OpenLLM locally.
|
||||
|
||||
\b
|
||||
@@ -744,6 +735,7 @@ def prune_command(
|
||||
if delete_confirmed:
|
||||
store.delete(store_item.tag)
|
||||
termui.echo(f"{store_item} deleted from {'model' if isinstance(store, ModelStore) else 'bento'} store.", fg='yellow')
|
||||
|
||||
def parsing_instruction_callback(ctx: click.Context, param: click.Parameter, value: list[str] | str | None) -> tuple[str, bool | str] | list[str] | str | None:
|
||||
if value is None:
|
||||
return value
|
||||
@@ -762,6 +754,7 @@ def parsing_instruction_callback(ctx: click.Context, param: click.Parameter, val
|
||||
return key, values[0]
|
||||
else:
|
||||
raise click.BadParameter(f'Invalid option format: {value}')
|
||||
|
||||
def shared_client_options(f: _AnyCallable | None = None, output_value: t.Literal['json', 'porcelain', 'pretty'] = 'pretty') -> t.Callable[[FC], FC]:
|
||||
options = [
|
||||
click.option('--endpoint', type=click.STRING, help='OpenLLM Server endpoint, i.e: http://localhost:3000', envvar='OPENLLM_ENDPOINT', default='http://localhost:3000',
|
||||
@@ -770,20 +763,19 @@ def shared_client_options(f: _AnyCallable | None = None, output_value: t.Literal
|
||||
output_option(default_value=output_value),
|
||||
]
|
||||
return compose(*options)(f) if f is not None else compose(*options)
|
||||
|
||||
@cli.command()
|
||||
@click.argument('task', type=click.STRING, metavar='TASK')
|
||||
@shared_client_options
|
||||
@click.option('--agent', type=click.Choice(['hf']), default='hf', help='Whether to interact with Agents from given Server endpoint.', show_default=True)
|
||||
@click.option('--remote', is_flag=True, default=False, help='Whether or not to use remote tools (inference endpoints) instead of local ones.', show_default=True)
|
||||
@click.option(
|
||||
'--opt',
|
||||
help="Define prompt options. "
|
||||
"(format: ``--opt text='I love this' --opt audio:./path/to/audio --opt image:/path/to/file``)",
|
||||
required=False,
|
||||
multiple=True,
|
||||
callback=opt_callback,
|
||||
metavar='ARG=VALUE[,ARG=VALUE]'
|
||||
)
|
||||
@click.option('--opt',
|
||||
help="Define prompt options. "
|
||||
"(format: ``--opt text='I love this' --opt audio:./path/to/audio --opt image:/path/to/file``)",
|
||||
required=False,
|
||||
multiple=True,
|
||||
callback=opt_callback,
|
||||
metavar='ARG=VALUE[,ARG=VALUE]')
|
||||
def instruct_command(endpoint: str, timeout: int, agent: LiteralString, output: LiteralOutput, remote: bool, task: str, _memoized: DictStrAny, **attrs: t.Any) -> str:
|
||||
'''Instruct agents interactively for given tasks, from a terminal.
|
||||
|
||||
@@ -795,66 +787,37 @@ def instruct_command(endpoint: str, timeout: int, agent: LiteralString, output:
|
||||
```
|
||||
'''
|
||||
raise click.ClickException("'instruct' is currently disabled")
|
||||
client = openllm.client.HTTPClient(endpoint, timeout=timeout)
|
||||
# client = openllm.client.HTTPClient(endpoint, timeout=timeout)
|
||||
#
|
||||
# try:
|
||||
# client.call('metadata')
|
||||
# except http.client.BadStatusLine:
|
||||
# raise click.ClickException(f'{endpoint} is neither a HTTP server nor reachable.') from None
|
||||
# if agent == 'hf':
|
||||
# _memoized = {k: v[0] for k, v in _memoized.items() if v}
|
||||
# client._hf_agent.set_stream(logger.info)
|
||||
# if output != 'porcelain': termui.echo(f"Sending the following prompt ('{task}') with the following vars: {_memoized}", fg='magenta')
|
||||
# result = client.ask_agent(task, agent_type=agent, return_code=False, remote=remote, **_memoized)
|
||||
# if output == 'json': termui.echo(orjson.dumps(result, option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
# else: termui.echo(result, fg='white')
|
||||
# return result
|
||||
# else:
|
||||
# raise click.BadOptionUsage('agent', f'Unknown agent type {agent}')
|
||||
|
||||
try:
|
||||
client.call('metadata')
|
||||
except http.client.BadStatusLine:
|
||||
raise click.ClickException(f'{endpoint} is neither a HTTP server nor reachable.') from None
|
||||
if agent == 'hf':
|
||||
_memoized = {k: v[0] for k, v in _memoized.items() if v}
|
||||
client._hf_agent.set_stream(logger.info)
|
||||
if output != 'porcelain': termui.echo(f"Sending the following prompt ('{task}') with the following vars: {_memoized}", fg='magenta')
|
||||
result = client.ask_agent(task, agent_type=agent, return_code=False, remote=remote, **_memoized)
|
||||
if output == 'json': termui.echo(orjson.dumps(result, option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
else: termui.echo(result, fg='white')
|
||||
return result
|
||||
else:
|
||||
raise click.BadOptionUsage('agent', f'Unknown agent type {agent}')
|
||||
@cli.command()
|
||||
@shared_client_options(output_value='json')
|
||||
@click.option('--server-type', type=click.Choice(['grpc', 'http']), help='Server type', default='http', show_default=True)
|
||||
@click.argument('text', type=click.STRING, nargs=-1)
|
||||
@machine_option
|
||||
@click.pass_context
|
||||
def embed_command(
|
||||
ctx: click.Context, text: tuple[str, ...], endpoint: str, timeout: int, server_type: t.Literal['http', 'grpc'], output: LiteralOutput, machine: bool
|
||||
) -> EmbeddingsOutput | None:
|
||||
'''Get embeddings interactively, from a terminal.
|
||||
|
||||
\b
|
||||
```bash
|
||||
$ openllm embed --endpoint http://12.323.2.1:3000 "What is the meaning of life?" "How many stars are there in the sky?"
|
||||
```
|
||||
'''
|
||||
client = openllm.client.HTTPClient(endpoint, timeout=timeout) if server_type == 'http' else openllm.client.GrpcClient(endpoint, timeout=timeout)
|
||||
try:
|
||||
gen_embed = client.embed(text)
|
||||
except ValueError:
|
||||
raise click.ClickException(f'Endpoint {endpoint} does not support embeddings.') from None
|
||||
if machine: return gen_embed
|
||||
elif output == 'pretty':
|
||||
termui.echo('Generated embeddings: ', fg='magenta', nl=False)
|
||||
termui.echo(gen_embed.embeddings, fg='white')
|
||||
termui.echo('\nNumber of tokens: ', fg='magenta', nl=False)
|
||||
termui.echo(gen_embed.num_tokens, fg='white')
|
||||
elif output == 'json':
|
||||
termui.echo(orjson.dumps(bentoml_cattr.unstructure(gen_embed), option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
else:
|
||||
termui.echo(gen_embed.embeddings, fg='white')
|
||||
ctx.exit(0)
|
||||
@cli.command()
|
||||
@shared_client_options(output_value='porcelain')
|
||||
@click.option('--server-type', type=click.Choice(['grpc', 'http']), help='Server type', default='http', show_default=True)
|
||||
@click.option('--stream/--no-stream', type=click.BOOL, is_flag=True, default=True, help='Whether to stream the response.')
|
||||
@click.argument('prompt', type=click.STRING)
|
||||
@click.option(
|
||||
'--sampling-params', help='Define query options. (format: ``--opt temperature=0.8 --opt=top_k:12)', required=False, multiple=True, callback=opt_callback, metavar='ARG=VALUE[,ARG=VALUE]'
|
||||
)
|
||||
@click.option('--sampling-params',
|
||||
help='Define query options. (format: ``--opt temperature=0.8 --opt=top_k:12)',
|
||||
required=False,
|
||||
multiple=True,
|
||||
callback=opt_callback,
|
||||
metavar='ARG=VALUE[,ARG=VALUE]')
|
||||
@click.pass_context
|
||||
def query_command(
|
||||
ctx: click.Context, /, prompt: str, endpoint: str, timeout: int, stream: bool, server_type: t.Literal['http', 'grpc'], output: LiteralOutput, _memoized: DictStrAny, **attrs: t.Any
|
||||
) -> None:
|
||||
def query_command(ctx: click.Context, /, prompt: str, endpoint: str, timeout: int, stream: bool, server_type: t.Literal['http', 'grpc'], output: LiteralOutput, _memoized: DictStrAny,
|
||||
**attrs: t.Any) -> None:
|
||||
'''Ask a LLM interactively, from a terminal.
|
||||
|
||||
\b
|
||||
@@ -870,24 +833,32 @@ def query_command(
|
||||
if output != 'porcelain':
|
||||
termui.echo('==Input==\n', fg='white')
|
||||
termui.echo(f'{prompt}', fg=input_fg)
|
||||
fn = client.generate_stream if stream else client.generate
|
||||
res = fn(prompt, **{**client._config(), **_memoized})
|
||||
if output == 'pretty':
|
||||
termui.echo('\n\n==Responses==\n', fg='white')
|
||||
if stream:
|
||||
for it in res: termui.echo(it.text, fg=generated_fg, nl=False)
|
||||
else: termui.echo(res.responses[0], fg=generated_fg)
|
||||
elif output == 'json':
|
||||
if stream:
|
||||
for it in res: termui.echo(orjson.dumps(bentoml_cattr.unstructure(it), option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
else: termui.echo(orjson.dumps(bentoml_cattr.unstructure(res), option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
else: # noqa: PLR5501
|
||||
if stream:
|
||||
for it in res: termui.echo(it.text, fg=generated_fg, nl=False)
|
||||
else: termui.echo(res.responses, fg='white')
|
||||
|
||||
if stream:
|
||||
stream_res: t.Iterator[StreamResponse] = client.generate_stream(prompt, **{**client._config(), **_memoized})
|
||||
if output == 'pretty':
|
||||
termui.echo('\n\n==Responses==\n', fg='white')
|
||||
for it in stream_res:
|
||||
termui.echo(it.text, fg=generated_fg, nl=False)
|
||||
elif output == 'json':
|
||||
for it in stream_res:
|
||||
termui.echo(orjson.dumps(bentoml_cattr.unstructure(it), option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
else:
|
||||
for it in stream_res:
|
||||
termui.echo(it.text, fg=generated_fg, nl=False)
|
||||
else:
|
||||
res: Response = client.generate(prompt, **{**client._config(), **_memoized})
|
||||
if output == 'pretty':
|
||||
termui.echo('\n\n==Responses==\n', fg='white')
|
||||
termui.echo(res.responses[0], fg=generated_fg)
|
||||
elif output == 'json':
|
||||
termui.echo(orjson.dumps(bentoml_cattr.unstructure(res), option=orjson.OPT_INDENT_2).decode(), fg='white')
|
||||
else:
|
||||
termui.echo(res.responses, fg='white')
|
||||
ctx.exit(0)
|
||||
|
||||
@cli.group(cls=Extensions, hidden=True, name='extension')
|
||||
def extension_command() -> None:
|
||||
'''Extension for OpenLLM CLI.'''
|
||||
|
||||
if __name__ == '__main__': cli()
|
||||
|
||||
@@ -4,11 +4,6 @@
|
||||
client = openllm.client.HTTPClient("http://localhost:8080")
|
||||
client.query("What is the difference between gather and scatter?")
|
||||
```
|
||||
|
||||
If the server has embedding supports, use it via `client.embed`:
|
||||
```python
|
||||
client.embed("What is the difference between gather and scatter?")
|
||||
```
|
||||
'''
|
||||
from __future__ import annotations
|
||||
import typing as t
|
||||
|
||||
@@ -15,17 +15,3 @@ class ChatGLM(openllm.LLM['transformers.PreTrainedModel', 'transformers.PreTrain
|
||||
# Only use half precision if the model is not yet quantized
|
||||
if self.config.use_half_precision: self.model.half()
|
||||
return self.model.chat(self.tokenizer, prompt, generation_config=self.config.model_construct_env(**attrs).to_generation_config())
|
||||
|
||||
def embeddings(self, prompts: list[str]) -> openllm.EmbeddingsOutput:
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
embeddings: list[list[float]] = []
|
||||
num_tokens = 0
|
||||
for prompt in prompts:
|
||||
input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.device)
|
||||
with torch.inference_mode():
|
||||
outputs = self.model(input_ids, output_hidden_states=True)
|
||||
data = F.normalize(torch.mean(outputs.hidden_states[-1].transpose(0, 1), dim=0), p=2, dim=0)
|
||||
embeddings.append(data.tolist())
|
||||
num_tokens += len(input_ids[0])
|
||||
return openllm.EmbeddingsOutput(embeddings=embeddings, num_tokens=num_tokens)
|
||||
|
||||
@@ -15,17 +15,3 @@ class FlanT5(openllm.LLM['transformers.T5ForConditionalGeneration', 'transformer
|
||||
do_sample=True,
|
||||
generation_config=self.config.model_construct_env(**attrs).to_generation_config()),
|
||||
skip_special_tokens=True)
|
||||
|
||||
def embeddings(self, prompts: list[str]) -> openllm.EmbeddingsOutput:
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
embeddings: list[list[float]] = []
|
||||
num_tokens = 0
|
||||
for prompt in prompts:
|
||||
input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.device)
|
||||
with torch.inference_mode():
|
||||
outputs = self.model(input_ids, decoder_input_ids=input_ids)
|
||||
data = F.normalize(torch.mean(outputs.encoder_last_hidden_state[0], dim=0), p=2, dim=0)
|
||||
embeddings.append(data.tolist())
|
||||
num_tokens += len(input_ids[0])
|
||||
return openllm.EmbeddingsOutput(embeddings=embeddings, num_tokens=num_tokens)
|
||||
|
||||
@@ -12,15 +12,3 @@ class Llama(openllm.LLM['transformers.LlamaForCausalLM', 'transformers.LlamaToke
|
||||
def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
|
||||
import torch
|
||||
return {'torch_dtype': torch.float16 if torch.cuda.is_available() else torch.float32}, {}
|
||||
|
||||
def embeddings(self, prompts: list[str]) -> openllm.EmbeddingsOutput:
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
encoding = self.tokenizer(prompts, padding=True, return_tensors='pt').to(self.device)
|
||||
input_ids, attention_mask = encoding['input_ids'], encoding['attention_mask']
|
||||
with torch.inference_mode():
|
||||
data = self.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True).hidden_states[-1]
|
||||
mask = attention_mask.unsqueeze(-1).expand(data.size()).float()
|
||||
masked_embeddings = data * mask
|
||||
sum_embeddings, seq_length = torch.sum(masked_embeddings, dim=1), torch.sum(mask, dim=1)
|
||||
return openllm.EmbeddingsOutput(embeddings=F.normalize(sum_embeddings / seq_length, p=2, dim=1).tolist(), num_tokens=int(torch.sum(attention_mask).item()))
|
||||
|
||||
Reference in New Issue
Block a user