mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-24 02:07:17 -05:00
Compare commits
7 Commits
v0.0.2-alp
...
v0.0.3-alp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d9a836f152 | ||
|
|
29244c6369 | ||
|
|
8c191050a2 | ||
|
|
7b1656140e | ||
|
|
fe50d4d34d | ||
|
|
03aa6cecf1 | ||
|
|
178cc4d961 |
@@ -18,6 +18,8 @@ exo: Run your own AI cluster at home with everyday devices. Maintained by [exo l
|
||||
[](https://dl.circleci.com/status-badge/redirect/circleci/TrkofJDoGzdQAeL6yVHKsg/4i5hJuafuwZYZQxbRAWS71/tree/main)
|
||||
[](https://www.gnu.org/licenses/gpl-3.0)
|
||||
|
||||
<a href="https://trendshift.io/repositories/11849" target="_blank"><img src="https://trendshift.io/api/badge/repositories/11849" alt="exo-explore%2Fexo | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
|
||||
</div>
|
||||
|
||||
---
|
||||
|
||||
@@ -160,7 +160,7 @@ class PromptSession:
|
||||
self.prompt = prompt
|
||||
|
||||
class ChatGPTAPI:
|
||||
def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None):
|
||||
def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None, system_prompt: Optional[str] = None):
|
||||
self.node = node
|
||||
self.inference_engine_classname = inference_engine_classname
|
||||
self.response_timeout = response_timeout
|
||||
@@ -170,6 +170,7 @@ class ChatGPTAPI:
|
||||
self.prev_token_lens: Dict[str, int] = {}
|
||||
self.stream_tasks: Dict[str, asyncio.Task] = {}
|
||||
self.default_model = default_model or "llama-3.2-1b"
|
||||
self.system_prompt = system_prompt
|
||||
|
||||
cors = aiohttp_cors.setup(self.app)
|
||||
cors_options = aiohttp_cors.ResourceOptions(
|
||||
@@ -244,7 +245,7 @@ class ChatGPTAPI:
|
||||
)
|
||||
await response.prepare(request)
|
||||
|
||||
for model_name, pretty in pretty_name.items():
|
||||
async def process_model(model_name, pretty):
|
||||
if model_name in model_cards:
|
||||
model_info = model_cards[model_name]
|
||||
|
||||
@@ -272,6 +273,12 @@ class ChatGPTAPI:
|
||||
|
||||
await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
|
||||
|
||||
# Process all models in parallel
|
||||
await asyncio.gather(*[
|
||||
process_model(model_name, pretty)
|
||||
for model_name, pretty in pretty_name.items()
|
||||
])
|
||||
|
||||
await response.write(b"data: [DONE]\n\n")
|
||||
return response
|
||||
|
||||
@@ -336,6 +343,10 @@ class ChatGPTAPI:
|
||||
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
|
||||
if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
|
||||
|
||||
# Add system prompt if set
|
||||
if self.system_prompt and not any(msg.role == "system" for msg in chat_request.messages):
|
||||
chat_request.messages.insert(0, Message("system", self.system_prompt))
|
||||
|
||||
prompt = build_prompt(tokenizer, chat_request.messages, chat_request.tools)
|
||||
request_id = str(uuid.uuid4())
|
||||
if self.on_chat_completion_request:
|
||||
@@ -557,7 +568,7 @@ class ChatGPTAPI:
|
||||
if model_name not in model_cards: return web.json_response({"error": f"Invalid model: {model_name}. Supported models: {list(model_cards.keys())}"}, status=400)
|
||||
shard = build_base_shard(model_name, self.inference_engine_classname)
|
||||
if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400)
|
||||
asyncio.create_task(self.node.inference_engine.ensure_shard(shard))
|
||||
asyncio.create_task(self.node.inference_engine.shard_downloader.ensure_shard(shard, self.inference_engine_classname))
|
||||
|
||||
return web.json_response({
|
||||
"status": "success",
|
||||
|
||||
@@ -69,6 +69,7 @@ parser.add_argument("--default-temp", type=float, help="Default token sampling t
|
||||
parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
|
||||
parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
|
||||
parser.add_argument("--node-id-filter", type=str, default=None, help="Comma separated list of allowed node IDs (only for UDP and Tailscale discovery)")
|
||||
parser.add_argument("--system-prompt", type=str, default=None, help="System prompt for the ChatGPT API")
|
||||
args = parser.parse_args()
|
||||
print(f"Selected inference engine: {args.inference_engine}")
|
||||
|
||||
@@ -146,7 +147,8 @@ api = ChatGPTAPI(
|
||||
inference_engine.__class__.__name__,
|
||||
response_timeout=args.chatgpt_api_response_timeout,
|
||||
on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None,
|
||||
default_model=args.default_model
|
||||
default_model=args.default_model,
|
||||
system_prompt=args.system_prompt
|
||||
)
|
||||
node.on_token.register("update_topology_viz").on_next(
|
||||
lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None
|
||||
|
||||
Reference in New Issue
Block a user