Compare commits

...

82 Commits

Author SHA1 Message Date
Alex Cheema
410d901505 Merge pull request #613 from samiamjidkhan/dmg-backend
image and text mode fix
2025-01-21 13:12:08 +00:00
Sami Khan
5c4ce5392c image and text mode fix 2025-01-21 04:33:54 -05:00
Alex Cheema
819ec7626e Merge pull request #611 from exo-explore/fixbuildname
fix scripts/build_exo.py: com.exolabs.exo -> net.exolabs.exo
2025-01-21 05:36:34 +00:00
Alex Cheema
ba5bb3e171 fix scripts/build_exo.py: com.exolabs.exo -> net.exolabs.exo 2025-01-21 05:36:02 +00:00
Alex Cheema
f4bbcf4c8f Merge pull request #607 from tensorsofthewall/smol_fix
Fixes for cross-platform operability
2025-01-21 02:21:18 +00:00
Sandesh Bharadwaj
b9eccedc3d Formatting 2025-01-17 05:40:42 -05:00
Sandesh Bharadwaj
5f06aa2759 Replace netifaces (unmaintained,outdated) with scapy + add dependencies for previous fixes 2025-01-17 05:37:01 -05:00
Sandesh Bharadwaj
349b5344eb Minor fix for Shard typing 2025-01-16 14:36:46 -05:00
Sandesh Bharadwaj
df3624d27a Add AMD GPU querying + Windows device capabilities 2025-01-14 20:37:02 -05:00
Sandesh Bharadwaj
6737e36e23 Fixed MLX import blocking native Windows execution of exo. (Not Final) 2025-01-14 20:35:21 -05:00
Alex Cheema
c260689a06 Merge pull request #602 from exo-explore/fixexodir
fix exo folder
2025-01-12 03:46:14 +00:00
Alex Cheema
fcc699a55f fix 2025-01-12 03:40:59 +00:00
Alex Cheema
e7b98f5ae5 fix unit tests 2025-01-12 03:35:24 +00:00
Alex Cheema
ffe78f6d0b fix dummy test 2025-01-12 03:30:06 +00:00
Alex Cheema
ce5041ee1b types 2025-01-12 03:24:42 +00:00
Alex Cheema
9b2c01c873 ensure dir exists 2025-01-12 03:15:49 +00:00
Alex Cheema
2aed3f3518 handle inference_state properly 2025-01-12 03:13:17 +00:00
Alex Cheema
2af5ee02e4 fix exo folder 2025-01-12 03:10:11 +00:00
Alex Cheema
b5cbcbc7a2 Merge pull request #474 from pranav4501/stable-stable-diffusion-mlx
Stable diffusion mlx
2025-01-12 02:57:21 +00:00
Alex Cheema
5f3d000a7b Merge branch 'main' into stable-stable-diffusion-mlx 2025-01-12 02:56:34 +00:00
Alex Cheema
bd2e8e7a5a Merge pull request #598 from exo-explore/fixphitest
typo in phi test
2025-01-08 22:09:38 +00:00
Alex Cheema
40696b21f7 typo in phi test 2025-01-08 22:09:04 +00:00
Alex Cheema
4937fb3df8 Merge pull request #597 from exo-explore/tuioverflow
Tui overflow
2025-01-08 16:40:16 +00:00
Alex Cheema
2d631ea53d Merge pull request #596 from exo-explore/phi4
add phi 3.5, phi 4
2025-01-08 16:39:32 +00:00
Alex Cheema
2846a9122f tok tests 2025-01-08 16:39:11 +00:00
Alex Cheema
553ccce728 fix prompt and output overflow in tui 2025-01-08 16:36:56 +00:00
Alex Cheema
c587593364 add phi 3.5, phi 4 2025-01-08 16:19:43 +00:00
Alex Cheema
3c9efe103d Merge pull request #590 from metaspartan/fix-models-api
Fix the /v1/models API to output proper OpenAI compatible endpoint
2025-01-07 02:32:06 +00:00
Carsen Klock
627bfcae7c Fix the /v1/models API to output proper OpenAI compatible endpoint
Modify the `/v1/models` API to output a proper OpenAI compatible endpoint with an object and a `data` object containing the models list.

* Change the `handle_get_models` method in `exo/api/chatgpt_api.py` to wrap the models list in an object with a `data` field.
* Add an `object` field with the value "list" to the response format.

---

For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/metaspartan/exo?shareId=XXXX-XXXX-XXXX-XXXX).
2025-01-06 01:20:30 -07:00
Alex Cheema
d9a836f152 Merge pull request #588 from exo-explore/betterdl
better download
2025-01-05 02:35:04 +00:00
Alex Cheema
29244c6369 fix args for ensure_shard 2025-01-05 02:33:25 +00:00
Alex Cheema
8c191050a2 download status in parallel, support async ensure shard with using shard_downloader instead 2025-01-05 02:31:59 +00:00
Alex Cheema
7b1656140e Merge pull request #585 from pepebruari/main
Add --system-prompt to exo cli
2025-01-03 23:49:50 +00:00
pepebruari
fe50d4d34d Add --system-prompt to exo cli 2025-01-03 16:16:22 -05:00
Alex Cheema
03aa6cecf1 Merge pull request #584 from exo-explore/AlexCheema-patch-1
add trending badge to README.md
2024-12-31 17:51:10 +00:00
Alex Cheema
178cc4d961 add trending badge to README.md 2024-12-31 17:50:29 +00:00
Pranav Veldurthi
b13e368368 fix inference engine 2024-12-30 19:41:19 -05:00
Pranav Veldurthi
9986fb86d4 remove prints and fix download progress for SD 2024-12-30 19:07:37 -05:00
Pranav Veldurthi
3475be9e9e Remove build 2024-12-30 18:39:17 -05:00
Pranav Veldurthi
fff8a1a690 fix inference engine for inference state 2024-12-30 18:36:53 -05:00
Pranav Veldurthi
54605299b8 Merge Latest 2024-12-30 18:36:23 -05:00
Alex Cheema
a174c78004 Merge pull request #383 from ianpaul10/feat/manual-disc-follow-up
Support changing manual configuration while running
2024-12-28 11:57:25 +00:00
Ian Paul
b003292b89 formatting and fixing tests after rebasing 2024-12-28 12:31:15 +07:00
Ian Paul
1dfd058c23 rm unecessary lock 2024-12-28 12:13:34 +07:00
Ian Paul
2eadaa2c0d rm redundant cleanup task 2024-12-28 12:13:34 +07:00
Ian Paul
637446ffa9 rm redundant typing 2024-12-28 12:13:34 +07:00
Ian Paul
a31f9e6c20 fix test warnings 2024-12-28 12:13:34 +07:00
Ian Paul
18acb97b42 make popping from dict threadsafe 2024-12-28 12:11:51 +07:00
Ian Paul
b066c944f3 make all I/O ops in manual_discovery.py run inside a ThreadPoolExecutor 2024-12-28 12:11:51 +07:00
Ian Paul
0e34ce2169 patch after rebasing to main 2024-12-28 12:11:51 +07:00
Ian Paul
90de7eada9 changes after rebase 2024-12-28 12:11:51 +07:00
Ian Paul
8d24df2b4b fix test runtime warning 2024-12-28 12:11:50 +07:00
Ian Paul
e5eb3259a5 handle when a peer is removed from config, so the known_peers dict gets updated accordingly 2024-12-28 12:11:21 +07:00
Ian Paul
2e8227fccb handle intermediate state for when config is being updated 2024-12-28 12:11:21 +07:00
Ian Paul
98118babae allow update to manual discovery file
re-load manual discovery file for each runthrough of the peer network, allowing incremental updates to the peer file even when exo is running
2024-12-28 12:11:21 +07:00
Alex Cheema
496a3b49f5 Merge pull request #561 from VerisimilitudeX/patch-1
Improved clarity, fixed typos, added macOS/Linux examples, and enhanc…
2024-12-27 17:06:00 +00:00
Alex Cheema
aba1bed5ed Merge pull request #575 from exo-explore/fixtok
Revert "Merge pull request #573 from damho1104/feature/add-exaone-3.5…
2024-12-27 16:36:34 +00:00
Alex Cheema
e08522ee97 Revert "Merge pull request #573 from damho1104/feature/add-exaone-3.5-model"
This reverts commit 4eb6a6a74a, reversing
changes made to fdc3b5ac02.
2024-12-27 16:35:54 +00:00
Alex Cheema
4eb6a6a74a Merge pull request #573 from damho1104/feature/add-exaone-3.5-model
Add exaone-3.5-2.4b, exaone-3.5-7.8b
2024-12-27 12:36:09 +00:00
damho.lee
94a5e908b0 add exaone-3.5 LLM Model 2024-12-24 20:57:11 +09:00
Alex Cheema
fdc3b5ac02 Merge pull request #571 from exo-explore/function_calling
add chatgpt-api-compatible tools for function calling
2024-12-24 02:08:48 +00:00
Alex Cheema
185b1e375c fix names in dummy tokenizer 2024-12-24 02:08:20 +00:00
Alex Cheema
078b807654 fix names of qwen models 2024-12-24 02:06:13 +00:00
Alex Cheema
188ac445c9 function calling example with weather tool 2024-12-24 01:57:17 +00:00
Alex Cheema
456fbdd2b0 add chatgpt-api-compatible tools for function calling 2024-12-24 01:51:55 +00:00
Alex Cheema
41df9ce1d7 Merge pull request #570 from exo-explore/moreqwen
add qwen-2.5-1.5b, qwen-2.5-3b, qwen-2.5-32b
2024-12-24 01:51:26 +00:00
Alex Cheema
c609c05e40 add qwen-2.5-1.5b, qwen-2.5-3b, qwen-2.5-32b 2024-12-24 01:50:12 +00:00
Alex Cheema
ba8c514974 Merge pull request #569 from deftdawg/env_bash
Use `#!/usr/bin/env bash` for better portability
2024-12-22 23:25:38 +00:00
DeftDawg
cde912deef - Use #!/usr/bin/env bash instead of #!/bin/bash for better portability 2024-12-22 01:14:54 -05:00
Piyush Acharya
154e0f58e4 Implement suggestiond 2024-12-21 19:40:53 -08:00
Piyush Acharya
6c82365ee2 Improved clarity, fixed typos, added macOS/Linux examples, and enhanced installation/debugging instructions 2024-12-17 18:02:34 -08:00
Pranav Veldurthi
5c0cd1839b Update strength image to image gen 2024-12-16 18:40:36 -05:00
Pranav Veldurthi
0f10244900 Merge latest 2024-12-04 22:52:48 -05:00
Pranav Veldurthi
686e139508 Merge Latest 2024-12-04 22:52:25 -05:00
Pranav Veldurthi
ca0caad0ae Image to image generation 2024-12-04 22:40:12 -05:00
Pranav Veldurthi
4b8c4a795f Images stored in system 2024-12-01 19:31:51 -05:00
Pranav Veldurthi
497756f7c8 merge latest main 2024-11-25 17:50:33 -05:00
Pranav Veldurthi
4874295b34 Image streaming while generation 2024-11-20 18:08:54 -05:00
Alex Cheema
fece3f0cef gitignore tinychat pngs 2024-11-20 10:01:06 +04:00
Alex Cheema
38ee815107 static images dir 2024-11-20 09:55:36 +04:00
Pranav Veldurthi
3d5746f16f Merge 2024-11-19 23:17:21 -05:00
Pranav Veldurthi
6b28ef0349 Stable stable diffusion mlx 2024-11-19 23:13:22 -05:00
45 changed files with 3043 additions and 464 deletions

2
.gitignore vendored
View File

@@ -171,3 +171,5 @@ cython_debug/
**/*.xcodeproj/*
.aider*
exo/tinychat/images/*.png

View File

@@ -18,6 +18,8 @@ exo: Run your own AI cluster at home with everyday devices. Maintained by [exo l
[![Tests](https://dl.circleci.com/status-badge/img/circleci/TrkofJDoGzdQAeL6yVHKsg/4i5hJuafuwZYZQxbRAWS71/tree/main.svg?style=svg)](https://dl.circleci.com/status-badge/redirect/circleci/TrkofJDoGzdQAeL6yVHKsg/4i5hJuafuwZYZQxbRAWS71/tree/main)
[![License: GPL v3](https://img.shields.io/badge/License-GPLv3-blue.svg)](https://www.gnu.org/licenses/gpl-3.0)
<a href="https://trendshift.io/repositories/11849" target="_blank"><img src="https://trendshift.io/api/badge/repositories/11849" alt="exo-explore%2Fexo | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
</div>
---
@@ -38,7 +40,7 @@ We also welcome contributions from the community. We have a list of bounties in
### Wide Model Support
exo supports different models including LLaMA ([MLX](exo/inference/mlx/models/llama.py) and [tinygrad](exo/inference/tinygrad/models/llama.py)), Mistral, LlaVA, Qwen and Deepseek.
exo supports different models including LLaMA ([MLX](exo/inference/mlx/models/llama.py) and [tinygrad](exo/inference/tinygrad/models/llama.py)), Mistral, LlaVA, Qwen, and Deepseek.
### Dynamic Model Partitioning
@@ -100,13 +102,13 @@ source install.sh
- There are a number of things users have empirically found to improve performance on Apple Silicon Macs:
1. Upgrade to the latest version of MacOS 15.
1. Upgrade to the latest version of macOS Sequoia.
2. Run `./configure_mlx.sh`. This runs commands to optimize GPU memory allocation on Apple Silicon Macs.
## Documentation
### Example Usage on Multiple MacOS Devices
### Example Usage on Multiple macOS Devices
#### Device 1:
@@ -177,9 +179,9 @@ curl http://localhost:52415/v1/chat/completions \
}'
```
### Example Usage on Multiple Heterogenous Devices (MacOS + Linux)
### Example Usage on Multiple Heterogenous Devices (macOS + Linux)
#### Device 1 (MacOS):
#### Device 1 (macOS):
```sh
exo
@@ -244,7 +246,7 @@ python3 format.py ./exo
## Known Issues
- On some versions of MacOS/Python, certificates are not installed properly which can lead to SSL errors (e.g. SSL error with huggingface.co). To fix this, run the Install Certificates command, usually:
- On certain versions of Python on macOS, certificates may not installed correctly, potentially causing SSL errors (e.g., when accessing huggingface.co). To resolve this, run the `Install Certificates` command, typicall as follows:
```sh
/Applications/Python 3.x/Install Certificates.command

View File

@@ -1,4 +1,4 @@
#!/bin/bash
#!/usr/bin/env bash
# Get the total memory in MB
TOTAL_MEM_MB=$(($(sysctl -n hw.memsize) / 1024 / 1024))

View File

@@ -0,0 +1,111 @@
import json
import re
import requests
def get_current_weather(location: str, unit: str = "celsius"):
"""Mock weather data function"""
# Hardcoded response for demo purposes
return {
"location": location,
"temperature": 22 if unit == "celsius" else 72,
"unit": unit,
"forecast": "Sunny with light clouds"
}
def try_parse_tool_calls(content: str):
"""Try parse the tool calls."""
tool_calls = []
offset = 0
for i, m in enumerate(re.finditer(r"<tool_call>\n(.+)?\n</tool_call>", content)):
if i == 0:
offset = m.start()
try:
func = json.loads(m.group(1))
tool_calls.append({"type": "function", "function": func})
if isinstance(func["arguments"], str):
func["arguments"] = json.loads(func["arguments"])
except json.JSONDecodeError as e:
print(f"Failed to parse tool calls: the content is {m.group(1)} and {e}")
pass
if tool_calls:
if offset > 0 and content[:offset].strip():
c = content[:offset]
else:
c = ""
return {"role": "assistant", "content": c, "tool_calls": tool_calls}
return {"role": "assistant", "content": re.sub(r"<\|im_end\|>$", "", content)}
def chat_completion(messages):
"""Send chat completion request to local server"""
response = requests.post(
"http://localhost:52415/v1/chat/completions",
json={
"model": "qwen-2.5-1.5b",
"messages": messages,
"tools": [{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["location"]
}
}
}],
"tool_choice": "auto"
}
)
return response.json()
def main():
# Initial conversation
messages = [{
"role": "user",
"content": "Hi there, what's the weather in Boston?"
}]
# Get initial response
response = chat_completion(messages)
print(f"First response: {response}")
assistant_message = try_parse_tool_calls(response["choices"][0]["message"]["content"])
messages.append(assistant_message)
# If there are tool calls, execute them and continue conversation
if "tool_calls" in assistant_message:
for tool_call in assistant_message["tool_calls"]:
if tool_call["function"]["name"] == "get_current_weather":
args = tool_call["function"]["arguments"]
weather_data = get_current_weather(**args)
# Add tool response to messages
messages.append({
"role": "tool",
"content": json.dumps(weather_data),
"name": tool_call["function"]["name"]
})
# Get final response with weather data
response = chat_completion(messages)
print(f"Final response: {response}")
messages.append({
"role": "assistant",
"content": response["choices"][0]["message"]["content"]
})
# Print full conversation
for msg in messages:
print(f"\n{msg['role'].upper()}: {msg['content']}")
if __name__ == "__main__":
main()

View File

@@ -5,41 +5,58 @@ import json
import os
from pathlib import Path
from transformers import AutoTokenizer
from typing import List, Literal, Union, Dict
from typing import List, Literal, Union, Dict, Optional
from aiohttp import web
import aiohttp_cors
import traceback
import signal
from exo import DEBUG, VERSION
from exo.download.download_progress import RepoProgressEvent
from exo.helpers import PrefixDict, shutdown
from exo.helpers import PrefixDict, shutdown, get_exo_images_dir
from exo.inference.tokenizers import resolve_tokenizer
from exo.orchestration import Node
from exo.models import build_base_shard, model_cards, get_repo, pretty_name
from typing import Callable, Optional
from PIL import Image
import numpy as np
import base64
from io import BytesIO
import platform
if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
import mlx.core as mx
else:
import numpy as mx
import tempfile
from exo.download.hf.hf_shard_download import HFShardDownloader
import shutil
from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
from exo.apputil import create_animation_mp4
class Message:
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]], tools: Optional[List[Dict]] = None):
self.role = role
self.content = content
self.tools = tools
def to_dict(self):
return {"role": self.role, "content": self.content}
data = {"role": self.role, "content": self.content}
if self.tools:
data["tools"] = self.tools
return data
class ChatCompletionRequest:
def __init__(self, model: str, messages: List[Message], temperature: float):
def __init__(self, model: str, messages: List[Message], temperature: float, tools: Optional[List[Dict]] = None):
self.model = model
self.messages = messages
self.temperature = temperature
self.tools = tools
def to_dict(self):
return {"model": self.model, "messages": [message.to_dict() for message in self.messages], "temperature": self.temperature}
return {"model": self.model, "messages": [message.to_dict() for message in self.messages], "temperature": self.temperature, "tools": self.tools}
def generate_completion(
@@ -119,20 +136,32 @@ def remap_messages(messages: List[Message]) -> List[Message]:
return remapped_messages
def build_prompt(tokenizer, _messages: List[Message]):
def build_prompt(tokenizer, _messages: List[Message], tools: Optional[List[Dict]] = None):
messages = remap_messages(_messages)
prompt = tokenizer.apply_chat_template([m.to_dict() for m in messages], tokenize=False, add_generation_prompt=True)
for message in messages:
if not isinstance(message.content, list):
continue
chat_template_args = {"conversation": [m.to_dict() for m in messages], "tokenize": False, "add_generation_prompt": True}
if tools:
chat_template_args["tools"] = tools
return prompt
try:
prompt = tokenizer.apply_chat_template(**chat_template_args)
if DEBUG >= 3: print(f"!!! Prompt: {prompt}")
return prompt
except UnicodeEncodeError:
# Handle Unicode encoding by ensuring everything is UTF-8
chat_template_args["conversation"] = [
{k: v.encode('utf-8').decode('utf-8') if isinstance(v, str) else v
for k, v in m.to_dict().items()}
for m in messages
]
prompt = tokenizer.apply_chat_template(**chat_template_args)
if DEBUG >= 3: print(f"!!! Prompt (UTF-8 encoded): {prompt}")
return prompt
def parse_message(data: dict):
if "role" not in data or "content" not in data:
raise ValueError(f"Invalid message: {data}. Must have 'role' and 'content'")
return Message(data["role"], data["content"])
return Message(data["role"], data["content"], data.get("tools"))
def parse_chat_request(data: dict, default_model: str):
@@ -140,6 +169,7 @@ def parse_chat_request(data: dict, default_model: str):
data.get("model", default_model),
[parse_message(msg) for msg in data["messages"]],
data.get("temperature", 0.0),
data.get("tools", None),
)
@@ -149,8 +179,17 @@ class PromptSession:
self.timestamp = timestamp
self.prompt = prompt
class ChatGPTAPI:
def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None):
def __init__(
self,
node: Node,
inference_engine_classname: str,
response_timeout: int = 90,
on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None,
default_model: Optional[str] = None,
system_prompt: Optional[str] = None
):
self.node = node
self.inference_engine_classname = inference_engine_classname
self.response_timeout = response_timeout
@@ -160,6 +199,7 @@ class ChatGPTAPI:
self.prev_token_lens: Dict[str, int] = {}
self.stream_tasks: Dict[str, asyncio.Task] = {}
self.default_model = default_model or "llama-3.2-1b"
self.system_prompt = system_prompt
cors = aiohttp_cors.setup(self.app)
cors_options = aiohttp_cors.ResourceOptions(
@@ -174,6 +214,7 @@ class ChatGPTAPI:
cors.add(self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
cors.add(self.app.router.add_post("/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
cors.add(self.app.router.add_post("/v1/image/generations", self.handle_post_image_generations), {"*": cors_options})
cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})
cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options})
@@ -184,16 +225,22 @@ class ChatGPTAPI:
cors.add(self.app.router.add_post("/download", self.handle_post_download), {"*": cors_options})
cors.add(self.app.router.add_get("/topology", self.handle_get_topology), {"*": cors_options})
# Add static routes
if "__compiled__" not in globals():
self.static_dir = Path(__file__).parent.parent/"tinychat"
self.app.router.add_get("/", self.handle_root)
self.app.router.add_static("/", self.static_dir, name="static")
# Always add images route, regardless of compilation status
self.images_dir = get_exo_images_dir()
self.images_dir.mkdir(parents=True, exist_ok=True)
self.app.router.add_static('/images/', self.images_dir, name='static_images')
self.app.middlewares.append(self.timeout_middleware)
self.app.middlewares.append(self.log_request)
async def handle_quit(self, request):
if DEBUG>=1: print("Received quit signal")
if DEBUG >= 1: print("Received quit signal")
response = web.json_response({"detail": "Quit signal received"}, status=200)
await response.prepare(request)
await response.write_eof()
@@ -223,58 +270,52 @@ class ChatGPTAPI:
async def handle_model_support(self, request):
try:
response = web.StreamResponse(
status=200,
reason='OK',
headers={
'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache',
'Connection': 'keep-alive',
}
)
await response.prepare(request)
response = web.StreamResponse(status=200, reason='OK', headers={
'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache',
'Connection': 'keep-alive',
})
await response.prepare(request)
for model_name, pretty in pretty_name.items():
if model_name in model_cards:
model_info = model_cards[model_name]
async def process_model(model_name, pretty):
if model_name in model_cards:
model_info = model_cards[model_name]
if self.inference_engine_classname in model_info.get("repo", {}):
shard = build_base_shard(model_name, self.inference_engine_classname)
if shard:
downloader = HFShardDownloader(quick_check=True)
downloader.current_shard = shard
downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
status = await downloader.get_shard_download_status()
if self.inference_engine_classname in model_info.get("repo", {}):
shard = build_base_shard(model_name, self.inference_engine_classname)
if shard:
downloader = HFShardDownloader(quick_check=True)
downloader.current_shard = shard
downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
status = await downloader.get_shard_download_status()
download_percentage = status.get("overall") if status else None
total_size = status.get("total_size") if status else None
total_downloaded = status.get("total_downloaded") if status else False
download_percentage = status.get("overall") if status else None
total_size = status.get("total_size") if status else None
total_downloaded = status.get("total_downloaded") if status else False
model_data = {
model_name: {
"name": pretty,
"downloaded": download_percentage == 100 if download_percentage is not None else False,
"download_percentage": download_percentage,
"total_size": total_size,
"total_downloaded": total_downloaded
}
}
model_data = {
model_name: {
"name": pretty, "downloaded": download_percentage == 100 if download_percentage is not None else False, "download_percentage": download_percentage, "total_size": total_size,
"total_downloaded": total_downloaded
}
}
await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
await response.write(b"data: [DONE]\n\n")
return response
# Process all models in parallel
await asyncio.gather(*[process_model(model_name, pretty) for model_name, pretty in pretty_name.items()])
await response.write(b"data: [DONE]\n\n")
return response
except Exception as e:
print(f"Error in handle_model_support: {str(e)}")
traceback.print_exc()
return web.json_response(
{"detail": f"Server error: {str(e)}"},
status=500
)
print(f"Error in handle_model_support: {str(e)}")
traceback.print_exc()
return web.json_response({"detail": f"Server error: {str(e)}"}, status=500)
async def handle_get_models(self, request):
return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()])
models_list = [{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()]
return web.json_response({"object": "list", "data": models_list})
async def handle_post_chat_token_encode(self, request):
data = await request.json()
@@ -287,7 +328,7 @@ class ChatGPTAPI:
shard = build_base_shard(model, self.inference_engine_classname)
messages = [parse_message(msg) for msg in data.get("messages", [])]
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
prompt = build_prompt(tokenizer, messages)
prompt = build_prompt(tokenizer, messages, data.get("tools", None))
tokens = tokenizer.encode(prompt)
return web.json_response({
"length": len(prompt),
@@ -326,7 +367,11 @@ class ChatGPTAPI:
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
prompt = build_prompt(tokenizer, chat_request.messages)
# Add system prompt if set
if self.system_prompt and not any(msg.role == "system" for msg in chat_request.messages):
chat_request.messages.insert(0, Message("system", self.system_prompt))
prompt = build_prompt(tokenizer, chat_request.messages, chat_request.tools)
request_id = str(uuid.uuid4())
if self.on_chat_completion_request:
try:
@@ -435,23 +480,109 @@ class ChatGPTAPI:
deregistered_callback = self.node.on_token.deregister(callback_id)
if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")
async def handle_post_image_generations(self, request):
data = await request.json()
if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
stream = data.get("stream", False)
model = data.get("model", "")
prompt = data.get("prompt", "")
image_url = data.get("image_url", "")
if DEBUG >= 2: print(f"model: {model}, prompt: {prompt}, stream: {stream}")
shard = build_base_shard(model, self.inference_engine_classname)
if DEBUG >= 2: print(f"shard: {shard}")
if not shard:
return web.json_response({"error": f"Unsupported model: {model} with inference engine {self.inference_engine_classname}"}, status=400)
request_id = str(uuid.uuid4())
callback_id = f"chatgpt-api-wait-response-{request_id}"
callback = self.node.on_token.register(callback_id)
try:
if image_url != "" and image_url != None:
img = self.base64_decode(image_url)
else:
img = None
await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id, inference_state={"image": img}))), timeout=self.response_timeout)
response = web.StreamResponse(status=200, reason='OK', headers={
'Content-Type': 'application/octet-stream',
"Cache-Control": "no-cache",
})
await response.prepare(request)
def get_progress_bar(current_step, total_steps, bar_length=50):
# Calculate the percentage of completion
percent = float(current_step)/total_steps
# Calculate the number of hashes to display
arrow = '-'*int(round(percent*bar_length) - 1) + '>'
spaces = ' '*(bar_length - len(arrow))
# Create the progress bar string
progress_bar = f'Progress: [{arrow}{spaces}] {int(percent * 100)}% ({current_step}/{total_steps})'
return progress_bar
async def stream_image(_request_id: str, result, is_finished: bool):
if isinstance(result, list):
await response.write(json.dumps({'progress': get_progress_bar((result[0]), (result[1]))}).encode('utf-8') + b'\n')
elif isinstance(result, np.ndarray):
try:
im = Image.fromarray(np.array(result))
# Save the image to a file
image_filename = f"{_request_id}.png"
image_path = self.images_dir/image_filename
im.save(image_path)
# Get URL for the saved image
try:
image_url = request.app.router['static_images'].url_for(filename=image_filename)
base_url = f"{request.scheme}://{request.host}"
full_image_url = base_url + str(image_url)
await response.write(json.dumps({'images': [{'url': str(full_image_url), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
except KeyError as e:
if DEBUG >= 2: print(f"Error getting image URL: {e}")
# Fallback to direct file path if URL generation fails
await response.write(json.dumps({'images': [{'url': str(image_path), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
if is_finished:
await response.write_eof()
except Exception as e:
if DEBUG >= 2: print(f"Error processing image: {e}")
if DEBUG >= 2: traceback.print_exc()
await response.write(json.dumps({'error': str(e)}).encode('utf-8') + b'\n')
stream_task = None
def on_result(_request_id: str, result, is_finished: bool):
nonlocal stream_task
stream_task = asyncio.create_task(stream_image(_request_id, result, is_finished))
return _request_id == request_id and is_finished
await callback.wait(on_result, timeout=self.response_timeout*10)
if stream_task:
# Wait for the stream task to complete before returning
await stream_task
return response
except Exception as e:
if DEBUG >= 2: traceback.print_exc()
return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
async def handle_delete_model(self, request):
try:
model_name = request.match_info.get('model_name')
if DEBUG >= 2: print(f"Attempting to delete model: {model_name}")
if not model_name or model_name not in model_cards:
return web.json_response(
{"detail": f"Invalid model name: {model_name}"},
status=400
)
return web.json_response({"detail": f"Invalid model name: {model_name}"}, status=400)
shard = build_base_shard(model_name, self.inference_engine_classname)
if not shard:
return web.json_response(
{"detail": "Could not build shard for model"},
status=400
)
return web.json_response({"detail": "Could not build shard for model"}, status=400)
repo_id = get_repo(shard.model_id, self.inference_engine_classname)
if DEBUG >= 2: print(f"Repo ID for model: {repo_id}")
@@ -466,38 +597,28 @@ class ChatGPTAPI:
if DEBUG >= 2: print(f"Found model files at {cache_dir}, deleting...")
try:
shutil.rmtree(cache_dir)
return web.json_response({
"status": "success",
"message": f"Model {model_name} deleted successfully",
"path": str(cache_dir)
})
return web.json_response({"status": "success", "message": f"Model {model_name} deleted successfully", "path": str(cache_dir)})
except Exception as e:
return web.json_response({
"detail": f"Failed to delete model files: {str(e)}"
}, status=500)
return web.json_response({"detail": f"Failed to delete model files: {str(e)}"}, status=500)
else:
return web.json_response({
"detail": f"Model files not found at {cache_dir}"
}, status=404)
return web.json_response({"detail": f"Model files not found at {cache_dir}"}, status=404)
except Exception as e:
print(f"Error in handle_delete_model: {str(e)}")
traceback.print_exc()
return web.json_response({
"detail": f"Server error: {str(e)}"
}, status=500)
print(f"Error in handle_delete_model: {str(e)}")
traceback.print_exc()
return web.json_response({"detail": f"Server error: {str(e)}"}, status=500)
async def handle_get_initial_models(self, request):
model_data = {}
for model_name, pretty in pretty_name.items():
model_data[model_name] = {
"name": pretty,
"downloaded": None, # Initially unknown
"download_percentage": None, # Change from 0 to null
"total_size": None,
"total_downloaded": None,
"loading": True # Add loading state
}
model_data[model_name] = {
"name": pretty,
"downloaded": None, # Initially unknown
"download_percentage": None, # Change from 0 to null
"total_size": None,
"total_downloaded": None,
"loading": True # Add loading state
}
return web.json_response(model_data)
async def handle_create_animation(self, request):
@@ -523,17 +644,9 @@ class ChatGPTAPI:
if DEBUG >= 2: print(f"Animation temp directory: {tmp_dir}, output file: {output_path}, directory exists: {tmp_dir.exists()}, directory permissions: {oct(tmp_dir.stat().st_mode)[-3:]}")
# Create the animation
create_animation_mp4(
replacement_image_path,
output_path,
device_name,
prompt_text
)
create_animation_mp4(replacement_image_path, output_path, device_name, prompt_text)
return web.json_response({
"status": "success",
"output_path": output_path
})
return web.json_response({"status": "success", "output_path": output_path})
except Exception as e:
if DEBUG >= 2: traceback.print_exc()
@@ -547,12 +660,9 @@ class ChatGPTAPI:
if model_name not in model_cards: return web.json_response({"error": f"Invalid model: {model_name}. Supported models: {list(model_cards.keys())}"}, status=400)
shard = build_base_shard(model_name, self.inference_engine_classname)
if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400)
asyncio.create_task(self.node.inference_engine.ensure_shard(shard))
asyncio.create_task(self.node.inference_engine.shard_downloader.ensure_shard(shard, self.inference_engine_classname))
return web.json_response({
"status": "success",
"message": f"Download started for model: {model_name}"
})
return web.json_response({"status": "success", "message": f"Download started for model: {model_name}"})
except Exception as e:
if DEBUG >= 2: traceback.print_exc()
return web.json_response({"error": str(e)}, status=500)
@@ -566,13 +676,25 @@ class ChatGPTAPI:
return web.json_response({})
except Exception as e:
if DEBUG >= 2: traceback.print_exc()
return web.json_response(
{"detail": f"Error getting topology: {str(e)}"},
status=500
)
return web.json_response({"detail": f"Error getting topology: {str(e)}"}, status=500)
async def run(self, host: str = "0.0.0.0", port: int = 52415):
runner = web.AppRunner(self.app)
await runner.setup()
site = web.TCPSite(runner, host, port)
await site.start()
def base64_decode(self, base64_string):
#decode and reshape image
if base64_string.startswith('data:image'):
base64_string = base64_string.split(',')[1]
image_data = base64.b64decode(base64_string)
img = Image.open(BytesIO(image_data))
W, H = (dim - dim%64 for dim in (img.width, img.height))
if W != img.width or H != img.height:
if DEBUG >= 2: print(f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}")
img = img.resize((W, H), Image.NEAREST) # use desired downsampling filter
img = mx.array(np.array(img))
img = (img[:, :, :3].astype(mx.float32)/255)*2 - 1
img = img[None]
return img

View File

@@ -303,6 +303,10 @@ async def download_repo_files(
await f.write(json.dumps(file_list))
if DEBUG >= 2: print(f"Cached file list at {cached_file_list_path}")
model_index_exists = any(file["path"] == "model_index.json" for file in file_list)
if model_index_exists:
allow_patterns = ["**/*.json", "**/*.txt", "**/*model.safetensors", "*.json"]
filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
total_files = len(filtered_file_list)
total_bytes = sum(file["size"] for file in filtered_file_list)

View File

@@ -104,15 +104,19 @@ class HFShardDownloader(ShardDownloader):
print(f"No snapshot directory found for {self.current_repo_id}")
return None
if not await aios.path.exists(snapshot_dir/"model_index.json"):
# Get the weight map to know what files we need
weight_map = await get_weight_map(self.current_repo_id, self.revision)
if not weight_map:
if DEBUG >= 2:
print(f"No weight map found for {self.current_repo_id}")
return None
weight_map = await get_weight_map(self.current_repo_id, self.revision)
if not weight_map:
if DEBUG >= 2:
print(f"No weight map found for {self.current_repo_id}")
return None
# Get all files needed for this shard
patterns = get_allow_patterns(weight_map, self.current_shard)
else:
patterns = ["**/*.json", "**/*.txt", "**/*model.safetensors", "*.json"]
# Get all files needed for this shard
patterns = get_allow_patterns(weight_map, self.current_shard)
# Check download status for all relevant files
status = {}

View File

@@ -7,7 +7,8 @@ import random
import platform
import psutil
import uuid
import netifaces
from scapy.all import get_if_addr, get_if_list
import re
import subprocess
from pathlib import Path
import tempfile
@@ -231,26 +232,26 @@ def pretty_print_bytes_per_second(bytes_per_second: int) -> str:
def get_all_ip_addresses_and_interfaces():
try:
ip_addresses = []
for interface in netifaces.interfaces():
ifaddresses = netifaces.ifaddresses(interface)
if netifaces.AF_INET in ifaddresses:
for link in ifaddresses[netifaces.AF_INET]:
ip = link['addr']
ip_addresses.append((ip, interface))
for interface in get_if_list():
ip = get_if_addr(interface)
# Include all addresses, including loopback
# Filter out link-local addresses
if not ip.startswith('169.254.') and not ip.startswith('0.0.'):
# Remove "\\Device\\NPF_" prefix from interface name
simplified_interface = re.sub(r'^\\Device\\NPF_', '', interface)
ip_addresses.append((ip, simplified_interface))
return list(set(ip_addresses))
except:
if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
return [("localhost", "lo")]
async def get_macos_interface_type(ifname: str) -> Optional[Tuple[int, str]]:
try:
# Use the shared subprocess_pool
output = await asyncio.get_running_loop().run_in_executor(subprocess_pool, lambda: subprocess.run(
['system_profiler', 'SPNetworkDataType', '-json'],
capture_output=True,
text=True,
close_fds=True
).stdout)
output = await asyncio.get_running_loop().run_in_executor(
subprocess_pool, lambda: subprocess.run(['system_profiler', 'SPNetworkDataType', '-json'], capture_output=True, text=True, close_fds=True).stdout
)
data = json.loads(output)
@@ -276,6 +277,7 @@ async def get_macos_interface_type(ifname: str) -> Optional[Tuple[int, str]]:
return None
async def get_interface_priority_and_type(ifname: str) -> Tuple[int, str]:
# On macOS, try to get interface type using networksetup
if psutil.MACOS:
@@ -283,8 +285,7 @@ async def get_interface_priority_and_type(ifname: str) -> Tuple[int, str]:
if macos_type is not None: return macos_type
# Local container/virtual interfaces
if (ifname.startswith(('docker', 'br-', 'veth', 'cni', 'flannel', 'calico', 'weave')) or
'bridge' in ifname):
if (ifname.startswith(('docker', 'br-', 'veth', 'cni', 'flannel', 'calico', 'weave')) or 'bridge' in ifname):
return (7, "Container Virtual")
# Loopback interface
@@ -310,6 +311,7 @@ async def get_interface_priority_and_type(ifname: str) -> Tuple[int, str]:
# Other physical interfaces
return (2, "Other")
async def shutdown(signal, loop, server):
"""Gracefully shutdown the server and close the asyncio loop."""
print(f"Received exit signal {signal.name}...")
@@ -325,4 +327,20 @@ async def shutdown(signal, loop, server):
def is_frozen():
return getattr(sys, 'frozen', False) or os.path.basename(sys.executable) == "exo" \
or ('Contents/MacOS' in str(os.path.dirname(sys.executable))) \
or '__nuitka__' in globals() or getattr(sys, '__compiled__', False)
or '__nuitka__' in globals() or getattr(sys, '__compiled__', False)
def get_exo_home() -> Path:
if psutil.WINDOWS: docs_folder = Path(os.environ["USERPROFILE"])/"Documents"
else: docs_folder = Path.home()/"Documents"
if not docs_folder.exists(): docs_folder.mkdir(exist_ok=True)
exo_folder = docs_folder/"Exo"
if not exo_folder.exists(): exo_folder.mkdir(exist_ok=True)
return exo_folder
def get_exo_images_dir() -> Path:
exo_home = get_exo_home()
images_dir = exo_home/"Images"
if not images_dir.exists(): images_dir.mkdir(exist_ok=True)
return images_dir

View File

@@ -16,25 +16,25 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
resp_full = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
token_full = await inference_engine_1.sample(resp_full)
next_resp_full = await inference_engine_1.infer_tensor(
next_resp_full, _ = await inference_engine_1.infer_tensor(
"A",
shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
input_data=token_full,
)
resp1 = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
resp2 = await inference_engine_2.infer_tensor(
resp1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
resp2, _ = await inference_engine_2.infer_tensor(
"B",
shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
input_data=resp1,
)
token2 = await inference_engine_2.sample(resp2)
resp3 = await inference_engine_1.infer_tensor(
resp3, _ = await inference_engine_1.infer_tensor(
"B",
shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32),
input_data=token2,
)
resp4 = await inference_engine_2.infer_tensor(
resp4, _ = await inference_engine_2.infer_tensor(
"B",
shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
input_data=resp3,

View File

@@ -25,9 +25,9 @@ class DummyInferenceEngine(InferenceEngine):
async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
return self.tokenizer.decode(tokens)
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
await self.ensure_shard(shard)
return input_data + 1 if self.shard.is_last_layer() else input_data
return input_data + 1 if self.shard.is_last_layer() else input_data, None
async def ensure_shard(self, shard: Shard):
if self.shard == shard: return

View File

@@ -5,6 +5,7 @@ from exo.helpers import DEBUG # Make sure to import DEBUG
from typing import Tuple, Optional
from abc import ABC, abstractmethod
from .shard import Shard
from exo.download.shard_download import ShardDownloader
class InferenceEngine(ABC):
@@ -13,7 +14,7 @@ class InferenceEngine(ABC):
@abstractmethod
async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
pass
@abstractmethod
async def sample(self, x: np.ndarray) -> np.ndarray:
pass
@@ -23,7 +24,7 @@ class InferenceEngine(ABC):
pass
@abstractmethod
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
pass
@abstractmethod
@@ -32,18 +33,23 @@ class InferenceEngine(ABC):
async def save_checkpoint(self, shard: Shard, path: str):
pass
async def save_session(self, key, value):
self.session[key] = value
async def clear_session(self):
self.session.empty()
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str) -> np.ndarray:
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
tokens = await self.encode(shard, prompt)
x = tokens.reshape(1, -1)
output_data = await self.infer_tensor(request_id, shard, x)
return output_data
if shard.model_id != 'stable-diffusion-2-1-base':
x = tokens.reshape(1, -1)
else:
x = tokens
output_data, inference_state = await self.infer_tensor(request_id, shard, x, inference_state)
return output_data, inference_state
inference_engine_classes = {
"mlx": "MLXDynamicShardInferenceEngine",
@@ -51,7 +57,8 @@ inference_engine_classes = {
"dummy": "DummyInferenceEngine",
}
def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):
def get_inference_engine(inference_engine_name: str, shard_downloader: ShardDownloader):
if DEBUG >= 2:
print(f"get_inference_engine called with: {inference_engine_name}")
if inference_engine_name == "mlx":

View File

@@ -0,0 +1,307 @@
# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/__init__.py
import time
from typing import Optional, Tuple
import inspect
import mlx.core as mx
import mlx.nn as nn
from pathlib import Path
from tqdm import tqdm
from .sd_models.vae import ModelArgs as VAEArgs
from .sd_models.vae import Autoencoder
from .sd_models.tokenizer import load_tokenizer
from .sd_models.clip import CLIPTextModel
from .sd_models.clip import ModelArgs as CLIPArgs
from .sd_models.unet import UNetConfig, UNetModel
from dataclasses import dataclass, field
from exo.inference.shard import Shard
@dataclass
class DiffusionConfig:
beta_schedule: str = "scaled_linear"
beta_start: float = 0.00085
beta_end: float = 0.012
num_train_steps: int = 1000
@classmethod
def from_dict(cls, params):
return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
#Sampler
def _linspace(a, b, num):
x = mx.arange(0, num) / (num - 1)
return (b - a) * x + a
def _interp(y, x_new):
"""Interpolate the function defined by (arange(0, len(y)), y) at positions x_new."""
x_low = x_new.astype(mx.int32)
x_high = mx.minimum(x_low + 1, len(y) - 1)
y_low = y[x_low]
y_high = y[x_high]
delta_x = x_new - x_low
y_new = y_low * (1 - delta_x) + delta_x * y_high
return y_new
class SimpleEulerSampler:
"""A simple Euler integrator that can be used to sample from our diffusion models.
The method ``step()`` performs one Euler step from x_t to x_t_prev.
"""
def __init__(self, config: DiffusionConfig):
# Compute the noise schedule
if config.beta_schedule == "linear":
betas = _linspace(
config.beta_start, config.beta_end, config.num_train_steps
)
elif config.beta_schedule == "scaled_linear":
betas = _linspace(
config.beta_start**0.5, config.beta_end**0.5, config.num_train_steps
).square()
else:
raise NotImplementedError(f"{config.beta_schedule} is not implemented.")
alphas = 1 - betas
alphas_cumprod = mx.cumprod(alphas)
self._sigmas = mx.concatenate(
[mx.zeros(1), ((1 - alphas_cumprod) / alphas_cumprod).sqrt()]
)
@property
def max_time(self):
return len(self._sigmas) - 1
def sample_prior(self, shape, dtype=mx.float32, key=None):
noise = mx.random.normal(shape, key=key)
return (
noise * self._sigmas[-1] * (self._sigmas[-1].square() + 1).rsqrt()
).astype(dtype)
def add_noise(self, x, t, key=None):
noise = mx.random.normal(x.shape, key=key)
s = self.sigmas(t)
return (x + noise * s) * (s.square() + 1).rsqrt()
def sigmas(self, t):
return _interp(self._sigmas, t)
def timesteps(self, num_steps: int, start_time=None, dtype=mx.float32):
start_time = start_time or (len(self._sigmas) - 1)
assert 0 < start_time <= (len(self._sigmas) - 1)
steps = _linspace(start_time, 0, num_steps + 1).astype(dtype)
return list(zip(steps, steps[1:]))
def current_timestep(self, step, total_steps, start_time=None):
if step < total_steps:
steps = self.timesteps(total_steps, start_time)
return steps[step]
else:
return mx.array(0),mx.array(0)
def step(self, eps_pred, x_t, t, t_prev):
sigma = self.sigmas(t).astype(eps_pred.dtype)
sigma_prev = self.sigmas(t_prev).astype(eps_pred.dtype)
dt = sigma_prev - sigma
x_t_prev = (sigma.square() + 1).sqrt() * x_t + eps_pred * dt
x_t_prev = x_t_prev * (sigma_prev.square() + 1).rsqrt()
return x_t_prev
@dataclass
class ShardConfig:
model_id:str
start_layer:int
end_layer:int
n_layers:int
@dataclass
class StableDiffusionConfig:
model_type:str
vae:VAEArgs
text_encoder:CLIPArgs
scheduler:DiffusionConfig
unet:UNetConfig
shard:ShardConfig
@classmethod
def from_dict(cls, params):
return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
@dataclass
class ModelArgs(StableDiffusionConfig):
shard:Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
def __post_init__(self):
if isinstance(self.shard, dict):
self.shard = Shard(**self.shard)
if not isinstance(self.shard, Shard):
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
class Model(nn.Module):
def __init__(self, config):
super().__init__()
self.model_type = config.model_type
self.config = config
self.model_path = config.vae['path'].split('/vae')[0]
self.shard = config.shard
self.shard_clip, self.shard_encoder, self.shard_unet, self.shard_decoder = model_shards(config.shard)
self.config_clip=CLIPArgs.from_dict(config.text_encoder['config'])
if self.shard_clip.start_layer != -1:
self.text_encoder = CLIPTextModel(self.config_clip, shard=self.shard_clip)
else:
self.text_encoder = nn.Identity()
self.tokenizer = load_tokenizer(Path(self.model_path), "vocab.json", "merges.txt")
self.diffusion_config = DiffusionConfig.from_dict(config.scheduler['config'])
self.sampler = SimpleEulerSampler(self.diffusion_config)
if self.shard_unet.start_layer!=-1:
self.config_unet = UNetConfig.from_dict(config.unet['config'])
self.unet = UNetModel(self.config_unet, self.shard_unet)
else:
self.unet = nn.Identity()
self.config_vae=VAEArgs.from_dict(config.vae['config'])
if self.shard_encoder.start_layer != -1:
self.encoder=Autoencoder(self.config_vae, self.shard_encoder, "vae_encoder")
else:
self.encoder = nn.Identity()
if self.shard_decoder.start_layer != -1:
self.decoder=Autoencoder(self.config_vae, self.shard_decoder, "vae_decoder")
else:
self.decoder = nn.Identity()
def __call__(self,x, step= 0, cfg_weight: float = 7.5,total_steps=50,conditioning=None,mask=None,residual=None,x_t_prev=None,is_finished=False,is_step_finished=False, image=None, strength=0.65, start_step=None):
t, t_prev = self.sampler.current_timestep(step=step, total_steps=total_steps, start_time=start_step)
is_finished = False
is_step_finished = False
if t.item()==1000:
if self.shard_clip.start_layer == 0:
conditioning = x
if self.shard_clip.start_layer != -1:
conditioning, mask= self.text_encoder(conditioning,mask)
seed = int(time.time())
mx.random.seed(seed)
if image is None:
if self.shard_encoder.is_last_layer():
x = self.sampler.sample_prior((1, *(64, 64), self.config_vae.latent_channels_in), dtype=mx.float32)
x_t_prev=x
start_step = self.sampler.max_time
else:
if self.shard_encoder.start_layer != -1:
image= self.encoder.encode(image)
if self.shard_encoder.is_last_layer():
start_step = self.sampler.max_time*strength
total_steps = int(total_steps*strength)
image = mx.broadcast_to(image, (1,) + image.shape[1:])
x_t_prev=self.sampler.add_noise(image, mx.array(start_step))
image = None
t, t_prev = self.sampler.current_timestep(step=step, total_steps=total_steps, start_time=start_step)
# Perform the denoising loop
if self.shard_unet.start_layer != -1:
with tqdm(total=total_steps,initial=step+1) as pbar:
if step<total_steps:
x = x_t_prev
if self.shard_unet.is_first_layer():
x_t_unet = mx.concatenate([x] * 2, axis=0) if cfg_weight> 1 else x
else:
x_t_unet = x
t_unet = mx.broadcast_to(t, [len(x_t_unet)])
x, residual= self.unet(x_t_unet, t_unet, encoder_x=conditioning, residuals=residual)
if self.shard_unet.is_last_layer():
if cfg_weight > 1:
eps_text, eps_neg = x.split(2)
eps_pred = eps_neg + cfg_weight * (eps_text - eps_neg)
x = self.sampler.step(eps_pred, x_t_prev, t, t_prev)
x_t_prev=x
mx.eval(x)
if self.shard_decoder.is_last_layer():
is_step_finished=True
if self.shard_decoder.start_layer != -1:
x=self.decoder.decode(x)
if self.shard_decoder.is_last_layer():
x = mx.clip(x / 2 + 0.5, 0, 1)
B, H, W, C = x.shape
x = x.reshape(1, B // 1, H, W, C).transpose(0, 2, 1, 3, 4)
x = x.reshape(1 * H, B // 1 * W, C)
x = (x * 255).astype(mx.uint8)
if t_prev.item() ==0:
is_finished=True
mx.eval(x)
return x, {'conditioning':conditioning, 'mask':mask,'residual':residual,'x_t_prev':x_t_prev,'is_finished':is_finished,'is_step_finished':is_step_finished, 'step':step, 'total_steps':total_steps, 'start_step':start_step, 'image':image}
def load(self):
if self.shard_encoder.start_layer != -1:
vae_weights = mx.load(self.config_vae.weight_files[0])
vae_weights = self.encoder.sanitize(vae_weights)
self.encoder.load_weights(list(vae_weights.items()), strict=True)
if self.shard_decoder.start_layer != -1:
vae_weights = mx.load(self.config_vae.weight_files[0])
vae_weights = self.decoder.sanitize(vae_weights)
self.decoder.load_weights(list(vae_weights.items()), strict=True)
if self.shard_clip.start_layer != -1:
clip_weights = mx.load(self.config_clip.weight_files[0])
clip_weights = self.text_encoder.sanitize(clip_weights)
self.text_encoder.load_weights(list(clip_weights.items()), strict=True)
if self.shard_unet.start_layer !=-1:
unet_weights = mx.load(self.config_unet.weight_files[0])
unet_weights = self.unet.sanitize(unet_weights)
self.unet.load_weights(list(unet_weights.items()), strict=True)
def model_shards(shard:ShardConfig):
def create_shard(shard, model_ranges):
start_layer = shard.start_layer
end_layer = shard.end_layer
shards = {}
for model_name, (range_start, range_end) in model_ranges.items():
if start_layer < range_end and end_layer >= range_start:
# Calculate the overlap with the model range
overlap_start = max(start_layer, range_start)
overlap_end = min(end_layer, range_end - 1)
# Adjust the layers relative to the model's range
relative_start = overlap_start - range_start
relative_end = overlap_end - range_start
shards[model_name] = Shard(model_name, relative_start, relative_end, range_end - range_start)
else:
# If no overlap, create a zero-layer shard
shards[model_name] = Shard(model_name, -1, -1, range_end - range_start)
return shards
# Define the ranges for different models
model_ranges = {
'clip': (0, 12),
'vae_encoder':(12,17),
'unet':(17,26),
'vae_decoder': (26, 31) # Example range for unet
}
# Call the function and get the shards for all models
shards = create_shard(shard, model_ranges)
# Access individual shards
shard_clip = shards['clip']
shard_encoder = shards['vae_encoder']
shard_unet = shards['unet']
shard_decoder = shards['vae_decoder']
return shard_clip, shard_encoder, shard_unet, shard_decoder

View File

@@ -0,0 +1,117 @@
from dataclasses import dataclass, field
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.models.base import create_attention_mask
from mlx_lm.models.phi3 import TransformerBlock, ModelArgs
from ...shard import Shard
from .base import IdentityBlock
@dataclass
class ModelArgs(ModelArgs):
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
def __post_init__(self):
super().__post_init__()
if isinstance(self.shard, Shard):
return
if not isinstance(self.shard, dict):
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
self.shard = Shard(**self.shard)
class Phi3Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
if self.args.shard.is_first_layer():
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = []
for i in range(self.num_hidden_layers):
if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
self.layers.append(TransformerBlock(args=args))
else:
self.layers.append(IdentityBlock())
if self.args.shard.is_last_layer():
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
):
if self.args.shard.is_first_layer():
h = self.embed_tokens(inputs)
else:
h = inputs
mask = None
if h.shape[1] > 1:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
if self.args.shard.is_last_layer():
h = self.norm(h)
return h
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = Phi3Model(args)
if self.args.shard.is_last_layer():
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.model(inputs, cache)
if self.args.shard.is_last_layer():
out = self.lm_head(out)
return out
def sanitize(self, weights):
shard_state_dict = {}
for key, value in weights.items():
if "self_attn.rope.inv_freq" in key:
continue
if key.startswith('model.layers.'):
layer_num = int(key.split('.')[2])
if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer:
shard_state_dict[key] = value
elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'):
shard_state_dict[key] = value
elif self.args.shard.is_last_layer() and (key.startswith('lm_head') or key.startswith('model.norm')):
shard_state_dict[key] = value
return shard_state_dict
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -9,13 +9,12 @@ from mlx_lm.models.qwen2 import TransformerBlock, ModelArgs
from ...shard import Shard
from .base import IdentityBlock
@dataclass
class ModelArgs(ModelArgs):
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
def __post_init__(self):
super().__post_init__() # Ensure parent initializations are respected
super().__post_init__()
if isinstance(self.shard, Shard):
return
@@ -24,7 +23,6 @@ class ModelArgs(ModelArgs):
self.shard = Shard(**self.shard)
class Qwen2Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
@@ -32,14 +30,17 @@ class Qwen2Model(nn.Module):
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
if self.args.shard.is_first_layer():
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = []
for i in range(self.num_hidden_layers):
if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
self.layers.append(TransformerBlock(args=args))
else:
self.layers.append(IdentityBlock())
if self.args.shard.is_last_layer():
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)

View File

@@ -0,0 +1,191 @@
# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/clip.py
import math
from dataclasses import dataclass
from typing import List, Optional
import mlx.core as mx
import mlx.nn as nn
from dataclasses import field, dataclass
from exo.inference.shard import Shard
from exo.inference.mlx.models.base import IdentityBlock
_ACTIVATIONS = {"quick_gelu": nn.gelu_fast_approx, "gelu": nn.gelu}
@dataclass
class CLIPTextModelConfig:
num_layers: int = 23
model_dims: int = 1024
num_heads: int = 16
max_length: int = 77
vocab_size: int = 49408
projection_dim: Optional[int] = None
hidden_act: str = "quick_gelu"
@classmethod
def from_dict(cls, config):
return ModelArgs(
num_layers=config["num_hidden_layers"],
model_dims=config["hidden_size"],
num_heads=config["num_attention_heads"],
max_length=config["max_position_embeddings"],
vocab_size=config["vocab_size"],
projection_dim=config["projection_dim"] if "WithProjection" in config['architectures'][0] else None,
hidden_act=config.get("hidden_act", "quick_gelu"),
weight_files=config.get("weight_files", [])
)
@dataclass
class ModelArgs(CLIPTextModelConfig):
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
weight_files: List[str] = field(default_factory=lambda: [])
def __post_init__(self):
if isinstance(self.shard, dict):
self.shard = Shard(**self.shard)
if not isinstance(self.shard, Shard):
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
if not self.shard.is_first_layer():
self.vision_config = None
@dataclass
class CLIPOutput:
pooled_output: Optional[mx.array] = None
last_hidden_state: Optional[mx.array] = None
hidden_states: Optional[List[mx.array]] = None
class CLIPEncoderLayer(nn.Module):
"""The transformer encoder layer from CLIP."""
def __init__(self, model_dims: int, num_heads: int, activation: str):
super().__init__()
self.layer_norm1 = nn.LayerNorm(model_dims)
self.layer_norm2 = nn.LayerNorm(model_dims)
self.attention = nn.MultiHeadAttention(model_dims, num_heads)
self.attention.query_proj.bias = mx.zeros(model_dims)
self.attention.key_proj.bias = mx.zeros(model_dims)
self.attention.value_proj.bias = mx.zeros(model_dims)
self.attention.out_proj.bias = mx.zeros(model_dims)
self.linear1 = nn.Linear(model_dims, 4 * model_dims)
self.linear2 = nn.Linear(4 * model_dims, model_dims)
self.act = _ACTIVATIONS[activation]
def __call__(self, x, attn_mask=None):
y = self.layer_norm1(x)
y = self.attention(y, y, y, attn_mask)
x = y + x
y = self.layer_norm2(x)
y = self.linear1(y)
y = self.act(y)
y = self.linear2(y)
x = y + x
return x
class CLIPTextModel(nn.Module):
"""Implements the text encoder transformer from CLIP."""
def __init__(self, config: CLIPTextModelConfig, shard: Shard):
super().__init__()
self.shard = shard
self.layers_range = range(self.shard.start_layer*2, self.shard.end_layer*2+2)
if self.shard.is_first_layer():
self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims)
self.position_embedding = nn.Embedding(config.max_length, config.model_dims)
self.layers = []
for i in range(math.ceil(config.num_layers/2)):
if 2*i in self.layers_range:
self.layers.append(CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act))
if 2*i+1 in self.layers_range and 2*i+1 < config.num_layers:
self.layers.append(CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act))
else:
self.layers.append(IdentityBlock())
if self.shard.is_last_layer():
self.final_layer_norm = nn.LayerNorm(config.model_dims)
if config.projection_dim is not None:
self.text_projection = nn.Linear(
config.model_dims, config.projection_dim, bias=False
)
def _get_mask(self, N, dtype):
indices = mx.arange(N)
mask = indices[:, None] < indices[None]
mask = mask.astype(dtype) * (-6e4 if dtype == mx.float16 else -1e9)
return mask
def __call__(self, x, mask=None):
# Extract some shapes
if self.shard.is_first_layer():
B, N = x.shape
eos_tokens = x.argmax(-1)
# Compute the embeddings
x = self.token_embedding(x)
x = x + self.position_embedding.weight[:N]
# Compute the features from the transformer
mask = self._get_mask(N, x.dtype)
for l in self.layers:
x = l(x, mask)
# Apply the final layernorm and return
if self.shard.is_last_layer():
x = self.final_layer_norm(x)
return x, mask
def sanitize(self, weights):
sanitized_weights = {}
for key, value in weights.items():
if "position_ids" in key:
continue
if key.startswith("text_model."):
key = key[11:]
if key.startswith("embeddings."):
key = key[11:]
if key.startswith("encoder."):
key = key[8:]
# Map attention layers
if "self_attn." in key:
key = key.replace("self_attn.", "attention.")
if "q_proj." in key:
key = key.replace("q_proj.", "query_proj.")
if "k_proj." in key:
key = key.replace("k_proj.", "key_proj.")
if "v_proj." in key:
key = key.replace("v_proj.", "value_proj.")
# Map ffn layers
if "mlp.fc1" in key:
key = key.replace("mlp.fc1", "linear1")
if "mlp.fc2" in key:
key = key.replace("mlp.fc2", "linear2")
if key.startswith("layers."):
layer_num = int(key.split(".")[1])
if layer_num not in self.layers_range:
continue
if not self.shard.is_first_layer() and "embedding" in key:
continue
if not self.shard.is_last_layer() and key.startswith("final_layer_norm"):
continue
if not self.shard.is_last_layer() and key.startswith("text_projection"):
continue
sanitized_weights[key] = value
return sanitized_weights

View File

@@ -0,0 +1,131 @@
# adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/tokenizer.py
import regex
import json
import glob
class Tokenizer:
"""A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ ."""
def __init__(self, bpe_ranks, vocab):
self.bpe_ranks = bpe_ranks
self.vocab = vocab
self.pat = regex.compile(
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
regex.IGNORECASE,
)
self._cache = {self.bos: self.bos, self.eos: self.eos}
@property
def bos(self):
return "<|startoftext|>"
@property
def bos_token(self):
return self.vocab[self.bos]
@property
def eos(self):
return "<|endoftext|>"
@property
def eos_token(self):
return self.vocab[self.eos]
def bpe(self, text):
if text in self._cache:
return self._cache[text]
unigrams = list(text[:-1]) + [text[-1] + "</w>"]
unique_bigrams = set(zip(unigrams, unigrams[1:]))
if not unique_bigrams:
return unigrams
# In every iteration try to merge the two most likely bigrams. If none
# was merged we are done.
#
# Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py
while unique_bigrams:
bigram = min(
unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))
)
if bigram not in self.bpe_ranks:
break
new_unigrams = []
skip = False
for a, b in zip(unigrams, unigrams[1:]):
if skip:
skip = False
continue
if (a, b) == bigram:
new_unigrams.append(a + b)
skip = True
else:
new_unigrams.append(a)
if not skip:
new_unigrams.append(b)
unigrams = new_unigrams
unique_bigrams = set(zip(unigrams, unigrams[1:]))
self._cache[text] = unigrams
return unigrams
def tokenize(self, text, prepend_bos=True, append_eos=True):
if isinstance(text, list):
return [self.tokenize(t, prepend_bos, append_eos) for t in text]
# Lower case cleanup and split according to self.pat. Hugging Face does
# a much more thorough job here but this should suffice for 95% of
# cases.
clean_text = regex.sub(r"\s+", " ", text.lower())
tokens = regex.findall(self.pat, clean_text)
# Split the tokens according to the byte-pair merge file
bpe_tokens = [ti for t in tokens for ti in self.bpe(t)]
# Map to token ids and return
tokens = [self.vocab[t] for t in bpe_tokens]
if prepend_bos:
tokens = [self.bos_token] + tokens
if append_eos:
tokens.append(self.eos_token)
return tokens
def encode(self, prompt):
tokens = [self.tokenize(prompt)]
negative_text = ""
if negative_text is not None:
tokens += [self.tokenize(negative_text)]
lengths = [len(t) for t in tokens]
N = max(lengths)
tokens = [t + [0] * (N - len(t)) for t in tokens]
return tokens
def load_tokenizer(
model_path: str,
vocab_key: str = "tokenizer_vocab",
merges_key: str = "tokenizer_merges",
):
vocab_file = glob.glob(str(model_path/"tokenizer"/vocab_key))[0]
with open(vocab_file, encoding="utf-8") as f:
vocab = json.load(f)
merges_file = glob.glob(str(model_path/"tokenizer"/merges_key))[0]
with open(merges_file, encoding="utf-8") as f:
bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
bpe_merges = [tuple(m.split()) for m in bpe_merges]
bpe_ranks = dict(map(reversed, enumerate(bpe_merges)))
return Tokenizer(bpe_ranks, vocab)

View File

@@ -0,0 +1,629 @@
# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/unet.py
import math
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
from dataclasses import dataclass, field
from typing import Tuple, Optional, List
from exo.inference.shard import Shard
@dataclass
class UNetConfig:
in_channels: int = 4
out_channels: int = 4
conv_in_kernel: int = 3
conv_out_kernel: int = 3
block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
layers_per_block: Tuple[int] = (2, 2, 2, 2)
mid_block_layers: int = 2
transformer_layers_per_block: Tuple[int] = (1, 1, 1, 1)
num_attention_heads: Tuple[int] = (5, 10, 20, 20)
cross_attention_dim: Tuple[int] = (1024,) * 4
norm_num_groups: int = 32
down_block_types: Tuple[str] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
)
up_block_types: Tuple[str] = (
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
)
addition_embed_type: Optional[str] = None
addition_time_embed_dim: Optional[int] = None
projection_class_embeddings_input_dim: Optional[int] = None
weight_files: List[str] = field(default_factory=lambda: [])
@classmethod
def from_dict(cls,config):
n_blocks = len(config['block_out_channels'])
return UNetConfig(
in_channels=config["in_channels"],
out_channels=config["out_channels"],
block_out_channels=config["block_out_channels"],
layers_per_block=[config["layers_per_block"]] * n_blocks,
transformer_layers_per_block=config.get(
"transformer_layers_per_block", (1,) * 4
),
num_attention_heads=(
[config["attention_head_dim"]] * n_blocks
if isinstance(config["attention_head_dim"], int)
else config["attention_head_dim"]
),
cross_attention_dim=[config["cross_attention_dim"]] * n_blocks,
norm_num_groups=config["norm_num_groups"],
down_block_types=config["down_block_types"],
up_block_types=config["up_block_types"][::-1],
addition_embed_type=config.get("addition_embed_type", None),
addition_time_embed_dim=config.get("addition_time_embed_dim", None),
projection_class_embeddings_input_dim=config.get(
"projection_class_embeddings_input_dim", None
),
weight_files=config.get("weight_files", [])
)
def upsample_nearest(x, scale: int = 2):
B, H, W, C = x.shape
x = mx.broadcast_to(x[:, :, None, :, None, :], (B, H, scale, W, scale, C))
x = x.reshape(B, H * scale, W * scale, C)
return x
class TimestepEmbedding(nn.Module):
def __init__(self, in_channels: int, time_embed_dim: int):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
def __call__(self, x):
x = self.linear_1(x)
x = nn.silu(x)
x = self.linear_2(x)
return x
class TransformerBlock(nn.Module):
def __init__(
self,
model_dims: int,
num_heads: int,
hidden_dims: Optional[int] = None,
memory_dims: Optional[int] = None,
):
super().__init__()
self.norm1 = nn.LayerNorm(model_dims)
self.attn1 = nn.MultiHeadAttention(model_dims, num_heads)
self.attn1.out_proj.bias = mx.zeros(model_dims)
memory_dims = memory_dims or model_dims
self.norm2 = nn.LayerNorm(model_dims)
self.attn2 = nn.MultiHeadAttention(
model_dims, num_heads, key_input_dims=memory_dims
)
self.attn2.out_proj.bias = mx.zeros(model_dims)
hidden_dims = hidden_dims or 4 * model_dims
self.norm3 = nn.LayerNorm(model_dims)
self.linear1 = nn.Linear(model_dims, hidden_dims)
self.linear2 = nn.Linear(model_dims, hidden_dims)
self.linear3 = nn.Linear(hidden_dims, model_dims)
def __call__(self, x, memory, attn_mask, memory_mask):
# Self attention
y = self.norm1(x)
y = self.attn1(y, y, y, attn_mask)
x = x + y
# Cross attention
y = self.norm2(x)
y = self.attn2(y, memory, memory, memory_mask)
x = x + y
# FFN
y = self.norm3(x)
y_a = self.linear1(y)
y_b = self.linear2(y)
y = y_a * nn.gelu(y_b)
y = self.linear3(y)
x = x + y
return x
class Transformer2D(nn.Module):
"""A transformer model for inputs with 2 spatial dimensions."""
def __init__(
self,
in_channels: int,
model_dims: int,
encoder_dims: int,
num_heads: int,
num_layers: int = 1,
norm_num_groups: int = 32,
):
super().__init__()
self.norm = nn.GroupNorm(norm_num_groups, in_channels, pytorch_compatible=True)
self.proj_in = nn.Linear(in_channels, model_dims)
self.transformer_blocks = [
TransformerBlock(model_dims, num_heads, memory_dims=encoder_dims)
for i in range(num_layers)
]
self.proj_out = nn.Linear(model_dims, in_channels)
def __call__(self, x, encoder_x, attn_mask, encoder_attn_mask):
# Save the input to add to the output
input_x = x
dtype = x.dtype
# Perform the input norm and projection
B, H, W, C = x.shape
x = self.norm(x.astype(mx.float32)).astype(dtype).reshape(B, -1, C)
x = self.proj_in(x)
# Apply the transformer
for block in self.transformer_blocks:
x = block(x, encoder_x, attn_mask, encoder_attn_mask)
# Apply the output projection and reshape
x = self.proj_out(x)
x = x.reshape(B, H, W, C)
return x + input_x
class ResnetBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
groups: int = 32,
temb_channels: Optional[int] = None,
):
super().__init__()
out_channels = out_channels or in_channels
self.norm1 = nn.GroupNorm(groups, in_channels, pytorch_compatible=True)
self.conv1 = nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if temb_channels is not None:
self.time_emb_proj = nn.Linear(temb_channels, out_channels)
self.norm2 = nn.GroupNorm(groups, out_channels, pytorch_compatible=True)
self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if in_channels != out_channels:
self.conv_shortcut = nn.Linear(in_channels, out_channels)
def __call__(self, x, temb=None):
dtype = x.dtype
if temb is not None:
temb = self.time_emb_proj(nn.silu(temb))
y = self.norm1(x.astype(mx.float32)).astype(dtype)
y = nn.silu(y)
y = self.conv1(y)
if temb is not None:
y = y + temb[:, None, None, :]
y = self.norm2(y.astype(mx.float32)).astype(dtype)
y = nn.silu(y)
y = self.conv2(y)
x = y + (x if "conv_shortcut" not in self else self.conv_shortcut(x))
return x
class UNetBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
prev_out_channels: Optional[int] = None,
num_layers: int = 1,
transformer_layers_per_block: int = 1,
num_attention_heads: int = 8,
cross_attention_dim=1280,
resnet_groups: int = 32,
add_downsample=True,
add_upsample=True,
add_cross_attention=True,
):
super().__init__()
# Prepare the in channels list for the resnets
if prev_out_channels is None:
in_channels_list = [in_channels] + [out_channels] * (num_layers - 1)
else:
in_channels_list = [prev_out_channels] + [out_channels] * (num_layers - 1)
res_channels_list = [out_channels] * (num_layers - 1) + [in_channels]
in_channels_list = [
a + b for a, b in zip(in_channels_list, res_channels_list)
]
# Add resnet blocks that also process the time embedding
self.resnets = [
ResnetBlock2D(
in_channels=ic,
out_channels=out_channels,
temb_channels=temb_channels,
groups=resnet_groups,
)
for ic in in_channels_list
]
# Add optional cross attention layers
if add_cross_attention:
self.attentions = [
Transformer2D(
in_channels=out_channels,
model_dims=out_channels,
num_heads=num_attention_heads,
num_layers=transformer_layers_per_block,
encoder_dims=cross_attention_dim,
)
for i in range(num_layers)
]
# Add an optional downsampling layer
if add_downsample:
self.downsample = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=2, padding=1
)
# or upsampling layer
if add_upsample:
self.upsample = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
def __call__(
self,
x,
encoder_x=None,
temb=None,
attn_mask=None,
encoder_attn_mask=None,
residual_hidden_states=None,
):
output_states = []
for i in range(len(self.resnets)):
if residual_hidden_states is not None:
x = mx.concatenate([x, residual_hidden_states.pop()], axis=-1)
x = self.resnets[i](x, temb)
if "attentions" in self:
x = self.attentions[i](x, encoder_x, attn_mask, encoder_attn_mask)
output_states.append(x)
if "downsample" in self:
x = self.downsample(x)
output_states.append(x)
if "upsample" in self:
x = self.upsample(upsample_nearest(x))
output_states.append(x)
return x, output_states
class UNetModel(nn.Module):
"""The conditional 2D UNet model that actually performs the denoising."""
def __init__(self, config: UNetConfig, shard: Shard):
super().__init__()
self.shard = shard
self.start_layer = shard.start_layer
self.end_layer = shard.end_layer
self.layers_range = list(range(self.start_layer, self.end_layer+1))
if shard.is_first_layer():
self.conv_in = nn.Conv2d(
config.in_channels,
config.block_out_channels[0],
config.conv_in_kernel,
padding=(config.conv_in_kernel - 1) // 2,
)
self.timesteps = nn.SinusoidalPositionalEncoding(
config.block_out_channels[0],
max_freq=1,
min_freq=math.exp(
-math.log(10000) + 2 * math.log(10000) / config.block_out_channels[0]
),
scale=1.0,
cos_first=True,
full_turns=False,
)
self.time_embedding = TimestepEmbedding(
config.block_out_channels[0],
config.block_out_channels[0] * 4,
)
if config.addition_embed_type == "text_time":
self.add_time_proj = nn.SinusoidalPositionalEncoding(
config.addition_time_embed_dim,
max_freq=1,
min_freq=math.exp(
-math.log(10000)
+ 2 * math.log(10000) / config.addition_time_embed_dim
),
scale=1.0,
cos_first=True,
full_turns=False,
)
self.add_embedding = TimestepEmbedding(
config.projection_class_embeddings_input_dim,
config.block_out_channels[0] * 4,
)
# Make the downsampling blocks
block_channels = [config.block_out_channels[0]] + list(
config.block_out_channels
)
self.down_blocks = []
for i, (in_channels, out_channels) in enumerate(zip(block_channels, block_channels[1:])):
if i in self.layers_range:
self.down_blocks.append(
UNetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=config.block_out_channels[0] * 4,
num_layers=config.layers_per_block[i],
transformer_layers_per_block=config.transformer_layers_per_block[i],
num_attention_heads=config.num_attention_heads[i],
cross_attention_dim=config.cross_attention_dim[i],
resnet_groups=config.norm_num_groups,
add_downsample=(i < len(config.block_out_channels) - 1),
add_upsample=False,
add_cross_attention="CrossAttn" in config.down_block_types[i],
)
)
else:
self.down_blocks.append(nn.Identity())
# Make the middle block
if 4 in self.layers_range:
self.mid_blocks = [
ResnetBlock2D(
in_channels=config.block_out_channels[-1],
out_channels=config.block_out_channels[-1],
temb_channels=config.block_out_channels[0] * 4,
groups=config.norm_num_groups,
),
Transformer2D(
in_channels=config.block_out_channels[-1],
model_dims=config.block_out_channels[-1],
num_heads=config.num_attention_heads[-1],
num_layers=config.transformer_layers_per_block[-1],
encoder_dims=config.cross_attention_dim[-1],
),
ResnetBlock2D(
in_channels=config.block_out_channels[-1],
out_channels=config.block_out_channels[-1],
temb_channels=config.block_out_channels[0] * 4,
groups=config.norm_num_groups,
),
]
# Make the upsampling blocks
block_channels = (
[config.block_out_channels[0]]
+ list(config.block_out_channels)
+ [config.block_out_channels[-1]]
)
total_items = len(block_channels) - 3
reversed_channels = list(reversed(list(zip(block_channels, block_channels[1:], block_channels[2:]))))
self.up_blocks = []
for rev_i, (in_channels, out_channels, prev_out_channels) in enumerate(reversed_channels):
i = total_items - rev_i
if rev_i+5 in self.layers_range:
self.up_blocks.append(
UNetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=config.block_out_channels[0] * 4,
prev_out_channels=prev_out_channels,
num_layers=config.layers_per_block[i] + 1,
transformer_layers_per_block=config.transformer_layers_per_block[i],
num_attention_heads=config.num_attention_heads[i],
cross_attention_dim=config.cross_attention_dim[i],
resnet_groups=config.norm_num_groups,
add_downsample=False,
add_upsample=(i > 0),
add_cross_attention="CrossAttn" in config.up_block_types[i],
)
)
else:
self.up_blocks.append(nn.Identity())
if shard.is_last_layer():
self.conv_norm_out = nn.GroupNorm(
config.norm_num_groups,
config.block_out_channels[0],
pytorch_compatible=True,
)
self.conv_out = nn.Conv2d(
config.block_out_channels[0],
config.out_channels,
config.conv_out_kernel,
padding=(config.conv_out_kernel - 1) // 2,
)
def __call__(
self,
x,
timestep,
encoder_x,
attn_mask=None,
encoder_attn_mask=None,
text_time=None,
residuals=None,
):
# Compute the time embeddings
temb = self.timesteps(timestep).astype(x.dtype)
temb = self.time_embedding(temb)
# Add the extra text_time conditioning
if text_time is not None:
text_emb, time_ids = text_time
emb = self.add_time_proj(time_ids).flatten(1).astype(x.dtype)
emb = mx.concatenate([text_emb, emb], axis=-1)
emb = self.add_embedding(emb)
temb = temb + emb
if self.shard.is_first_layer():
# Preprocess the input
x = self.conv_in(x)
residuals = [x]
# Run the downsampling part of the unet
for i in range(len(self.down_blocks)):
if i in self.layers_range:
x, res = self.down_blocks[i](
x,
encoder_x=encoder_x,
temb=temb,
attn_mask=attn_mask,
encoder_attn_mask=encoder_attn_mask,
)
residuals.extend(res)
else:
x= self.down_blocks[i](x)
if 4 in self.layers_range:
# Run the middle part of the unet
x = self.mid_blocks[0](x, temb)
x = self.mid_blocks[1](x, encoder_x, attn_mask, encoder_attn_mask)
x = self.mid_blocks[2](x, temb)
# Run the upsampling part of the unet
for i in range(len(self.up_blocks)):
if i+5 in self.layers_range:
x, _ = self.up_blocks[i](
x,
encoder_x=encoder_x,
temb=temb,
attn_mask=attn_mask,
encoder_attn_mask=encoder_attn_mask,
residual_hidden_states=residuals,
)
else:
x= self.up_blocks[i](x)
# Postprocess the output
if self.shard.is_last_layer():
dtype = x.dtype
x = self.conv_norm_out(x.astype(mx.float32)).astype(dtype)
x = nn.silu(x)
x = self.conv_out(x)
return x, residuals
def sanitize(self, weights):
sanitized_weights = {}
for key, value in weights.items():
k1=""
k2=""
if "downsamplers" in key:
key = key.replace("downsamplers.0.conv", "downsample")
if "upsamplers" in key:
key = key.replace("upsamplers.0.conv", "upsample")
# Map the mid block
if "mid_block.resnets.0" in key:
key = key.replace("mid_block.resnets.0", "mid_blocks.0")
if "mid_block.attentions.0" in key:
key = key.replace("mid_block.attentions.0", "mid_blocks.1")
if "mid_block.resnets.1" in key:
key = key.replace("mid_block.resnets.1", "mid_blocks.2")
# Map attention layers
if "to_k" in key:
key = key.replace("to_k", "key_proj")
if "to_out.0" in key:
key = key.replace("to_out.0", "out_proj")
if "to_q" in key:
key = key.replace("to_q", "query_proj")
if "to_v" in key:
key = key.replace("to_v", "value_proj")
# Map transformer ffn
if "ff.net.2" in key:
key = key.replace("ff.net.2", "linear3")
if "ff.net.0" in key:
k1 = key.replace("ff.net.0.proj", "linear1")
k2 = key.replace("ff.net.0.proj", "linear2")
v1, v2 = mx.split(value, 2)
if "conv_shortcut.weight" in key:
value = value.squeeze()
# Transform the weights from 1x1 convs to linear
if len(value.shape) == 4 and ("proj_in" in key or "proj_out" in key):
value = value.squeeze()
if len(value.shape) == 4:
value = value.transpose(0, 2, 3, 1)
value = value.reshape(-1).reshape(value.shape)
if key.startswith("conv_in") :
if 0 not in self.layers_range:
continue
if key.startswith("down_blocks"):
layer_num = int(key.split(".")[1])
if layer_num not in self.layers_range:
continue
if key.startswith("mid_block"):
if 4 not in self.layers_range:
continue
if key.startswith("up_blocks"):
layer_num = int(key.split(".")[1])
if (layer_num+5) not in self.layers_range:
continue
if key.startswith("conv_out") or key.startswith("conv_norm_out"):
if 8 not in self.layers_range:
continue
if len(k1)>0:
sanitized_weights[k1] = v1
sanitized_weights[k2] = v2
else:
sanitized_weights[key] = value
return sanitized_weights

View File

@@ -0,0 +1,429 @@
# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/vae.py
import math
from typing import List
import mlx.core as mx
import mlx.nn as nn
from .unet import ResnetBlock2D, upsample_nearest
from dataclasses import dataclass, field
from exo.inference.shard import Shard
from typing import Tuple
import inspect
from ..base import IdentityBlock
@dataclass
class AutoencoderConfig:
in_channels: int = 3
out_channels: int = 3
latent_channels_out: int = 8
latent_channels_in: int = 4
block_out_channels: Tuple[int] = (128, 256, 512, 512)
layers_per_block: int = 2
norm_num_groups: int = 32
scaling_factor: float = 0.18215
weight_files: List[str] = field(default_factory=lambda: [])
@classmethod
def from_dict(cls, params):
return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
@dataclass
class ModelArgs(AutoencoderConfig):
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
def __post_init__(self):
if isinstance(self.shard, dict):
self.shard = Shard(**self.shard)
if not isinstance(self.shard, Shard):
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
if not self.shard.is_first_layer():
self.vision_config = None
class Attention(nn.Module):
"""A single head unmasked attention for use with the VAE."""
def __init__(self, dims: int, norm_groups: int = 32):
super().__init__()
self.group_norm = nn.GroupNorm(norm_groups, dims, pytorch_compatible=True)
self.query_proj = nn.Linear(dims, dims)
self.key_proj = nn.Linear(dims, dims)
self.value_proj = nn.Linear(dims, dims)
self.out_proj = nn.Linear(dims, dims)
def __call__(self, x):
B, H, W, C = x.shape
y = self.group_norm(x)
queries = self.query_proj(y).reshape(B, H * W, C)
keys = self.key_proj(y).reshape(B, H * W, C)
values = self.value_proj(y).reshape(B, H * W, C)
scale = 1 / math.sqrt(queries.shape[-1])
scores = (queries * scale) @ keys.transpose(0, 2, 1)
attn = mx.softmax(scores, axis=-1)
y = (attn @ values).reshape(B, H, W, C)
y = self.out_proj(y)
x = x + y
return x
class EncoderDecoderBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
resnet_groups: int = 32,
add_downsample=True,
add_upsample=True,
):
super().__init__()
# Add the resnet blocks
self.resnets = [
ResnetBlock2D(
in_channels=in_channels if i == 0 else out_channels,
out_channels=out_channels,
groups=resnet_groups,
)
for i in range(num_layers)
]
# Add an optional downsampling layer
if add_downsample:
self.downsample = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=2, padding=0
)
# or upsampling layer
if add_upsample:
self.upsample = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
def __call__(self, x):
for resnet in self.resnets:
x = resnet(x)
if "downsample" in self:
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
x = self.downsample(x)
if "upsample" in self:
x = self.upsample(upsample_nearest(x))
return x
class Encoder(nn.Module):
"""Implements the encoder side of the Autoencoder."""
def __init__(
self,
in_channels: int,
latent_channels_out: int,
block_out_channels: List[int] = [64],
layers_per_block: int = 2,
resnet_groups: int = 32,
layers_range: List[int] = [],
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
):
super().__init__()
self.layers_range = layers_range
self.shard = shard
if self.shard.is_first_layer():
self.conv_in = nn.Conv2d(
in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1
)
channels = [block_out_channels[0]] + list(block_out_channels)
self.down_blocks = []
current_layer = 1
for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:])):
if current_layer in self.layers_range:
self.down_blocks.append(
EncoderDecoderBlock2D(
in_channels,
out_channels,
num_layers=layers_per_block,
resnet_groups=resnet_groups,
add_downsample=i < len(block_out_channels) - 1,
add_upsample=False,
)
)
else:
self.down_blocks.append(IdentityBlock())
current_layer += 1
if self.shard.is_last_layer():
self.mid_blocks = [
ResnetBlock2D(
in_channels=block_out_channels[-1],
out_channels=block_out_channels[-1],
groups=resnet_groups,
),
Attention(block_out_channels[-1], resnet_groups),
ResnetBlock2D(
in_channels=block_out_channels[-1],
out_channels=block_out_channels[-1],
groups=resnet_groups,
),
]
self.conv_norm_out = nn.GroupNorm(
resnet_groups, block_out_channels[-1], pytorch_compatible=True
)
self.conv_out = nn.Conv2d(block_out_channels[-1], latent_channels_out, 3, padding=1)
def __call__(self, x):
if self.shard.is_first_layer():
x = self.conv_in(x)
for l in self.down_blocks:
x = l(x)
if self.shard.is_last_layer():
x = self.mid_blocks[0](x)
x = self.mid_blocks[1](x)
x = self.mid_blocks[2](x)
x = self.conv_norm_out(x)
x = nn.silu(x)
x = self.conv_out(x)
return x
class Decoder(nn.Module):
"""Implements the decoder side of the Autoencoder."""
def __init__(
self,
in_channels: int,
out_channels: int,
shard: Shard,
layer_range: List[int],
block_out_channels: List[int] = [64],
layers_per_block: int = 2,
resnet_groups: int = 32,
):
super().__init__()
self.out_channels = out_channels
self.layers_range = layer_range
if 0 in layer_range:
self.conv_in = nn.Conv2d(
in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1
)
if 0 in layer_range:
self.mid_blocks = [
ResnetBlock2D(
in_channels=block_out_channels[-1],
out_channels=block_out_channels[-1],
groups=resnet_groups,
),
Attention(block_out_channels[-1], resnet_groups),
ResnetBlock2D(
in_channels=block_out_channels[-1],
out_channels=block_out_channels[-1],
groups=resnet_groups,
),
]
channels = list(reversed(block_out_channels))
channels = [channels[0]] + channels
self.up_blocks = []
current_layer = 1
for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:])):
if current_layer in layer_range:
self.up_blocks.append(
EncoderDecoderBlock2D(
in_channels,
out_channels,
num_layers=layers_per_block,
resnet_groups=resnet_groups,
add_downsample=False,
add_upsample=i < len(block_out_channels) - 1,
)
)
else:
self.up_blocks.append(IdentityBlock())
current_layer += 1
if 4 in layer_range:
self.conv_norm_out = nn.GroupNorm(
resnet_groups, block_out_channels[0], pytorch_compatible=True
)
self.conv_out = nn.Conv2d(block_out_channels[0], self.out_channels, 3, padding=1)
def __call__(self, x):
if 0 in self.layers_range:
x = self.conv_in(x)
x = self.mid_blocks[0](x)
x = self.mid_blocks[1](x)
x = self.mid_blocks[2](x)
for l in self.up_blocks:
x = l(x)
if 4 in self.layers_range:
x = self.conv_norm_out(x)
x = nn.silu(x)
x = self.conv_out(x)
return x
class Autoencoder(nn.Module):
"""The autoencoder that allows us to perform diffusion in the latent space."""
def __init__(self, config: AutoencoderConfig, shard: Shard, model_shard: str):
super().__init__()
self.shard = shard
self.start_layer = shard.start_layer
self.end_layer = shard.end_layer
self.layers_range = list(range(self.start_layer, self.end_layer+1))
self.latent_channels = config.latent_channels_in
self.scaling_factor = config.scaling_factor
self.model_shard = model_shard
if self.model_shard == "vae_encoder":
self.encoder = Encoder(
config.in_channels,
config.latent_channels_out,
config.block_out_channels,
config.layers_per_block,
resnet_groups=config.norm_num_groups,
layers_range=self.layers_range,
shard=shard
)
if self.shard.is_last_layer():
self.quant_proj = nn.Linear(
config.latent_channels_out, config.latent_channels_out
)
if self.model_shard == "vae_decoder":
self.decoder = Decoder(
config.latent_channels_in,
config.out_channels,
shard,
self.layers_range,
config.block_out_channels,
config.layers_per_block + 1,
resnet_groups=config.norm_num_groups,
)
if self.shard.is_first_layer():
self.post_quant_proj = nn.Linear(
config.latent_channels_in, config.latent_channels_in
)
def decode(self, z):
if self.shard.is_first_layer():
z = z / self.scaling_factor
z=self.post_quant_proj(z)
return self.decoder(z)
def encode(self, x):
x = self.encoder(x)
if self.shard.is_last_layer():
x = self.quant_proj(x)
mean, logvar = x.split(2, axis=-1)
mean = mean * self.scaling_factor
logvar = logvar + 2 * math.log(self.scaling_factor)
x = mean
return x
def __call__(self, x, key=None):
mean, logvar = self.encode(x)
z = mx.random.normal(mean.shape, key=key) * mx.exp(0.5 * logvar) + mean
x_hat = self.decode(z)
return dict(x_hat=x_hat, z=z, mean=mean, logvar=logvar)
def sanitize(self, weights):
shard = self.shard
layers = self.layers_range
sanitized_weights = {}
for key, value in weights.items():
if "downsamplers" in key:
key = key.replace("downsamplers.0.conv", "downsample")
if "upsamplers" in key:
key = key.replace("upsamplers.0.conv", "upsample")
# Map attention layers
if "key" in key:
key = key.replace("key", "key_proj")
if "proj_attn" in key:
key = key.replace("proj_attn", "out_proj")
if "query" in key:
key = key.replace("query", "query_proj")
if "value" in key:
key = key.replace("value", "value_proj")
# Map the mid block
if "mid_block.resnets.0" in key:
key = key.replace("mid_block.resnets.0", "mid_blocks.0")
if "mid_block.attentions.0" in key:
key = key.replace("mid_block.attentions.0", "mid_blocks.1")
if "mid_block.resnets.1" in key:
key = key.replace("mid_block.resnets.1", "mid_blocks.2")
# Map the quant/post_quant layers
if "quant_conv" in key:
key = key.replace("quant_conv", "quant_proj")
value = value.squeeze()
# Map the conv_shortcut to linear
if "conv_shortcut.weight" in key:
value = value.squeeze()
if len(value.shape) == 4:
value = value.transpose(0, 2, 3, 1)
value = value.reshape(-1).reshape(value.shape)
if "post_quant_conv" in key :
key = key.replace("quant_conv", "quant_proj")
value = value.squeeze()
if 'decoder' in key and self.model_shard == "vae_decoder":
if key.startswith("decoder.mid_blocks."):
if 0 in layers:
sanitized_weights[key] = value
if "conv_in" in key and 0 in layers:
sanitized_weights[key] = value
if key.startswith("decoder.up_blocks."):
layer_num = int(key.split(".")[2])+1
if layer_num in layers:
sanitized_weights[key] = value
if key.startswith("decoder.conv_norm_out") and 4 in layers:
sanitized_weights[key] = value
if key.startswith("decoder.conv_out") and 4 in layers:
sanitized_weights[key] = value
if self.model_shard == "vae_decoder":
if key.startswith("post_quant_proj") and 0 in layers:
sanitized_weights[key] = value
if self.model_shard == "vae_encoder":
if key.startswith("encoder."):
if "conv_in" in key and shard.is_first_layer():
sanitized_weights[key] = value
if key.startswith("encoder.down_blocks."):
layer_num = int(key.split(".")[2])+1
if layer_num in layers:
sanitized_weights[key] = value
if key.startswith("encoder.mid_blocks.") and shard.is_last_layer():
sanitized_weights[key] = value
if "conv_norm_out" in key and shard.is_last_layer():
sanitized_weights[key] = value
if "conv_out" in key and shard.is_last_layer():
sanitized_weights[key] = value
if key.startswith("quant_proj") and shard.is_last_layer():
sanitized_weights[key] = value
return sanitized_weights

View File

@@ -77,13 +77,17 @@ class MLXDynamicShardInferenceEngine(InferenceEngine):
await self.ensure_shard(shard)
await asyncio.get_running_loop().run_in_executor(self.executor, self.model.load_weights, path)
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
await self.ensure_shard(shard)
loop = asyncio.get_running_loop()
state = await self.poll_state(request_id)
state = await self.poll_state(request_id) if self.model.model_type != 'StableDiffusionPipeline' else {}
x = mx.array(input_data)
output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, lambda: self.model(x, **state)))
return output_data
if self.model.model_type != 'StableDiffusionPipeline':
output_data = await loop.run_in_executor(self.executor, lambda: self.model(x, **state, **(inference_state or {})))
else:
output_data, inference_state = await loop.run_in_executor(self.executor, lambda: self.model(x, **state, **(inference_state or {})))
output_data = np.array(output_data)
return output_data, inference_state
async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce"):
await self.ensure_shard(shard)

View File

@@ -62,8 +62,16 @@ def _get_classes(config: dict):
def load_config(model_path: Path) -> dict:
try:
with open(model_path/"config.json", "r") as f:
config = json.load(f)
config_path = model_path / "config.json"
if config_path.exists():
with open(config_path, "r") as f:
config = json.load(f)
return config
model_index_path = model_path / "model_index.json"
if model_index_path.exists():
config = load_model_index(model_path, model_index_path)
return config
except FileNotFoundError:
logging.error(f"Config file not found in {model_path}")
raise
@@ -110,6 +118,24 @@ def load_model_shard(
# Try weight for back-compat
weight_files = glob.glob(str(model_path/"weight*.safetensors"))
model_class, model_args_class = _get_classes(config=config)
class ShardedModel(model_class):
def __init__(self, args):
super().__init__(args)
self.shard = Shard(args.shard.model_id, args.shard.start_layer, args.shard.end_layer, args.shard.n_layers)
def __call__(self, x, *args, **kwargs):
y = super().__call__(x, *args, **kwargs)
return y
model_args = model_args_class.from_dict(config)
model = ShardedModel(model_args)
if config.get("model_index", False):
model.load()
return model
if not weight_files:
logging.error(f"No safetensors found in {model_path}")
raise FileNotFoundError(f"No safetensors found in {model_path}")
@@ -129,19 +155,7 @@ def load_model_shard(
weights.update(mx.load(wf))
model_class, model_args_class = _get_classes(config=config)
class ShardedModel(model_class):
def __init__(self, args):
super().__init__(args)
self.shard = Shard(args.shard.model_id, args.shard.start_layer, args.shard.end_layer, args.shard.n_layers)
def __call__(self, x, *args, **kwargs):
y = super().__call__(x, *args, **kwargs)
return y
model_args = model_args_class.from_dict(config)
model = ShardedModel(model_args)
if hasattr(model, "sanitize"):
weights = model.sanitize(weights)
@@ -186,6 +200,9 @@ async def load_shard(
processor.eos_token_id = processor.tokenizer.eos_token_id
processor.encode = processor.tokenizer.encode
return model, processor
elif hasattr(model, "tokenizer"):
tokenizer = model.tokenizer
return model, tokenizer
else:
tokenizer = await resolve_tokenizer(model_path)
return model, tokenizer
@@ -214,3 +231,27 @@ async def get_image_from_str(_image_str: str):
return img
else:
raise ValueError("Invalid image_str format. Must be a URL or a base64 encoded image.")
# loading a combined config for all models in the index
def load_model_index(model_path: Path, model_index_path: Path):
models_config = {}
with open(model_index_path, "r") as f:
model_index = json.load(f)
models_config["model_index"] = True
models_config["model_type"] = model_index["_class_name"]
models_config["models"] = {}
for model in model_index.keys():
model_config_path = glob.glob(str(model_path / model / "*config.json"))
if len(model_config_path)>0:
with open(model_config_path[0], "r") as f:
model_config = { }
model_config["model_type"] = model
model_config["config"] = json.load(f)
model_config["path"] = model_path / model
if model_config["path"]/"*model.safetensors":
model_config["config"].update({"weight_files": list(glob.glob(str(model_config["path"]/"*model.safetensors")))})
model_config["path"] = str(model_path / model)
m = {}
m[model] = model_config
models_config.update(m)
return models_config

View File

@@ -1,22 +1,16 @@
import pytest
import json
import numpy as np
from exo.inference.dummy_inference_engine import DummyInferenceEngine
from exo.inference.shard import Shard
class MockShardDownloader:
async def ensure_shard(self, shard):
pass
@pytest.mark.asyncio
async def test_dummy_inference_specific():
engine = DummyInferenceEngine(MockShardDownloader())
engine = DummyInferenceEngine()
test_shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
test_prompt = "This is a test prompt"
result = await engine.infer_prompt("test_request", test_shard, test_prompt)
result, _ = await engine.infer_prompt("test_request", test_shard, test_prompt)
print(f"Inference result shape: {result.shape}")
@@ -26,20 +20,20 @@ async def test_dummy_inference_specific():
@pytest.mark.asyncio
async def test_dummy_inference_engine():
# Initialize the DummyInferenceEngine
engine = DummyInferenceEngine(MockShardDownloader())
engine = DummyInferenceEngine()
# Create a test shard
shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
# Test infer_prompt
output = await engine.infer_prompt("test_id", shard, "Test prompt")
output, _ = await engine.infer_prompt("test_id", shard, "Test prompt")
assert isinstance(output, np.ndarray), "Output should be a numpy array"
assert output.ndim == 2, "Output should be 2-dimensional"
# Test infer_tensor
input_tensor = np.array([[1, 2, 3]])
output = await engine.infer_tensor("test_id", shard, input_tensor)
output, _ = await engine.infer_tensor("test_id", shard, input_tensor)
assert isinstance(output, np.ndarray), "Output should be a numpy array"
assert output.ndim == 2, "Output should be 2-dimensional"

View File

@@ -11,30 +11,30 @@ import numpy as np
# An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str, n_layers: int):
prompt = "In a single word only, what is the last name of the current president of the USA?"
resp_full = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers), prompt=prompt)
resp_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers), prompt=prompt)
token_full = await inference_engine_1.sample(resp_full)
token_full = token_full.reshape(1, -1)
next_resp_full = await inference_engine_1.infer_tensor(
next_resp_full, _ = await inference_engine_1.infer_tensor(
"A",
shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers),
input_data=token_full,
)
pp = n_layers // 2
resp1 = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers), prompt=prompt)
resp2 = await inference_engine_2.infer_tensor(
resp1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers), prompt=prompt)
resp2, _ = await inference_engine_2.infer_tensor(
"B",
shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
input_data=resp1,
)
tokens2 = await inference_engine_1.sample(resp2)
tokens2 = tokens2.reshape(1, -1)
resp3 = await inference_engine_1.infer_tensor(
resp3, _ = await inference_engine_1.infer_tensor(
"B",
shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers),
input_data=tokens2,
)
resp4 = await inference_engine_2.infer_tensor(
resp4, _ = await inference_engine_2.infer_tensor(
"B",
shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
input_data=resp3,

View File

@@ -15,7 +15,7 @@ from .stateful_model import make_prompt_state
from .losses import length_masked_ce_loss
from collections import OrderedDict
import asyncio
from typing import Optional
Tensor.no_grad = True
# default settings
TEMPERATURE = int(os.getenv("TEMPERATURE", 0.85))
@@ -104,7 +104,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
state_dict = await asyncio.get_running_loop().run_in_executor(self.executor, get_state_dict, self.model)
safe_save(state_dict, path)
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
await self.ensure_shard(shard)
def wrap_infer():
x = Tensor(input_data)
@@ -114,7 +114,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
self.states[request_id].start += x.shape[1]
return out.realize()
output_data = await asyncio.get_running_loop().run_in_executor(self.executor, wrap_infer)
return output_data.numpy()
return output_data.numpy(), inference_state
async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss=length_masked_ce_loss):
def step(x, y, l):

View File

@@ -14,7 +14,7 @@ class DummyTokenizer:
self.eos_token_id = 69
self.vocab_size = 1000
def apply_chat_template(self, messages, tokenize=True, add_generation_prompt=True):
def apply_chat_template(self, conversation, tokenize=True, add_generation_prompt=True, tools=None, **kwargs):
return "dummy_tokenized_prompt"
def encode(self, text):

View File

@@ -69,6 +69,7 @@ parser.add_argument("--default-temp", type=float, help="Default token sampling t
parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
parser.add_argument("--node-id-filter", type=str, default=None, help="Comma separated list of allowed node IDs (only for UDP and Tailscale discovery)")
parser.add_argument("--system-prompt", type=str, default=None, help="System prompt for the ChatGPT API")
args = parser.parse_args()
print(f"Selected inference engine: {args.inference_engine}")
@@ -146,10 +147,11 @@ api = ChatGPTAPI(
inference_engine.__class__.__name__,
response_timeout=args.chatgpt_api_response_timeout,
on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None,
default_model=args.default_model
default_model=args.default_model,
system_prompt=args.system_prompt
)
node.on_token.register("update_topology_viz").on_next(
lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None
lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") and inference_engine.shard.model_id != 'stable-diffusion-2-1-base' else None
)
def preemptively_start_download(request_id: str, opaque_status: str):

View File

@@ -92,14 +92,17 @@ model_cards = {
"llava-1.5-7b-hf": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "llava-hf/llava-1.5-7b-hf", }, },
### qwen
"qwen-2.5-0.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-0.5B-Instruct-4bit", }, },
"qwen-2.5-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-1.5B-Instruct-4bit", }, },
"qwen-2.5-coder-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-1.5B-Instruct-4bit", }, },
"qwen-2.5-3b": { "layers": 36, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-3B-Instruct-4bit", }, },
"qwen-2.5-coder-3b": { "layers": 36, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-3B-Instruct-4bit", }, },
"qwen-2.5-coder-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", }, },
"qwen-2.5-coder-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-14B-Instruct-4bit", }, },
"qwen-2.5-coder-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-32B-Instruct-4bit", }, },
"qwen-2.5-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-7B-Instruct-4bit", }, },
"qwen-2.5-coder-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", }, },
"qwen-2.5-math-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-7B-Instruct-4bit", }, },
"qwen-2.5-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-14B-Instruct-4bit", }, },
"qwen-2.5-coder-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-14B-Instruct-4bit", }, },
"qwen-2.5-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-32B-Instruct-4bit", }, },
"qwen-2.5-coder-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-32B-Instruct-4bit", }, },
"qwen-2.5-72b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-72B-Instruct-4bit", }, },
"qwen-2.5-math-72b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-72B-Instruct-4bit", }, },
### nemotron
@@ -108,6 +111,11 @@ model_cards = {
# gemma
"gemma2-9b": { "layers": 42, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-9b-it-4bit", }, },
"gemma2-27b": { "layers": 46, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-27b-it-4bit", }, },
# stable diffusion
"stable-diffusion-2-1-base": { "layers": 31, "repo": { "MLXDynamicShardInferenceEngine": "stabilityai/stable-diffusion-2-1-base" } },
# phi
"phi-3.5-mini": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Phi-3.5-mini-instruct-4bit", }, },
"phi-4": { "layers": 40, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/phi-4-4bit", }, },
# dummy
"dummy": { "layers": 8, "repo": { "DummyInferenceEngine": "dummy", }, },
}
@@ -133,18 +141,24 @@ pretty_name = {
"deepseek-coder-v2-lite": "Deepseek Coder V2 Lite",
"deepseek-coder-v2.5": "Deepseek Coder V2.5",
"llava-1.5-7b-hf": "LLaVa 1.5 7B (Vision Model)",
"qwen-2.5-1.5b": "Qwen 2.5 1.5B",
"qwen-2.5-coder-1.5b": "Qwen 2.5 Coder 1.5B",
"qwen-2.5-3b": "Qwen 2.5 3B",
"qwen-2.5-coder-3b": "Qwen 2.5 Coder 3B",
"qwen-2.5-coder-7b": "Qwen 2.5 Coder 7B",
"qwen-2.5-coder-14b": "Qwen 2.5 Coder 14B",
"qwen-2.5-coder-32b": "Qwen 2.5 Coder 32B",
"qwen-2.5-7b": "Qwen 2.5 7B",
"qwen-2.5-coder-7b": "Qwen 2.5 Coder 7B",
"qwen-2.5-math-7b": "Qwen 2.5 7B (Math)",
"qwen-2.5-14b": "Qwen 2.5 14B",
"qwen-2.5-coder-14b": "Qwen 2.5 Coder 14B",
"qwen-2.5-32b": "Qwen 2.5 32B",
"qwen-2.5-coder-32b": "Qwen 2.5 Coder 32B",
"qwen-2.5-72b": "Qwen 2.5 72B",
"qwen-2.5-math-72b": "Qwen 2.5 72B (Math)",
"phi-3.5-mini": "Phi-3.5 Mini",
"phi-4": "Phi-4",
"llama-3-8b": "Llama 3 8B",
"llama-3-70b": "Llama 3 70B",
"stable-diffusion-2-1-base": "Stable Diffusion 2.1",
}
def get_repo(model_id: str, inference_engine_classname: str) -> Optional[str]:

View File

@@ -11,6 +11,13 @@ from exo.inference.shard import Shard
from exo.topology.topology import Topology
from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
from exo.helpers import DEBUG
import json
import platform
if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
import mlx.core as mx
else:
import numpy as mx
class GRPCPeerHandle(PeerHandle):
@@ -36,11 +43,9 @@ class GRPCPeerHandle(PeerHandle):
async def connect(self):
if self.channel is None:
self.channel = grpc.aio.insecure_channel(self.address, options=[
("grpc.max_metadata_size", 32*1024*1024),
('grpc.max_receive_message_length', 32*1024*1024),
('grpc.max_send_message_length', 32*1024*1024)
])
self.channel = grpc.aio.insecure_channel(
self.address, options=[("grpc.max_metadata_size", 32*1024*1024), ('grpc.max_receive_message_length', 32*1024*1024), ('grpc.max_send_message_length', 32*1024*1024)]
)
self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
await self.channel.channel_ready()
@@ -71,7 +76,7 @@ class GRPCPeerHandle(PeerHandle):
traceback.print_exc()
return False
async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.array]:
async def send_prompt(self, shard: Shard, prompt: str, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]:
request = node_service_pb2.PromptRequest(
prompt=prompt,
shard=node_service_pb2.Shard(
@@ -81,6 +86,7 @@ class GRPCPeerHandle(PeerHandle):
n_layers=shard.n_layers,
),
request_id=request_id,
inference_state=None if inference_state is None else self.serialize_inference_state(inference_state)
)
response = await self.stub.SendPrompt(request)
@@ -89,7 +95,7 @@ class GRPCPeerHandle(PeerHandle):
return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
async def send_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.array]:
async def send_tensor(self, shard: Shard, tensor: np.ndarray, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]:
request = node_service_pb2.TensorRequest(
shard=node_service_pb2.Shard(
model_id=shard.model_id,
@@ -99,6 +105,7 @@ class GRPCPeerHandle(PeerHandle):
),
tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)),
request_id=request_id,
inference_state=None if inference_state is None else self.serialize_inference_state(inference_state)
)
response = await self.stub.SendTensor(request)
@@ -106,7 +113,7 @@ class GRPCPeerHandle(PeerHandle):
return None
return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
async def send_example(self, shard: Shard, example: np.ndarray, target: np.ndarray, length: np.ndarray, train: bool, request_id: Optional[str] = None) -> Optional[np.array]:
request = node_service_pb2.ExampleRequest(
shard=node_service_pb2.Shard(
@@ -128,7 +135,7 @@ class GRPCPeerHandle(PeerHandle):
return loss, grads
else:
return loss
async def send_loss(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.array]:
request = node_service_pb2.TensorRequest(
shard=node_service_pb2.Shard(
@@ -163,10 +170,7 @@ class GRPCPeerHandle(PeerHandle):
topology = Topology()
for node_id, capabilities in response.nodes.items():
device_capabilities = DeviceCapabilities(
model=capabilities.model,
chip=capabilities.chip,
memory=capabilities.memory,
flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8)
model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory, flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8)
)
topology.update_node(node_id, device_capabilities)
for node_id, peer_connections in response.peer_graph.items():
@@ -175,9 +179,35 @@ class GRPCPeerHandle(PeerHandle):
return topology
async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
request = node_service_pb2.SendResultRequest(request_id=request_id, result=result, is_finished=is_finished)
tensor = None
if isinstance(result, np.ndarray):
tensor = node_service_pb2.Tensor(tensor_data=result.tobytes(), shape=result.shape, dtype=str(result.dtype))
result = []
request = node_service_pb2.SendResultRequest(request_id=request_id, result=result, tensor=tensor, is_finished=is_finished)
await self.stub.SendResult(request)
async def send_opaque_status(self, request_id: str, status: str) -> None:
request = node_service_pb2.SendOpaqueStatusRequest(request_id=request_id, status=status)
await self.stub.SendOpaqueStatus(request)
def serialize_inference_state(self, inference_state: dict) -> node_service_pb2.InferenceState:
proto_inference_state = node_service_pb2.InferenceState()
other_data = {}
for k, v in inference_state.items():
if isinstance(v, mx.array):
np_array = np.array(v)
tensor_data = node_service_pb2.Tensor(tensor_data=np_array.tobytes(), shape=list(np_array.shape), dtype=str(np_array.dtype))
proto_inference_state.tensor_data[k].CopyFrom(tensor_data)
elif isinstance(v, list) and all(isinstance(item, mx.array) for item in v):
tensor_list = node_service_pb2.TensorList()
for tensor in v:
np_array = np.array(tensor)
tensor_data = node_service_pb2.Tensor(tensor_data=np_array.tobytes(), shape=list(np_array.shape), dtype=str(np_array.dtype))
tensor_list.tensors.append(tensor_data)
proto_inference_state.tensor_list_data[k].CopyFrom(tensor_list)
else:
# For non-tensor data, we'll still use JSON
other_data[k] = v
if other_data:
proto_inference_state.other_data_json = json.dumps(other_data)
return proto_inference_state

View File

@@ -3,11 +3,19 @@ from concurrent import futures
import numpy as np
from asyncio import CancelledError
import platform
from . import node_service_pb2
from . import node_service_pb2_grpc
from exo import DEBUG
from exo.inference.shard import Shard
from exo.orchestration import Node
import json
if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
import mlx.core as mx
else:
import numpy as mx
class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
@@ -50,7 +58,8 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
)
prompt = request.prompt
request_id = request.request_id
result = await self.node.process_prompt(shard, prompt, request_id)
inference_state = None if request.inference_state is None else self.deserialize_inference_state(request.inference_state)
result = await self.node.process_prompt(shard, prompt, request_id, inference_state)
if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}")
tensor_data = result.tobytes() if result is not None else None
return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
@@ -65,11 +74,13 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
request_id = request.request_id
result = await self.node.process_tensor(shard, tensor, request_id)
inference_state = None if request.inference_state is None else self.deserialize_inference_state(request.inference_state)
result = await self.node.process_tensor(shard, tensor, request_id, inference_state)
if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
tensor_data = result.tobytes() if result is not None else None
return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
async def SendExample(self, request, context):
shard = Shard(
model_id=request.shard.model_id,
@@ -91,7 +102,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
else:
loss = await self.node.process_example(shard, example, target, length, train, request_id)
return node_service_pb2.Loss(loss=loss, grads=None)
async def CollectTopology(self, request, context):
max_depth = request.max_depth
visited = set(request.visited)
@@ -107,12 +118,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
for node_id, cap in topology.nodes.items()
}
peer_graph = {
node_id: node_service_pb2.PeerConnections(
connections=[
node_service_pb2.PeerConnection(to_id=conn.to_id, description=conn.description)
for conn in connections
]
)
node_id: node_service_pb2.PeerConnections(connections=[node_service_pb2.PeerConnection(to_id=conn.to_id, description=conn.description) for conn in connections])
for node_id, connections in topology.peer_graph.items()
}
if DEBUG >= 5: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}")
@@ -122,7 +128,11 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
request_id = request.request_id
result = request.result
is_finished = request.is_finished
img = request.tensor
if DEBUG >= 5: print(f"Received SendResult request: {request_id=} {result=} {is_finished=}")
result = list(result)
if len(img.tensor_data) > 0:
result = np.frombuffer(img.tensor_data, dtype=np.dtype(img.dtype)).reshape(img.shape)
self.node.on_token.trigger_all(request_id, result, is_finished)
return node_service_pb2.Empty()
@@ -135,3 +145,19 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
async def HealthCheck(self, request, context):
return node_service_pb2.HealthCheckResponse(is_healthy=True)
def deserialize_inference_state(self, inference_state_proto: node_service_pb2.InferenceState) -> dict:
inference_state = {}
for k, tensor_data in inference_state_proto.tensor_data.items():
np_array = np.frombuffer(tensor_data.tensor_data, dtype=tensor_data.dtype).reshape(tensor_data.shape)
inference_state[k] = mx.array(np_array)
for k, tensor_list in inference_state_proto.tensor_list_data.items():
inference_state[k] = [mx.array(np.frombuffer(tensor.tensor_data, dtype=tensor.dtype).reshape(tensor.shape)) for tensor in tensor_list.tensors]
if inference_state_proto.other_data_json:
other_data = json.loads(inference_state_proto.other_data_json)
inference_state.update(other_data)
return inference_state

View File

@@ -24,12 +24,14 @@ message PromptRequest {
Shard shard = 1;
string prompt = 2;
optional string request_id = 3;
optional InferenceState inference_state = 4;
}
message TensorRequest {
Shard shard = 1;
Tensor tensor = 2;
optional string request_id = 3;
optional InferenceState inference_state = 4;
}
message ExampleRequest {
@@ -61,6 +63,16 @@ message Tensor {
string dtype = 3;
}
message TensorList {
repeated Tensor tensors = 1;
}
message InferenceState {
map<string, Tensor> tensor_data = 1;
map<string, TensorList> tensor_list_data = 2;
string other_data_json = 3;
}
message CollectTopologyRequest {
repeated string visited = 1;
int32 max_depth = 2;
@@ -96,7 +108,8 @@ message DeviceCapabilities {
message SendResultRequest {
string request_id = 1;
repeated int32 result = 2;
bool is_finished = 3;
optional Tensor tensor = 3;
bool is_finished = 4;
}
message SendOpaqueStatusRequest {

View File

File diff suppressed because one or more lines are too long

View File

@@ -3,7 +3,7 @@
import grpc
import warnings
from . import node_service_pb2 as node__service__pb2
from exo.networking.grpc import node_service_pb2 as exo_dot_networking_dot_grpc_dot_node__service__pb2
GRPC_GENERATED_VERSION = '1.68.0'
GRPC_VERSION = grpc.__version__
@@ -18,7 +18,7 @@ except ImportError:
if _version_not_supported:
raise RuntimeError(
f'The grpc package installed is at version {GRPC_VERSION},'
+ f' but the generated code in node_service_pb2_grpc.py depends on'
+ f' but the generated code in exo/networking/grpc/node_service_pb2_grpc.py depends on'
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
@@ -36,43 +36,43 @@ class NodeServiceStub(object):
"""
self.SendPrompt = channel.unary_unary(
'/node_service.NodeService/SendPrompt',
request_serializer=node__service__pb2.PromptRequest.SerializeToString,
response_deserializer=node__service__pb2.Tensor.FromString,
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.PromptRequest.SerializeToString,
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.FromString,
_registered_method=True)
self.SendTensor = channel.unary_unary(
'/node_service.NodeService/SendTensor',
request_serializer=node__service__pb2.TensorRequest.SerializeToString,
response_deserializer=node__service__pb2.Tensor.FromString,
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.TensorRequest.SerializeToString,
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.FromString,
_registered_method=True)
self.SendExample = channel.unary_unary(
'/node_service.NodeService/SendExample',
request_serializer=node__service__pb2.ExampleRequest.SerializeToString,
response_deserializer=node__service__pb2.Loss.FromString,
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.ExampleRequest.SerializeToString,
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Loss.FromString,
_registered_method=True)
self.GetInferenceResult = channel.unary_unary(
'/node_service.NodeService/GetInferenceResult',
request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString,
response_deserializer=node__service__pb2.InferenceResult.FromString,
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.GetInferenceResultRequest.SerializeToString,
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.InferenceResult.FromString,
_registered_method=True)
self.CollectTopology = channel.unary_unary(
'/node_service.NodeService/CollectTopology',
request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
response_deserializer=node__service__pb2.Topology.FromString,
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.CollectTopologyRequest.SerializeToString,
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Topology.FromString,
_registered_method=True)
self.SendResult = channel.unary_unary(
'/node_service.NodeService/SendResult',
request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
response_deserializer=node__service__pb2.Empty.FromString,
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.SendResultRequest.SerializeToString,
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.FromString,
_registered_method=True)
self.SendOpaqueStatus = channel.unary_unary(
'/node_service.NodeService/SendOpaqueStatus',
request_serializer=node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
response_deserializer=node__service__pb2.Empty.FromString,
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.FromString,
_registered_method=True)
self.HealthCheck = channel.unary_unary(
'/node_service.NodeService/HealthCheck',
request_serializer=node__service__pb2.HealthCheckRequest.SerializeToString,
response_deserializer=node__service__pb2.HealthCheckResponse.FromString,
request_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckRequest.SerializeToString,
response_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckResponse.FromString,
_registered_method=True)
@@ -132,43 +132,43 @@ def add_NodeServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
'SendPrompt': grpc.unary_unary_rpc_method_handler(
servicer.SendPrompt,
request_deserializer=node__service__pb2.PromptRequest.FromString,
response_serializer=node__service__pb2.Tensor.SerializeToString,
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.PromptRequest.FromString,
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.SerializeToString,
),
'SendTensor': grpc.unary_unary_rpc_method_handler(
servicer.SendTensor,
request_deserializer=node__service__pb2.TensorRequest.FromString,
response_serializer=node__service__pb2.Tensor.SerializeToString,
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.TensorRequest.FromString,
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.SerializeToString,
),
'SendExample': grpc.unary_unary_rpc_method_handler(
servicer.SendExample,
request_deserializer=node__service__pb2.ExampleRequest.FromString,
response_serializer=node__service__pb2.Loss.SerializeToString,
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.ExampleRequest.FromString,
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Loss.SerializeToString,
),
'GetInferenceResult': grpc.unary_unary_rpc_method_handler(
servicer.GetInferenceResult,
request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString,
response_serializer=node__service__pb2.InferenceResult.SerializeToString,
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.GetInferenceResultRequest.FromString,
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.InferenceResult.SerializeToString,
),
'CollectTopology': grpc.unary_unary_rpc_method_handler(
servicer.CollectTopology,
request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
response_serializer=node__service__pb2.Topology.SerializeToString,
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.CollectTopologyRequest.FromString,
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Topology.SerializeToString,
),
'SendResult': grpc.unary_unary_rpc_method_handler(
servicer.SendResult,
request_deserializer=node__service__pb2.SendResultRequest.FromString,
response_serializer=node__service__pb2.Empty.SerializeToString,
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.SendResultRequest.FromString,
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.SerializeToString,
),
'SendOpaqueStatus': grpc.unary_unary_rpc_method_handler(
servicer.SendOpaqueStatus,
request_deserializer=node__service__pb2.SendOpaqueStatusRequest.FromString,
response_serializer=node__service__pb2.Empty.SerializeToString,
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.SendOpaqueStatusRequest.FromString,
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.SerializeToString,
),
'HealthCheck': grpc.unary_unary_rpc_method_handler(
servicer.HealthCheck,
request_deserializer=node__service__pb2.HealthCheckRequest.FromString,
response_serializer=node__service__pb2.HealthCheckResponse.SerializeToString,
request_deserializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckRequest.FromString,
response_serializer=exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
@@ -196,8 +196,8 @@ class NodeService(object):
request,
target,
'/node_service.NodeService/SendPrompt',
node__service__pb2.PromptRequest.SerializeToString,
node__service__pb2.Tensor.FromString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.PromptRequest.SerializeToString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.FromString,
options,
channel_credentials,
insecure,
@@ -223,8 +223,8 @@ class NodeService(object):
request,
target,
'/node_service.NodeService/SendTensor',
node__service__pb2.TensorRequest.SerializeToString,
node__service__pb2.Tensor.FromString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.TensorRequest.SerializeToString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.Tensor.FromString,
options,
channel_credentials,
insecure,
@@ -250,8 +250,8 @@ class NodeService(object):
request,
target,
'/node_service.NodeService/SendExample',
node__service__pb2.ExampleRequest.SerializeToString,
node__service__pb2.Loss.FromString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.ExampleRequest.SerializeToString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.Loss.FromString,
options,
channel_credentials,
insecure,
@@ -277,8 +277,8 @@ class NodeService(object):
request,
target,
'/node_service.NodeService/GetInferenceResult',
node__service__pb2.GetInferenceResultRequest.SerializeToString,
node__service__pb2.InferenceResult.FromString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.GetInferenceResultRequest.SerializeToString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.InferenceResult.FromString,
options,
channel_credentials,
insecure,
@@ -304,8 +304,8 @@ class NodeService(object):
request,
target,
'/node_service.NodeService/CollectTopology',
node__service__pb2.CollectTopologyRequest.SerializeToString,
node__service__pb2.Topology.FromString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.CollectTopologyRequest.SerializeToString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.Topology.FromString,
options,
channel_credentials,
insecure,
@@ -331,8 +331,8 @@ class NodeService(object):
request,
target,
'/node_service.NodeService/SendResult',
node__service__pb2.SendResultRequest.SerializeToString,
node__service__pb2.Empty.FromString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.SendResultRequest.SerializeToString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
@@ -358,8 +358,8 @@ class NodeService(object):
request,
target,
'/node_service.NodeService/SendOpaqueStatus',
node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
node__service__pb2.Empty.FromString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.Empty.FromString,
options,
channel_credentials,
insecure,
@@ -385,8 +385,8 @@ class NodeService(object):
request,
target,
'/node_service.NodeService/HealthCheck',
node__service__pb2.HealthCheckRequest.SerializeToString,
node__service__pb2.HealthCheckResponse.FromString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckRequest.SerializeToString,
exo_dot_networking_dot_grpc_dot_node__service__pb2.HealthCheckResponse.FromString,
options,
channel_credentials,
insecure,

View File

@@ -1,7 +1,9 @@
import os
import asyncio
from exo.networking.discovery import Discovery
from typing import Dict, List, Callable
from typing import Dict, List, Callable, Optional
from concurrent.futures import ThreadPoolExecutor
from exo.networking.discovery import Discovery
from exo.topology.device_capabilities import DeviceCapabilities
from exo.networking.manual.network_topology_config import NetworkTopology, PeerConfig
from exo.helpers import DEBUG_DISCOVERY
@@ -13,28 +15,25 @@ class ManualDiscovery(Discovery):
self,
network_config_path: str,
node_id: str,
create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
create_peer_handle: Callable[[str, str, str, DeviceCapabilities], PeerHandle],
):
self.topology = NetworkTopology.from_path(network_config_path)
self.network_config_path = network_config_path
self.node_id = node_id
self.create_peer_handle = create_peer_handle
if node_id not in self.topology.peers:
raise ValueError(
f"Node ID {node_id} not found in network config file {network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in self.topology.peers]}"
)
self.listen_task = None
self.known_peers: Dict[str, PeerHandle] = {}
self.peers_in_network: Dict[str, PeerConfig] = self.topology.peers
self.peers_in_network.pop(node_id)
self._cached_peers: Dict[str, PeerConfig] = {}
self._last_modified_time: Optional[float] = None
self._file_executor = ThreadPoolExecutor(max_workers=1)
async def start(self) -> None:
self.listen_task = asyncio.create_task(self.task_find_peers_from_config())
async def stop(self) -> None:
if self.listen_task:
self.listen_task.cancel()
if self.listen_task: self.listen_task.cancel()
self._file_executor.shutdown(wait=True)
async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
if wait_for_peers > 0:
@@ -47,7 +46,9 @@ class ManualDiscovery(Discovery):
async def task_find_peers_from_config(self):
if DEBUG_DISCOVERY >= 2: print("Starting task to find peers from config...")
while True:
for peer_id, peer_config in self.peers_in_network.items():
peers_from_config = await self._get_peers()
new_known_peers = {}
for peer_id, peer_config in peers_from_config.items():
try:
if DEBUG_DISCOVERY >= 2: print(f"Checking peer {peer_id=} at {peer_config.address}:{peer_config.port}")
peer = self.known_peers.get(peer_id)
@@ -57,15 +58,44 @@ class ManualDiscovery(Discovery):
is_healthy = await peer.health_check()
if is_healthy:
if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.")
self.known_peers[peer_id] = peer
else:
if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy.")
try:
del self.known_peers[peer_id]
except KeyError:
pass
new_known_peers[peer_id] = peer
elif DEBUG_DISCOVERY >= 2:
print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy. Removing.")
except Exception as e:
if DEBUG_DISCOVERY >= 2: print(f"Exception occured when attempting to add {peer_id=}: {e}")
self.known_peers = new_known_peers
await asyncio.sleep(1.0)
if DEBUG_DISCOVERY >= 2: print(f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}")
async def _get_peers(self):
try:
loop = asyncio.get_running_loop()
current_mtime = await loop.run_in_executor(self._file_executor, os.path.getmtime, self.network_config_path)
if (self._cached_peers is not None and self._last_modified_time is not None and current_mtime <= self._last_modified_time):
return self._cached_peers
topology = await loop.run_in_executor(self._file_executor, NetworkTopology.from_path, self.network_config_path)
if self.node_id not in topology.peers:
raise ValueError(
f"Node ID {self.node_id} not found in network config file "
f"{self.network_config_path}. Please run with `node_id` set to "
f"one of the keys in the config file: {[k for k, _ in topology.peers]}"
)
peers_in_network = topology.peers
peers_in_network.pop(self.node_id)
self._cached_peers = peers_in_network
self._last_modified_time = current_mtime
return peers_in_network
except Exception as e:
if DEBUG_DISCOVERY >= 2:
print(f"Error when loading network config file from {self.network_config_path}. "
f"Please update the config file in order to successfully discover peers. "
f"Exception: {e}")
return self._cached_peers

View File

@@ -29,4 +29,4 @@
}
}
}
}
}

View File

@@ -1,3 +1,4 @@
import json
import asyncio
import unittest
from unittest import mock
@@ -14,8 +15,12 @@ class TestSingleNodeManualDiscovery(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
self.peer1 = mock.AsyncMock()
self.peer1.connect = mock.AsyncMock()
self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1)
_ = self.discovery1.start()
self.discovery1 = ManualDiscovery(
root_path,
"node1",
create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1,
)
await self.discovery1.start()
async def asyncTearDown(self):
await self.discovery1.stop()
@@ -33,8 +38,16 @@ class TestManualDiscovery(unittest.IsolatedAsyncioTestCase):
self.peer2 = mock.AsyncMock()
self.peer1.connect = mock.AsyncMock()
self.peer2.connect = mock.AsyncMock()
self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1)
self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2)
self.discovery1 = ManualDiscovery(
root_path,
"node1",
create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1,
)
self.discovery2 = ManualDiscovery(
root_path,
"node2",
create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2,
)
await self.discovery1.start()
await self.discovery2.start()
@@ -63,8 +76,16 @@ class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
self.server2 = GRPCServer(self.node2, config.peers["node2"].address, config.peers["node2"].port)
await self.server1.start()
await self.server2.start()
self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
self.discovery1 = ManualDiscovery(
root_path,
"node1",
create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
)
self.discovery2 = ManualDiscovery(
root_path,
"node2",
create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
)
await self.discovery1.start()
await self.discovery2.start()
@@ -98,6 +119,63 @@ class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
self.assertFalse(await peers1[0].is_connected())
self.assertFalse(await peers2[0].is_connected())
async def test_dynamic_config_update(self):
initial_peers = await self.discovery1.discover_peers(wait_for_peers=1)
self.assertEqual(len(initial_peers), 1)
# Save original config for cleanup
with open(root_path, "r") as f:
original_config = json.load(f)
try:
updated_config = {
"peers": {
**original_config["peers"],
"node3": {
"address": "localhost",
"port": 50053,
"device_capabilities": {
"model": "Unknown Model",
"chip": "Unknown Chip",
"memory": 0,
"flops": {"fp32": 0, "fp16": 0, "int8": 0},
},
},
}
}
with open(root_path, "w") as f:
json.dump(updated_config, f, indent=2)
node3 = mock.AsyncMock(spec=Node)
server3 = GRPCServer(node3, "localhost", 50053)
await server3.start()
try:
# Wait for the config to be reloaded
await asyncio.sleep(1.5)
updated_peers = await self.discovery1.discover_peers(wait_for_peers=2)
self.assertEqual(len(updated_peers), 2)
for peer in updated_peers:
await peer.connect()
self.assertTrue(await peer.is_connected())
finally:
await server3.stop()
finally:
# Restore the original config file
with open(root_path, "w") as f:
json.dump(original_config, f, indent=2)
# Wait for the config to be reloaded again
await asyncio.sleep(1.5)
updated_peers = await self.discovery1.discover_peers(wait_for_peers=1)
self.assertEqual(len(updated_peers), 1)
if __name__ == "__main__":
asyncio.run(unittest.main())

View File

@@ -112,37 +112,49 @@ class Node:
shard,
result: np.ndarray,
request_id: Optional[str] = None,
inference_state: Optional[dict] = None,
):
if request_id not in self.buffered_token_output:
self.buffered_token_output[request_id] = ([], False)
is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
if shard.is_last_layer() and not is_finished:
token = await self.inference_engine.sample(result, temp=self.default_sample_temperature)
await self.inference_engine.ensure_shard(shard)
self.buffered_token_output[request_id][0].append(token.item())
is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id or is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
asyncio.create_task(self.broadcast_result(request_id, *self.buffered_token_output[request_id]))
forward = token.reshape(1, -1)
self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished))
if shard.model_id != 'stable-diffusion-2-1-base':
if request_id not in self.buffered_token_output:
self.buffered_token_output[request_id] = ([], False)
is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
if shard.is_last_layer() and not is_finished:
token = await self.inference_engine.sample(result, temp=self.default_sample_temperature)
await self.inference_engine.ensure_shard(shard)
self.buffered_token_output[request_id][0].append(token.item())
is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id or is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
asyncio.create_task(self.broadcast_result(request_id, *self.buffered_token_output[request_id]))
forward = token.reshape(1, -1)
intermediate_result = self.buffered_token_output[request_id][0]
else:
forward = result
else:
await self.inference_engine.ensure_shard(shard)
is_finished = inference_state.get("is_finished", False)
intermediate_result, inference_state = self.handle_stable_diffusion(inference_state, result)
forward = result
if shard.is_last_layer():
self.trigger_on_token_callbacks(request_id, intermediate_result, is_finished)
asyncio.create_task(self.broadcast_result(request_id, intermediate_result, is_finished))
if is_finished:
self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
if shard.model_id != 'stable-diffusion-2-1-base':
self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
self.outstanding_requests.pop(request_id)
else:
self.outstanding_requests[request_id] = "waiting"
asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1)))
asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1), inference_state))
return np.array(self.buffered_token_output[request_id][0]) if shard.model_id != 'stable-diffusion-2-1-base' else intermediate_result
return np.array(self.buffered_token_output[request_id][0])
async def process_prompt(
self,
base_shard: Shard,
prompt: str,
request_id: Optional[str] = None,
inference_state: Optional[dict] = {},
) -> Optional[np.ndarray]:
shard = self.get_current_shard(base_shard)
asyncio.create_task(
@@ -160,7 +172,7 @@ class Node:
)
)
start_time = time.perf_counter_ns()
resp = await self._process_prompt(base_shard, prompt, request_id)
resp = await self._process_prompt(base_shard, prompt, request_id, inference_state)
end_time = time.perf_counter_ns()
elapsed_time_ns = end_time - start_time
asyncio.create_task(
@@ -181,7 +193,7 @@ class Node:
)
return resp
async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.ndarray]:
async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[dict] = None) -> Optional[np.ndarray]:
if request_id is None:
request_id = str(uuid.uuid4())
shard = self.get_current_shard(base_shard)
@@ -190,12 +202,12 @@ class Node:
if not shard.is_first_layer():
if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
self.outstanding_requests[request_id] = "waiting"
resp = await self.forward_prompt(shard, prompt, request_id, 0)
resp = await self.forward_prompt(shard, prompt, request_id, 0, inference_state)
return None
else:
self.outstanding_requests[request_id] = "processing"
result = await self.inference_engine.infer_prompt(request_id, shard, prompt)
ret = await self.process_inference_result(shard, result, request_id)
result, inference_state = await self.inference_engine.infer_prompt(request_id, shard, prompt, inference_state)
ret = await self.process_inference_result(shard, result, request_id, inference_state)
return result
async def enqueue_example(
@@ -308,7 +320,7 @@ class Node:
loss, grad = await self.inference_engine.train(request_id, shard, example, target, length)
else:
self.outstanding_requests[request_id] = "preprocessing"
step = await self.inference_engine.infer_tensor(request_id, shard, example)
step, _ = await self.inference_engine.infer_tensor(request_id, shard, example)
self.outstanding_requests[request_id] = "waiting"
loss, backgrad = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset = 1))
self.outstanding_requests[request_id] = "training"
@@ -324,7 +336,7 @@ class Node:
loss = await self.inference_engine.evaluate(request_id, shard, example, target, length)
else:
self.outstanding_requests[request_id] = "preprocessing"
step = await self.inference_engine.infer_tensor(request_id, shard, example)
step, _ = await self.inference_engine.infer_tensor(request_id, shard, example)
self.outstanding_requests[request_id] = "waiting"
loss = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset = 1))
self.outstanding_requests.pop(request_id)
@@ -340,6 +352,7 @@ class Node:
base_shard: Shard,
tensor: np.ndarray,
request_id: Optional[str] = None,
inference_state: Optional[dict] = None,
) -> Optional[np.ndarray]:
shard = self.get_current_shard(base_shard)
asyncio.create_task(
@@ -358,7 +371,7 @@ class Node:
)
)
start_time = time.perf_counter_ns()
resp = await self._process_tensor(shard, tensor, request_id)
resp = await self._process_tensor(shard, tensor, request_id, inference_state)
end_time = time.perf_counter_ns()
elapsed_time_ns = end_time - start_time
asyncio.create_task(
@@ -383,6 +396,7 @@ class Node:
base_shard: Shard,
tensor: np.ndarray,
request_id: Optional[str] = None,
inference_state: Optional[dict] = None,
) -> Optional[np.ndarray]:
if request_id is None:
request_id = str(uuid.uuid4())
@@ -391,8 +405,8 @@ class Node:
if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
try:
self.outstanding_requests[request_id] = "processing"
result = await self.inference_engine.infer_tensor(request_id, shard, tensor)
ret = await self.process_inference_result(shard, result, request_id)
result, inference_state = await self.inference_engine.infer_tensor(request_id, shard, tensor, inference_state)
ret = await self.process_inference_result(shard, result, request_id, inference_state)
return ret
except Exception as e:
self.outstanding_requests.pop(request_id)
@@ -427,19 +441,20 @@ class Node:
prompt: str,
request_id: str,
target_index: int,
inference_state: Optional[dict] = None,
) -> None:
if DEBUG >= 1: print(f"target partition index: {target_index}")
target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
next_shard = self.get_current_shard(base_shard, target_index)
if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. next shard: {next_shard}")
if target_id == self.id:
await self.process_prompt(next_shard, prompt, request_id)
await self.process_prompt(next_shard, prompt, request_id, inference_state)
else:
target_peer = next((p for p in self.peers if p.id() == target_id), None)
if not target_peer:
raise ValueError(f"Peer for {target_index} not found")
if DEBUG >= 1: print(f"Sending prompt to {target_peer.id()}: {prompt}")
await target_peer.send_prompt(next_shard, prompt, request_id=request_id)
await target_peer.send_prompt(next_shard, prompt, request_id=request_id, inference_state=inference_state)
async def forward_tensor(
self,
@@ -447,19 +462,20 @@ class Node:
tensor: np.ndarray,
request_id: str,
target_index: int,
inference_state: Optional[dict] = None,
) -> None:
if DEBUG >= 1: print(f"target partition index: {target_index}")
target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
next_shard = self.get_current_shard(base_shard, target_index)
if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. target shard: {next_shard}")
if target_id == self.id:
await self.process_tensor(next_shard, tensor, request_id)
await self.process_tensor(next_shard, tensor, request_id, inference_state)
else:
target_peer = next((p for p in self.peers if p.id() == target_id), None)
if not target_peer:
raise ValueError(f"Peer for {target_index} not found")
if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor}")
await target_peer.send_tensor(next_shard, tensor, request_id=request_id)
await target_peer.send_tensor(next_shard, tensor, request_id=request_id, inference_state=inference_state)
def get_partition_index(self, offset: int = 0):
if not self.partitioning_strategy:
@@ -632,3 +648,12 @@ class Node:
@property
def current_topology(self) -> Topology:
return self.topology
def handle_stable_diffusion(self, inference_state, result):
if inference_state['is_step_finished']:
inference_state['step']+=1
progress = [inference_state['step'],inference_state['total_steps']]
intermediate_result = result
if progress[0] == progress[1]:
intermediate_result = result
return intermediate_result, inference_state

View File

@@ -182,7 +182,25 @@
const div = document.createElement('div');
div.className = `message message-role-${role}`;
try {
div.innerHTML = DOMPurify.sanitize(marked.parse(content));
if (content.includes('![Generated Image]')) {
const imageUrl = content.match(/\((.*?)\)/)[1];
const img = document.createElement('img');
img.src = imageUrl;
img.alt = 'Generated Image';
img.onclick = async () => {
try {
const response = await fetch(img.src);
const blob = await response.blob();
const file = new File([blob], 'image.png', { type: 'image/png' });
handleImageUpload({ target: { files: [file] } });
} catch (error) {
console.error('Error fetching image:', error);
}
};
div.appendChild(img);
} else {
div.innerHTML = DOMPurify.sanitize(marked.parse(content));
}
} catch (e) {
console.log(content);
console.error(e);
@@ -266,7 +284,7 @@
</span>
</div>
<div class="input">
<button @click="$refs.imageUpload.click()" class="image-input-button" x-show="cstate.selectedModel === 'llava-1.5-7b-hf'">
<button @click="$refs.imageUpload.click()" class="image-input-button" x-show="cstate.selectedModel === 'llava-1.5-7b-hf' || cstate.selectedModel === 'stable-diffusion-2-1-base'">
<i class="fas fa-image"></i>
</button>
<input @change="$data.handleImageUpload($event)" accept="image/*" id="image-upload" style="display: none;" type="file" x-ref="imageUpload"/>

View File

@@ -228,53 +228,110 @@ document.addEventListener("alpine:init", () => {
};
}
});
const containsImage = apiMessages.some(msg => Array.isArray(msg.content) && msg.content.some(item => item.type === 'image_url'));
if (containsImage) {
// Map all messages with string content to object with type text
apiMessages = apiMessages.map(msg => {
if (typeof msg.content === 'string') {
return {
...msg,
content: [
{
type: "text",
text: msg.content
}
]
};
}
return msg;
if (this.cstate.selectedModel === "stable-diffusion-2-1-base") {
// Send a request to the image generation endpoint
console.log(apiMessages[apiMessages.length - 1].content)
console.log(this.cstate.selectedModel)
console.log(this.endpoint)
const response = await fetch(`${this.endpoint}/image/generations`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
"model": 'stable-diffusion-2-1-base',
"prompt": apiMessages[apiMessages.length - 1].content,
"image_url": this.imageUrl
}),
});
}
// start receiving server sent events
let gottenFirstChunk = false;
for await (
const chunk of this.openaiChatCompletion(this.cstate.selectedModel, apiMessages)
) {
if (!gottenFirstChunk) {
this.cstate.messages.push({ role: "assistant", content: "" });
gottenFirstChunk = true;
if (!response.ok) {
throw new Error("Failed to fetch");
}
// add chunk to the last message
this.cstate.messages[this.cstate.messages.length - 1].content += chunk;
// calculate performance tracking
tokens += 1;
this.total_tokens += 1;
if (start_time === 0) {
start_time = Date.now();
this.time_till_first = start_time - prefill_start;
} else {
const diff = Date.now() - start_time;
if (diff > 0) {
this.tokens_per_second = tokens / (diff / 1000);
const reader = response.body.getReader();
let done = false;
let gottenFirstChunk = false;
while (!done) {
const { value, done: readerDone } = await reader.read();
done = readerDone;
const decoder = new TextDecoder();
if (value) {
// Assume non-binary data (text) comes first
const chunk = decoder.decode(value, { stream: true });
const parsed = JSON.parse(chunk);
console.log(parsed)
if (parsed.progress) {
if (!gottenFirstChunk) {
this.cstate.messages.push({ role: "assistant", content: "" });
gottenFirstChunk = true;
}
this.cstate.messages[this.cstate.messages.length - 1].content = parsed.progress;
}
else if (parsed.images) {
if (!gottenFirstChunk) {
this.cstate.messages.push({ role: "assistant", content: "" });
gottenFirstChunk = true;
}
const imageUrl = parsed.images[0].url;
console.log(imageUrl)
this.cstate.messages[this.cstate.messages.length - 1].content = `![Generated Image](${imageUrl}?t=${Date.now()})`;
}
}
}
}
else{
const containsImage = apiMessages.some(msg => Array.isArray(msg.content) && msg.content.some(item => item.type === 'image_url'));
if (containsImage) {
// Map all messages with string content to object with type text
apiMessages = apiMessages.map(msg => {
if (typeof msg.content === 'string') {
return {
...msg,
content: [
{
type: "text",
text: msg.content
}
]
};
}
return msg;
});
}
console.log(apiMessages)
//start receiving server sent events
let gottenFirstChunk = false;
for await (
const chunk of this.openaiChatCompletion(this.cstate.selectedModel, apiMessages)
) {
if (!gottenFirstChunk) {
this.cstate.messages.push({ role: "assistant", content: "" });
gottenFirstChunk = true;
}
// add chunk to the last message
this.cstate.messages[this.cstate.messages.length - 1].content += chunk;
// calculate performance tracking
tokens += 1;
this.total_tokens += 1;
if (start_time === 0) {
start_time = Date.now();
this.time_till_first = start_time - prefill_start;
} else {
const diff = Date.now() - start_time;
if (diff > 0) {
this.tokens_per_second = tokens / (diff / 1000);
}
}
}
}
// Clean the cstate before adding it to histories
const cleanedCstate = JSON.parse(JSON.stringify(this.cstate));
cleanedCstate.messages = cleanedCstate.messages.map(msg => {

View File

@@ -149,6 +149,8 @@ def device_capabilities() -> DeviceCapabilities:
return mac_device_capabilities()
elif psutil.LINUX:
return linux_device_capabilities()
elif psutil.WINDOWS:
return windows_device_capabilities()
else:
return DeviceCapabilities(
model="Unknown Device",
@@ -194,6 +196,8 @@ def linux_device_capabilities() -> DeviceCapabilities:
if DEBUG >= 2: print(f"NVIDIA device {gpu_name=} {gpu_memory_info=}")
pynvml.nvmlShutdown()
return DeviceCapabilities(
model=f"Linux Box ({gpu_name})",
chip=gpu_name,
@@ -201,13 +205,24 @@ def linux_device_capabilities() -> DeviceCapabilities:
flops=CHIP_FLOPS.get(gpu_name, DeviceFlops(fp32=0, fp16=0, int8=0)),
)
elif Device.DEFAULT == "AMD":
# TODO AMD support
# For AMD GPUs, pyrsmi is the way (Official python package for rocm-smi)
from pyrsmi import rocml
rocml.smi_initialize()
gpu_name = rocml.smi_get_device_name(0).upper()
gpu_memory_info = rocml.smi_get_device_memory_total(0)
if DEBUG >= 2: print(f"AMD device {gpu_name=} {gpu_memory_info=}")
rocml.smi_shutdown()
return DeviceCapabilities(
model="Linux Box (AMD)",
chip="Unknown AMD",
memory=psutil.virtual_memory().total // 2**20,
model="Linux Box ({gpu_name})",
chip={gpu_name},
memory=gpu_memory_info.total // 2**20,
flops=DeviceFlops(fp32=0, fp16=0, int8=0),
)
else:
return DeviceCapabilities(
model=f"Linux Box (Device: {Device.DEFAULT})",
@@ -215,3 +230,74 @@ def linux_device_capabilities() -> DeviceCapabilities:
memory=psutil.virtual_memory().total // 2**20,
flops=DeviceFlops(fp32=0, fp16=0, int8=0),
)
def windows_device_capabilities() -> DeviceCapabilities:
import psutil
def get_gpu_info():
import win32com.client # install pywin32
wmiObj = win32com.client.GetObject("winmgmts:\\\\.\\root\\cimv2")
gpus = wmiObj.ExecQuery("SELECT * FROM Win32_VideoController")
gpu_info = []
for gpu in gpus:
info = {
"Name": gpu.Name,
"AdapterRAM": gpu.AdapterRAM, # Bug in this property, returns -ve for VRAM > 4GB (uint32 overflow)
"DriverVersion": gpu.DriverVersion,
"VideoProcessor": gpu.VideoProcessor
}
gpu_info.append(info)
return gpu_info
gpus_info = get_gpu_info()
gpu_names = [gpu['Name'] for gpu in gpus_info]
contains_nvidia = any('nvidia' in gpu_name.lower() for gpu_name in gpu_names)
contains_amd = any('amd' in gpu_name.lower() for gpu_name in gpu_names)
if contains_nvidia:
import pynvml
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
gpu_raw_name = pynvml.nvmlDeviceGetName(handle).upper()
gpu_name = gpu_raw_name.rsplit(" ", 1)[0] if gpu_raw_name.endswith("GB") else gpu_raw_name
gpu_memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
if DEBUG >= 2: print(f"NVIDIA device {gpu_name=} {gpu_memory_info=}")
return DeviceCapabilities(
model=f"Windows Box ({gpu_name})",
chip=gpu_name,
memory=gpu_memory_info.total // 2**20,
flops=CHIP_FLOPS.get(gpu_name, DeviceFlops(fp32=0, fp16=0, int8=0)),
)
elif contains_amd:
# For AMD GPUs, pyrsmi is the way (Official python package for rocm-smi)
from pyrsmi import rocml
rocml.smi_initialize()
gpu_name = rocml.smi_get_device_name(0).upper()
gpu_memory_info = rocml.smi_get_device_memory_total(0)
if DEBUG >= 2: print(f"AMD device {gpu_name=} {gpu_memory_info=}")
rocml.smi_shutdown()
return DeviceCapabilities(
model="Windows Box ({gpu_name})",
chip={gpu_name},
memory=gpu_memory_info.total // 2**20,
flops=DeviceFlops(fp32=0, fp16=0, int8=0),
)
else:
return DeviceCapabilities(
model=f"Windows Box (Device: Unknown)",
chip=f"Unknown Chip (Device(s): {gpu_names})",
memory=psutil.virtual_memory().total // 2**20,
flops=DeviceFlops(fp32=0, fp16=0, int8=0),
)

View File

@@ -91,25 +91,70 @@ class TopologyViz:
content = []
requests = list(self.requests.values())[-3:] # Get the 3 most recent requests
max_width = self.console.width - 6 # Full width minus padding and icon
max_lines = 13 # Maximum number of lines for the entire panel content
# Calculate available height for content
panel_height = 15 # Fixed panel height
available_lines = panel_height - 2 # Subtract 2 for panel borders
lines_per_entry = available_lines // len(requests) if requests else 0
for (prompt, output) in reversed(requests):
prompt_icon, output_icon = "💬️", "🤖"
# Process prompt
prompt_lines = prompt.split('\n')
if len(prompt_lines) > max_lines // 2:
prompt_lines = prompt_lines[:max_lines//2 - 1] + ['...']
prompt_text = Text(f"{prompt_icon} ", style="bold bright_blue")
prompt_text.append('\n'.join(line[:max_width] for line in prompt_lines), style="white")
# Calculate max lines for prompt and output
max_prompt_lines = lines_per_entry // 3 # Allocate 1/3 for prompt
max_output_lines = lines_per_entry - max_prompt_lines - 1 # Remaining space minus spacing
# Process prompt
prompt_lines = []
for line in prompt.split('\n'):
words = line.split()
current_line = []
current_length = 0
for word in words:
if current_length + len(word) + 1 <= max_width:
current_line.append(word)
current_length += len(word) + 1
else:
if current_line:
prompt_lines.append(' '.join(current_line))
current_line = [word]
current_length = len(word)
if current_line:
prompt_lines.append(' '.join(current_line))
if len(prompt_lines) > max_prompt_lines:
prompt_lines = prompt_lines[:max_prompt_lines - 1] + ['...']
prompt_text = Text(f"{prompt_icon} ", style="bold bright_blue")
prompt_text.append('\n'.join(prompt_lines), style="white")
# Process output - same word-aware wrapping
output_lines = []
for line in output.split('\n'):
words = line.split()
current_line = []
current_length = 0
for word in words:
if current_length + len(word) + 1 <= max_width:
current_line.append(word)
current_length += len(word) + 1
else:
if current_line:
output_lines.append(' '.join(current_line))
current_line = [word]
current_length = len(word)
if current_line:
output_lines.append(' '.join(current_line))
if len(output_lines) > max_output_lines:
output_lines = output_lines[:max_output_lines - 1] + ['...']
# Process output
output_lines = output.split('\n')
remaining_lines = max_lines - len(prompt_lines) - 2 # -2 for spacing
if len(output_lines) > remaining_lines:
output_lines = output_lines[:remaining_lines - 1] + ['...']
output_text = Text(f"\n{output_icon} ", style="bold bright_magenta")
output_text.append('\n'.join(line[:max_width] for line in output_lines), style="white")
output_text.append('\n'.join(output_lines), style="white")
content.append(prompt_text)
content.append(output_text)
@@ -119,8 +164,8 @@ class TopologyViz:
Group(*content),
title="",
border_style="cyan",
height=15, # Increased height to accommodate multiple lines
expand=True # Allow the panel to expand to full width
height=panel_height,
expand=True
)
def _generate_main_layout(self) -> str:

View File

@@ -1,4 +1,4 @@
#!/bin/bash
#!/usr/bin/env bash
if command -v python3.12 &>/dev/null; then
echo "Python 3.12 is installed, proceeding with python3.12..."

View File

@@ -23,7 +23,7 @@ def run():
"--macos-app-name=exo",
"--macos-app-mode=gui",
"--macos-app-version=0.0.1",
"--macos-signed-app-name=com.exolabs.exo",
"--macos-signed-app-name=net.exolabs.exo",
"--include-distribution-meta=mlx",
"--include-module=mlx._reprlib_fix",
"--include-module=mlx._os_warning",

View File

@@ -1,4 +1,4 @@
#!/bin/bash
#!/usr/bin/env bash
source ./install.sh
pushd exo/networking/grpc
python3 -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. node_service.proto

View File

@@ -1,5 +1,6 @@
import sys
import platform
import subprocess
from setuptools import find_packages, setup
@@ -11,7 +12,6 @@ install_requires = [
"grpcio==1.68.0",
"grpcio-tools==1.68.0",
"Jinja2==3.1.4",
"netifaces==0.11.0",
"numpy==2.0.0",
"nuitka==2.5.1",
"nvidia-ml-py==12.560.30",
@@ -23,6 +23,7 @@ install_requires = [
"pydantic==2.9.2",
"requests==2.32.3",
"rich==13.7.1",
"scapy==2.6.1",
"tenacity==9.0.0",
"tqdm==4.66.4",
"transformers==4.46.3",
@@ -31,19 +32,47 @@ install_requires = [
]
extras_require = {
"formatting": [
"yapf==0.40.2",
],
"apple_silicon": [
"formatting": ["yapf==0.40.2",], "apple_silicon": [
"mlx==0.20.0",
"mlx-lm==0.19.3",
],
], "windows": ["pywin32==308",], "nvidia-gpu": ["nvidia-ml-py==12.560.30",], "amd-gpu": ["pyrsmi==0.2.0"]
}
# Check if running on macOS with Apple Silicon
if sys.platform.startswith("darwin") and platform.machine() == "arm64":
install_requires.extend(extras_require["apple_silicon"])
# Check if running Windows
if sys.platform.startswith("win32"):
install_requires.extend(extras_require["windows"])
def _add_gpu_requires():
global install_requires
# Add Nvidia-GPU
try:
out = subprocess.run(['nvidia-smi', '--query-gpu=name', '--format=csv,noheader'], shell=True, text=True, capture_output=True, check=False)
if out.returncode == 0:
install_requires.extend(extras_require["nvidia-gpu"])
except subprocess.CalledProcessError:
pass
# Add AMD-GPU
# This will mostly work only on Linux, amd/rocm-smi is not yet supported on Windows
try:
out = subprocess.run(['amd-smi', 'list', '--csv'], shell=True, text=True, capture_output=True, check=False)
if out.returncode == 0:
install_requires.extend(extras_require["amd-gpu"])
except:
out = subprocess.run(['rocm-smi', 'list', '--csv'], shell=True, text=True, capture_output=True, check=False)
if out.returncode == 0:
install_requires.extend(extras_require["amd-gpu"])
finally:
pass
_add_gpu_requires()
setup(
name="exo",
version="0.0.1",

View File

@@ -1,4 +1,4 @@
#!/bin/bash
#!/usr/bin/env bash
echo "Starting node 1"
DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 52415 --chatgpt-api-response-timeout 900 > output1.log 2>&1 &

View File

@@ -24,7 +24,7 @@ def test_tokenizer(name, tokenizer, verbose=False):
strip_tokens = lambda s: s.lstrip(tokenizer.decode([tokenizer.bos_token_id])).rstrip(tokenizer.decode([tokenizer.eos_token_id]))
assert text == strip_tokens(decoded) == strip_tokens(reconstructed)
ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", "llava-hf/llava-1.5-7b-hf", "mlx-community/Qwen*", "dummy", "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit"]
ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", "llava-hf/llava-1.5-7b-hf", "mlx-community/Qwen*", "dummy", "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit", "mlx-community/Phi-3.5-mini-instruct-4bit", "mlx-community/phi-4-4bit", "stabilityai/stable-diffusion-2-1-base"]
ignore_pattern = re.compile(r"^(" + "|".join(model.replace("*", ".*") for model in ignore) + r")")
models = []
for model_id in model_cards: