mirror of
https://github.com/bentoml/OpenLLM.git
synced 2026-06-12 02:20:32 -04:00
add local run
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import typer
|
||||
from openllm_next.model import app as model_app
|
||||
from openllm_next.repo import app as repo_app
|
||||
from openllm_next.serve import serve as serve_serve
|
||||
from openllm_next.serve import serve as local_serve, run as local_run
|
||||
|
||||
|
||||
app = typer.Typer()
|
||||
@@ -12,7 +12,12 @@ app.add_typer(model_app, name="model")
|
||||
|
||||
@app.command()
|
||||
def serve(model: str):
|
||||
serve_serve(model)
|
||||
local_serve(model)
|
||||
|
||||
|
||||
@app.command()
|
||||
def run(model: str):
|
||||
local_run(model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
accelerator_details = {
|
||||
ACCELERATOR_SPEC = {
|
||||
"nvidia-gtx-1650": {"model": "GTX 1650", "memory_size": 4.0},
|
||||
"nvidia-gtx-1060": {"model": "GTX 1060", "memory_size": 6.0},
|
||||
"nvidia-gtx-1080-ti": {"model": "GTX 1080 Ti", "memory_size": 11.0},
|
||||
@@ -432,9 +432,6 @@ def serve(model: str, tag: str = "latest", force_rebuild: bool = False):
|
||||
envs = {}
|
||||
if len(bento_info.get("envs", [])) > 0:
|
||||
for env in bento_info["envs"]:
|
||||
if env["name"] == "CLLAMA_MODEL":
|
||||
envs[env["name"]] = f"{model}:{tag}"
|
||||
continue
|
||||
if env["name"] in os.environ:
|
||||
value = os.environ.get(env["name"])
|
||||
questionary.print(f"Using environment value for {env['name']}")
|
||||
|
||||
@@ -63,19 +63,31 @@ class BentoInfo(TypedDict):
|
||||
bento_yaml: dict
|
||||
|
||||
|
||||
def run_command(cmd, cwd=None, env=None, copy_env=True):
|
||||
questionary.print("\n")
|
||||
def run_command(cmd, cwd=None, env=None, copy_env=True, bg=False):
|
||||
env = env or {}
|
||||
if cwd:
|
||||
questionary.print(f"$ cd {cwd}", style="bold")
|
||||
if env:
|
||||
for k, v in env.items():
|
||||
questionary.print(f"$ export {k}={shlex.quote(v)}", style="bold")
|
||||
if copy_env:
|
||||
env = {**os.environ, **env}
|
||||
questionary.print(f"$ {' '.join(cmd)}", style="bold")
|
||||
merged_env = {**os.environ, **env}
|
||||
else:
|
||||
merged_env = env
|
||||
if not bg:
|
||||
questionary.print("\n")
|
||||
if cwd:
|
||||
questionary.print(f"$ cd {cwd}", style="bold")
|
||||
if env:
|
||||
for k, v in env.items():
|
||||
questionary.print(f"$ export {k}={shlex.quote(v)}", style="bold")
|
||||
questionary.print(f"$ {' '.join(cmd)}", style="bold")
|
||||
try:
|
||||
subprocess.run(cmd, cwd=cwd, env=env, check=True)
|
||||
if bg:
|
||||
return subprocess.Popen(
|
||||
cmd,
|
||||
cwd=cwd,
|
||||
env=merged_env,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
else:
|
||||
subprocess.run(cmd, cwd=cwd, env=merged_env, check=True)
|
||||
except subprocess.CalledProcessError:
|
||||
questionary.print("Command failed", style=ERROR_STYLE)
|
||||
return
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import typer
|
||||
import asyncio
|
||||
import questionary
|
||||
from openllm_next.common import ERROR_STYLE, run_command
|
||||
from openllm_next.model import _get_bento_info
|
||||
@@ -6,7 +7,7 @@ from openllm_next.model import _get_bento_info
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
def _serve_model(model: str):
|
||||
def _serve_model(model: str, bg: bool = False):
|
||||
if ":" not in model:
|
||||
model = f"{model}:latest"
|
||||
bento_info = _get_bento_info(model)
|
||||
@@ -15,12 +16,60 @@ def _serve_model(model: str):
|
||||
return
|
||||
cmd = ["bentoml", "serve", model]
|
||||
env = {
|
||||
"CLLAMA_MODEL": model,
|
||||
"BENTOML_HOME": bento_info["model"]["repo"]["path"] + "/bentoml",
|
||||
}
|
||||
run_command(cmd, env=env)
|
||||
return run_command(cmd, env=env, bg=bg)
|
||||
|
||||
|
||||
@app.command()
|
||||
def serve(model: str):
|
||||
_serve_model(model)
|
||||
|
||||
|
||||
async def _run_model(model: str, timeout: int = 600):
|
||||
server_proc = _serve_model(model, bg=True)
|
||||
assert server_proc is not None
|
||||
|
||||
import bentoml
|
||||
|
||||
try:
|
||||
questionary.print("Model loading...", style="green")
|
||||
for _ in range(timeout):
|
||||
try:
|
||||
client = bentoml.AsyncHTTPClient(
|
||||
"http://localhost:3000", timeout=timeout
|
||||
)
|
||||
resp = await client.request("GET", "/readyz")
|
||||
if resp.status_code == 200:
|
||||
break
|
||||
except bentoml.exceptions.BentoMLException:
|
||||
await asyncio.sleep(1)
|
||||
else:
|
||||
questionary.print("Model failed to load", style="red")
|
||||
return
|
||||
|
||||
questionary.print("Model is ready", style="green")
|
||||
messages = []
|
||||
while True:
|
||||
try:
|
||||
message = input("uesr: ")
|
||||
messages.append(dict(role="user", content=message))
|
||||
print("assistant: ", end="")
|
||||
assistant_message = ""
|
||||
async for text in client.chat(messages=messages): # type: ignore
|
||||
assistant_message += text
|
||||
print(text, end="")
|
||||
messages.append(dict(role="assistant", content=assistant_message))
|
||||
print()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
finally:
|
||||
questionary.print("\nStopping model server...", style="green")
|
||||
server_proc.terminate()
|
||||
questionary.print("Stopped model server", style="green")
|
||||
|
||||
|
||||
@app.command()
|
||||
def run(model: str):
|
||||
asyncio.run(_run_model(model))
|
||||
|
||||
Reference in New Issue
Block a user