revive the chatgpt api endpoint on :8000

This commit is contained in:
Alex Cheema
2024-07-16 00:17:23 -07:00
parent 1d5c28aed4
commit f2895cbcee
6 changed files with 128 additions and 6 deletions

View File

@@ -74,7 +74,19 @@ python3 main.py
That's it! No configuration required - exo will automatically discover the other device(s).
Until the below is fixed, the only way to access inference is via peer handles. See how it's done in [this example for Llama 3](examples/llama3_distributed.py).
The native way to access models running on exo is using the exo library with peer handles. See how in [this example for Llama 3](examples/llama3_distributed.py).
exo also starts a ChatGPT-compatible API endpoint on http://localhost:8000. Note: this is currently only supported by tail nodes (i.e. nodes selected to be at the end of the ring topology). Example request:
```
curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "llama-3-70b",
"messages": [{"role": "user", "content": "What is the meaning of exo?"}],
"temperature": 0.7
}'
```
// A ChatGPT-like web interface will be available on each device on port 8000 http://localhost:8000 and Chat-GPT-compatible API on port 8001 (currently doesn't work see https://github.com/exo-explore/exo/issues/6).

1
exo/api/__init__.py Normal file
View File

@@ -0,0 +1 @@
from exo.api.chatgpt_api import ChatGPTAPI

104
exo/api/chatgpt_api.py Normal file
View File

@@ -0,0 +1,104 @@
import uuid
import time
import asyncio
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import List
from aiohttp import web
from exo import DEBUG
from exo.inference.shard import Shard
from exo.orchestration import Node
from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
shard_mappings = {
"llama-3-8b": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
"llama-3-70b": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
}
class Message:
def __init__(self, role: str, content: str):
self.role = role
self.content = content
class ChatCompletionRequest:
def __init__(self, model: str, messages: List[Message], temperature: float):
self.model = model
self.messages = messages
self.temperature = temperature
class ChatGPTAPI:
def __init__(self, node: Node):
self.node = node
self.app = web.Application()
self.app.router.add_post('/v1/chat/completions', self.handle_post)
async def handle_post(self, request):
data = await request.json()
messages = [Message(**msg) for msg in data['messages']]
chat_request = ChatCompletionRequest(data['model'], messages, data['temperature'])
prompt = " ".join([msg.content for msg in chat_request.messages if msg.role == "user"])
shard = shard_mappings.get(chat_request.model)
if not shard:
return web.json_response({'detail': f"Invalid model: {chat_request.model}. Supported: {list(shard_mappings.keys())}"}, status=400)
request_id = str(uuid.uuid4())
tokenizer = load_tokenizer(get_model_path(shard.model_id))
prompt = tokenizer.apply_chat_template(
chat_request.messages, tokenize=False, add_generation_prompt=True
)
if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}")
try:
result = await self.node.process_prompt(shard, prompt, request_id=request_id)
except Exception as e:
pass # TODO
# return web.json_response({'detail': str(e)}, status=500)
# poll for the response. TODO: implement callback for specific request id
timeout = 90
start_time = time.time()
while time.time() - start_time < timeout:
print("poll")
try:
result, is_finished = await self.node.get_inference_result(request_id)
except Exception as e:
continue
await asyncio.sleep(0.1)
if is_finished:
return web.json_response({
"id": f"chatcmpl-{request_id}",
"object": "chat.completion",
"created": int(time.time()),
"model": chat_request.model,
"usage": {
"prompt_tokens": len(tokenizer.encode(prompt)),
"completion_tokens": len(result),
"total_tokens": len(tokenizer.encode(prompt)) + len(result)
},
"choices": [
{
"message": {
"role": "assistant",
"content": tokenizer.decode(result)
},
"logprobs": None,
"finish_reason": "stop",
"index": 0
}
]
})
return web.json_response({'detail': "Response generation timed out"}, status=408)
async def run(self, host: str = "0.0.0.0", port: int = 8000):
runner = web.AppRunner(self.app)
await runner.setup()
site = web.TCPSite(runner, host, port)
await site.start()
print(f"Starting ChatGPT API server at {host}:{port}")
# Usage example
if __name__ == "__main__":
loop = asyncio.get_event_loop()
node = Node() # Assuming Node is properly defined elsewhere
api = ChatGPTAPI(node)
loop.run_until_complete(api.run())

View File

@@ -14,11 +14,11 @@ class Node(ABC):
pass
@abstractmethod
async def process_prompt(self, shard: Shard, prompt: str) -> Optional[np.ndarray]:
async def process_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.ndarray]:
pass
@abstractmethod
async def process_tensor(self, shard: Shard, tensor: np.ndarray) -> Optional[np.ndarray]:
async def process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.ndarray]:
pass
@abstractmethod

View File

@@ -12,7 +12,7 @@ import asyncio
import uuid
class StandardNode(Node):
def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery, partitioning_strategy: PartitioningStrategy = None, on_token: Callable[[List[int]], None] = None, max_generate_tokens: int = 50):
def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery, partitioning_strategy: PartitioningStrategy = None, on_token: Callable[[List[int]], None] = None, max_generate_tokens: int = 256):
self.id = id
self.inference_engine = inference_engine
self.server = server
@@ -50,7 +50,7 @@ class StandardNode(Node):
return
result, is_finished = await self.inference_engine.infer_prompt(self.get_current_shard(shard), prompt)
is_finished = is_finished or len(self.buffered_token_output[request_id]) >= self.max_generate_tokens
is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
if is_finished:
self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
@@ -74,7 +74,7 @@ class StandardNode(Node):
try:
if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
result, is_finished = await self.inference_engine.infer_tensor(self.get_current_shard(shard), tensor)
is_finished = is_finished or len(self.buffered_token_output[request_id]) >= self.max_generate_tokens
is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
if is_finished:
self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)

View File

@@ -11,6 +11,8 @@ from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceE
from exo.inference.shard import Shard
from exo.networking.grpc.grpc_discovery import GRPCDiscovery
from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
from exo.api import ChatGPTAPI
# parse args
parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
@@ -20,6 +22,7 @@ parser.add_argument("--node-port", type=int, default=8080, help="Node port")
parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
args = parser.parse_args()
@@ -32,6 +35,7 @@ node = StandardNode(args.node_id, None, inference_engine, discovery, partitionin
server = GRPCServer(node, args.node_host, args.node_port)
node.server = server
api = ChatGPTAPI(node)
async def shutdown(signal, loop):
"""Gracefully shutdown the server and close the asyncio loop."""
@@ -54,6 +58,7 @@ async def main():
loop.add_signal_handler(s, handle_exit)
await node.start(wait_for_peers=args.wait_for_peers)
asyncio.create_task(api.run(port=args.chatgpt_api_port)) # Start the API server as a non-blocking task
await asyncio.Event().wait()