Compare commits

..

39 Commits

Author SHA1 Message Date
Alex Cheema
3c9efe103d Merge pull request #590 from metaspartan/fix-models-api
Fix the /v1/models API to output proper OpenAI compatible endpoint
2025-01-07 02:32:06 +00:00
Carsen Klock
627bfcae7c Fix the /v1/models API to output proper OpenAI compatible endpoint
Modify the `/v1/models` API to output a proper OpenAI compatible endpoint with an object and a `data` object containing the models list.

* Change the `handle_get_models` method in `exo/api/chatgpt_api.py` to wrap the models list in an object with a `data` field.
* Add an `object` field with the value "list" to the response format.

---

For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/metaspartan/exo?shareId=XXXX-XXXX-XXXX-XXXX).
2025-01-06 01:20:30 -07:00
Alex Cheema
d9a836f152 Merge pull request #588 from exo-explore/betterdl
better download
2025-01-05 02:35:04 +00:00
Alex Cheema
29244c6369 fix args for ensure_shard 2025-01-05 02:33:25 +00:00
Alex Cheema
8c191050a2 download status in parallel, support async ensure shard with using shard_downloader instead 2025-01-05 02:31:59 +00:00
Alex Cheema
7b1656140e Merge pull request #585 from pepebruari/main
Add --system-prompt to exo cli
2025-01-03 23:49:50 +00:00
pepebruari
fe50d4d34d Add --system-prompt to exo cli 2025-01-03 16:16:22 -05:00
Alex Cheema
03aa6cecf1 Merge pull request #584 from exo-explore/AlexCheema-patch-1
add trending badge to README.md
2024-12-31 17:51:10 +00:00
Alex Cheema
178cc4d961 add trending badge to README.md 2024-12-31 17:50:29 +00:00
Alex Cheema
a174c78004 Merge pull request #383 from ianpaul10/feat/manual-disc-follow-up
Support changing manual configuration while running
2024-12-28 11:57:25 +00:00
Ian Paul
b003292b89 formatting and fixing tests after rebasing 2024-12-28 12:31:15 +07:00
Ian Paul
1dfd058c23 rm unecessary lock 2024-12-28 12:13:34 +07:00
Ian Paul
2eadaa2c0d rm redundant cleanup task 2024-12-28 12:13:34 +07:00
Ian Paul
637446ffa9 rm redundant typing 2024-12-28 12:13:34 +07:00
Ian Paul
a31f9e6c20 fix test warnings 2024-12-28 12:13:34 +07:00
Ian Paul
18acb97b42 make popping from dict threadsafe 2024-12-28 12:11:51 +07:00
Ian Paul
b066c944f3 make all I/O ops in manual_discovery.py run inside a ThreadPoolExecutor 2024-12-28 12:11:51 +07:00
Ian Paul
0e34ce2169 patch after rebasing to main 2024-12-28 12:11:51 +07:00
Ian Paul
90de7eada9 changes after rebase 2024-12-28 12:11:51 +07:00
Ian Paul
8d24df2b4b fix test runtime warning 2024-12-28 12:11:50 +07:00
Ian Paul
e5eb3259a5 handle when a peer is removed from config, so the known_peers dict gets updated accordingly 2024-12-28 12:11:21 +07:00
Ian Paul
2e8227fccb handle intermediate state for when config is being updated 2024-12-28 12:11:21 +07:00
Ian Paul
98118babae allow update to manual discovery file
re-load manual discovery file for each runthrough of the peer network, allowing incremental updates to the peer file even when exo is running
2024-12-28 12:11:21 +07:00
Alex Cheema
496a3b49f5 Merge pull request #561 from VerisimilitudeX/patch-1
Improved clarity, fixed typos, added macOS/Linux examples, and enhanc…
2024-12-27 17:06:00 +00:00
Alex Cheema
aba1bed5ed Merge pull request #575 from exo-explore/fixtok
Revert "Merge pull request #573 from damho1104/feature/add-exaone-3.5…
2024-12-27 16:36:34 +00:00
Alex Cheema
e08522ee97 Revert "Merge pull request #573 from damho1104/feature/add-exaone-3.5-model"
This reverts commit 4eb6a6a74a, reversing
changes made to fdc3b5ac02.
2024-12-27 16:35:54 +00:00
Alex Cheema
4eb6a6a74a Merge pull request #573 from damho1104/feature/add-exaone-3.5-model
Add exaone-3.5-2.4b, exaone-3.5-7.8b
2024-12-27 12:36:09 +00:00
damho.lee
94a5e908b0 add exaone-3.5 LLM Model 2024-12-24 20:57:11 +09:00
Alex Cheema
fdc3b5ac02 Merge pull request #571 from exo-explore/function_calling
add chatgpt-api-compatible tools for function calling
2024-12-24 02:08:48 +00:00
Alex Cheema
185b1e375c fix names in dummy tokenizer 2024-12-24 02:08:20 +00:00
Alex Cheema
078b807654 fix names of qwen models 2024-12-24 02:06:13 +00:00
Alex Cheema
188ac445c9 function calling example with weather tool 2024-12-24 01:57:17 +00:00
Alex Cheema
456fbdd2b0 add chatgpt-api-compatible tools for function calling 2024-12-24 01:51:55 +00:00
Alex Cheema
41df9ce1d7 Merge pull request #570 from exo-explore/moreqwen
add qwen-2.5-1.5b, qwen-2.5-3b, qwen-2.5-32b
2024-12-24 01:51:26 +00:00
Alex Cheema
c609c05e40 add qwen-2.5-1.5b, qwen-2.5-3b, qwen-2.5-32b 2024-12-24 01:50:12 +00:00
Alex Cheema
ba8c514974 Merge pull request #569 from deftdawg/env_bash
Use `#!/usr/bin/env bash` for better portability
2024-12-22 23:25:38 +00:00
DeftDawg
cde912deef - Use #!/usr/bin/env bash instead of #!/bin/bash for better portability 2024-12-22 01:14:54 -05:00
Piyush Acharya
154e0f58e4 Implement suggestiond 2024-12-21 19:40:53 -08:00
Piyush Acharya
6c82365ee2 Improved clarity, fixed typos, added macOS/Linux examples, and enhanced installation/debugging instructions 2024-12-17 18:02:34 -08:00
13 changed files with 315 additions and 64 deletions

View File

@@ -18,6 +18,8 @@ exo: Run your own AI cluster at home with everyday devices. Maintained by [exo l
[![Tests](https://dl.circleci.com/status-badge/img/circleci/TrkofJDoGzdQAeL6yVHKsg/4i5hJuafuwZYZQxbRAWS71/tree/main.svg?style=svg)](https://dl.circleci.com/status-badge/redirect/circleci/TrkofJDoGzdQAeL6yVHKsg/4i5hJuafuwZYZQxbRAWS71/tree/main)
[![License: GPL v3](https://img.shields.io/badge/License-GPLv3-blue.svg)](https://www.gnu.org/licenses/gpl-3.0)
<a href="https://trendshift.io/repositories/11849" target="_blank"><img src="https://trendshift.io/api/badge/repositories/11849" alt="exo-explore%2Fexo | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
</div>
---
@@ -38,7 +40,7 @@ We also welcome contributions from the community. We have a list of bounties in
### Wide Model Support
exo supports different models including LLaMA ([MLX](exo/inference/mlx/models/llama.py) and [tinygrad](exo/inference/tinygrad/models/llama.py)), Mistral, LlaVA, Qwen and Deepseek.
exo supports different models including LLaMA ([MLX](exo/inference/mlx/models/llama.py) and [tinygrad](exo/inference/tinygrad/models/llama.py)), Mistral, LlaVA, Qwen, and Deepseek.
### Dynamic Model Partitioning
@@ -100,13 +102,13 @@ source install.sh
- There are a number of things users have empirically found to improve performance on Apple Silicon Macs:
1. Upgrade to the latest version of MacOS 15.
1. Upgrade to the latest version of macOS Sequoia.
2. Run `./configure_mlx.sh`. This runs commands to optimize GPU memory allocation on Apple Silicon Macs.
## Documentation
### Example Usage on Multiple MacOS Devices
### Example Usage on Multiple macOS Devices
#### Device 1:
@@ -177,9 +179,9 @@ curl http://localhost:52415/v1/chat/completions \
}'
```
### Example Usage on Multiple Heterogenous Devices (MacOS + Linux)
### Example Usage on Multiple Heterogenous Devices (macOS + Linux)
#### Device 1 (MacOS):
#### Device 1 (macOS):
```sh
exo
@@ -244,7 +246,7 @@ python3 format.py ./exo
## Known Issues
- On some versions of MacOS/Python, certificates are not installed properly which can lead to SSL errors (e.g. SSL error with huggingface.co). To fix this, run the Install Certificates command, usually:
- On certain versions of Python on macOS, certificates may not installed correctly, potentially causing SSL errors (e.g., when accessing huggingface.co). To resolve this, run the `Install Certificates` command, typicall as follows:
```sh
/Applications/Python 3.x/Install Certificates.command

View File

@@ -1,4 +1,4 @@
#!/bin/bash
#!/usr/bin/env bash
# Get the total memory in MB
TOTAL_MEM_MB=$(($(sysctl -n hw.memsize) / 1024 / 1024))

View File

@@ -0,0 +1,111 @@
import json
import re
import requests
def get_current_weather(location: str, unit: str = "celsius"):
"""Mock weather data function"""
# Hardcoded response for demo purposes
return {
"location": location,
"temperature": 22 if unit == "celsius" else 72,
"unit": unit,
"forecast": "Sunny with light clouds"
}
def try_parse_tool_calls(content: str):
"""Try parse the tool calls."""
tool_calls = []
offset = 0
for i, m in enumerate(re.finditer(r"<tool_call>\n(.+)?\n</tool_call>", content)):
if i == 0:
offset = m.start()
try:
func = json.loads(m.group(1))
tool_calls.append({"type": "function", "function": func})
if isinstance(func["arguments"], str):
func["arguments"] = json.loads(func["arguments"])
except json.JSONDecodeError as e:
print(f"Failed to parse tool calls: the content is {m.group(1)} and {e}")
pass
if tool_calls:
if offset > 0 and content[:offset].strip():
c = content[:offset]
else:
c = ""
return {"role": "assistant", "content": c, "tool_calls": tool_calls}
return {"role": "assistant", "content": re.sub(r"<\|im_end\|>$", "", content)}
def chat_completion(messages):
"""Send chat completion request to local server"""
response = requests.post(
"http://localhost:52415/v1/chat/completions",
json={
"model": "qwen-2.5-1.5b",
"messages": messages,
"tools": [{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["location"]
}
}
}],
"tool_choice": "auto"
}
)
return response.json()
def main():
# Initial conversation
messages = [{
"role": "user",
"content": "Hi there, what's the weather in Boston?"
}]
# Get initial response
response = chat_completion(messages)
print(f"First response: {response}")
assistant_message = try_parse_tool_calls(response["choices"][0]["message"]["content"])
messages.append(assistant_message)
# If there are tool calls, execute them and continue conversation
if "tool_calls" in assistant_message:
for tool_call in assistant_message["tool_calls"]:
if tool_call["function"]["name"] == "get_current_weather":
args = tool_call["function"]["arguments"]
weather_data = get_current_weather(**args)
# Add tool response to messages
messages.append({
"role": "tool",
"content": json.dumps(weather_data),
"name": tool_call["function"]["name"]
})
# Get final response with weather data
response = chat_completion(messages)
print(f"Final response: {response}")
messages.append({
"role": "assistant",
"content": response["choices"][0]["message"]["content"]
})
# Print full conversation
for msg in messages:
print(f"\n{msg['role'].upper()}: {msg['content']}")
if __name__ == "__main__":
main()

View File

@@ -5,7 +5,7 @@ import json
import os
from pathlib import Path
from transformers import AutoTokenizer
from typing import List, Literal, Union, Dict
from typing import List, Literal, Union, Dict, Optional
from aiohttp import web
import aiohttp_cors
import traceback
@@ -23,23 +23,28 @@ from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
from exo.apputil import create_animation_mp4
class Message:
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]], tools: Optional[List[Dict]] = None):
self.role = role
self.content = content
self.tools = tools
def to_dict(self):
return {"role": self.role, "content": self.content}
data = {"role": self.role, "content": self.content}
if self.tools:
data["tools"] = self.tools
return data
class ChatCompletionRequest:
def __init__(self, model: str, messages: List[Message], temperature: float):
def __init__(self, model: str, messages: List[Message], temperature: float, tools: Optional[List[Dict]] = None):
self.model = model
self.messages = messages
self.temperature = temperature
self.tools = tools
def to_dict(self):
return {"model": self.model, "messages": [message.to_dict() for message in self.messages], "temperature": self.temperature}
return {"model": self.model, "messages": [message.to_dict() for message in self.messages], "temperature": self.temperature, "tools": self.tools}
def generate_completion(
@@ -119,20 +124,24 @@ def remap_messages(messages: List[Message]) -> List[Message]:
return remapped_messages
def build_prompt(tokenizer, _messages: List[Message]):
def build_prompt(tokenizer, _messages: List[Message], tools: Optional[List[Dict]] = None):
messages = remap_messages(_messages)
prompt = tokenizer.apply_chat_template([m.to_dict() for m in messages], tokenize=False, add_generation_prompt=True)
for message in messages:
if not isinstance(message.content, list):
continue
chat_template_args = {
"conversation": [m.to_dict() for m in messages],
"tokenize": False,
"add_generation_prompt": True
}
if tools: chat_template_args["tools"] = tools
prompt = tokenizer.apply_chat_template(**chat_template_args)
print(f"!!! Prompt: {prompt}")
return prompt
def parse_message(data: dict):
if "role" not in data or "content" not in data:
raise ValueError(f"Invalid message: {data}. Must have 'role' and 'content'")
return Message(data["role"], data["content"])
return Message(data["role"], data["content"], data.get("tools"))
def parse_chat_request(data: dict, default_model: str):
@@ -140,6 +149,7 @@ def parse_chat_request(data: dict, default_model: str):
data.get("model", default_model),
[parse_message(msg) for msg in data["messages"]],
data.get("temperature", 0.0),
data.get("tools", None),
)
@@ -150,7 +160,7 @@ class PromptSession:
self.prompt = prompt
class ChatGPTAPI:
def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None):
def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None, system_prompt: Optional[str] = None):
self.node = node
self.inference_engine_classname = inference_engine_classname
self.response_timeout = response_timeout
@@ -160,6 +170,7 @@ class ChatGPTAPI:
self.prev_token_lens: Dict[str, int] = {}
self.stream_tasks: Dict[str, asyncio.Task] = {}
self.default_model = default_model or "llama-3.2-1b"
self.system_prompt = system_prompt
cors = aiohttp_cors.setup(self.app)
cors_options = aiohttp_cors.ResourceOptions(
@@ -234,7 +245,7 @@ class ChatGPTAPI:
)
await response.prepare(request)
for model_name, pretty in pretty_name.items():
async def process_model(model_name, pretty):
if model_name in model_cards:
model_info = model_cards[model_name]
@@ -262,6 +273,12 @@ class ChatGPTAPI:
await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
# Process all models in parallel
await asyncio.gather(*[
process_model(model_name, pretty)
for model_name, pretty in pretty_name.items()
])
await response.write(b"data: [DONE]\n\n")
return response
@@ -274,7 +291,8 @@ class ChatGPTAPI:
)
async def handle_get_models(self, request):
return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()])
models_list = [{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()]
return web.json_response({"object": "list", "data": models_list})
async def handle_post_chat_token_encode(self, request):
data = await request.json()
@@ -287,7 +305,7 @@ class ChatGPTAPI:
shard = build_base_shard(model, self.inference_engine_classname)
messages = [parse_message(msg) for msg in data.get("messages", [])]
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
prompt = build_prompt(tokenizer, messages)
prompt = build_prompt(tokenizer, messages, data.get("tools", None))
tokens = tokenizer.encode(prompt)
return web.json_response({
"length": len(prompt),
@@ -326,7 +344,11 @@ class ChatGPTAPI:
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
prompt = build_prompt(tokenizer, chat_request.messages)
# Add system prompt if set
if self.system_prompt and not any(msg.role == "system" for msg in chat_request.messages):
chat_request.messages.insert(0, Message("system", self.system_prompt))
prompt = build_prompt(tokenizer, chat_request.messages, chat_request.tools)
request_id = str(uuid.uuid4())
if self.on_chat_completion_request:
try:
@@ -547,7 +569,7 @@ class ChatGPTAPI:
if model_name not in model_cards: return web.json_response({"error": f"Invalid model: {model_name}. Supported models: {list(model_cards.keys())}"}, status=400)
shard = build_base_shard(model_name, self.inference_engine_classname)
if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400)
asyncio.create_task(self.node.inference_engine.ensure_shard(shard))
asyncio.create_task(self.node.inference_engine.shard_downloader.ensure_shard(shard, self.inference_engine_classname))
return web.json_response({
"status": "success",

View File

@@ -14,7 +14,7 @@ class DummyTokenizer:
self.eos_token_id = 69
self.vocab_size = 1000
def apply_chat_template(self, messages, tokenize=True, add_generation_prompt=True):
def apply_chat_template(self, conversation, tokenize=True, add_generation_prompt=True, tools=None, **kwargs):
return "dummy_tokenized_prompt"
def encode(self, text):

View File

@@ -69,6 +69,7 @@ parser.add_argument("--default-temp", type=float, help="Default token sampling t
parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
parser.add_argument("--node-id-filter", type=str, default=None, help="Comma separated list of allowed node IDs (only for UDP and Tailscale discovery)")
parser.add_argument("--system-prompt", type=str, default=None, help="System prompt for the ChatGPT API")
args = parser.parse_args()
print(f"Selected inference engine: {args.inference_engine}")
@@ -146,7 +147,8 @@ api = ChatGPTAPI(
inference_engine.__class__.__name__,
response_timeout=args.chatgpt_api_response_timeout,
on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None,
default_model=args.default_model
default_model=args.default_model,
system_prompt=args.system_prompt
)
node.on_token.register("update_topology_viz").on_next(
lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None

View File

@@ -92,14 +92,17 @@ model_cards = {
"llava-1.5-7b-hf": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "llava-hf/llava-1.5-7b-hf", }, },
### qwen
"qwen-2.5-0.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-0.5B-Instruct-4bit", }, },
"qwen-2.5-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-1.5B-Instruct-4bit", }, },
"qwen-2.5-coder-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-1.5B-Instruct-4bit", }, },
"qwen-2.5-3b": { "layers": 36, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-3B-Instruct-4bit", }, },
"qwen-2.5-coder-3b": { "layers": 36, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-3B-Instruct-4bit", }, },
"qwen-2.5-coder-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", }, },
"qwen-2.5-coder-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-14B-Instruct-4bit", }, },
"qwen-2.5-coder-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-32B-Instruct-4bit", }, },
"qwen-2.5-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-7B-Instruct-4bit", }, },
"qwen-2.5-coder-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", }, },
"qwen-2.5-math-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-7B-Instruct-4bit", }, },
"qwen-2.5-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-14B-Instruct-4bit", }, },
"qwen-2.5-coder-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-14B-Instruct-4bit", }, },
"qwen-2.5-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-32B-Instruct-4bit", }, },
"qwen-2.5-coder-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-32B-Instruct-4bit", }, },
"qwen-2.5-72b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-72B-Instruct-4bit", }, },
"qwen-2.5-math-72b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-72B-Instruct-4bit", }, },
### nemotron
@@ -133,14 +136,17 @@ pretty_name = {
"deepseek-coder-v2-lite": "Deepseek Coder V2 Lite",
"deepseek-coder-v2.5": "Deepseek Coder V2.5",
"llava-1.5-7b-hf": "LLaVa 1.5 7B (Vision Model)",
"qwen-2.5-1.5b": "Qwen 2.5 1.5B",
"qwen-2.5-coder-1.5b": "Qwen 2.5 Coder 1.5B",
"qwen-2.5-3b": "Qwen 2.5 3B",
"qwen-2.5-coder-3b": "Qwen 2.5 Coder 3B",
"qwen-2.5-coder-7b": "Qwen 2.5 Coder 7B",
"qwen-2.5-coder-14b": "Qwen 2.5 Coder 14B",
"qwen-2.5-coder-32b": "Qwen 2.5 Coder 32B",
"qwen-2.5-7b": "Qwen 2.5 7B",
"qwen-2.5-coder-7b": "Qwen 2.5 Coder 7B",
"qwen-2.5-math-7b": "Qwen 2.5 7B (Math)",
"qwen-2.5-14b": "Qwen 2.5 14B",
"qwen-2.5-coder-14b": "Qwen 2.5 Coder 14B",
"qwen-2.5-32b": "Qwen 2.5 32B",
"qwen-2.5-coder-32b": "Qwen 2.5 Coder 32B",
"qwen-2.5-72b": "Qwen 2.5 72B",
"qwen-2.5-math-72b": "Qwen 2.5 72B (Math)",
"llama-3-8b": "Llama 3 8B",

View File

@@ -1,7 +1,9 @@
import os
import asyncio
from exo.networking.discovery import Discovery
from typing import Dict, List, Callable
from typing import Dict, List, Callable, Optional
from concurrent.futures import ThreadPoolExecutor
from exo.networking.discovery import Discovery
from exo.topology.device_capabilities import DeviceCapabilities
from exo.networking.manual.network_topology_config import NetworkTopology, PeerConfig
from exo.helpers import DEBUG_DISCOVERY
@@ -13,28 +15,25 @@ class ManualDiscovery(Discovery):
self,
network_config_path: str,
node_id: str,
create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
create_peer_handle: Callable[[str, str, str, DeviceCapabilities], PeerHandle],
):
self.topology = NetworkTopology.from_path(network_config_path)
self.network_config_path = network_config_path
self.node_id = node_id
self.create_peer_handle = create_peer_handle
if node_id not in self.topology.peers:
raise ValueError(
f"Node ID {node_id} not found in network config file {network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in self.topology.peers]}"
)
self.listen_task = None
self.known_peers: Dict[str, PeerHandle] = {}
self.peers_in_network: Dict[str, PeerConfig] = self.topology.peers
self.peers_in_network.pop(node_id)
self._cached_peers: Dict[str, PeerConfig] = {}
self._last_modified_time: Optional[float] = None
self._file_executor = ThreadPoolExecutor(max_workers=1)
async def start(self) -> None:
self.listen_task = asyncio.create_task(self.task_find_peers_from_config())
async def stop(self) -> None:
if self.listen_task:
self.listen_task.cancel()
if self.listen_task: self.listen_task.cancel()
self._file_executor.shutdown(wait=True)
async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
if wait_for_peers > 0:
@@ -47,7 +46,9 @@ class ManualDiscovery(Discovery):
async def task_find_peers_from_config(self):
if DEBUG_DISCOVERY >= 2: print("Starting task to find peers from config...")
while True:
for peer_id, peer_config in self.peers_in_network.items():
peers_from_config = await self._get_peers()
new_known_peers = {}
for peer_id, peer_config in peers_from_config.items():
try:
if DEBUG_DISCOVERY >= 2: print(f"Checking peer {peer_id=} at {peer_config.address}:{peer_config.port}")
peer = self.known_peers.get(peer_id)
@@ -57,15 +58,44 @@ class ManualDiscovery(Discovery):
is_healthy = await peer.health_check()
if is_healthy:
if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.")
self.known_peers[peer_id] = peer
else:
if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy.")
try:
del self.known_peers[peer_id]
except KeyError:
pass
new_known_peers[peer_id] = peer
elif DEBUG_DISCOVERY >= 2:
print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy. Removing.")
except Exception as e:
if DEBUG_DISCOVERY >= 2: print(f"Exception occured when attempting to add {peer_id=}: {e}")
self.known_peers = new_known_peers
await asyncio.sleep(1.0)
if DEBUG_DISCOVERY >= 2: print(f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}")
async def _get_peers(self):
try:
loop = asyncio.get_running_loop()
current_mtime = await loop.run_in_executor(self._file_executor, os.path.getmtime, self.network_config_path)
if (self._cached_peers is not None and self._last_modified_time is not None and current_mtime <= self._last_modified_time):
return self._cached_peers
topology = await loop.run_in_executor(self._file_executor, NetworkTopology.from_path, self.network_config_path)
if self.node_id not in topology.peers:
raise ValueError(
f"Node ID {self.node_id} not found in network config file "
f"{self.network_config_path}. Please run with `node_id` set to "
f"one of the keys in the config file: {[k for k, _ in topology.peers]}"
)
peers_in_network = topology.peers
peers_in_network.pop(self.node_id)
self._cached_peers = peers_in_network
self._last_modified_time = current_mtime
return peers_in_network
except Exception as e:
if DEBUG_DISCOVERY >= 2:
print(f"Error when loading network config file from {self.network_config_path}. "
f"Please update the config file in order to successfully discover peers. "
f"Exception: {e}")
return self._cached_peers

View File

@@ -29,4 +29,4 @@
}
}
}
}
}

View File

@@ -1,3 +1,4 @@
import json
import asyncio
import unittest
from unittest import mock
@@ -14,8 +15,12 @@ class TestSingleNodeManualDiscovery(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
self.peer1 = mock.AsyncMock()
self.peer1.connect = mock.AsyncMock()
self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1)
_ = self.discovery1.start()
self.discovery1 = ManualDiscovery(
root_path,
"node1",
create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1,
)
await self.discovery1.start()
async def asyncTearDown(self):
await self.discovery1.stop()
@@ -33,8 +38,16 @@ class TestManualDiscovery(unittest.IsolatedAsyncioTestCase):
self.peer2 = mock.AsyncMock()
self.peer1.connect = mock.AsyncMock()
self.peer2.connect = mock.AsyncMock()
self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1)
self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2)
self.discovery1 = ManualDiscovery(
root_path,
"node1",
create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1,
)
self.discovery2 = ManualDiscovery(
root_path,
"node2",
create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2,
)
await self.discovery1.start()
await self.discovery2.start()
@@ -63,8 +76,16 @@ class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
self.server2 = GRPCServer(self.node2, config.peers["node2"].address, config.peers["node2"].port)
await self.server1.start()
await self.server2.start()
self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
self.discovery1 = ManualDiscovery(
root_path,
"node1",
create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
)
self.discovery2 = ManualDiscovery(
root_path,
"node2",
create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
)
await self.discovery1.start()
await self.discovery2.start()
@@ -98,6 +119,63 @@ class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
self.assertFalse(await peers1[0].is_connected())
self.assertFalse(await peers2[0].is_connected())
async def test_dynamic_config_update(self):
initial_peers = await self.discovery1.discover_peers(wait_for_peers=1)
self.assertEqual(len(initial_peers), 1)
# Save original config for cleanup
with open(root_path, "r") as f:
original_config = json.load(f)
try:
updated_config = {
"peers": {
**original_config["peers"],
"node3": {
"address": "localhost",
"port": 50053,
"device_capabilities": {
"model": "Unknown Model",
"chip": "Unknown Chip",
"memory": 0,
"flops": {"fp32": 0, "fp16": 0, "int8": 0},
},
},
}
}
with open(root_path, "w") as f:
json.dump(updated_config, f, indent=2)
node3 = mock.AsyncMock(spec=Node)
server3 = GRPCServer(node3, "localhost", 50053)
await server3.start()
try:
# Wait for the config to be reloaded
await asyncio.sleep(1.5)
updated_peers = await self.discovery1.discover_peers(wait_for_peers=2)
self.assertEqual(len(updated_peers), 2)
for peer in updated_peers:
await peer.connect()
self.assertTrue(await peer.is_connected())
finally:
await server3.stop()
finally:
# Restore the original config file
with open(root_path, "w") as f:
json.dump(original_config, f, indent=2)
# Wait for the config to be reloaded again
await asyncio.sleep(1.5)
updated_peers = await self.discovery1.discover_peers(wait_for_peers=1)
self.assertEqual(len(updated_peers), 1)
if __name__ == "__main__":
asyncio.run(unittest.main())

View File

@@ -1,4 +1,4 @@
#!/bin/bash
#!/usr/bin/env bash
if command -v python3.12 &>/dev/null; then
echo "Python 3.12 is installed, proceeding with python3.12..."

View File

@@ -1,4 +1,4 @@
#!/bin/bash
#!/usr/bin/env bash
source ./install.sh
pushd exo/networking/grpc
python3 -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. node_service.proto

View File

@@ -1,4 +1,4 @@
#!/bin/bash
#!/usr/bin/env bash
echo "Starting node 1"
DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 52415 --chatgpt-api-response-timeout 900 > output1.log 2>&1 &