mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
revive the chatgpt api endpoint on :8000
This commit is contained in:
14
README.md
14
README.md
@@ -74,7 +74,19 @@ python3 main.py
|
||||
|
||||
That's it! No configuration required - exo will automatically discover the other device(s).
|
||||
|
||||
Until the below is fixed, the only way to access inference is via peer handles. See how it's done in [this example for Llama 3](examples/llama3_distributed.py).
|
||||
The native way to access models running on exo is using the exo library with peer handles. See how in [this example for Llama 3](examples/llama3_distributed.py).
|
||||
|
||||
exo also starts a ChatGPT-compatible API endpoint on http://localhost:8000. Note: this is currently only supported by tail nodes (i.e. nodes selected to be at the end of the ring topology). Example request:
|
||||
|
||||
```
|
||||
curl http://localhost:8000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "llama-3-70b",
|
||||
"messages": [{"role": "user", "content": "What is the meaning of exo?"}],
|
||||
"temperature": 0.7
|
||||
}'
|
||||
```
|
||||
|
||||
// A ChatGPT-like web interface will be available on each device on port 8000 http://localhost:8000 and Chat-GPT-compatible API on port 8001 (currently doesn't work see https://github.com/exo-explore/exo/issues/6).
|
||||
|
||||
|
||||
1
exo/api/__init__.py
Normal file
1
exo/api/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from exo.api.chatgpt_api import ChatGPTAPI
|
||||
104
exo/api/chatgpt_api.py
Normal file
104
exo/api/chatgpt_api.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import uuid
|
||||
import time
|
||||
import asyncio
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from typing import List
|
||||
from aiohttp import web
|
||||
from exo import DEBUG
|
||||
from exo.inference.shard import Shard
|
||||
from exo.orchestration import Node
|
||||
from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
|
||||
|
||||
shard_mappings = {
|
||||
"llama-3-8b": Shard(model_id="mlx-community/Meta-Llama-3-8B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=32),
|
||||
"llama-3-70b": Shard(model_id="mlx-community/Meta-Llama-3-70B-Instruct-4bit", start_layer=0, end_layer=0, n_layers=80),
|
||||
}
|
||||
|
||||
class Message:
|
||||
def __init__(self, role: str, content: str):
|
||||
self.role = role
|
||||
self.content = content
|
||||
|
||||
class ChatCompletionRequest:
|
||||
def __init__(self, model: str, messages: List[Message], temperature: float):
|
||||
self.model = model
|
||||
self.messages = messages
|
||||
self.temperature = temperature
|
||||
|
||||
class ChatGPTAPI:
|
||||
def __init__(self, node: Node):
|
||||
self.node = node
|
||||
self.app = web.Application()
|
||||
self.app.router.add_post('/v1/chat/completions', self.handle_post)
|
||||
|
||||
async def handle_post(self, request):
|
||||
data = await request.json()
|
||||
messages = [Message(**msg) for msg in data['messages']]
|
||||
chat_request = ChatCompletionRequest(data['model'], messages, data['temperature'])
|
||||
prompt = " ".join([msg.content for msg in chat_request.messages if msg.role == "user"])
|
||||
shard = shard_mappings.get(chat_request.model)
|
||||
if not shard:
|
||||
return web.json_response({'detail': f"Invalid model: {chat_request.model}. Supported: {list(shard_mappings.keys())}"}, status=400)
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
tokenizer = load_tokenizer(get_model_path(shard.model_id))
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
chat_request.messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}")
|
||||
try:
|
||||
result = await self.node.process_prompt(shard, prompt, request_id=request_id)
|
||||
except Exception as e:
|
||||
pass # TODO
|
||||
# return web.json_response({'detail': str(e)}, status=500)
|
||||
|
||||
# poll for the response. TODO: implement callback for specific request id
|
||||
timeout = 90
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
print("poll")
|
||||
try:
|
||||
result, is_finished = await self.node.get_inference_result(request_id)
|
||||
except Exception as e:
|
||||
continue
|
||||
await asyncio.sleep(0.1)
|
||||
if is_finished:
|
||||
return web.json_response({
|
||||
"id": f"chatcmpl-{request_id}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": chat_request.model,
|
||||
"usage": {
|
||||
"prompt_tokens": len(tokenizer.encode(prompt)),
|
||||
"completion_tokens": len(result),
|
||||
"total_tokens": len(tokenizer.encode(prompt)) + len(result)
|
||||
},
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": tokenizer.decode(result)
|
||||
},
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop",
|
||||
"index": 0
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
return web.json_response({'detail': "Response generation timed out"}, status=408)
|
||||
|
||||
async def run(self, host: str = "0.0.0.0", port: int = 8000):
|
||||
runner = web.AppRunner(self.app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, host, port)
|
||||
await site.start()
|
||||
print(f"Starting ChatGPT API server at {host}:{port}")
|
||||
|
||||
# Usage example
|
||||
if __name__ == "__main__":
|
||||
loop = asyncio.get_event_loop()
|
||||
node = Node() # Assuming Node is properly defined elsewhere
|
||||
api = ChatGPTAPI(node)
|
||||
loop.run_until_complete(api.run())
|
||||
@@ -14,11 +14,11 @@ class Node(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def process_prompt(self, shard: Shard, prompt: str) -> Optional[np.ndarray]:
|
||||
async def process_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.ndarray]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def process_tensor(self, shard: Shard, tensor: np.ndarray) -> Optional[np.ndarray]:
|
||||
async def process_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.ndarray]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -12,7 +12,7 @@ import asyncio
|
||||
import uuid
|
||||
|
||||
class StandardNode(Node):
|
||||
def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery, partitioning_strategy: PartitioningStrategy = None, on_token: Callable[[List[int]], None] = None, max_generate_tokens: int = 50):
|
||||
def __init__(self, id: str, server: Server, inference_engine: InferenceEngine, discovery: Discovery, partitioning_strategy: PartitioningStrategy = None, on_token: Callable[[List[int]], None] = None, max_generate_tokens: int = 256):
|
||||
self.id = id
|
||||
self.inference_engine = inference_engine
|
||||
self.server = server
|
||||
@@ -50,7 +50,7 @@ class StandardNode(Node):
|
||||
return
|
||||
|
||||
result, is_finished = await self.inference_engine.infer_prompt(self.get_current_shard(shard), prompt)
|
||||
is_finished = is_finished or len(self.buffered_token_output[request_id]) >= self.max_generate_tokens
|
||||
is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
|
||||
if is_finished:
|
||||
self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
|
||||
|
||||
@@ -74,7 +74,7 @@ class StandardNode(Node):
|
||||
try:
|
||||
if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
|
||||
result, is_finished = await self.inference_engine.infer_tensor(self.get_current_shard(shard), tensor)
|
||||
is_finished = is_finished or len(self.buffered_token_output[request_id]) >= self.max_generate_tokens
|
||||
is_finished = is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
|
||||
if is_finished:
|
||||
self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
|
||||
|
||||
|
||||
5
main.py
5
main.py
@@ -11,6 +11,8 @@ from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceE
|
||||
from exo.inference.shard import Shard
|
||||
from exo.networking.grpc.grpc_discovery import GRPCDiscovery
|
||||
from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
|
||||
from exo.api import ChatGPTAPI
|
||||
|
||||
|
||||
# parse args
|
||||
parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
|
||||
@@ -20,6 +22,7 @@ parser.add_argument("--node-port", type=int, default=8080, help="Node port")
|
||||
parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
|
||||
parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
|
||||
parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
|
||||
parser.add_argument("--chatgpt-api-port", type=int, default=8000, help="ChatGPT API port")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
@@ -32,6 +35,7 @@ node = StandardNode(args.node_id, None, inference_engine, discovery, partitionin
|
||||
server = GRPCServer(node, args.node_host, args.node_port)
|
||||
node.server = server
|
||||
|
||||
api = ChatGPTAPI(node)
|
||||
|
||||
async def shutdown(signal, loop):
|
||||
"""Gracefully shutdown the server and close the asyncio loop."""
|
||||
@@ -54,6 +58,7 @@ async def main():
|
||||
loop.add_signal_handler(s, handle_exit)
|
||||
|
||||
await node.start(wait_for_peers=args.wait_for_peers)
|
||||
asyncio.create_task(api.run(port=args.chatgpt_api_port)) # Start the API server as a non-blocking task
|
||||
|
||||
await asyncio.Event().wait()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user