diff --git a/README.md b/README.md index 48392bbf..02ac0208 100644 --- a/README.md +++ b/README.md @@ -74,7 +74,19 @@ python3 main.py That's it! No configuration required - exo will automatically discover the other device(s). -Until the below is fixed, the only way to access inference is via peer handles. See how it's done in [this example for Llama 3](examples/llama3_distributed.py). +The native way to access models running on exo is using the exo library with peer handles. See how in [this example for Llama 3](examples/llama3_distributed.py). + +exo also starts a ChatGPT-compatible API endpoint on http://localhost:8000. Note: this is currently only supported by tail nodes (i.e. nodes selected to be at the end of the ring topology). Example request: + +``` +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "llama-3-70b", + "messages": [{"role": "user", "content": "What is the meaning of exo?"}], + "temperature": 0.7 + }' +``` // A ChatGPT-like web interface will be available on each device on port 8000 http://localhost:8000 and Chat-GPT-compatible API on port 8001 (currently doesn't work see https://github.com/exo-explore/exo/issues/6). diff --git a/exo/api/__init__.py b/exo/api/__init__.py new file mode 100644 index 00000000..8854f281 --- /dev/null +++ b/exo/api/__init__.py @@ -0,0 +1 @@ +from exo.api.chatgpt_api import ChatGPTAPI diff --git a/exo/api/chatgpt_api.py b/exo/api/chatgpt_api.py new file mode 100644 index 00000000..36162728 --- /dev/null +++ b/exo/api/chatgpt_api.py @@ -0,0 +1,104 @@ +import uuid +import time +import asyncio +from http.server import BaseHTTPRequestHandler, HTTPServer +from typing import List +from aiohttp import web +from exo import DEBUG +from exo.inference.shard import Shard +from exo.orchestration import Node +from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer + +shard_mappings = { + "llama-3-8b": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32), + "llama-3-70b": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80), +} + +class Message: + def __init__(self, role: str, content: str): + self.role = role + self.content = content + +class ChatCompletionRequest: + def __init__(self, model: str, messages: List[Message], temperature: float): + self.model = model + self.messages = messages + self.temperature = temperature + +class ChatGPTAPI: + def __init__(self, node: Node): + self.node = node + self.app = web.Application() + self.app.router.add_post('/v1/chat/completions', self.handle_post) + + async def handle_post(self, request): + data = await request.json() + messages = [Message(**msg) for msg in data['messages']] + chat_request = ChatCompletionRequest(data['model'], messages, data['temperature']) + prompt = " ".join([msg.content for msg in chat_request.messages if msg.role == "user"]) + shard = shard_mappings.get(chat_request.model) + if not shard: + return web.json_response({'detail': f"Invalid model: {chat_request.model}. Supported: {list(shard_mappings.keys())}"}, status=400) + request_id = str(uuid.uuid4()) + + tokenizer = load_tokenizer(get_model_path(shard.model_id)) + prompt = tokenizer.apply_chat_template( + chat_request.messages, tokenize=False, add_generation_prompt=True + ) + + if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}") + try: + result = await self.node.process_prompt(shard, prompt, request_id=request_id) + except Exception as e: + pass # TODO + # return web.json_response({'detail': str(e)}, status=500) + + # poll for the response. TODO: implement callback for specific request id + timeout = 90 + start_time = time.time() + while time.time() - start_time < timeout: + print("poll") + try: + result, is_finished = await self.node.get_inference_result(request_id) + except Exception as e: + continue + await asyncio.sleep(0.1) + if is_finished: + return web.json_response({ + "id": f"chatcmpl-{request_id}", + "object": "chat.completion", + "created": int(time.time()), + "model": chat_request.model, + "usage": { + "prompt_tokens": len(tokenizer.encode(prompt)), + "completion_tokens": len(result), + "total_tokens": len(tokenizer.encode(prompt)) + len(result) + }, + "choices": [ + { + "message": { + "role": "assistant", + "content": tokenizer.decode(result) + }, + "logprobs": None, + "finish_reason": "stop", + "index": 0 + } + ] + }) + + return web.json_response({'detail': "Response generation timed out"}, status=408) + + async def run(self, host: str = "0.0.0.0", port: int = 8000): + runner = web.AppRunner(self.app) + await runner.setup() + site = web.TCPSite(runner, host, port) + await site.start() + print(f"Starting ChatGPT API server at {host}:{port}") + +# Usage example +if __name__ == "__main__": + loop = asyncio.get_event_loop() + node = Node() # Assuming Node is properly defined elsewhere + api = ChatGPTAPI(node) + loop.run_until_complete(api.run()) diff --git a/exo/orchestration/node.py b/exo/orchestration/node.py index 3e02340b..8dac1085 100644 --- a/exo/orchestration/node.py +++ b/exo/orchestration/node.py @@ -14,11 +14,11 @@ class Node(ABC): pass @abstractmethod - async def process_prompt(self, shard: Shard, prompt: str) -> Optional[np.ndarray]: + async def process_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.ndarray]: pass @abstractmethod - async def process_tensor(self, shard: Shard, tensor: np.ndarray) -> Optional[np.ndarray]: + async def process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.ndarray]: pass @abstractmethod diff --git a/exo/orchestration/standard_node.py b/exo/orchestration/standard_node.py index e292afed..5fd194d5 100644 --- a/exo/orchestration/standard_node.py +++ b/exo/orchestration/standard_node.py @@ -12,7 +12,7 @@ import asyncio import uuid class StandardNode(Node): - def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery, partitioning_strategy: PartitioningStrategy = None, on_token: Callable[[List[int]], None] = None, max_generate_tokens: int = 50): + def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery, partitioning_strategy: PartitioningStrategy = None, on_token: Callable[[List[int]], None] = None, max_generate_tokens: int = 256): self.id = id self.inference_engine = inference_engine self.server = server @@ -50,7 +50,7 @@ class StandardNode(Node): return result, is_finished = await self.inference_engine.infer_prompt(self.get_current_shard(shard), prompt) - is_finished = is_finished or len(self.buffered_token_output[request_id]) >= self.max_generate_tokens + is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens if is_finished: self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True) @@ -74,7 +74,7 @@ class StandardNode(Node): try: if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}") result, is_finished = await self.inference_engine.infer_tensor(self.get_current_shard(shard), tensor) - is_finished = is_finished or len(self.buffered_token_output[request_id]) >= self.max_generate_tokens + is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens if is_finished: self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True) diff --git a/main.py b/main.py index 58363caa..4d385d9b 100644 --- a/main.py +++ b/main.py @@ -11,6 +11,8 @@ from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceE from exo.inference.shard import Shard from exo.networking.grpc.grpc_discovery import GRPCDiscovery from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy +from exo.api import ChatGPTAPI + # parse args parser = argparse.ArgumentParser(description="Initialize GRPC Discovery") @@ -20,6 +22,7 @@ parser.add_argument("--node-port", type=int, default=8080, help="Node port") parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery") parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery") parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting") +parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port") args = parser.parse_args() @@ -32,6 +35,7 @@ node = StandardNode(args.node_id, None, inference_engine, discovery, partitionin server = GRPCServer(node, args.node_host, args.node_port) node.server = server +api = ChatGPTAPI(node) async def shutdown(signal, loop): """Gracefully shutdown the server and close the asyncio loop.""" @@ -54,6 +58,7 @@ async def main(): loop.add_signal_handler(s, handle_exit) await node.start(wait_for_peers=args.wait_for_peers) + asyncio.create_task(api.run(port=args.chatgpt_api_port)) # Start the API server as a non-blocking task await asyncio.Event().wait()