mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-02-18 14:47:30 -05:00
perf: unify LLM interface (#518)
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> Signed-off-by: paperspace <29749331+aarnphm@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -22,7 +22,6 @@ def model_settings(draw: st.DrawFn):
|
||||
'url': st.text(),
|
||||
'trust_remote_code': st.booleans(),
|
||||
'requirements': st.none() | st.lists(st.text(), min_size=1),
|
||||
'default_backend': st.dictionaries(st.sampled_from(['cpu', 'nvidia.com/gpu']), st.sampled_from(['vllm', 'pt', 'tf', 'flax'])),
|
||||
'model_type': st.sampled_from(['causal_lm', 'seq2seq_lm']),
|
||||
'name_type': st.sampled_from(['dasherize', 'lowercase']),
|
||||
'timeout': st.integers(min_value=3600),
|
||||
|
||||
@@ -10,15 +10,12 @@ import openllm
|
||||
if t.TYPE_CHECKING:
|
||||
from openllm_core._typing_compat import LiteralBackend
|
||||
|
||||
_MODELING_MAPPING = {'flan_t5': 'google/flan-t5-small', 'opt': 'facebook/opt-125m', 'baichuan': 'baichuan-inc/Baichuan-7B',}
|
||||
_PROMPT_MAPPING = {'qa': 'Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?',}
|
||||
_MODELING_MAPPING = {'flan_t5': 'google/flan-t5-small', 'opt': 'facebook/opt-125m', 'baichuan': 'baichuan-inc/Baichuan-7B'}
|
||||
_PROMPT_MAPPING = {'qa': 'Answer the following yes/no question by reasoning step-by-step. Can you write a whole Haiku in a single tweet?'}
|
||||
|
||||
def parametrise_local_llm(model: str,) -> t.Generator[tuple[str, openllm.LLMRunner[t.Any, t.Any] | openllm.LLM[t.Any, t.Any]], None, None]:
|
||||
if model not in _MODELING_MAPPING: pytest.skip(f"'{model}' is not yet supported in framework testing.")
|
||||
backends: tuple[LiteralBackend, ...] = tuple()
|
||||
if model in openllm.MODEL_MAPPING_NAMES: backends += ('pt',)
|
||||
if model in openllm.MODEL_FLAX_MAPPING_NAMES: backends += ('flax',)
|
||||
if model in openllm.MODEL_TF_MAPPING_NAMES: backends += ('tf',)
|
||||
backends: tuple[LiteralBackend, ...] = ('pt',)
|
||||
for backend, prompt in itertools.product(backends, _PROMPT_MAPPING.keys()):
|
||||
yield prompt, openllm.Runner(model, model_id=_MODELING_MAPPING[model], ensure_available=True, backend=backend, init_local=True)
|
||||
|
||||
|
||||
@@ -21,7 +21,8 @@ from syrupy.extensions.json import JSONSnapshotExtension
|
||||
|
||||
import openllm
|
||||
|
||||
from openllm._llm import normalise_model_name
|
||||
from bentoml._internal.types import LazyType
|
||||
from openllm._llm import self
|
||||
from openllm_core._typing_compat import DictStrAny
|
||||
from openllm_core._typing_compat import ListAny
|
||||
from openllm_core._typing_compat import LiteralQuantise
|
||||
@@ -37,12 +38,11 @@ if t.TYPE_CHECKING:
|
||||
from syrupy.types import SerializableData
|
||||
from syrupy.types import SerializedData
|
||||
|
||||
from openllm._configuration import GenerationConfig
|
||||
from openllm.client import BaseAsyncClient
|
||||
|
||||
class ResponseComparator(JSONSnapshotExtension):
|
||||
def serialize(self, data: SerializableData, *, exclude: PropertyFilter | None = None, matcher: PropertyMatcher | None = None,) -> SerializedData:
|
||||
if openllm.utils.LazyType(ListAny).isinstance(data):
|
||||
if LazyType(ListAny).isinstance(data):
|
||||
data = [d.unmarshaled for d in data]
|
||||
else:
|
||||
data = data.unmarshaled
|
||||
@@ -50,31 +50,28 @@ class ResponseComparator(JSONSnapshotExtension):
|
||||
return orjson.dumps(data, option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS).decode()
|
||||
|
||||
def matches(self, *, serialized_data: SerializableData, snapshot_data: SerializableData) -> bool:
|
||||
def convert_data(data: SerializableData) -> openllm.GenerateOutput | t.Sequence[openllm.GenerateOutput]:
|
||||
def convert_data(data: SerializableData) -> openllm.GenerationOutput | t.Sequence[openllm.GenerationOutput]:
|
||||
try:
|
||||
data = orjson.loads(data)
|
||||
except orjson.JSONDecodeError as err:
|
||||
raise ValueError(f'Failed to decode JSON data: {data}') from err
|
||||
if openllm.utils.LazyType(DictStrAny).isinstance(data):
|
||||
return openllm.GenerateOutput(**data)
|
||||
elif openllm.utils.LazyType(ListAny).isinstance(data):
|
||||
return [openllm.GenerateOutput(**d) for d in data]
|
||||
if LazyType(DictStrAny).isinstance(data):
|
||||
return openllm.GenerationOutput(**data)
|
||||
elif LazyType(ListAny).isinstance(data):
|
||||
return [openllm.GenerationOutput(**d) for d in data]
|
||||
else:
|
||||
raise NotImplementedError(f'Data {data} has unsupported type.')
|
||||
|
||||
serialized_data = convert_data(serialized_data)
|
||||
snapshot_data = convert_data(snapshot_data)
|
||||
|
||||
if openllm.utils.LazyType(ListAny).isinstance(serialized_data):
|
||||
if LazyType(ListAny).isinstance(serialized_data):
|
||||
serialized_data = [serialized_data]
|
||||
if openllm.utils.LazyType(ListAny).isinstance(snapshot_data):
|
||||
if LazyType(ListAny).isinstance(snapshot_data):
|
||||
snapshot_data = [snapshot_data]
|
||||
|
||||
def eq_config(s: GenerationConfig, t: GenerationConfig) -> bool:
|
||||
return s == t
|
||||
|
||||
def eq_output(s: openllm.GenerateOutput, t: openllm.GenerateOutput) -> bool:
|
||||
return (len(s.responses) == len(t.responses) and all([_s == _t for _s, _t in zip(s.responses, t.responses)]) and eq_config(s.marshaled_config, t.marshaled_config))
|
||||
def eq_output(s: openllm.GenerationOutput, t: openllm.GenerationOutput) -> bool:
|
||||
return len(s.outputs) == len(t.outputs)
|
||||
|
||||
return len(serialized_data) == len(snapshot_data) and all([eq_output(s, t) for s, t in zip(serialized_data, snapshot_data)])
|
||||
|
||||
@@ -168,7 +165,7 @@ def _container_handle(model: str, model_id: str, image_tag: str, deployment_mode
|
||||
|
||||
with openllm.utils.reserve_free_port() as port, openllm.utils.reserve_free_port() as prom_port:
|
||||
pass
|
||||
container_name = f'openllm-{model}-{normalise_model_name(model_id)}'.replace('-', '_')
|
||||
container_name = f'openllm-{model}-{self(model_id)}'.replace('-', '_')
|
||||
client = docker.from_env()
|
||||
try:
|
||||
container = client.containers.get(container_name)
|
||||
|
||||
@@ -24,7 +24,7 @@ actions_xfail = functools.partial(pytest.mark.xfail,
|
||||
def test_general_build_with_internal_testing():
|
||||
bento_store = BentoMLContainer.bento_store.get()
|
||||
|
||||
llm = openllm.AutoLLM.for_model('flan-t5', model_id=HF_INTERNAL_T5_TESTING)
|
||||
llm = openllm.LLM(model_id=HF_INTERNAL_T5_TESTING, serialisation='legacy')
|
||||
bento = openllm.build('flan-t5', model_id=HF_INTERNAL_T5_TESTING)
|
||||
|
||||
assert llm.llm_type == bento.info.labels['_type']
|
||||
@@ -36,7 +36,8 @@ def test_general_build_with_internal_testing():
|
||||
@actions_xfail
|
||||
def test_general_build_from_local(tmp_path_factory: pytest.TempPathFactory):
|
||||
local_path = tmp_path_factory.mktemp('local_t5')
|
||||
llm = openllm.AutoLLM.for_model('flan-t5', model_id=HF_INTERNAL_T5_TESTING, ensure_available=True)
|
||||
llm = openllm.LLM(model_id=HF_INTERNAL_T5_TESTING, serialisation='legacy')
|
||||
llm.save_pretrained()
|
||||
|
||||
if isinstance(llm.model, transformers.Pipeline):
|
||||
llm.model.save_pretrained(str(local_path))
|
||||
|
||||
Reference in New Issue
Block a user