mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-04 12:10:00 -05:00
Compare commits
39 Commits
v1.0
...
v0.0.4-alp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3c9efe103d | ||
|
|
627bfcae7c | ||
|
|
d9a836f152 | ||
|
|
29244c6369 | ||
|
|
8c191050a2 | ||
|
|
7b1656140e | ||
|
|
fe50d4d34d | ||
|
|
03aa6cecf1 | ||
|
|
178cc4d961 | ||
|
|
a174c78004 | ||
|
|
b003292b89 | ||
|
|
1dfd058c23 | ||
|
|
2eadaa2c0d | ||
|
|
637446ffa9 | ||
|
|
a31f9e6c20 | ||
|
|
18acb97b42 | ||
|
|
b066c944f3 | ||
|
|
0e34ce2169 | ||
|
|
90de7eada9 | ||
|
|
8d24df2b4b | ||
|
|
e5eb3259a5 | ||
|
|
2e8227fccb | ||
|
|
98118babae | ||
|
|
496a3b49f5 | ||
|
|
aba1bed5ed | ||
|
|
e08522ee97 | ||
|
|
4eb6a6a74a | ||
|
|
94a5e908b0 | ||
|
|
fdc3b5ac02 | ||
|
|
185b1e375c | ||
|
|
078b807654 | ||
|
|
188ac445c9 | ||
|
|
456fbdd2b0 | ||
|
|
41df9ce1d7 | ||
|
|
c609c05e40 | ||
|
|
ba8c514974 | ||
|
|
cde912deef | ||
|
|
154e0f58e4 | ||
|
|
6c82365ee2 |
14
README.md
14
README.md
@@ -18,6 +18,8 @@ exo: Run your own AI cluster at home with everyday devices. Maintained by [exo l
|
||||
[](https://dl.circleci.com/status-badge/redirect/circleci/TrkofJDoGzdQAeL6yVHKsg/4i5hJuafuwZYZQxbRAWS71/tree/main)
|
||||
[](https://www.gnu.org/licenses/gpl-3.0)
|
||||
|
||||
<a href="https://trendshift.io/repositories/11849" target="_blank"><img src="https://trendshift.io/api/badge/repositories/11849" alt="exo-explore%2Fexo | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
|
||||
</div>
|
||||
|
||||
---
|
||||
@@ -38,7 +40,7 @@ We also welcome contributions from the community. We have a list of bounties in
|
||||
|
||||
### Wide Model Support
|
||||
|
||||
exo supports different models including LLaMA ([MLX](exo/inference/mlx/models/llama.py) and [tinygrad](exo/inference/tinygrad/models/llama.py)), Mistral, LlaVA, Qwen and Deepseek.
|
||||
exo supports different models including LLaMA ([MLX](exo/inference/mlx/models/llama.py) and [tinygrad](exo/inference/tinygrad/models/llama.py)), Mistral, LlaVA, Qwen, and Deepseek.
|
||||
|
||||
### Dynamic Model Partitioning
|
||||
|
||||
@@ -100,13 +102,13 @@ source install.sh
|
||||
|
||||
- There are a number of things users have empirically found to improve performance on Apple Silicon Macs:
|
||||
|
||||
1. Upgrade to the latest version of MacOS 15.
|
||||
1. Upgrade to the latest version of macOS Sequoia.
|
||||
2. Run `./configure_mlx.sh`. This runs commands to optimize GPU memory allocation on Apple Silicon Macs.
|
||||
|
||||
|
||||
## Documentation
|
||||
|
||||
### Example Usage on Multiple MacOS Devices
|
||||
### Example Usage on Multiple macOS Devices
|
||||
|
||||
#### Device 1:
|
||||
|
||||
@@ -177,9 +179,9 @@ curl http://localhost:52415/v1/chat/completions \
|
||||
}'
|
||||
```
|
||||
|
||||
### Example Usage on Multiple Heterogenous Devices (MacOS + Linux)
|
||||
### Example Usage on Multiple Heterogenous Devices (macOS + Linux)
|
||||
|
||||
#### Device 1 (MacOS):
|
||||
#### Device 1 (macOS):
|
||||
|
||||
```sh
|
||||
exo
|
||||
@@ -244,7 +246,7 @@ python3 format.py ./exo
|
||||
|
||||
## Known Issues
|
||||
|
||||
- On some versions of MacOS/Python, certificates are not installed properly which can lead to SSL errors (e.g. SSL error with huggingface.co). To fix this, run the Install Certificates command, usually:
|
||||
- On certain versions of Python on macOS, certificates may not installed correctly, potentially causing SSL errors (e.g., when accessing huggingface.co). To resolve this, run the `Install Certificates` command, typicall as follows:
|
||||
|
||||
```sh
|
||||
/Applications/Python 3.x/Install Certificates.command
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#!/bin/bash
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Get the total memory in MB
|
||||
TOTAL_MEM_MB=$(($(sysctl -n hw.memsize) / 1024 / 1024))
|
||||
|
||||
111
examples/function_calling.py
Normal file
111
examples/function_calling.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import json
|
||||
import re
|
||||
import requests
|
||||
|
||||
def get_current_weather(location: str, unit: str = "celsius"):
|
||||
"""Mock weather data function"""
|
||||
# Hardcoded response for demo purposes
|
||||
return {
|
||||
"location": location,
|
||||
"temperature": 22 if unit == "celsius" else 72,
|
||||
"unit": unit,
|
||||
"forecast": "Sunny with light clouds"
|
||||
}
|
||||
|
||||
def try_parse_tool_calls(content: str):
|
||||
"""Try parse the tool calls."""
|
||||
tool_calls = []
|
||||
offset = 0
|
||||
for i, m in enumerate(re.finditer(r"<tool_call>\n(.+)?\n</tool_call>", content)):
|
||||
if i == 0:
|
||||
offset = m.start()
|
||||
try:
|
||||
func = json.loads(m.group(1))
|
||||
tool_calls.append({"type": "function", "function": func})
|
||||
if isinstance(func["arguments"], str):
|
||||
func["arguments"] = json.loads(func["arguments"])
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Failed to parse tool calls: the content is {m.group(1)} and {e}")
|
||||
pass
|
||||
if tool_calls:
|
||||
if offset > 0 and content[:offset].strip():
|
||||
c = content[:offset]
|
||||
else:
|
||||
c = ""
|
||||
return {"role": "assistant", "content": c, "tool_calls": tool_calls}
|
||||
return {"role": "assistant", "content": re.sub(r"<\|im_end\|>$", "", content)}
|
||||
|
||||
def chat_completion(messages):
|
||||
"""Send chat completion request to local server"""
|
||||
response = requests.post(
|
||||
"http://localhost:52415/v1/chat/completions",
|
||||
json={
|
||||
"model": "qwen-2.5-1.5b",
|
||||
"messages": messages,
|
||||
"tools": [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"]
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}
|
||||
}],
|
||||
"tool_choice": "auto"
|
||||
}
|
||||
)
|
||||
return response.json()
|
||||
|
||||
def main():
|
||||
# Initial conversation
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": "Hi there, what's the weather in Boston?"
|
||||
}]
|
||||
|
||||
# Get initial response
|
||||
response = chat_completion(messages)
|
||||
print(f"First response: {response}")
|
||||
assistant_message = try_parse_tool_calls(response["choices"][0]["message"]["content"])
|
||||
messages.append(assistant_message)
|
||||
|
||||
# If there are tool calls, execute them and continue conversation
|
||||
if "tool_calls" in assistant_message:
|
||||
for tool_call in assistant_message["tool_calls"]:
|
||||
if tool_call["function"]["name"] == "get_current_weather":
|
||||
args = tool_call["function"]["arguments"]
|
||||
weather_data = get_current_weather(**args)
|
||||
|
||||
# Add tool response to messages
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"content": json.dumps(weather_data),
|
||||
"name": tool_call["function"]["name"]
|
||||
})
|
||||
|
||||
# Get final response with weather data
|
||||
response = chat_completion(messages)
|
||||
print(f"Final response: {response}")
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": response["choices"][0]["message"]["content"]
|
||||
})
|
||||
|
||||
# Print full conversation
|
||||
for msg in messages:
|
||||
print(f"\n{msg['role'].upper()}: {msg['content']}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -5,7 +5,7 @@ import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from transformers import AutoTokenizer
|
||||
from typing import List, Literal, Union, Dict
|
||||
from typing import List, Literal, Union, Dict, Optional
|
||||
from aiohttp import web
|
||||
import aiohttp_cors
|
||||
import traceback
|
||||
@@ -23,23 +23,28 @@ from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
|
||||
from exo.apputil import create_animation_mp4
|
||||
|
||||
class Message:
|
||||
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
|
||||
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]], tools: Optional[List[Dict]] = None):
|
||||
self.role = role
|
||||
self.content = content
|
||||
self.tools = tools
|
||||
|
||||
def to_dict(self):
|
||||
return {"role": self.role, "content": self.content}
|
||||
data = {"role": self.role, "content": self.content}
|
||||
if self.tools:
|
||||
data["tools"] = self.tools
|
||||
return data
|
||||
|
||||
|
||||
|
||||
class ChatCompletionRequest:
|
||||
def __init__(self, model: str, messages: List[Message], temperature: float):
|
||||
def __init__(self, model: str, messages: List[Message], temperature: float, tools: Optional[List[Dict]] = None):
|
||||
self.model = model
|
||||
self.messages = messages
|
||||
self.temperature = temperature
|
||||
self.tools = tools
|
||||
|
||||
def to_dict(self):
|
||||
return {"model": self.model, "messages": [message.to_dict() for message in self.messages], "temperature": self.temperature}
|
||||
return {"model": self.model, "messages": [message.to_dict() for message in self.messages], "temperature": self.temperature, "tools": self.tools}
|
||||
|
||||
|
||||
def generate_completion(
|
||||
@@ -119,20 +124,24 @@ def remap_messages(messages: List[Message]) -> List[Message]:
|
||||
return remapped_messages
|
||||
|
||||
|
||||
def build_prompt(tokenizer, _messages: List[Message]):
|
||||
def build_prompt(tokenizer, _messages: List[Message], tools: Optional[List[Dict]] = None):
|
||||
messages = remap_messages(_messages)
|
||||
prompt = tokenizer.apply_chat_template([m.to_dict() for m in messages], tokenize=False, add_generation_prompt=True)
|
||||
for message in messages:
|
||||
if not isinstance(message.content, list):
|
||||
continue
|
||||
chat_template_args = {
|
||||
"conversation": [m.to_dict() for m in messages],
|
||||
"tokenize": False,
|
||||
"add_generation_prompt": True
|
||||
}
|
||||
if tools: chat_template_args["tools"] = tools
|
||||
|
||||
prompt = tokenizer.apply_chat_template(**chat_template_args)
|
||||
print(f"!!! Prompt: {prompt}")
|
||||
return prompt
|
||||
|
||||
|
||||
def parse_message(data: dict):
|
||||
if "role" not in data or "content" not in data:
|
||||
raise ValueError(f"Invalid message: {data}. Must have 'role' and 'content'")
|
||||
return Message(data["role"], data["content"])
|
||||
return Message(data["role"], data["content"], data.get("tools"))
|
||||
|
||||
|
||||
def parse_chat_request(data: dict, default_model: str):
|
||||
@@ -140,6 +149,7 @@ def parse_chat_request(data: dict, default_model: str):
|
||||
data.get("model", default_model),
|
||||
[parse_message(msg) for msg in data["messages"]],
|
||||
data.get("temperature", 0.0),
|
||||
data.get("tools", None),
|
||||
)
|
||||
|
||||
|
||||
@@ -150,7 +160,7 @@ class PromptSession:
|
||||
self.prompt = prompt
|
||||
|
||||
class ChatGPTAPI:
|
||||
def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None):
|
||||
def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None, system_prompt: Optional[str] = None):
|
||||
self.node = node
|
||||
self.inference_engine_classname = inference_engine_classname
|
||||
self.response_timeout = response_timeout
|
||||
@@ -160,6 +170,7 @@ class ChatGPTAPI:
|
||||
self.prev_token_lens: Dict[str, int] = {}
|
||||
self.stream_tasks: Dict[str, asyncio.Task] = {}
|
||||
self.default_model = default_model or "llama-3.2-1b"
|
||||
self.system_prompt = system_prompt
|
||||
|
||||
cors = aiohttp_cors.setup(self.app)
|
||||
cors_options = aiohttp_cors.ResourceOptions(
|
||||
@@ -234,7 +245,7 @@ class ChatGPTAPI:
|
||||
)
|
||||
await response.prepare(request)
|
||||
|
||||
for model_name, pretty in pretty_name.items():
|
||||
async def process_model(model_name, pretty):
|
||||
if model_name in model_cards:
|
||||
model_info = model_cards[model_name]
|
||||
|
||||
@@ -262,6 +273,12 @@ class ChatGPTAPI:
|
||||
|
||||
await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
|
||||
|
||||
# Process all models in parallel
|
||||
await asyncio.gather(*[
|
||||
process_model(model_name, pretty)
|
||||
for model_name, pretty in pretty_name.items()
|
||||
])
|
||||
|
||||
await response.write(b"data: [DONE]\n\n")
|
||||
return response
|
||||
|
||||
@@ -274,7 +291,8 @@ class ChatGPTAPI:
|
||||
)
|
||||
|
||||
async def handle_get_models(self, request):
|
||||
return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()])
|
||||
models_list = [{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()]
|
||||
return web.json_response({"object": "list", "data": models_list})
|
||||
|
||||
async def handle_post_chat_token_encode(self, request):
|
||||
data = await request.json()
|
||||
@@ -287,7 +305,7 @@ class ChatGPTAPI:
|
||||
shard = build_base_shard(model, self.inference_engine_classname)
|
||||
messages = [parse_message(msg) for msg in data.get("messages", [])]
|
||||
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
|
||||
prompt = build_prompt(tokenizer, messages)
|
||||
prompt = build_prompt(tokenizer, messages, data.get("tools", None))
|
||||
tokens = tokenizer.encode(prompt)
|
||||
return web.json_response({
|
||||
"length": len(prompt),
|
||||
@@ -326,7 +344,11 @@ class ChatGPTAPI:
|
||||
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
|
||||
if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
|
||||
|
||||
prompt = build_prompt(tokenizer, chat_request.messages)
|
||||
# Add system prompt if set
|
||||
if self.system_prompt and not any(msg.role == "system" for msg in chat_request.messages):
|
||||
chat_request.messages.insert(0, Message("system", self.system_prompt))
|
||||
|
||||
prompt = build_prompt(tokenizer, chat_request.messages, chat_request.tools)
|
||||
request_id = str(uuid.uuid4())
|
||||
if self.on_chat_completion_request:
|
||||
try:
|
||||
@@ -547,7 +569,7 @@ class ChatGPTAPI:
|
||||
if model_name not in model_cards: return web.json_response({"error": f"Invalid model: {model_name}. Supported models: {list(model_cards.keys())}"}, status=400)
|
||||
shard = build_base_shard(model_name, self.inference_engine_classname)
|
||||
if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400)
|
||||
asyncio.create_task(self.node.inference_engine.ensure_shard(shard))
|
||||
asyncio.create_task(self.node.inference_engine.shard_downloader.ensure_shard(shard, self.inference_engine_classname))
|
||||
|
||||
return web.json_response({
|
||||
"status": "success",
|
||||
|
||||
@@ -14,7 +14,7 @@ class DummyTokenizer:
|
||||
self.eos_token_id = 69
|
||||
self.vocab_size = 1000
|
||||
|
||||
def apply_chat_template(self, messages, tokenize=True, add_generation_prompt=True):
|
||||
def apply_chat_template(self, conversation, tokenize=True, add_generation_prompt=True, tools=None, **kwargs):
|
||||
return "dummy_tokenized_prompt"
|
||||
|
||||
def encode(self, text):
|
||||
|
||||
@@ -69,6 +69,7 @@ parser.add_argument("--default-temp", type=float, help="Default token sampling t
|
||||
parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
|
||||
parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
|
||||
parser.add_argument("--node-id-filter", type=str, default=None, help="Comma separated list of allowed node IDs (only for UDP and Tailscale discovery)")
|
||||
parser.add_argument("--system-prompt", type=str, default=None, help="System prompt for the ChatGPT API")
|
||||
args = parser.parse_args()
|
||||
print(f"Selected inference engine: {args.inference_engine}")
|
||||
|
||||
@@ -146,7 +147,8 @@ api = ChatGPTAPI(
|
||||
inference_engine.__class__.__name__,
|
||||
response_timeout=args.chatgpt_api_response_timeout,
|
||||
on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None,
|
||||
default_model=args.default_model
|
||||
default_model=args.default_model,
|
||||
system_prompt=args.system_prompt
|
||||
)
|
||||
node.on_token.register("update_topology_viz").on_next(
|
||||
lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None
|
||||
|
||||
@@ -92,14 +92,17 @@ model_cards = {
|
||||
"llava-1.5-7b-hf": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "llava-hf/llava-1.5-7b-hf", }, },
|
||||
### qwen
|
||||
"qwen-2.5-0.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-0.5B-Instruct-4bit", }, },
|
||||
"qwen-2.5-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-1.5B-Instruct-4bit", }, },
|
||||
"qwen-2.5-coder-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-1.5B-Instruct-4bit", }, },
|
||||
"qwen-2.5-3b": { "layers": 36, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-3B-Instruct-4bit", }, },
|
||||
"qwen-2.5-coder-3b": { "layers": 36, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-3B-Instruct-4bit", }, },
|
||||
"qwen-2.5-coder-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", }, },
|
||||
"qwen-2.5-coder-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-14B-Instruct-4bit", }, },
|
||||
"qwen-2.5-coder-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-32B-Instruct-4bit", }, },
|
||||
"qwen-2.5-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-7B-Instruct-4bit", }, },
|
||||
"qwen-2.5-coder-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", }, },
|
||||
"qwen-2.5-math-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-7B-Instruct-4bit", }, },
|
||||
"qwen-2.5-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-14B-Instruct-4bit", }, },
|
||||
"qwen-2.5-coder-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-14B-Instruct-4bit", }, },
|
||||
"qwen-2.5-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-32B-Instruct-4bit", }, },
|
||||
"qwen-2.5-coder-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-32B-Instruct-4bit", }, },
|
||||
"qwen-2.5-72b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-72B-Instruct-4bit", }, },
|
||||
"qwen-2.5-math-72b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-72B-Instruct-4bit", }, },
|
||||
### nemotron
|
||||
@@ -133,14 +136,17 @@ pretty_name = {
|
||||
"deepseek-coder-v2-lite": "Deepseek Coder V2 Lite",
|
||||
"deepseek-coder-v2.5": "Deepseek Coder V2.5",
|
||||
"llava-1.5-7b-hf": "LLaVa 1.5 7B (Vision Model)",
|
||||
"qwen-2.5-1.5b": "Qwen 2.5 1.5B",
|
||||
"qwen-2.5-coder-1.5b": "Qwen 2.5 Coder 1.5B",
|
||||
"qwen-2.5-3b": "Qwen 2.5 3B",
|
||||
"qwen-2.5-coder-3b": "Qwen 2.5 Coder 3B",
|
||||
"qwen-2.5-coder-7b": "Qwen 2.5 Coder 7B",
|
||||
"qwen-2.5-coder-14b": "Qwen 2.5 Coder 14B",
|
||||
"qwen-2.5-coder-32b": "Qwen 2.5 Coder 32B",
|
||||
"qwen-2.5-7b": "Qwen 2.5 7B",
|
||||
"qwen-2.5-coder-7b": "Qwen 2.5 Coder 7B",
|
||||
"qwen-2.5-math-7b": "Qwen 2.5 7B (Math)",
|
||||
"qwen-2.5-14b": "Qwen 2.5 14B",
|
||||
"qwen-2.5-coder-14b": "Qwen 2.5 Coder 14B",
|
||||
"qwen-2.5-32b": "Qwen 2.5 32B",
|
||||
"qwen-2.5-coder-32b": "Qwen 2.5 Coder 32B",
|
||||
"qwen-2.5-72b": "Qwen 2.5 72B",
|
||||
"qwen-2.5-math-72b": "Qwen 2.5 72B (Math)",
|
||||
"llama-3-8b": "Llama 3 8B",
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import os
|
||||
import asyncio
|
||||
from exo.networking.discovery import Discovery
|
||||
from typing import Dict, List, Callable
|
||||
from typing import Dict, List, Callable, Optional
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from exo.networking.discovery import Discovery
|
||||
from exo.topology.device_capabilities import DeviceCapabilities
|
||||
from exo.networking.manual.network_topology_config import NetworkTopology, PeerConfig
|
||||
from exo.helpers import DEBUG_DISCOVERY
|
||||
@@ -13,28 +15,25 @@ class ManualDiscovery(Discovery):
|
||||
self,
|
||||
network_config_path: str,
|
||||
node_id: str,
|
||||
create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
|
||||
create_peer_handle: Callable[[str, str, str, DeviceCapabilities], PeerHandle],
|
||||
):
|
||||
self.topology = NetworkTopology.from_path(network_config_path)
|
||||
self.network_config_path = network_config_path
|
||||
self.node_id = node_id
|
||||
self.create_peer_handle = create_peer_handle
|
||||
|
||||
if node_id not in self.topology.peers:
|
||||
raise ValueError(
|
||||
f"Node ID {node_id} not found in network config file {network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in self.topology.peers]}"
|
||||
)
|
||||
|
||||
self.listen_task = None
|
||||
|
||||
self.known_peers: Dict[str, PeerHandle] = {}
|
||||
self.peers_in_network: Dict[str, PeerConfig] = self.topology.peers
|
||||
self.peers_in_network.pop(node_id)
|
||||
|
||||
self._cached_peers: Dict[str, PeerConfig] = {}
|
||||
self._last_modified_time: Optional[float] = None
|
||||
self._file_executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
async def start(self) -> None:
|
||||
self.listen_task = asyncio.create_task(self.task_find_peers_from_config())
|
||||
|
||||
async def stop(self) -> None:
|
||||
if self.listen_task:
|
||||
self.listen_task.cancel()
|
||||
if self.listen_task: self.listen_task.cancel()
|
||||
self._file_executor.shutdown(wait=True)
|
||||
|
||||
async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
|
||||
if wait_for_peers > 0:
|
||||
@@ -47,7 +46,9 @@ class ManualDiscovery(Discovery):
|
||||
async def task_find_peers_from_config(self):
|
||||
if DEBUG_DISCOVERY >= 2: print("Starting task to find peers from config...")
|
||||
while True:
|
||||
for peer_id, peer_config in self.peers_in_network.items():
|
||||
peers_from_config = await self._get_peers()
|
||||
new_known_peers = {}
|
||||
for peer_id, peer_config in peers_from_config.items():
|
||||
try:
|
||||
if DEBUG_DISCOVERY >= 2: print(f"Checking peer {peer_id=} at {peer_config.address}:{peer_config.port}")
|
||||
peer = self.known_peers.get(peer_id)
|
||||
@@ -57,15 +58,44 @@ class ManualDiscovery(Discovery):
|
||||
is_healthy = await peer.health_check()
|
||||
if is_healthy:
|
||||
if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.")
|
||||
self.known_peers[peer_id] = peer
|
||||
else:
|
||||
if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy.")
|
||||
try:
|
||||
del self.known_peers[peer_id]
|
||||
except KeyError:
|
||||
pass
|
||||
new_known_peers[peer_id] = peer
|
||||
elif DEBUG_DISCOVERY >= 2:
|
||||
print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy. Removing.")
|
||||
except Exception as e:
|
||||
if DEBUG_DISCOVERY >= 2: print(f"Exception occured when attempting to add {peer_id=}: {e}")
|
||||
self.known_peers = new_known_peers
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
if DEBUG_DISCOVERY >= 2: print(f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}")
|
||||
|
||||
async def _get_peers(self):
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
current_mtime = await loop.run_in_executor(self._file_executor, os.path.getmtime, self.network_config_path)
|
||||
|
||||
if (self._cached_peers is not None and self._last_modified_time is not None and current_mtime <= self._last_modified_time):
|
||||
return self._cached_peers
|
||||
|
||||
topology = await loop.run_in_executor(self._file_executor, NetworkTopology.from_path, self.network_config_path)
|
||||
|
||||
if self.node_id not in topology.peers:
|
||||
raise ValueError(
|
||||
f"Node ID {self.node_id} not found in network config file "
|
||||
f"{self.network_config_path}. Please run with `node_id` set to "
|
||||
f"one of the keys in the config file: {[k for k, _ in topology.peers]}"
|
||||
)
|
||||
|
||||
peers_in_network = topology.peers
|
||||
peers_in_network.pop(self.node_id)
|
||||
|
||||
self._cached_peers = peers_in_network
|
||||
self._last_modified_time = current_mtime
|
||||
|
||||
return peers_in_network
|
||||
|
||||
except Exception as e:
|
||||
if DEBUG_DISCOVERY >= 2:
|
||||
print(f"Error when loading network config file from {self.network_config_path}. "
|
||||
f"Please update the config file in order to successfully discover peers. "
|
||||
f"Exception: {e}")
|
||||
return self._cached_peers
|
||||
|
||||
@@ -29,4 +29,4 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import asyncio
|
||||
import unittest
|
||||
from unittest import mock
|
||||
@@ -14,8 +15,12 @@ class TestSingleNodeManualDiscovery(unittest.IsolatedAsyncioTestCase):
|
||||
async def asyncSetUp(self):
|
||||
self.peer1 = mock.AsyncMock()
|
||||
self.peer1.connect = mock.AsyncMock()
|
||||
self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1)
|
||||
_ = self.discovery1.start()
|
||||
self.discovery1 = ManualDiscovery(
|
||||
root_path,
|
||||
"node1",
|
||||
create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1,
|
||||
)
|
||||
await self.discovery1.start()
|
||||
|
||||
async def asyncTearDown(self):
|
||||
await self.discovery1.stop()
|
||||
@@ -33,8 +38,16 @@ class TestManualDiscovery(unittest.IsolatedAsyncioTestCase):
|
||||
self.peer2 = mock.AsyncMock()
|
||||
self.peer1.connect = mock.AsyncMock()
|
||||
self.peer2.connect = mock.AsyncMock()
|
||||
self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1)
|
||||
self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2)
|
||||
self.discovery1 = ManualDiscovery(
|
||||
root_path,
|
||||
"node1",
|
||||
create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1,
|
||||
)
|
||||
self.discovery2 = ManualDiscovery(
|
||||
root_path,
|
||||
"node2",
|
||||
create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2,
|
||||
)
|
||||
await self.discovery1.start()
|
||||
await self.discovery2.start()
|
||||
|
||||
@@ -63,8 +76,16 @@ class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
|
||||
self.server2 = GRPCServer(self.node2, config.peers["node2"].address, config.peers["node2"].port)
|
||||
await self.server1.start()
|
||||
await self.server2.start()
|
||||
self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
|
||||
self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
|
||||
self.discovery1 = ManualDiscovery(
|
||||
root_path,
|
||||
"node1",
|
||||
create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
|
||||
)
|
||||
self.discovery2 = ManualDiscovery(
|
||||
root_path,
|
||||
"node2",
|
||||
create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
|
||||
)
|
||||
await self.discovery1.start()
|
||||
await self.discovery2.start()
|
||||
|
||||
@@ -98,6 +119,63 @@ class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertFalse(await peers1[0].is_connected())
|
||||
self.assertFalse(await peers2[0].is_connected())
|
||||
|
||||
async def test_dynamic_config_update(self):
|
||||
initial_peers = await self.discovery1.discover_peers(wait_for_peers=1)
|
||||
self.assertEqual(len(initial_peers), 1)
|
||||
|
||||
# Save original config for cleanup
|
||||
with open(root_path, "r") as f:
|
||||
original_config = json.load(f)
|
||||
|
||||
try:
|
||||
updated_config = {
|
||||
"peers": {
|
||||
**original_config["peers"],
|
||||
"node3": {
|
||||
"address": "localhost",
|
||||
"port": 50053,
|
||||
"device_capabilities": {
|
||||
"model": "Unknown Model",
|
||||
"chip": "Unknown Chip",
|
||||
"memory": 0,
|
||||
"flops": {"fp32": 0, "fp16": 0, "int8": 0},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
with open(root_path, "w") as f:
|
||||
json.dump(updated_config, f, indent=2)
|
||||
|
||||
node3 = mock.AsyncMock(spec=Node)
|
||||
server3 = GRPCServer(node3, "localhost", 50053)
|
||||
await server3.start()
|
||||
|
||||
try:
|
||||
# Wait for the config to be reloaded
|
||||
await asyncio.sleep(1.5)
|
||||
|
||||
updated_peers = await self.discovery1.discover_peers(wait_for_peers=2)
|
||||
self.assertEqual(len(updated_peers), 2)
|
||||
|
||||
for peer in updated_peers:
|
||||
await peer.connect()
|
||||
self.assertTrue(await peer.is_connected())
|
||||
|
||||
finally:
|
||||
await server3.stop()
|
||||
|
||||
finally:
|
||||
# Restore the original config file
|
||||
with open(root_path, "w") as f:
|
||||
json.dump(original_config, f, indent=2)
|
||||
|
||||
# Wait for the config to be reloaded again
|
||||
await asyncio.sleep(1.5)
|
||||
|
||||
updated_peers = await self.discovery1.discover_peers(wait_for_peers=1)
|
||||
self.assertEqual(len(updated_peers), 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(unittest.main())
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#!/bin/bash
|
||||
#!/usr/bin/env bash
|
||||
|
||||
if command -v python3.12 &>/dev/null; then
|
||||
echo "Python 3.12 is installed, proceeding with python3.12..."
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#!/bin/bash
|
||||
#!/usr/bin/env bash
|
||||
source ./install.sh
|
||||
pushd exo/networking/grpc
|
||||
python3 -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. node_service.proto
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#!/bin/bash
|
||||
#!/usr/bin/env bash
|
||||
|
||||
echo "Starting node 1"
|
||||
DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 52415 --chatgpt-api-response-timeout 900 > output1.log 2>&1 &
|
||||
|
||||
Reference in New Issue
Block a user