Compare commits

..

7 Commits

Author SHA1 Message Date
Alex Cheema
d9a836f152 Merge pull request #588 from exo-explore/betterdl
better download
2025-01-05 02:35:04 +00:00
Alex Cheema
29244c6369 fix args for ensure_shard 2025-01-05 02:33:25 +00:00
Alex Cheema
8c191050a2 download status in parallel, support async ensure shard with using shard_downloader instead 2025-01-05 02:31:59 +00:00
Alex Cheema
7b1656140e Merge pull request #585 from pepebruari/main
Add --system-prompt to exo cli
2025-01-03 23:49:50 +00:00
pepebruari
fe50d4d34d Add --system-prompt to exo cli 2025-01-03 16:16:22 -05:00
Alex Cheema
03aa6cecf1 Merge pull request #584 from exo-explore/AlexCheema-patch-1
add trending badge to README.md
2024-12-31 17:51:10 +00:00
Alex Cheema
178cc4d961 add trending badge to README.md 2024-12-31 17:50:29 +00:00
3 changed files with 19 additions and 4 deletions

View File

@@ -18,6 +18,8 @@ exo: Run your own AI cluster at home with everyday devices. Maintained by [exo l
[![Tests](https://dl.circleci.com/status-badge/img/circleci/TrkofJDoGzdQAeL6yVHKsg/4i5hJuafuwZYZQxbRAWS71/tree/main.svg?style=svg)](https://dl.circleci.com/status-badge/redirect/circleci/TrkofJDoGzdQAeL6yVHKsg/4i5hJuafuwZYZQxbRAWS71/tree/main)
[![License: GPL v3](https://img.shields.io/badge/License-GPLv3-blue.svg)](https://www.gnu.org/licenses/gpl-3.0)
<a href="https://trendshift.io/repositories/11849" target="_blank"><img src="https://trendshift.io/api/badge/repositories/11849" alt="exo-explore%2Fexo | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
</div>
---

View File

@@ -160,7 +160,7 @@ class PromptSession:
self.prompt = prompt
class ChatGPTAPI:
def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None):
def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None, system_prompt: Optional[str] = None):
self.node = node
self.inference_engine_classname = inference_engine_classname
self.response_timeout = response_timeout
@@ -170,6 +170,7 @@ class ChatGPTAPI:
self.prev_token_lens: Dict[str, int] = {}
self.stream_tasks: Dict[str, asyncio.Task] = {}
self.default_model = default_model or "llama-3.2-1b"
self.system_prompt = system_prompt
cors = aiohttp_cors.setup(self.app)
cors_options = aiohttp_cors.ResourceOptions(
@@ -244,7 +245,7 @@ class ChatGPTAPI:
)
await response.prepare(request)
for model_name, pretty in pretty_name.items():
async def process_model(model_name, pretty):
if model_name in model_cards:
model_info = model_cards[model_name]
@@ -272,6 +273,12 @@ class ChatGPTAPI:
await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
# Process all models in parallel
await asyncio.gather(*[
process_model(model_name, pretty)
for model_name, pretty in pretty_name.items()
])
await response.write(b"data: [DONE]\n\n")
return response
@@ -336,6 +343,10 @@ class ChatGPTAPI:
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
# Add system prompt if set
if self.system_prompt and not any(msg.role == "system" for msg in chat_request.messages):
chat_request.messages.insert(0, Message("system", self.system_prompt))
prompt = build_prompt(tokenizer, chat_request.messages, chat_request.tools)
request_id = str(uuid.uuid4())
if self.on_chat_completion_request:
@@ -557,7 +568,7 @@ class ChatGPTAPI:
if model_name not in model_cards: return web.json_response({"error": f"Invalid model: {model_name}. Supported models: {list(model_cards.keys())}"}, status=400)
shard = build_base_shard(model_name, self.inference_engine_classname)
if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400)
asyncio.create_task(self.node.inference_engine.ensure_shard(shard))
asyncio.create_task(self.node.inference_engine.shard_downloader.ensure_shard(shard, self.inference_engine_classname))
return web.json_response({
"status": "success",

View File

@@ -69,6 +69,7 @@ parser.add_argument("--default-temp", type=float, help="Default token sampling t
parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
parser.add_argument("--node-id-filter", type=str, default=None, help="Comma separated list of allowed node IDs (only for UDP and Tailscale discovery)")
parser.add_argument("--system-prompt", type=str, default=None, help="System prompt for the ChatGPT API")
args = parser.parse_args()
print(f"Selected inference engine: {args.inference_engine}")
@@ -146,7 +147,8 @@ api = ChatGPTAPI(
inference_engine.__class__.__name__,
response_timeout=args.chatgpt_api_response_timeout,
on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None,
default_model=args.default_model
default_model=args.default_model,
system_prompt=args.system_prompt
)
node.on_token.register("update_topology_viz").on_next(
lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None