add local run

This commit is contained in:
bojiang
2024-05-21 04:31:53 +08:00
parent 10ae5c64a9
commit 284c88acf7
5 changed files with 82 additions and 19 deletions

View File

@@ -1,7 +1,7 @@
import typer
from openllm_next.model import app as model_app
from openllm_next.repo import app as repo_app
from openllm_next.serve import serve as serve_serve
from openllm_next.serve import serve as local_serve, run as local_run
app = typer.Typer()
@@ -12,7 +12,12 @@ app.add_typer(model_app, name="model")
@app.command()
def serve(model: str):
serve_serve(model)
local_serve(model)
@app.command()
def run(model: str):
local_run(model)
if __name__ == "__main__":

View File

@@ -1,4 +1,4 @@
accelerator_details = {
ACCELERATOR_SPEC = {
"nvidia-gtx-1650": {"model": "GTX 1650", "memory_size": 4.0},
"nvidia-gtx-1060": {"model": "GTX 1060", "memory_size": 6.0},
"nvidia-gtx-1080-ti": {"model": "GTX 1080 Ti", "memory_size": 11.0},

View File

@@ -432,9 +432,6 @@ def serve(model: str, tag: str = "latest", force_rebuild: bool = False):
envs = {}
if len(bento_info.get("envs", [])) > 0:
for env in bento_info["envs"]:
if env["name"] == "CLLAMA_MODEL":
envs[env["name"]] = f"{model}:{tag}"
continue
if env["name"] in os.environ:
value = os.environ.get(env["name"])
questionary.print(f"Using environment value for {env['name']}")

View File

@@ -63,19 +63,31 @@ class BentoInfo(TypedDict):
bento_yaml: dict
def run_command(cmd, cwd=None, env=None, copy_env=True):
questionary.print("\n")
def run_command(cmd, cwd=None, env=None, copy_env=True, bg=False):
env = env or {}
if cwd:
questionary.print(f"$ cd {cwd}", style="bold")
if env:
for k, v in env.items():
questionary.print(f"$ export {k}={shlex.quote(v)}", style="bold")
if copy_env:
env = {**os.environ, **env}
questionary.print(f"$ {' '.join(cmd)}", style="bold")
merged_env = {**os.environ, **env}
else:
merged_env = env
if not bg:
questionary.print("\n")
if cwd:
questionary.print(f"$ cd {cwd}", style="bold")
if env:
for k, v in env.items():
questionary.print(f"$ export {k}={shlex.quote(v)}", style="bold")
questionary.print(f"$ {' '.join(cmd)}", style="bold")
try:
subprocess.run(cmd, cwd=cwd, env=env, check=True)
if bg:
return subprocess.Popen(
cmd,
cwd=cwd,
env=merged_env,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
else:
subprocess.run(cmd, cwd=cwd, env=merged_env, check=True)
except subprocess.CalledProcessError:
questionary.print("Command failed", style=ERROR_STYLE)
return

View File

@@ -1,4 +1,5 @@
import typer
import asyncio
import questionary
from openllm_next.common import ERROR_STYLE, run_command
from openllm_next.model import _get_bento_info
@@ -6,7 +7,7 @@ from openllm_next.model import _get_bento_info
app = typer.Typer()
def _serve_model(model: str):
def _serve_model(model: str, bg: bool = False):
if ":" not in model:
model = f"{model}:latest"
bento_info = _get_bento_info(model)
@@ -15,12 +16,60 @@ def _serve_model(model: str):
return
cmd = ["bentoml", "serve", model]
env = {
"CLLAMA_MODEL": model,
"BENTOML_HOME": bento_info["model"]["repo"]["path"] + "/bentoml",
}
run_command(cmd, env=env)
return run_command(cmd, env=env, bg=bg)
@app.command()
def serve(model: str):
_serve_model(model)
async def _run_model(model: str, timeout: int = 600):
server_proc = _serve_model(model, bg=True)
assert server_proc is not None
import bentoml
try:
questionary.print("Model loading...", style="green")
for _ in range(timeout):
try:
client = bentoml.AsyncHTTPClient(
"http://localhost:3000", timeout=timeout
)
resp = await client.request("GET", "/readyz")
if resp.status_code == 200:
break
except bentoml.exceptions.BentoMLException:
await asyncio.sleep(1)
else:
questionary.print("Model failed to load", style="red")
return
questionary.print("Model is ready", style="green")
messages = []
while True:
try:
message = input("uesr: ")
messages.append(dict(role="user", content=message))
print("assistant: ", end="")
assistant_message = ""
async for text in client.chat(messages=messages): # type: ignore
assistant_message += text
print(text, end="")
messages.append(dict(role="assistant", content=assistant_message))
print()
except KeyboardInterrupt:
break
finally:
questionary.print("\nStopping model server...", style="green")
server_proc.terminate()
questionary.print("Stopped model server", style="green")
@app.command()
def run(model: str):
asyncio.run(_run_model(model))