mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-01-16 19:37:49 -05:00
chore: update local script and update service
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
@@ -240,7 +240,7 @@ def start_command(
|
||||
bentomodel = bentoml.models.get(model_id.lower())
|
||||
model_id = bentomodel.path
|
||||
except (ValueError, bentoml.exceptions.NotFound):
|
||||
pass
|
||||
bentomodel = None
|
||||
config = transformers.AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)
|
||||
for arch in config.architectures:
|
||||
if arch in openllm_core.AutoConfig._architecture_mappings:
|
||||
@@ -250,6 +250,8 @@ def start_command(
|
||||
raise RuntimeError(f'Failed to determine config class for {model_id}')
|
||||
|
||||
llm_config = openllm_core.AutoConfig.for_model(model_name).model_construct_env()
|
||||
if serialisation is None:
|
||||
serialisation = llm_config['serialisation']
|
||||
|
||||
# TODO: support LoRA adapters
|
||||
os.environ.update({
|
||||
|
||||
@@ -11,7 +11,6 @@ from openllm_core.utils import (
|
||||
)
|
||||
from openllm_core._typing_compat import LiteralQuantise, LiteralSerialisation, LiteralDtype
|
||||
from openllm_core._schemas import GenerationOutput, GenerationInput
|
||||
from _bentoml_sdk.service import ServiceConfig
|
||||
|
||||
Dtype = t.Union[LiteralDtype, t.Literal['auto', 'half', 'float']]
|
||||
|
||||
@@ -39,7 +38,6 @@ class LLM:
|
||||
quantise: t.Optional[LiteralQuantise] = attr.field(default=None)
|
||||
trust_remote_code: bool = attr.field(default=False)
|
||||
engine_args: t.Dict[str, t.Any] = attr.field(factory=dict, validator=check_engine_args)
|
||||
service_config: t.Optional[ServiceConfig] = attr.field(factory=dict)
|
||||
|
||||
_path: str = attr.field(
|
||||
init=False,
|
||||
@@ -106,7 +104,6 @@ class LLM:
|
||||
quantise: LiteralQuantise | None = None,
|
||||
trust_remote_code: bool = False,
|
||||
llm_config: openllm_core.LLMConfig | None = None,
|
||||
service_config: ServiceConfig | None = None,
|
||||
**engine_args: t.Any,
|
||||
) -> LLM:
|
||||
return cls(
|
||||
@@ -119,7 +116,6 @@ class LLM:
|
||||
dtype=dtype,
|
||||
engine_args=engine_args,
|
||||
trust_remote_code=trust_remote_code,
|
||||
service_config=service_config,
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@@ -5,7 +5,7 @@ from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, StreamingResponse
|
||||
import openllm, bentoml, logging, openllm_core as core
|
||||
import _service_vars as svars, typing as t
|
||||
from openllm_core._typing_compat import Annotated, Unpack
|
||||
from openllm_core._typing_compat import Annotated
|
||||
from openllm_core._schemas import MessageParam, MessagesConverterInput
|
||||
from openllm_core.protocol.openai import ModelCard, ModelList, ChatCompletionRequest
|
||||
from _openllm_tiny._helpers import OpenAI, Error
|
||||
@@ -43,19 +43,48 @@ class LLMService:
|
||||
quantise=svars.quantise,
|
||||
llm_config=llm_config,
|
||||
trust_remote_code=svars.trust_remote_code,
|
||||
services_config=svars.services_config,
|
||||
max_model_len=svars.max_model_len,
|
||||
gpu_memory_utilization=svars.gpu_memory_utilization,
|
||||
)
|
||||
self.openai = OpenAI(self.llm)
|
||||
|
||||
@core.utils.api(route='/v1/generate')
|
||||
async def generate_v1(self, **parameters: Unpack[core.GenerationInputDict]) -> core.GenerationOutput:
|
||||
return await self.llm.generate(**GenerationInput.from_dict(parameters).model_dump())
|
||||
async def generate_v1(
|
||||
self,
|
||||
llm_config: t.Dict[str, t.Any],
|
||||
prompt: str = 'What is the meaning of life?',
|
||||
prompt_token_ids: t.Optional[t.List[int]] = None,
|
||||
stop: t.Optional[t.List[str]] = None,
|
||||
stop_token_ids: t.Optional[t.List[int]] = None,
|
||||
request_id: t.Optional[str] = None,
|
||||
) -> core.GenerationOutput:
|
||||
return await self.llm.generate(
|
||||
prompt=prompt,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
llm_config=llm_config,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
@core.utils.api(route='/v1/generate_stream')
|
||||
async def generate_stream_v1(self, **parameters: Unpack[core.GenerationInputDict]) -> t.AsyncGenerator[str, None]:
|
||||
async for generated in self.llm.generate_iterator(**GenerationInput.from_dict(parameters).model_dump()):
|
||||
async def generate_stream_v1(
|
||||
self,
|
||||
llm_config: t.Dict[str, t.Any],
|
||||
prompt: str = 'What is the meaning of life?',
|
||||
prompt_token_ids: t.Optional[t.List[int]] = None,
|
||||
stop: t.Optional[t.List[str]] = None,
|
||||
stop_token_ids: t.Optional[t.List[int]] = None,
|
||||
request_id: t.Optional[str] = None,
|
||||
) -> t.AsyncGenerator[str, None]:
|
||||
async for generated in self.llm.generate_iterator(
|
||||
prompt=prompt,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
llm_config=llm_config,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
request_id=request_id,
|
||||
):
|
||||
yield f'data: {generated.model_dump_json()}\n\n'
|
||||
yield 'data: [DONE]\n\n'
|
||||
|
||||
|
||||
Reference in New Issue
Block a user