fix(build): check for parity (#508)

This commit is contained in:
Aaron Pham
2023-10-16 17:33:47 -04:00
committed by GitHub
parent cb4b5acf63
commit d59a8860df
7 changed files with 41 additions and 18 deletions

View File

@@ -15,7 +15,7 @@ from . import exceptions as exceptions, utils as utils
from openllm_core._configuration import GenerationConfig as GenerationConfig, LLMConfig as LLMConfig, SamplingParams as SamplingParams
from openllm_core._strategies import CascadingResourceStrategy as CascadingResourceStrategy, get_resource as get_resource
from openllm_core._schema import GenerationInput as GenerationInput, GenerationOutput as GenerationOutput, HfAgentInput as HfAgentInput, MetadataOutput as MetadataOutput, unmarshal_vllm_outputs as unmarshal_vllm_outputs
from openllm_core._schema import GenerateInput as GenerateInput, GenerateOutput as GenerateOutput, GenerationOutput as GenerationOutput, HfAgentInput as HfAgentInput, MetadataOutput as MetadataOutput
from openllm_core.config import AutoConfig as AutoConfig, CONFIG_MAPPING as CONFIG_MAPPING, CONFIG_MAPPING_NAMES as CONFIG_MAPPING_NAMES, BaichuanConfig as BaichuanConfig, ChatGLMConfig as ChatGLMConfig, DollyV2Config as DollyV2Config, FalconConfig as FalconConfig, FlanT5Config as FlanT5Config, GPTNeoXConfig as GPTNeoXConfig, LlamaConfig as LlamaConfig, MPTConfig as MPTConfig, OPTConfig as OPTConfig, StableLMConfig as StableLMConfig, StarCoderConfig as StarCoderConfig
if openllm_core.utils.DEBUG:

View File

@@ -36,7 +36,7 @@ svc = bentoml.Service(name=f"llm-{llm_config['start_name']}-service", runners=[r
_JsonInput = bentoml.io.JSON.from_sample({'prompt': '', 'llm_config': llm_config.model_dump(flatten=True), 'adapter_name': None})
@svc.api(route='/v1/generate', input=_JsonInput, output=bentoml.io.JSON.from_sample({'responses': [], 'configuration': llm_config.model_dump(flatten=True)}))
async def generate_v1(input_dict: dict[str, t.Any]) -> openllm.GenerationOutput:
async def generate_v1(input_dict: dict[str, t.Any]) -> openllm.GenerateOutput:
echo = input_dict.pop('echo', False)
qa_inputs = openllm.GenerationInput.from_llm_config(llm_config)(**input_dict)
config = qa_inputs.llm_config.model_dump()
@@ -46,7 +46,7 @@ async def generate_v1(input_dict: dict[str, t.Any]) -> openllm.GenerationOutput:
if responses is None: raise ValueError("'responses' should not be None.")
else:
responses = await runner.generate.async_run(qa_inputs.prompt, adapter_name=qa_inputs.adapter_name, **config)
return openllm.GenerationOutput(responses=responses, configuration=config)
return openllm.GenerateOutput(responses=responses, configuration=config)
@svc.api(route='/v1/generate_stream', input=_JsonInput, output=bentoml.io.Text(content_type='text/event-stream'))
async def generate_stream_v1(input_dict: dict[str, t.Any]) -> t.AsyncGenerator[str, None]:

View File

@@ -522,6 +522,8 @@ def build_command(ctx: click.Context, /, model_name: str, model_id: str | None,
quantize=env['quantize_value'],
serialisation=_serialisation,
**attrs)
# FIX: This is a patch for _service_vars injection
if 'OPENLLM_MODEL_ID' not in os.environ: os.environ['OPENLLM_MODEL_ID'] = llm.model_id
labels = dict(llm.identifying_params)
labels.update({'_type': llm.llm_type, '_framework': env['backend_value']})