diff --git a/openllm_next/__main__.py b/openllm_next/__main__.py index 70e13f4d..9da5c740 100644 --- a/openllm_next/__main__.py +++ b/openllm_next/__main__.py @@ -1,7 +1,7 @@ import typer from openllm_next.model import app as model_app from openllm_next.repo import app as repo_app -from openllm_next.serve import serve as serve_serve +from openllm_next.serve import serve as local_serve, run as local_run app = typer.Typer() @@ -12,7 +12,12 @@ app.add_typer(model_app, name="model") @app.command() def serve(model: str): - serve_serve(model) + local_serve(model) + + +@app.command() +def run(model: str): + local_run(model) if __name__ == "__main__": diff --git a/openllm_next/spec.py b/openllm_next/accelerator_spec.py similarity index 98% rename from openllm_next/spec.py rename to openllm_next/accelerator_spec.py index 0023a2d0..798b4c31 100644 --- a/openllm_next/spec.py +++ b/openllm_next/accelerator_spec.py @@ -1,4 +1,4 @@ -accelerator_details = { +ACCELERATOR_SPEC = { "nvidia-gtx-1650": {"model": "GTX 1650", "memory_size": 4.0}, "nvidia-gtx-1060": {"model": "GTX 1060", "memory_size": 6.0}, "nvidia-gtx-1080-ti": {"model": "GTX 1080 Ti", "memory_size": 11.0}, diff --git a/openllm_next/aws.py b/openllm_next/aws.py index 35c3c887..7df1c93a 100644 --- a/openllm_next/aws.py +++ b/openllm_next/aws.py @@ -432,9 +432,6 @@ def serve(model: str, tag: str = "latest", force_rebuild: bool = False): envs = {} if len(bento_info.get("envs", [])) > 0: for env in bento_info["envs"]: - if env["name"] == "CLLAMA_MODEL": - envs[env["name"]] = f"{model}:{tag}" - continue if env["name"] in os.environ: value = os.environ.get(env["name"]) questionary.print(f"Using environment value for {env['name']}") diff --git a/openllm_next/common.py b/openllm_next/common.py index 5241362c..10eec89f 100644 --- a/openllm_next/common.py +++ b/openllm_next/common.py @@ -63,19 +63,31 @@ class BentoInfo(TypedDict): bento_yaml: dict -def run_command(cmd, cwd=None, env=None, copy_env=True): - questionary.print("\n") +def run_command(cmd, cwd=None, env=None, copy_env=True, bg=False): env = env or {} - if cwd: - questionary.print(f"$ cd {cwd}", style="bold") - if env: - for k, v in env.items(): - questionary.print(f"$ export {k}={shlex.quote(v)}", style="bold") if copy_env: - env = {**os.environ, **env} - questionary.print(f"$ {' '.join(cmd)}", style="bold") + merged_env = {**os.environ, **env} + else: + merged_env = env + if not bg: + questionary.print("\n") + if cwd: + questionary.print(f"$ cd {cwd}", style="bold") + if env: + for k, v in env.items(): + questionary.print(f"$ export {k}={shlex.quote(v)}", style="bold") + questionary.print(f"$ {' '.join(cmd)}", style="bold") try: - subprocess.run(cmd, cwd=cwd, env=env, check=True) + if bg: + return subprocess.Popen( + cmd, + cwd=cwd, + env=merged_env, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + else: + subprocess.run(cmd, cwd=cwd, env=merged_env, check=True) except subprocess.CalledProcessError: questionary.print("Command failed", style=ERROR_STYLE) return diff --git a/openllm_next/serve.py b/openllm_next/serve.py index 36c51d35..9a5384d8 100644 --- a/openllm_next/serve.py +++ b/openllm_next/serve.py @@ -1,4 +1,5 @@ import typer +import asyncio import questionary from openllm_next.common import ERROR_STYLE, run_command from openllm_next.model import _get_bento_info @@ -6,7 +7,7 @@ from openllm_next.model import _get_bento_info app = typer.Typer() -def _serve_model(model: str): +def _serve_model(model: str, bg: bool = False): if ":" not in model: model = f"{model}:latest" bento_info = _get_bento_info(model) @@ -15,12 +16,60 @@ def _serve_model(model: str): return cmd = ["bentoml", "serve", model] env = { - "CLLAMA_MODEL": model, "BENTOML_HOME": bento_info["model"]["repo"]["path"] + "/bentoml", } - run_command(cmd, env=env) + return run_command(cmd, env=env, bg=bg) @app.command() def serve(model: str): _serve_model(model) + + +async def _run_model(model: str, timeout: int = 600): + server_proc = _serve_model(model, bg=True) + assert server_proc is not None + + import bentoml + + try: + questionary.print("Model loading...", style="green") + for _ in range(timeout): + try: + client = bentoml.AsyncHTTPClient( + "http://localhost:3000", timeout=timeout + ) + resp = await client.request("GET", "/readyz") + if resp.status_code == 200: + break + except bentoml.exceptions.BentoMLException: + await asyncio.sleep(1) + else: + questionary.print("Model failed to load", style="red") + return + + questionary.print("Model is ready", style="green") + messages = [] + while True: + try: + message = input("uesr: ") + messages.append(dict(role="user", content=message)) + print("assistant: ", end="") + assistant_message = "" + async for text in client.chat(messages=messages): # type: ignore + assistant_message += text + print(text, end="") + messages.append(dict(role="assistant", content=assistant_message)) + print() + + except KeyboardInterrupt: + break + finally: + questionary.print("\nStopping model server...", style="green") + server_proc.terminate() + questionary.print("Stopped model server", style="green") + + +@app.command() +def run(model: str): + asyncio.run(_run_model(model))