chore: update local script and update service

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham
2024-03-15 20:29:49 +00:00
parent 4229a4de72
commit 824ff68818
5 changed files with 48 additions and 16 deletions

View File

@@ -240,7 +240,7 @@ def start_command(
bentomodel = bentoml.models.get(model_id.lower())
model_id = bentomodel.path
except (ValueError, bentoml.exceptions.NotFound):
pass
bentomodel = None
config = transformers.AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)
for arch in config.architectures:
if arch in openllm_core.AutoConfig._architecture_mappings:
@@ -250,6 +250,8 @@ def start_command(
raise RuntimeError(f'Failed to determine config class for {model_id}')
llm_config = openllm_core.AutoConfig.for_model(model_name).model_construct_env()
if serialisation is None:
serialisation = llm_config['serialisation']
# TODO: support LoRA adapters
os.environ.update({

View File

@@ -11,7 +11,6 @@ from openllm_core.utils import (
)
from openllm_core._typing_compat import LiteralQuantise, LiteralSerialisation, LiteralDtype
from openllm_core._schemas import GenerationOutput, GenerationInput
from _bentoml_sdk.service import ServiceConfig
Dtype = t.Union[LiteralDtype, t.Literal['auto', 'half', 'float']]
@@ -39,7 +38,6 @@ class LLM:
quantise: t.Optional[LiteralQuantise] = attr.field(default=None)
trust_remote_code: bool = attr.field(default=False)
engine_args: t.Dict[str, t.Any] = attr.field(factory=dict, validator=check_engine_args)
service_config: t.Optional[ServiceConfig] = attr.field(factory=dict)
_path: str = attr.field(
init=False,
@@ -106,7 +104,6 @@ class LLM:
quantise: LiteralQuantise | None = None,
trust_remote_code: bool = False,
llm_config: openllm_core.LLMConfig | None = None,
service_config: ServiceConfig | None = None,
**engine_args: t.Any,
) -> LLM:
return cls(
@@ -119,7 +116,6 @@ class LLM:
dtype=dtype,
engine_args=engine_args,
trust_remote_code=trust_remote_code,
service_config=service_config,
)
@property

View File

@@ -5,7 +5,7 @@ from starlette.requests import Request
from starlette.responses import JSONResponse, StreamingResponse
import openllm, bentoml, logging, openllm_core as core
import _service_vars as svars, typing as t
from openllm_core._typing_compat import Annotated, Unpack
from openllm_core._typing_compat import Annotated
from openllm_core._schemas import MessageParam, MessagesConverterInput
from openllm_core.protocol.openai import ModelCard, ModelList, ChatCompletionRequest
from _openllm_tiny._helpers import OpenAI, Error
@@ -43,19 +43,48 @@ class LLMService:
quantise=svars.quantise,
llm_config=llm_config,
trust_remote_code=svars.trust_remote_code,
services_config=svars.services_config,
max_model_len=svars.max_model_len,
gpu_memory_utilization=svars.gpu_memory_utilization,
)
self.openai = OpenAI(self.llm)
@core.utils.api(route='/v1/generate')
async def generate_v1(self, **parameters: Unpack[core.GenerationInputDict]) -> core.GenerationOutput:
return await self.llm.generate(**GenerationInput.from_dict(parameters).model_dump())
async def generate_v1(
self,
llm_config: t.Dict[str, t.Any],
prompt: str = 'What is the meaning of life?',
prompt_token_ids: t.Optional[t.List[int]] = None,
stop: t.Optional[t.List[str]] = None,
stop_token_ids: t.Optional[t.List[int]] = None,
request_id: t.Optional[str] = None,
) -> core.GenerationOutput:
return await self.llm.generate(
prompt=prompt,
prompt_token_ids=prompt_token_ids,
llm_config=llm_config,
stop=stop,
stop_token_ids=stop_token_ids,
request_id=request_id,
)
@core.utils.api(route='/v1/generate_stream')
async def generate_stream_v1(self, **parameters: Unpack[core.GenerationInputDict]) -> t.AsyncGenerator[str, None]:
async for generated in self.llm.generate_iterator(**GenerationInput.from_dict(parameters).model_dump()):
async def generate_stream_v1(
self,
llm_config: t.Dict[str, t.Any],
prompt: str = 'What is the meaning of life?',
prompt_token_ids: t.Optional[t.List[int]] = None,
stop: t.Optional[t.List[str]] = None,
stop_token_ids: t.Optional[t.List[int]] = None,
request_id: t.Optional[str] = None,
) -> t.AsyncGenerator[str, None]:
async for generated in self.llm.generate_iterator(
prompt=prompt,
prompt_token_ids=prompt_token_ids,
llm_config=llm_config,
stop=stop,
stop_token_ids=stop_token_ids,
request_id=request_id,
):
yield f'data: {generated.model_dump_json()}\n\n'
yield 'data: [DONE]\n\n'