mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-01-20 13:29:35 -05:00
fix(build): check for parity (#508)
This commit is contained in:
@@ -15,7 +15,7 @@ from . import exceptions as exceptions, utils as utils
|
||||
|
||||
from openllm_core._configuration import GenerationConfig as GenerationConfig, LLMConfig as LLMConfig, SamplingParams as SamplingParams
|
||||
from openllm_core._strategies import CascadingResourceStrategy as CascadingResourceStrategy, get_resource as get_resource
|
||||
from openllm_core._schema import GenerationInput as GenerationInput, GenerationOutput as GenerationOutput, HfAgentInput as HfAgentInput, MetadataOutput as MetadataOutput, unmarshal_vllm_outputs as unmarshal_vllm_outputs
|
||||
from openllm_core._schema import GenerateInput as GenerateInput, GenerateOutput as GenerateOutput, GenerationOutput as GenerationOutput, HfAgentInput as HfAgentInput, MetadataOutput as MetadataOutput
|
||||
from openllm_core.config import AutoConfig as AutoConfig, CONFIG_MAPPING as CONFIG_MAPPING, CONFIG_MAPPING_NAMES as CONFIG_MAPPING_NAMES, BaichuanConfig as BaichuanConfig, ChatGLMConfig as ChatGLMConfig, DollyV2Config as DollyV2Config, FalconConfig as FalconConfig, FlanT5Config as FlanT5Config, GPTNeoXConfig as GPTNeoXConfig, LlamaConfig as LlamaConfig, MPTConfig as MPTConfig, OPTConfig as OPTConfig, StableLMConfig as StableLMConfig, StarCoderConfig as StarCoderConfig
|
||||
|
||||
if openllm_core.utils.DEBUG:
|
||||
|
||||
@@ -36,7 +36,7 @@ svc = bentoml.Service(name=f"llm-{llm_config['start_name']}-service", runners=[r
|
||||
_JsonInput = bentoml.io.JSON.from_sample({'prompt': '', 'llm_config': llm_config.model_dump(flatten=True), 'adapter_name': None})
|
||||
|
||||
@svc.api(route='/v1/generate', input=_JsonInput, output=bentoml.io.JSON.from_sample({'responses': [], 'configuration': llm_config.model_dump(flatten=True)}))
|
||||
async def generate_v1(input_dict: dict[str, t.Any]) -> openllm.GenerationOutput:
|
||||
async def generate_v1(input_dict: dict[str, t.Any]) -> openllm.GenerateOutput:
|
||||
echo = input_dict.pop('echo', False)
|
||||
qa_inputs = openllm.GenerationInput.from_llm_config(llm_config)(**input_dict)
|
||||
config = qa_inputs.llm_config.model_dump()
|
||||
@@ -46,7 +46,7 @@ async def generate_v1(input_dict: dict[str, t.Any]) -> openllm.GenerationOutput:
|
||||
if responses is None: raise ValueError("'responses' should not be None.")
|
||||
else:
|
||||
responses = await runner.generate.async_run(qa_inputs.prompt, adapter_name=qa_inputs.adapter_name, **config)
|
||||
return openllm.GenerationOutput(responses=responses, configuration=config)
|
||||
return openllm.GenerateOutput(responses=responses, configuration=config)
|
||||
|
||||
@svc.api(route='/v1/generate_stream', input=_JsonInput, output=bentoml.io.Text(content_type='text/event-stream'))
|
||||
async def generate_stream_v1(input_dict: dict[str, t.Any]) -> t.AsyncGenerator[str, None]:
|
||||
|
||||
@@ -522,6 +522,8 @@ def build_command(ctx: click.Context, /, model_name: str, model_id: str | None,
|
||||
quantize=env['quantize_value'],
|
||||
serialisation=_serialisation,
|
||||
**attrs)
|
||||
# FIX: This is a patch for _service_vars injection
|
||||
if 'OPENLLM_MODEL_ID' not in os.environ: os.environ['OPENLLM_MODEL_ID'] = llm.model_id
|
||||
|
||||
labels = dict(llm.identifying_params)
|
||||
labels.update({'_type': llm.llm_type, '_framework': env['backend_value']})
|
||||
|
||||
Reference in New Issue
Block a user