mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-28 15:52:56 -05:00
Compare commits
442 Commits
v1.0
...
v0.0.15-al
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
af734f1bf6 | ||
|
|
ee095766d9 | ||
|
|
a605e233ad | ||
|
|
f9a1e5342b | ||
|
|
7a374a74cd | ||
|
|
5a00899d73 | ||
|
|
cb4bee2694 | ||
|
|
9078d094b9 | ||
|
|
ed70d47cfd | ||
|
|
477e3a5e4c | ||
|
|
be3b9ee973 | ||
|
|
b4e6f8acad | ||
|
|
de99da7c75 | ||
|
|
76d1bd95f5 | ||
|
|
928214d479 | ||
|
|
ce34a886c2 | ||
|
|
d8c3aed0cc | ||
|
|
2c982d9295 | ||
|
|
5fe241ec61 | ||
|
|
05ff20fa89 | ||
|
|
b5fc4bc288 | ||
|
|
5157d80a46 | ||
|
|
75914b4de8 | ||
|
|
d084dbe574 | ||
|
|
1a77a52d71 | ||
|
|
72329ba984 | ||
|
|
f663b0afa2 | ||
|
|
51b5c2ca9b | ||
|
|
9a1f0a85e6 | ||
|
|
2c0d17c336 | ||
|
|
7034ee0fcb | ||
|
|
7a75fb09b2 | ||
|
|
0bebf8dfde | ||
|
|
55c4385db5 | ||
|
|
90690a7d10 | ||
|
|
130d998d36 | ||
|
|
788c49784c | ||
|
|
6b1c8635fc | ||
|
|
24c410c19c | ||
|
|
f6ed830ba6 | ||
|
|
e6b4f2993c | ||
|
|
a25e02c913 | ||
|
|
3675804f4d | ||
|
|
96f1aecb05 | ||
|
|
23a5030604 | ||
|
|
31b56e862f | ||
|
|
9f6c688d62 | ||
|
|
4887be5103 | ||
|
|
75091e206b | ||
|
|
141de0d011 | ||
|
|
263b18a31e | ||
|
|
9cf6818f10 | ||
|
|
837ed5d980 | ||
|
|
9c1bea97e8 | ||
|
|
af171f06fa | ||
|
|
edfa53a4c2 | ||
|
|
4a5b80a958 | ||
|
|
92d1bc01de | ||
|
|
6662d5668c | ||
|
|
a0d673fa3a | ||
|
|
7c649085a1 | ||
|
|
90e0e2761f | ||
|
|
265586f7b4 | ||
|
|
4748bb7dc7 | ||
|
|
ae770db4f3 | ||
|
|
82f75d0ccf | ||
|
|
295f41c5cc | ||
|
|
19a27c5bfd | ||
|
|
d7ca9b7732 | ||
|
|
b349e48b0d | ||
|
|
21586063f6 | ||
|
|
277d63d860 | ||
|
|
74379ef671 | ||
|
|
3c7bd48aa3 | ||
|
|
1df023023e | ||
|
|
b89495f444 | ||
|
|
903950f64e | ||
|
|
a3766f538a | ||
|
|
9711d632e0 | ||
|
|
82ef086010 | ||
|
|
55ea366932 | ||
|
|
63318983de | ||
|
|
fb841a1f50 | ||
|
|
4512366580 | ||
|
|
9525c0e7a7 | ||
|
|
66f73768cc | ||
|
|
fdd05baddb | ||
|
|
59174bdc62 | ||
|
|
cfdaaef8e6 | ||
|
|
d8ffa59dba | ||
|
|
aa1ce21f82 | ||
|
|
4fb01f516d | ||
|
|
a635b23044 | ||
|
|
ad0e0d02d8 | ||
|
|
2644fd02c8 | ||
|
|
88ac12df6c | ||
|
|
dfd9d3eb48 | ||
|
|
200ff4d713 | ||
|
|
b2764f177f | ||
|
|
e57fa1dfa0 | ||
|
|
209163c595 | ||
|
|
495987b50b | ||
|
|
8484eb4165 | ||
|
|
790c08afd4 | ||
|
|
a8a9e3ffa1 | ||
|
|
5c9bcb8620 | ||
|
|
d54e19c20a | ||
|
|
cc78738e24 | ||
|
|
2391051c11 | ||
|
|
112dea1582 | ||
|
|
dc5cdc4d78 | ||
|
|
f8db4e131e | ||
|
|
bbb6856988 | ||
|
|
9ba8bbbcf8 | ||
|
|
8ab9977f01 | ||
|
|
3a4bae0dab | ||
|
|
87d1271d33 | ||
|
|
55d1846f5e | ||
|
|
9954ce8e4d | ||
|
|
09e12d8673 | ||
|
|
98d6e986bd | ||
|
|
d80324fe20 | ||
|
|
97f3bad38f | ||
|
|
461e4f37cb | ||
|
|
07ceb19f0a | ||
|
|
27b4577f38 | ||
|
|
a70943f8d2 | ||
|
|
410d901505 | ||
|
|
5c4ce5392c | ||
|
|
819ec7626e | ||
|
|
ba5bb3e171 | ||
|
|
f4bbcf4c8f | ||
|
|
6b8cd0577e | ||
|
|
218c1e79d9 | ||
|
|
b9eccedc3d | ||
|
|
5f06aa2759 | ||
|
|
349b5344eb | ||
|
|
df3624d27a | ||
|
|
6737e36e23 | ||
|
|
c260689a06 | ||
|
|
fcc699a55f | ||
|
|
e7b98f5ae5 | ||
|
|
ffe78f6d0b | ||
|
|
ce5041ee1b | ||
|
|
9b2c01c873 | ||
|
|
2aed3f3518 | ||
|
|
2af5ee02e4 | ||
|
|
b5cbcbc7a2 | ||
|
|
5f3d000a7b | ||
|
|
bd2e8e7a5a | ||
|
|
40696b21f7 | ||
|
|
4937fb3df8 | ||
|
|
2d631ea53d | ||
|
|
2846a9122f | ||
|
|
553ccce728 | ||
|
|
c587593364 | ||
|
|
3c9efe103d | ||
|
|
627bfcae7c | ||
|
|
d9a836f152 | ||
|
|
29244c6369 | ||
|
|
8c191050a2 | ||
|
|
7b1656140e | ||
|
|
fe50d4d34d | ||
|
|
03aa6cecf1 | ||
|
|
178cc4d961 | ||
|
|
b13e368368 | ||
|
|
9986fb86d4 | ||
|
|
3475be9e9e | ||
|
|
fff8a1a690 | ||
|
|
54605299b8 | ||
|
|
a174c78004 | ||
|
|
b003292b89 | ||
|
|
1dfd058c23 | ||
|
|
2eadaa2c0d | ||
|
|
637446ffa9 | ||
|
|
a31f9e6c20 | ||
|
|
18acb97b42 | ||
|
|
b066c944f3 | ||
|
|
0e34ce2169 | ||
|
|
90de7eada9 | ||
|
|
8d24df2b4b | ||
|
|
e5eb3259a5 | ||
|
|
2e8227fccb | ||
|
|
98118babae | ||
|
|
496a3b49f5 | ||
|
|
aba1bed5ed | ||
|
|
e08522ee97 | ||
|
|
4eb6a6a74a | ||
|
|
94a5e908b0 | ||
|
|
fdc3b5ac02 | ||
|
|
185b1e375c | ||
|
|
078b807654 | ||
|
|
188ac445c9 | ||
|
|
456fbdd2b0 | ||
|
|
41df9ce1d7 | ||
|
|
c609c05e40 | ||
|
|
ba8c514974 | ||
|
|
cde912deef | ||
|
|
154e0f58e4 | ||
|
|
6c82365ee2 | ||
|
|
023ddc207e | ||
|
|
2f0b543a1e | ||
|
|
7ac4004392 | ||
|
|
198308b1eb | ||
|
|
1f108a06ff | ||
|
|
3a58576f8c | ||
|
|
0a07223074 | ||
|
|
58f0a0f547 | ||
|
|
5c0cd1839b | ||
|
|
e2474c3f15 | ||
|
|
1b14be6013 | ||
|
|
036224f877 | ||
|
|
b17faa8199 | ||
|
|
35d90d947c | ||
|
|
8d94b8ae12 | ||
|
|
99a70f1045 | ||
|
|
bd0febe35f | ||
|
|
34ecbbe01c | ||
|
|
427d0718b3 | ||
|
|
b49c4ca0e5 | ||
|
|
41eaaec5a9 | ||
|
|
bf1aafdea7 | ||
|
|
bfa06ee9f3 | ||
|
|
c0534b67c3 | ||
|
|
063964aab3 | ||
|
|
804ad4705a | ||
|
|
c9ded9ba96 | ||
|
|
64365d684f | ||
|
|
9397464fad | ||
|
|
08912d1b64 | ||
|
|
06c2e236b8 | ||
|
|
cb4615c95d | ||
|
|
f55a53ae7e | ||
|
|
25b4af70e0 | ||
|
|
a93092105c | ||
|
|
0c6ab35333 | ||
|
|
e5d54c77a9 | ||
|
|
2ff4638122 | ||
|
|
b6f2385c41 | ||
|
|
9472ab0d2c | ||
|
|
dbb7ad3c08 | ||
|
|
2abe57be21 | ||
|
|
eeecdcb409 | ||
|
|
f9f76129a1 | ||
|
|
8c6d37d9b8 | ||
|
|
1194db6e65 | ||
|
|
8cb7327da2 | ||
|
|
bba0aa0877 | ||
|
|
279354a1fd | ||
|
|
92e2b74902 | ||
|
|
76196b8c2f | ||
|
|
8408c8499f | ||
|
|
c65d1d9141 | ||
|
|
0bd44c0f78 | ||
|
|
f22bc99f2c | ||
|
|
3fda05aa39 | ||
|
|
6c322ac070 | ||
|
|
c5c27a32af | ||
|
|
9f1393dc7f | ||
|
|
32ff3ef9af | ||
|
|
b23c3fdaad | ||
|
|
8b47a9d017 | ||
|
|
f89b85b3f2 | ||
|
|
6f097c9321 | ||
|
|
fb7a0defe1 | ||
|
|
fe506a53d9 | ||
|
|
3f6ef1c763 | ||
|
|
e63c224c71 | ||
|
|
20e3065e57 | ||
|
|
83892d5b7e | ||
|
|
83470a98b4 | ||
|
|
92edfa5efc | ||
|
|
225dcba788 | ||
|
|
6249bee793 | ||
|
|
741c31836e | ||
|
|
d0b7f1b4bb | ||
|
|
90677415c7 | ||
|
|
6cf2af39e8 | ||
|
|
5a1a0f5fd2 | ||
|
|
dd3fd279dc | ||
|
|
61c09631c0 | ||
|
|
e698ef6ab1 | ||
|
|
26351e719d | ||
|
|
5dee5e55fe | ||
|
|
6acfb81860 | ||
|
|
b1142d4ff4 | ||
|
|
a932afc01c | ||
|
|
cdae702673 | ||
|
|
d95f40b6c8 | ||
|
|
97ffb83e86 | ||
|
|
9a11e27c93 | ||
|
|
d6c2146dd9 | ||
|
|
63da9fc194 | ||
|
|
7c0c5ef7fc | ||
|
|
739b7d178e | ||
|
|
cacf50cd57 | ||
|
|
0904cda3ac | ||
|
|
6bb38939ec | ||
|
|
1dbe11caf9 | ||
|
|
8d9e3b88d3 | ||
|
|
9dd33d37f2 | ||
|
|
a4bb4bb6ac | ||
|
|
7b99cb4a12 | ||
|
|
9848a45da5 | ||
|
|
378975813c | ||
|
|
e680e8a1ed | ||
|
|
7b2282d300 | ||
|
|
3b1ea1933b | ||
|
|
668766fc4b | ||
|
|
e501eeaf91 | ||
|
|
41902f716f | ||
|
|
b7bab80ec8 | ||
|
|
6169996c70 | ||
|
|
bbb58460f8 | ||
|
|
cff03fc6c5 | ||
|
|
f7122d400d | ||
|
|
c938efb531 | ||
|
|
e2d3a90832 | ||
|
|
ba96413a63 | ||
|
|
cb40eb23ce | ||
|
|
afe71c01da | ||
|
|
a84cba4e3a | ||
|
|
23158a42ad | ||
|
|
18e7919971 | ||
|
|
0e32a625d7 | ||
|
|
04bc163fea | ||
|
|
949055dec0 | ||
|
|
070b163cc7 | ||
|
|
fc26ad4006 | ||
|
|
5d3be3c6ed | ||
|
|
23dd5de3ae | ||
|
|
6030b39964 | ||
|
|
4f4ac0fa52 | ||
|
|
16d9839071 | ||
|
|
8269b4b190 | ||
|
|
1e869a0f15 | ||
|
|
5a4d128db6 | ||
|
|
8a5d212cfc | ||
|
|
53edb8508b | ||
|
|
29d9df04bf | ||
|
|
4d6af6e6ca | ||
|
|
8c7c156f57 | ||
|
|
310843487f | ||
|
|
a4b221d0a0 | ||
|
|
286db875de | ||
|
|
d714e40f62 | ||
|
|
e78ef75531 | ||
|
|
38eaecf087 | ||
|
|
3cf28f8452 | ||
|
|
9ba8bbdd70 | ||
|
|
af6048e373 | ||
|
|
d93b8e8948 | ||
|
|
b69cb49a46 | ||
|
|
cc74b1f9b3 | ||
|
|
e78a52de5f | ||
|
|
f6c2c37c4b | ||
|
|
314a5d9781 | ||
|
|
b4e885bbd2 | ||
|
|
bd9d11861b | ||
|
|
571b26c50e | ||
|
|
b21681931d | ||
|
|
f584e86d8e | ||
|
|
fd05bca1c8 | ||
|
|
cbac4d6a3e | ||
|
|
b0977f97ab | ||
|
|
1716f637f7 | ||
|
|
903a5aabf7 | ||
|
|
b4f86496ea | ||
|
|
8e57f3385c | ||
|
|
3ccbdf19de | ||
|
|
3687ba18df | ||
|
|
6bb7c11bbb | ||
|
|
c8f93721c5 | ||
|
|
fb8d87025f | ||
|
|
87865f0cd9 | ||
|
|
755dd477dd | ||
|
|
fb44eb086c | ||
|
|
be8cbc0f56 | ||
|
|
fe8074929f | ||
|
|
c3c80c61c9 | ||
|
|
c138de0875 | ||
|
|
38bd00390c | ||
|
|
732ba915aa | ||
|
|
785710355f | ||
|
|
320892dccc | ||
|
|
6dae3a4719 | ||
|
|
7b77ef000e | ||
|
|
6c08b32350 | ||
|
|
4dd617ad37 | ||
|
|
acdee16aee | ||
|
|
9fc33587da | ||
|
|
f087c0ac99 | ||
|
|
16b126d890 | ||
|
|
faf0aaedba | ||
|
|
4cac1bb151 | ||
|
|
cb3c1477bb | ||
|
|
19a7d5a5cf | ||
|
|
f7e0348f62 | ||
|
|
c3dfac60a6 | ||
|
|
64954aacfe | ||
|
|
ccc5415cc6 | ||
|
|
1dcc731b43 | ||
|
|
3662ec402a | ||
|
|
0739dc9564 | ||
|
|
d16280ddfc | ||
|
|
f9c23617a7 | ||
|
|
ce2ccddc93 | ||
|
|
1af28cb5a1 | ||
|
|
6b61fc6660 | ||
|
|
bdf417f25e | ||
|
|
d154d37ac4 | ||
|
|
90fd5c13a4 | ||
|
|
7d223a0095 | ||
|
|
cb3d89eb48 | ||
|
|
8302fd0aae | ||
|
|
deb80d2577 | ||
|
|
976e5f2fdb | ||
|
|
9dc76ef03b | ||
|
|
32cd1f1d72 | ||
|
|
6b54188140 | ||
|
|
58bcf5b429 | ||
|
|
3c0297c3e9 | ||
|
|
8d433e6579 | ||
|
|
676125bfe6 | ||
|
|
902e0d35e1 | ||
|
|
972aea446c | ||
|
|
0d0338f871 | ||
|
|
0f10244900 | ||
|
|
686e139508 | ||
|
|
ca0caad0ae | ||
|
|
f94c9067e2 | ||
|
|
f0bb515d1d | ||
|
|
71db641fe4 | ||
|
|
4b8c4a795f | ||
|
|
f339f74fe3 | ||
|
|
7dc0a7467b | ||
|
|
497756f7c8 | ||
|
|
4874295b34 | ||
|
|
fece3f0cef | ||
|
|
38ee815107 | ||
|
|
3d5746f16f | ||
|
|
6b28ef0349 |
@@ -27,7 +27,7 @@ commands:
|
||||
fi
|
||||
|
||||
# Start first instance
|
||||
HF_HOME="$(pwd)/.hf_cache_node1" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <<parameters.inference_engine>> \
|
||||
EXO_HOME="$(pwd)/.exo_cache_node1" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <<parameters.inference_engine>> \
|
||||
--node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 \
|
||||
--chatgpt-api-response-timeout 900 --disable-tui > output1.log &
|
||||
PID1=$!
|
||||
@@ -35,7 +35,7 @@ commands:
|
||||
TAIL1=$!
|
||||
|
||||
# Start second instance
|
||||
HF_HOME="$(pwd)/.hf_cache_node2" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <<parameters.inference_engine>> \
|
||||
EXO_HOME="$(pwd)/.exo_cache_node2" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <<parameters.inference_engine>> \
|
||||
--node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 \
|
||||
--chatgpt-api-response-timeout 900 --disable-tui > output2.log &
|
||||
PID2=$!
|
||||
@@ -254,6 +254,35 @@ jobs:
|
||||
prompt: "Keep responses concise. Who was the king of pop?"
|
||||
expected_output: "Michael Jackson"
|
||||
|
||||
chatgpt_api_integration_test_tinygrad_linux:
|
||||
machine:
|
||||
image: ubuntu-2204:current
|
||||
resource_class: xlarge
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Set up Python
|
||||
command: |
|
||||
export DEBIAN_FRONTEND=noninteractive
|
||||
export DEBCONF_NONINTERACTIVE_SEEN=true
|
||||
sudo apt-get update
|
||||
sudo add-apt-repository -y ppa:deadsnakes/ppa
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y python3.12 python3.12-venv clang
|
||||
python3.12 -m venv env
|
||||
source env/bin/activate
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install .
|
||||
- run_chatgpt_api_test:
|
||||
inference_engine: tinygrad
|
||||
model_id: llama-3.2-1b
|
||||
prompt: "Keep responses concise. Who was the king of pop?"
|
||||
expected_output: "Michael Jackson"
|
||||
|
||||
measure_pip_sizes:
|
||||
macos:
|
||||
xcode: "16.0.0"
|
||||
@@ -342,5 +371,6 @@ workflows:
|
||||
- discovery_integration_test
|
||||
- chatgpt_api_integration_test_mlx
|
||||
- chatgpt_api_integration_test_tinygrad
|
||||
- chatgpt_api_integration_test_tinygrad_linux
|
||||
- chatgpt_api_integration_test_dummy
|
||||
- measure_pip_sizes
|
||||
|
||||
401
.github/bench.py
vendored
Normal file
401
.github/bench.py
vendored
Normal file
@@ -0,0 +1,401 @@
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import time
|
||||
import json
|
||||
import os
|
||||
import boto3
|
||||
from typing import Dict, Any
|
||||
from datetime import datetime
|
||||
import subprocess
|
||||
import psutil
|
||||
import platform
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def check_system_state():
|
||||
print("\n=== System State Check ===", flush=True)
|
||||
|
||||
# Add macOS-specific checks
|
||||
try:
|
||||
# Check powermetrics with sudo
|
||||
try:
|
||||
power_metrics = subprocess.run(
|
||||
['sudo', 'powermetrics', '-n', '1', '-i', '1000', '--samplers', 'cpu_power'],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
print("\nPower Metrics:", power_metrics.stdout, flush=True)
|
||||
except Exception as e:
|
||||
print(f"Error getting power metrics: {e}", flush=True)
|
||||
|
||||
# Check thermal state
|
||||
thermal_state = subprocess.run(['pmset', '-g', 'therm'], capture_output=True, text=True)
|
||||
print("\nThermal State:", thermal_state.stdout, flush=True)
|
||||
|
||||
# Check if running under Rosetta
|
||||
arch = subprocess.run(['arch'], capture_output=True, text=True)
|
||||
print("\nArchitecture:", arch.stdout, flush=True)
|
||||
|
||||
# Check MLX compilation mode - only if mlx is available
|
||||
try:
|
||||
import mlx.core as mx
|
||||
if hasattr(mx, 'build_info'):
|
||||
print("\nMLX Build Info:", mx.build_info(), flush=True)
|
||||
else:
|
||||
print("\nMLX Build Info: Not available in this version", flush=True)
|
||||
except ImportError:
|
||||
print("\nMLX: Not installed", flush=True)
|
||||
except Exception as e:
|
||||
print(f"\nError checking MLX: {e}", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in macOS checks: {e}", flush=True)
|
||||
|
||||
# CPU Info
|
||||
print("\nCPU Information:", flush=True)
|
||||
try:
|
||||
if platform.system() == 'Darwin' and platform.processor() == 'arm':
|
||||
# Use sysctl for Apple Silicon Macs
|
||||
cpu_info = subprocess.run(['sysctl', 'machdep.cpu'], capture_output=True, text=True)
|
||||
if cpu_info.returncode == 0:
|
||||
print(f"CPU Info (Apple Silicon):", cpu_info.stdout, flush=True)
|
||||
|
||||
# Parse powermetrics output for clearer CPU frequency display
|
||||
try:
|
||||
power_metrics = subprocess.run(
|
||||
['sudo', 'powermetrics', '-n', '1', '-i', '100', '--samplers', 'cpu_power'],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
if power_metrics.returncode == 0:
|
||||
output = power_metrics.stdout
|
||||
print("\nDetailed CPU Frequency Information:")
|
||||
|
||||
# Extract cluster frequencies and max frequencies
|
||||
current_cluster = None
|
||||
max_freqs = {'E': 0, 'P0': 0, 'P1': 0}
|
||||
|
||||
for line in output.split('\n'):
|
||||
# Track which cluster we're processing
|
||||
if "E-Cluster" in line:
|
||||
current_cluster = 'E'
|
||||
elif "P0-Cluster" in line:
|
||||
current_cluster = 'P0'
|
||||
elif "P1-Cluster" in line:
|
||||
current_cluster = 'P1'
|
||||
|
||||
# Get current frequencies
|
||||
if "HW active frequency:" in line:
|
||||
freq = line.split(':')[1].strip()
|
||||
if freq != "0 MHz":
|
||||
print(f"Current {current_cluster}-Cluster Frequency: {freq}")
|
||||
|
||||
# Get max frequencies from residency lines
|
||||
if current_cluster and "active residency:" in line and "MHz:" in line:
|
||||
try:
|
||||
# Extract all frequency values
|
||||
freqs = []
|
||||
parts = line.split('MHz:')[:-1] # Skip last part as it's not a frequency
|
||||
for part in parts:
|
||||
freq_str = part.split()[-1]
|
||||
try:
|
||||
freq = float(freq_str)
|
||||
freqs.append(freq)
|
||||
except ValueError:
|
||||
continue
|
||||
if freqs:
|
||||
max_freqs[current_cluster] = max(max_freqs[current_cluster], max(freqs))
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# Print max frequencies
|
||||
print("\nMaximum Available Frequencies:")
|
||||
for cluster, max_freq in max_freqs.items():
|
||||
if max_freq > 0:
|
||||
print(f"{cluster}-Cluster Max: {max_freq:.0f} MHz")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error parsing powermetrics: {e}", flush=True)
|
||||
else:
|
||||
# Use psutil for other systems
|
||||
cpu_freq = psutil.cpu_freq()
|
||||
print(f"CPU Frequency - Current: {cpu_freq.current:.2f}MHz, Min: {cpu_freq.min:.2f}MHz, Max: {cpu_freq.max:.2f}MHz", flush=True)
|
||||
|
||||
print(f"\nCPU Usage per Core: {psutil.cpu_percent(percpu=True)}%", flush=True)
|
||||
|
||||
# Check if running in low power mode
|
||||
power_mode = subprocess.run(['pmset', '-g'], capture_output=True, text=True)
|
||||
print("\nPower Settings:", power_mode.stdout, flush=True)
|
||||
except Exception as e:
|
||||
print(f"Error getting CPU info: {e}", flush=True)
|
||||
|
||||
# Memory Info
|
||||
print("\nMemory Information:", flush=True)
|
||||
try:
|
||||
mem = psutil.virtual_memory()
|
||||
print(f"Total: {mem.total/1024/1024/1024:.2f}GB", flush=True)
|
||||
print(f"Available: {mem.available/1024/1024/1024:.2f}GB", flush=True)
|
||||
print(f"Used: {mem.used/1024/1024/1024:.2f}GB ({mem.percent}%)", flush=True)
|
||||
|
||||
# Check swap
|
||||
swap = psutil.swap_memory()
|
||||
print(f"Swap Used: {swap.used/1024/1024/1024:.2f}GB of {swap.total/1024/1024/1024:.2f}GB", flush=True)
|
||||
except Exception as e:
|
||||
print(f"Error getting memory info: {e}", flush=True)
|
||||
|
||||
# GPU Info
|
||||
print("\nGPU Information:", flush=True)
|
||||
try:
|
||||
# Check MLX GPU settings
|
||||
print("MLX Environment Variables:", flush=True)
|
||||
mlx_vars = {k: v for k, v in os.environ.items() if k.startswith('MLX')}
|
||||
print(json.dumps(mlx_vars, indent=2), flush=True)
|
||||
|
||||
# Check Metal GPU memory allocation
|
||||
gpu_mem = subprocess.run(['sysctl', 'iogpu'], capture_output=True, text=True)
|
||||
print("GPU Memory Settings:", gpu_mem.stdout, flush=True)
|
||||
except Exception as e:
|
||||
print(f"Error getting GPU info: {e}", flush=True)
|
||||
|
||||
# Process Priority
|
||||
print("\nProcess Priority Information:", flush=True)
|
||||
try:
|
||||
current_process = psutil.Process()
|
||||
print(f"Process Nice Value: {current_process.nice()}", flush=True)
|
||||
# Only try to get ionice if the platform supports it
|
||||
if hasattr(current_process, 'ionice'):
|
||||
print(f"Process IO Nice Value: {current_process.ionice()}", flush=True)
|
||||
except Exception as e:
|
||||
print(f"Error getting process priority info: {e}", flush=True)
|
||||
|
||||
# System Load
|
||||
print("\nSystem Load:", flush=True)
|
||||
try:
|
||||
load_avg = psutil.getloadavg()
|
||||
print(f"Load Average: {load_avg}", flush=True)
|
||||
|
||||
# Get top processes by CPU and Memory
|
||||
print("\nTop Processes by CPU Usage:", flush=True)
|
||||
processes = []
|
||||
for proc in psutil.process_iter(['pid', 'name', 'cpu_percent', 'memory_percent']):
|
||||
try:
|
||||
pinfo = proc.info
|
||||
if pinfo['cpu_percent'] is not None and pinfo['memory_percent'] is not None:
|
||||
processes.append(pinfo)
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
||||
continue
|
||||
|
||||
# Sort and display top 5 CPU-consuming processes
|
||||
sorted_by_cpu = sorted(processes, key=lambda x: x['cpu_percent'] or 0, reverse=True)[:5]
|
||||
for proc in sorted_by_cpu:
|
||||
print(f"PID: {proc['pid']}, Name: {proc['name']}, CPU: {proc['cpu_percent']}%, Memory: {proc['memory_percent']:.1f}%")
|
||||
except Exception as e:
|
||||
print(f"Error getting system load info: {e}", flush=True)
|
||||
|
||||
print("\n=== End System State Check ===\n", flush=True)
|
||||
|
||||
|
||||
def check_gpu_access():
|
||||
try:
|
||||
# Check if MLX can see the GPU
|
||||
import mlx.core as mx
|
||||
print("MLX device info:", mx.default_device())
|
||||
|
||||
# Check Metal device availability
|
||||
result = subprocess.run(['system_profiler', 'SPDisplaysDataType'], capture_output=True, text=True)
|
||||
print("GPU Info:", result.stdout)
|
||||
except Exception as e:
|
||||
print(f"Failed to check GPU access: {e}")
|
||||
|
||||
|
||||
async def measure_performance(api_endpoint: str, prompt: str, model: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Measures the performance of an API endpoint by sending a prompt and recording metrics.
|
||||
|
||||
Args:
|
||||
api_endpoint (str): The API endpoint URL.
|
||||
prompt (str): The prompt to send to the API.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary containing performance metrics or error information.
|
||||
"""
|
||||
|
||||
results = {
|
||||
'model': model,
|
||||
'run_id': os.environ.get('GITHUB_RUN_ID', 'unknown'),
|
||||
'branch': os.environ.get('GITHUB_REF_NAME', 'unknown'),
|
||||
'commit': os.environ.get('GITHUB_SHA', 'unknown'),
|
||||
'configuration': json.loads(os.environ.get('HARDWARE_CONFIG', '{}'))
|
||||
}
|
||||
|
||||
# Get token count
|
||||
session = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=600, connect=10, sock_read=600, sock_connect=10))
|
||||
try:
|
||||
response = await session.post(
|
||||
"http://localhost:52415/v1/chat/token/encode",
|
||||
json={
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": prompt}]
|
||||
}
|
||||
)
|
||||
response.raise_for_status()
|
||||
token_data = await response.json()
|
||||
results['prompt_len'] = token_data['num_tokens']
|
||||
except Exception as e:
|
||||
await session.close()
|
||||
raise RuntimeError(f"Failed to get token count: {str(e)}")
|
||||
|
||||
# Measure completion performance
|
||||
try:
|
||||
start_time = time.time()
|
||||
response = await session.post(
|
||||
api_endpoint,
|
||||
json={
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"temperature": 0,
|
||||
"stream": True
|
||||
}
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
first_token_time = None
|
||||
total_tokens = 0
|
||||
|
||||
async for line in response.content.iter_chunks():
|
||||
line = line[0].decode('utf-8').strip()
|
||||
if not line.startswith('data: '):
|
||||
continue
|
||||
|
||||
data = json.loads(line[6:]) # Skip 'data: ' prefix
|
||||
if content := data.get('choices', [{}])[0].get('delta', {}).get('content'):
|
||||
print(f"Received content: {content}", flush=True)
|
||||
if first_token_time is None:
|
||||
first_token_time = time.time()
|
||||
ttft = first_token_time - start_time
|
||||
results.update({
|
||||
'ttft': ttft,
|
||||
'prompt_tps': results['prompt_len'] / ttft
|
||||
})
|
||||
total_tokens += 1
|
||||
|
||||
total_time = time.time() - start_time
|
||||
results.update({
|
||||
'generation_tps': total_tokens / total_time,
|
||||
'response_len': total_tokens,
|
||||
'total_time': total_time
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Performance measurement failed: {str(e)}")
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
api_endpoint = "http://localhost:52415/v1/chat/completions"
|
||||
|
||||
# Define prompts
|
||||
prompt_warmup = "what is the capital of France?"
|
||||
prompt_essay = "write an essay about cats"
|
||||
|
||||
model = os.environ.get('model', 'llama-3.2-1b')
|
||||
# Warmup request
|
||||
print("\nPerforming warmup request...", flush=True)
|
||||
try:
|
||||
warmup_results = await measure_performance(api_endpoint, prompt_warmup, model)
|
||||
print("Warmup completed successfully", flush=True)
|
||||
except Exception as e:
|
||||
print(f"Warmup request failed: {e}", flush=True)
|
||||
|
||||
# Measure performance for the essay prompt
|
||||
print("\nMeasuring performance for the essay prompt...", flush=True)
|
||||
results = await measure_performance(api_endpoint, prompt_essay, model)
|
||||
|
||||
try:
|
||||
s3_client = boto3.client(
|
||||
's3',
|
||||
aws_access_key_id=os.environ.get('aws_access_key_id'),
|
||||
aws_secret_access_key=os.environ.get('aws_secret_key')
|
||||
)
|
||||
job_name = os.environ.get('GITHUB_JOB')
|
||||
|
||||
# Create S3 key with timestamp and commit info
|
||||
now = datetime.utcnow()
|
||||
timestamp = now.strftime('%H-%M-%S')
|
||||
commit_sha = os.environ.get('GITHUB_SHA', 'unknown')[:7]
|
||||
s3_key = f"{job_name}/{model}/{now.year}/{now.month}/{now.day}/{timestamp}_{commit_sha}.json"
|
||||
|
||||
# Upload to S3
|
||||
s3_client.put_object(
|
||||
Bucket='exo-benchmarks',
|
||||
Key=s3_key,
|
||||
Body=json.dumps(results),
|
||||
ContentType='application/json'
|
||||
)
|
||||
print(f"Performance metrics uploaded to S3: s3://exo-benchmarks/{s3_key}", flush=True)
|
||||
except Exception as e:
|
||||
print(f"Failed to upload metrics to S3: {e}", flush=True)
|
||||
|
||||
# Optionally print the metrics for visibility
|
||||
print("Performance metrics:", flush=True)
|
||||
print(json.dumps(results, indent=4), flush=True)
|
||||
|
||||
|
||||
def optimize_system_performance():
|
||||
"""Set optimal system performance settings before running benchmark."""
|
||||
try:
|
||||
# Try to set high performance power mode
|
||||
subprocess.run(['sudo', 'pmset', '-a', 'powermode', '2'], check=False)
|
||||
|
||||
# Ensure MLX uses performance cores and GPU
|
||||
os.environ['MLX_FORCE_P_CORES'] = '1'
|
||||
os.environ['MLX_METAL_PREWARM'] = '1'
|
||||
os.environ['MLX_USE_GPU'] = '1'
|
||||
|
||||
# Set process priority
|
||||
current_process = psutil.Process()
|
||||
try:
|
||||
# Set highest priority
|
||||
subprocess.run(['sudo', 'renice', '-n', '-20', '-p', str(current_process.pid)], check=False)
|
||||
|
||||
# Print current process state
|
||||
print("\nProcess State Before Benchmark:", flush=True)
|
||||
proc_info = subprocess.run(
|
||||
['ps', '-o', 'pid,ppid,user,%cpu,%mem,nice,stat,pri,command', '-p', str(current_process.pid)],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
print(proc_info.stdout, flush=True)
|
||||
|
||||
# Verify power mode
|
||||
power_info = subprocess.run(['pmset', '-g'], capture_output=True, text=True)
|
||||
if 'powermode 0' in power_info.stdout:
|
||||
print("\nWarning: System still in normal power mode. Trying to set high performance mode again...", flush=True)
|
||||
subprocess.run(['sudo', 'pmset', '-a', 'powermode', '2'], check=False)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not set process priority: {e}", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not optimize system performance: {e}", flush=True)
|
||||
|
||||
# Print optimization status
|
||||
print("\nOptimization Settings:", flush=True)
|
||||
print("MLX Environment Variables:", flush=True)
|
||||
for var in ['MLX_FORCE_P_CORES', 'MLX_METAL_PREWARM', 'MLX_USE_GPU']:
|
||||
print(f"{var}: {os.environ.get(var, 'Not set')}", flush=True)
|
||||
|
||||
try:
|
||||
nice_value = psutil.Process().nice()
|
||||
print(f"Process Nice Value: {nice_value}", flush=True)
|
||||
if nice_value != -20:
|
||||
print("Warning: Process not running at highest priority", flush=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
check_system_state()
|
||||
check_gpu_access()
|
||||
optimize_system_performance()
|
||||
asyncio.run(main())
|
||||
330
.github/bootstrap.sh
vendored
Executable file
330
.github/bootstrap.sh
vendored
Executable file
@@ -0,0 +1,330 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
command_exists() {
|
||||
command -v "$1" >/dev/null 2>&1
|
||||
}
|
||||
|
||||
log() {
|
||||
echo "[$(date '+%Y-%m-%d %H:%M:%S')] $1"
|
||||
}
|
||||
|
||||
if [ "$EUID" -eq 0 ]; then
|
||||
log "Please do not run as root. Run as regular user with sudo access."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check for required arguments
|
||||
if [ -z "$1" ]; then
|
||||
log "Error: Runner token is required"
|
||||
log "Usage: $0 <runner-token> [tailscale-auth-key]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
RUNNER_TOKEN=$1
|
||||
TAILSCALE_AUTH_KEY=$2
|
||||
REPO="exo-explore/exo"
|
||||
|
||||
# Add sudoers configuration
|
||||
log "Configuring sudo access..."
|
||||
SUDOERS_CONTENT="$(whoami) ALL=(ALL) NOPASSWD: ALL"
|
||||
echo "$SUDOERS_CONTENT" | sudo tee /etc/sudoers.d/github-runner > /dev/null
|
||||
sudo chmod 440 /etc/sudoers.d/github-runner
|
||||
|
||||
log "Configuring privacy permissions..."
|
||||
sudo tccutil reset All
|
||||
sudo tccutil reset SystemPolicyAllFiles
|
||||
sudo tccutil reset SystemPolicyNetworkVolumes
|
||||
|
||||
# Configure power management for maximum performance
|
||||
log "Configuring power management..."
|
||||
sudo pmset -a powermode 2 # Force highest performance mode
|
||||
sudo pmset -a gpuswitch 2 # Force discrete/high-performance GPU
|
||||
sudo pmset -a lowpowermode 0
|
||||
sudo pmset -a lessbright 0
|
||||
sudo pmset -a disablesleep 1
|
||||
sudo pmset -a sleep 0
|
||||
sudo pmset -a hibernatemode 0
|
||||
sudo pmset -a autopoweroff 0
|
||||
sudo pmset -a standby 0
|
||||
sudo pmset -a powernap 0
|
||||
|
||||
# For Python specifically
|
||||
PYTHON_PATH="/opt/homebrew/bin/python3.12"
|
||||
sudo chmod 755 "$PYTHON_PATH"
|
||||
|
||||
# Add to firewall
|
||||
log "Configuring firewall access..."
|
||||
sudo /usr/libexec/ApplicationFirewall/socketfilterfw --add "$PYTHON_PATH"
|
||||
sudo /usr/libexec/ApplicationFirewall/socketfilterfw --unblock "$PYTHON_PATH"
|
||||
|
||||
# Set Homebrew paths based on architecture
|
||||
if [ "$(uname -p)" = "arm" ]; then
|
||||
BREW_PREFIX="/opt/homebrew"
|
||||
else
|
||||
BREW_PREFIX="/usr/local"
|
||||
fi
|
||||
|
||||
# Install Homebrew if not present
|
||||
if ! command_exists brew; then
|
||||
log "Installing Homebrew..."
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
|
||||
echo 'eval "$(/opt/homebrew/bin/brew shellenv)"' >> ~/.zshrc
|
||||
eval "$(/opt/homebrew/bin/brew shellenv)"
|
||||
fi
|
||||
|
||||
# Install required packages
|
||||
log "Installing required packages..."
|
||||
export HOMEBREW_NO_AUTO_UPDATE=1
|
||||
brew install python@3.12 coreutils
|
||||
|
||||
# Optional Tailscale setup if auth key is provided
|
||||
if [ -n "$TAILSCALE_AUTH_KEY" ]; then
|
||||
log "Installing and configuring Tailscale..."
|
||||
brew install --quiet tailscale
|
||||
sudo brew services stop tailscale 2>/dev/null || true
|
||||
sudo rm -f /var/db/tailscale/tailscaled.state 2>/dev/null || true
|
||||
sudo brew services start tailscale
|
||||
sleep 2
|
||||
sudo tailscale up --authkey=$TAILSCALE_AUTH_KEY
|
||||
|
||||
# Enable SSH and Screen Sharing
|
||||
log "Enabling remote access services..."
|
||||
sudo launchctl load -w /System/Library/LaunchDaemons/ssh.plist
|
||||
sudo /System/Library/CoreServices/RemoteManagement/ARDAgent.app/Contents/Resources/kickstart \
|
||||
-activate \
|
||||
-configure -access -on \
|
||||
-configure -allowAccessFor -allUsers \
|
||||
-configure -restart -agent -privs -all
|
||||
|
||||
# Create launch daemon for remote access
|
||||
sudo bash -c 'cat > /Library/LaunchDaemons/com.remote.access.setup.plist' << 'EOL'
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>Label</key>
|
||||
<string>com.remote.access.setup</string>
|
||||
<key>ProgramArguments</key>
|
||||
<array>
|
||||
<string>/bin/bash</string>
|
||||
<string>-c</string>
|
||||
<string>
|
||||
launchctl load -w /System/Library/LaunchDaemons/ssh.plist;
|
||||
/System/Library/CoreServices/RemoteManagement/ARDAgent.app/Contents/Resources/kickstart -activate -configure -access -on
|
||||
</string>
|
||||
</array>
|
||||
<key>RunAtLoad</key>
|
||||
<true/>
|
||||
</dict>
|
||||
</plist>
|
||||
EOL
|
||||
|
||||
sudo chmod 644 /Library/LaunchDaemons/com.remote.access.setup.plist
|
||||
sudo launchctl load -w /Library/LaunchDaemons/com.remote.access.setup.plist
|
||||
fi
|
||||
|
||||
# Configure GitHub Actions Runner
|
||||
log "Gathering system metadata..."
|
||||
MACHINE_NAME=$(scutil --get ComputerName)
|
||||
MACHINE_NAME="runner-$(echo -n "$MACHINE_NAME" | tr '[:upper:]' '[:lower:]' | tr -cd '[:alnum:]-')"
|
||||
|
||||
# Enhanced Apple Silicon detection
|
||||
MACHINE_INFO=$(system_profiler SPHardwareDataType)
|
||||
CHIP_FULL=$(echo "$MACHINE_INFO" | grep "Chip" | cut -d: -f2 | xargs)
|
||||
if [[ $CHIP_FULL =~ "Apple" ]]; then
|
||||
CHIP_MODEL=$(echo "$CHIP_FULL" | sed 's/^Apple //' | tr -d ' ' | tr '[:lower:]' '[:upper:]')
|
||||
GPU_CORES=$(ioreg -l | grep "gpu-core-count" | awk -F'= ' '{print $2}')
|
||||
if [ -z "$GPU_CORES" ]; then
|
||||
GPU_CORES="N/A"
|
||||
fi
|
||||
else
|
||||
CHIP_MODEL="Intel"
|
||||
GPU_CORES="N/A"
|
||||
fi
|
||||
|
||||
MEMORY=$(($(sysctl -n hw.memsize) / 1024 / 1024 / 1024))
|
||||
|
||||
# Set up GitHub Runner
|
||||
RUNNER_DIR="$HOME/actions-runner"
|
||||
|
||||
# Check if runner is already configured
|
||||
if [ -f "$RUNNER_DIR/.runner" ]; then
|
||||
log "Runner already configured. Stopping existing service..."
|
||||
sudo launchctl unload /Library/LaunchDaemons/com.github.runner.plist 2>/dev/null || true
|
||||
fi
|
||||
|
||||
# Create runner directory if it doesn't exist
|
||||
mkdir -p "$RUNNER_DIR"
|
||||
cd "$RUNNER_DIR"
|
||||
|
||||
CUSTOM_LABELS="self-hosted,macos,arm64,${CHIP_MODEL}_GPU${GPU_CORES}_${MEMORY}GB"
|
||||
|
||||
# Only download and extract if not already present or if forced
|
||||
if [ ! -f "$RUNNER_DIR/run.sh" ] || [ "${FORCE_SETUP:-false}" = "true" ]; then
|
||||
log "Downloading GitHub Actions runner..."
|
||||
RUNNER_VERSION=$(curl -s https://api.github.com/repos/actions/runner/releases/latest | grep '"tag_name":' | cut -d'"' -f4)
|
||||
curl -o actions-runner.tar.gz -L "https://github.com/actions/runner/releases/download/${RUNNER_VERSION}/actions-runner-osx-arm64-${RUNNER_VERSION#v}.tar.gz"
|
||||
tar xzf actions-runner.tar.gz
|
||||
rm actions-runner.tar.gz
|
||||
else
|
||||
log "Runner already downloaded, skipping download step"
|
||||
fi
|
||||
|
||||
log "Configuring runner with labels: $CUSTOM_LABELS"
|
||||
./config.sh --unattended \
|
||||
--url "https://github.com/${REPO}" \
|
||||
--token "${RUNNER_TOKEN}" \
|
||||
--name "${MACHINE_NAME}" \
|
||||
--labels "${CUSTOM_LABELS}" \
|
||||
--work "_work"
|
||||
|
||||
# Set optimal performance settings
|
||||
log "Configuring system for optimal performance..."
|
||||
|
||||
# Configure CPU performance
|
||||
log "Setting CPU performance controls..."
|
||||
# Disable timer coalescing
|
||||
sudo sysctl -w kern.timer.coalescing_enabled=0
|
||||
sudo sysctl -w kern.timer_coalesce_bg_scale=-5
|
||||
sudo sysctl -w kern.timer_resort_threshold_ns=0
|
||||
# Set minimum timer intervals
|
||||
sudo sysctl -w kern.wq_max_timer_interval_usecs=1000
|
||||
sudo sysctl -w kern.timer_coalesce_bg_ns_max=1000
|
||||
# Set minimum timer coalescing for all tiers
|
||||
sudo sysctl -w kern.timer_coalesce_tier0_scale=-5
|
||||
sudo sysctl -w kern.timer_coalesce_tier0_ns_max=1000
|
||||
sudo sysctl -w kern.timer_coalesce_tier1_scale=-5
|
||||
sudo sysctl -w kern.timer_coalesce_tier1_ns_max=1000
|
||||
sudo sysctl -w kern.timer_coalesce_tier2_scale=-5
|
||||
sudo sysctl -w kern.timer_coalesce_tier2_ns_max=1000
|
||||
sudo sysctl -w kern.timer_coalesce_tier3_scale=-5
|
||||
sudo sysctl -w kern.timer_coalesce_tier3_ns_max=1000
|
||||
sudo sysctl -w kern.timer_coalesce_tier4_scale=-5
|
||||
sudo sysctl -w kern.timer_coalesce_tier4_ns_max=1000
|
||||
# Disable QoS restrictions
|
||||
sudo sysctl -w net.qos.policy.restricted=0
|
||||
sudo sysctl -w net.qos.policy.restrict_avapps=0
|
||||
sudo sysctl -w net.qos.policy.wifi_enabled=0
|
||||
sudo sysctl -w net.qos.policy.capable_enabled=0
|
||||
# Set scheduler parameters
|
||||
sudo sysctl -w kern.sched_rt_avoid_cpu0=0
|
||||
sudo sysctl -w debug.sched=2
|
||||
sudo sysctl -w net.pktsched.netem.sched_output_ival_ms=1
|
||||
|
||||
# Clean up any existing runner services
|
||||
log "Cleaning up existing runner services..."
|
||||
for service in com.github.runner com.github.runner.monitor com.github.runner.cpuaffinity com.github.runner.affinity; do
|
||||
sudo launchctl bootout system/$service 2>/dev/null || true
|
||||
sudo rm -f /Library/LaunchDaemons/$service.plist
|
||||
done
|
||||
|
||||
# Create a simple runner service configuration
|
||||
sudo tee /Library/LaunchDaemons/com.github.runner.plist > /dev/null << EOF
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>Label</key>
|
||||
<string>com.github.runner</string>
|
||||
<key>UserName</key>
|
||||
<string>$(whoami)</string>
|
||||
<key>GroupName</key>
|
||||
<string>staff</string>
|
||||
<key>WorkingDirectory</key>
|
||||
<string>$RUNNER_DIR</string>
|
||||
<key>ProgramArguments</key>
|
||||
<array>
|
||||
<string>$RUNNER_DIR/run.sh</string>
|
||||
</array>
|
||||
<key>RunAtLoad</key>
|
||||
<true/>
|
||||
<key>KeepAlive</key>
|
||||
<dict>
|
||||
<key>SuccessfulExit</key>
|
||||
<false/>
|
||||
<key>Crashed</key>
|
||||
<true/>
|
||||
</dict>
|
||||
<key>ProcessType</key>
|
||||
<string>Interactive</string>
|
||||
<key>LowPriorityIO</key>
|
||||
<false/>
|
||||
<key>AbandonProcessGroup</key>
|
||||
<false/>
|
||||
<key>EnableTransactions</key>
|
||||
<true/>
|
||||
<key>ThrottleInterval</key>
|
||||
<integer>0</integer>
|
||||
<key>HardResourceLimits</key>
|
||||
<dict>
|
||||
<key>NumberOfFiles</key>
|
||||
<integer>524288</integer>
|
||||
<key>MemoryLock</key>
|
||||
<integer>-1</integer>
|
||||
</dict>
|
||||
<key>SoftResourceLimits</key>
|
||||
<dict>
|
||||
<key>NumberOfFiles</key>
|
||||
<integer>524288</integer>
|
||||
<key>MemoryLock</key>
|
||||
<integer>-1</integer>
|
||||
</dict>
|
||||
<key>QOSClass</key>
|
||||
<string>User-Interactive</string>
|
||||
<key>StandardOutPath</key>
|
||||
<string>$RUNNER_DIR/_diag/runner.log</string>
|
||||
<key>StandardErrorPath</key>
|
||||
<string>$RUNNER_DIR/_diag/runner.err</string>
|
||||
<key>EnvironmentVariables</key>
|
||||
<dict>
|
||||
<key>PATH</key>
|
||||
<string>/usr/local/bin:/opt/homebrew/bin:/usr/bin:/bin:/usr/sbin:/sbin</string>
|
||||
</dict>
|
||||
<key>Nice</key>
|
||||
<integer>-20</integer>
|
||||
</dict>
|
||||
</plist>
|
||||
EOF
|
||||
|
||||
# Set proper permissions for the LaunchDaemon
|
||||
sudo chown root:wheel /Library/LaunchDaemons/com.github.runner.plist
|
||||
sudo chmod 644 /Library/LaunchDaemons/com.github.runner.plist
|
||||
|
||||
# Remove any existing service
|
||||
sudo launchctl bootout system/com.github.runner 2>/dev/null || true
|
||||
|
||||
# Load the new service using bootstrap
|
||||
sudo launchctl bootstrap system /Library/LaunchDaemons/com.github.runner.plist
|
||||
|
||||
# Add Runner.Listener permissions (after runner installation)
|
||||
RUNNER_PATH="$RUNNER_DIR/bin/Runner.Listener"
|
||||
sudo chmod 755 "$RUNNER_PATH"
|
||||
sudo /usr/libexec/ApplicationFirewall/socketfilterfw --add "$RUNNER_PATH"
|
||||
sudo /usr/libexec/ApplicationFirewall/socketfilterfw --unblock "$RUNNER_PATH"
|
||||
|
||||
# Create connection info file if Tailscale is configured
|
||||
if [ -n "$TAILSCALE_AUTH_KEY" ]; then
|
||||
TAILSCALE_IP=$(tailscale ip)
|
||||
cat > "$HOME/remote_access_info.txt" << EOL
|
||||
Mac Remote Access Information
|
||||
============================
|
||||
Computer Name: $MACHINE_NAME
|
||||
Username: $USER
|
||||
Tailscale IP: $TAILSCALE_IP
|
||||
|
||||
SSH Command: ssh $USER@$TAILSCALE_IP
|
||||
Screen Sharing: vnc://$TAILSCALE_IP
|
||||
EOL
|
||||
chmod 600 "$HOME/remote_access_info.txt"
|
||||
fi
|
||||
|
||||
log "Verifying runner service status..."
|
||||
if sudo launchctl list | grep com.github.runner > /dev/null; then
|
||||
log "GitHub Actions runner service is running successfully!"
|
||||
log "Runner labels: $CUSTOM_LABELS"
|
||||
[ -n "$TAILSCALE_AUTH_KEY" ] && log "Remote access details saved to: $HOME/remote_access_info.txt"
|
||||
else
|
||||
log "Error: Failed to start GitHub Actions runner service"
|
||||
exit 1
|
||||
fi
|
||||
95
.github/optimize_performance.sh
vendored
Executable file
95
.github/optimize_performance.sh
vendored
Executable file
@@ -0,0 +1,95 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# Function to log with timestamp
|
||||
log() {
|
||||
echo "[$(date '+%Y-%m-%d %H:%M:%S')] $1"
|
||||
}
|
||||
|
||||
log "Applying comprehensive performance optimizations..."
|
||||
|
||||
# System-wide power management
|
||||
log "Configuring power management..."
|
||||
sudo pmset -a lessbright 0
|
||||
sudo pmset -a disablesleep 1
|
||||
sudo pmset -a sleep 0
|
||||
sudo pmset -a hibernatemode 0
|
||||
sudo pmset -a autopoweroff 0
|
||||
sudo pmset -a standby 0
|
||||
sudo pmset -a powernap 0
|
||||
sudo pmset -a proximitywake 0
|
||||
sudo pmset -a tcpkeepalive 1
|
||||
sudo pmset -a powermode 2
|
||||
sudo pmset -a gpuswitch 2
|
||||
sudo pmset -a displaysleep 0
|
||||
sudo pmset -a disksleep 0
|
||||
|
||||
# Memory and kernel optimizations
|
||||
log "Configuring memory and kernel settings..."
|
||||
sudo sysctl -w kern.memorystatus_purge_on_warning=0
|
||||
sudo sysctl -w kern.memorystatus_purge_on_critical=0
|
||||
sudo sysctl -w kern.timer.coalescing_enabled=0
|
||||
|
||||
# Metal and GPU optimizations
|
||||
log "Configuring Metal and GPU settings..."
|
||||
defaults write com.apple.CoreML MPSEnableGPUValidation -bool false
|
||||
defaults write com.apple.CoreML MPSEnableMetalValidation -bool false
|
||||
defaults write com.apple.CoreML MPSEnableGPUDebug -bool false
|
||||
defaults write com.apple.Metal GPUDebug -bool false
|
||||
defaults write com.apple.Metal GPUValidation -bool false
|
||||
defaults write com.apple.Metal MetalValidation -bool false
|
||||
defaults write com.apple.Metal MetalCaptureEnabled -bool false
|
||||
defaults write com.apple.Metal MTLValidationBehavior -string "Disabled"
|
||||
defaults write com.apple.Metal EnableMTLDebugLayer -bool false
|
||||
defaults write com.apple.Metal MTLDebugLevel -int 0
|
||||
defaults write com.apple.Metal PreferIntegratedGPU -bool false
|
||||
defaults write com.apple.Metal ForceMaximumPerformance -bool true
|
||||
defaults write com.apple.Metal MTLPreferredDeviceGPUFrame -bool true
|
||||
|
||||
# Create MPS cache directory with proper permissions
|
||||
sudo mkdir -p /tmp/mps_cache
|
||||
sudo chmod 777 /tmp/mps_cache
|
||||
|
||||
# Process and resource limits
|
||||
log "Configuring process limits..."
|
||||
sudo launchctl limit maxfiles 524288 524288
|
||||
ulimit -n 524288 || log "Warning: Could not set file descriptor limit"
|
||||
ulimit -c 0
|
||||
ulimit -l unlimited || log "Warning: Could not set memory lock limit"
|
||||
|
||||
# Export performance-related environment variables
|
||||
cat << 'EOF' > /tmp/performance_env.sh
|
||||
# Metal optimizations
|
||||
export MTL_DEBUG_LAYER=0
|
||||
export METAL_DEVICE_WRAPPER_TYPE=1
|
||||
export METAL_DEBUG_ERROR_MODE=0
|
||||
export METAL_FORCE_PERFORMANCE_MODE=1
|
||||
export METAL_DEVICE_PRIORITY=high
|
||||
export METAL_MAX_COMMAND_QUEUES=1024
|
||||
export METAL_LOAD_LIMIT=0
|
||||
export METAL_VALIDATION_ENABLED=0
|
||||
export METAL_ENABLE_VALIDATION_LAYER=0
|
||||
export OBJC_DEBUG_MISSING_POOLS=NO
|
||||
export MPS_CACHEDIR=/tmp/mps_cache
|
||||
|
||||
# MLX optimizations
|
||||
export MLX_USE_GPU=1
|
||||
export MLX_METAL_COMPILE_ASYNC=1
|
||||
export MLX_METAL_PREALLOCATE=1
|
||||
export MLX_METAL_MEMORY_GUARD=0
|
||||
export MLX_METAL_CACHE_KERNELS=1
|
||||
export MLX_PLACEMENT_POLICY=metal
|
||||
export MLX_METAL_VALIDATION=0
|
||||
export MLX_METAL_DEBUG=0
|
||||
export MLX_FORCE_P_CORES=1
|
||||
export MLX_METAL_MEMORY_BUDGET=0
|
||||
export MLX_METAL_PREWARM=1
|
||||
|
||||
# Python optimizations
|
||||
export PYTHONUNBUFFERED=1
|
||||
export PYTHONOPTIMIZE=2
|
||||
export PYTHONHASHSEED=0
|
||||
export PYTHONDONTWRITEBYTECODE=1
|
||||
EOF
|
||||
|
||||
log "Performance optimizations completed. Environment variables written to /tmp/performance_env.sh"
|
||||
207
.github/workflows/bench_job.yml
vendored
Normal file
207
.github/workflows/bench_job.yml
vendored
Normal file
@@ -0,0 +1,207 @@
|
||||
# This is the reusable workflow file
|
||||
name: Distributed Job Runner
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
config:
|
||||
required: true
|
||||
type: string
|
||||
model:
|
||||
required: true
|
||||
type: string
|
||||
calling_job_name:
|
||||
required: true
|
||||
type: string
|
||||
network_interface:
|
||||
required: true
|
||||
type: string
|
||||
jobs:
|
||||
generate-matrix:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
matrix: ${{ steps.set-matrix.outputs.matrix }}
|
||||
steps:
|
||||
- id: set-matrix
|
||||
env:
|
||||
CONFIG: ${{ inputs.config }}
|
||||
run: |
|
||||
MATRIX=$(echo $CONFIG | jq -c '{cpu: [to_entries | .[] | .key as $k | range(.value) | $k]}')
|
||||
echo "matrix=$MATRIX" >> $GITHUB_OUTPUT
|
||||
|
||||
run-distributed-job:
|
||||
needs: generate-matrix
|
||||
strategy:
|
||||
matrix: ${{fromJson(needs.generate-matrix.outputs.matrix)}}
|
||||
runs-on: ['self-hosted', 'macOS', '${{ matrix.cpu }}']
|
||||
env:
|
||||
HARDWARE_CONFIG: ${{ inputs.config }}
|
||||
model: ${{ inputs.model }}
|
||||
# Add performance-related environment variables
|
||||
MTL_DEBUG_LAYER: 0
|
||||
METAL_VALIDATION_ENABLED: 0
|
||||
MLX_METAL_VALIDATION: 0
|
||||
MLX_METAL_DEBUG: 0
|
||||
MLX_FORCE_P_CORES: 1
|
||||
MLX_METAL_PREWARM: 1
|
||||
PYTHONOPTIMIZE: 2
|
||||
steps:
|
||||
- name: Cleanup workspace
|
||||
run: |
|
||||
sudo rm -rf "$GITHUB_WORKSPACE"
|
||||
sudo mkdir -p "$GITHUB_WORKSPACE"
|
||||
sudo chown -R $(whoami):$(id -g) "$GITHUB_WORKSPACE"
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
export PATH="/usr/local/bin:/opt/homebrew/bin:$PATH"
|
||||
python3.12 -m venv .venv || {
|
||||
echo "Failed to find python3.12. Checking installation locations:"
|
||||
ls -l /usr/local/bin/python* /opt/homebrew/bin/python* 2>/dev/null || true
|
||||
exit 1
|
||||
}
|
||||
source .venv/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install -e .
|
||||
pip install boto3==1.35.76
|
||||
|
||||
- name: Apply Performance Optimizations
|
||||
run: |
|
||||
# Export performance-related environment variables
|
||||
cat << 'EOF' > /tmp/performance_env.sh
|
||||
# MLX and Metal optimizations
|
||||
export MTL_DEBUG_LAYER=0
|
||||
export METAL_VALIDATION_ENABLED=0
|
||||
export MLX_METAL_VALIDATION=0
|
||||
export MLX_METAL_DEBUG=0
|
||||
export MLX_FORCE_P_CORES=1
|
||||
export MLX_METAL_PREWARM=1
|
||||
export PYTHONOPTIMIZE=2
|
||||
EOF
|
||||
|
||||
# Source the performance environment variables
|
||||
source /tmp/performance_env.sh
|
||||
|
||||
# MLX Memory Settings
|
||||
./configure_mlx.sh
|
||||
|
||||
# Verify optimizations
|
||||
echo "Verifying performance settings..."
|
||||
env | grep -E "MLX_|METAL_|MTL_"
|
||||
|
||||
- name: Run exo
|
||||
env:
|
||||
aws_access_key_id: ${{ secrets.S3_EXO_BENCHMARKS_AWS_ACCESS_KEY_ID }}
|
||||
aws_secret_key: ${{ secrets.S3_EXO_BENCHMARKS_AWS_SECRET_ACCESS_KEY }}
|
||||
run: |
|
||||
# Source performance environment variables
|
||||
source /tmp/performance_env.sh
|
||||
|
||||
# Debug information
|
||||
echo "Current commit SHA: $GITHUB_SHA"
|
||||
git rev-parse HEAD
|
||||
git status
|
||||
|
||||
CALLING_JOB="${{ inputs.calling_job_name }}"
|
||||
UNIQUE_JOB_ID="${CALLING_JOB}_${model}_${GITHUB_RUN_ID}"
|
||||
ALL_NODE_IDS=$(for i in $(seq ${{ strategy.job-total }} -1 0); do echo -n "${UNIQUE_JOB_ID}_${i},"; done | sed 's/,$//')
|
||||
MY_NODE_ID="${UNIQUE_JOB_ID}_${{ strategy.job-index }}"
|
||||
|
||||
source .venv/bin/activate
|
||||
export PATH="/usr/local/bin:/opt/homebrew/bin:$PATH"
|
||||
|
||||
echo "=== Before starting exo ==="
|
||||
ps -eo pid,ppid,user,%cpu,%mem,nice,state,pri,command | head -1
|
||||
ps -eo pid,ppid,user,%cpu,%mem,nice,state,pri,command | grep -i python
|
||||
|
||||
echo "Starting exo daemon..."
|
||||
|
||||
echo "Power mode settings:"
|
||||
sudo pmset -g
|
||||
|
||||
# Start exo with explicit process control
|
||||
sudo taskpolicy -d default -g default -a -t 0 -l 0 .venv/bin/exo \
|
||||
--node-id="${MY_NODE_ID}" \
|
||||
--node-id-filter="${ALL_NODE_IDS}" \
|
||||
--interface-type-filter="${{ inputs.network_interface }}" \
|
||||
--disable-tui \
|
||||
--max-generate-tokens 250 \
|
||||
--chatgpt-api-response-timeout 900 \
|
||||
--chatgpt-api-port 52415 > output1.log 2>&1 &
|
||||
PID1=$!
|
||||
|
||||
echo "Exo process started with PID: $PID1"
|
||||
tail -f output1.log &
|
||||
TAIL1=$!
|
||||
|
||||
# Give process time to start
|
||||
sleep 2
|
||||
|
||||
# Set additional process priorities
|
||||
sudo renice -n -20 -p $PID1
|
||||
sudo taskpolicy -t 4 -p $PID1
|
||||
|
||||
echo "=== After starting exo ==="
|
||||
ps -eo pid,ppid,user,%cpu,%mem,nice,state,pri,command | head -1
|
||||
ps -eo pid,ppid,user,%cpu,%mem,nice,state,pri,command | grep $PID1
|
||||
|
||||
echo "Additional process details:"
|
||||
sudo powermetrics -n 1 -i 1000 --show-process-energy | grep -A 5 $PID1 || true
|
||||
|
||||
trap 'kill $TAIL1' EXIT
|
||||
trap 'kill $PID1' EXIT
|
||||
|
||||
echo "Waiting for all nodes to connect..."
|
||||
for i in {1..20}; do
|
||||
echo "Attempt $i: Checking node count..."
|
||||
nodes=$(curl -s http://localhost:52415/topology | jq ".nodes | length")
|
||||
echo "Current node count: $nodes"
|
||||
if [ "$nodes" -eq "${{ strategy.job-total }}" ]; then
|
||||
echo "All nodes connected successfully!"
|
||||
break
|
||||
fi
|
||||
if [ $i -eq 20 ]; then
|
||||
echo "ERROR: Failed to connect all nodes after 20 attempts. Expected ${{ strategy.job-total }} nodes, but got $nodes"
|
||||
exit 1
|
||||
fi
|
||||
sleep 5
|
||||
done
|
||||
|
||||
if ! kill -0 $PID1 2>/dev/null; then
|
||||
echo "ERROR: Instance (PID $PID1) died unexpectedly. Full log output:"
|
||||
cat output1.log
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ "${{ strategy.job-index }}" -eq "0" ]; then
|
||||
sleep 10
|
||||
echo "This is the primary node (index 0). Running benchmark..."
|
||||
GITHUB_JOB=$CALLING_JOB python .github/bench.py
|
||||
else
|
||||
echo "This is a secondary node (index ${{ strategy.job-index }}). Waiting for completion..."
|
||||
sleep 10
|
||||
while true; do
|
||||
echo "Checking if primary node is still running..."
|
||||
nodes=$(curl -s http://localhost:52415/topology | jq ".nodes | length")
|
||||
echo "Current node count: $nodes"
|
||||
if [ "$nodes" -lt "${{ strategy.job-total }}" ]; then
|
||||
echo "Primary node completed, exiting..."
|
||||
break
|
||||
fi
|
||||
sleep 5
|
||||
done
|
||||
fi
|
||||
|
||||
- name: Check Final System State
|
||||
if: always()
|
||||
run: |
|
||||
echo "=== Final System State ==="
|
||||
sudo pmset -g
|
||||
sudo powermetrics -n 1 -i 1000 --show-process-energy || true
|
||||
system_profiler SPDisplaysDataType
|
||||
sysctl iogpu
|
||||
ps -eo pid,ppid,user,%cpu,%mem,nice,state,command | grep -i python
|
||||
env | grep -E "MLX_|METAL_|MTL_"
|
||||
echo "=== End Final System State ==="
|
||||
71
.github/workflows/benchmarks.yml
vendored
Normal file
71
.github/workflows/benchmarks.yml
vendored
Normal file
@@ -0,0 +1,71 @@
|
||||
name: Build and Test
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ '*' ]
|
||||
tags: [ '*' ]
|
||||
pull_request:
|
||||
branches: [ '*' ]
|
||||
|
||||
jobs:
|
||||
single-m4-pro:
|
||||
strategy:
|
||||
matrix:
|
||||
model: ['llama-3.2-1b', 'llama-3.2-3b', 'llama-3.1-8b']
|
||||
uses: ./.github/workflows/bench_job.yml
|
||||
with:
|
||||
config: '{"M4PRO_GPU16_24GB": 1}'
|
||||
model: ${{ matrix.model }}
|
||||
calling_job_name: 'single-m4-pro'
|
||||
network_interface: 'Ethernet'
|
||||
secrets: inherit
|
||||
|
||||
two-m4-pro-cluster:
|
||||
strategy:
|
||||
matrix:
|
||||
model: ['llama-3.2-1b', 'llama-3.2-3b', 'llama-3.1-8b']
|
||||
uses: ./.github/workflows/bench_job.yml
|
||||
with:
|
||||
config: '{"M4PRO_GPU16_24GB": 2}'
|
||||
model: ${{ matrix.model }}
|
||||
calling_job_name: 'two-m4-pro-cluster'
|
||||
network_interface: 'Ethernet'
|
||||
secrets: inherit
|
||||
|
||||
# two-m4-pro-cluster-thunderbolt:
|
||||
# strategy:
|
||||
# matrix:
|
||||
# model: ['llama-3.2-1b', 'llama-3.2-3b', 'llama-3.1-8b']
|
||||
# uses: ./.github/workflows/bench_job.yml
|
||||
# with:
|
||||
# config: '{"M4PRO_GPU16_24GB": 2}'
|
||||
# model: ${{ matrix.model }}
|
||||
# calling_job_name: 'two-m4-pro-cluster-thunderbolt'
|
||||
# network_interface: 'Thunderbolt'
|
||||
# secrets: inherit
|
||||
|
||||
three-m4-pro-cluster:
|
||||
strategy:
|
||||
matrix:
|
||||
model: ['llama-3.2-1b', 'llama-3.2-3b', 'llama-3.1-8b', 'llama-3.3-70b']
|
||||
fail-fast: false
|
||||
uses: ./.github/workflows/bench_job.yml
|
||||
with:
|
||||
config: '{"M4PRO_GPU16_24GB": 3}'
|
||||
model: ${{ matrix.model }}
|
||||
calling_job_name: 'three-m4-pro-cluster'
|
||||
network_interface: 'Ethernet'
|
||||
secrets: inherit
|
||||
|
||||
# test-m3-single-node:
|
||||
# strategy:
|
||||
# matrix:
|
||||
# model: ['llama-3.2-1b']
|
||||
# fail-fast: false
|
||||
# uses: ./.github/workflows/bench_job.yml
|
||||
# with:
|
||||
# config: '{"M3MAX_GPU40_128GB": 1}'
|
||||
# model: ${{ matrix.model }}
|
||||
# calling_job_name: 'test-m3-cluster'
|
||||
# network_interface: 'Ethernet'
|
||||
# secrets: inherit
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -171,3 +171,5 @@ cython_debug/
|
||||
|
||||
**/*.xcodeproj/*
|
||||
.aider*
|
||||
|
||||
exo/tinychat/images/*.png
|
||||
|
||||
56
README.md
56
README.md
@@ -18,14 +18,17 @@ exo: Run your own AI cluster at home with everyday devices. Maintained by [exo l
|
||||
[](https://dl.circleci.com/status-badge/redirect/circleci/TrkofJDoGzdQAeL6yVHKsg/4i5hJuafuwZYZQxbRAWS71/tree/main)
|
||||
[](https://www.gnu.org/licenses/gpl-3.0)
|
||||
|
||||
<a href="https://trendshift.io/repositories/11849" target="_blank"><img src="https://trendshift.io/api/badge/repositories/11849" alt="exo-explore%2Fexo | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
|
||||
</div>
|
||||
|
||||
---
|
||||
|
||||
Forget expensive NVIDIA GPUs, unify your existing devices into one powerful GPU: iPhone, iPad, Android, Mac, Linux, pretty much any device!
|
||||
Unify your existing devices into one powerful GPU: iPhone, iPad, Android, Mac, NVIDIA, Raspberry Pi, pretty much any device!
|
||||
|
||||
<div align="center">
|
||||
<h2>Update: exo is hiring. See <a href="https://exolabs.net">here</a> for more details.</h2>
|
||||
<h2>Interested in running exo in your business? <a href="mailto:hello@exolabs.net">Contact us</a> to discuss.</h2>
|
||||
</div>
|
||||
|
||||
## Get Involved
|
||||
@@ -38,7 +41,7 @@ We also welcome contributions from the community. We have a list of bounties in
|
||||
|
||||
### Wide Model Support
|
||||
|
||||
exo supports different models including LLaMA ([MLX](exo/inference/mlx/models/llama.py) and [tinygrad](exo/inference/tinygrad/models/llama.py)), Mistral, LlaVA, Qwen and Deepseek.
|
||||
exo supports different models including LLaMA ([MLX](exo/inference/mlx/models/llama.py) and [tinygrad](exo/inference/tinygrad/models/llama.py)), Mistral, LlaVA, Qwen, and Deepseek.
|
||||
|
||||
### Dynamic Model Partitioning
|
||||
|
||||
@@ -58,7 +61,7 @@ Unlike other distributed inference frameworks, exo does not use a master-worker
|
||||
|
||||
Exo supports different [partitioning strategies](exo/topology/partitioning_strategy.py) to split up a model across devices. The default partitioning strategy is [ring memory weighted partitioning](exo/topology/ring_memory_weighted_partitioning_strategy.py). This runs an inference in a ring where each device runs a number of model layers proportional to the memory of the device.
|
||||
|
||||

|
||||

|
||||
|
||||
## Installation
|
||||
|
||||
@@ -100,13 +103,13 @@ source install.sh
|
||||
|
||||
- There are a number of things users have empirically found to improve performance on Apple Silicon Macs:
|
||||
|
||||
1. Upgrade to the latest version of MacOS 15.
|
||||
1. Upgrade to the latest version of macOS Sequoia.
|
||||
2. Run `./configure_mlx.sh`. This runs commands to optimize GPU memory allocation on Apple Silicon Macs.
|
||||
|
||||
|
||||
## Documentation
|
||||
|
||||
### Example Usage on Multiple MacOS Devices
|
||||
### Example Usage on Multiple macOS Devices
|
||||
|
||||
#### Device 1:
|
||||
|
||||
@@ -149,6 +152,18 @@ curl http://localhost:52415/v1/chat/completions \
|
||||
}'
|
||||
```
|
||||
|
||||
#### DeepSeek R1 (full 671B):
|
||||
|
||||
```sh
|
||||
curl http://localhost:52415/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "deepseek-r1",
|
||||
"messages": [{"role": "user", "content": "What is the meaning of exo?"}],
|
||||
"temperature": 0.7
|
||||
}'
|
||||
```
|
||||
|
||||
#### Llava 1.5 7B (Vision Language Model):
|
||||
|
||||
```sh
|
||||
@@ -177,9 +192,9 @@ curl http://localhost:52415/v1/chat/completions \
|
||||
}'
|
||||
```
|
||||
|
||||
### Example Usage on Multiple Heterogenous Devices (MacOS + Linux)
|
||||
### Example Usage on Multiple Heterogenous Devices (macOS + Linux)
|
||||
|
||||
#### Device 1 (MacOS):
|
||||
#### Device 1 (macOS):
|
||||
|
||||
```sh
|
||||
exo
|
||||
@@ -210,9 +225,19 @@ exo run llama-3.2-3b --prompt "What is the meaning of exo?"
|
||||
|
||||
### Model Storage
|
||||
|
||||
Models by default are stored in `~/.cache/huggingface/hub`.
|
||||
Models by default are stored in `~/.cache/exo/downloads`.
|
||||
|
||||
You can set a different model storage location by setting the `HF_HOME` env var.
|
||||
You can set a different model storage location by setting the `EXO_HOME` env var.
|
||||
|
||||
## Model Downloading
|
||||
|
||||
Models are downloaded from Hugging Face. If you are running exo in a country with strict internet censorship, you may need to download the models manually and put them in the `~/.cache/exo/downloads` directory.
|
||||
|
||||
To download models from a proxy endpoint, set the `HF_ENDPOINT` environment variable. For example, to run exo with the huggingface mirror endpoint:
|
||||
|
||||
```sh
|
||||
HF_ENDPOINT=https://hf-mirror.com exo
|
||||
```
|
||||
|
||||
## Debugging
|
||||
|
||||
@@ -244,7 +269,7 @@ python3 format.py ./exo
|
||||
|
||||
## Known Issues
|
||||
|
||||
- On some versions of MacOS/Python, certificates are not installed properly which can lead to SSL errors (e.g. SSL error with huggingface.co). To fix this, run the Install Certificates command, usually:
|
||||
- On certain versions of Python on macOS, certificates may not installed correctly, potentially causing SSL errors (e.g., when accessing huggingface.co). To resolve this, run the `Install Certificates` command, typicall as follows:
|
||||
|
||||
```sh
|
||||
/Applications/Python 3.x/Install Certificates.command
|
||||
@@ -261,8 +286,15 @@ exo supports the following inference engines:
|
||||
- 🚧 [PyTorch](https://github.com/exo-explore/exo/pull/139)
|
||||
- 🚧 [llama.cpp](https://github.com/exo-explore/exo/issues/167)
|
||||
|
||||
## Networking Modules
|
||||
## Discovery Modules
|
||||
|
||||
- ✅ [GRPC](exo/networking/grpc)
|
||||
- ✅ [UDP](exo/networking/udp)
|
||||
- ✅ [Manual](exo/networking/manual)
|
||||
- ✅ [Tailscale](exo/networking/tailscale)
|
||||
- 🚧 [Radio](TODO)
|
||||
- 🚧 [Bluetooth](TODO)
|
||||
|
||||
# Peer Networking Modules
|
||||
|
||||
- ✅ [GRPC](exo/networking/grpc)
|
||||
- 🚧 [NCCL](TODO)
|
||||
|
||||
@@ -1,18 +1,43 @@
|
||||
#!/bin/bash
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Get the total memory in MB
|
||||
TOTAL_MEM_MB=$(($(sysctl -n hw.memsize) / 1024 / 1024))
|
||||
|
||||
# Set WIRED_LIMIT_MB to 80%
|
||||
WIRED_LIMIT_MB=$(($TOTAL_MEM_MB * 80 / 100))
|
||||
# Set WIRED_LWM_MB to 70%
|
||||
WIRED_LWM_MB=$(($TOTAL_MEM_MB * 70 / 100))
|
||||
# Calculate 80% and TOTAL_MEM_GB-5GB in MB
|
||||
EIGHTY_PERCENT=$(($TOTAL_MEM_MB * 80 / 100))
|
||||
MINUS_5GB=$((($TOTAL_MEM_MB - 5120)))
|
||||
|
||||
# Calculate 70% and TOTAL_MEM_GB-8GB in MB
|
||||
SEVENTY_PERCENT=$(($TOTAL_MEM_MB * 70 / 100))
|
||||
MINUS_8GB=$((($TOTAL_MEM_MB - 8192)))
|
||||
|
||||
# Set WIRED_LIMIT_MB to higher value
|
||||
if [ $EIGHTY_PERCENT -gt $MINUS_5GB ]; then
|
||||
WIRED_LIMIT_MB=$EIGHTY_PERCENT
|
||||
else
|
||||
WIRED_LIMIT_MB=$MINUS_5GB
|
||||
fi
|
||||
|
||||
# Set WIRED_LWM_MB to higher value
|
||||
if [ $SEVENTY_PERCENT -gt $MINUS_8GB ]; then
|
||||
WIRED_LWM_MB=$SEVENTY_PERCENT
|
||||
else
|
||||
WIRED_LWM_MB=$MINUS_8GB
|
||||
fi
|
||||
|
||||
# Display the calculated values
|
||||
echo "Total memory: $TOTAL_MEM_MB MB"
|
||||
echo "Maximum limit (iogpu.wired_limit_mb): $WIRED_LIMIT_MB MB"
|
||||
echo "Lower bound (iogpu.wired_lwm_mb): $WIRED_LWM_MB MB"
|
||||
|
||||
# Apply the values with sysctl
|
||||
sudo sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB
|
||||
sudo sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB
|
||||
# Apply the values with sysctl, but check if we're already root
|
||||
if [ "$EUID" -eq 0 ]; then
|
||||
sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB
|
||||
sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB
|
||||
else
|
||||
# Try without sudo first, fall back to sudo if needed
|
||||
sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB 2>/dev/null || \
|
||||
sudo sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB
|
||||
sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB 2>/dev/null || \
|
||||
sudo sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB
|
||||
fi
|
||||
BIN
docs/exo-screenshot.jpg
Normal file
BIN
docs/exo-screenshot.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 295 KiB |
@@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:561ec71a226a154503b1d70553c9d57c7cd45dbb3bb0e1244ed5b00edbf0523d
|
||||
size 479724
|
||||
@@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:3f57b11f2d3aefb3887c5266994c4b4335501830c77a6a53fa6901c8725d0f6c
|
||||
size 55857
|
||||
111
examples/function_calling.py
Normal file
111
examples/function_calling.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import json
|
||||
import re
|
||||
import requests
|
||||
|
||||
def get_current_weather(location: str, unit: str = "celsius"):
|
||||
"""Mock weather data function"""
|
||||
# Hardcoded response for demo purposes
|
||||
return {
|
||||
"location": location,
|
||||
"temperature": 22 if unit == "celsius" else 72,
|
||||
"unit": unit,
|
||||
"forecast": "Sunny with light clouds"
|
||||
}
|
||||
|
||||
def try_parse_tool_calls(content: str):
|
||||
"""Try parse the tool calls."""
|
||||
tool_calls = []
|
||||
offset = 0
|
||||
for i, m in enumerate(re.finditer(r"<tool_call>\n(.+)?\n</tool_call>", content)):
|
||||
if i == 0:
|
||||
offset = m.start()
|
||||
try:
|
||||
func = json.loads(m.group(1))
|
||||
tool_calls.append({"type": "function", "function": func})
|
||||
if isinstance(func["arguments"], str):
|
||||
func["arguments"] = json.loads(func["arguments"])
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Failed to parse tool calls: the content is {m.group(1)} and {e}")
|
||||
pass
|
||||
if tool_calls:
|
||||
if offset > 0 and content[:offset].strip():
|
||||
c = content[:offset]
|
||||
else:
|
||||
c = ""
|
||||
return {"role": "assistant", "content": c, "tool_calls": tool_calls}
|
||||
return {"role": "assistant", "content": re.sub(r"<\|im_end\|>$", "", content)}
|
||||
|
||||
def chat_completion(messages):
|
||||
"""Send chat completion request to local server"""
|
||||
response = requests.post(
|
||||
"http://localhost:52415/v1/chat/completions",
|
||||
json={
|
||||
"model": "qwen-2.5-1.5b",
|
||||
"messages": messages,
|
||||
"tools": [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"]
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}
|
||||
}],
|
||||
"tool_choice": "auto"
|
||||
}
|
||||
)
|
||||
return response.json()
|
||||
|
||||
def main():
|
||||
# Initial conversation
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": "Hi there, what's the weather in Boston?"
|
||||
}]
|
||||
|
||||
# Get initial response
|
||||
response = chat_completion(messages)
|
||||
print(f"First response: {response}")
|
||||
assistant_message = try_parse_tool_calls(response["choices"][0]["message"]["content"])
|
||||
messages.append(assistant_message)
|
||||
|
||||
# If there are tool calls, execute them and continue conversation
|
||||
if "tool_calls" in assistant_message:
|
||||
for tool_call in assistant_message["tool_calls"]:
|
||||
if tool_call["function"]["name"] == "get_current_weather":
|
||||
args = tool_call["function"]["arguments"]
|
||||
weather_data = get_current_weather(**args)
|
||||
|
||||
# Add tool response to messages
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"content": json.dumps(weather_data),
|
||||
"name": tool_call["function"]["name"]
|
||||
})
|
||||
|
||||
# Get final response with weather data
|
||||
response = chat_completion(messages)
|
||||
print(f"Final response: {response}")
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": response["choices"][0]["message"]["content"]
|
||||
})
|
||||
|
||||
# Print full conversation
|
||||
for msg in messages:
|
||||
print(f"\n{msg['role'].upper()}: {msg['content']}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -5,41 +5,56 @@ import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from transformers import AutoTokenizer
|
||||
from typing import List, Literal, Union, Dict
|
||||
from typing import List, Literal, Union, Dict, Optional
|
||||
from aiohttp import web
|
||||
import aiohttp_cors
|
||||
import traceback
|
||||
import signal
|
||||
from exo import DEBUG, VERSION
|
||||
from exo.download.download_progress import RepoProgressEvent
|
||||
from exo.helpers import PrefixDict, shutdown
|
||||
from exo.helpers import PrefixDict, shutdown, get_exo_images_dir
|
||||
from exo.inference.tokenizers import resolve_tokenizer
|
||||
from exo.orchestration import Node
|
||||
from exo.models import build_base_shard, model_cards, get_repo, pretty_name
|
||||
from exo.models import build_base_shard, build_full_shard, model_cards, get_repo, get_supported_models, get_pretty_name
|
||||
from typing import Callable, Optional
|
||||
from exo.download.hf.hf_shard_download import HFShardDownloader
|
||||
import shutil
|
||||
from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import base64
|
||||
from io import BytesIO
|
||||
import platform
|
||||
from exo.download.download_progress import RepoProgressEvent
|
||||
from exo.download.new_shard_download import delete_model
|
||||
import tempfile
|
||||
from exo.apputil import create_animation_mp4
|
||||
from collections import defaultdict
|
||||
|
||||
if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
|
||||
import mlx.core as mx
|
||||
else:
|
||||
import numpy as mx
|
||||
|
||||
|
||||
class Message:
|
||||
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
|
||||
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]], tools: Optional[List[Dict]] = None):
|
||||
self.role = role
|
||||
self.content = content
|
||||
self.tools = tools
|
||||
|
||||
def to_dict(self):
|
||||
return {"role": self.role, "content": self.content}
|
||||
|
||||
data = {"role": self.role, "content": self.content}
|
||||
if self.tools:
|
||||
data["tools"] = self.tools
|
||||
return data
|
||||
|
||||
|
||||
class ChatCompletionRequest:
|
||||
def __init__(self, model: str, messages: List[Message], temperature: float):
|
||||
def __init__(self, model: str, messages: List[Message], temperature: float, tools: Optional[List[Dict]] = None):
|
||||
self.model = model
|
||||
self.messages = messages
|
||||
self.temperature = temperature
|
||||
self.tools = tools
|
||||
|
||||
def to_dict(self):
|
||||
return {"model": self.model, "messages": [message.to_dict() for message in self.messages], "temperature": self.temperature}
|
||||
return {"model": self.model, "messages": [message.to_dict() for message in self.messages], "temperature": self.temperature, "tools": self.tools}
|
||||
|
||||
|
||||
def generate_completion(
|
||||
@@ -119,20 +134,32 @@ def remap_messages(messages: List[Message]) -> List[Message]:
|
||||
return remapped_messages
|
||||
|
||||
|
||||
def build_prompt(tokenizer, _messages: List[Message]):
|
||||
def build_prompt(tokenizer, _messages: List[Message], tools: Optional[List[Dict]] = None):
|
||||
messages = remap_messages(_messages)
|
||||
prompt = tokenizer.apply_chat_template([m.to_dict() for m in messages], tokenize=False, add_generation_prompt=True)
|
||||
for message in messages:
|
||||
if not isinstance(message.content, list):
|
||||
continue
|
||||
chat_template_args = {"conversation": [m.to_dict() for m in messages], "tokenize": False, "add_generation_prompt": True}
|
||||
if tools:
|
||||
chat_template_args["tools"] = tools
|
||||
|
||||
return prompt
|
||||
try:
|
||||
prompt = tokenizer.apply_chat_template(**chat_template_args)
|
||||
if DEBUG >= 3: print(f"!!! Prompt: {prompt}")
|
||||
return prompt
|
||||
except UnicodeEncodeError:
|
||||
# Handle Unicode encoding by ensuring everything is UTF-8
|
||||
chat_template_args["conversation"] = [
|
||||
{k: v.encode('utf-8').decode('utf-8') if isinstance(v, str) else v
|
||||
for k, v in m.to_dict().items()}
|
||||
for m in messages
|
||||
]
|
||||
prompt = tokenizer.apply_chat_template(**chat_template_args)
|
||||
if DEBUG >= 3: print(f"!!! Prompt (UTF-8 encoded): {prompt}")
|
||||
return prompt
|
||||
|
||||
|
||||
def parse_message(data: dict):
|
||||
if "role" not in data or "content" not in data:
|
||||
raise ValueError(f"Invalid message: {data}. Must have 'role' and 'content'")
|
||||
return Message(data["role"], data["content"])
|
||||
return Message(data["role"], data["content"], data.get("tools"))
|
||||
|
||||
|
||||
def parse_chat_request(data: dict, default_model: str):
|
||||
@@ -140,6 +167,7 @@ def parse_chat_request(data: dict, default_model: str):
|
||||
data.get("model", default_model),
|
||||
[parse_message(msg) for msg in data["messages"]],
|
||||
data.get("temperature", 0.0),
|
||||
data.get("tools", None),
|
||||
)
|
||||
|
||||
|
||||
@@ -149,8 +177,17 @@ class PromptSession:
|
||||
self.timestamp = timestamp
|
||||
self.prompt = prompt
|
||||
|
||||
|
||||
class ChatGPTAPI:
|
||||
def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None):
|
||||
def __init__(
|
||||
self,
|
||||
node: Node,
|
||||
inference_engine_classname: str,
|
||||
response_timeout: int = 90,
|
||||
on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None,
|
||||
default_model: Optional[str] = None,
|
||||
system_prompt: Optional[str] = None
|
||||
):
|
||||
self.node = node
|
||||
self.inference_engine_classname = inference_engine_classname
|
||||
self.response_timeout = response_timeout
|
||||
@@ -160,6 +197,12 @@ class ChatGPTAPI:
|
||||
self.prev_token_lens: Dict[str, int] = {}
|
||||
self.stream_tasks: Dict[str, asyncio.Task] = {}
|
||||
self.default_model = default_model or "llama-3.2-1b"
|
||||
self.token_queues = defaultdict(asyncio.Queue)
|
||||
|
||||
# Get the callback system and register our handler
|
||||
self.token_callback = node.on_token.register("chatgpt-api-token-handler")
|
||||
self.token_callback.on_next(lambda _request_id, tokens, is_finished: asyncio.create_task(self.handle_tokens(_request_id, tokens, is_finished)))
|
||||
self.system_prompt = system_prompt
|
||||
|
||||
cors = aiohttp_cors.setup(self.app)
|
||||
cors_options = aiohttp_cors.ResourceOptions(
|
||||
@@ -174,6 +217,7 @@ class ChatGPTAPI:
|
||||
cors.add(self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
|
||||
cors.add(self.app.router.add_post("/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
|
||||
cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
|
||||
cors.add(self.app.router.add_post("/v1/image/generations", self.handle_post_image_generations), {"*": cors_options})
|
||||
cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})
|
||||
cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
|
||||
cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options})
|
||||
@@ -182,18 +226,25 @@ class ChatGPTAPI:
|
||||
cors.add(self.app.router.add_get("/initial_models", self.handle_get_initial_models), {"*": cors_options})
|
||||
cors.add(self.app.router.add_post("/create_animation", self.handle_create_animation), {"*": cors_options})
|
||||
cors.add(self.app.router.add_post("/download", self.handle_post_download), {"*": cors_options})
|
||||
cors.add(self.app.router.add_get("/v1/topology", self.handle_get_topology), {"*": cors_options})
|
||||
cors.add(self.app.router.add_get("/topology", self.handle_get_topology), {"*": cors_options})
|
||||
|
||||
# Add static routes
|
||||
if "__compiled__" not in globals():
|
||||
self.static_dir = Path(__file__).parent.parent/"tinychat"
|
||||
self.app.router.add_get("/", self.handle_root)
|
||||
self.app.router.add_static("/", self.static_dir, name="static")
|
||||
|
||||
# Always add images route, regardless of compilation status
|
||||
self.images_dir = get_exo_images_dir()
|
||||
self.images_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.app.router.add_static('/images/', self.images_dir, name='static_images')
|
||||
|
||||
self.app.middlewares.append(self.timeout_middleware)
|
||||
self.app.middlewares.append(self.log_request)
|
||||
|
||||
async def handle_quit(self, request):
|
||||
if DEBUG>=1: print("Received quit signal")
|
||||
if DEBUG >= 1: print("Received quit signal")
|
||||
response = web.json_response({"detail": "Quit signal received"}, status=200)
|
||||
await response.prepare(request)
|
||||
await response.write_eof()
|
||||
@@ -223,58 +274,22 @@ class ChatGPTAPI:
|
||||
|
||||
async def handle_model_support(self, request):
|
||||
try:
|
||||
response = web.StreamResponse(
|
||||
status=200,
|
||||
reason='OK',
|
||||
headers={
|
||||
'Content-Type': 'text/event-stream',
|
||||
'Cache-Control': 'no-cache',
|
||||
'Connection': 'keep-alive',
|
||||
}
|
||||
)
|
||||
await response.prepare(request)
|
||||
|
||||
for model_name, pretty in pretty_name.items():
|
||||
if model_name in model_cards:
|
||||
model_info = model_cards[model_name]
|
||||
|
||||
if self.inference_engine_classname in model_info.get("repo", {}):
|
||||
shard = build_base_shard(model_name, self.inference_engine_classname)
|
||||
if shard:
|
||||
downloader = HFShardDownloader(quick_check=True)
|
||||
downloader.current_shard = shard
|
||||
downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
|
||||
status = await downloader.get_shard_download_status()
|
||||
|
||||
download_percentage = status.get("overall") if status else None
|
||||
total_size = status.get("total_size") if status else None
|
||||
total_downloaded = status.get("total_downloaded") if status else False
|
||||
|
||||
model_data = {
|
||||
model_name: {
|
||||
"name": pretty,
|
||||
"downloaded": download_percentage == 100 if download_percentage is not None else False,
|
||||
"download_percentage": download_percentage,
|
||||
"total_size": total_size,
|
||||
"total_downloaded": total_downloaded
|
||||
}
|
||||
}
|
||||
|
||||
await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
|
||||
|
||||
await response.write(b"data: [DONE]\n\n")
|
||||
return response
|
||||
response = web.StreamResponse(status=200, reason='OK', headers={ 'Content-Type': 'text/event-stream', 'Cache-Control': 'no-cache', 'Connection': 'keep-alive' })
|
||||
await response.prepare(request)
|
||||
async for path, s in self.node.shard_downloader.get_shard_download_status(self.inference_engine_classname):
|
||||
model_data = { s.shard.model_id: { "downloaded": s.downloaded_bytes == s.total_bytes, "download_percentage": 100 if s.downloaded_bytes == s.total_bytes else 100 * float(s.downloaded_bytes) / float(s.total_bytes), "total_size": s.total_bytes, "total_downloaded": s.downloaded_bytes } }
|
||||
await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
|
||||
await response.write(b"data: [DONE]\n\n")
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in handle_model_support: {str(e)}")
|
||||
traceback.print_exc()
|
||||
return web.json_response(
|
||||
{"detail": f"Server error: {str(e)}"},
|
||||
status=500
|
||||
)
|
||||
print(f"Error in handle_model_support: {str(e)}")
|
||||
traceback.print_exc()
|
||||
return web.json_response({"detail": f"Server error: {str(e)}"}, status=500)
|
||||
|
||||
async def handle_get_models(self, request):
|
||||
return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()])
|
||||
models_list = [{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()]
|
||||
return web.json_response({"object": "list", "data": models_list})
|
||||
|
||||
async def handle_post_chat_token_encode(self, request):
|
||||
data = await request.json()
|
||||
@@ -287,7 +302,7 @@ class ChatGPTAPI:
|
||||
shard = build_base_shard(model, self.inference_engine_classname)
|
||||
messages = [parse_message(msg) for msg in data.get("messages", [])]
|
||||
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
|
||||
prompt = build_prompt(tokenizer, messages)
|
||||
prompt = build_prompt(tokenizer, messages, data.get("tools", None))
|
||||
tokens = tokenizer.encode(prompt)
|
||||
return web.json_response({
|
||||
"length": len(prompt),
|
||||
@@ -300,6 +315,7 @@ class ChatGPTAPI:
|
||||
progress_data = {}
|
||||
for node_id, progress_event in self.node.node_download_progress.items():
|
||||
if isinstance(progress_event, RepoProgressEvent):
|
||||
if progress_event.status != "in_progress": continue
|
||||
progress_data[node_id] = progress_event.to_dict()
|
||||
else:
|
||||
print(f"Unknown progress event type: {type(progress_event)}. {progress_event}")
|
||||
@@ -307,13 +323,13 @@ class ChatGPTAPI:
|
||||
|
||||
async def handle_post_chat_completions(self, request):
|
||||
data = await request.json()
|
||||
if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
|
||||
if DEBUG >= 2: print(f"[ChatGPTAPI] Handling chat completions request from {request.remote}: {data}")
|
||||
stream = data.get("stream", False)
|
||||
chat_request = parse_chat_request(data, self.default_model)
|
||||
if chat_request.model and chat_request.model.startswith("gpt-"): # to be compatible with ChatGPT tools, point all gpt- model requests to default model
|
||||
chat_request.model = self.default_model
|
||||
if not chat_request.model or chat_request.model not in model_cards:
|
||||
if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
|
||||
if DEBUG >= 1: print(f"[ChatGPTAPI] Invalid model: {chat_request.model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
|
||||
chat_request.model = self.default_model
|
||||
shard = build_base_shard(chat_request.model, self.inference_engine_classname)
|
||||
if not shard:
|
||||
@@ -324,37 +340,26 @@ class ChatGPTAPI:
|
||||
)
|
||||
|
||||
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
|
||||
if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
|
||||
if DEBUG >= 4: print(f"[ChatGPTAPI] Resolved tokenizer: {tokenizer}")
|
||||
|
||||
prompt = build_prompt(tokenizer, chat_request.messages)
|
||||
# Add system prompt if set
|
||||
if self.system_prompt and not any(msg.role == "system" for msg in chat_request.messages):
|
||||
chat_request.messages.insert(0, Message("system", self.system_prompt))
|
||||
|
||||
prompt = build_prompt(tokenizer, chat_request.messages, chat_request.tools)
|
||||
request_id = str(uuid.uuid4())
|
||||
if self.on_chat_completion_request:
|
||||
try:
|
||||
self.on_chat_completion_request(request_id, chat_request, prompt)
|
||||
except Exception as e:
|
||||
if DEBUG >= 2: traceback.print_exc()
|
||||
# request_id = None
|
||||
# match = self.prompts.find_longest_prefix(prompt)
|
||||
# if match and len(prompt) > len(match[1].prompt):
|
||||
# if DEBUG >= 2:
|
||||
# print(f"Prompt for request starts with previous prompt {len(match[1].prompt)} of {len(prompt)}: {match[1].prompt}")
|
||||
# request_id = match[1].request_id
|
||||
# self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt))
|
||||
# # remove the matching prefix from the prompt
|
||||
# prompt = prompt[len(match[1].prompt):]
|
||||
# else:
|
||||
# request_id = str(uuid.uuid4())
|
||||
# self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt))
|
||||
|
||||
callback_id = f"chatgpt-api-wait-response-{request_id}"
|
||||
callback = self.node.on_token.register(callback_id)
|
||||
|
||||
if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}")
|
||||
if DEBUG >= 2: print(f"[ChatGPTAPI] Processing prompt: {request_id=} {shard=} {prompt=}")
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id))), timeout=self.response_timeout)
|
||||
|
||||
if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout}s")
|
||||
if DEBUG >= 2: print(f"[ChatGPTAPI] Waiting for response to finish. timeout={self.response_timeout}s")
|
||||
|
||||
if stream:
|
||||
response = web.StreamResponse(
|
||||
@@ -367,62 +372,74 @@ class ChatGPTAPI:
|
||||
)
|
||||
await response.prepare(request)
|
||||
|
||||
async def stream_result(_request_id: str, tokens: List[int], is_finished: bool):
|
||||
prev_last_tokens_len = self.prev_token_lens.get(_request_id, 0)
|
||||
self.prev_token_lens[_request_id] = max(prev_last_tokens_len, len(tokens))
|
||||
new_tokens = tokens[prev_last_tokens_len:]
|
||||
finish_reason = None
|
||||
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") and isinstance(tokenizer._tokenizer,
|
||||
AutoTokenizer) else getattr(tokenizer, "eos_token_id", None)
|
||||
if len(new_tokens) > 0 and new_tokens[-1] == eos_token_id:
|
||||
new_tokens = new_tokens[:-1]
|
||||
if is_finished:
|
||||
finish_reason = "stop"
|
||||
if is_finished and not finish_reason:
|
||||
finish_reason = "length"
|
||||
try:
|
||||
# Stream tokens while waiting for inference to complete
|
||||
while True:
|
||||
if DEBUG >= 2: print(f"[ChatGPTAPI] Waiting for token from queue: {request_id=}")
|
||||
tokens, is_finished = await asyncio.wait_for(
|
||||
self.token_queues[request_id].get(),
|
||||
timeout=self.response_timeout
|
||||
)
|
||||
if DEBUG >= 2: print(f"[ChatGPTAPI] Got token from queue: {request_id=} {tokens=} {is_finished=}")
|
||||
|
||||
eos_token_id = None
|
||||
if not eos_token_id and hasattr(tokenizer, "eos_token_id"): eos_token_id = tokenizer.eos_token_id
|
||||
if not eos_token_id and hasattr(tokenizer, "_tokenizer"): eos_token_id = tokenizer.special_tokens_map.get("eos_token_id")
|
||||
|
||||
finish_reason = None
|
||||
if is_finished: finish_reason = "stop" if tokens[-1] == eos_token_id else "length"
|
||||
if DEBUG >= 2: print(f"{eos_token_id=} {tokens[-1]=} {finish_reason=}")
|
||||
|
||||
completion = generate_completion(
|
||||
chat_request,
|
||||
tokenizer,
|
||||
prompt,
|
||||
request_id,
|
||||
tokens,
|
||||
stream,
|
||||
finish_reason,
|
||||
"chat.completion",
|
||||
)
|
||||
|
||||
completion = generate_completion(
|
||||
chat_request,
|
||||
tokenizer,
|
||||
prompt,
|
||||
request_id,
|
||||
new_tokens,
|
||||
stream,
|
||||
finish_reason,
|
||||
"chat.completion",
|
||||
)
|
||||
if DEBUG >= 2: print(f"Streaming completion: {completion}")
|
||||
try:
|
||||
await response.write(f"data: {json.dumps(completion)}\n\n".encode())
|
||||
except Exception as e:
|
||||
if DEBUG >= 2: print(f"Error streaming completion: {e}")
|
||||
if DEBUG >= 2: traceback.print_exc()
|
||||
|
||||
def on_result(_request_id: str, tokens: List[int], is_finished: bool):
|
||||
if _request_id == request_id: self.stream_tasks[_request_id] = asyncio.create_task(stream_result(_request_id, tokens, is_finished))
|
||||
if is_finished:
|
||||
break
|
||||
|
||||
return _request_id == request_id and is_finished
|
||||
await response.write_eof()
|
||||
return response
|
||||
|
||||
_, tokens, _ = await callback.wait(on_result, timeout=self.response_timeout)
|
||||
if request_id in self.stream_tasks: # in case there is still a stream task running, wait for it to complete
|
||||
if DEBUG >= 2: print("Pending stream task. Waiting for stream task to complete.")
|
||||
try:
|
||||
await asyncio.wait_for(self.stream_tasks[request_id], timeout=30)
|
||||
except asyncio.TimeoutError:
|
||||
print("WARNING: Stream task timed out. This should not happen.")
|
||||
await response.write_eof()
|
||||
return response
|
||||
except asyncio.TimeoutError:
|
||||
if DEBUG >= 2: print(f"[ChatGPTAPI] Timeout waiting for token: {request_id=}")
|
||||
return web.json_response({"detail": "Response generation timed out"}, status=408)
|
||||
|
||||
except Exception as e:
|
||||
if DEBUG >= 2:
|
||||
print(f"[ChatGPTAPI] Error processing prompt: {e}")
|
||||
traceback.print_exc()
|
||||
return web.json_response(
|
||||
{"detail": f"Error processing prompt: {str(e)}"},
|
||||
status=500
|
||||
)
|
||||
|
||||
finally:
|
||||
# Clean up the queue for this request
|
||||
if request_id in self.token_queues:
|
||||
if DEBUG >= 2: print(f"[ChatGPTAPI] Cleaning up token queue: {request_id=}")
|
||||
del self.token_queues[request_id]
|
||||
else:
|
||||
_, tokens, _ = await callback.wait(
|
||||
lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished,
|
||||
timeout=self.response_timeout,
|
||||
)
|
||||
|
||||
tokens = []
|
||||
while True:
|
||||
_tokens, is_finished = await asyncio.wait_for(self.token_queues[request_id].get(), timeout=self.response_timeout)
|
||||
tokens.extend(_tokens)
|
||||
if is_finished:
|
||||
break
|
||||
finish_reason = "length"
|
||||
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(getattr(tokenizer, "_tokenizer", None), AutoTokenizer) else tokenizer.eos_token_id
|
||||
eos_token_id = None
|
||||
if not eos_token_id and hasattr(tokenizer, "eos_token_id"): eos_token_id = tokenizer.eos_token_id
|
||||
if not eos_token_id and hasattr(tokenizer, "_tokenizer"): eos_token_id = tokenizer.special_tokens_map.get("eos_token_id")
|
||||
if DEBUG >= 2: print(f"Checking if end of tokens result {tokens[-1]=} is {eos_token_id=}")
|
||||
if tokens[-1] == eos_token_id:
|
||||
tokens = tokens[:-1]
|
||||
finish_reason = "stop"
|
||||
|
||||
return web.json_response(generate_completion(chat_request, tokenizer, prompt, request_id, tokens, stream, finish_reason, "chat.completion"))
|
||||
@@ -431,73 +448,119 @@ class ChatGPTAPI:
|
||||
except Exception as e:
|
||||
if DEBUG >= 2: traceback.print_exc()
|
||||
return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
|
||||
finally:
|
||||
deregistered_callback = self.node.on_token.deregister(callback_id)
|
||||
if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")
|
||||
|
||||
async def handle_delete_model(self, request):
|
||||
async def handle_post_image_generations(self, request):
|
||||
data = await request.json()
|
||||
|
||||
if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
|
||||
stream = data.get("stream", False)
|
||||
model = data.get("model", "")
|
||||
prompt = data.get("prompt", "")
|
||||
image_url = data.get("image_url", "")
|
||||
if DEBUG >= 2: print(f"model: {model}, prompt: {prompt}, stream: {stream}")
|
||||
shard = build_base_shard(model, self.inference_engine_classname)
|
||||
if DEBUG >= 2: print(f"shard: {shard}")
|
||||
if not shard:
|
||||
return web.json_response({"error": f"Unsupported model: {model} with inference engine {self.inference_engine_classname}"}, status=400)
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
callback_id = f"chatgpt-api-wait-response-{request_id}"
|
||||
callback = self.node.on_token.register(callback_id)
|
||||
try:
|
||||
model_name = request.match_info.get('model_name')
|
||||
if DEBUG >= 2: print(f"Attempting to delete model: {model_name}")
|
||||
|
||||
if not model_name or model_name not in model_cards:
|
||||
return web.json_response(
|
||||
{"detail": f"Invalid model name: {model_name}"},
|
||||
status=400
|
||||
)
|
||||
|
||||
shard = build_base_shard(model_name, self.inference_engine_classname)
|
||||
if not shard:
|
||||
return web.json_response(
|
||||
{"detail": "Could not build shard for model"},
|
||||
status=400
|
||||
)
|
||||
|
||||
repo_id = get_repo(shard.model_id, self.inference_engine_classname)
|
||||
if DEBUG >= 2: print(f"Repo ID for model: {repo_id}")
|
||||
|
||||
# Get the HF cache directory using the helper function
|
||||
hf_home = get_hf_home()
|
||||
cache_dir = get_repo_root(repo_id)
|
||||
|
||||
if DEBUG >= 2: print(f"Looking for model files in: {cache_dir}")
|
||||
|
||||
if os.path.exists(cache_dir):
|
||||
if DEBUG >= 2: print(f"Found model files at {cache_dir}, deleting...")
|
||||
try:
|
||||
shutil.rmtree(cache_dir)
|
||||
return web.json_response({
|
||||
"status": "success",
|
||||
"message": f"Model {model_name} deleted successfully",
|
||||
"path": str(cache_dir)
|
||||
})
|
||||
except Exception as e:
|
||||
return web.json_response({
|
||||
"detail": f"Failed to delete model files: {str(e)}"
|
||||
}, status=500)
|
||||
if image_url != "" and image_url != None:
|
||||
img = self.base64_decode(image_url)
|
||||
else:
|
||||
return web.json_response({
|
||||
"detail": f"Model files not found at {cache_dir}"
|
||||
}, status=404)
|
||||
img = None
|
||||
await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id, inference_state={"image": img}))), timeout=self.response_timeout)
|
||||
|
||||
response = web.StreamResponse(status=200, reason='OK', headers={
|
||||
'Content-Type': 'application/octet-stream',
|
||||
"Cache-Control": "no-cache",
|
||||
})
|
||||
await response.prepare(request)
|
||||
|
||||
def get_progress_bar(current_step, total_steps, bar_length=50):
|
||||
# Calculate the percentage of completion
|
||||
percent = float(current_step)/total_steps
|
||||
# Calculate the number of hashes to display
|
||||
arrow = '-'*int(round(percent*bar_length) - 1) + '>'
|
||||
spaces = ' '*(bar_length - len(arrow))
|
||||
|
||||
# Create the progress bar string
|
||||
progress_bar = f'Progress: [{arrow}{spaces}] {int(percent * 100)}% ({current_step}/{total_steps})'
|
||||
return progress_bar
|
||||
|
||||
async def stream_image(_request_id: str, result, is_finished: bool):
|
||||
if isinstance(result, list):
|
||||
await response.write(json.dumps({'progress': get_progress_bar((result[0]), (result[1]))}).encode('utf-8') + b'\n')
|
||||
|
||||
elif isinstance(result, np.ndarray):
|
||||
try:
|
||||
im = Image.fromarray(np.array(result))
|
||||
# Save the image to a file
|
||||
image_filename = f"{_request_id}.png"
|
||||
image_path = self.images_dir/image_filename
|
||||
im.save(image_path)
|
||||
|
||||
# Get URL for the saved image
|
||||
try:
|
||||
image_url = request.app.router['static_images'].url_for(filename=image_filename)
|
||||
base_url = f"{request.scheme}://{request.host}"
|
||||
full_image_url = base_url + str(image_url)
|
||||
|
||||
await response.write(json.dumps({'images': [{'url': str(full_image_url), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
|
||||
except KeyError as e:
|
||||
if DEBUG >= 2: print(f"Error getting image URL: {e}")
|
||||
# Fallback to direct file path if URL generation fails
|
||||
await response.write(json.dumps({'images': [{'url': str(image_path), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
|
||||
|
||||
if is_finished:
|
||||
await response.write_eof()
|
||||
|
||||
except Exception as e:
|
||||
if DEBUG >= 2: print(f"Error processing image: {e}")
|
||||
if DEBUG >= 2: traceback.print_exc()
|
||||
await response.write(json.dumps({'error': str(e)}).encode('utf-8') + b'\n')
|
||||
|
||||
stream_task = None
|
||||
|
||||
def on_result(_request_id: str, result, is_finished: bool):
|
||||
nonlocal stream_task
|
||||
stream_task = asyncio.create_task(stream_image(_request_id, result, is_finished))
|
||||
return _request_id == request_id and is_finished
|
||||
|
||||
await callback.wait(on_result, timeout=self.response_timeout*10)
|
||||
|
||||
if stream_task:
|
||||
# Wait for the stream task to complete before returning
|
||||
await stream_task
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in handle_delete_model: {str(e)}")
|
||||
traceback.print_exc()
|
||||
return web.json_response({
|
||||
"detail": f"Server error: {str(e)}"
|
||||
}, status=500)
|
||||
if DEBUG >= 2: traceback.print_exc()
|
||||
return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
|
||||
|
||||
async def handle_delete_model(self, request):
|
||||
model_id = request.match_info.get('model_name')
|
||||
try:
|
||||
if await delete_model(model_id, self.inference_engine_classname): return web.json_response({"status": "success", "message": f"Model {model_id} deleted successfully"})
|
||||
else: return web.json_response({"detail": f"Model {model_id} files not found"}, status=404)
|
||||
except Exception as e:
|
||||
if DEBUG >= 2: traceback.print_exc()
|
||||
return web.json_response({"detail": f"Error deleting model: {str(e)}"}, status=500)
|
||||
|
||||
async def handle_get_initial_models(self, request):
|
||||
model_data = {}
|
||||
for model_name, pretty in pretty_name.items():
|
||||
model_data[model_name] = {
|
||||
"name": pretty,
|
||||
"downloaded": None, # Initially unknown
|
||||
"download_percentage": None, # Change from 0 to null
|
||||
"total_size": None,
|
||||
"total_downloaded": None,
|
||||
"loading": True # Add loading state
|
||||
}
|
||||
for model_id in get_supported_models([[self.inference_engine_classname]]):
|
||||
model_data[model_id] = {
|
||||
"name": get_pretty_name(model_id),
|
||||
"downloaded": None, # Initially unknown
|
||||
"download_percentage": None, # Change from 0 to null
|
||||
"total_size": None,
|
||||
"total_downloaded": None,
|
||||
"loading": True # Add loading state
|
||||
}
|
||||
return web.json_response(model_data)
|
||||
|
||||
async def handle_create_animation(self, request):
|
||||
@@ -523,17 +586,9 @@ class ChatGPTAPI:
|
||||
if DEBUG >= 2: print(f"Animation temp directory: {tmp_dir}, output file: {output_path}, directory exists: {tmp_dir.exists()}, directory permissions: {oct(tmp_dir.stat().st_mode)[-3:]}")
|
||||
|
||||
# Create the animation
|
||||
create_animation_mp4(
|
||||
replacement_image_path,
|
||||
output_path,
|
||||
device_name,
|
||||
prompt_text
|
||||
)
|
||||
create_animation_mp4(replacement_image_path, output_path, device_name, prompt_text)
|
||||
|
||||
return web.json_response({
|
||||
"status": "success",
|
||||
"output_path": output_path
|
||||
})
|
||||
return web.json_response({"status": "success", "output_path": output_path})
|
||||
|
||||
except Exception as e:
|
||||
if DEBUG >= 2: traceback.print_exc()
|
||||
@@ -545,14 +600,11 @@ class ChatGPTAPI:
|
||||
model_name = data.get("model")
|
||||
if not model_name: return web.json_response({"error": "model parameter is required"}, status=400)
|
||||
if model_name not in model_cards: return web.json_response({"error": f"Invalid model: {model_name}. Supported models: {list(model_cards.keys())}"}, status=400)
|
||||
shard = build_base_shard(model_name, self.inference_engine_classname)
|
||||
shard = build_full_shard(model_name, self.inference_engine_classname)
|
||||
if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400)
|
||||
asyncio.create_task(self.node.inference_engine.ensure_shard(shard))
|
||||
asyncio.create_task(self.node.inference_engine.shard_downloader.ensure_shard(shard, self.inference_engine_classname))
|
||||
|
||||
return web.json_response({
|
||||
"status": "success",
|
||||
"message": f"Download started for model: {model_name}"
|
||||
})
|
||||
return web.json_response({"status": "success", "message": f"Download started for model: {model_name}"})
|
||||
except Exception as e:
|
||||
if DEBUG >= 2: traceback.print_exc()
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
@@ -566,13 +618,28 @@ class ChatGPTAPI:
|
||||
return web.json_response({})
|
||||
except Exception as e:
|
||||
if DEBUG >= 2: traceback.print_exc()
|
||||
return web.json_response(
|
||||
{"detail": f"Error getting topology: {str(e)}"},
|
||||
status=500
|
||||
)
|
||||
return web.json_response({"detail": f"Error getting topology: {str(e)}"}, status=500)
|
||||
|
||||
async def handle_tokens(self, request_id: str, tokens: List[int], is_finished: bool):
|
||||
await self.token_queues[request_id].put((tokens, is_finished))
|
||||
|
||||
async def run(self, host: str = "0.0.0.0", port: int = 52415):
|
||||
runner = web.AppRunner(self.app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, host, port)
|
||||
await site.start()
|
||||
|
||||
def base64_decode(self, base64_string):
|
||||
#decode and reshape image
|
||||
if base64_string.startswith('data:image'):
|
||||
base64_string = base64_string.split(',')[1]
|
||||
image_data = base64.b64decode(base64_string)
|
||||
img = Image.open(BytesIO(image_data))
|
||||
W, H = (dim - dim%64 for dim in (img.width, img.height))
|
||||
if W != img.width or H != img.height:
|
||||
if DEBUG >= 2: print(f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}")
|
||||
img = img.resize((W, H), Image.NEAREST) # use desired downsampling filter
|
||||
img = mx.array(np.array(img))
|
||||
img = (img[:, :, :3].astype(mx.float32)/255)*2 - 1
|
||||
img = img[None]
|
||||
return img
|
||||
|
||||
@@ -2,6 +2,7 @@ from PIL import Image, ImageDraw, ImageFont, ImageFilter
|
||||
import os
|
||||
import numpy as np
|
||||
import cv2
|
||||
import sys
|
||||
|
||||
def draw_rounded_rectangle(draw, coords, radius, fill):
|
||||
left, top, right, bottom = coords
|
||||
@@ -80,14 +81,20 @@ def create_animation_mp4(
|
||||
font = ImageFont.load_default()
|
||||
promptfont = ImageFont.load_default()
|
||||
|
||||
# Get the base directory for images when running as a bundled app
|
||||
if hasattr(sys, '_MEIPASS'):
|
||||
base_dir = os.path.join(sys._MEIPASS, "exo", "apputil", "baseimages")
|
||||
else:
|
||||
base_dir = os.path.join(os.path.dirname(__file__), "baseimages")
|
||||
|
||||
# Process first frame
|
||||
base_img = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image1.png"))
|
||||
base_img = Image.open(os.path.join(base_dir, "image1.png"))
|
||||
draw = ImageDraw.Draw(base_img)
|
||||
draw_centered_text_rounded(draw, device_name, font, device_coords)
|
||||
frames.extend([crop_image(base_img)] * 30) # 1 second at 30fps
|
||||
|
||||
# Process second frame with typing animation
|
||||
base_img2 = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image2.png"))
|
||||
base_img2 = Image.open(os.path.join(base_dir, "image2.png"))
|
||||
for i in range(len(prompt_text) + 1):
|
||||
current_frame = base_img2.copy()
|
||||
draw = ImageDraw.Draw(current_frame)
|
||||
@@ -101,7 +108,7 @@ def create_animation_mp4(
|
||||
|
||||
# Create blur sequence
|
||||
replacement_img = Image.open(replacement_image_path)
|
||||
base_img = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image3.png"))
|
||||
base_img = Image.open(os.path.join(base_dir, "image3.png"))
|
||||
blur_steps = [int(80 * (1 - i/8)) for i in range(9)]
|
||||
|
||||
for i, blur_amount in enumerate(blur_steps):
|
||||
@@ -123,7 +130,7 @@ def create_animation_mp4(
|
||||
frames.extend([crop_image(new_frame)] * 15) # 0.5 seconds at 30fps
|
||||
|
||||
# Create and add final frame (image4)
|
||||
final_base = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image4.png"))
|
||||
final_base = Image.open(os.path.join(base_dir, "image4.png"))
|
||||
draw = ImageDraw.Draw(final_base)
|
||||
|
||||
draw_centered_text_rounded(draw, device_name, font, device_coords)
|
||||
@@ -158,4 +165,4 @@ def create_animation_mp4(
|
||||
out.write(frame_array)
|
||||
|
||||
out.release()
|
||||
print(f"Video saved successfully to {output_path}")
|
||||
print(f"Video saved successfully to {output_path}")
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Dict, Callable, Coroutine, Any, Literal
|
||||
from exo.inference.shard import Shard
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
|
||||
@@ -14,11 +15,12 @@ class RepoFileProgressEvent:
|
||||
speed: int
|
||||
eta: timedelta
|
||||
status: Literal["not_started", "in_progress", "complete"]
|
||||
start_time: float
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"repo_id": self.repo_id, "repo_revision": self.repo_revision, "file_path": self.file_path, "downloaded": self.downloaded, "downloaded_this_session": self.downloaded_this_session,
|
||||
"total": self.total, "speed": self.speed, "eta": self.eta.total_seconds(), "status": self.status
|
||||
"total": self.total, "speed": self.speed, "eta": self.eta.total_seconds(), "status": self.status, "start_time": self.start_time
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -29,6 +31,7 @@ class RepoFileProgressEvent:
|
||||
|
||||
@dataclass
|
||||
class RepoProgressEvent:
|
||||
shard: Shard
|
||||
repo_id: str
|
||||
repo_revision: str
|
||||
completed_files: int
|
||||
@@ -43,7 +46,7 @@ class RepoProgressEvent:
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"repo_id": self.repo_id, "repo_revision": self.repo_revision, "completed_files": self.completed_files, "total_files": self.total_files, "downloaded_bytes": self.downloaded_bytes,
|
||||
"shard": self.shard.to_dict(), "repo_id": self.repo_id, "repo_revision": self.repo_revision, "completed_files": self.completed_files, "total_files": self.total_files, "downloaded_bytes": self.downloaded_bytes,
|
||||
"downloaded_bytes_this_session": self.downloaded_bytes_this_session, "total_bytes": self.total_bytes, "overall_speed": self.overall_speed, "overall_eta": self.overall_eta.total_seconds(),
|
||||
"file_progress": {k: v.to_dict()
|
||||
for k, v in self.file_progress.items()}, "status": self.status
|
||||
@@ -53,6 +56,7 @@ class RepoProgressEvent:
|
||||
def from_dict(cls, data):
|
||||
if 'overall_eta' in data: data['overall_eta'] = timedelta(seconds=data['overall_eta'])
|
||||
if 'file_progress' in data: data['file_progress'] = {k: RepoFileProgressEvent.from_dict(v) for k, v in data['file_progress'].items()}
|
||||
if 'shard' in data: data['shard'] = Shard.from_dict(data['shard'])
|
||||
|
||||
return cls(**data)
|
||||
|
||||
|
||||
@@ -1,36 +1,16 @@
|
||||
import aiofiles.os as aios
|
||||
from typing import Union
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
from urllib.parse import urljoin
|
||||
from typing import Callable, Optional, Coroutine, Any, Dict, List, Union, Literal
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Callable, Optional, Dict, List, Union
|
||||
from fnmatch import fnmatch
|
||||
from pathlib import Path
|
||||
from typing import Generator, Iterable, TypeVar, TypedDict
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
||||
from exo.helpers import DEBUG, is_frozen
|
||||
from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent, RepoProgressCallback, RepoFileProgressCallback
|
||||
from typing import Generator, Iterable, TypeVar
|
||||
from exo.helpers import DEBUG
|
||||
from exo.inference.shard import Shard
|
||||
import aiofiles
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
async def get_local_snapshot_dir(repo_id: str, revision: str = "main") -> Optional[Path]:
|
||||
refs_dir = get_repo_root(repo_id)/"refs"
|
||||
refs_file = refs_dir/revision
|
||||
if await aios.path.exists(refs_file):
|
||||
async with aiofiles.open(refs_file, 'r') as f:
|
||||
commit_hash = (await f.read()).strip()
|
||||
snapshot_dir = get_repo_root(repo_id)/"snapshots"/commit_hash
|
||||
return snapshot_dir
|
||||
return None
|
||||
|
||||
|
||||
def filter_repo_objects(
|
||||
items: Iterable[T],
|
||||
*,
|
||||
@@ -48,14 +28,12 @@ def filter_repo_objects(
|
||||
ignore_patterns = [_add_wildcard_to_directories(p) for p in ignore_patterns]
|
||||
|
||||
if key is None:
|
||||
|
||||
def _identity(item: T) -> str:
|
||||
if isinstance(item, str):
|
||||
return item
|
||||
if isinstance(item, Path):
|
||||
return str(item)
|
||||
raise ValueError(f"Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.")
|
||||
|
||||
key = _identity
|
||||
|
||||
for item in items:
|
||||
@@ -66,22 +44,18 @@ def filter_repo_objects(
|
||||
continue
|
||||
yield item
|
||||
|
||||
|
||||
def _add_wildcard_to_directories(pattern: str) -> str:
|
||||
if pattern[-1] == "/":
|
||||
return pattern + "*"
|
||||
return pattern
|
||||
|
||||
|
||||
def get_hf_endpoint() -> str:
|
||||
return os.environ.get('HF_ENDPOINT', "https://huggingface.co")
|
||||
|
||||
|
||||
def get_hf_home() -> Path:
|
||||
"""Get the Hugging Face home directory."""
|
||||
return Path(os.environ.get("HF_HOME", Path.home()/".cache"/"huggingface"))
|
||||
|
||||
|
||||
async def get_hf_token():
|
||||
"""Retrieve the Hugging Face token from the user's HF_HOME directory."""
|
||||
token_path = get_hf_home()/"token"
|
||||
@@ -90,7 +64,6 @@ async def get_hf_token():
|
||||
return (await f.read()).strip()
|
||||
return None
|
||||
|
||||
|
||||
async def get_auth_headers():
|
||||
"""Get authentication headers if a token is available."""
|
||||
token = await get_hf_token()
|
||||
@@ -98,321 +71,6 @@ async def get_auth_headers():
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
return {}
|
||||
|
||||
|
||||
def get_repo_root(repo_id: str) -> Path:
|
||||
"""Get the root directory for a given repo ID in the Hugging Face cache."""
|
||||
sanitized_repo_id = str(repo_id).replace("/", "--")
|
||||
return get_hf_home()/"hub"/f"models--{sanitized_repo_id}"
|
||||
|
||||
async def move_models_to_hf(seed_dir: Union[str, Path]):
|
||||
"""Move model in resources folder of app to .cache/huggingface/hub"""
|
||||
source_dir = Path(seed_dir)
|
||||
dest_dir = get_hf_home()/"hub"
|
||||
await aios.makedirs(dest_dir, exist_ok=True)
|
||||
for path in source_dir.iterdir():
|
||||
if path.is_dir() and path.name.startswith("models--"):
|
||||
dest_path = dest_dir / path.name
|
||||
if await aios.path.exists(dest_path):
|
||||
print('Skipping moving model to .cache directory')
|
||||
else:
|
||||
try:
|
||||
await aios.rename(str(path), str(dest_path))
|
||||
except Exception as e:
|
||||
print(f'Error moving model to .cache: {e}')
|
||||
|
||||
|
||||
|
||||
async def fetch_file_list(session, repo_id, revision, path=""):
|
||||
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
|
||||
url = f"{api_url}/{path}" if path else api_url
|
||||
|
||||
headers = await get_auth_headers()
|
||||
async with session.get(url, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
files = []
|
||||
for item in data:
|
||||
if item["type"] == "file":
|
||||
files.append({"path": item["path"], "size": item["size"]})
|
||||
elif item["type"] == "directory":
|
||||
subfiles = await fetch_file_list(session, repo_id, revision, item["path"])
|
||||
files.extend(subfiles)
|
||||
return files
|
||||
else:
|
||||
raise Exception(f"Failed to fetch file list: {response.status}")
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=60), retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError, aiohttp.ClientResponseError)), reraise=True
|
||||
)
|
||||
async def download_file(
|
||||
session: aiohttp.ClientSession, repo_id: str, revision: str, file_path: str, save_directory: str, progress_callback: Optional[RepoFileProgressCallback] = None, use_range_request: bool = True
|
||||
):
|
||||
base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
|
||||
url = urljoin(base_url, file_path)
|
||||
local_path = os.path.join(save_directory, file_path)
|
||||
|
||||
await aios.makedirs(os.path.dirname(local_path), exist_ok=True)
|
||||
|
||||
# Check if file already exists and get its size
|
||||
local_file_size = await aios.path.getsize(local_path) if await aios.path.exists(local_path) else 0
|
||||
|
||||
headers = await get_auth_headers()
|
||||
if use_range_request:
|
||||
headers["Range"] = f"bytes={local_file_size}-"
|
||||
|
||||
async with session.get(url, headers=headers) as response:
|
||||
total_size = int(response.headers.get('Content-Length', 0))
|
||||
downloaded_size = local_file_size
|
||||
downloaded_this_session = 0
|
||||
mode = 'ab' if use_range_request else 'wb'
|
||||
percentage = await get_file_download_percentage(
|
||||
session,
|
||||
repo_id,
|
||||
revision,
|
||||
file_path,
|
||||
Path(save_directory)
|
||||
)
|
||||
|
||||
if percentage == 100:
|
||||
if DEBUG >= 2: print(f"File already downloaded: {file_path}")
|
||||
if progress_callback:
|
||||
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, total_size, 0, total_size, 0, timedelta(0), "complete"))
|
||||
return
|
||||
|
||||
if response.status == 200:
|
||||
# File doesn't support range requests or we're not using them, start from beginning
|
||||
mode = 'wb'
|
||||
downloaded_size = 0
|
||||
elif response.status == 206:
|
||||
# Partial content, resume download
|
||||
content_range = response.headers.get('Content-Range', '')
|
||||
try:
|
||||
total_size = int(content_range.split('/')[-1])
|
||||
except ValueError:
|
||||
if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
|
||||
return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
|
||||
elif response.status == 416:
|
||||
# Range not satisfiable, get the actual file size
|
||||
content_range = response.headers.get('Content-Range', '')
|
||||
try:
|
||||
total_size = int(content_range.split('/')[-1])
|
||||
if downloaded_size == total_size:
|
||||
if DEBUG >= 2: print(f"File fully downloaded on first pass: {file_path}")
|
||||
if progress_callback:
|
||||
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
|
||||
return
|
||||
except ValueError:
|
||||
if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
|
||||
return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
|
||||
else:
|
||||
raise aiohttp.ClientResponseError(response.request_info, response.history, status=response.status, message=f"Failed to download {file_path}: {response.status}")
|
||||
|
||||
if downloaded_size == total_size:
|
||||
print(f"File already downloaded: {file_path}")
|
||||
if progress_callback:
|
||||
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
|
||||
return
|
||||
|
||||
DOWNLOAD_CHUNK_SIZE = 32768
|
||||
start_time = datetime.now()
|
||||
async with aiofiles.open(local_path, mode) as f:
|
||||
async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE):
|
||||
await f.write(chunk)
|
||||
downloaded_size += len(chunk)
|
||||
downloaded_this_session += len(chunk)
|
||||
if progress_callback and total_size:
|
||||
elapsed_time = (datetime.now() - start_time).total_seconds()
|
||||
speed = int(downloaded_this_session/elapsed_time) if elapsed_time > 0 else 0
|
||||
remaining_size = total_size - downloaded_size
|
||||
eta = timedelta(seconds=remaining_size/speed) if speed > 0 else timedelta(0)
|
||||
status = "in_progress" if downloaded_size < total_size else "complete"
|
||||
if DEBUG >= 8: print(f"HF repo file download progress: {file_path=} {elapsed_time=} {speed=} Downloaded={downloaded_size}/{total_size} {remaining_size=} {eta=} {status=}")
|
||||
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status))
|
||||
if DEBUG >= 2: print(f"Downloaded: {file_path}")
|
||||
|
||||
|
||||
async def resolve_revision_to_commit_hash(repo_id: str, revision: str) -> str:
|
||||
repo_root = get_repo_root(repo_id)
|
||||
refs_dir = repo_root/"refs"
|
||||
refs_file = refs_dir/revision
|
||||
|
||||
# Check if we have a cached commit hash
|
||||
if await aios.path.exists(refs_file):
|
||||
async with aiofiles.open(refs_file, 'r') as f:
|
||||
commit_hash = (await f.read()).strip()
|
||||
if DEBUG >= 2: print(f"Commit hash is already cached at {refs_file}: {commit_hash}")
|
||||
return commit_hash
|
||||
|
||||
# Fetch the commit hash for the given revision
|
||||
async with aiohttp.ClientSession() as session:
|
||||
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/revision/{revision}"
|
||||
headers = await get_auth_headers()
|
||||
async with session.get(api_url, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}")
|
||||
revision_info = await response.json()
|
||||
commit_hash = revision_info['sha']
|
||||
|
||||
# Cache the commit hash
|
||||
await aios.makedirs(refs_dir, exist_ok=True)
|
||||
async with aiofiles.open(refs_file, 'w') as f:
|
||||
await f.write(commit_hash)
|
||||
|
||||
return commit_hash
|
||||
|
||||
|
||||
async def download_repo_files(
|
||||
repo_id: str,
|
||||
revision: str = "main",
|
||||
progress_callback: Optional[RepoProgressCallback] = None,
|
||||
allow_patterns: Optional[Union[List[str], str]] = None,
|
||||
ignore_patterns: Optional[Union[List[str], str]] = None,
|
||||
max_parallel_downloads: int = 4
|
||||
) -> Path:
|
||||
repo_root = get_repo_root(repo_id)
|
||||
snapshots_dir = repo_root/"snapshots"
|
||||
cachedreqs_dir = repo_root/"cachedreqs"
|
||||
|
||||
# Ensure directories exist
|
||||
await aios.makedirs(snapshots_dir, exist_ok=True)
|
||||
await aios.makedirs(cachedreqs_dir, exist_ok=True)
|
||||
|
||||
# Resolve revision to commit hash
|
||||
commit_hash = await resolve_revision_to_commit_hash(repo_id, revision)
|
||||
|
||||
# Set up the snapshot directory
|
||||
snapshot_dir = snapshots_dir/commit_hash
|
||||
await aios.makedirs(snapshot_dir, exist_ok=True)
|
||||
|
||||
# Set up the cached file list directory
|
||||
cached_file_list_dir = cachedreqs_dir/commit_hash
|
||||
await aios.makedirs(cached_file_list_dir, exist_ok=True)
|
||||
cached_file_list_path = cached_file_list_dir/"fetch_file_list.json"
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Check if we have a cached file list
|
||||
if await aios.path.exists(cached_file_list_path):
|
||||
async with aiofiles.open(cached_file_list_path, 'r') as f:
|
||||
file_list = json.loads(await f.read())
|
||||
if DEBUG >= 2: print(f"Using cached file list from {cached_file_list_path}")
|
||||
else:
|
||||
file_list = await fetch_file_list(session, repo_id, revision)
|
||||
# Cache the file list
|
||||
async with aiofiles.open(cached_file_list_path, 'w') as f:
|
||||
await f.write(json.dumps(file_list))
|
||||
if DEBUG >= 2: print(f"Cached file list at {cached_file_list_path}")
|
||||
|
||||
filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
|
||||
total_files = len(filtered_file_list)
|
||||
total_bytes = sum(file["size"] for file in filtered_file_list)
|
||||
file_progress: Dict[str, RepoFileProgressEvent] = {
|
||||
file["path"]: RepoFileProgressEvent(repo_id, revision, file["path"], 0, 0, file["size"], 0, timedelta(0), "not_started")
|
||||
for file in filtered_file_list
|
||||
}
|
||||
start_time = datetime.now()
|
||||
|
||||
async def download_with_progress(file_info, progress_state):
|
||||
local_path = snapshot_dir/file_info["path"]
|
||||
if await aios.path.exists(local_path) and (await aios.stat(local_path)).st_size == file_info["size"]:
|
||||
if DEBUG >= 2: print(f"File already fully downloaded: {file_info['path']}")
|
||||
progress_state['completed_files'] += 1
|
||||
progress_state['downloaded_bytes'] += file_info["size"]
|
||||
file_progress[file_info["path"]] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], 0, file_info["size"], 0, timedelta(0), "complete")
|
||||
if progress_callback:
|
||||
elapsed_time = (datetime.now() - start_time).total_seconds()
|
||||
overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0
|
||||
remaining_bytes = total_bytes - progress_state['downloaded_bytes']
|
||||
overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0)
|
||||
status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
|
||||
await progress_callback(
|
||||
RepoProgressEvent(
|
||||
repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
|
||||
overall_eta, file_progress, status
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
async def file_progress_callback(event: RepoFileProgressEvent):
|
||||
progress_state['downloaded_bytes'] += event.downloaded - file_progress[event.file_path].downloaded
|
||||
progress_state['downloaded_bytes_this_session'] += event.downloaded_this_session - file_progress[event.file_path].downloaded_this_session
|
||||
file_progress[event.file_path] = event
|
||||
if progress_callback:
|
||||
elapsed_time = (datetime.now() - start_time).total_seconds()
|
||||
overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0
|
||||
remaining_bytes = total_bytes - progress_state['downloaded_bytes']
|
||||
overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0)
|
||||
status = "in_progress" if progress_state['downloaded_bytes'] < total_bytes else "complete"
|
||||
await progress_callback(
|
||||
RepoProgressEvent(
|
||||
repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
|
||||
overall_eta, file_progress, status
|
||||
)
|
||||
)
|
||||
|
||||
await download_file(session, repo_id, revision, file_info["path"], snapshot_dir, file_progress_callback)
|
||||
progress_state['completed_files'] += 1
|
||||
file_progress[
|
||||
file_info["path"]
|
||||
] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], file_progress[file_info["path"]].downloaded_this_session, file_info["size"], 0, timedelta(0), "complete")
|
||||
if progress_callback:
|
||||
elapsed_time = (datetime.now() - start_time).total_seconds()
|
||||
overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0
|
||||
remaining_bytes = total_bytes - progress_state['downloaded_bytes']
|
||||
overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0)
|
||||
status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
|
||||
await progress_callback(
|
||||
RepoProgressEvent(
|
||||
repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
|
||||
overall_eta, file_progress, status
|
||||
)
|
||||
)
|
||||
|
||||
progress_state = {'completed_files': 0, 'downloaded_bytes': 0, 'downloaded_bytes_this_session': 0}
|
||||
|
||||
semaphore = asyncio.Semaphore(max_parallel_downloads)
|
||||
|
||||
async def download_with_semaphore(file_info):
|
||||
async with semaphore:
|
||||
await download_with_progress(file_info, progress_state)
|
||||
|
||||
tasks = [asyncio.create_task(download_with_semaphore(file_info)) for file_info in filtered_file_list]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
return snapshot_dir
|
||||
|
||||
|
||||
async def get_weight_map(repo_id: str, revision: str = "main") -> Optional[Dict[str, str]]:
|
||||
"""
|
||||
Retrieve the weight map from the model.safetensors.index.json file.
|
||||
|
||||
Args:
|
||||
repo_id (str): The Hugging Face repository ID.
|
||||
revision (str): The revision of the repository to use.
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, str]]: The weight map if it exists, otherwise None.
|
||||
"""
|
||||
|
||||
# Download the index file
|
||||
await download_repo_files(repo_id=repo_id, revision=revision, allow_patterns="model.safetensors.index.json")
|
||||
|
||||
# Check if the file exists
|
||||
repo_root = get_repo_root(repo_id)
|
||||
commit_hash = await resolve_revision_to_commit_hash(repo_id, revision)
|
||||
snapshot_dir = repo_root/"snapshots"/commit_hash
|
||||
index_file = next((f for f in await aios.listdir(snapshot_dir) if f.endswith("model.safetensors.index.json")), None)
|
||||
|
||||
if index_file:
|
||||
index_file_path = snapshot_dir/index_file
|
||||
if await aios.path.exists(index_file_path):
|
||||
async with aiofiles.open(index_file_path, 'r') as f:
|
||||
index_data = json.loads(await f.read())
|
||||
return index_data.get("weight_map")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def extract_layer_num(tensor_name: str) -> Optional[int]:
|
||||
# This is a simple example and might need to be adjusted based on the actual naming convention
|
||||
parts = tensor_name.split('.')
|
||||
@@ -421,7 +79,6 @@ def extract_layer_num(tensor_name: str) -> Optional[int]:
|
||||
return int(part)
|
||||
return None
|
||||
|
||||
|
||||
def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
|
||||
default_patterns = set(["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"])
|
||||
shard_specific_patterns = set()
|
||||
@@ -437,67 +94,5 @@ def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
|
||||
shard_specific_patterns.add(sorted_file_names[-1])
|
||||
else:
|
||||
shard_specific_patterns = set(["*.safetensors"])
|
||||
if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
|
||||
if DEBUG >= 3: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
|
||||
return list(default_patterns | shard_specific_patterns)
|
||||
|
||||
async def get_file_download_percentage(
|
||||
session: aiohttp.ClientSession,
|
||||
repo_id: str,
|
||||
revision: str,
|
||||
file_path: str,
|
||||
snapshot_dir: Path,
|
||||
) -> float:
|
||||
"""
|
||||
Calculate the download percentage for a file by comparing local and remote sizes.
|
||||
"""
|
||||
try:
|
||||
local_path = snapshot_dir / file_path
|
||||
if not await aios.path.exists(local_path):
|
||||
return 0
|
||||
|
||||
# Get local file size first
|
||||
local_size = await aios.path.getsize(local_path)
|
||||
if local_size == 0:
|
||||
return 0
|
||||
|
||||
# Check remote size
|
||||
base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
|
||||
url = urljoin(base_url, file_path)
|
||||
headers = await get_auth_headers()
|
||||
|
||||
# Use HEAD request with redirect following for all files
|
||||
async with session.head(url, headers=headers, allow_redirects=True) as response:
|
||||
if response.status != 200:
|
||||
if DEBUG >= 2:
|
||||
print(f"Failed to get remote file info for {file_path}: {response.status}")
|
||||
return 0
|
||||
|
||||
remote_size = int(response.headers.get('Content-Length', 0))
|
||||
|
||||
if remote_size == 0:
|
||||
if DEBUG >= 2:
|
||||
print(f"Remote size is 0 for {file_path}")
|
||||
return 0
|
||||
|
||||
# Only return 100% if sizes match exactly
|
||||
if local_size == remote_size:
|
||||
return 100.0
|
||||
|
||||
# Calculate percentage based on sizes
|
||||
return (local_size / remote_size) * 100 if remote_size > 0 else 0
|
||||
|
||||
except Exception as e:
|
||||
if DEBUG >= 2:
|
||||
print(f"Error checking file download status for {file_path}: {e}")
|
||||
return 0
|
||||
|
||||
async def has_hf_home_read_access() -> bool:
|
||||
hf_home = get_hf_home()
|
||||
try: return await aios.access(hf_home, os.R_OK)
|
||||
except OSError: return False
|
||||
|
||||
async def has_hf_home_write_access() -> bool:
|
||||
hf_home = get_hf_home()
|
||||
try: return await aios.access(hf_home, os.W_OK)
|
||||
except OSError: return False
|
||||
|
||||
|
||||
@@ -1,167 +0,0 @@
|
||||
import asyncio
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple, Optional, Union
|
||||
from exo.inference.shard import Shard
|
||||
from exo.download.shard_download import ShardDownloader
|
||||
from exo.download.download_progress import RepoProgressEvent
|
||||
from exo.download.hf.hf_helpers import (
|
||||
download_repo_files, RepoProgressEvent, get_weight_map,
|
||||
get_allow_patterns, get_repo_root, fetch_file_list,
|
||||
get_local_snapshot_dir, get_file_download_percentage,
|
||||
filter_repo_objects
|
||||
)
|
||||
from exo.helpers import AsyncCallbackSystem, DEBUG
|
||||
from exo.models import model_cards, get_repo
|
||||
import aiohttp
|
||||
from aiofiles import os as aios
|
||||
|
||||
|
||||
class HFShardDownloader(ShardDownloader):
|
||||
def __init__(self, quick_check: bool = False, max_parallel_downloads: int = 4):
|
||||
self.quick_check = quick_check
|
||||
self.max_parallel_downloads = max_parallel_downloads
|
||||
self.active_downloads: Dict[Shard, asyncio.Task] = {}
|
||||
self.completed_downloads: Dict[Shard, Path] = {}
|
||||
self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
|
||||
self.current_shard: Optional[Shard] = None
|
||||
self.current_repo_id: Optional[str] = None
|
||||
self.revision: str = "main"
|
||||
|
||||
async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
|
||||
self.current_shard = shard
|
||||
self.current_repo_id = get_repo(shard.model_id, inference_engine_name)
|
||||
repo_name = get_repo(shard.model_id, inference_engine_name)
|
||||
if shard in self.completed_downloads:
|
||||
return self.completed_downloads[shard]
|
||||
if self.quick_check:
|
||||
repo_root = get_repo_root(repo_name)
|
||||
snapshots_dir = repo_root/"snapshots"
|
||||
if snapshots_dir.exists():
|
||||
visible_dirs = [d for d in snapshots_dir.iterdir() if not d.name.startswith('.')]
|
||||
if visible_dirs:
|
||||
most_recent_dir = max(visible_dirs, key=lambda x: x.stat().st_mtime)
|
||||
return most_recent_dir
|
||||
|
||||
# If a download on this shard is already in progress, keep that one
|
||||
for active_shard in self.active_downloads:
|
||||
if active_shard == shard:
|
||||
if DEBUG >= 2: print(f"Download already in progress for {shard}. Keeping that one.")
|
||||
return await self.active_downloads[shard]
|
||||
|
||||
# Cancel any downloads for this model_id on a different shard
|
||||
existing_active_shards = [active_shard for active_shard in self.active_downloads.keys() if active_shard.model_id == shard.model_id]
|
||||
for active_shard in existing_active_shards:
|
||||
if DEBUG >= 2: print(f"Cancelling download for {active_shard} (replacing with {shard})")
|
||||
task = self.active_downloads[active_shard]
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass # This is expected when cancelling a task
|
||||
except Exception as e:
|
||||
if DEBUG >= 2: print(f"Error in cancelling download {active_shard}: {e}")
|
||||
traceback.print_exc()
|
||||
self.active_downloads = {active_shard: task for active_shard, task in self.active_downloads.items() if active_shard.model_id != shard.model_id}
|
||||
|
||||
# Start new download
|
||||
download_task = asyncio.create_task(self._download_shard(shard, repo_name))
|
||||
self.active_downloads[shard] = download_task
|
||||
try:
|
||||
path = await download_task
|
||||
self.completed_downloads[shard] = path
|
||||
return path
|
||||
finally:
|
||||
# Ensure the task is removed even if an exception occurs
|
||||
print(f"Removing download task for {shard}: {shard in self.active_downloads}")
|
||||
if shard in self.active_downloads:
|
||||
self.active_downloads.pop(shard)
|
||||
|
||||
async def _download_shard(self, shard: Shard, repo_name: str) -> Path:
|
||||
async def wrapped_progress_callback(event: RepoProgressEvent):
|
||||
self._on_progress.trigger_all(shard, event)
|
||||
|
||||
weight_map = await get_weight_map(repo_name)
|
||||
allow_patterns = get_allow_patterns(weight_map, shard)
|
||||
|
||||
return await download_repo_files(repo_name, progress_callback=wrapped_progress_callback, allow_patterns=allow_patterns, max_parallel_downloads=self.max_parallel_downloads)
|
||||
|
||||
@property
|
||||
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
|
||||
return self._on_progress
|
||||
|
||||
async def get_shard_download_status(self) -> Optional[Dict[str, Union[float, int]]]:
|
||||
if not self.current_shard or not self.current_repo_id:
|
||||
if DEBUG >= 2:
|
||||
print(f"No current shard or repo_id set: {self.current_shard=} {self.current_repo_id=}")
|
||||
return None
|
||||
|
||||
try:
|
||||
# If no snapshot directory exists, return None - no need to check remote files
|
||||
snapshot_dir = await get_local_snapshot_dir(self.current_repo_id, self.revision)
|
||||
if not snapshot_dir:
|
||||
if DEBUG >= 2:
|
||||
print(f"No snapshot directory found for {self.current_repo_id}")
|
||||
return None
|
||||
|
||||
# Get the weight map to know what files we need
|
||||
weight_map = await get_weight_map(self.current_repo_id, self.revision)
|
||||
if not weight_map:
|
||||
if DEBUG >= 2:
|
||||
print(f"No weight map found for {self.current_repo_id}")
|
||||
return None
|
||||
|
||||
# Get all files needed for this shard
|
||||
patterns = get_allow_patterns(weight_map, self.current_shard)
|
||||
|
||||
# Check download status for all relevant files
|
||||
status = {}
|
||||
total_bytes = 0
|
||||
downloaded_bytes = 0
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
file_list = await fetch_file_list(session, self.current_repo_id, self.revision)
|
||||
relevant_files = list(
|
||||
filter_repo_objects(
|
||||
file_list, allow_patterns=patterns, key=lambda x: x["path"]))
|
||||
|
||||
for file in relevant_files:
|
||||
file_size = file["size"]
|
||||
total_bytes += file_size
|
||||
|
||||
percentage = await get_file_download_percentage(
|
||||
session,
|
||||
self.current_repo_id,
|
||||
self.revision,
|
||||
file["path"],
|
||||
snapshot_dir,
|
||||
)
|
||||
status[file["path"]] = percentage
|
||||
downloaded_bytes += (file_size * (percentage / 100))
|
||||
|
||||
# Add overall progress weighted by file size
|
||||
if total_bytes > 0:
|
||||
status["overall"] = (downloaded_bytes / total_bytes) * 100
|
||||
else:
|
||||
status["overall"] = 0
|
||||
|
||||
# Add total size in bytes
|
||||
status["total_size"] = total_bytes
|
||||
if status["overall"] != 100:
|
||||
status["total_downloaded"] = downloaded_bytes
|
||||
|
||||
|
||||
if DEBUG >= 2:
|
||||
print(f"Download calculation for {self.current_repo_id}:")
|
||||
print(f"Total bytes: {total_bytes}")
|
||||
print(f"Downloaded bytes: {downloaded_bytes}")
|
||||
for file in relevant_files:
|
||||
print(f"File {file['path']}: size={file['size']}, percentage={status[file['path']]}")
|
||||
|
||||
return status
|
||||
|
||||
except Exception as e:
|
||||
if DEBUG >= 2:
|
||||
print(f"Error getting shard download status: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
307
exo/download/new_shard_download.py
Normal file
307
exo/download/new_shard_download.py
Normal file
@@ -0,0 +1,307 @@
|
||||
from exo.inference.shard import Shard
|
||||
from exo.models import get_repo
|
||||
from pathlib import Path
|
||||
from exo.download.hf.hf_helpers import get_hf_endpoint, get_auth_headers, filter_repo_objects, get_allow_patterns
|
||||
from exo.download.shard_download import ShardDownloader
|
||||
from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent
|
||||
from exo.helpers import AsyncCallbackSystem, DEBUG
|
||||
from exo.models import get_supported_models, build_full_shard
|
||||
import os
|
||||
import aiofiles.os as aios
|
||||
import aiohttp
|
||||
import aiofiles
|
||||
from urllib.parse import urljoin
|
||||
from typing import Callable, Union, Tuple, Dict, List, Optional, Literal, AsyncIterator
|
||||
import time
|
||||
from datetime import timedelta
|
||||
import asyncio
|
||||
import json
|
||||
import traceback
|
||||
import shutil
|
||||
import tempfile
|
||||
import hashlib
|
||||
|
||||
def exo_home() -> Path:
|
||||
return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo"))
|
||||
|
||||
def exo_tmp() -> Path:
|
||||
return Path(tempfile.gettempdir())/"exo"
|
||||
|
||||
async def ensure_exo_home() -> Path:
|
||||
await aios.makedirs(exo_home(), exist_ok=True)
|
||||
return exo_home()
|
||||
|
||||
async def ensure_exo_tmp() -> Path:
|
||||
await aios.makedirs(exo_tmp(), exist_ok=True)
|
||||
return exo_tmp()
|
||||
|
||||
async def has_exo_home_read_access() -> bool:
|
||||
try: return await aios.access(exo_home(), os.R_OK)
|
||||
except OSError: return False
|
||||
|
||||
async def has_exo_home_write_access() -> bool:
|
||||
try: return await aios.access(exo_home(), os.W_OK)
|
||||
except OSError: return False
|
||||
|
||||
async def ensure_downloads_dir() -> Path:
|
||||
downloads_dir = exo_home()/"downloads"
|
||||
await aios.makedirs(downloads_dir, exist_ok=True)
|
||||
return downloads_dir
|
||||
|
||||
async def delete_model(model_id: str, inference_engine_name: str) -> bool:
|
||||
repo_id = get_repo(model_id, inference_engine_name)
|
||||
model_dir = await ensure_downloads_dir()/repo_id.replace("/", "--")
|
||||
if not await aios.path.exists(model_dir): return False
|
||||
await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)
|
||||
return True
|
||||
|
||||
async def seed_models(seed_dir: Union[str, Path]):
|
||||
"""Move model in resources folder of app to .cache/huggingface/hub"""
|
||||
source_dir = Path(seed_dir)
|
||||
dest_dir = await ensure_downloads_dir()
|
||||
for path in source_dir.iterdir():
|
||||
if path.is_dir() and path.name.startswith("models--"):
|
||||
dest_path = dest_dir/path.name
|
||||
if await aios.path.exists(dest_path): print('Skipping moving model to .cache directory')
|
||||
else:
|
||||
try: await aios.rename(str(path), str(dest_path))
|
||||
except:
|
||||
print(f"Error seeding model {path} to {dest_path}")
|
||||
traceback.print_exc()
|
||||
|
||||
async def fetch_file_list_with_cache(repo_id: str, revision: str = "main") -> List[Dict[str, Union[str, int]]]:
|
||||
cache_file = (await ensure_exo_tmp())/f"{repo_id.replace('/', '--')}--{revision}--file_list.json"
|
||||
if await aios.path.exists(cache_file):
|
||||
async with aiofiles.open(cache_file, 'r') as f: return json.loads(await f.read())
|
||||
file_list = await fetch_file_list_with_retry(repo_id, revision)
|
||||
await aios.makedirs(cache_file.parent, exist_ok=True)
|
||||
async with aiofiles.open(cache_file, 'w') as f: await f.write(json.dumps(file_list))
|
||||
return file_list
|
||||
|
||||
async def fetch_file_list_with_retry(repo_id: str, revision: str = "main", path: str = "") -> List[Dict[str, Union[str, int]]]:
|
||||
n_attempts = 30
|
||||
for attempt in range(n_attempts):
|
||||
try: return await _fetch_file_list(repo_id, revision, path)
|
||||
except Exception as e:
|
||||
if attempt == n_attempts - 1: raise e
|
||||
await asyncio.sleep(min(8, 0.1 * (2 ** attempt)))
|
||||
|
||||
async def _fetch_file_list(repo_id: str, revision: str = "main", path: str = "") -> List[Dict[str, Union[str, int]]]:
|
||||
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
|
||||
url = f"{api_url}/{path}" if path else api_url
|
||||
|
||||
headers = await get_auth_headers()
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=30, connect=10, sock_read=30, sock_connect=10)) as session:
|
||||
async with session.get(url, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
files = []
|
||||
for item in data:
|
||||
if item["type"] == "file":
|
||||
files.append({"path": item["path"], "size": item["size"]})
|
||||
elif item["type"] == "directory":
|
||||
subfiles = await _fetch_file_list(repo_id, revision, item["path"])
|
||||
files.extend(subfiles)
|
||||
return files
|
||||
else:
|
||||
raise Exception(f"Failed to fetch file list: {response.status}")
|
||||
|
||||
async def calc_hash(path: Path, type: Literal["sha1", "sha256"] = "sha1") -> str:
|
||||
hash = hashlib.sha1() if type == "sha1" else hashlib.sha256()
|
||||
if type == "sha1":
|
||||
header = f"blob {(await aios.stat(path)).st_size}\0".encode()
|
||||
hash.update(header)
|
||||
async with aiofiles.open(path, 'rb') as f:
|
||||
while chunk := await f.read(8 * 1024 * 1024):
|
||||
hash.update(chunk)
|
||||
return hash.hexdigest()
|
||||
|
||||
async def file_meta(repo_id: str, revision: str, path: str) -> Tuple[int, str]:
|
||||
url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
|
||||
headers = await get_auth_headers()
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as session:
|
||||
async with session.head(url, headers=headers) as r:
|
||||
content_length = int(r.headers.get('x-linked-size') or r.headers.get('content-length') or 0)
|
||||
etag = r.headers.get('X-Linked-ETag') or r.headers.get('ETag') or r.headers.get('Etag')
|
||||
assert content_length > 0, f"No content length for {url}"
|
||||
assert etag is not None, f"No remote hash for {url}"
|
||||
if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"): etag = etag[1:-1]
|
||||
return content_length, etag
|
||||
|
||||
async def download_file_with_retry(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
|
||||
n_attempts = 30
|
||||
for attempt in range(n_attempts):
|
||||
try: return await _download_file(repo_id, revision, path, target_dir, on_progress)
|
||||
except Exception as e:
|
||||
if isinstance(e, FileNotFoundError) or attempt == n_attempts - 1: raise e
|
||||
print(f"Download error on attempt {attempt}/{n_attempts} for {repo_id=} {revision=} {path=} {target_dir=}")
|
||||
traceback.print_exc()
|
||||
await asyncio.sleep(min(8, 0.1 * (2 ** attempt)))
|
||||
|
||||
async def _download_file(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
|
||||
if await aios.path.exists(target_dir/path): return target_dir/path
|
||||
await aios.makedirs((target_dir/path).parent, exist_ok=True)
|
||||
length, etag = await file_meta(repo_id, revision, path)
|
||||
remote_hash = etag[:-5] if etag.endswith("-gzip") else etag
|
||||
partial_path = target_dir/f"{path}.partial"
|
||||
resume_byte_pos = (await aios.stat(partial_path)).st_size if (await aios.path.exists(partial_path)) else None
|
||||
if resume_byte_pos != length:
|
||||
url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
|
||||
headers = await get_auth_headers()
|
||||
if resume_byte_pos: headers['Range'] = f'bytes={resume_byte_pos}-'
|
||||
n_read = resume_byte_pos or 0
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as session:
|
||||
async with session.get(url, headers=headers, timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as r:
|
||||
if r.status == 404: raise FileNotFoundError(f"File not found: {url}")
|
||||
assert r.status in [200, 206], f"Failed to download {path} from {url}: {r.status}"
|
||||
async with aiofiles.open(partial_path, 'ab' if resume_byte_pos else 'wb') as f:
|
||||
while chunk := await r.content.read(8 * 1024 * 1024): on_progress(n_read := n_read + await f.write(chunk), length)
|
||||
|
||||
final_hash = await calc_hash(partial_path, type="sha256" if len(remote_hash) == 64 else "sha1")
|
||||
integrity = final_hash == remote_hash
|
||||
if not integrity:
|
||||
try: await aios.remove(partial_path)
|
||||
except Exception as e: print(f"Error removing partial file {partial_path}: {e}")
|
||||
raise Exception(f"Downloaded file {target_dir/path} has hash {final_hash} but remote hash is {remote_hash}")
|
||||
await aios.rename(partial_path, target_dir/path)
|
||||
return target_dir/path
|
||||
|
||||
|
||||
def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_progress: Dict[str, RepoFileProgressEvent], all_start_time: float) -> RepoProgressEvent:
|
||||
all_total_bytes = sum([p.total for p in file_progress.values()])
|
||||
all_downloaded_bytes = sum([p.downloaded for p in file_progress.values()])
|
||||
all_downloaded_bytes_this_session = sum([p.downloaded_this_session for p in file_progress.values()])
|
||||
elapsed_time = time.time() - all_start_time
|
||||
all_speed = all_downloaded_bytes_this_session / elapsed_time if elapsed_time > 0 else 0
|
||||
all_eta = timedelta(seconds=(all_total_bytes - all_downloaded_bytes) / all_speed) if all_speed > 0 else timedelta(seconds=0)
|
||||
status = "complete" if all(p.status == "complete" for p in file_progress.values()) else "in_progress" if any(p.status == "in_progress" for p in file_progress.values()) else "not_started"
|
||||
return RepoProgressEvent(shard, repo_id, revision, len([p for p in file_progress.values() if p.downloaded == p.total]), len(file_progress), all_downloaded_bytes, all_downloaded_bytes_this_session, all_total_bytes, all_speed, all_eta, file_progress, status)
|
||||
|
||||
async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]:
|
||||
target_dir = (await ensure_exo_tmp())/repo_id.replace("/", "--")
|
||||
index_file = await download_file_with_retry(repo_id, revision, "model.safetensors.index.json", target_dir)
|
||||
async with aiofiles.open(index_file, 'r') as f: index_data = json.loads(await f.read())
|
||||
return index_data.get("weight_map")
|
||||
|
||||
async def resolve_allow_patterns(shard: Shard, inference_engine_classname: str) -> List[str]:
|
||||
try:
|
||||
weight_map = await get_weight_map(get_repo(shard.model_id, inference_engine_classname))
|
||||
return get_allow_patterns(weight_map, shard)
|
||||
except:
|
||||
if DEBUG >= 1: print(f"Error getting weight map for {shard.model_id=} and inference engine {inference_engine_classname}")
|
||||
if DEBUG >= 1: traceback.print_exc()
|
||||
return ["*"]
|
||||
|
||||
async def get_downloaded_size(path: Path) -> int:
|
||||
partial_path = path.with_suffix(path.suffix + ".partial")
|
||||
if await aios.path.exists(path): return (await aios.stat(path)).st_size
|
||||
if await aios.path.exists(partial_path): return (await aios.stat(partial_path)).st_size
|
||||
return 0
|
||||
|
||||
async def download_shard(shard: Shard, inference_engine_classname: str, on_progress: AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]], max_parallel_downloads: int = 8, skip_download: bool = False) -> tuple[Path, RepoProgressEvent]:
|
||||
if DEBUG >= 2 and not skip_download: print(f"Downloading {shard.model_id=} for {inference_engine_classname}")
|
||||
repo_id = get_repo(shard.model_id, inference_engine_classname)
|
||||
revision = "main"
|
||||
target_dir = await ensure_downloads_dir()/repo_id.replace("/", "--")
|
||||
if not skip_download: await aios.makedirs(target_dir, exist_ok=True)
|
||||
|
||||
if repo_id is None:
|
||||
raise ValueError(f"No repo found for {shard.model_id=} and inference engine {inference_engine_classname}")
|
||||
|
||||
allow_patterns = await resolve_allow_patterns(shard, inference_engine_classname)
|
||||
if DEBUG >= 2: print(f"Downloading {shard.model_id=} with {allow_patterns=}")
|
||||
|
||||
all_start_time = time.time()
|
||||
file_list = await fetch_file_list_with_cache(repo_id, revision)
|
||||
filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, key=lambda x: x["path"]))
|
||||
file_progress: Dict[str, RepoFileProgressEvent] = {}
|
||||
def on_progress_wrapper(file: dict, curr_bytes: int, total_bytes: int):
|
||||
start_time = file_progress[file["path"]].start_time if file["path"] in file_progress else time.time()
|
||||
downloaded_this_session = file_progress[file["path"]].downloaded_this_session + (curr_bytes - file_progress[file["path"]].downloaded) if file["path"] in file_progress else curr_bytes
|
||||
speed = downloaded_this_session / (time.time() - start_time) if time.time() - start_time > 0 else 0
|
||||
eta = timedelta(seconds=(total_bytes - curr_bytes) / speed) if speed > 0 else timedelta(seconds=0)
|
||||
file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], curr_bytes, downloaded_this_session, total_bytes, speed, eta, "complete" if curr_bytes == total_bytes else "in_progress", start_time)
|
||||
on_progress.trigger_all(shard, calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time))
|
||||
if DEBUG >= 6: print(f"Downloading {file['path']} {curr_bytes}/{total_bytes} {speed} {eta}")
|
||||
for file in filtered_file_list:
|
||||
downloaded_bytes = await get_downloaded_size(target_dir/file["path"])
|
||||
file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], downloaded_bytes, 0, file["size"], 0, timedelta(0), "complete" if downloaded_bytes == file["size"] else "not_started", time.time())
|
||||
|
||||
semaphore = asyncio.Semaphore(max_parallel_downloads)
|
||||
async def download_with_semaphore(file):
|
||||
async with semaphore:
|
||||
await download_file_with_retry(repo_id, revision, file["path"], target_dir, lambda curr_bytes, total_bytes: on_progress_wrapper(file, curr_bytes, total_bytes))
|
||||
if not skip_download: await asyncio.gather(*[download_with_semaphore(file) for file in filtered_file_list])
|
||||
final_repo_progress = calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time)
|
||||
on_progress.trigger_all(shard, final_repo_progress)
|
||||
if gguf := next((f for f in filtered_file_list if f["path"].endswith(".gguf")), None):
|
||||
return target_dir/gguf["path"], final_repo_progress
|
||||
else:
|
||||
return target_dir, final_repo_progress
|
||||
|
||||
def new_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
|
||||
return SingletonShardDownloader(CachedShardDownloader(NewShardDownloader(max_parallel_downloads)))
|
||||
|
||||
class SingletonShardDownloader(ShardDownloader):
|
||||
def __init__(self, shard_downloader: ShardDownloader):
|
||||
self.shard_downloader = shard_downloader
|
||||
self.active_downloads: Dict[Shard, asyncio.Task] = {}
|
||||
|
||||
@property
|
||||
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
|
||||
return self.shard_downloader.on_progress
|
||||
|
||||
async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
|
||||
if shard not in self.active_downloads: self.active_downloads[shard] = asyncio.create_task(self.shard_downloader.ensure_shard(shard, inference_engine_name))
|
||||
try: return await self.active_downloads[shard]
|
||||
finally:
|
||||
if shard in self.active_downloads and self.active_downloads[shard].done(): del self.active_downloads[shard]
|
||||
|
||||
async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
|
||||
async for path, status in self.shard_downloader.get_shard_download_status(inference_engine_name):
|
||||
yield path, status
|
||||
|
||||
class CachedShardDownloader(ShardDownloader):
|
||||
def __init__(self, shard_downloader: ShardDownloader):
|
||||
self.shard_downloader = shard_downloader
|
||||
self.cache: Dict[tuple[str, Shard], Path] = {}
|
||||
|
||||
@property
|
||||
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
|
||||
return self.shard_downloader.on_progress
|
||||
|
||||
async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
|
||||
if (inference_engine_name, shard) in self.cache:
|
||||
if DEBUG >= 2: print(f"ensure_shard cache hit {shard=} for {inference_engine_name}")
|
||||
return self.cache[(inference_engine_name, shard)]
|
||||
if DEBUG >= 2: print(f"ensure_shard cache miss {shard=} for {inference_engine_name}")
|
||||
target_dir = await self.shard_downloader.ensure_shard(shard, inference_engine_name)
|
||||
self.cache[(inference_engine_name, shard)] = target_dir
|
||||
return target_dir
|
||||
|
||||
async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
|
||||
async for path, status in self.shard_downloader.get_shard_download_status(inference_engine_name):
|
||||
yield path, status
|
||||
|
||||
class NewShardDownloader(ShardDownloader):
|
||||
def __init__(self, max_parallel_downloads: int = 8):
|
||||
self.max_parallel_downloads = max_parallel_downloads
|
||||
self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
|
||||
|
||||
@property
|
||||
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
|
||||
return self._on_progress
|
||||
|
||||
async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
|
||||
target_dir, _ = await download_shard(shard, inference_engine_name, self.on_progress, max_parallel_downloads=self.max_parallel_downloads)
|
||||
return target_dir
|
||||
|
||||
async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
|
||||
if DEBUG >= 2: print("Getting shard download status for", inference_engine_name)
|
||||
tasks = [download_shard(build_full_shard(model_id, inference_engine_name), inference_engine_name, self.on_progress, skip_download=True) for model_id in get_supported_models([[inference_engine_name]])]
|
||||
for task in asyncio.as_completed(tasks):
|
||||
try:
|
||||
path, progress = await task
|
||||
yield (path, progress)
|
||||
except Exception as e:
|
||||
print("Error downloading shard:", e)
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Tuple, Dict
|
||||
from typing import Optional, Tuple, Dict, AsyncIterator
|
||||
from pathlib import Path
|
||||
from exo.inference.shard import Shard
|
||||
from exo.download.download_progress import RepoProgressEvent
|
||||
@@ -27,7 +27,7 @@ class ShardDownloader(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_shard_download_status(self) -> Optional[Dict[str, float]]:
|
||||
async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
|
||||
"""Get the download status of shards.
|
||||
|
||||
Returns:
|
||||
@@ -45,5 +45,5 @@ class NoopShardDownloader(ShardDownloader):
|
||||
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
|
||||
return AsyncCallbackSystem()
|
||||
|
||||
async def get_shard_download_status(self) -> Optional[Dict[str, float]]:
|
||||
return None
|
||||
async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
|
||||
if False: yield
|
||||
|
||||
14
exo/download/test_new_shard_download.py
Normal file
14
exo/download/test_new_shard_download.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from exo.download.new_shard_download import NewShardDownloader
|
||||
from exo.inference.shard import Shard
|
||||
import asyncio
|
||||
|
||||
async def test_new_shard_download():
|
||||
shard_downloader = NewShardDownloader()
|
||||
shard_downloader.on_progress.register("test").on_next(lambda shard, event: print(shard, event))
|
||||
await shard_downloader.ensure_shard(Shard(model_id="llama-3.2-1b", start_layer=0, end_layer=0, n_layers=16), "MLXDynamicShardInferenceEngine")
|
||||
async for path, shard_status in shard_downloader.get_shard_download_status("MLXDynamicShardInferenceEngine"):
|
||||
print("Shard download status:", path, shard_status)
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_new_shard_download())
|
||||
|
||||
@@ -7,12 +7,14 @@ import random
|
||||
import platform
|
||||
import psutil
|
||||
import uuid
|
||||
import netifaces
|
||||
from scapy.all import get_if_addr, get_if_list
|
||||
import re
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
import json
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import traceback
|
||||
|
||||
DEBUG = int(os.getenv("DEBUG", default="0"))
|
||||
DEBUG_DISCOVERY = int(os.getenv("DEBUG_DISCOVERY", default="0"))
|
||||
@@ -229,28 +231,29 @@ def pretty_print_bytes_per_second(bytes_per_second: int) -> str:
|
||||
|
||||
|
||||
def get_all_ip_addresses_and_interfaces():
|
||||
try:
|
||||
ip_addresses = []
|
||||
for interface in netifaces.interfaces():
|
||||
ifaddresses = netifaces.ifaddresses(interface)
|
||||
if netifaces.AF_INET in ifaddresses:
|
||||
for link in ifaddresses[netifaces.AF_INET]:
|
||||
ip = link['addr']
|
||||
ip_addresses.append((ip, interface))
|
||||
for interface in get_if_list():
|
||||
try:
|
||||
ip = get_if_addr(interface)
|
||||
if ip.startswith("0.0."): continue
|
||||
simplified_interface = re.sub(r'^\\Device\\NPF_', '', interface)
|
||||
ip_addresses.append((ip, simplified_interface))
|
||||
except:
|
||||
if DEBUG >= 1: print(f"Failed to get IP address for interface {interface}")
|
||||
if DEBUG >= 1: traceback.print_exc()
|
||||
if not ip_addresses:
|
||||
if DEBUG >= 1: print("Failed to get any IP addresses. Defaulting to localhost.")
|
||||
return [("localhost", "lo")]
|
||||
return list(set(ip_addresses))
|
||||
except:
|
||||
if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
|
||||
return [("localhost", "lo")]
|
||||
|
||||
|
||||
|
||||
async def get_macos_interface_type(ifname: str) -> Optional[Tuple[int, str]]:
|
||||
try:
|
||||
# Use the shared subprocess_pool
|
||||
output = await asyncio.get_running_loop().run_in_executor(subprocess_pool, lambda: subprocess.run(
|
||||
['system_profiler', 'SPNetworkDataType', '-json'],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
close_fds=True
|
||||
).stdout)
|
||||
output = await asyncio.get_running_loop().run_in_executor(
|
||||
subprocess_pool, lambda: subprocess.run(['system_profiler', 'SPNetworkDataType', '-json'], capture_output=True, text=True, close_fds=True).stdout
|
||||
)
|
||||
|
||||
data = json.loads(output)
|
||||
|
||||
@@ -276,6 +279,7 @@ async def get_macos_interface_type(ifname: str) -> Optional[Tuple[int, str]]:
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def get_interface_priority_and_type(ifname: str) -> Tuple[int, str]:
|
||||
# On macOS, try to get interface type using networksetup
|
||||
if psutil.MACOS:
|
||||
@@ -283,8 +287,7 @@ async def get_interface_priority_and_type(ifname: str) -> Tuple[int, str]:
|
||||
if macos_type is not None: return macos_type
|
||||
|
||||
# Local container/virtual interfaces
|
||||
if (ifname.startswith(('docker', 'br-', 'veth', 'cni', 'flannel', 'calico', 'weave')) or
|
||||
'bridge' in ifname):
|
||||
if (ifname.startswith(('docker', 'br-', 'veth', 'cni', 'flannel', 'calico', 'weave')) or 'bridge' in ifname):
|
||||
return (7, "Container Virtual")
|
||||
|
||||
# Loopback interface
|
||||
@@ -310,6 +313,7 @@ async def get_interface_priority_and_type(ifname: str) -> Tuple[int, str]:
|
||||
# Other physical interfaces
|
||||
return (2, "Other")
|
||||
|
||||
|
||||
async def shutdown(signal, loop, server):
|
||||
"""Gracefully shutdown the server and close the asyncio loop."""
|
||||
print(f"Received exit signal {signal.name}...")
|
||||
@@ -325,4 +329,44 @@ async def shutdown(signal, loop, server):
|
||||
def is_frozen():
|
||||
return getattr(sys, 'frozen', False) or os.path.basename(sys.executable) == "exo" \
|
||||
or ('Contents/MacOS' in str(os.path.dirname(sys.executable))) \
|
||||
or '__nuitka__' in globals() or getattr(sys, '__compiled__', False)
|
||||
or '__nuitka__' in globals() or getattr(sys, '__compiled__', False)
|
||||
|
||||
async def get_mac_system_info() -> Tuple[str, str, int]:
|
||||
"""Get Mac system information using system_profiler."""
|
||||
try:
|
||||
output = await asyncio.get_running_loop().run_in_executor(
|
||||
subprocess_pool,
|
||||
lambda: subprocess.check_output(["system_profiler", "SPHardwareDataType"]).decode("utf-8")
|
||||
)
|
||||
|
||||
model_line = next((line for line in output.split("\n") if "Model Name" in line), None)
|
||||
model_id = model_line.split(": ")[1] if model_line else "Unknown Model"
|
||||
|
||||
chip_line = next((line for line in output.split("\n") if "Chip" in line), None)
|
||||
chip_id = chip_line.split(": ")[1] if chip_line else "Unknown Chip"
|
||||
|
||||
memory_line = next((line for line in output.split("\n") if "Memory" in line), None)
|
||||
memory_str = memory_line.split(": ")[1] if memory_line else "Unknown Memory"
|
||||
memory_units = memory_str.split()
|
||||
memory_value = int(memory_units[0])
|
||||
memory = memory_value * 1024 if memory_units[1] == "GB" else memory_value
|
||||
|
||||
return model_id, chip_id, memory
|
||||
except Exception as e:
|
||||
if DEBUG >= 2: print(f"Error getting Mac system info: {e}")
|
||||
return "Unknown Model", "Unknown Chip", 0
|
||||
|
||||
def get_exo_home() -> Path:
|
||||
if psutil.WINDOWS: docs_folder = Path(os.environ["USERPROFILE"])/"Documents"
|
||||
else: docs_folder = Path.home()/"Documents"
|
||||
if not docs_folder.exists(): docs_folder.mkdir(exist_ok=True)
|
||||
exo_folder = docs_folder/"Exo"
|
||||
if not exo_folder.exists(): exo_folder.mkdir(exist_ok=True)
|
||||
return exo_folder
|
||||
|
||||
|
||||
def get_exo_images_dir() -> Path:
|
||||
exo_home = get_exo_home()
|
||||
images_dir = exo_home/"Images"
|
||||
if not images_dir.exists(): images_dir.mkdir(exist_ok=True)
|
||||
return images_dir
|
||||
|
||||
@@ -16,25 +16,25 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
|
||||
resp_full = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
|
||||
token_full = await inference_engine_1.sample(resp_full)
|
||||
|
||||
next_resp_full = await inference_engine_1.infer_tensor(
|
||||
next_resp_full, _ = await inference_engine_1.infer_tensor(
|
||||
"A",
|
||||
shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
|
||||
input_data=token_full,
|
||||
)
|
||||
|
||||
resp1 = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
|
||||
resp2 = await inference_engine_2.infer_tensor(
|
||||
resp1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
|
||||
resp2, _ = await inference_engine_2.infer_tensor(
|
||||
"B",
|
||||
shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
|
||||
input_data=resp1,
|
||||
)
|
||||
token2 = await inference_engine_2.sample(resp2)
|
||||
resp3 = await inference_engine_1.infer_tensor(
|
||||
resp3, _ = await inference_engine_1.infer_tensor(
|
||||
"B",
|
||||
shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32),
|
||||
input_data=token2,
|
||||
)
|
||||
resp4 = await inference_engine_2.infer_tensor(
|
||||
resp4, _ = await inference_engine_2.infer_tensor(
|
||||
"B",
|
||||
shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
|
||||
input_data=resp3,
|
||||
|
||||
@@ -25,9 +25,9 @@ class DummyInferenceEngine(InferenceEngine):
|
||||
async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
|
||||
return self.tokenizer.decode(tokens)
|
||||
|
||||
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
|
||||
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
|
||||
await self.ensure_shard(shard)
|
||||
return input_data + 1 if self.shard.is_last_layer() else input_data
|
||||
return input_data + 1 if self.shard.is_last_layer() else input_data, None
|
||||
|
||||
async def ensure_shard(self, shard: Shard):
|
||||
if self.shard == shard: return
|
||||
|
||||
@@ -5,6 +5,7 @@ from exo.helpers import DEBUG # Make sure to import DEBUG
|
||||
from typing import Tuple, Optional
|
||||
from abc import ABC, abstractmethod
|
||||
from .shard import Shard
|
||||
from exo.download.shard_download import ShardDownloader
|
||||
|
||||
|
||||
class InferenceEngine(ABC):
|
||||
@@ -13,7 +14,7 @@ class InferenceEngine(ABC):
|
||||
@abstractmethod
|
||||
async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def sample(self, x: np.ndarray) -> np.ndarray:
|
||||
pass
|
||||
@@ -23,7 +24,7 @@ class InferenceEngine(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
|
||||
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -32,18 +33,23 @@ class InferenceEngine(ABC):
|
||||
|
||||
async def save_checkpoint(self, shard: Shard, path: str):
|
||||
pass
|
||||
|
||||
|
||||
async def save_session(self, key, value):
|
||||
self.session[key] = value
|
||||
|
||||
|
||||
async def clear_session(self):
|
||||
self.session.empty()
|
||||
|
||||
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str) -> np.ndarray:
|
||||
|
||||
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
|
||||
tokens = await self.encode(shard, prompt)
|
||||
x = tokens.reshape(1, -1)
|
||||
output_data = await self.infer_tensor(request_id, shard, x)
|
||||
return output_data
|
||||
if shard.model_id != 'stable-diffusion-2-1-base':
|
||||
x = tokens.reshape(1, -1)
|
||||
else:
|
||||
x = tokens
|
||||
output_data, inference_state = await self.infer_tensor(request_id, shard, x, inference_state)
|
||||
|
||||
return output_data, inference_state
|
||||
|
||||
|
||||
inference_engine_classes = {
|
||||
"mlx": "MLXDynamicShardInferenceEngine",
|
||||
@@ -51,7 +57,8 @@ inference_engine_classes = {
|
||||
"dummy": "DummyInferenceEngine",
|
||||
}
|
||||
|
||||
def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):
|
||||
|
||||
def get_inference_engine(inference_engine_name: str, shard_downloader: ShardDownloader):
|
||||
if DEBUG >= 2:
|
||||
print(f"get_inference_engine called with: {inference_engine_name}")
|
||||
if inference_engine_name == "mlx":
|
||||
|
||||
307
exo/inference/mlx/models/StableDiffusionPipeline.py
Normal file
307
exo/inference/mlx/models/StableDiffusionPipeline.py
Normal file
@@ -0,0 +1,307 @@
|
||||
# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/__init__.py
|
||||
|
||||
import time
|
||||
from typing import Optional, Tuple
|
||||
import inspect
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from pathlib import Path
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from .sd_models.vae import ModelArgs as VAEArgs
|
||||
from .sd_models.vae import Autoencoder
|
||||
from .sd_models.tokenizer import load_tokenizer
|
||||
from .sd_models.clip import CLIPTextModel
|
||||
from .sd_models.clip import ModelArgs as CLIPArgs
|
||||
from .sd_models.unet import UNetConfig, UNetModel
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from exo.inference.shard import Shard
|
||||
|
||||
@dataclass
|
||||
class DiffusionConfig:
|
||||
beta_schedule: str = "scaled_linear"
|
||||
beta_start: float = 0.00085
|
||||
beta_end: float = 0.012
|
||||
num_train_steps: int = 1000
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, params):
|
||||
return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
|
||||
|
||||
|
||||
#Sampler
|
||||
def _linspace(a, b, num):
|
||||
x = mx.arange(0, num) / (num - 1)
|
||||
return (b - a) * x + a
|
||||
|
||||
|
||||
def _interp(y, x_new):
|
||||
"""Interpolate the function defined by (arange(0, len(y)), y) at positions x_new."""
|
||||
x_low = x_new.astype(mx.int32)
|
||||
x_high = mx.minimum(x_low + 1, len(y) - 1)
|
||||
|
||||
y_low = y[x_low]
|
||||
y_high = y[x_high]
|
||||
delta_x = x_new - x_low
|
||||
y_new = y_low * (1 - delta_x) + delta_x * y_high
|
||||
|
||||
return y_new
|
||||
|
||||
class SimpleEulerSampler:
|
||||
"""A simple Euler integrator that can be used to sample from our diffusion models.
|
||||
|
||||
The method ``step()`` performs one Euler step from x_t to x_t_prev.
|
||||
"""
|
||||
|
||||
def __init__(self, config: DiffusionConfig):
|
||||
# Compute the noise schedule
|
||||
if config.beta_schedule == "linear":
|
||||
betas = _linspace(
|
||||
config.beta_start, config.beta_end, config.num_train_steps
|
||||
)
|
||||
elif config.beta_schedule == "scaled_linear":
|
||||
betas = _linspace(
|
||||
config.beta_start**0.5, config.beta_end**0.5, config.num_train_steps
|
||||
).square()
|
||||
else:
|
||||
raise NotImplementedError(f"{config.beta_schedule} is not implemented.")
|
||||
|
||||
alphas = 1 - betas
|
||||
alphas_cumprod = mx.cumprod(alphas)
|
||||
|
||||
self._sigmas = mx.concatenate(
|
||||
[mx.zeros(1), ((1 - alphas_cumprod) / alphas_cumprod).sqrt()]
|
||||
)
|
||||
|
||||
@property
|
||||
def max_time(self):
|
||||
return len(self._sigmas) - 1
|
||||
|
||||
def sample_prior(self, shape, dtype=mx.float32, key=None):
|
||||
noise = mx.random.normal(shape, key=key)
|
||||
return (
|
||||
noise * self._sigmas[-1] * (self._sigmas[-1].square() + 1).rsqrt()
|
||||
).astype(dtype)
|
||||
|
||||
def add_noise(self, x, t, key=None):
|
||||
noise = mx.random.normal(x.shape, key=key)
|
||||
s = self.sigmas(t)
|
||||
return (x + noise * s) * (s.square() + 1).rsqrt()
|
||||
|
||||
def sigmas(self, t):
|
||||
return _interp(self._sigmas, t)
|
||||
|
||||
def timesteps(self, num_steps: int, start_time=None, dtype=mx.float32):
|
||||
start_time = start_time or (len(self._sigmas) - 1)
|
||||
assert 0 < start_time <= (len(self._sigmas) - 1)
|
||||
steps = _linspace(start_time, 0, num_steps + 1).astype(dtype)
|
||||
return list(zip(steps, steps[1:]))
|
||||
|
||||
def current_timestep(self, step, total_steps, start_time=None):
|
||||
if step < total_steps:
|
||||
steps = self.timesteps(total_steps, start_time)
|
||||
return steps[step]
|
||||
else:
|
||||
return mx.array(0),mx.array(0)
|
||||
|
||||
def step(self, eps_pred, x_t, t, t_prev):
|
||||
sigma = self.sigmas(t).astype(eps_pred.dtype)
|
||||
sigma_prev = self.sigmas(t_prev).astype(eps_pred.dtype)
|
||||
|
||||
dt = sigma_prev - sigma
|
||||
x_t_prev = (sigma.square() + 1).sqrt() * x_t + eps_pred * dt
|
||||
|
||||
x_t_prev = x_t_prev * (sigma_prev.square() + 1).rsqrt()
|
||||
|
||||
return x_t_prev
|
||||
|
||||
@dataclass
|
||||
class ShardConfig:
|
||||
model_id:str
|
||||
start_layer:int
|
||||
end_layer:int
|
||||
n_layers:int
|
||||
|
||||
@dataclass
|
||||
class StableDiffusionConfig:
|
||||
model_type:str
|
||||
vae:VAEArgs
|
||||
text_encoder:CLIPArgs
|
||||
scheduler:DiffusionConfig
|
||||
unet:UNetConfig
|
||||
shard:ShardConfig
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, params):
|
||||
return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(StableDiffusionConfig):
|
||||
shard:Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.shard, dict):
|
||||
self.shard = Shard(**self.shard)
|
||||
|
||||
if not isinstance(self.shard, Shard):
|
||||
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.model_type = config.model_type
|
||||
self.config = config
|
||||
self.model_path = config.vae['path'].split('/vae')[0]
|
||||
self.shard = config.shard
|
||||
self.shard_clip, self.shard_encoder, self.shard_unet, self.shard_decoder = model_shards(config.shard)
|
||||
self.config_clip=CLIPArgs.from_dict(config.text_encoder['config'])
|
||||
if self.shard_clip.start_layer != -1:
|
||||
self.text_encoder = CLIPTextModel(self.config_clip, shard=self.shard_clip)
|
||||
else:
|
||||
self.text_encoder = nn.Identity()
|
||||
self.tokenizer = load_tokenizer(Path(self.model_path), "vocab.json", "merges.txt")
|
||||
self.diffusion_config = DiffusionConfig.from_dict(config.scheduler['config'])
|
||||
self.sampler = SimpleEulerSampler(self.diffusion_config)
|
||||
if self.shard_unet.start_layer!=-1:
|
||||
self.config_unet = UNetConfig.from_dict(config.unet['config'])
|
||||
self.unet = UNetModel(self.config_unet, self.shard_unet)
|
||||
else:
|
||||
self.unet = nn.Identity()
|
||||
self.config_vae=VAEArgs.from_dict(config.vae['config'])
|
||||
if self.shard_encoder.start_layer != -1:
|
||||
self.encoder=Autoencoder(self.config_vae, self.shard_encoder, "vae_encoder")
|
||||
else:
|
||||
self.encoder = nn.Identity()
|
||||
if self.shard_decoder.start_layer != -1:
|
||||
self.decoder=Autoencoder(self.config_vae, self.shard_decoder, "vae_decoder")
|
||||
else:
|
||||
self.decoder = nn.Identity()
|
||||
|
||||
def __call__(self,x, step= 0, cfg_weight: float = 7.5,total_steps=50,conditioning=None,mask=None,residual=None,x_t_prev=None,is_finished=False,is_step_finished=False, image=None, strength=0.65, start_step=None):
|
||||
t, t_prev = self.sampler.current_timestep(step=step, total_steps=total_steps, start_time=start_step)
|
||||
is_finished = False
|
||||
is_step_finished = False
|
||||
if t.item()==1000:
|
||||
if self.shard_clip.start_layer == 0:
|
||||
conditioning = x
|
||||
if self.shard_clip.start_layer != -1:
|
||||
conditioning, mask= self.text_encoder(conditioning,mask)
|
||||
seed = int(time.time())
|
||||
mx.random.seed(seed)
|
||||
if image is None:
|
||||
if self.shard_encoder.is_last_layer():
|
||||
x = self.sampler.sample_prior((1, *(64, 64), self.config_vae.latent_channels_in), dtype=mx.float32)
|
||||
x_t_prev=x
|
||||
start_step = self.sampler.max_time
|
||||
else:
|
||||
if self.shard_encoder.start_layer != -1:
|
||||
image= self.encoder.encode(image)
|
||||
if self.shard_encoder.is_last_layer():
|
||||
start_step = self.sampler.max_time*strength
|
||||
total_steps = int(total_steps*strength)
|
||||
image = mx.broadcast_to(image, (1,) + image.shape[1:])
|
||||
x_t_prev=self.sampler.add_noise(image, mx.array(start_step))
|
||||
image = None
|
||||
t, t_prev = self.sampler.current_timestep(step=step, total_steps=total_steps, start_time=start_step)
|
||||
# Perform the denoising loop
|
||||
if self.shard_unet.start_layer != -1:
|
||||
with tqdm(total=total_steps,initial=step+1) as pbar:
|
||||
if step<total_steps:
|
||||
x = x_t_prev
|
||||
if self.shard_unet.is_first_layer():
|
||||
x_t_unet = mx.concatenate([x] * 2, axis=0) if cfg_weight> 1 else x
|
||||
else:
|
||||
x_t_unet = x
|
||||
t_unet = mx.broadcast_to(t, [len(x_t_unet)])
|
||||
x, residual= self.unet(x_t_unet, t_unet, encoder_x=conditioning, residuals=residual)
|
||||
if self.shard_unet.is_last_layer():
|
||||
if cfg_weight > 1:
|
||||
eps_text, eps_neg = x.split(2)
|
||||
eps_pred = eps_neg + cfg_weight * (eps_text - eps_neg)
|
||||
x = self.sampler.step(eps_pred, x_t_prev, t, t_prev)
|
||||
x_t_prev=x
|
||||
mx.eval(x)
|
||||
|
||||
if self.shard_decoder.is_last_layer():
|
||||
is_step_finished=True
|
||||
if self.shard_decoder.start_layer != -1:
|
||||
x=self.decoder.decode(x)
|
||||
if self.shard_decoder.is_last_layer():
|
||||
x = mx.clip(x / 2 + 0.5, 0, 1)
|
||||
B, H, W, C = x.shape
|
||||
x = x.reshape(1, B // 1, H, W, C).transpose(0, 2, 1, 3, 4)
|
||||
x = x.reshape(1 * H, B // 1 * W, C)
|
||||
x = (x * 255).astype(mx.uint8)
|
||||
if t_prev.item() ==0:
|
||||
is_finished=True
|
||||
mx.eval(x)
|
||||
|
||||
return x, {'conditioning':conditioning, 'mask':mask,'residual':residual,'x_t_prev':x_t_prev,'is_finished':is_finished,'is_step_finished':is_step_finished, 'step':step, 'total_steps':total_steps, 'start_step':start_step, 'image':image}
|
||||
|
||||
|
||||
def load(self):
|
||||
if self.shard_encoder.start_layer != -1:
|
||||
vae_weights = mx.load(self.config_vae.weight_files[0])
|
||||
vae_weights = self.encoder.sanitize(vae_weights)
|
||||
self.encoder.load_weights(list(vae_weights.items()), strict=True)
|
||||
if self.shard_decoder.start_layer != -1:
|
||||
vae_weights = mx.load(self.config_vae.weight_files[0])
|
||||
vae_weights = self.decoder.sanitize(vae_weights)
|
||||
self.decoder.load_weights(list(vae_weights.items()), strict=True)
|
||||
if self.shard_clip.start_layer != -1:
|
||||
clip_weights = mx.load(self.config_clip.weight_files[0])
|
||||
clip_weights = self.text_encoder.sanitize(clip_weights)
|
||||
self.text_encoder.load_weights(list(clip_weights.items()), strict=True)
|
||||
if self.shard_unet.start_layer !=-1:
|
||||
unet_weights = mx.load(self.config_unet.weight_files[0])
|
||||
unet_weights = self.unet.sanitize(unet_weights)
|
||||
self.unet.load_weights(list(unet_weights.items()), strict=True)
|
||||
|
||||
def model_shards(shard:ShardConfig):
|
||||
def create_shard(shard, model_ranges):
|
||||
start_layer = shard.start_layer
|
||||
end_layer = shard.end_layer
|
||||
|
||||
shards = {}
|
||||
|
||||
for model_name, (range_start, range_end) in model_ranges.items():
|
||||
if start_layer < range_end and end_layer >= range_start:
|
||||
# Calculate the overlap with the model range
|
||||
overlap_start = max(start_layer, range_start)
|
||||
overlap_end = min(end_layer, range_end - 1)
|
||||
|
||||
# Adjust the layers relative to the model's range
|
||||
relative_start = overlap_start - range_start
|
||||
relative_end = overlap_end - range_start
|
||||
shards[model_name] = Shard(model_name, relative_start, relative_end, range_end - range_start)
|
||||
else:
|
||||
# If no overlap, create a zero-layer shard
|
||||
shards[model_name] = Shard(model_name, -1, -1, range_end - range_start)
|
||||
|
||||
return shards
|
||||
|
||||
# Define the ranges for different models
|
||||
model_ranges = {
|
||||
'clip': (0, 12),
|
||||
'vae_encoder':(12,17),
|
||||
'unet':(17,26),
|
||||
'vae_decoder': (26, 31) # Example range for unet
|
||||
}
|
||||
|
||||
# Call the function and get the shards for all models
|
||||
shards = create_shard(shard, model_ranges)
|
||||
|
||||
# Access individual shards
|
||||
shard_clip = shards['clip']
|
||||
shard_encoder = shards['vae_encoder']
|
||||
shard_unet = shards['unet']
|
||||
shard_decoder = shards['vae_decoder']
|
||||
|
||||
return shard_clip, shard_encoder, shard_unet, shard_decoder
|
||||
|
||||
|
||||
|
||||
134
exo/inference/mlx/models/deepseek_v3.py
Normal file
134
exo/inference/mlx/models/deepseek_v3.py
Normal file
@@ -0,0 +1,134 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_lm.models.cache import KVCache
|
||||
from mlx_lm.models.deepseek_v3 import (
|
||||
ModelArgs as V3ModelArgs,
|
||||
DeepseekV3DecoderLayer,
|
||||
)
|
||||
from .base import IdentityBlock
|
||||
from exo.inference.shard import Shard
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(V3ModelArgs):
|
||||
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.shard, Shard):
|
||||
return
|
||||
if not isinstance(self.shard, dict):
|
||||
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
|
||||
|
||||
self.shard = Shard(**self.shard)
|
||||
|
||||
|
||||
class DeepseekV3Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = config
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
self.vocab_size = config.vocab_size
|
||||
if self.args.shard.is_first_layer():
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
|
||||
self.layers = []
|
||||
for i in range(self.num_hidden_layers):
|
||||
if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
|
||||
self.layers.append(DeepseekV3DecoderLayer(config, i))
|
||||
else:
|
||||
self.layers.append(IdentityBlock())
|
||||
|
||||
if self.args.shard.is_last_layer():
|
||||
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache: Optional[KVCache] = None,
|
||||
) -> mx.array:
|
||||
if self.args.shard.is_first_layer():
|
||||
h = self.embed_tokens(x)
|
||||
else:
|
||||
h = x
|
||||
|
||||
mask = None
|
||||
T = h.shape[1]
|
||||
if T > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
|
||||
mask = mask.astype(h.dtype)
|
||||
|
||||
if cache is None:
|
||||
cache = [None]*len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
if self.args.shard.is_last_layer():
|
||||
h = self.norm(h)
|
||||
return h
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = config
|
||||
self.model_type = config.model_type
|
||||
self.model = DeepseekV3Model(config)
|
||||
if self.args.shard.is_last_layer():
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[KVCache] = None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
if self.args.shard.is_last_layer():
|
||||
return self.lm_head(out)
|
||||
return out
|
||||
|
||||
def sanitize(self, weights):
|
||||
shard_state_dict = {}
|
||||
|
||||
for key, value in weights.items():
|
||||
if key.startswith('model.layers.'):
|
||||
layer_num = int(key.split('.')[2])
|
||||
if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer:
|
||||
shard_state_dict[key] = value
|
||||
elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'):
|
||||
shard_state_dict[key] = value
|
||||
elif self.args.shard.is_last_layer() and (key.startswith('model.norm') or key.startswith('lm_head')):
|
||||
shard_state_dict[key] = value
|
||||
|
||||
for l in range(self.args.num_hidden_layers):
|
||||
prefix = f"model.layers.{l}"
|
||||
for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
|
||||
for k in ["weight", "scales", "biases"]:
|
||||
expert_key = f"{prefix}.mlp.experts.0.{m}.{k}"
|
||||
if expert_key in shard_state_dict:
|
||||
to_join = [
|
||||
shard_state_dict.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}")
|
||||
for e in range(self.args.n_routed_experts)
|
||||
]
|
||||
shard_state_dict[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
|
||||
|
||||
return shard_state_dict
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return (
|
||||
self.args.qk_nope_head_dim + self.args.qk_rope_head_dim,
|
||||
self.args.v_head_dim,
|
||||
)
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
117
exo/inference/mlx/models/phi3.py
Normal file
117
exo/inference/mlx/models/phi3.py
Normal file
@@ -0,0 +1,117 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_lm.models.base import create_attention_mask
|
||||
from mlx_lm.models.phi3 import TransformerBlock, ModelArgs
|
||||
|
||||
from ...shard import Shard
|
||||
from .base import IdentityBlock
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(ModelArgs):
|
||||
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
if isinstance(self.shard, Shard):
|
||||
return
|
||||
if not isinstance(self.shard, dict):
|
||||
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
|
||||
|
||||
self.shard = Shard(**self.shard)
|
||||
|
||||
class Phi3Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
|
||||
if self.args.shard.is_first_layer():
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
|
||||
self.layers = []
|
||||
for i in range(self.num_hidden_layers):
|
||||
if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
|
||||
self.layers.append(TransformerBlock(args=args))
|
||||
else:
|
||||
self.layers.append(IdentityBlock())
|
||||
|
||||
if self.args.shard.is_last_layer():
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
if self.args.shard.is_first_layer():
|
||||
h = self.embed_tokens(inputs)
|
||||
else:
|
||||
h = inputs
|
||||
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
if self.args.shard.is_last_layer():
|
||||
h = self.norm(h)
|
||||
return h
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = Phi3Model(args)
|
||||
if self.args.shard.is_last_layer():
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
if self.args.shard.is_last_layer():
|
||||
out = self.lm_head(out)
|
||||
return out
|
||||
|
||||
def sanitize(self, weights):
|
||||
shard_state_dict = {}
|
||||
|
||||
for key, value in weights.items():
|
||||
if "self_attn.rope.inv_freq" in key:
|
||||
continue
|
||||
if key.startswith('model.layers.'):
|
||||
layer_num = int(key.split('.')[2])
|
||||
if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer:
|
||||
shard_state_dict[key] = value
|
||||
elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'):
|
||||
shard_state_dict[key] = value
|
||||
elif self.args.shard.is_last_layer() and (key.startswith('lm_head') or key.startswith('model.norm')):
|
||||
shard_state_dict[key] = value
|
||||
|
||||
return shard_state_dict
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
@@ -9,13 +9,12 @@ from mlx_lm.models.qwen2 import TransformerBlock, ModelArgs
|
||||
from ...shard import Shard
|
||||
from .base import IdentityBlock
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(ModelArgs):
|
||||
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__() # Ensure parent initializations are respected
|
||||
super().__post_init__()
|
||||
|
||||
if isinstance(self.shard, Shard):
|
||||
return
|
||||
@@ -24,7 +23,6 @@ class ModelArgs(ModelArgs):
|
||||
|
||||
self.shard = Shard(**self.shard)
|
||||
|
||||
|
||||
class Qwen2Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
@@ -32,14 +30,17 @@ class Qwen2Model(nn.Module):
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
if self.args.shard.is_first_layer():
|
||||
|
||||
if self.args.shard.is_first_layer() or (self.args.shard.is_last_layer() and args.tie_word_embeddings):
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
|
||||
self.layers = []
|
||||
for i in range(self.num_hidden_layers):
|
||||
if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
|
||||
self.layers.append(TransformerBlock(args=args))
|
||||
else:
|
||||
self.layers.append(IdentityBlock())
|
||||
|
||||
if self.args.shard.is_last_layer():
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
|
||||
191
exo/inference/mlx/models/sd_models/clip.py
Normal file
191
exo/inference/mlx/models/sd_models/clip.py
Normal file
@@ -0,0 +1,191 @@
|
||||
# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/clip.py
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from dataclasses import field, dataclass
|
||||
from exo.inference.shard import Shard
|
||||
from exo.inference.mlx.models.base import IdentityBlock
|
||||
|
||||
_ACTIVATIONS = {"quick_gelu": nn.gelu_fast_approx, "gelu": nn.gelu}
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class CLIPTextModelConfig:
|
||||
num_layers: int = 23
|
||||
model_dims: int = 1024
|
||||
num_heads: int = 16
|
||||
max_length: int = 77
|
||||
vocab_size: int = 49408
|
||||
projection_dim: Optional[int] = None
|
||||
hidden_act: str = "quick_gelu"
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config):
|
||||
return ModelArgs(
|
||||
num_layers=config["num_hidden_layers"],
|
||||
model_dims=config["hidden_size"],
|
||||
num_heads=config["num_attention_heads"],
|
||||
max_length=config["max_position_embeddings"],
|
||||
vocab_size=config["vocab_size"],
|
||||
projection_dim=config["projection_dim"] if "WithProjection" in config['architectures'][0] else None,
|
||||
hidden_act=config.get("hidden_act", "quick_gelu"),
|
||||
weight_files=config.get("weight_files", [])
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(CLIPTextModelConfig):
|
||||
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
|
||||
weight_files: List[str] = field(default_factory=lambda: [])
|
||||
def __post_init__(self):
|
||||
if isinstance(self.shard, dict):
|
||||
self.shard = Shard(**self.shard)
|
||||
|
||||
if not isinstance(self.shard, Shard):
|
||||
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
|
||||
|
||||
if not self.shard.is_first_layer():
|
||||
self.vision_config = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class CLIPOutput:
|
||||
pooled_output: Optional[mx.array] = None
|
||||
last_hidden_state: Optional[mx.array] = None
|
||||
hidden_states: Optional[List[mx.array]] = None
|
||||
|
||||
|
||||
class CLIPEncoderLayer(nn.Module):
|
||||
"""The transformer encoder layer from CLIP."""
|
||||
|
||||
def __init__(self, model_dims: int, num_heads: int, activation: str):
|
||||
super().__init__()
|
||||
|
||||
self.layer_norm1 = nn.LayerNorm(model_dims)
|
||||
self.layer_norm2 = nn.LayerNorm(model_dims)
|
||||
|
||||
self.attention = nn.MultiHeadAttention(model_dims, num_heads)
|
||||
self.attention.query_proj.bias = mx.zeros(model_dims)
|
||||
self.attention.key_proj.bias = mx.zeros(model_dims)
|
||||
self.attention.value_proj.bias = mx.zeros(model_dims)
|
||||
self.attention.out_proj.bias = mx.zeros(model_dims)
|
||||
|
||||
self.linear1 = nn.Linear(model_dims, 4 * model_dims)
|
||||
self.linear2 = nn.Linear(4 * model_dims, model_dims)
|
||||
|
||||
self.act = _ACTIVATIONS[activation]
|
||||
|
||||
def __call__(self, x, attn_mask=None):
|
||||
|
||||
y = self.layer_norm1(x)
|
||||
y = self.attention(y, y, y, attn_mask)
|
||||
x = y + x
|
||||
|
||||
y = self.layer_norm2(x)
|
||||
y = self.linear1(y)
|
||||
y = self.act(y)
|
||||
y = self.linear2(y)
|
||||
x = y + x
|
||||
return x
|
||||
|
||||
|
||||
class CLIPTextModel(nn.Module):
|
||||
"""Implements the text encoder transformer from CLIP."""
|
||||
|
||||
def __init__(self, config: CLIPTextModelConfig, shard: Shard):
|
||||
super().__init__()
|
||||
|
||||
self.shard = shard
|
||||
self.layers_range = range(self.shard.start_layer*2, self.shard.end_layer*2+2)
|
||||
if self.shard.is_first_layer():
|
||||
self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims)
|
||||
self.position_embedding = nn.Embedding(config.max_length, config.model_dims)
|
||||
self.layers = []
|
||||
for i in range(math.ceil(config.num_layers/2)):
|
||||
if 2*i in self.layers_range:
|
||||
self.layers.append(CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act))
|
||||
if 2*i+1 in self.layers_range and 2*i+1 < config.num_layers:
|
||||
self.layers.append(CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act))
|
||||
else:
|
||||
self.layers.append(IdentityBlock())
|
||||
if self.shard.is_last_layer():
|
||||
self.final_layer_norm = nn.LayerNorm(config.model_dims)
|
||||
|
||||
if config.projection_dim is not None:
|
||||
self.text_projection = nn.Linear(
|
||||
config.model_dims, config.projection_dim, bias=False
|
||||
)
|
||||
|
||||
def _get_mask(self, N, dtype):
|
||||
indices = mx.arange(N)
|
||||
mask = indices[:, None] < indices[None]
|
||||
mask = mask.astype(dtype) * (-6e4 if dtype == mx.float16 else -1e9)
|
||||
return mask
|
||||
|
||||
def __call__(self, x, mask=None):
|
||||
# Extract some shapes
|
||||
if self.shard.is_first_layer():
|
||||
B, N = x.shape
|
||||
eos_tokens = x.argmax(-1)
|
||||
|
||||
# Compute the embeddings
|
||||
x = self.token_embedding(x)
|
||||
|
||||
x = x + self.position_embedding.weight[:N]
|
||||
# Compute the features from the transformer
|
||||
mask = self._get_mask(N, x.dtype)
|
||||
|
||||
for l in self.layers:
|
||||
x = l(x, mask)
|
||||
# Apply the final layernorm and return
|
||||
|
||||
if self.shard.is_last_layer():
|
||||
x = self.final_layer_norm(x)
|
||||
|
||||
|
||||
|
||||
return x, mask
|
||||
def sanitize(self, weights):
|
||||
sanitized_weights = {}
|
||||
for key, value in weights.items():
|
||||
if "position_ids" in key:
|
||||
continue
|
||||
if key.startswith("text_model."):
|
||||
key = key[11:]
|
||||
if key.startswith("embeddings."):
|
||||
key = key[11:]
|
||||
if key.startswith("encoder."):
|
||||
key = key[8:]
|
||||
|
||||
# Map attention layers
|
||||
if "self_attn." in key:
|
||||
key = key.replace("self_attn.", "attention.")
|
||||
if "q_proj." in key:
|
||||
key = key.replace("q_proj.", "query_proj.")
|
||||
if "k_proj." in key:
|
||||
key = key.replace("k_proj.", "key_proj.")
|
||||
if "v_proj." in key:
|
||||
key = key.replace("v_proj.", "value_proj.")
|
||||
|
||||
# Map ffn layers
|
||||
if "mlp.fc1" in key:
|
||||
key = key.replace("mlp.fc1", "linear1")
|
||||
if "mlp.fc2" in key:
|
||||
key = key.replace("mlp.fc2", "linear2")
|
||||
|
||||
if key.startswith("layers."):
|
||||
layer_num = int(key.split(".")[1])
|
||||
if layer_num not in self.layers_range:
|
||||
continue
|
||||
if not self.shard.is_first_layer() and "embedding" in key:
|
||||
continue
|
||||
if not self.shard.is_last_layer() and key.startswith("final_layer_norm"):
|
||||
continue
|
||||
if not self.shard.is_last_layer() and key.startswith("text_projection"):
|
||||
continue
|
||||
sanitized_weights[key] = value
|
||||
return sanitized_weights
|
||||
131
exo/inference/mlx/models/sd_models/tokenizer.py
Normal file
131
exo/inference/mlx/models/sd_models/tokenizer.py
Normal file
@@ -0,0 +1,131 @@
|
||||
# adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/tokenizer.py
|
||||
|
||||
import regex
|
||||
import json
|
||||
import glob
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
"""A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ ."""
|
||||
|
||||
def __init__(self, bpe_ranks, vocab):
|
||||
self.bpe_ranks = bpe_ranks
|
||||
self.vocab = vocab
|
||||
self.pat = regex.compile(
|
||||
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
|
||||
regex.IGNORECASE,
|
||||
)
|
||||
|
||||
self._cache = {self.bos: self.bos, self.eos: self.eos}
|
||||
|
||||
@property
|
||||
def bos(self):
|
||||
return "<|startoftext|>"
|
||||
|
||||
@property
|
||||
def bos_token(self):
|
||||
return self.vocab[self.bos]
|
||||
|
||||
@property
|
||||
def eos(self):
|
||||
return "<|endoftext|>"
|
||||
|
||||
@property
|
||||
def eos_token(self):
|
||||
return self.vocab[self.eos]
|
||||
|
||||
def bpe(self, text):
|
||||
if text in self._cache:
|
||||
return self._cache[text]
|
||||
|
||||
unigrams = list(text[:-1]) + [text[-1] + "</w>"]
|
||||
unique_bigrams = set(zip(unigrams, unigrams[1:]))
|
||||
|
||||
if not unique_bigrams:
|
||||
return unigrams
|
||||
|
||||
# In every iteration try to merge the two most likely bigrams. If none
|
||||
# was merged we are done.
|
||||
#
|
||||
# Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py
|
||||
while unique_bigrams:
|
||||
bigram = min(
|
||||
unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))
|
||||
)
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
|
||||
new_unigrams = []
|
||||
skip = False
|
||||
for a, b in zip(unigrams, unigrams[1:]):
|
||||
if skip:
|
||||
skip = False
|
||||
continue
|
||||
|
||||
if (a, b) == bigram:
|
||||
new_unigrams.append(a + b)
|
||||
skip = True
|
||||
|
||||
else:
|
||||
new_unigrams.append(a)
|
||||
|
||||
if not skip:
|
||||
new_unigrams.append(b)
|
||||
|
||||
unigrams = new_unigrams
|
||||
unique_bigrams = set(zip(unigrams, unigrams[1:]))
|
||||
|
||||
self._cache[text] = unigrams
|
||||
|
||||
return unigrams
|
||||
|
||||
def tokenize(self, text, prepend_bos=True, append_eos=True):
|
||||
if isinstance(text, list):
|
||||
return [self.tokenize(t, prepend_bos, append_eos) for t in text]
|
||||
|
||||
# Lower case cleanup and split according to self.pat. Hugging Face does
|
||||
# a much more thorough job here but this should suffice for 95% of
|
||||
# cases.
|
||||
clean_text = regex.sub(r"\s+", " ", text.lower())
|
||||
tokens = regex.findall(self.pat, clean_text)
|
||||
|
||||
# Split the tokens according to the byte-pair merge file
|
||||
bpe_tokens = [ti for t in tokens for ti in self.bpe(t)]
|
||||
|
||||
# Map to token ids and return
|
||||
tokens = [self.vocab[t] for t in bpe_tokens]
|
||||
if prepend_bos:
|
||||
tokens = [self.bos_token] + tokens
|
||||
if append_eos:
|
||||
tokens.append(self.eos_token)
|
||||
|
||||
return tokens
|
||||
|
||||
def encode(self, prompt):
|
||||
tokens = [self.tokenize(prompt)]
|
||||
negative_text = ""
|
||||
if negative_text is not None:
|
||||
tokens += [self.tokenize(negative_text)]
|
||||
lengths = [len(t) for t in tokens]
|
||||
N = max(lengths)
|
||||
tokens = [t + [0] * (N - len(t)) for t in tokens]
|
||||
return tokens
|
||||
|
||||
def load_tokenizer(
|
||||
model_path: str,
|
||||
vocab_key: str = "tokenizer_vocab",
|
||||
merges_key: str = "tokenizer_merges",
|
||||
):
|
||||
|
||||
vocab_file = glob.glob(str(model_path/"tokenizer"/vocab_key))[0]
|
||||
with open(vocab_file, encoding="utf-8") as f:
|
||||
vocab = json.load(f)
|
||||
|
||||
merges_file = glob.glob(str(model_path/"tokenizer"/merges_key))[0]
|
||||
with open(merges_file, encoding="utf-8") as f:
|
||||
bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
|
||||
bpe_merges = [tuple(m.split()) for m in bpe_merges]
|
||||
bpe_ranks = dict(map(reversed, enumerate(bpe_merges)))
|
||||
|
||||
return Tokenizer(bpe_ranks, vocab)
|
||||
|
||||
629
exo/inference/mlx/models/sd_models/unet.py
Normal file
629
exo/inference/mlx/models/sd_models/unet.py
Normal file
@@ -0,0 +1,629 @@
|
||||
# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/unet.py
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Tuple, Optional, List
|
||||
from exo.inference.shard import Shard
|
||||
|
||||
@dataclass
|
||||
class UNetConfig:
|
||||
in_channels: int = 4
|
||||
out_channels: int = 4
|
||||
conv_in_kernel: int = 3
|
||||
conv_out_kernel: int = 3
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
|
||||
layers_per_block: Tuple[int] = (2, 2, 2, 2)
|
||||
mid_block_layers: int = 2
|
||||
transformer_layers_per_block: Tuple[int] = (1, 1, 1, 1)
|
||||
num_attention_heads: Tuple[int] = (5, 10, 20, 20)
|
||||
cross_attention_dim: Tuple[int] = (1024,) * 4
|
||||
norm_num_groups: int = 32
|
||||
down_block_types: Tuple[str] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
)
|
||||
up_block_types: Tuple[str] = (
|
||||
"UpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
)
|
||||
addition_embed_type: Optional[str] = None
|
||||
addition_time_embed_dim: Optional[int] = None
|
||||
projection_class_embeddings_input_dim: Optional[int] = None
|
||||
weight_files: List[str] = field(default_factory=lambda: [])
|
||||
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls,config):
|
||||
n_blocks = len(config['block_out_channels'])
|
||||
return UNetConfig(
|
||||
in_channels=config["in_channels"],
|
||||
out_channels=config["out_channels"],
|
||||
block_out_channels=config["block_out_channels"],
|
||||
layers_per_block=[config["layers_per_block"]] * n_blocks,
|
||||
transformer_layers_per_block=config.get(
|
||||
"transformer_layers_per_block", (1,) * 4
|
||||
),
|
||||
num_attention_heads=(
|
||||
[config["attention_head_dim"]] * n_blocks
|
||||
if isinstance(config["attention_head_dim"], int)
|
||||
else config["attention_head_dim"]
|
||||
),
|
||||
cross_attention_dim=[config["cross_attention_dim"]] * n_blocks,
|
||||
norm_num_groups=config["norm_num_groups"],
|
||||
down_block_types=config["down_block_types"],
|
||||
up_block_types=config["up_block_types"][::-1],
|
||||
addition_embed_type=config.get("addition_embed_type", None),
|
||||
addition_time_embed_dim=config.get("addition_time_embed_dim", None),
|
||||
projection_class_embeddings_input_dim=config.get(
|
||||
"projection_class_embeddings_input_dim", None
|
||||
),
|
||||
weight_files=config.get("weight_files", [])
|
||||
|
||||
)
|
||||
|
||||
|
||||
def upsample_nearest(x, scale: int = 2):
|
||||
B, H, W, C = x.shape
|
||||
x = mx.broadcast_to(x[:, :, None, :, None, :], (B, H, scale, W, scale, C))
|
||||
x = x.reshape(B, H * scale, W * scale, C)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(self, in_channels: int, time_embed_dim: int):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
||||
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.linear_1(x)
|
||||
x = nn.silu(x)
|
||||
x = self.linear_2(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_dims: int,
|
||||
num_heads: int,
|
||||
hidden_dims: Optional[int] = None,
|
||||
memory_dims: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = nn.LayerNorm(model_dims)
|
||||
self.attn1 = nn.MultiHeadAttention(model_dims, num_heads)
|
||||
self.attn1.out_proj.bias = mx.zeros(model_dims)
|
||||
|
||||
memory_dims = memory_dims or model_dims
|
||||
self.norm2 = nn.LayerNorm(model_dims)
|
||||
self.attn2 = nn.MultiHeadAttention(
|
||||
model_dims, num_heads, key_input_dims=memory_dims
|
||||
)
|
||||
self.attn2.out_proj.bias = mx.zeros(model_dims)
|
||||
|
||||
hidden_dims = hidden_dims or 4 * model_dims
|
||||
self.norm3 = nn.LayerNorm(model_dims)
|
||||
self.linear1 = nn.Linear(model_dims, hidden_dims)
|
||||
self.linear2 = nn.Linear(model_dims, hidden_dims)
|
||||
self.linear3 = nn.Linear(hidden_dims, model_dims)
|
||||
|
||||
def __call__(self, x, memory, attn_mask, memory_mask):
|
||||
# Self attention
|
||||
y = self.norm1(x)
|
||||
y = self.attn1(y, y, y, attn_mask)
|
||||
x = x + y
|
||||
|
||||
# Cross attention
|
||||
y = self.norm2(x)
|
||||
y = self.attn2(y, memory, memory, memory_mask)
|
||||
x = x + y
|
||||
|
||||
# FFN
|
||||
y = self.norm3(x)
|
||||
y_a = self.linear1(y)
|
||||
y_b = self.linear2(y)
|
||||
y = y_a * nn.gelu(y_b)
|
||||
y = self.linear3(y)
|
||||
x = x + y
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Transformer2D(nn.Module):
|
||||
"""A transformer model for inputs with 2 spatial dimensions."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
model_dims: int,
|
||||
encoder_dims: int,
|
||||
num_heads: int,
|
||||
num_layers: int = 1,
|
||||
norm_num_groups: int = 32,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.norm = nn.GroupNorm(norm_num_groups, in_channels, pytorch_compatible=True)
|
||||
self.proj_in = nn.Linear(in_channels, model_dims)
|
||||
self.transformer_blocks = [
|
||||
TransformerBlock(model_dims, num_heads, memory_dims=encoder_dims)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
self.proj_out = nn.Linear(model_dims, in_channels)
|
||||
|
||||
def __call__(self, x, encoder_x, attn_mask, encoder_attn_mask):
|
||||
# Save the input to add to the output
|
||||
input_x = x
|
||||
dtype = x.dtype
|
||||
|
||||
# Perform the input norm and projection
|
||||
B, H, W, C = x.shape
|
||||
x = self.norm(x.astype(mx.float32)).astype(dtype).reshape(B, -1, C)
|
||||
x = self.proj_in(x)
|
||||
|
||||
# Apply the transformer
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, encoder_x, attn_mask, encoder_attn_mask)
|
||||
|
||||
# Apply the output projection and reshape
|
||||
x = self.proj_out(x)
|
||||
x = x.reshape(B, H, W, C)
|
||||
|
||||
return x + input_x
|
||||
|
||||
|
||||
class ResnetBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
groups: int = 32,
|
||||
temb_channels: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
out_channels = out_channels or in_channels
|
||||
|
||||
self.norm1 = nn.GroupNorm(groups, in_channels, pytorch_compatible=True)
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
if temb_channels is not None:
|
||||
self.time_emb_proj = nn.Linear(temb_channels, out_channels)
|
||||
self.norm2 = nn.GroupNorm(groups, out_channels, pytorch_compatible=True)
|
||||
self.conv2 = nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
if in_channels != out_channels:
|
||||
self.conv_shortcut = nn.Linear(in_channels, out_channels)
|
||||
|
||||
def __call__(self, x, temb=None):
|
||||
dtype = x.dtype
|
||||
|
||||
if temb is not None:
|
||||
temb = self.time_emb_proj(nn.silu(temb))
|
||||
y = self.norm1(x.astype(mx.float32)).astype(dtype)
|
||||
|
||||
y = nn.silu(y)
|
||||
|
||||
y = self.conv1(y)
|
||||
|
||||
|
||||
if temb is not None:
|
||||
y = y + temb[:, None, None, :]
|
||||
y = self.norm2(y.astype(mx.float32)).astype(dtype)
|
||||
y = nn.silu(y)
|
||||
y = self.conv2(y)
|
||||
|
||||
x = y + (x if "conv_shortcut" not in self else self.conv_shortcut(x))
|
||||
return x
|
||||
|
||||
|
||||
class UNetBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
prev_out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
transformer_layers_per_block: int = 1,
|
||||
num_attention_heads: int = 8,
|
||||
cross_attention_dim=1280,
|
||||
resnet_groups: int = 32,
|
||||
add_downsample=True,
|
||||
add_upsample=True,
|
||||
add_cross_attention=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Prepare the in channels list for the resnets
|
||||
if prev_out_channels is None:
|
||||
in_channels_list = [in_channels] + [out_channels] * (num_layers - 1)
|
||||
else:
|
||||
in_channels_list = [prev_out_channels] + [out_channels] * (num_layers - 1)
|
||||
res_channels_list = [out_channels] * (num_layers - 1) + [in_channels]
|
||||
in_channels_list = [
|
||||
a + b for a, b in zip(in_channels_list, res_channels_list)
|
||||
]
|
||||
|
||||
# Add resnet blocks that also process the time embedding
|
||||
self.resnets = [
|
||||
ResnetBlock2D(
|
||||
in_channels=ic,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
groups=resnet_groups,
|
||||
)
|
||||
for ic in in_channels_list
|
||||
]
|
||||
|
||||
# Add optional cross attention layers
|
||||
if add_cross_attention:
|
||||
self.attentions = [
|
||||
Transformer2D(
|
||||
in_channels=out_channels,
|
||||
model_dims=out_channels,
|
||||
num_heads=num_attention_heads,
|
||||
num_layers=transformer_layers_per_block,
|
||||
encoder_dims=cross_attention_dim,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
|
||||
# Add an optional downsampling layer
|
||||
if add_downsample:
|
||||
self.downsample = nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=2, padding=1
|
||||
)
|
||||
|
||||
# or upsampling layer
|
||||
if add_upsample:
|
||||
self.upsample = nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x,
|
||||
encoder_x=None,
|
||||
temb=None,
|
||||
attn_mask=None,
|
||||
encoder_attn_mask=None,
|
||||
residual_hidden_states=None,
|
||||
):
|
||||
output_states = []
|
||||
|
||||
for i in range(len(self.resnets)):
|
||||
if residual_hidden_states is not None:
|
||||
x = mx.concatenate([x, residual_hidden_states.pop()], axis=-1)
|
||||
|
||||
x = self.resnets[i](x, temb)
|
||||
|
||||
if "attentions" in self:
|
||||
x = self.attentions[i](x, encoder_x, attn_mask, encoder_attn_mask)
|
||||
|
||||
output_states.append(x)
|
||||
|
||||
if "downsample" in self:
|
||||
x = self.downsample(x)
|
||||
output_states.append(x)
|
||||
|
||||
if "upsample" in self:
|
||||
x = self.upsample(upsample_nearest(x))
|
||||
output_states.append(x)
|
||||
|
||||
return x, output_states
|
||||
|
||||
|
||||
class UNetModel(nn.Module):
|
||||
"""The conditional 2D UNet model that actually performs the denoising."""
|
||||
|
||||
def __init__(self, config: UNetConfig, shard: Shard):
|
||||
super().__init__()
|
||||
self.shard = shard
|
||||
self.start_layer = shard.start_layer
|
||||
self.end_layer = shard.end_layer
|
||||
self.layers_range = list(range(self.start_layer, self.end_layer+1))
|
||||
if shard.is_first_layer():
|
||||
self.conv_in = nn.Conv2d(
|
||||
config.in_channels,
|
||||
config.block_out_channels[0],
|
||||
config.conv_in_kernel,
|
||||
padding=(config.conv_in_kernel - 1) // 2,
|
||||
)
|
||||
|
||||
self.timesteps = nn.SinusoidalPositionalEncoding(
|
||||
config.block_out_channels[0],
|
||||
max_freq=1,
|
||||
min_freq=math.exp(
|
||||
-math.log(10000) + 2 * math.log(10000) / config.block_out_channels[0]
|
||||
),
|
||||
scale=1.0,
|
||||
cos_first=True,
|
||||
full_turns=False,
|
||||
)
|
||||
self.time_embedding = TimestepEmbedding(
|
||||
config.block_out_channels[0],
|
||||
config.block_out_channels[0] * 4,
|
||||
)
|
||||
|
||||
if config.addition_embed_type == "text_time":
|
||||
self.add_time_proj = nn.SinusoidalPositionalEncoding(
|
||||
config.addition_time_embed_dim,
|
||||
max_freq=1,
|
||||
min_freq=math.exp(
|
||||
-math.log(10000)
|
||||
+ 2 * math.log(10000) / config.addition_time_embed_dim
|
||||
),
|
||||
scale=1.0,
|
||||
cos_first=True,
|
||||
full_turns=False,
|
||||
)
|
||||
self.add_embedding = TimestepEmbedding(
|
||||
config.projection_class_embeddings_input_dim,
|
||||
config.block_out_channels[0] * 4,
|
||||
)
|
||||
|
||||
# Make the downsampling blocks
|
||||
block_channels = [config.block_out_channels[0]] + list(
|
||||
config.block_out_channels
|
||||
)
|
||||
self.down_blocks = []
|
||||
|
||||
for i, (in_channels, out_channels) in enumerate(zip(block_channels, block_channels[1:])):
|
||||
if i in self.layers_range:
|
||||
self.down_blocks.append(
|
||||
UNetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=config.block_out_channels[0] * 4,
|
||||
num_layers=config.layers_per_block[i],
|
||||
transformer_layers_per_block=config.transformer_layers_per_block[i],
|
||||
num_attention_heads=config.num_attention_heads[i],
|
||||
cross_attention_dim=config.cross_attention_dim[i],
|
||||
resnet_groups=config.norm_num_groups,
|
||||
add_downsample=(i < len(config.block_out_channels) - 1),
|
||||
add_upsample=False,
|
||||
add_cross_attention="CrossAttn" in config.down_block_types[i],
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.down_blocks.append(nn.Identity())
|
||||
|
||||
|
||||
# Make the middle block
|
||||
if 4 in self.layers_range:
|
||||
self.mid_blocks = [
|
||||
ResnetBlock2D(
|
||||
in_channels=config.block_out_channels[-1],
|
||||
out_channels=config.block_out_channels[-1],
|
||||
temb_channels=config.block_out_channels[0] * 4,
|
||||
groups=config.norm_num_groups,
|
||||
),
|
||||
Transformer2D(
|
||||
in_channels=config.block_out_channels[-1],
|
||||
model_dims=config.block_out_channels[-1],
|
||||
num_heads=config.num_attention_heads[-1],
|
||||
num_layers=config.transformer_layers_per_block[-1],
|
||||
encoder_dims=config.cross_attention_dim[-1],
|
||||
),
|
||||
ResnetBlock2D(
|
||||
in_channels=config.block_out_channels[-1],
|
||||
out_channels=config.block_out_channels[-1],
|
||||
temb_channels=config.block_out_channels[0] * 4,
|
||||
groups=config.norm_num_groups,
|
||||
),
|
||||
]
|
||||
|
||||
# Make the upsampling blocks
|
||||
block_channels = (
|
||||
[config.block_out_channels[0]]
|
||||
+ list(config.block_out_channels)
|
||||
+ [config.block_out_channels[-1]]
|
||||
)
|
||||
|
||||
total_items = len(block_channels) - 3
|
||||
reversed_channels = list(reversed(list(zip(block_channels, block_channels[1:], block_channels[2:]))))
|
||||
|
||||
self.up_blocks = []
|
||||
for rev_i, (in_channels, out_channels, prev_out_channels) in enumerate(reversed_channels):
|
||||
i = total_items - rev_i
|
||||
if rev_i+5 in self.layers_range:
|
||||
self.up_blocks.append(
|
||||
UNetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=config.block_out_channels[0] * 4,
|
||||
prev_out_channels=prev_out_channels,
|
||||
num_layers=config.layers_per_block[i] + 1,
|
||||
transformer_layers_per_block=config.transformer_layers_per_block[i],
|
||||
num_attention_heads=config.num_attention_heads[i],
|
||||
cross_attention_dim=config.cross_attention_dim[i],
|
||||
resnet_groups=config.norm_num_groups,
|
||||
add_downsample=False,
|
||||
add_upsample=(i > 0),
|
||||
add_cross_attention="CrossAttn" in config.up_block_types[i],
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.up_blocks.append(nn.Identity())
|
||||
|
||||
|
||||
if shard.is_last_layer():
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
config.norm_num_groups,
|
||||
config.block_out_channels[0],
|
||||
pytorch_compatible=True,
|
||||
)
|
||||
self.conv_out = nn.Conv2d(
|
||||
config.block_out_channels[0],
|
||||
config.out_channels,
|
||||
config.conv_out_kernel,
|
||||
padding=(config.conv_out_kernel - 1) // 2,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x,
|
||||
timestep,
|
||||
encoder_x,
|
||||
attn_mask=None,
|
||||
encoder_attn_mask=None,
|
||||
text_time=None,
|
||||
residuals=None,
|
||||
):
|
||||
# Compute the time embeddings
|
||||
|
||||
temb = self.timesteps(timestep).astype(x.dtype)
|
||||
temb = self.time_embedding(temb)
|
||||
|
||||
# Add the extra text_time conditioning
|
||||
if text_time is not None:
|
||||
text_emb, time_ids = text_time
|
||||
emb = self.add_time_proj(time_ids).flatten(1).astype(x.dtype)
|
||||
emb = mx.concatenate([text_emb, emb], axis=-1)
|
||||
emb = self.add_embedding(emb)
|
||||
temb = temb + emb
|
||||
|
||||
if self.shard.is_first_layer():
|
||||
# Preprocess the input
|
||||
x = self.conv_in(x)
|
||||
residuals = [x]
|
||||
# Run the downsampling part of the unet
|
||||
|
||||
for i in range(len(self.down_blocks)):
|
||||
if i in self.layers_range:
|
||||
x, res = self.down_blocks[i](
|
||||
x,
|
||||
encoder_x=encoder_x,
|
||||
temb=temb,
|
||||
attn_mask=attn_mask,
|
||||
encoder_attn_mask=encoder_attn_mask,
|
||||
)
|
||||
residuals.extend(res)
|
||||
else:
|
||||
x= self.down_blocks[i](x)
|
||||
|
||||
if 4 in self.layers_range:
|
||||
# Run the middle part of the unet
|
||||
x = self.mid_blocks[0](x, temb)
|
||||
x = self.mid_blocks[1](x, encoder_x, attn_mask, encoder_attn_mask)
|
||||
x = self.mid_blocks[2](x, temb)
|
||||
|
||||
# Run the upsampling part of the unet
|
||||
for i in range(len(self.up_blocks)):
|
||||
if i+5 in self.layers_range:
|
||||
x, _ = self.up_blocks[i](
|
||||
x,
|
||||
encoder_x=encoder_x,
|
||||
temb=temb,
|
||||
attn_mask=attn_mask,
|
||||
encoder_attn_mask=encoder_attn_mask,
|
||||
residual_hidden_states=residuals,
|
||||
)
|
||||
else:
|
||||
x= self.up_blocks[i](x)
|
||||
|
||||
# Postprocess the output
|
||||
if self.shard.is_last_layer():
|
||||
dtype = x.dtype
|
||||
x = self.conv_norm_out(x.astype(mx.float32)).astype(dtype)
|
||||
x = nn.silu(x)
|
||||
x = self.conv_out(x)
|
||||
|
||||
return x, residuals
|
||||
def sanitize(self, weights):
|
||||
sanitized_weights = {}
|
||||
for key, value in weights.items():
|
||||
k1=""
|
||||
k2=""
|
||||
if "downsamplers" in key:
|
||||
key = key.replace("downsamplers.0.conv", "downsample")
|
||||
if "upsamplers" in key:
|
||||
key = key.replace("upsamplers.0.conv", "upsample")
|
||||
|
||||
# Map the mid block
|
||||
if "mid_block.resnets.0" in key:
|
||||
key = key.replace("mid_block.resnets.0", "mid_blocks.0")
|
||||
if "mid_block.attentions.0" in key:
|
||||
key = key.replace("mid_block.attentions.0", "mid_blocks.1")
|
||||
if "mid_block.resnets.1" in key:
|
||||
key = key.replace("mid_block.resnets.1", "mid_blocks.2")
|
||||
|
||||
# Map attention layers
|
||||
if "to_k" in key:
|
||||
key = key.replace("to_k", "key_proj")
|
||||
if "to_out.0" in key:
|
||||
key = key.replace("to_out.0", "out_proj")
|
||||
if "to_q" in key:
|
||||
key = key.replace("to_q", "query_proj")
|
||||
if "to_v" in key:
|
||||
key = key.replace("to_v", "value_proj")
|
||||
|
||||
# Map transformer ffn
|
||||
if "ff.net.2" in key:
|
||||
key = key.replace("ff.net.2", "linear3")
|
||||
if "ff.net.0" in key:
|
||||
k1 = key.replace("ff.net.0.proj", "linear1")
|
||||
k2 = key.replace("ff.net.0.proj", "linear2")
|
||||
v1, v2 = mx.split(value, 2)
|
||||
|
||||
|
||||
if "conv_shortcut.weight" in key:
|
||||
value = value.squeeze()
|
||||
|
||||
# Transform the weights from 1x1 convs to linear
|
||||
if len(value.shape) == 4 and ("proj_in" in key or "proj_out" in key):
|
||||
value = value.squeeze()
|
||||
|
||||
if len(value.shape) == 4:
|
||||
value = value.transpose(0, 2, 3, 1)
|
||||
value = value.reshape(-1).reshape(value.shape)
|
||||
|
||||
if key.startswith("conv_in") :
|
||||
if 0 not in self.layers_range:
|
||||
continue
|
||||
|
||||
if key.startswith("down_blocks"):
|
||||
layer_num = int(key.split(".")[1])
|
||||
if layer_num not in self.layers_range:
|
||||
continue
|
||||
|
||||
if key.startswith("mid_block"):
|
||||
if 4 not in self.layers_range:
|
||||
continue
|
||||
|
||||
if key.startswith("up_blocks"):
|
||||
layer_num = int(key.split(".")[1])
|
||||
if (layer_num+5) not in self.layers_range:
|
||||
continue
|
||||
|
||||
if key.startswith("conv_out") or key.startswith("conv_norm_out"):
|
||||
if 8 not in self.layers_range:
|
||||
continue
|
||||
|
||||
if len(k1)>0:
|
||||
sanitized_weights[k1] = v1
|
||||
sanitized_weights[k2] = v2
|
||||
else:
|
||||
sanitized_weights[key] = value
|
||||
|
||||
|
||||
return sanitized_weights
|
||||
429
exo/inference/mlx/models/sd_models/vae.py
Normal file
429
exo/inference/mlx/models/sd_models/vae.py
Normal file
@@ -0,0 +1,429 @@
|
||||
# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/vae.py
|
||||
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .unet import ResnetBlock2D, upsample_nearest
|
||||
from dataclasses import dataclass, field
|
||||
from exo.inference.shard import Shard
|
||||
from typing import Tuple
|
||||
import inspect
|
||||
from ..base import IdentityBlock
|
||||
|
||||
@dataclass
|
||||
class AutoencoderConfig:
|
||||
in_channels: int = 3
|
||||
out_channels: int = 3
|
||||
latent_channels_out: int = 8
|
||||
latent_channels_in: int = 4
|
||||
block_out_channels: Tuple[int] = (128, 256, 512, 512)
|
||||
layers_per_block: int = 2
|
||||
norm_num_groups: int = 32
|
||||
scaling_factor: float = 0.18215
|
||||
weight_files: List[str] = field(default_factory=lambda: [])
|
||||
@classmethod
|
||||
def from_dict(cls, params):
|
||||
return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(AutoencoderConfig):
|
||||
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.shard, dict):
|
||||
self.shard = Shard(**self.shard)
|
||||
|
||||
if not isinstance(self.shard, Shard):
|
||||
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
|
||||
|
||||
if not self.shard.is_first_layer():
|
||||
self.vision_config = None
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""A single head unmasked attention for use with the VAE."""
|
||||
|
||||
def __init__(self, dims: int, norm_groups: int = 32):
|
||||
super().__init__()
|
||||
|
||||
self.group_norm = nn.GroupNorm(norm_groups, dims, pytorch_compatible=True)
|
||||
self.query_proj = nn.Linear(dims, dims)
|
||||
self.key_proj = nn.Linear(dims, dims)
|
||||
self.value_proj = nn.Linear(dims, dims)
|
||||
self.out_proj = nn.Linear(dims, dims)
|
||||
|
||||
def __call__(self, x):
|
||||
B, H, W, C = x.shape
|
||||
|
||||
y = self.group_norm(x)
|
||||
|
||||
queries = self.query_proj(y).reshape(B, H * W, C)
|
||||
keys = self.key_proj(y).reshape(B, H * W, C)
|
||||
values = self.value_proj(y).reshape(B, H * W, C)
|
||||
|
||||
scale = 1 / math.sqrt(queries.shape[-1])
|
||||
scores = (queries * scale) @ keys.transpose(0, 2, 1)
|
||||
attn = mx.softmax(scores, axis=-1)
|
||||
y = (attn @ values).reshape(B, H, W, C)
|
||||
|
||||
y = self.out_proj(y)
|
||||
x = x + y
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class EncoderDecoderBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_layers: int = 1,
|
||||
resnet_groups: int = 32,
|
||||
add_downsample=True,
|
||||
add_upsample=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Add the resnet blocks
|
||||
self.resnets = [
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels if i == 0 else out_channels,
|
||||
out_channels=out_channels,
|
||||
groups=resnet_groups,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
|
||||
# Add an optional downsampling layer
|
||||
if add_downsample:
|
||||
self.downsample = nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=2, padding=0
|
||||
)
|
||||
|
||||
# or upsampling layer
|
||||
if add_upsample:
|
||||
self.upsample = nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
for resnet in self.resnets:
|
||||
x = resnet(x)
|
||||
if "downsample" in self:
|
||||
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
|
||||
x = self.downsample(x)
|
||||
|
||||
if "upsample" in self:
|
||||
x = self.upsample(upsample_nearest(x))
|
||||
return x
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
"""Implements the encoder side of the Autoencoder."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
latent_channels_out: int,
|
||||
block_out_channels: List[int] = [64],
|
||||
layers_per_block: int = 2,
|
||||
resnet_groups: int = 32,
|
||||
layers_range: List[int] = [],
|
||||
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
|
||||
):
|
||||
super().__init__()
|
||||
self.layers_range = layers_range
|
||||
self.shard = shard
|
||||
if self.shard.is_first_layer():
|
||||
self.conv_in = nn.Conv2d(
|
||||
in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
channels = [block_out_channels[0]] + list(block_out_channels)
|
||||
self.down_blocks = []
|
||||
current_layer = 1
|
||||
for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:])):
|
||||
if current_layer in self.layers_range:
|
||||
self.down_blocks.append(
|
||||
EncoderDecoderBlock2D(
|
||||
in_channels,
|
||||
out_channels,
|
||||
num_layers=layers_per_block,
|
||||
resnet_groups=resnet_groups,
|
||||
add_downsample=i < len(block_out_channels) - 1,
|
||||
add_upsample=False,
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.down_blocks.append(IdentityBlock())
|
||||
current_layer += 1
|
||||
|
||||
if self.shard.is_last_layer():
|
||||
self.mid_blocks = [
|
||||
ResnetBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
out_channels=block_out_channels[-1],
|
||||
groups=resnet_groups,
|
||||
),
|
||||
Attention(block_out_channels[-1], resnet_groups),
|
||||
ResnetBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
out_channels=block_out_channels[-1],
|
||||
groups=resnet_groups,
|
||||
),
|
||||
]
|
||||
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
resnet_groups, block_out_channels[-1], pytorch_compatible=True
|
||||
)
|
||||
self.conv_out = nn.Conv2d(block_out_channels[-1], latent_channels_out, 3, padding=1)
|
||||
|
||||
def __call__(self, x):
|
||||
if self.shard.is_first_layer():
|
||||
x = self.conv_in(x)
|
||||
|
||||
for l in self.down_blocks:
|
||||
x = l(x)
|
||||
|
||||
if self.shard.is_last_layer():
|
||||
x = self.mid_blocks[0](x)
|
||||
x = self.mid_blocks[1](x)
|
||||
x = self.mid_blocks[2](x)
|
||||
|
||||
x = self.conv_norm_out(x)
|
||||
x = nn.silu(x)
|
||||
x = self.conv_out(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""Implements the decoder side of the Autoencoder."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
shard: Shard,
|
||||
layer_range: List[int],
|
||||
block_out_channels: List[int] = [64],
|
||||
layers_per_block: int = 2,
|
||||
resnet_groups: int = 32,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels
|
||||
self.layers_range = layer_range
|
||||
if 0 in layer_range:
|
||||
self.conv_in = nn.Conv2d(
|
||||
in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
if 0 in layer_range:
|
||||
self.mid_blocks = [
|
||||
ResnetBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
out_channels=block_out_channels[-1],
|
||||
groups=resnet_groups,
|
||||
),
|
||||
Attention(block_out_channels[-1], resnet_groups),
|
||||
ResnetBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
out_channels=block_out_channels[-1],
|
||||
groups=resnet_groups,
|
||||
),
|
||||
]
|
||||
|
||||
channels = list(reversed(block_out_channels))
|
||||
channels = [channels[0]] + channels
|
||||
|
||||
self.up_blocks = []
|
||||
current_layer = 1
|
||||
|
||||
for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:])):
|
||||
if current_layer in layer_range:
|
||||
self.up_blocks.append(
|
||||
EncoderDecoderBlock2D(
|
||||
in_channels,
|
||||
out_channels,
|
||||
num_layers=layers_per_block,
|
||||
resnet_groups=resnet_groups,
|
||||
add_downsample=False,
|
||||
add_upsample=i < len(block_out_channels) - 1,
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.up_blocks.append(IdentityBlock())
|
||||
current_layer += 1
|
||||
if 4 in layer_range:
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
resnet_groups, block_out_channels[0], pytorch_compatible=True
|
||||
)
|
||||
self.conv_out = nn.Conv2d(block_out_channels[0], self.out_channels, 3, padding=1)
|
||||
|
||||
|
||||
def __call__(self, x):
|
||||
if 0 in self.layers_range:
|
||||
x = self.conv_in(x)
|
||||
x = self.mid_blocks[0](x)
|
||||
x = self.mid_blocks[1](x)
|
||||
x = self.mid_blocks[2](x)
|
||||
|
||||
for l in self.up_blocks:
|
||||
x = l(x)
|
||||
if 4 in self.layers_range:
|
||||
x = self.conv_norm_out(x)
|
||||
x = nn.silu(x)
|
||||
x = self.conv_out(x)
|
||||
return x
|
||||
|
||||
|
||||
class Autoencoder(nn.Module):
|
||||
"""The autoencoder that allows us to perform diffusion in the latent space."""
|
||||
|
||||
def __init__(self, config: AutoencoderConfig, shard: Shard, model_shard: str):
|
||||
super().__init__()
|
||||
self.shard = shard
|
||||
self.start_layer = shard.start_layer
|
||||
self.end_layer = shard.end_layer
|
||||
self.layers_range = list(range(self.start_layer, self.end_layer+1))
|
||||
self.latent_channels = config.latent_channels_in
|
||||
self.scaling_factor = config.scaling_factor
|
||||
self.model_shard = model_shard
|
||||
if self.model_shard == "vae_encoder":
|
||||
self.encoder = Encoder(
|
||||
config.in_channels,
|
||||
config.latent_channels_out,
|
||||
config.block_out_channels,
|
||||
config.layers_per_block,
|
||||
resnet_groups=config.norm_num_groups,
|
||||
layers_range=self.layers_range,
|
||||
shard=shard
|
||||
)
|
||||
if self.shard.is_last_layer():
|
||||
self.quant_proj = nn.Linear(
|
||||
config.latent_channels_out, config.latent_channels_out
|
||||
)
|
||||
if self.model_shard == "vae_decoder":
|
||||
self.decoder = Decoder(
|
||||
config.latent_channels_in,
|
||||
config.out_channels,
|
||||
shard,
|
||||
self.layers_range,
|
||||
config.block_out_channels,
|
||||
config.layers_per_block + 1,
|
||||
resnet_groups=config.norm_num_groups,
|
||||
)
|
||||
if self.shard.is_first_layer():
|
||||
self.post_quant_proj = nn.Linear(
|
||||
config.latent_channels_in, config.latent_channels_in
|
||||
)
|
||||
|
||||
def decode(self, z):
|
||||
if self.shard.is_first_layer():
|
||||
z = z / self.scaling_factor
|
||||
z=self.post_quant_proj(z)
|
||||
return self.decoder(z)
|
||||
|
||||
def encode(self, x):
|
||||
x = self.encoder(x)
|
||||
if self.shard.is_last_layer():
|
||||
x = self.quant_proj(x)
|
||||
mean, logvar = x.split(2, axis=-1)
|
||||
mean = mean * self.scaling_factor
|
||||
logvar = logvar + 2 * math.log(self.scaling_factor)
|
||||
x = mean
|
||||
return x
|
||||
|
||||
def __call__(self, x, key=None):
|
||||
mean, logvar = self.encode(x)
|
||||
z = mx.random.normal(mean.shape, key=key) * mx.exp(0.5 * logvar) + mean
|
||||
x_hat = self.decode(z)
|
||||
|
||||
return dict(x_hat=x_hat, z=z, mean=mean, logvar=logvar)
|
||||
|
||||
def sanitize(self, weights):
|
||||
shard = self.shard
|
||||
layers = self.layers_range
|
||||
sanitized_weights = {}
|
||||
for key, value in weights.items():
|
||||
|
||||
if "downsamplers" in key:
|
||||
key = key.replace("downsamplers.0.conv", "downsample")
|
||||
if "upsamplers" in key:
|
||||
key = key.replace("upsamplers.0.conv", "upsample")
|
||||
|
||||
# Map attention layers
|
||||
if "key" in key:
|
||||
key = key.replace("key", "key_proj")
|
||||
if "proj_attn" in key:
|
||||
key = key.replace("proj_attn", "out_proj")
|
||||
if "query" in key:
|
||||
key = key.replace("query", "query_proj")
|
||||
if "value" in key:
|
||||
key = key.replace("value", "value_proj")
|
||||
|
||||
# Map the mid block
|
||||
if "mid_block.resnets.0" in key:
|
||||
key = key.replace("mid_block.resnets.0", "mid_blocks.0")
|
||||
if "mid_block.attentions.0" in key:
|
||||
key = key.replace("mid_block.attentions.0", "mid_blocks.1")
|
||||
if "mid_block.resnets.1" in key:
|
||||
key = key.replace("mid_block.resnets.1", "mid_blocks.2")
|
||||
|
||||
# Map the quant/post_quant layers
|
||||
if "quant_conv" in key:
|
||||
key = key.replace("quant_conv", "quant_proj")
|
||||
value = value.squeeze()
|
||||
|
||||
# Map the conv_shortcut to linear
|
||||
if "conv_shortcut.weight" in key:
|
||||
value = value.squeeze()
|
||||
|
||||
if len(value.shape) == 4:
|
||||
value = value.transpose(0, 2, 3, 1)
|
||||
value = value.reshape(-1).reshape(value.shape)
|
||||
|
||||
|
||||
if "post_quant_conv" in key :
|
||||
key = key.replace("quant_conv", "quant_proj")
|
||||
value = value.squeeze()
|
||||
|
||||
if 'decoder' in key and self.model_shard == "vae_decoder":
|
||||
if key.startswith("decoder.mid_blocks."):
|
||||
if 0 in layers:
|
||||
sanitized_weights[key] = value
|
||||
if "conv_in" in key and 0 in layers:
|
||||
sanitized_weights[key] = value
|
||||
if key.startswith("decoder.up_blocks."):
|
||||
layer_num = int(key.split(".")[2])+1
|
||||
if layer_num in layers:
|
||||
sanitized_weights[key] = value
|
||||
if key.startswith("decoder.conv_norm_out") and 4 in layers:
|
||||
sanitized_weights[key] = value
|
||||
if key.startswith("decoder.conv_out") and 4 in layers:
|
||||
sanitized_weights[key] = value
|
||||
if self.model_shard == "vae_decoder":
|
||||
if key.startswith("post_quant_proj") and 0 in layers:
|
||||
sanitized_weights[key] = value
|
||||
if self.model_shard == "vae_encoder":
|
||||
if key.startswith("encoder."):
|
||||
if "conv_in" in key and shard.is_first_layer():
|
||||
sanitized_weights[key] = value
|
||||
if key.startswith("encoder.down_blocks."):
|
||||
layer_num = int(key.split(".")[2])+1
|
||||
if layer_num in layers:
|
||||
sanitized_weights[key] = value
|
||||
if key.startswith("encoder.mid_blocks.") and shard.is_last_layer():
|
||||
sanitized_weights[key] = value
|
||||
if "conv_norm_out" in key and shard.is_last_layer():
|
||||
sanitized_weights[key] = value
|
||||
if "conv_out" in key and shard.is_last_layer():
|
||||
sanitized_weights[key] = value
|
||||
if key.startswith("quant_proj") and shard.is_last_layer():
|
||||
sanitized_weights[key] = value
|
||||
return sanitized_weights
|
||||
|
||||
7
exo/inference/mlx/perf_improvements.md
Normal file
7
exo/inference/mlx/perf_improvements.md
Normal file
@@ -0,0 +1,7 @@
|
||||
# Perf improvements
|
||||
|
||||
Target: 460 tok/sec
|
||||
- removing sample goes from 369 -> 402
|
||||
- performance degrades as we generate more tokens
|
||||
- make mlx inference engien synchronous, removing thread pool executor: 402 -> 413
|
||||
- remove self.on_opaque_status.trigger_all: 413 -> 418
|
||||
@@ -1,151 +1,179 @@
|
||||
import numpy as np
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_lm.sample_utils import top_p_sampling
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
import mlx.optimizers as optim
|
||||
from ..inference_engine import InferenceEngine
|
||||
from .sharded_utils import load_shard, get_image_from_str
|
||||
from .losses import loss_fns
|
||||
from .sharded_utils import load_model_shard, resolve_tokenizer
|
||||
from .losses import loss_fns
|
||||
from ..shard import Shard
|
||||
from typing import Dict, Optional, Tuple
|
||||
from exo.download.shard_download import ShardDownloader
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from collections import OrderedDict
|
||||
from mlx_lm.models.cache import make_prompt_cache
|
||||
|
||||
def sample_logits(
|
||||
logits: mx.array,
|
||||
temp: float = 0.0,
|
||||
top_p: float = 1.0,
|
||||
logit_bias: Optional[Dict[int, float]] = None
|
||||
) -> Tuple[mx.array, float]:
|
||||
if logit_bias:
|
||||
indices = mx.array(list(logit_bias.keys()))
|
||||
values = mx.array(list(logit_bias.values()))
|
||||
logits[:, indices] += values
|
||||
|
||||
if temp == 0:
|
||||
token = mx.argmax(logits, axis=-1)
|
||||
else:
|
||||
if top_p > 0 and top_p < 1.0:
|
||||
token = top_p_sampling(logits, top_p, temp)
|
||||
else:
|
||||
token = mx.random.categorical(logits*(1/temp))
|
||||
|
||||
return token
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
class MLXDynamicShardInferenceEngine(InferenceEngine):
|
||||
def __init__(self, shard_downloader: ShardDownloader):
|
||||
self.shard = None
|
||||
self.shard_downloader = shard_downloader
|
||||
self.executor = ThreadPoolExecutor(max_workers=1)
|
||||
self.caches = OrderedDict()
|
||||
self.sampler_params: tuple[float, float] = (0.0, 0.0, 0.0, 1)
|
||||
self.sampler = make_sampler(*self.sampler_params)
|
||||
self._mlx_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix="mlx")
|
||||
self._tokenizer_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix="tokenizer")
|
||||
self.session = {}
|
||||
self._shard_lock = asyncio.Lock()
|
||||
|
||||
async def _eval_mlx(self, *args):
|
||||
await asyncio.get_running_loop().run_in_executor(self._mlx_thread, mx.eval, *args)
|
||||
|
||||
async def poll_state(self, request_id: str, max_caches=2):
|
||||
if request_id in self.caches:
|
||||
self.caches.move_to_end(request_id)
|
||||
else:
|
||||
newcache = await asyncio.get_running_loop().run_in_executor(self.executor, make_prompt_cache, self.model)
|
||||
newcache = make_prompt_cache(self.model)
|
||||
if len(self.caches) > max_caches:
|
||||
self.caches.popitem(last=False)
|
||||
self.caches[request_id] = newcache
|
||||
return {"cache": self.caches[request_id]}
|
||||
|
||||
async def sample(self, x, temp: float = 0.0, top_p: float = 1.0) -> np.ndarray:
|
||||
y = mx.array(x)
|
||||
logits = y[:, -1, :]
|
||||
out = np.array(sample_logits(logits, temp=temp, top_p=top_p), dtype=int)
|
||||
return out
|
||||
async def sample(self, x: np.ndarray, temp: float = 0.0, top_p: float = 1.0) -> np.ndarray:
|
||||
if (temp, top_p, 0.0, 1) != self.sampler_params:
|
||||
self.sampler_params = (temp, top_p, 0.0, 1)
|
||||
self.sampler = make_sampler(*self.sampler_params)
|
||||
logits = mx.array(x)
|
||||
logits = logits[:, -1, :]
|
||||
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
||||
result = self.sampler(logprobs)
|
||||
await self._eval_mlx(result)
|
||||
return np.asarray(result, dtype=int)
|
||||
|
||||
async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
|
||||
await self.ensure_shard(shard)
|
||||
tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
|
||||
return np.array(tokens)
|
||||
return np.asarray(
|
||||
await asyncio.get_running_loop().run_in_executor(
|
||||
self._tokenizer_thread,
|
||||
self.tokenizer.encode,
|
||||
prompt
|
||||
)
|
||||
)
|
||||
|
||||
async def decode(self, shard: Shard, tokens) -> str:
|
||||
await self.ensure_shard(shard)
|
||||
tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
|
||||
return tokens
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
self._tokenizer_thread,
|
||||
self.tokenizer.decode,
|
||||
tokens
|
||||
)
|
||||
|
||||
async def save_checkpoint(self, shard: Shard, path: str):
|
||||
await self.ensure_shard(shard)
|
||||
await asyncio.get_running_loop().run_in_executor(self.executor, self.model.save_weights, path)
|
||||
await asyncio.get_running_loop().run_in_executor(self._mlx_thread, lambda: self.model.save_weights(path))
|
||||
|
||||
async def load_checkpoint(self, shard: Shard, path: str):
|
||||
await self.ensure_shard(shard)
|
||||
await asyncio.get_running_loop().run_in_executor(self.executor, self.model.load_weights, path)
|
||||
|
||||
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
|
||||
await asyncio.get_running_loop().run_in_executor(self._mlx_thread, lambda: self.model.load_weights(path))
|
||||
|
||||
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
|
||||
await self.ensure_shard(shard)
|
||||
loop = asyncio.get_running_loop()
|
||||
state = await self.poll_state(request_id)
|
||||
state = await self.poll_state(request_id) if self.model.model_type != 'StableDiffusionPipeline' else {}
|
||||
x = mx.array(input_data)
|
||||
output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, lambda: self.model(x, **state)))
|
||||
return output_data
|
||||
|
||||
if self.model.model_type != 'StableDiffusionPipeline':
|
||||
output_data = await asyncio.get_running_loop().run_in_executor(
|
||||
self._mlx_thread,
|
||||
lambda: self.model(x, **state, **(inference_state or {}))
|
||||
)
|
||||
inference_state = None
|
||||
else:
|
||||
result = await asyncio.get_running_loop().run_in_executor(
|
||||
self._mlx_thread,
|
||||
lambda: self.model(x, **state, **(inference_state or {}))
|
||||
)
|
||||
output_data, inference_state = result
|
||||
|
||||
await self._eval_mlx(output_data)
|
||||
output_data = await asyncio.get_running_loop().run_in_executor(
|
||||
self._mlx_thread,
|
||||
lambda: np.array(output_data, copy=False)
|
||||
)
|
||||
return output_data, inference_state
|
||||
|
||||
async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce"):
|
||||
await self.ensure_shard(shard)
|
||||
await self.save_session('loss', loss_fns[loss])
|
||||
loop = asyncio.get_running_loop()
|
||||
#print(f"evaluate in <- {inputs}")
|
||||
x = mx.array(inputs)
|
||||
y = mx.array(targets)
|
||||
l = mx.array(lengths)
|
||||
score = await loop.run_in_executor(self.executor, self.session['loss'], self.model, x, y, l)
|
||||
#print(f"evaluate out -> {score}")
|
||||
|
||||
score = await asyncio.get_running_loop().run_in_executor(
|
||||
self._mlx_thread,
|
||||
lambda: self.session['loss'](self.model, x, y, l)
|
||||
)
|
||||
return score
|
||||
|
||||
async def ensure_train(self, shard: Shard, loss: str, opt=optim.SGD, lr=1e-5, trainable_layers=['input_layernorm', 'gate_proj']):
|
||||
await self.ensure_shard(shard)
|
||||
|
||||
if 'train_layers' not in self.session or self.session['train_layers'] != trainable_layers:
|
||||
await self.save_session('train_layers', trainable_layers)
|
||||
self.model.freeze()
|
||||
self.model.apply_to_modules(lambda k, v: v.unfreeze() if any(lambda: k.endswith(i) for i in trainable_layers) else None)
|
||||
def freeze_unfreeze():
|
||||
self.model.freeze()
|
||||
self.model.apply_to_modules(
|
||||
lambda k, v: v.unfreeze() if any(k.endswith(layer_name) for layer_name in trainable_layers) else None
|
||||
)
|
||||
await asyncio.get_running_loop().run_in_executor(self._mlx_thread, freeze_unfreeze)
|
||||
|
||||
if 'lossname' not in self.session or 'LVaG' not in self.session or self.session['lossname'] != loss:
|
||||
await self.save_session('lossname', loss)
|
||||
await self.save_session('LVaG', nn.value_and_grad(self.model, loss_fns[loss]))
|
||||
|
||||
if 'opt' not in self.session:
|
||||
await self.save_session('opt', opt(lr))
|
||||
return True
|
||||
|
||||
async def train(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce", opt=optim.SGD, lr=1e-5):
|
||||
loop = asyncio.get_running_loop()
|
||||
nothin = await self.ensure_train(shard, loss, opt, lr)
|
||||
await self.ensure_train(shard, loss, opt, lr)
|
||||
|
||||
def train_step(inp, tar, lng):
|
||||
lval, grad = self.session['LVaG'](self.model, inp, tar, lng)
|
||||
gradlayers = grad['model']['layers']
|
||||
self.session['opt'].update(self.model, grad)
|
||||
mx.eval(self.model.parameters(), self.session['opt'].state, lval)
|
||||
return lval, gradlayers
|
||||
return lval, gradlayers, (self.model.parameters(), self.session['opt'].state, lval)
|
||||
|
||||
x = mx.array(inputs)
|
||||
y = mx.array(targets)
|
||||
l = mx.array(lengths)
|
||||
score, gradients, eval_args = await asyncio.get_running_loop().run_in_executor(
|
||||
self._mlx_thread,
|
||||
lambda: train_step(x, y, l)
|
||||
)
|
||||
await self._eval_mlx(*eval_args)
|
||||
|
||||
score, gradients = await loop.run_in_executor(self.executor, train_step, x, y, l)
|
||||
#print(f"{score=}")
|
||||
|
||||
layers = [{k: v["weight"] for k,v in l.items() if 'weight' in v} for l in gradients if l]
|
||||
#print(layers[0])
|
||||
|
||||
return score, np.array(layers[0]['input_layernorm'])
|
||||
layers = [{k: v["weight"] for k, v in layer.items() if 'weight' in v} for layer in gradients if layer]
|
||||
first_layer = np.array(layers[0]['input_layernorm'], copy=False)
|
||||
await self._eval_mlx(first_layer)
|
||||
return score, first_layer
|
||||
|
||||
async def ensure_shard(self, shard: Shard):
|
||||
if self.shard == shard:
|
||||
return
|
||||
|
||||
model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
|
||||
|
||||
if self.shard != shard:
|
||||
|
||||
def load_shard_wrapper():
|
||||
return asyncio.run(load_shard(model_path, shard))
|
||||
|
||||
model_shard, self.tokenizer = await asyncio.get_running_loop().run_in_executor(self.executor, load_shard_wrapper)
|
||||
self.shard = shard
|
||||
self.model = model_shard
|
||||
self.caches = OrderedDict()
|
||||
self.session = {}
|
||||
async with self._shard_lock:
|
||||
if self.shard == shard: return
|
||||
model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
|
||||
if self.shard != shard:
|
||||
model_shard = await asyncio.get_running_loop().run_in_executor(
|
||||
self._mlx_thread,
|
||||
lambda: load_model_shard(model_path, shard, lazy=False)
|
||||
)
|
||||
if hasattr(model_shard, "tokenizer"):
|
||||
self.tokenizer = model_shard.tokenizer
|
||||
else:
|
||||
self.tokenizer = await resolve_tokenizer(model_path)
|
||||
self.shard = shard
|
||||
self.model = model_shard
|
||||
self.caches = OrderedDict()
|
||||
self.session = {}
|
||||
|
||||
async def cleanup(self):
|
||||
self._mlx_thread.shutdown(wait=True)
|
||||
|
||||
@@ -62,8 +62,16 @@ def _get_classes(config: dict):
|
||||
|
||||
def load_config(model_path: Path) -> dict:
|
||||
try:
|
||||
with open(model_path/"config.json", "r") as f:
|
||||
config = json.load(f)
|
||||
config_path = model_path / "config.json"
|
||||
if config_path.exists():
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
return config
|
||||
|
||||
model_index_path = model_path / "model_index.json"
|
||||
if model_index_path.exists():
|
||||
config = load_model_index(model_path, model_index_path)
|
||||
return config
|
||||
except FileNotFoundError:
|
||||
logging.error(f"Config file not found in {model_path}")
|
||||
raise
|
||||
@@ -110,6 +118,24 @@ def load_model_shard(
|
||||
# Try weight for back-compat
|
||||
weight_files = glob.glob(str(model_path/"weight*.safetensors"))
|
||||
|
||||
model_class, model_args_class = _get_classes(config=config)
|
||||
|
||||
class ShardedModel(model_class):
|
||||
def __init__(self, args):
|
||||
super().__init__(args)
|
||||
self.shard = Shard(args.shard.model_id, args.shard.start_layer, args.shard.end_layer, args.shard.n_layers)
|
||||
|
||||
def __call__(self, x, *args, **kwargs):
|
||||
y = super().__call__(x, *args, **kwargs)
|
||||
return y
|
||||
|
||||
model_args = model_args_class.from_dict(config)
|
||||
model = ShardedModel(model_args)
|
||||
|
||||
if config.get("model_index", False):
|
||||
model.load()
|
||||
return model
|
||||
|
||||
if not weight_files:
|
||||
logging.error(f"No safetensors found in {model_path}")
|
||||
raise FileNotFoundError(f"No safetensors found in {model_path}")
|
||||
@@ -129,19 +155,7 @@ def load_model_shard(
|
||||
|
||||
weights.update(mx.load(wf))
|
||||
|
||||
model_class, model_args_class = _get_classes(config=config)
|
||||
|
||||
class ShardedModel(model_class):
|
||||
def __init__(self, args):
|
||||
super().__init__(args)
|
||||
self.shard = Shard(args.shard.model_id, args.shard.start_layer, args.shard.end_layer, args.shard.n_layers)
|
||||
|
||||
def __call__(self, x, *args, **kwargs):
|
||||
y = super().__call__(x, *args, **kwargs)
|
||||
return y
|
||||
|
||||
model_args = model_args_class.from_dict(config)
|
||||
model = ShardedModel(model_args)
|
||||
|
||||
|
||||
if hasattr(model, "sanitize"):
|
||||
weights = model.sanitize(weights)
|
||||
@@ -186,6 +200,9 @@ async def load_shard(
|
||||
processor.eos_token_id = processor.tokenizer.eos_token_id
|
||||
processor.encode = processor.tokenizer.encode
|
||||
return model, processor
|
||||
elif hasattr(model, "tokenizer"):
|
||||
tokenizer = model.tokenizer
|
||||
return model, tokenizer
|
||||
else:
|
||||
tokenizer = await resolve_tokenizer(model_path)
|
||||
return model, tokenizer
|
||||
@@ -214,3 +231,27 @@ async def get_image_from_str(_image_str: str):
|
||||
return img
|
||||
else:
|
||||
raise ValueError("Invalid image_str format. Must be a URL or a base64 encoded image.")
|
||||
|
||||
# loading a combined config for all models in the index
|
||||
def load_model_index(model_path: Path, model_index_path: Path):
|
||||
models_config = {}
|
||||
with open(model_index_path, "r") as f:
|
||||
model_index = json.load(f)
|
||||
models_config["model_index"] = True
|
||||
models_config["model_type"] = model_index["_class_name"]
|
||||
models_config["models"] = {}
|
||||
for model in model_index.keys():
|
||||
model_config_path = glob.glob(str(model_path / model / "*config.json"))
|
||||
if len(model_config_path)>0:
|
||||
with open(model_config_path[0], "r") as f:
|
||||
model_config = { }
|
||||
model_config["model_type"] = model
|
||||
model_config["config"] = json.load(f)
|
||||
model_config["path"] = model_path / model
|
||||
if model_config["path"]/"*model.safetensors":
|
||||
model_config["config"].update({"weight_files": list(glob.glob(str(model_config["path"]/"*model.safetensors")))})
|
||||
model_config["path"] = str(model_path / model)
|
||||
m = {}
|
||||
m[model] = model_config
|
||||
models_config.update(m)
|
||||
return models_config
|
||||
|
||||
81
exo/inference/mlx/test_non_blocking.py
Normal file
81
exo/inference/mlx/test_non_blocking.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import asyncio
|
||||
import time
|
||||
import numpy as np
|
||||
from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
|
||||
from exo.download.new_shard_download import NewShardDownloader
|
||||
from exo.inference.shard import Shard
|
||||
from exo.models import build_base_shard
|
||||
from collections import deque
|
||||
from statistics import mean, median
|
||||
|
||||
async def test_non_blocking():
|
||||
# Setup
|
||||
shard_downloader = NewShardDownloader()
|
||||
engine = MLXDynamicShardInferenceEngine(shard_downloader)
|
||||
_shard = build_base_shard("llama-3.1-8b", "MLXDynamicShardInferenceEngine")
|
||||
shard = Shard(_shard.model_id, _shard.start_layer, _shard.n_layers - 1, _shard.n_layers)
|
||||
await engine.ensure_shard(shard)
|
||||
|
||||
queue = asyncio.Queue()
|
||||
measurements = deque(maxlen=1000000)
|
||||
running = True
|
||||
|
||||
async def mlx_worker():
|
||||
try:
|
||||
start_time = time.time()
|
||||
count = 0
|
||||
while running and (time.time() - start_time) < 5: # Hard time limit
|
||||
start = time.perf_counter_ns()
|
||||
await engine.infer_prompt("req1", shard, "test prompt")
|
||||
duration = (time.perf_counter_ns() - start) / 1_000_000 # Convert to ms
|
||||
count += 1
|
||||
print(f"MLX operation {count} took: {duration:.3f}ms")
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
print(f"\nTotal MLX operations completed: {count}")
|
||||
print(f"Average rate: {count/5:.1f} ops/second")
|
||||
|
||||
async def latency_producer():
|
||||
try:
|
||||
start_time = time.perf_counter_ns()
|
||||
count = 0
|
||||
while running:
|
||||
await queue.put(time.perf_counter_ns())
|
||||
count += 1
|
||||
await asyncio.sleep(0) # Yield to event loop without delay
|
||||
duration = (time.perf_counter_ns() - start_time) / 1e9 # Convert to seconds
|
||||
print(f"\nProducer iterations: {count}")
|
||||
print(f"Producer rate: {count/duration:.1f} iterations/second")
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def latency_consumer():
|
||||
try:
|
||||
while running:
|
||||
timestamp = await queue.get()
|
||||
latency = (time.perf_counter_ns() - timestamp) / 1_000_000 # Convert to ms
|
||||
measurements.append(latency)
|
||||
queue.task_done()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(mlx_worker()),
|
||||
asyncio.create_task(latency_producer()),
|
||||
asyncio.create_task(latency_consumer())
|
||||
]
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.gather(*tasks), timeout=6)
|
||||
except asyncio.TimeoutError:
|
||||
print("\nTest timed out")
|
||||
finally:
|
||||
running = False
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
print(f"\nFinal measurement count: {len(measurements)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_non_blocking())
|
||||
@@ -1,22 +1,16 @@
|
||||
import pytest
|
||||
import json
|
||||
import numpy as np
|
||||
from exo.inference.dummy_inference_engine import DummyInferenceEngine
|
||||
from exo.inference.shard import Shard
|
||||
|
||||
|
||||
class MockShardDownloader:
|
||||
async def ensure_shard(self, shard):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dummy_inference_specific():
|
||||
engine = DummyInferenceEngine(MockShardDownloader())
|
||||
engine = DummyInferenceEngine()
|
||||
test_shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
|
||||
test_prompt = "This is a test prompt"
|
||||
|
||||
result = await engine.infer_prompt("test_request", test_shard, test_prompt)
|
||||
result, _ = await engine.infer_prompt("test_request", test_shard, test_prompt)
|
||||
|
||||
print(f"Inference result shape: {result.shape}")
|
||||
|
||||
@@ -26,20 +20,20 @@ async def test_dummy_inference_specific():
|
||||
@pytest.mark.asyncio
|
||||
async def test_dummy_inference_engine():
|
||||
# Initialize the DummyInferenceEngine
|
||||
engine = DummyInferenceEngine(MockShardDownloader())
|
||||
engine = DummyInferenceEngine()
|
||||
|
||||
# Create a test shard
|
||||
shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
|
||||
|
||||
# Test infer_prompt
|
||||
output = await engine.infer_prompt("test_id", shard, "Test prompt")
|
||||
output, _ = await engine.infer_prompt("test_id", shard, "Test prompt")
|
||||
|
||||
assert isinstance(output, np.ndarray), "Output should be a numpy array"
|
||||
assert output.ndim == 2, "Output should be 2-dimensional"
|
||||
|
||||
# Test infer_tensor
|
||||
input_tensor = np.array([[1, 2, 3]])
|
||||
output = await engine.infer_tensor("test_id", shard, input_tensor)
|
||||
output, _ = await engine.infer_tensor("test_id", shard, input_tensor)
|
||||
|
||||
assert isinstance(output, np.ndarray), "Output should be a numpy array"
|
||||
assert output.ndim == 2, "Output should be 2-dimensional"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
|
||||
from exo.download.hf.hf_shard_download import HFShardDownloader
|
||||
from exo.inference.inference_engine import InferenceEngine
|
||||
from exo.download.new_shard_download import NewShardDownloader
|
||||
from exo.inference.shard import Shard
|
||||
from exo.helpers import DEBUG
|
||||
import os
|
||||
@@ -11,30 +11,30 @@ import numpy as np
|
||||
# An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
|
||||
async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str, n_layers: int):
|
||||
prompt = "In a single word only, what is the last name of the current president of the USA?"
|
||||
resp_full = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers), prompt=prompt)
|
||||
resp_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers), prompt=prompt)
|
||||
token_full = await inference_engine_1.sample(resp_full)
|
||||
token_full = token_full.reshape(1, -1)
|
||||
next_resp_full = await inference_engine_1.infer_tensor(
|
||||
next_resp_full, _ = await inference_engine_1.infer_tensor(
|
||||
"A",
|
||||
shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers),
|
||||
input_data=token_full,
|
||||
)
|
||||
|
||||
pp = n_layers // 2
|
||||
resp1 = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers), prompt=prompt)
|
||||
resp2 = await inference_engine_2.infer_tensor(
|
||||
resp1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers), prompt=prompt)
|
||||
resp2, _ = await inference_engine_2.infer_tensor(
|
||||
"B",
|
||||
shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
|
||||
input_data=resp1,
|
||||
)
|
||||
tokens2 = await inference_engine_1.sample(resp2)
|
||||
tokens2 = tokens2.reshape(1, -1)
|
||||
resp3 = await inference_engine_1.infer_tensor(
|
||||
resp3, _ = await inference_engine_1.infer_tensor(
|
||||
"B",
|
||||
shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers),
|
||||
input_data=tokens2,
|
||||
)
|
||||
resp4 = await inference_engine_2.infer_tensor(
|
||||
resp4, _ = await inference_engine_2.infer_tensor(
|
||||
"B",
|
||||
shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
|
||||
input_data=resp3,
|
||||
@@ -44,13 +44,11 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
|
||||
assert np.array_equal(next_resp_full, resp4)
|
||||
|
||||
|
||||
asyncio.run(test_inference_engine(MLXDynamicShardInferenceEngine(HFShardDownloader()), MLXDynamicShardInferenceEngine(HFShardDownloader()), "llama-3.2-1b", 16))
|
||||
asyncio.run(test_inference_engine(MLXDynamicShardInferenceEngine(NewShardDownloader()), MLXDynamicShardInferenceEngine(NewShardDownloader()), "llama-3.2-1b", 16))
|
||||
|
||||
if os.getenv("RUN_TINYGRAD", default="0") == "1":
|
||||
import tinygrad
|
||||
import os
|
||||
from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
|
||||
tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
|
||||
asyncio.run(
|
||||
test_inference_engine(TinygradDynamicShardInferenceEngine(HFShardDownloader()), TinygradDynamicShardInferenceEngine(HFShardDownloader()), "llama-3-8b", 32)
|
||||
)
|
||||
asyncio.run(test_inference_engine(TinygradDynamicShardInferenceEngine(NewShardDownloader()), TinygradDynamicShardInferenceEngine(NewShardDownloader()), "llama-3.2-1b", 32))
|
||||
|
||||
@@ -15,7 +15,7 @@ from .stateful_model import make_prompt_state
|
||||
from .losses import length_masked_ce_loss
|
||||
from collections import OrderedDict
|
||||
import asyncio
|
||||
|
||||
from typing import Optional
|
||||
Tensor.no_grad = True
|
||||
# default settings
|
||||
TEMPERATURE = int(os.getenv("TEMPERATURE", 0.85))
|
||||
@@ -61,12 +61,13 @@ def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=No
|
||||
|
||||
return model
|
||||
|
||||
_executor = ThreadPoolExecutor(max_workers=1) # singleton so tinygrad always runs on the same thread
|
||||
class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
||||
def __init__(self, shard_downloader: ShardDownloader):
|
||||
self.shard = None
|
||||
self.shard_downloader = shard_downloader
|
||||
self.executor = ThreadPoolExecutor(max_workers=1)
|
||||
self.states = OrderedDict()
|
||||
self.executor = _executor
|
||||
|
||||
def poll_state(self, x, request_id: str, max_states=2):
|
||||
if request_id not in self.states:
|
||||
@@ -79,8 +80,8 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
||||
return {"start_pos": state.start, "cache": state.cache}
|
||||
|
||||
async def sample(self, x: np.ndarray, temp=TEMPERATURE, top_p: float = 0.0) -> np.ndarray:
|
||||
logits = x[:, -1, :]
|
||||
def sample_wrapper():
|
||||
logits = x[:, -1, :]
|
||||
return sample_logits(Tensor(logits).flatten(), temp, 0, 0.8, top_p, 0.0).realize().numpy().astype(int)
|
||||
return await asyncio.get_running_loop().run_in_executor(self.executor, sample_wrapper)
|
||||
|
||||
@@ -104,7 +105,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
||||
state_dict = await asyncio.get_running_loop().run_in_executor(self.executor, get_state_dict, self.model)
|
||||
safe_save(state_dict, path)
|
||||
|
||||
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
|
||||
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
|
||||
await self.ensure_shard(shard)
|
||||
def wrap_infer():
|
||||
x = Tensor(input_data)
|
||||
@@ -112,9 +113,9 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
||||
state = self.poll_state(h, request_id)
|
||||
out = self.model.forward(h, **state)
|
||||
self.states[request_id].start += x.shape[1]
|
||||
return out.realize()
|
||||
return out.numpy()
|
||||
output_data = await asyncio.get_running_loop().run_in_executor(self.executor, wrap_infer)
|
||||
return output_data.numpy()
|
||||
return output_data, inference_state
|
||||
|
||||
async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss=length_masked_ce_loss):
|
||||
def step(x, y, l):
|
||||
|
||||
@@ -322,6 +322,6 @@ def fix_bf16(weights: Dict[Any, Tensor]):
|
||||
}
|
||||
if getenv("SUPPORT_BF16", 1):
|
||||
# TODO: without casting to float16, 70B llama OOM on tinybox.
|
||||
return {k: v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}
|
||||
return {k: v.cast(dtypes.float32).cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}
|
||||
# TODO: check if device supports bf16
|
||||
return {k: v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
import traceback
|
||||
from aiofiles import os as aios
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from aiofiles import os as aios
|
||||
from typing import Union
|
||||
from transformers import AutoTokenizer, AutoProcessor
|
||||
import numpy as np
|
||||
from exo.download.hf.hf_helpers import get_local_snapshot_dir
|
||||
from exo.helpers import DEBUG
|
||||
from exo.download.new_shard_download import ensure_downloads_dir
|
||||
|
||||
|
||||
class DummyTokenizer:
|
||||
@@ -14,7 +13,7 @@ class DummyTokenizer:
|
||||
self.eos_token_id = 69
|
||||
self.vocab_size = 1000
|
||||
|
||||
def apply_chat_template(self, messages, tokenize=True, add_generation_prompt=True):
|
||||
def apply_chat_template(self, conversation, tokenize=True, add_generation_prompt=True, tools=None, **kwargs):
|
||||
return "dummy_tokenized_prompt"
|
||||
|
||||
def encode(self, text):
|
||||
@@ -24,25 +23,25 @@ class DummyTokenizer:
|
||||
return "dummy" * len(tokens)
|
||||
|
||||
|
||||
async def resolve_tokenizer(model_id: str):
|
||||
if model_id == "dummy":
|
||||
async def resolve_tokenizer(repo_id: Union[str, PathLike]):
|
||||
if repo_id == "dummy":
|
||||
return DummyTokenizer()
|
||||
local_path = await get_local_snapshot_dir(model_id)
|
||||
local_path = await ensure_downloads_dir()/str(repo_id).replace("/", "--")
|
||||
if DEBUG >= 2: print(f"Checking if local path exists to load tokenizer from local {local_path=}")
|
||||
try:
|
||||
if local_path and await aios.path.exists(local_path):
|
||||
if DEBUG >= 2: print(f"Resolving tokenizer for {model_id=} from {local_path=}")
|
||||
if DEBUG >= 2: print(f"Resolving tokenizer for {repo_id=} from {local_path=}")
|
||||
return await _resolve_tokenizer(local_path)
|
||||
except:
|
||||
if DEBUG >= 5: print(f"Local check for {local_path=} failed. Resolving tokenizer for {model_id=} normally...")
|
||||
if DEBUG >= 5: print(f"Local check for {local_path=} failed. Resolving tokenizer for {repo_id=} normally...")
|
||||
if DEBUG >= 5: traceback.print_exc()
|
||||
return await _resolve_tokenizer(model_id)
|
||||
return await _resolve_tokenizer(repo_id)
|
||||
|
||||
|
||||
async def _resolve_tokenizer(model_id_or_local_path: Union[str, PathLike]):
|
||||
async def _resolve_tokenizer(repo_id_or_local_path: Union[str, PathLike]):
|
||||
try:
|
||||
if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id_or_local_path}")
|
||||
processor = AutoProcessor.from_pretrained(model_id_or_local_path, use_fast=True if "Mistral-Large" in f"{model_id_or_local_path}" else False, trust_remote_code=True)
|
||||
if DEBUG >= 4: print(f"Trying AutoProcessor for {repo_id_or_local_path}")
|
||||
processor = AutoProcessor.from_pretrained(repo_id_or_local_path, use_fast=True if "Mistral-Large" in f"{repo_id_or_local_path}" else False, trust_remote_code=True)
|
||||
if not hasattr(processor, 'eos_token_id'):
|
||||
processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id
|
||||
if not hasattr(processor, 'encode'):
|
||||
@@ -51,14 +50,14 @@ async def _resolve_tokenizer(model_id_or_local_path: Union[str, PathLike]):
|
||||
processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode
|
||||
return processor
|
||||
except Exception as e:
|
||||
if DEBUG >= 4: print(f"Failed to load processor for {model_id_or_local_path}. Error: {e}")
|
||||
if DEBUG >= 4: print(f"Failed to load processor for {repo_id_or_local_path}. Error: {e}")
|
||||
if DEBUG >= 4: print(traceback.format_exc())
|
||||
|
||||
try:
|
||||
if DEBUG >= 4: print(f"Trying AutoTokenizer for {model_id_or_local_path}")
|
||||
return AutoTokenizer.from_pretrained(model_id_or_local_path, trust_remote_code=True)
|
||||
if DEBUG >= 4: print(f"Trying AutoTokenizer for {repo_id_or_local_path}")
|
||||
return AutoTokenizer.from_pretrained(repo_id_or_local_path, trust_remote_code=True)
|
||||
except Exception as e:
|
||||
if DEBUG >= 4: print(f"Failed to load tokenizer for {model_id_or_local_path}. Falling back to tinygrad tokenizer. Error: {e}")
|
||||
if DEBUG >= 4: print(f"Failed to load tokenizer for {repo_id_or_local_path}. Falling back to tinygrad tokenizer. Error: {e}")
|
||||
if DEBUG >= 4: print(traceback.format_exc())
|
||||
|
||||
raise ValueError(f"[TODO] Unsupported model: {model_id_or_local_path}")
|
||||
raise ValueError(f"[TODO] Unsupported model: {repo_id_or_local_path}")
|
||||
|
||||
206
exo/main.py
206
exo/main.py
@@ -3,20 +3,15 @@ import asyncio
|
||||
import atexit
|
||||
import signal
|
||||
import json
|
||||
import logging
|
||||
import platform
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
from tqdm import tqdm
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
from exo.train.dataset import load_dataset, iterate_batches, compose
|
||||
from exo.train.dataset import load_dataset, iterate_batches
|
||||
from exo.networking.manual.manual_discovery import ManualDiscovery
|
||||
from exo.networking.manual.network_topology_config import NetworkTopology
|
||||
from exo.orchestration.node import Node
|
||||
from exo.networking.grpc.grpc_server import GRPCServer
|
||||
from exo.networking.udp.udp_discovery import UDPDiscovery
|
||||
@@ -24,15 +19,41 @@ from exo.networking.tailscale.tailscale_discovery import TailscaleDiscovery
|
||||
from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
|
||||
from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
|
||||
from exo.api import ChatGPTAPI
|
||||
from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader
|
||||
from exo.download.hf.hf_shard_download import HFShardDownloader
|
||||
from exo.download.shard_download import ShardDownloader, NoopShardDownloader
|
||||
from exo.download.download_progress import RepoProgressEvent
|
||||
from exo.download.new_shard_download import new_shard_downloader, has_exo_home_read_access, has_exo_home_write_access, ensure_exo_home, seed_models
|
||||
from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses_and_interfaces, terminal_link, shutdown
|
||||
from exo.inference.shard import Shard
|
||||
from exo.inference.inference_engine import get_inference_engine, InferenceEngine
|
||||
from exo.inference.inference_engine import get_inference_engine
|
||||
from exo.inference.tokenizers import resolve_tokenizer
|
||||
from exo.models import build_base_shard, get_repo
|
||||
from exo.viz.topology_viz import TopologyViz
|
||||
from exo.download.hf.hf_helpers import has_hf_home_read_access, has_hf_home_write_access, get_hf_home, move_models_to_hf
|
||||
import uvloop
|
||||
import concurrent.futures
|
||||
import resource
|
||||
import psutil
|
||||
|
||||
# TODO: figure out why this is happening
|
||||
os.environ["GRPC_VERBOSITY"] = "error"
|
||||
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||
|
||||
# Configure uvloop for maximum performance
|
||||
def configure_uvloop():
|
||||
uvloop.install()
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# Increase file descriptor limits on Unix systems
|
||||
if not psutil.WINDOWS:
|
||||
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
try: resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
|
||||
except ValueError:
|
||||
try: resource.setrlimit(resource.RLIMIT_NOFILE, (8192, hard))
|
||||
except ValueError: pass
|
||||
|
||||
loop.set_default_executor(concurrent.futures.ThreadPoolExecutor(max_workers=min(32, (os.cpu_count() or 1) * 4)))
|
||||
return loop
|
||||
|
||||
# parse args
|
||||
parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
|
||||
@@ -51,15 +72,14 @@ parser.add_argument("--node-port", type=int, default=None, help="Node port")
|
||||
parser.add_argument("--models-seed-dir", type=str, default=None, help="Model seed directory")
|
||||
parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
|
||||
parser.add_argument("--download-quick-check", action="store_true", help="Quick check local path for model shards download")
|
||||
parser.add_argument("--max-parallel-downloads", type=int, default=4, help="Max parallel downloads for model shards download")
|
||||
parser.add_argument("--prometheus-client-port", type=int, default=None, help="Prometheus client port")
|
||||
parser.add_argument("--max-parallel-downloads", type=int, default=8, help="Max parallel downloads for model shards download")
|
||||
parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
|
||||
parser.add_argument("--discovery-module", type=str, choices=["udp", "tailscale", "manual"], default="udp", help="Discovery module to use")
|
||||
parser.add_argument("--discovery-timeout", type=int, default=30, help="Discovery timeout in seconds")
|
||||
parser.add_argument("--discovery-config-path", type=str, default=None, help="Path to discovery config json file")
|
||||
parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
|
||||
parser.add_argument("--chatgpt-api-port", type=int, default=52415, help="ChatGPT API port")
|
||||
parser.add_argument("--chatgpt-api-response-timeout", type=int, default=90, help="ChatGPT API response timeout in seconds")
|
||||
parser.add_argument("--chatgpt-api-response-timeout", type=int, default=900, help="ChatGPT API response timeout in seconds")
|
||||
parser.add_argument("--max-generate-tokens", type=int, default=10000, help="Max tokens to generate in each request")
|
||||
parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use (mlx, tinygrad, or dummy)")
|
||||
parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
|
||||
@@ -69,6 +89,8 @@ parser.add_argument("--default-temp", type=float, help="Default token sampling t
|
||||
parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
|
||||
parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
|
||||
parser.add_argument("--node-id-filter", type=str, default=None, help="Comma separated list of allowed node IDs (only for UDP and Tailscale discovery)")
|
||||
parser.add_argument("--interface-type-filter", type=str, default=None, help="Comma separated list of allowed interface types (only for UDP discovery)")
|
||||
parser.add_argument("--system-prompt", type=str, default=None, help="System prompt for the ChatGPT API")
|
||||
args = parser.parse_args()
|
||||
print(f"Selected inference engine: {args.inference_engine}")
|
||||
|
||||
@@ -77,8 +99,7 @@ print_yellow_exo()
|
||||
system_info = get_system_info()
|
||||
print(f"Detected system: {system_info}")
|
||||
|
||||
shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check,
|
||||
max_parallel_downloads=args.max_parallel_downloads) if args.inference_engine != "dummy" else NoopShardDownloader()
|
||||
shard_downloader: ShardDownloader = new_shard_downloader(args.max_parallel_downloads) if args.inference_engine != "dummy" else NoopShardDownloader()
|
||||
inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
|
||||
print(f"Inference engine name after selection: {inference_engine_name}")
|
||||
|
||||
@@ -100,8 +121,9 @@ if DEBUG >= 0:
|
||||
for chatgpt_api_endpoint in chatgpt_api_endpoints:
|
||||
print(f" - {terminal_link(chatgpt_api_endpoint)}")
|
||||
|
||||
# Convert node-id-filter to list if provided
|
||||
# Convert node-id-filter and interface-type-filter to lists if provided
|
||||
allowed_node_ids = args.node_id_filter.split(',') if args.node_id_filter else None
|
||||
allowed_interface_types = args.interface_type_filter.split(',') if args.interface_type_filter else None
|
||||
|
||||
if args.discovery_module == "udp":
|
||||
discovery = UDPDiscovery(
|
||||
@@ -111,7 +133,8 @@ if args.discovery_module == "udp":
|
||||
args.broadcast_port,
|
||||
lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
|
||||
discovery_timeout=args.discovery_timeout,
|
||||
allowed_node_ids=allowed_node_ids
|
||||
allowed_node_ids=allowed_node_ids,
|
||||
allowed_interface_types=allowed_interface_types
|
||||
)
|
||||
elif args.discovery_module == "tailscale":
|
||||
discovery = TailscaleDiscovery(
|
||||
@@ -133,62 +156,73 @@ node = Node(
|
||||
None,
|
||||
inference_engine,
|
||||
discovery,
|
||||
shard_downloader,
|
||||
partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
|
||||
max_generate_tokens=args.max_generate_tokens,
|
||||
topology_viz=topology_viz,
|
||||
shard_downloader=shard_downloader,
|
||||
default_sample_temperature=args.default_temp
|
||||
)
|
||||
server = GRPCServer(node, args.node_host, args.node_port)
|
||||
node.server = server
|
||||
api = ChatGPTAPI(
|
||||
node,
|
||||
inference_engine.__class__.__name__,
|
||||
node.inference_engine.__class__.__name__,
|
||||
response_timeout=args.chatgpt_api_response_timeout,
|
||||
on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None,
|
||||
default_model=args.default_model
|
||||
default_model=args.default_model,
|
||||
system_prompt=args.system_prompt
|
||||
)
|
||||
node.on_token.register("update_topology_viz").on_next(
|
||||
lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None
|
||||
)
|
||||
|
||||
def preemptively_start_download(request_id: str, opaque_status: str):
|
||||
buffered_token_output = {}
|
||||
def update_topology_viz(req_id, tokens, __):
|
||||
if not topology_viz: return
|
||||
if not node.inference_engine.shard: return
|
||||
if node.inference_engine.shard.model_id == 'stable-diffusion-2-1-base': return
|
||||
if req_id in buffered_token_output: buffered_token_output[req_id].extend(tokens)
|
||||
else: buffered_token_output[req_id] = tokens
|
||||
topology_viz.update_prompt_output(req_id, node.inference_engine.tokenizer.decode(buffered_token_output[req_id]))
|
||||
node.on_token.register("update_topology_viz").on_next(update_topology_viz)
|
||||
def update_prompt_viz(request_id, opaque_status: str):
|
||||
if not topology_viz: return
|
||||
try:
|
||||
status = json.loads(opaque_status)
|
||||
if status.get("type") == "node_status" and status.get("status") == "start_process_prompt":
|
||||
current_shard = node.get_current_shard(Shard.from_dict(status.get("shard")))
|
||||
if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}")
|
||||
asyncio.create_task(shard_downloader.ensure_shard(current_shard, inference_engine.__class__.__name__))
|
||||
if status.get("type") != "node_status" or status.get("status") != "start_process_prompt": return
|
||||
topology_viz.update_prompt(request_id, status.get("prompt", "corrupted prompt (this should never happen)"))
|
||||
except Exception as e:
|
||||
if DEBUG >= 2:
|
||||
print(f"Failed to update prompt viz: {e}")
|
||||
traceback.print_exc()
|
||||
node.on_opaque_status.register("update_prompt_viz").on_next(update_prompt_viz)
|
||||
|
||||
def preemptively_load_shard(request_id: str, opaque_status: str):
|
||||
try:
|
||||
status = json.loads(opaque_status)
|
||||
if status.get("type") != "node_status" or status.get("status") != "start_process_prompt": return
|
||||
current_shard = node.get_current_shard(Shard.from_dict(status.get("shard")))
|
||||
if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}")
|
||||
asyncio.create_task(node.inference_engine.ensure_shard(current_shard))
|
||||
except Exception as e:
|
||||
if DEBUG >= 2:
|
||||
print(f"Failed to preemptively start download: {e}")
|
||||
traceback.print_exc()
|
||||
node.on_opaque_status.register("preemptively_load_shard").on_next(preemptively_load_shard)
|
||||
|
||||
|
||||
node.on_opaque_status.register("start_download").on_next(preemptively_start_download)
|
||||
|
||||
if args.prometheus_client_port:
|
||||
from exo.stats.metrics import start_metrics_server
|
||||
start_metrics_server(node, args.prometheus_client_port)
|
||||
|
||||
last_broadcast_time = 0
|
||||
|
||||
|
||||
last_events: dict[str, tuple[float, RepoProgressEvent]] = {}
|
||||
def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
|
||||
global last_broadcast_time
|
||||
global last_events
|
||||
current_time = time.time()
|
||||
if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
|
||||
last_broadcast_time = current_time
|
||||
asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
|
||||
|
||||
|
||||
if event.status == "not_started": return
|
||||
last_event = last_events.get(shard.model_id)
|
||||
if last_event and last_event[1].status == "complete" and event.status == "complete": return
|
||||
if last_event and last_event[0] == event.status and current_time - last_event[0] < 0.2: return
|
||||
last_events[shard.model_id] = (current_time, event)
|
||||
asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
|
||||
shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
|
||||
|
||||
async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
|
||||
inference_class = inference_engine.__class__.__name__
|
||||
async def run_model_cli(node: Node, model_name: str, prompt: str):
|
||||
inference_class = node.inference_engine.__class__.__name__
|
||||
shard = build_base_shard(model_name, inference_class)
|
||||
if not shard:
|
||||
print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
|
||||
print(f"Error: Unsupported model '{model_name}' for inference engine {inference_class}")
|
||||
return
|
||||
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class))
|
||||
request_id = str(uuid.uuid4())
|
||||
@@ -202,7 +236,11 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam
|
||||
print(f"Processing prompt: {prompt}")
|
||||
await node.process_prompt(shard, prompt, request_id=request_id)
|
||||
|
||||
_, tokens, _ = await callback.wait(lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished, timeout=300)
|
||||
tokens = []
|
||||
def on_token(_request_id, _tokens, _is_finished):
|
||||
tokens.extend(_tokens)
|
||||
return _request_id == request_id and _is_finished
|
||||
await callback.wait(on_token, timeout=300)
|
||||
|
||||
print("\nGenerated response:")
|
||||
print(tokenizer.decode(tokens))
|
||||
@@ -221,7 +259,7 @@ def clean_path(path):
|
||||
async def hold_outstanding(node: Node):
|
||||
while node.outstanding_requests:
|
||||
await asyncio.sleep(.5)
|
||||
return
|
||||
return
|
||||
|
||||
async def run_iter(node: Node, shard: Shard, train: bool, data, batch_size=1):
|
||||
losses = []
|
||||
@@ -232,14 +270,14 @@ async def run_iter(node: Node, shard: Shard, train: bool, data, batch_size=1):
|
||||
tokens.append(np.sum(lengths))
|
||||
total_tokens = np.sum(tokens)
|
||||
total_loss = np.sum(losses) / total_tokens
|
||||
|
||||
|
||||
return total_loss, total_tokens
|
||||
|
||||
async def eval_model_cli(node: Node, inference_engine: InferenceEngine, model_name, dataloader, batch_size, num_batches=-1):
|
||||
inference_class = inference_engine.__class__.__name__
|
||||
async def eval_model_cli(node: Node, model_name, dataloader, batch_size, num_batches=-1):
|
||||
inference_class = node.inference_engine.__class__.__name__
|
||||
shard = build_base_shard(model_name, inference_class)
|
||||
if not shard:
|
||||
print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
|
||||
print(f"Error: Unsupported model '{model_name}' for inference engine {inference_class}")
|
||||
return
|
||||
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class))
|
||||
train, val, test = dataloader(tokenizer.encode)
|
||||
@@ -249,11 +287,11 @@ async def eval_model_cli(node: Node, inference_engine: InferenceEngine, model_na
|
||||
print("Waiting for outstanding tasks")
|
||||
await hold_outstanding(node)
|
||||
|
||||
async def train_model_cli(node: Node, inference_engine: InferenceEngine, model_name, dataloader, batch_size, iters, save_interval=0, checkpoint_dir=None):
|
||||
inference_class = inference_engine.__class__.__name__
|
||||
async def train_model_cli(node: Node, model_name, dataloader, batch_size, iters, save_interval=0, checkpoint_dir=None):
|
||||
inference_class = node.inference_engine.__class__.__name__
|
||||
shard = build_base_shard(model_name, inference_class)
|
||||
if not shard:
|
||||
print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
|
||||
print(f"Error: Unsupported model '{model_name}' for inference engine {inference_class}")
|
||||
return
|
||||
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class))
|
||||
train, val, test = dataloader(tokenizer.encode)
|
||||
@@ -268,28 +306,30 @@ async def train_model_cli(node: Node, inference_engine: InferenceEngine, model_n
|
||||
await hold_outstanding(node)
|
||||
await hold_outstanding(node)
|
||||
|
||||
|
||||
async def main():
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
# Check HuggingFace directory permissions
|
||||
hf_home, has_read, has_write = get_hf_home(), await has_hf_home_read_access(), await has_hf_home_write_access()
|
||||
if DEBUG >= 1: print(f"Model storage directory: {hf_home}")
|
||||
async def check_exo_home():
|
||||
home, has_read, has_write = await ensure_exo_home(), await has_exo_home_read_access(), await has_exo_home_write_access()
|
||||
if DEBUG >= 1: print(f"exo home directory: {home}")
|
||||
print(f"{has_read=}, {has_write=}")
|
||||
if not has_read or not has_write:
|
||||
print(f"""
|
||||
WARNING: Limited permissions for model storage directory: {hf_home}.
|
||||
WARNING: Limited permissions for exo home directory: {home}.
|
||||
This may prevent model downloads from working correctly.
|
||||
{"❌ No read access" if not has_read else ""}
|
||||
{"❌ No write access" if not has_write else ""}
|
||||
""")
|
||||
|
||||
|
||||
async def main():
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
try: await check_exo_home()
|
||||
except Exception as e: print(f"Error checking exo home directory: {e}")
|
||||
|
||||
if not args.models_seed_dir is None:
|
||||
try:
|
||||
models_seed_dir = clean_path(args.models_seed_dir)
|
||||
await move_models_to_hf(models_seed_dir)
|
||||
await seed_models(models_seed_dir)
|
||||
except Exception as e:
|
||||
print(f"Error moving models to .cache/huggingface: {e}")
|
||||
print(f"Error seeding models: {e}")
|
||||
|
||||
def restore_cursor():
|
||||
if platform.system() != "Windows":
|
||||
@@ -313,7 +353,7 @@ async def main():
|
||||
if not model_name:
|
||||
print("Error: Model name is required when using 'run' command or --run-model")
|
||||
return
|
||||
await run_model_cli(node, inference_engine, model_name, args.prompt)
|
||||
await run_model_cli(node, model_name, args.prompt)
|
||||
elif args.command == "eval" or args.command == 'train':
|
||||
model_name = args.model_name
|
||||
dataloader = lambda tok: load_dataset(args.data, preprocess=lambda item: tok(item)
|
||||
@@ -322,35 +362,31 @@ async def main():
|
||||
if not model_name:
|
||||
print("Error: Much like a human, I can't evaluate anything without a model")
|
||||
return
|
||||
await eval_model_cli(node, inference_engine, model_name, dataloader, args.batch_size)
|
||||
await eval_model_cli(node, model_name, dataloader, args.batch_size)
|
||||
else:
|
||||
if not model_name:
|
||||
print("Error: This train ain't leaving the station without a model")
|
||||
return
|
||||
await train_model_cli(node, inference_engine, model_name, dataloader, args.batch_size, args.iters, save_interval=args.save_every, checkpoint_dir=args.save_checkpoint_dir)
|
||||
|
||||
await train_model_cli(node, model_name, dataloader, args.batch_size, args.iters, save_interval=args.save_every, checkpoint_dir=args.save_checkpoint_dir)
|
||||
|
||||
else:
|
||||
asyncio.create_task(api.run(port=args.chatgpt_api_port)) # Start the API server as a non-blocking task
|
||||
await asyncio.Event().wait()
|
||||
|
||||
|
||||
if args.wait_for_peers > 0:
|
||||
print("Cooldown to allow peers to exit gracefully")
|
||||
for i in tqdm(range(50)):
|
||||
await asyncio.sleep(.1)
|
||||
|
||||
|
||||
def run():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
loop.run_until_complete(main())
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("Received keyboard interrupt. Shutting down...")
|
||||
finally:
|
||||
loop.run_until_complete(shutdown(signal.SIGTERM, loop, node.server))
|
||||
loop.close()
|
||||
|
||||
loop = None
|
||||
try:
|
||||
loop = configure_uvloop()
|
||||
loop.run_until_complete(main())
|
||||
except KeyboardInterrupt:
|
||||
print("\nShutdown requested... exiting")
|
||||
finally:
|
||||
if loop: loop.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
|
||||
107
exo/models.py
107
exo/models.py
@@ -88,18 +88,55 @@ model_cards = {
|
||||
### deepseek
|
||||
"deepseek-coder-v2-lite": { "layers": 27, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", }, },
|
||||
"deepseek-coder-v2.5": { "layers": 60, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", }, },
|
||||
"deepseek-v3": { "layers": 61, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-V3-4bit", }, },
|
||||
"deepseek-v3-3bit": { "layers": 61, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-V3-3bit", }, },
|
||||
"deepseek-r1": { "layers": 61, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-4bit", }, },
|
||||
"deepseek-r1-3bit": { "layers": 61, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-3bit", }, },
|
||||
### deepseek distills
|
||||
"deepseek-r1-distill-qwen-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/deepseek-r1-distill-qwen-1.5b", }, },
|
||||
"deepseek-r1-distill-qwen-1.5b-3bit": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-3bit", }, },
|
||||
"deepseek-r1-distill-qwen-1.5b-6bit": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-6bit", }, },
|
||||
"deepseek-r1-distill-qwen-1.5b-8bit": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-8bit", }, },
|
||||
"deepseek-r1-distill-qwen-1.5b-bf16": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-bf16", }, },
|
||||
"deepseek-r1-distill-qwen-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-7B-4bit", }, },
|
||||
"deepseek-r1-distill-qwen-7b-3bit": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-7B-3bit", }, },
|
||||
"deepseek-r1-distill-qwen-7b-6bit": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-7B-6bit", }, },
|
||||
"deepseek-r1-distill-qwen-7b-8bit": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-7B-8bit", }, },
|
||||
"deepseek-r1-distill-qwen-7b-bf16": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-7B-bf16", }, },
|
||||
"deepseek-r1-distill-qwen-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-14B-4bit", }, },
|
||||
"deepseek-r1-distill-qwen-14b-3bit": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-14B-3bit", }, },
|
||||
"deepseek-r1-distill-qwen-14b-6bit": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-14B-6bit", }, },
|
||||
"deepseek-r1-distill-qwen-14b-8bit": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-14B-8bit", }, },
|
||||
"deepseek-r1-distill-qwen-14b-bf16": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-14B-bf16", }, },
|
||||
"deepseek-r1-distill-qwen-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-32B-4bit", }, },
|
||||
"deepseek-r1-distill-qwen-32b-3bit": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-32B-3bit", }, },
|
||||
"deepseek-r1-distill-qwen-32b-6bit": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-32B-6bit", }, },
|
||||
"deepseek-r1-distill-qwen-32b-8bit": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-32B-MLX-8Bit", }, },
|
||||
"deepseek-r1-distill-qwen-32b-bf16": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-32B-bf16", }, },
|
||||
"deepseek-r1-distill-llama-8b": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-8B-4bit", }, },
|
||||
"deepseek-r1-distill-llama-8b-3bit": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-8B-3bit", }, },
|
||||
"deepseek-r1-distill-llama-8b-6bit": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-8B-6bit", }, },
|
||||
"deepseek-r1-distill-llama-8b-8bit": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-8B-8bit", }, },
|
||||
"deepseek-r1-distill-llama-8b-bf16": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-8B-bf16", }, },
|
||||
"deepseek-r1-distill-llama-70b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-70B-4bit", }, },
|
||||
"deepseek-r1-distill-llama-70b-3bit": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-70B-3bit", }, },
|
||||
"deepseek-r1-distill-llama-70b-6bit": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-70B-6bit", }, },
|
||||
"deepseek-r1-distill-llama-70b-8bit": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-70B-8bit", }, },
|
||||
### llava
|
||||
"llava-1.5-7b-hf": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "llava-hf/llava-1.5-7b-hf", }, },
|
||||
### qwen
|
||||
"qwen-2.5-0.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-0.5B-Instruct-4bit", }, },
|
||||
"qwen-2.5-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-1.5B-Instruct-4bit", }, },
|
||||
"qwen-2.5-coder-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-1.5B-Instruct-4bit", }, },
|
||||
"qwen-2.5-3b": { "layers": 36, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-3B-Instruct-4bit", }, },
|
||||
"qwen-2.5-coder-3b": { "layers": 36, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-3B-Instruct-4bit", }, },
|
||||
"qwen-2.5-coder-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", }, },
|
||||
"qwen-2.5-coder-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-14B-Instruct-4bit", }, },
|
||||
"qwen-2.5-coder-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-32B-Instruct-4bit", }, },
|
||||
"qwen-2.5-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-7B-Instruct-4bit", }, },
|
||||
"qwen-2.5-coder-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", }, },
|
||||
"qwen-2.5-math-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-7B-Instruct-4bit", }, },
|
||||
"qwen-2.5-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-14B-Instruct-4bit", }, },
|
||||
"qwen-2.5-coder-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-14B-Instruct-4bit", }, },
|
||||
"qwen-2.5-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-32B-Instruct-4bit", }, },
|
||||
"qwen-2.5-coder-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-32B-Instruct-4bit", }, },
|
||||
"qwen-2.5-72b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-72B-Instruct-4bit", }, },
|
||||
"qwen-2.5-math-72b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-72B-Instruct-4bit", }, },
|
||||
### nemotron
|
||||
@@ -108,6 +145,11 @@ model_cards = {
|
||||
# gemma
|
||||
"gemma2-9b": { "layers": 42, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-9b-it-4bit", }, },
|
||||
"gemma2-27b": { "layers": 46, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-27b-it-4bit", }, },
|
||||
# stable diffusion
|
||||
"stable-diffusion-2-1-base": { "layers": 31, "repo": { "MLXDynamicShardInferenceEngine": "stabilityai/stable-diffusion-2-1-base" } },
|
||||
# phi
|
||||
"phi-3.5-mini": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Phi-3.5-mini-instruct-4bit", }, },
|
||||
"phi-4": { "layers": 40, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/phi-4-4bit", }, },
|
||||
# dummy
|
||||
"dummy": { "layers": 8, "repo": { "DummyInferenceEngine": "dummy", }, },
|
||||
}
|
||||
@@ -132,24 +174,70 @@ pretty_name = {
|
||||
"mistral-large": "Mistral Large",
|
||||
"deepseek-coder-v2-lite": "Deepseek Coder V2 Lite",
|
||||
"deepseek-coder-v2.5": "Deepseek Coder V2.5",
|
||||
"deepseek-v3": "Deepseek V3 (4-bit)",
|
||||
"deepseek-v3-3bit": "Deepseek V3 (3-bit)",
|
||||
"deepseek-r1": "Deepseek R1 (4-bit)",
|
||||
"deepseek-r1-3bit": "Deepseek R1 (3-bit)",
|
||||
"llava-1.5-7b-hf": "LLaVa 1.5 7B (Vision Model)",
|
||||
"qwen-2.5-0.5b": "Qwen 2.5 0.5B",
|
||||
"qwen-2.5-1.5b": "Qwen 2.5 1.5B",
|
||||
"qwen-2.5-coder-1.5b": "Qwen 2.5 Coder 1.5B",
|
||||
"qwen-2.5-3b": "Qwen 2.5 3B",
|
||||
"qwen-2.5-coder-3b": "Qwen 2.5 Coder 3B",
|
||||
"qwen-2.5-coder-7b": "Qwen 2.5 Coder 7B",
|
||||
"qwen-2.5-coder-14b": "Qwen 2.5 Coder 14B",
|
||||
"qwen-2.5-coder-32b": "Qwen 2.5 Coder 32B",
|
||||
"qwen-2.5-7b": "Qwen 2.5 7B",
|
||||
"qwen-2.5-coder-7b": "Qwen 2.5 Coder 7B",
|
||||
"qwen-2.5-math-7b": "Qwen 2.5 7B (Math)",
|
||||
"qwen-2.5-14b": "Qwen 2.5 14B",
|
||||
"qwen-2.5-coder-14b": "Qwen 2.5 Coder 14B",
|
||||
"qwen-2.5-32b": "Qwen 2.5 32B",
|
||||
"qwen-2.5-coder-32b": "Qwen 2.5 Coder 32B",
|
||||
"qwen-2.5-72b": "Qwen 2.5 72B",
|
||||
"qwen-2.5-math-72b": "Qwen 2.5 72B (Math)",
|
||||
"phi-3.5-mini": "Phi-3.5 Mini",
|
||||
"phi-4": "Phi-4",
|
||||
"llama-3-8b": "Llama 3 8B",
|
||||
"llama-3-70b": "Llama 3 70B",
|
||||
"stable-diffusion-2-1-base": "Stable Diffusion 2.1",
|
||||
"deepseek-r1-distill-qwen-1.5b": "DeepSeek R1 Distill Qwen 1.5B",
|
||||
"deepseek-r1-distill-qwen-1.5b-3bit": "DeepSeek R1 Distill Qwen 1.5B (3-bit)",
|
||||
"deepseek-r1-distill-qwen-1.5b-6bit": "DeepSeek R1 Distill Qwen 1.5B (6-bit)",
|
||||
"deepseek-r1-distill-qwen-1.5b-8bit": "DeepSeek R1 Distill Qwen 1.5B (8-bit)",
|
||||
"deepseek-r1-distill-qwen-1.5b-bf16": "DeepSeek R1 Distill Qwen 1.5B (BF16)",
|
||||
"deepseek-r1-distill-qwen-7b": "DeepSeek R1 Distill Qwen 7B",
|
||||
"deepseek-r1-distill-qwen-7b-3bit": "DeepSeek R1 Distill Qwen 7B (3-bit)",
|
||||
"deepseek-r1-distill-qwen-7b-6bit": "DeepSeek R1 Distill Qwen 7B (6-bit)",
|
||||
"deepseek-r1-distill-qwen-7b-8bit": "DeepSeek R1 Distill Qwen 7B (8-bit)",
|
||||
"deepseek-r1-distill-qwen-7b-bf16": "DeepSeek R1 Distill Qwen 7B (BF16)",
|
||||
"deepseek-r1-distill-qwen-14b": "DeepSeek R1 Distill Qwen 14B",
|
||||
"deepseek-r1-distill-qwen-14b-3bit": "DeepSeek R1 Distill Qwen 14B (3-bit)",
|
||||
"deepseek-r1-distill-qwen-14b-6bit": "DeepSeek R1 Distill Qwen 14B (6-bit)",
|
||||
"deepseek-r1-distill-qwen-14b-8bit": "DeepSeek R1 Distill Qwen 14B (8-bit)",
|
||||
"deepseek-r1-distill-qwen-14b-bf16": "DeepSeek R1 Distill Qwen 14B (BF16)",
|
||||
"deepseek-r1-distill-qwen-32b": "DeepSeek R1 Distill Qwen 32B",
|
||||
"deepseek-r1-distill-qwen-32b-3bit": "DeepSeek R1 Distill Qwen 32B (3-bit)",
|
||||
"deepseek-r1-distill-qwen-32b-8bit": "DeepSeek R1 Distill Qwen 32B (8-bit)",
|
||||
"deepseek-r1-distill-qwen-32b-bf16": "DeepSeek R1 Distill Qwen 32B (BF16)",
|
||||
"deepseek-r1-distill-llama-8b-8bit": "DeepSeek R1 Distill Llama 8B (8-bit)",
|
||||
"deepseek-r1-distill-llama-70b-6bit": "DeepSeek R1 Distill Llama 70B (6-bit)",
|
||||
"deepseek-r1-distill-llama-70b-8bit": "DeepSeek R1 Distill Llama 70B (8-bit)",
|
||||
"deepseek-r1-distill-llama-8b": "DeepSeek R1 Distill Llama 8B",
|
||||
"deepseek-r1-distill-llama-8b-3bit": "DeepSeek R1 Distill Llama 8B (3-bit)",
|
||||
"deepseek-r1-distill-llama-8b-6bit": "DeepSeek R1 Distill Llama 8B (6-bit)",
|
||||
"deepseek-r1-distill-llama-8b-8bit": "DeepSeek R1 Distill Llama 8B (8-bit)",
|
||||
"deepseek-r1-distill-llama-8b-bf16": "DeepSeek R1 Distill Llama 8B (BF16)",
|
||||
"deepseek-r1-distill-llama-70b": "DeepSeek R1 Distill Llama 70B",
|
||||
"deepseek-r1-distill-llama-70b-3bit": "DeepSeek R1 Distill Llama 70B (3-bit)",
|
||||
"deepseek-r1-distill-llama-70b-6bit": "DeepSeek R1 Distill Llama 70B (6-bit)",
|
||||
"deepseek-r1-distill-llama-70b-8bit": "DeepSeek R1 Distill Llama 70B (8-bit)",
|
||||
"deepseek-r1-distill-qwen-32b-6bit": "DeepSeek R1 Distill Qwen 32B (6-bit)",
|
||||
}
|
||||
|
||||
def get_repo(model_id: str, inference_engine_classname: str) -> Optional[str]:
|
||||
return model_cards.get(model_id, {}).get("repo", {}).get(inference_engine_classname, None)
|
||||
|
||||
def get_pretty_name(model_id: str) -> Optional[str]:
|
||||
return pretty_name.get(model_id, None)
|
||||
|
||||
def build_base_shard(model_id: str, inference_engine_classname: str) -> Optional[Shard]:
|
||||
repo = get_repo(model_id, inference_engine_classname)
|
||||
n_layers = model_cards.get(model_id, {}).get("layers", 0)
|
||||
@@ -157,7 +245,12 @@ def build_base_shard(model_id: str, inference_engine_classname: str) -> Optional
|
||||
return None
|
||||
return Shard(model_id, 0, 0, n_layers)
|
||||
|
||||
def get_supported_models(supported_inference_engine_lists: List[List[str]]) -> List[str]:
|
||||
def build_full_shard(model_id: str, inference_engine_classname: str) -> Optional[Shard]:
|
||||
base_shard = build_base_shard(model_id, inference_engine_classname)
|
||||
if base_shard is None: return None
|
||||
return Shard(base_shard.model_id, 0, base_shard.n_layers - 1, base_shard.n_layers)
|
||||
|
||||
def get_supported_models(supported_inference_engine_lists: Optional[List[List[str]]] = None) -> List[str]:
|
||||
if not supported_inference_engine_lists:
|
||||
return list(model_cards.keys())
|
||||
|
||||
|
||||
@@ -11,6 +11,13 @@ from exo.inference.shard import Shard
|
||||
from exo.topology.topology import Topology
|
||||
from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
|
||||
from exo.helpers import DEBUG
|
||||
import json
|
||||
import platform
|
||||
|
||||
if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
|
||||
import mlx.core as mx
|
||||
else:
|
||||
import numpy as mx
|
||||
|
||||
|
||||
class GRPCPeerHandle(PeerHandle):
|
||||
@@ -21,6 +28,19 @@ class GRPCPeerHandle(PeerHandle):
|
||||
self._device_capabilities = device_capabilities
|
||||
self.channel = None
|
||||
self.stub = None
|
||||
self.channel_options = [
|
||||
("grpc.max_metadata_size", 64 * 1024 * 1024),
|
||||
("grpc.max_receive_message_length", 256 * 1024 * 1024),
|
||||
("grpc.max_send_message_length", 256 * 1024 * 1024),
|
||||
("grpc.max_concurrent_streams", 100),
|
||||
("grpc.http2.min_time_between_pings_ms", 10000),
|
||||
("grpc.keepalive_time_ms", 20000),
|
||||
("grpc.keepalive_timeout_ms", 10000),
|
||||
("grpc.keepalive_permit_without_calls", 1),
|
||||
("grpc.http2.max_pings_without_data", 0),
|
||||
("grpc.tcp_nodelay", 1),
|
||||
("grpc.optimization_target", "throughput"),
|
||||
]
|
||||
|
||||
def id(self) -> str:
|
||||
return self._id
|
||||
@@ -36,11 +56,11 @@ class GRPCPeerHandle(PeerHandle):
|
||||
|
||||
async def connect(self):
|
||||
if self.channel is None:
|
||||
self.channel = grpc.aio.insecure_channel(self.address, options=[
|
||||
("grpc.max_metadata_size", 32*1024*1024),
|
||||
('grpc.max_receive_message_length', 32*1024*1024),
|
||||
('grpc.max_send_message_length', 32*1024*1024)
|
||||
])
|
||||
self.channel = grpc.aio.insecure_channel(
|
||||
self.address,
|
||||
options=self.channel_options,
|
||||
compression=grpc.Compression.Gzip
|
||||
)
|
||||
self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
|
||||
await self.channel.channel_ready()
|
||||
|
||||
@@ -54,7 +74,13 @@ class GRPCPeerHandle(PeerHandle):
|
||||
self.stub = None
|
||||
|
||||
async def _ensure_connected(self):
|
||||
if not await self.is_connected(): await asyncio.wait_for(self.connect(), timeout=5)
|
||||
if not await self.is_connected():
|
||||
try:
|
||||
await asyncio.wait_for(self.connect(), timeout=10.0)
|
||||
except asyncio.TimeoutError:
|
||||
if DEBUG >= 2: print(f"Connection timeout for {self._id}@{self.address}")
|
||||
await self.disconnect()
|
||||
raise
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
try:
|
||||
@@ -71,7 +97,7 @@ class GRPCPeerHandle(PeerHandle):
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.array]:
|
||||
async def send_prompt(self, shard: Shard, prompt: str, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]:
|
||||
request = node_service_pb2.PromptRequest(
|
||||
prompt=prompt,
|
||||
shard=node_service_pb2.Shard(
|
||||
@@ -81,15 +107,11 @@ class GRPCPeerHandle(PeerHandle):
|
||||
n_layers=shard.n_layers,
|
||||
),
|
||||
request_id=request_id,
|
||||
inference_state=None if inference_state is None else self.serialize_inference_state(inference_state)
|
||||
)
|
||||
response = await self.stub.SendPrompt(request)
|
||||
await self.stub.SendPrompt(request)
|
||||
|
||||
if not response.tensor_data or not response.shape or not response.dtype:
|
||||
return None
|
||||
|
||||
return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
|
||||
|
||||
async def send_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.array]:
|
||||
async def send_tensor(self, shard: Shard, tensor: np.ndarray, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]:
|
||||
request = node_service_pb2.TensorRequest(
|
||||
shard=node_service_pb2.Shard(
|
||||
model_id=shard.model_id,
|
||||
@@ -99,6 +121,7 @@ class GRPCPeerHandle(PeerHandle):
|
||||
),
|
||||
tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)),
|
||||
request_id=request_id,
|
||||
inference_state=None if inference_state is None else self.serialize_inference_state(inference_state)
|
||||
)
|
||||
response = await self.stub.SendTensor(request)
|
||||
|
||||
@@ -106,7 +129,7 @@ class GRPCPeerHandle(PeerHandle):
|
||||
return None
|
||||
|
||||
return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
|
||||
|
||||
|
||||
async def send_example(self, shard: Shard, example: np.ndarray, target: np.ndarray, length: np.ndarray, train: bool, request_id: Optional[str] = None) -> Optional[np.array]:
|
||||
request = node_service_pb2.ExampleRequest(
|
||||
shard=node_service_pb2.Shard(
|
||||
@@ -128,7 +151,7 @@ class GRPCPeerHandle(PeerHandle):
|
||||
return loss, grads
|
||||
else:
|
||||
return loss
|
||||
|
||||
|
||||
async def send_loss(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.array]:
|
||||
request = node_service_pb2.TensorRequest(
|
||||
shard=node_service_pb2.Shard(
|
||||
@@ -147,26 +170,13 @@ class GRPCPeerHandle(PeerHandle):
|
||||
|
||||
return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
|
||||
|
||||
async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
|
||||
request = node_service_pb2.GetInferenceResultRequest(request_id=request_id)
|
||||
response = await self.stub.GetInferenceResult(request)
|
||||
if response.tensor is None:
|
||||
return None, response.is_finished
|
||||
return (
|
||||
np.frombuffer(response.tensor.tensor_data, dtype=np.dtype(response.tensor.dtype)).reshape(response.tensor.shape),
|
||||
response.is_finished,
|
||||
)
|
||||
|
||||
async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
|
||||
request = node_service_pb2.CollectTopologyRequest(visited=visited, max_depth=max_depth)
|
||||
response = await self.stub.CollectTopology(request)
|
||||
topology = Topology()
|
||||
for node_id, capabilities in response.nodes.items():
|
||||
device_capabilities = DeviceCapabilities(
|
||||
model=capabilities.model,
|
||||
chip=capabilities.chip,
|
||||
memory=capabilities.memory,
|
||||
flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8)
|
||||
model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory, flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8)
|
||||
)
|
||||
topology.update_node(node_id, device_capabilities)
|
||||
for node_id, peer_connections in response.peer_graph.items():
|
||||
@@ -175,9 +185,35 @@ class GRPCPeerHandle(PeerHandle):
|
||||
return topology
|
||||
|
||||
async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
|
||||
request = node_service_pb2.SendResultRequest(request_id=request_id, result=result, is_finished=is_finished)
|
||||
tensor = None
|
||||
if isinstance(result, np.ndarray):
|
||||
tensor = node_service_pb2.Tensor(tensor_data=result.tobytes(), shape=result.shape, dtype=str(result.dtype))
|
||||
result = []
|
||||
request = node_service_pb2.SendResultRequest(request_id=request_id, result=result, tensor=tensor, is_finished=is_finished)
|
||||
await self.stub.SendResult(request)
|
||||
|
||||
async def send_opaque_status(self, request_id: str, status: str) -> None:
|
||||
request = node_service_pb2.SendOpaqueStatusRequest(request_id=request_id, status=status)
|
||||
await self.stub.SendOpaqueStatus(request)
|
||||
|
||||
def serialize_inference_state(self, inference_state: dict) -> node_service_pb2.InferenceState:
|
||||
proto_inference_state = node_service_pb2.InferenceState()
|
||||
other_data = {}
|
||||
for k, v in inference_state.items():
|
||||
if isinstance(v, mx.array):
|
||||
np_array = np.array(v)
|
||||
tensor_data = node_service_pb2.Tensor(tensor_data=np_array.tobytes(), shape=list(np_array.shape), dtype=str(np_array.dtype))
|
||||
proto_inference_state.tensor_data[k].CopyFrom(tensor_data)
|
||||
elif isinstance(v, list) and all(isinstance(item, mx.array) for item in v):
|
||||
tensor_list = node_service_pb2.TensorList()
|
||||
for tensor in v:
|
||||
np_array = np.array(tensor)
|
||||
tensor_data = node_service_pb2.Tensor(tensor_data=np_array.tobytes(), shape=list(np_array.shape), dtype=str(np_array.dtype))
|
||||
tensor_list.tensors.append(tensor_data)
|
||||
proto_inference_state.tensor_list_data[k].CopyFrom(tensor_list)
|
||||
else:
|
||||
# For non-tensor data, we'll still use JSON
|
||||
other_data[k] = v
|
||||
if other_data:
|
||||
proto_inference_state.other_data_json = json.dumps(other_data)
|
||||
return proto_inference_state
|
||||
|
||||
@@ -3,11 +3,19 @@ from concurrent import futures
|
||||
import numpy as np
|
||||
from asyncio import CancelledError
|
||||
|
||||
import platform
|
||||
|
||||
from . import node_service_pb2
|
||||
from . import node_service_pb2_grpc
|
||||
from exo import DEBUG
|
||||
from exo.inference.shard import Shard
|
||||
from exo.orchestration import Node
|
||||
import json
|
||||
|
||||
if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
|
||||
import mlx.core as mx
|
||||
else:
|
||||
import numpy as mx
|
||||
|
||||
|
||||
class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
||||
@@ -19,11 +27,19 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
||||
|
||||
async def start(self) -> None:
|
||||
self.server = grpc.aio.server(
|
||||
futures.ThreadPoolExecutor(max_workers=10),
|
||||
futures.ThreadPoolExecutor(max_workers=32),
|
||||
options=[
|
||||
("grpc.max_metadata_size", 32*1024*1024),
|
||||
("grpc.max_send_message_length", 128*1024*1024),
|
||||
("grpc.max_receive_message_length", 128*1024*1024),
|
||||
("grpc.max_send_message_length", 256*1024*1024),
|
||||
("grpc.max_receive_message_length", 256*1024*1024),
|
||||
("grpc.keepalive_time_ms", 10000),
|
||||
("grpc.keepalive_timeout_ms", 5000),
|
||||
("grpc.http2.max_pings_without_data", 0),
|
||||
("grpc.http2.min_time_between_pings_ms", 10000),
|
||||
("grpc.http2.min_ping_interval_without_data_ms", 5000),
|
||||
("grpc.max_concurrent_streams", 100),
|
||||
("grpc.tcp_nodelay", 1),
|
||||
("grpc.optimization_target", "throughput"),
|
||||
],
|
||||
)
|
||||
node_service_pb2_grpc.add_NodeServiceServicer_to_server(self, self.server)
|
||||
@@ -50,7 +66,8 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
||||
)
|
||||
prompt = request.prompt
|
||||
request_id = request.request_id
|
||||
result = await self.node.process_prompt(shard, prompt, request_id)
|
||||
inference_state = None if request.inference_state is None else self.deserialize_inference_state(request.inference_state)
|
||||
result = await self.node.process_prompt(shard, prompt, request_id, inference_state)
|
||||
if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}")
|
||||
tensor_data = result.tobytes() if result is not None else None
|
||||
return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
|
||||
@@ -65,11 +82,13 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
||||
tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
|
||||
request_id = request.request_id
|
||||
|
||||
result = await self.node.process_tensor(shard, tensor, request_id)
|
||||
inference_state = None if request.inference_state is None else self.deserialize_inference_state(request.inference_state)
|
||||
|
||||
result = await self.node.process_tensor(shard, tensor, request_id, inference_state)
|
||||
if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
|
||||
tensor_data = result.tobytes() if result is not None else None
|
||||
return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
|
||||
|
||||
|
||||
async def SendExample(self, request, context):
|
||||
shard = Shard(
|
||||
model_id=request.shard.model_id,
|
||||
@@ -91,7 +110,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
||||
else:
|
||||
loss = await self.node.process_example(shard, example, target, length, train, request_id)
|
||||
return node_service_pb2.Loss(loss=loss, grads=None)
|
||||
|
||||
|
||||
async def CollectTopology(self, request, context):
|
||||
max_depth = request.max_depth
|
||||
visited = set(request.visited)
|
||||
@@ -107,12 +126,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
||||
for node_id, cap in topology.nodes.items()
|
||||
}
|
||||
peer_graph = {
|
||||
node_id: node_service_pb2.PeerConnections(
|
||||
connections=[
|
||||
node_service_pb2.PeerConnection(to_id=conn.to_id, description=conn.description)
|
||||
for conn in connections
|
||||
]
|
||||
)
|
||||
node_id: node_service_pb2.PeerConnections(connections=[node_service_pb2.PeerConnection(to_id=conn.to_id, description=conn.description) for conn in connections])
|
||||
for node_id, connections in topology.peer_graph.items()
|
||||
}
|
||||
if DEBUG >= 5: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}")
|
||||
@@ -122,7 +136,11 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
||||
request_id = request.request_id
|
||||
result = request.result
|
||||
is_finished = request.is_finished
|
||||
img = request.tensor
|
||||
if DEBUG >= 5: print(f"Received SendResult request: {request_id=} {result=} {is_finished=}")
|
||||
result = list(result)
|
||||
if len(img.tensor_data) > 0:
|
||||
result = np.frombuffer(img.tensor_data, dtype=np.dtype(img.dtype)).reshape(img.shape)
|
||||
self.node.on_token.trigger_all(request_id, result, is_finished)
|
||||
return node_service_pb2.Empty()
|
||||
|
||||
@@ -135,3 +153,19 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
||||
|
||||
async def HealthCheck(self, request, context):
|
||||
return node_service_pb2.HealthCheckResponse(is_healthy=True)
|
||||
|
||||
def deserialize_inference_state(self, inference_state_proto: node_service_pb2.InferenceState) -> dict:
|
||||
inference_state = {}
|
||||
|
||||
for k, tensor_data in inference_state_proto.tensor_data.items():
|
||||
np_array = np.frombuffer(tensor_data.tensor_data, dtype=tensor_data.dtype).reshape(tensor_data.shape)
|
||||
inference_state[k] = mx.array(np_array)
|
||||
|
||||
for k, tensor_list in inference_state_proto.tensor_list_data.items():
|
||||
inference_state[k] = [mx.array(np.frombuffer(tensor.tensor_data, dtype=tensor.dtype).reshape(tensor.shape)) for tensor in tensor_list.tensors]
|
||||
|
||||
if inference_state_proto.other_data_json:
|
||||
other_data = json.loads(inference_state_proto.other_data_json)
|
||||
inference_state.update(other_data)
|
||||
|
||||
return inference_state
|
||||
|
||||
@@ -6,7 +6,6 @@ service NodeService {
|
||||
rpc SendPrompt (PromptRequest) returns (Tensor) {}
|
||||
rpc SendTensor (TensorRequest) returns (Tensor) {}
|
||||
rpc SendExample (ExampleRequest) returns (Loss) {}
|
||||
rpc GetInferenceResult (GetInferenceResultRequest) returns (InferenceResult) {}
|
||||
rpc CollectTopology (CollectTopologyRequest) returns (Topology) {}
|
||||
rpc SendResult (SendResultRequest) returns (Empty) {}
|
||||
rpc SendOpaqueStatus (SendOpaqueStatusRequest) returns (Empty) {}
|
||||
@@ -24,12 +23,14 @@ message PromptRequest {
|
||||
Shard shard = 1;
|
||||
string prompt = 2;
|
||||
optional string request_id = 3;
|
||||
optional InferenceState inference_state = 4;
|
||||
}
|
||||
|
||||
message TensorRequest {
|
||||
Shard shard = 1;
|
||||
Tensor tensor = 2;
|
||||
optional string request_id = 3;
|
||||
optional InferenceState inference_state = 4;
|
||||
}
|
||||
|
||||
message ExampleRequest {
|
||||
@@ -45,15 +46,6 @@ message Loss {
|
||||
float loss = 1;
|
||||
optional Tensor grads = 2;
|
||||
}
|
||||
|
||||
message GetInferenceResultRequest {
|
||||
string request_id = 1;
|
||||
}
|
||||
|
||||
message InferenceResult {
|
||||
optional Tensor tensor = 1;
|
||||
bool is_finished = 2;
|
||||
}
|
||||
|
||||
message Tensor {
|
||||
bytes tensor_data = 1;
|
||||
@@ -61,6 +53,16 @@ message Tensor {
|
||||
string dtype = 3;
|
||||
}
|
||||
|
||||
message TensorList {
|
||||
repeated Tensor tensors = 1;
|
||||
}
|
||||
|
||||
message InferenceState {
|
||||
map<string, Tensor> tensor_data = 1;
|
||||
map<string, TensorList> tensor_list_data = 2;
|
||||
string other_data_json = 3;
|
||||
}
|
||||
|
||||
message CollectTopologyRequest {
|
||||
repeated string visited = 1;
|
||||
int32 max_depth = 2;
|
||||
@@ -96,7 +98,8 @@ message DeviceCapabilities {
|
||||
message SendResultRequest {
|
||||
string request_id = 1;
|
||||
repeated int32 result = 2;
|
||||
bool is_finished = 3;
|
||||
optional Tensor tensor = 3;
|
||||
bool is_finished = 4;
|
||||
}
|
||||
|
||||
message SendOpaqueStatusRequest {
|
||||
|
||||
@@ -24,59 +24,67 @@ _sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"k\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\r\n\x0b_request_id\"\x81\x01\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\r\n\x0b_request_id\"\xde\x01\n\x0e\x45xampleRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12%\n\x07\x65xample\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12$\n\x06target\x18\x03 \x01(\x0b\x32\x14.node_service.Tensor\x12$\n\x06length\x18\x04 \x01(\x0b\x32\x14.node_service.Tensor\x12\r\n\x05train\x18\x05 \x01(\x08\x12\x17\n\nrequest_id\x18\x06 \x01(\tH\x00\x88\x01\x01\x42\r\n\x0b_request_id\"H\n\x04Loss\x12\x0c\n\x04loss\x18\x01 \x01(\x02\x12(\n\x05grads\x18\x02 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x42\x08\n\x06_grads\"/\n\x19GetInferenceResultRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\"\\\n\x0fInferenceResult\x12)\n\x06tensor\x18\x01 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x12\x13\n\x0bis_finished\x18\x02 \x01(\x08\x42\t\n\x07_tensor\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"<\n\x16\x43ollectTopologyRequest\x12\x0f\n\x07visited\x18\x01 \x03(\t\x12\x11\n\tmax_depth\x18\x02 \x01(\x05\"\x98\x02\n\x08Topology\x12\x30\n\x05nodes\x18\x01 \x03(\x0b\x32!.node_service.Topology.NodesEntry\x12\x39\n\npeer_graph\x18\x02 \x03(\x0b\x32%.node_service.Topology.PeerGraphEntry\x1aN\n\nNodesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .node_service.DeviceCapabilities:\x02\x38\x01\x1aO\n\x0ePeerGraphEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12,\n\x05value\x18\x02 \x01(\x0b\x32\x1d.node_service.PeerConnections:\x02\x38\x01\"I\n\x0ePeerConnection\x12\r\n\x05to_id\x18\x01 \x01(\t\x12\x18\n\x0b\x64\x65scription\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x0e\n\x0c_description\"D\n\x0fPeerConnections\x12\x31\n\x0b\x63onnections\x18\x01 \x03(\x0b\x32\x1c.node_service.PeerConnection\"7\n\x0b\x44\x65viceFlops\x12\x0c\n\x04\x66p32\x18\x01 \x01(\x01\x12\x0c\n\x04\x66p16\x18\x02 \x01(\x01\x12\x0c\n\x04int8\x18\x03 \x01(\x01\"k\n\x12\x44\x65viceCapabilities\x12\r\n\x05model\x18\x01 \x01(\t\x12\x0c\n\x04\x63hip\x18\x02 \x01(\t\x12\x0e\n\x06memory\x18\x03 \x01(\x05\x12(\n\x05\x66lops\x18\x04 \x01(\x0b\x32\x19.node_service.DeviceFlops\"L\n\x11SendResultRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06result\x18\x02 \x03(\x05\x12\x13\n\x0bis_finished\x18\x03 \x01(\x08\"=\n\x17SendOpaqueStatusRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\t\"\x14\n\x12HealthCheckRequest\")\n\x13HealthCheckResponse\x12\x12\n\nis_healthy\x18\x01 \x01(\x08\"\x07\n\x05\x45mpty2\xf7\x04\n\x0bNodeService\x12\x41\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\x0bSendExample\x12\x1c.node_service.ExampleRequest\x1a\x12.node_service.Loss\"\x00\x12^\n\x12GetInferenceResult\x12\'.node_service.GetInferenceResultRequest\x1a\x1d.node_service.InferenceResult\"\x00\x12Q\n\x0f\x43ollectTopology\x12$.node_service.CollectTopologyRequest\x1a\x16.node_service.Topology\"\x00\x12\x44\n\nSendResult\x12\x1f.node_service.SendResultRequest\x1a\x13.node_service.Empty\"\x00\x12P\n\x10SendOpaqueStatus\x12%.node_service.SendOpaqueStatusRequest\x1a\x13.node_service.Empty\"\x00\x12T\n\x0bHealthCheck\x12 .node_service.HealthCheckRequest\x1a!.node_service.HealthCheckResponse\"\x00\x62\x06proto3')
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"\xbb\x01\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x12:\n\x0finference_state\x18\x04 \x01(\x0b\x32\x1c.node_service.InferenceStateH\x01\x88\x01\x01\x42\r\n\x0b_request_idB\x12\n\x10_inference_state\"\xd1\x01\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x12:\n\x0finference_state\x18\x04 \x01(\x0b\x32\x1c.node_service.InferenceStateH\x01\x88\x01\x01\x42\r\n\x0b_request_idB\x12\n\x10_inference_state\"\xde\x01\n\x0e\x45xampleRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12%\n\x07\x65xample\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12$\n\x06target\x18\x03 \x01(\x0b\x32\x14.node_service.Tensor\x12$\n\x06length\x18\x04 \x01(\x0b\x32\x14.node_service.Tensor\x12\r\n\x05train\x18\x05 \x01(\x08\x12\x17\n\nrequest_id\x18\x06 \x01(\tH\x00\x88\x01\x01\x42\r\n\x0b_request_id\"H\n\x04Loss\x12\x0c\n\x04loss\x18\x01 \x01(\x02\x12(\n\x05grads\x18\x02 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x42\x08\n\x06_grads\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"3\n\nTensorList\x12%\n\x07tensors\x18\x01 \x03(\x0b\x32\x14.node_service.Tensor\"\xd2\x02\n\x0eInferenceState\x12\x41\n\x0btensor_data\x18\x01 \x03(\x0b\x32,.node_service.InferenceState.TensorDataEntry\x12J\n\x10tensor_list_data\x18\x02 \x03(\x0b\x32\x30.node_service.InferenceState.TensorListDataEntry\x12\x17\n\x0fother_data_json\x18\x03 \x01(\t\x1aG\n\x0fTensorDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor:\x02\x38\x01\x1aO\n\x13TensorListDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\'\n\x05value\x18\x02 \x01(\x0b\x32\x18.node_service.TensorList:\x02\x38\x01\"<\n\x16\x43ollectTopologyRequest\x12\x0f\n\x07visited\x18\x01 \x03(\t\x12\x11\n\tmax_depth\x18\x02 \x01(\x05\"\x98\x02\n\x08Topology\x12\x30\n\x05nodes\x18\x01 \x03(\x0b\x32!.node_service.Topology.NodesEntry\x12\x39\n\npeer_graph\x18\x02 \x03(\x0b\x32%.node_service.Topology.PeerGraphEntry\x1aN\n\nNodesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .node_service.DeviceCapabilities:\x02\x38\x01\x1aO\n\x0ePeerGraphEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12,\n\x05value\x18\x02 \x01(\x0b\x32\x1d.node_service.PeerConnections:\x02\x38\x01\"I\n\x0ePeerConnection\x12\r\n\x05to_id\x18\x01 \x01(\t\x12\x18\n\x0b\x64\x65scription\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x0e\n\x0c_description\"D\n\x0fPeerConnections\x12\x31\n\x0b\x63onnections\x18\x01 \x03(\x0b\x32\x1c.node_service.PeerConnection\"7\n\x0b\x44\x65viceFlops\x12\x0c\n\x04\x66p32\x18\x01 \x01(\x01\x12\x0c\n\x04\x66p16\x18\x02 \x01(\x01\x12\x0c\n\x04int8\x18\x03 \x01(\x01\"k\n\x12\x44\x65viceCapabilities\x12\r\n\x05model\x18\x01 \x01(\t\x12\x0c\n\x04\x63hip\x18\x02 \x01(\t\x12\x0e\n\x06memory\x18\x03 \x01(\x05\x12(\n\x05\x66lops\x18\x04 \x01(\x0b\x32\x19.node_service.DeviceFlops\"\x82\x01\n\x11SendResultRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06result\x18\x02 \x03(\x05\x12)\n\x06tensor\x18\x03 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x12\x13\n\x0bis_finished\x18\x04 \x01(\x08\x42\t\n\x07_tensor\"=\n\x17SendOpaqueStatusRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\t\"\x14\n\x12HealthCheckRequest\")\n\x13HealthCheckResponse\x12\x12\n\nis_healthy\x18\x01 \x01(\x08\"\x07\n\x05\x45mpty2\x97\x04\n\x0bNodeService\x12\x41\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\x0bSendExample\x12\x1c.node_service.ExampleRequest\x1a\x12.node_service.Loss\"\x00\x12Q\n\x0f\x43ollectTopology\x12$.node_service.CollectTopologyRequest\x1a\x16.node_service.Topology\"\x00\x12\x44\n\nSendResult\x12\x1f.node_service.SendResultRequest\x1a\x13.node_service.Empty\"\x00\x12P\n\x10SendOpaqueStatus\x12%.node_service.SendOpaqueStatusRequest\x1a\x13.node_service.Empty\"\x00\x12T\n\x0bHealthCheck\x12 .node_service.HealthCheckRequest\x1a!.node_service.HealthCheckResponse\"\x00\x62\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'node_service_pb2', _globals)
|
||||
if not _descriptor._USE_C_DESCRIPTORS:
|
||||
DESCRIPTOR._loaded_options = None
|
||||
_globals['_INFERENCESTATE_TENSORDATAENTRY']._loaded_options = None
|
||||
_globals['_INFERENCESTATE_TENSORDATAENTRY']._serialized_options = b'8\001'
|
||||
_globals['_INFERENCESTATE_TENSORLISTDATAENTRY']._loaded_options = None
|
||||
_globals['_INFERENCESTATE_TENSORLISTDATAENTRY']._serialized_options = b'8\001'
|
||||
_globals['_TOPOLOGY_NODESENTRY']._loaded_options = None
|
||||
_globals['_TOPOLOGY_NODESENTRY']._serialized_options = b'8\001'
|
||||
_globals['_TOPOLOGY_PEERGRAPHENTRY']._loaded_options = None
|
||||
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_options = b'8\001'
|
||||
_globals['_SHARD']._serialized_start=36
|
||||
_globals['_SHARD']._serialized_end=119
|
||||
_globals['_PROMPTREQUEST']._serialized_start=121
|
||||
_globals['_PROMPTREQUEST']._serialized_end=228
|
||||
_globals['_TENSORREQUEST']._serialized_start=231
|
||||
_globals['_TENSORREQUEST']._serialized_end=360
|
||||
_globals['_EXAMPLEREQUEST']._serialized_start=363
|
||||
_globals['_EXAMPLEREQUEST']._serialized_end=585
|
||||
_globals['_LOSS']._serialized_start=587
|
||||
_globals['_LOSS']._serialized_end=659
|
||||
_globals['_GETINFERENCERESULTREQUEST']._serialized_start=661
|
||||
_globals['_GETINFERENCERESULTREQUEST']._serialized_end=708
|
||||
_globals['_INFERENCERESULT']._serialized_start=710
|
||||
_globals['_INFERENCERESULT']._serialized_end=802
|
||||
_globals['_TENSOR']._serialized_start=804
|
||||
_globals['_TENSOR']._serialized_end=863
|
||||
_globals['_COLLECTTOPOLOGYREQUEST']._serialized_start=865
|
||||
_globals['_COLLECTTOPOLOGYREQUEST']._serialized_end=925
|
||||
_globals['_TOPOLOGY']._serialized_start=928
|
||||
_globals['_TOPOLOGY']._serialized_end=1208
|
||||
_globals['_TOPOLOGY_NODESENTRY']._serialized_start=1049
|
||||
_globals['_TOPOLOGY_NODESENTRY']._serialized_end=1127
|
||||
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_start=1129
|
||||
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_end=1208
|
||||
_globals['_PEERCONNECTION']._serialized_start=1210
|
||||
_globals['_PEERCONNECTION']._serialized_end=1283
|
||||
_globals['_PEERCONNECTIONS']._serialized_start=1285
|
||||
_globals['_PEERCONNECTIONS']._serialized_end=1353
|
||||
_globals['_DEVICEFLOPS']._serialized_start=1355
|
||||
_globals['_DEVICEFLOPS']._serialized_end=1410
|
||||
_globals['_DEVICECAPABILITIES']._serialized_start=1412
|
||||
_globals['_DEVICECAPABILITIES']._serialized_end=1519
|
||||
_globals['_SENDRESULTREQUEST']._serialized_start=1521
|
||||
_globals['_SENDRESULTREQUEST']._serialized_end=1597
|
||||
_globals['_SENDOPAQUESTATUSREQUEST']._serialized_start=1599
|
||||
_globals['_SENDOPAQUESTATUSREQUEST']._serialized_end=1660
|
||||
_globals['_HEALTHCHECKREQUEST']._serialized_start=1662
|
||||
_globals['_HEALTHCHECKREQUEST']._serialized_end=1682
|
||||
_globals['_HEALTHCHECKRESPONSE']._serialized_start=1684
|
||||
_globals['_HEALTHCHECKRESPONSE']._serialized_end=1725
|
||||
_globals['_EMPTY']._serialized_start=1727
|
||||
_globals['_EMPTY']._serialized_end=1734
|
||||
_globals['_NODESERVICE']._serialized_start=1737
|
||||
_globals['_NODESERVICE']._serialized_end=2368
|
||||
_globals['_PROMPTREQUEST']._serialized_start=122
|
||||
_globals['_PROMPTREQUEST']._serialized_end=309
|
||||
_globals['_TENSORREQUEST']._serialized_start=312
|
||||
_globals['_TENSORREQUEST']._serialized_end=521
|
||||
_globals['_EXAMPLEREQUEST']._serialized_start=524
|
||||
_globals['_EXAMPLEREQUEST']._serialized_end=746
|
||||
_globals['_LOSS']._serialized_start=748
|
||||
_globals['_LOSS']._serialized_end=820
|
||||
_globals['_TENSOR']._serialized_start=822
|
||||
_globals['_TENSOR']._serialized_end=881
|
||||
_globals['_TENSORLIST']._serialized_start=883
|
||||
_globals['_TENSORLIST']._serialized_end=934
|
||||
_globals['_INFERENCESTATE']._serialized_start=937
|
||||
_globals['_INFERENCESTATE']._serialized_end=1275
|
||||
_globals['_INFERENCESTATE_TENSORDATAENTRY']._serialized_start=1123
|
||||
_globals['_INFERENCESTATE_TENSORDATAENTRY']._serialized_end=1194
|
||||
_globals['_INFERENCESTATE_TENSORLISTDATAENTRY']._serialized_start=1196
|
||||
_globals['_INFERENCESTATE_TENSORLISTDATAENTRY']._serialized_end=1275
|
||||
_globals['_COLLECTTOPOLOGYREQUEST']._serialized_start=1277
|
||||
_globals['_COLLECTTOPOLOGYREQUEST']._serialized_end=1337
|
||||
_globals['_TOPOLOGY']._serialized_start=1340
|
||||
_globals['_TOPOLOGY']._serialized_end=1620
|
||||
_globals['_TOPOLOGY_NODESENTRY']._serialized_start=1461
|
||||
_globals['_TOPOLOGY_NODESENTRY']._serialized_end=1539
|
||||
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_start=1541
|
||||
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_end=1620
|
||||
_globals['_PEERCONNECTION']._serialized_start=1622
|
||||
_globals['_PEERCONNECTION']._serialized_end=1695
|
||||
_globals['_PEERCONNECTIONS']._serialized_start=1697
|
||||
_globals['_PEERCONNECTIONS']._serialized_end=1765
|
||||
_globals['_DEVICEFLOPS']._serialized_start=1767
|
||||
_globals['_DEVICEFLOPS']._serialized_end=1822
|
||||
_globals['_DEVICECAPABILITIES']._serialized_start=1824
|
||||
_globals['_DEVICECAPABILITIES']._serialized_end=1931
|
||||
_globals['_SENDRESULTREQUEST']._serialized_start=1934
|
||||
_globals['_SENDRESULTREQUEST']._serialized_end=2064
|
||||
_globals['_SENDOPAQUESTATUSREQUEST']._serialized_start=2066
|
||||
_globals['_SENDOPAQUESTATUSREQUEST']._serialized_end=2127
|
||||
_globals['_HEALTHCHECKREQUEST']._serialized_start=2129
|
||||
_globals['_HEALTHCHECKREQUEST']._serialized_end=2149
|
||||
_globals['_HEALTHCHECKRESPONSE']._serialized_start=2151
|
||||
_globals['_HEALTHCHECKRESPONSE']._serialized_end=2192
|
||||
_globals['_EMPTY']._serialized_start=2194
|
||||
_globals['_EMPTY']._serialized_end=2201
|
||||
_globals['_NODESERVICE']._serialized_start=2204
|
||||
_globals['_NODESERVICE']._serialized_end=2739
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
|
||||
@@ -49,11 +49,6 @@ class NodeServiceStub(object):
|
||||
request_serializer=node__service__pb2.ExampleRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.Loss.FromString,
|
||||
_registered_method=True)
|
||||
self.GetInferenceResult = channel.unary_unary(
|
||||
'/node_service.NodeService/GetInferenceResult',
|
||||
request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.InferenceResult.FromString,
|
||||
_registered_method=True)
|
||||
self.CollectTopology = channel.unary_unary(
|
||||
'/node_service.NodeService/CollectTopology',
|
||||
request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
|
||||
@@ -97,12 +92,6 @@ class NodeServiceServicer(object):
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def GetInferenceResult(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def CollectTopology(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
@@ -145,11 +134,6 @@ def add_NodeServiceServicer_to_server(servicer, server):
|
||||
request_deserializer=node__service__pb2.ExampleRequest.FromString,
|
||||
response_serializer=node__service__pb2.Loss.SerializeToString,
|
||||
),
|
||||
'GetInferenceResult': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.GetInferenceResult,
|
||||
request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString,
|
||||
response_serializer=node__service__pb2.InferenceResult.SerializeToString,
|
||||
),
|
||||
'CollectTopology': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.CollectTopology,
|
||||
request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
|
||||
@@ -262,33 +246,6 @@ class NodeService(object):
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def GetInferenceResult(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/node_service.NodeService/GetInferenceResult',
|
||||
node__service__pb2.GetInferenceResultRequest.SerializeToString,
|
||||
node__service__pb2.InferenceResult.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def CollectTopology(request,
|
||||
target,
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import os
|
||||
import asyncio
|
||||
from exo.networking.discovery import Discovery
|
||||
from typing import Dict, List, Callable
|
||||
from typing import Dict, List, Callable, Optional
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from exo.networking.discovery import Discovery
|
||||
from exo.topology.device_capabilities import DeviceCapabilities
|
||||
from exo.networking.manual.network_topology_config import NetworkTopology, PeerConfig
|
||||
from exo.helpers import DEBUG_DISCOVERY
|
||||
@@ -13,28 +15,25 @@ class ManualDiscovery(Discovery):
|
||||
self,
|
||||
network_config_path: str,
|
||||
node_id: str,
|
||||
create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
|
||||
create_peer_handle: Callable[[str, str, str, DeviceCapabilities], PeerHandle],
|
||||
):
|
||||
self.topology = NetworkTopology.from_path(network_config_path)
|
||||
self.network_config_path = network_config_path
|
||||
self.node_id = node_id
|
||||
self.create_peer_handle = create_peer_handle
|
||||
|
||||
if node_id not in self.topology.peers:
|
||||
raise ValueError(
|
||||
f"Node ID {node_id} not found in network config file {network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in self.topology.peers]}"
|
||||
)
|
||||
|
||||
self.listen_task = None
|
||||
|
||||
self.known_peers: Dict[str, PeerHandle] = {}
|
||||
self.peers_in_network: Dict[str, PeerConfig] = self.topology.peers
|
||||
self.peers_in_network.pop(node_id)
|
||||
|
||||
self._cached_peers: Dict[str, PeerConfig] = {}
|
||||
self._last_modified_time: Optional[float] = None
|
||||
self._file_executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
async def start(self) -> None:
|
||||
self.listen_task = asyncio.create_task(self.task_find_peers_from_config())
|
||||
|
||||
async def stop(self) -> None:
|
||||
if self.listen_task:
|
||||
self.listen_task.cancel()
|
||||
if self.listen_task: self.listen_task.cancel()
|
||||
self._file_executor.shutdown(wait=True)
|
||||
|
||||
async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
|
||||
if wait_for_peers > 0:
|
||||
@@ -47,7 +46,9 @@ class ManualDiscovery(Discovery):
|
||||
async def task_find_peers_from_config(self):
|
||||
if DEBUG_DISCOVERY >= 2: print("Starting task to find peers from config...")
|
||||
while True:
|
||||
for peer_id, peer_config in self.peers_in_network.items():
|
||||
peers_from_config = await self._get_peers()
|
||||
new_known_peers = {}
|
||||
for peer_id, peer_config in peers_from_config.items():
|
||||
try:
|
||||
if DEBUG_DISCOVERY >= 2: print(f"Checking peer {peer_id=} at {peer_config.address}:{peer_config.port}")
|
||||
peer = self.known_peers.get(peer_id)
|
||||
@@ -57,15 +58,44 @@ class ManualDiscovery(Discovery):
|
||||
is_healthy = await peer.health_check()
|
||||
if is_healthy:
|
||||
if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.")
|
||||
self.known_peers[peer_id] = peer
|
||||
else:
|
||||
if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy.")
|
||||
try:
|
||||
del self.known_peers[peer_id]
|
||||
except KeyError:
|
||||
pass
|
||||
new_known_peers[peer_id] = peer
|
||||
elif DEBUG_DISCOVERY >= 2:
|
||||
print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy. Removing.")
|
||||
except Exception as e:
|
||||
if DEBUG_DISCOVERY >= 2: print(f"Exception occured when attempting to add {peer_id=}: {e}")
|
||||
await asyncio.sleep(1.0)
|
||||
if DEBUG_DISCOVERY >= 2: print(f"Exception occurred when attempting to add {peer_id=}: {e}")
|
||||
self.known_peers = new_known_peers
|
||||
await asyncio.sleep(5.0)
|
||||
|
||||
if DEBUG_DISCOVERY >= 2: print(f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}")
|
||||
|
||||
async def _get_peers(self):
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
current_mtime = await loop.run_in_executor(self._file_executor, os.path.getmtime, self.network_config_path)
|
||||
|
||||
if (self._cached_peers is not None and self._last_modified_time is not None and current_mtime <= self._last_modified_time):
|
||||
return self._cached_peers
|
||||
|
||||
topology = await loop.run_in_executor(self._file_executor, NetworkTopology.from_path, self.network_config_path)
|
||||
|
||||
if self.node_id not in topology.peers:
|
||||
raise ValueError(
|
||||
f"Node ID {self.node_id} not found in network config file "
|
||||
f"{self.network_config_path}. Please run with `node_id` set to "
|
||||
f"one of the keys in the config file: {[k for k, _ in topology.peers]}"
|
||||
)
|
||||
|
||||
peers_in_network = topology.peers
|
||||
peers_in_network.pop(self.node_id)
|
||||
|
||||
self._cached_peers = peers_in_network
|
||||
self._last_modified_time = current_mtime
|
||||
|
||||
return peers_in_network
|
||||
|
||||
except Exception as e:
|
||||
if DEBUG_DISCOVERY >= 2:
|
||||
print(f"Error when loading network config file from {self.network_config_path}. "
|
||||
f"Please update the config file in order to successfully discover peers. "
|
||||
f"Exception: {e}")
|
||||
return self._cached_peers
|
||||
|
||||
@@ -29,4 +29,4 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import asyncio
|
||||
import unittest
|
||||
from unittest import mock
|
||||
@@ -14,8 +15,12 @@ class TestSingleNodeManualDiscovery(unittest.IsolatedAsyncioTestCase):
|
||||
async def asyncSetUp(self):
|
||||
self.peer1 = mock.AsyncMock()
|
||||
self.peer1.connect = mock.AsyncMock()
|
||||
self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1)
|
||||
_ = self.discovery1.start()
|
||||
self.discovery1 = ManualDiscovery(
|
||||
root_path,
|
||||
"node1",
|
||||
create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1,
|
||||
)
|
||||
await self.discovery1.start()
|
||||
|
||||
async def asyncTearDown(self):
|
||||
await self.discovery1.stop()
|
||||
@@ -33,8 +38,16 @@ class TestManualDiscovery(unittest.IsolatedAsyncioTestCase):
|
||||
self.peer2 = mock.AsyncMock()
|
||||
self.peer1.connect = mock.AsyncMock()
|
||||
self.peer2.connect = mock.AsyncMock()
|
||||
self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1)
|
||||
self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2)
|
||||
self.discovery1 = ManualDiscovery(
|
||||
root_path,
|
||||
"node1",
|
||||
create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1,
|
||||
)
|
||||
self.discovery2 = ManualDiscovery(
|
||||
root_path,
|
||||
"node2",
|
||||
create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2,
|
||||
)
|
||||
await self.discovery1.start()
|
||||
await self.discovery2.start()
|
||||
|
||||
@@ -63,8 +76,16 @@ class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
|
||||
self.server2 = GRPCServer(self.node2, config.peers["node2"].address, config.peers["node2"].port)
|
||||
await self.server1.start()
|
||||
await self.server2.start()
|
||||
self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
|
||||
self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
|
||||
self.discovery1 = ManualDiscovery(
|
||||
root_path,
|
||||
"node1",
|
||||
create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
|
||||
)
|
||||
self.discovery2 = ManualDiscovery(
|
||||
root_path,
|
||||
"node2",
|
||||
create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
|
||||
)
|
||||
await self.discovery1.start()
|
||||
await self.discovery2.start()
|
||||
|
||||
@@ -98,6 +119,63 @@ class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertFalse(await peers1[0].is_connected())
|
||||
self.assertFalse(await peers2[0].is_connected())
|
||||
|
||||
async def test_dynamic_config_update(self):
|
||||
initial_peers = await self.discovery1.discover_peers(wait_for_peers=1)
|
||||
self.assertEqual(len(initial_peers), 1)
|
||||
|
||||
# Save original config for cleanup
|
||||
with open(root_path, "r") as f:
|
||||
original_config = json.load(f)
|
||||
|
||||
try:
|
||||
updated_config = {
|
||||
"peers": {
|
||||
**original_config["peers"],
|
||||
"node3": {
|
||||
"address": "localhost",
|
||||
"port": 50053,
|
||||
"device_capabilities": {
|
||||
"model": "Unknown Model",
|
||||
"chip": "Unknown Chip",
|
||||
"memory": 0,
|
||||
"flops": {"fp32": 0, "fp16": 0, "int8": 0},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
with open(root_path, "w") as f:
|
||||
json.dump(updated_config, f, indent=2)
|
||||
|
||||
node3 = mock.AsyncMock(spec=Node)
|
||||
server3 = GRPCServer(node3, "localhost", 50053)
|
||||
await server3.start()
|
||||
|
||||
try:
|
||||
# Wait for the config to be reloaded
|
||||
await asyncio.sleep(1.5)
|
||||
|
||||
updated_peers = await self.discovery1.discover_peers(wait_for_peers=2)
|
||||
self.assertEqual(len(updated_peers), 2)
|
||||
|
||||
for peer in updated_peers:
|
||||
await peer.connect()
|
||||
self.assertTrue(await peer.is_connected())
|
||||
|
||||
finally:
|
||||
await server3.stop()
|
||||
|
||||
finally:
|
||||
# Restore the original config file
|
||||
with open(root_path, "w") as f:
|
||||
json.dump(original_config, f, indent=2)
|
||||
|
||||
# Wait for the config to be reloaded again
|
||||
await asyncio.sleep(1.5)
|
||||
|
||||
updated_peers = await self.discovery1.discover_peers(wait_for_peers=1)
|
||||
self.assertEqual(len(updated_peers), 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(unittest.main())
|
||||
|
||||
@@ -51,10 +51,6 @@ class PeerHandle(ABC):
|
||||
async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
|
||||
pass
|
||||
|
||||
@@ -40,7 +40,7 @@ class TailscaleDiscovery(Discovery):
|
||||
self.update_task = None
|
||||
|
||||
async def start(self):
|
||||
self.device_capabilities = device_capabilities()
|
||||
self.device_capabilities = await device_capabilities()
|
||||
self.discovery_task = asyncio.create_task(self.task_discover_peers())
|
||||
self.cleanup_task = asyncio.create_task(self.task_cleanup_peers())
|
||||
self.update_task = asyncio.create_task(self.task_update_device_posture_attributes())
|
||||
|
||||
@@ -3,7 +3,7 @@ import json
|
||||
import socket
|
||||
import time
|
||||
import traceback
|
||||
from typing import List, Dict, Callable, Tuple, Coroutine
|
||||
from typing import List, Dict, Callable, Tuple, Coroutine, Optional
|
||||
from exo.networking.discovery import Discovery
|
||||
from exo.networking.peer_handle import PeerHandle
|
||||
from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
|
||||
@@ -23,15 +23,29 @@ class ListenProtocol(asyncio.DatagramProtocol):
|
||||
asyncio.create_task(self.on_message(data, addr))
|
||||
|
||||
|
||||
def get_broadcast_address(ip_addr: str) -> str:
|
||||
try:
|
||||
# Split IP into octets and create broadcast address for the subnet
|
||||
ip_parts = ip_addr.split('.')
|
||||
return f"{ip_parts[0]}.{ip_parts[1]}.{ip_parts[2]}.255"
|
||||
except:
|
||||
return "255.255.255.255"
|
||||
|
||||
|
||||
class BroadcastProtocol(asyncio.DatagramProtocol):
|
||||
def __init__(self, message: str, broadcast_port: int):
|
||||
def __init__(self, message: str, broadcast_port: int, source_ip: str):
|
||||
self.message = message
|
||||
self.broadcast_port = broadcast_port
|
||||
self.source_ip = source_ip
|
||||
|
||||
def connection_made(self, transport):
|
||||
sock = transport.get_extra_info("socket")
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
|
||||
transport.sendto(self.message.encode("utf-8"), ("<broadcast>", self.broadcast_port))
|
||||
# Try both subnet-specific and global broadcast
|
||||
broadcast_addr = get_broadcast_address(self.source_ip)
|
||||
transport.sendto(self.message.encode("utf-8"), (broadcast_addr, self.broadcast_port))
|
||||
if broadcast_addr != "255.255.255.255":
|
||||
transport.sendto(self.message.encode("utf-8"), ("255.255.255.255", self.broadcast_port))
|
||||
|
||||
|
||||
class UDPDiscovery(Discovery):
|
||||
@@ -45,7 +59,8 @@ class UDPDiscovery(Discovery):
|
||||
broadcast_interval: int = 2.5,
|
||||
discovery_timeout: int = 30,
|
||||
device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES,
|
||||
allowed_node_ids: List[str] = None,
|
||||
allowed_node_ids: Optional[List[str]] = None,
|
||||
allowed_interface_types: Optional[List[str]] = None,
|
||||
):
|
||||
self.node_id = node_id
|
||||
self.node_port = node_port
|
||||
@@ -56,13 +71,14 @@ class UDPDiscovery(Discovery):
|
||||
self.discovery_timeout = discovery_timeout
|
||||
self.device_capabilities = device_capabilities
|
||||
self.allowed_node_ids = allowed_node_ids
|
||||
self.allowed_interface_types = allowed_interface_types
|
||||
self.known_peers: Dict[str, Tuple[PeerHandle, float, float, int]] = {}
|
||||
self.broadcast_task = None
|
||||
self.listen_task = None
|
||||
self.cleanup_task = None
|
||||
|
||||
async def start(self):
|
||||
self.device_capabilities = device_capabilities()
|
||||
self.device_capabilities = await device_capabilities()
|
||||
self.broadcast_task = asyncio.create_task(self.task_broadcast_presence())
|
||||
self.listen_task = asyncio.create_task(self.task_listen_for_peers())
|
||||
self.cleanup_task = asyncio.create_task(self.task_cleanup_peers())
|
||||
@@ -82,11 +98,7 @@ class UDPDiscovery(Discovery):
|
||||
return [peer_handle for peer_handle, _, _, _ in self.known_peers.values()]
|
||||
|
||||
async def task_broadcast_presence(self):
|
||||
if DEBUG_DISCOVERY >= 2: print("Starting task_broadcast_presence...")
|
||||
|
||||
while True:
|
||||
# Explicitly broadcasting on all assigned ips since broadcasting on `0.0.0.0` on MacOS does not broadcast over
|
||||
# the Thunderbolt bridge when other connection modalities exist such as WiFi or Ethernet
|
||||
for addr, interface_name in get_all_ip_addresses_and_interfaces():
|
||||
interface_priority, interface_type = await get_interface_priority_and_type(interface_name)
|
||||
message = json.dumps({
|
||||
@@ -94,16 +106,26 @@ class UDPDiscovery(Discovery):
|
||||
"node_id": self.node_id,
|
||||
"grpc_port": self.node_port,
|
||||
"device_capabilities": self.device_capabilities.to_dict(),
|
||||
"priority": interface_priority, # TODO: Prioritise interfaces based on bandwidth, latency, and jitter e.g. prioritise Thunderbolt over WiFi.
|
||||
"priority": interface_priority,
|
||||
"interface_name": interface_name,
|
||||
"interface_type": interface_type,
|
||||
})
|
||||
if DEBUG_DISCOVERY >= 3: print(f"Broadcasting presence at ({addr} - {interface_name} - {interface_priority}): {message}")
|
||||
|
||||
transport = None
|
||||
try:
|
||||
transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(lambda: BroadcastProtocol(message, self.broadcast_port), local_addr=(addr, 0), family=socket.AF_INET)
|
||||
if DEBUG_DISCOVERY >= 3: print(f"Broadcasting presence at ({addr} - {interface_name} - {interface_priority})")
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
try:
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
|
||||
except AttributeError:
|
||||
pass
|
||||
sock.bind((addr, 0))
|
||||
|
||||
transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
|
||||
lambda: BroadcastProtocol(message, self.broadcast_port, addr),
|
||||
sock=sock
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error in broadcast presence ({addr} - {interface_name} - {interface_priority}): {e}")
|
||||
finally:
|
||||
@@ -111,7 +133,7 @@ class UDPDiscovery(Discovery):
|
||||
try: transport.close()
|
||||
except Exception as e:
|
||||
if DEBUG_DISCOVERY >= 2: print(f"Error closing transport: {e}")
|
||||
if DEBUG_DISCOVERY >= 2: traceback.print_exc()
|
||||
|
||||
await asyncio.sleep(self.broadcast_interval)
|
||||
|
||||
async def on_listen_message(self, data, addr):
|
||||
@@ -147,6 +169,12 @@ class UDPDiscovery(Discovery):
|
||||
peer_prio = message["priority"]
|
||||
peer_interface_name = message["interface_name"]
|
||||
peer_interface_type = message["interface_type"]
|
||||
|
||||
# Skip if interface type is not in allowed list
|
||||
if self.allowed_interface_types and peer_interface_type not in self.allowed_interface_types:
|
||||
if DEBUG_DISCOVERY >= 2: print(f"Ignoring peer {peer_id} as its interface type {peer_interface_type} is not in the allowed interface types list")
|
||||
return
|
||||
|
||||
device_capabilities = DeviceCapabilities(**message["device_capabilities"])
|
||||
|
||||
if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":
|
||||
|
||||
@@ -8,14 +8,14 @@ from typing import List, Dict, Optional, Tuple, Union, Set
|
||||
from exo.networking import Discovery, PeerHandle, Server
|
||||
from exo.inference.inference_engine import InferenceEngine, Shard
|
||||
from exo.topology.topology import Topology
|
||||
from exo.topology.device_capabilities import device_capabilities
|
||||
from exo.topology.device_capabilities import device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
|
||||
from exo.topology.partitioning_strategy import Partition, PartitioningStrategy, map_partitions_to_shards
|
||||
from exo import DEBUG
|
||||
from exo.helpers import AsyncCallbackSystem
|
||||
from exo.viz.topology_viz import TopologyViz
|
||||
from exo.download.hf.hf_helpers import RepoProgressEvent
|
||||
from exo.download.download_progress import RepoProgressEvent
|
||||
from exo.inference.inference_engine import get_inference_engine, InferenceEngine
|
||||
from exo.download.hf.hf_shard_download import HFShardDownloader
|
||||
from exo.download.shard_download import ShardDownloader
|
||||
|
||||
class Node:
|
||||
def __init__(
|
||||
@@ -24,20 +24,21 @@ class Node:
|
||||
server: Server,
|
||||
inference_engine: InferenceEngine,
|
||||
discovery: Discovery,
|
||||
shard_downloader: ShardDownloader,
|
||||
partitioning_strategy: PartitioningStrategy = None,
|
||||
max_generate_tokens: int = 1024,
|
||||
default_sample_temperature: float = 0.0,
|
||||
topology_viz: Optional[TopologyViz] = None,
|
||||
shard_downloader: Optional[HFShardDownloader] = None,
|
||||
):
|
||||
self.id = _id
|
||||
self.inference_engine = inference_engine
|
||||
self.server = server
|
||||
self.discovery = discovery
|
||||
self.shard_downloader = shard_downloader
|
||||
self.partitioning_strategy = partitioning_strategy
|
||||
self.peers: List[PeerHandle] = {}
|
||||
self.topology: Topology = Topology()
|
||||
self.device_capabilities = device_capabilities()
|
||||
self.device_capabilities = UNKNOWN_DEVICE_CAPABILITIES
|
||||
self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
|
||||
self.buffered_logits: Dict[str, List[np.ndarray]] = {}
|
||||
self.buffered_inputs: Dict[str, List[np.ndarray]] = {}
|
||||
@@ -52,10 +53,10 @@ class Node:
|
||||
self._on_opaque_status.register("node_status").on_next(self.on_node_status)
|
||||
self.node_download_progress: Dict[str, RepoProgressEvent] = {}
|
||||
self.topology_inference_engines_pool: List[List[str]] = []
|
||||
self.shard_downloader = shard_downloader
|
||||
self.outstanding_requests = {}
|
||||
|
||||
async def start(self, wait_for_peers: int = 0) -> None:
|
||||
self.device_capabilities = await device_capabilities()
|
||||
await self.server.start()
|
||||
await self.discovery.start()
|
||||
await self.update_peers(wait_for_peers)
|
||||
@@ -70,25 +71,28 @@ class Node:
|
||||
def on_node_status(self, request_id, opaque_status):
|
||||
try:
|
||||
status_data = json.loads(opaque_status)
|
||||
if status_data.get("type", "") == "supported_inference_engines":
|
||||
status_type = status_data.get("type", "")
|
||||
if status_type == "supported_inference_engines":
|
||||
node_id = status_data.get("node_id")
|
||||
engines = status_data.get("engines", [])
|
||||
self.topology_inference_engines_pool.append(engines)
|
||||
if status_data.get("type", "") == "node_status":
|
||||
elif status_type == "node_status":
|
||||
if status_data.get("status", "").startswith("start_"):
|
||||
self.current_topology.active_node_id = status_data.get("node_id")
|
||||
elif status_data.get("status", "").startswith("end_"):
|
||||
if status_data.get("node_id") == self.current_topology.active_node_id:
|
||||
self.current_topology.active_node_id = None
|
||||
|
||||
download_progress = None
|
||||
if status_data.get("type", "") == "download_progress":
|
||||
if status_type == "download_progress":
|
||||
if DEBUG >= 8: print(f"Download progress from {status_data.get('node_id')}: {status_data.get('progress')}")
|
||||
download_progress = RepoProgressEvent.from_dict(status_data.get('progress'))
|
||||
self.node_download_progress[status_data.get('node_id')] = download_progress
|
||||
|
||||
if self.topology_viz:
|
||||
self.topology_viz.update_visualization(self.topology, self.partitioning_strategy.partition(self.topology), self.id, self.node_download_progress)
|
||||
except Exception as e:
|
||||
if DEBUG >= 1: print(f"Error updating visualization: {e}")
|
||||
if DEBUG >= 1: print(f"Error on_node_status: {e}")
|
||||
if DEBUG >= 1: traceback.print_exc()
|
||||
|
||||
def get_supported_inference_engines(self):
|
||||
@@ -107,44 +111,58 @@ class Node:
|
||||
def get_topology_inference_engines(self) -> List[List[str]]:
|
||||
return self.topology_inference_engines_pool
|
||||
|
||||
token_count = 0
|
||||
first_token_time = 0
|
||||
async def process_inference_result(
|
||||
self,
|
||||
shard,
|
||||
result: np.ndarray,
|
||||
request_id: Optional[str] = None,
|
||||
inference_state: Optional[dict] = None,
|
||||
):
|
||||
if request_id not in self.buffered_token_output:
|
||||
self.buffered_token_output[request_id] = ([], False)
|
||||
is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
|
||||
if shard.is_last_layer() and not is_finished:
|
||||
token = await self.inference_engine.sample(result, temp=self.default_sample_temperature)
|
||||
await self.inference_engine.ensure_shard(shard)
|
||||
self.buffered_token_output[request_id][0].append(token.item())
|
||||
is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id or is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
|
||||
if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
|
||||
asyncio.create_task(self.broadcast_result(request_id, *self.buffered_token_output[request_id]))
|
||||
forward = token.reshape(1, -1)
|
||||
self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
|
||||
asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished))
|
||||
if shard.model_id != 'stable-diffusion-2-1-base':
|
||||
if request_id not in self.buffered_token_output:
|
||||
self.buffered_token_output[request_id] = ([], False)
|
||||
is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
|
||||
if shard.is_last_layer() and not is_finished:
|
||||
token = await self.inference_engine.sample(result, temp=self.default_sample_temperature)
|
||||
await self.inference_engine.ensure_shard(shard)
|
||||
self.buffered_token_output[request_id][0].append(token.item())
|
||||
is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id or is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
|
||||
if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
|
||||
forward = token.reshape(1, -1)
|
||||
intermediate_result = [self.buffered_token_output[request_id][0][-1]]
|
||||
else:
|
||||
forward = result
|
||||
else:
|
||||
await self.inference_engine.ensure_shard(shard)
|
||||
is_finished = inference_state.get("is_finished", False)
|
||||
intermediate_result, inference_state = self.handle_stable_diffusion(inference_state, result)
|
||||
forward = result
|
||||
if shard.is_last_layer():
|
||||
self.trigger_on_token_callbacks(request_id, intermediate_result, is_finished)
|
||||
asyncio.create_task(self.broadcast_result(request_id, intermediate_result, is_finished))
|
||||
|
||||
if is_finished:
|
||||
self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
|
||||
if shard.model_id != 'stable-diffusion-2-1-base':
|
||||
self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
|
||||
self.outstanding_requests.pop(request_id)
|
||||
else:
|
||||
self.outstanding_requests[request_id] = "waiting"
|
||||
asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1)))
|
||||
asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1), inference_state))
|
||||
|
||||
return np.array(self.buffered_token_output[request_id][0]) if shard.model_id != 'stable-diffusion-2-1-base' else intermediate_result
|
||||
|
||||
return np.array(self.buffered_token_output[request_id][0])
|
||||
|
||||
async def process_prompt(
|
||||
self,
|
||||
base_shard: Shard,
|
||||
prompt: str,
|
||||
request_id: Optional[str] = None,
|
||||
inference_state: Optional[dict] = {},
|
||||
) -> Optional[np.ndarray]:
|
||||
shard = self.get_current_shard(base_shard)
|
||||
start_time = time.perf_counter_ns()
|
||||
asyncio.create_task(
|
||||
self.broadcast_opaque_status(
|
||||
request_id,
|
||||
@@ -160,7 +178,7 @@ class Node:
|
||||
)
|
||||
)
|
||||
start_time = time.perf_counter_ns()
|
||||
resp = await self._process_prompt(base_shard, prompt, request_id)
|
||||
resp = await self._process_prompt(base_shard, prompt, request_id, inference_state)
|
||||
end_time = time.perf_counter_ns()
|
||||
elapsed_time_ns = end_time - start_time
|
||||
asyncio.create_task(
|
||||
@@ -175,27 +193,26 @@ class Node:
|
||||
"prompt": prompt,
|
||||
"request_id": request_id,
|
||||
"elapsed_time_ns": elapsed_time_ns,
|
||||
"result_size": resp.size if resp is not None else 0,
|
||||
}),
|
||||
)
|
||||
)
|
||||
return resp
|
||||
if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=} {elapsed_time_ns=}")
|
||||
|
||||
async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.ndarray]:
|
||||
async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[dict] = None) -> Optional[np.ndarray]:
|
||||
if request_id is None:
|
||||
request_id = str(uuid.uuid4())
|
||||
shard = self.get_current_shard(base_shard)
|
||||
|
||||
if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}")
|
||||
|
||||
if not shard.is_first_layer():
|
||||
if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
|
||||
self.outstanding_requests[request_id] = "waiting"
|
||||
resp = await self.forward_prompt(shard, prompt, request_id, 0)
|
||||
resp = await self.forward_prompt(shard, prompt, request_id, 0, inference_state)
|
||||
return None
|
||||
else:
|
||||
self.outstanding_requests[request_id] = "processing"
|
||||
result = await self.inference_engine.infer_prompt(request_id, shard, prompt)
|
||||
ret = await self.process_inference_result(shard, result, request_id)
|
||||
result, inference_state = await self.inference_engine.infer_prompt(request_id, shard, prompt, inference_state)
|
||||
ret = await self.process_inference_result(shard, result, request_id, inference_state)
|
||||
return result
|
||||
|
||||
async def enqueue_example(
|
||||
@@ -308,7 +325,7 @@ class Node:
|
||||
loss, grad = await self.inference_engine.train(request_id, shard, example, target, length)
|
||||
else:
|
||||
self.outstanding_requests[request_id] = "preprocessing"
|
||||
step = await self.inference_engine.infer_tensor(request_id, shard, example)
|
||||
step, _ = await self.inference_engine.infer_tensor(request_id, shard, example)
|
||||
self.outstanding_requests[request_id] = "waiting"
|
||||
loss, backgrad = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset = 1))
|
||||
self.outstanding_requests[request_id] = "training"
|
||||
@@ -324,7 +341,7 @@ class Node:
|
||||
loss = await self.inference_engine.evaluate(request_id, shard, example, target, length)
|
||||
else:
|
||||
self.outstanding_requests[request_id] = "preprocessing"
|
||||
step = await self.inference_engine.infer_tensor(request_id, shard, example)
|
||||
step, _ = await self.inference_engine.infer_tensor(request_id, shard, example)
|
||||
self.outstanding_requests[request_id] = "waiting"
|
||||
loss = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset = 1))
|
||||
self.outstanding_requests.pop(request_id)
|
||||
@@ -340,65 +357,35 @@ class Node:
|
||||
base_shard: Shard,
|
||||
tensor: np.ndarray,
|
||||
request_id: Optional[str] = None,
|
||||
inference_state: Optional[dict] = None,
|
||||
) -> Optional[np.ndarray]:
|
||||
shard = self.get_current_shard(base_shard)
|
||||
asyncio.create_task(
|
||||
self.broadcast_opaque_status(
|
||||
request_id,
|
||||
json.dumps({
|
||||
"type": "node_status",
|
||||
"node_id": self.id,
|
||||
"status": "start_process_tensor",
|
||||
"base_shard": base_shard.to_dict(),
|
||||
"shard": shard.to_dict(),
|
||||
"tensor_size": tensor.size,
|
||||
"tensor_shape": tensor.shape,
|
||||
"request_id": request_id,
|
||||
}),
|
||||
)
|
||||
)
|
||||
start_time = time.perf_counter_ns()
|
||||
resp = await self._process_tensor(shard, tensor, request_id)
|
||||
resp = await self._process_tensor(shard, tensor, request_id, inference_state)
|
||||
end_time = time.perf_counter_ns()
|
||||
elapsed_time_ns = end_time - start_time
|
||||
asyncio.create_task(
|
||||
self.broadcast_opaque_status(
|
||||
request_id,
|
||||
json.dumps({
|
||||
"type": "node_status",
|
||||
"node_id": self.id,
|
||||
"status": "end_process_tensor",
|
||||
"base_shard": base_shard.to_dict(),
|
||||
"shard": shard.to_dict(),
|
||||
"request_id": request_id,
|
||||
"elapsed_time_ns": elapsed_time_ns,
|
||||
"result_size": resp.size if resp is not None else 0,
|
||||
}),
|
||||
)
|
||||
)
|
||||
return resp
|
||||
if DEBUG >= 2: print(f"[{request_id}] process_tensor: {base_shard=} {shard=} {tensor.size=} {tensor.shape=} {elapsed_time_ns=}")
|
||||
|
||||
async def _process_tensor(
|
||||
self,
|
||||
base_shard: Shard,
|
||||
tensor: np.ndarray,
|
||||
request_id: Optional[str] = None,
|
||||
inference_state: Optional[dict] = None,
|
||||
) -> Optional[np.ndarray]:
|
||||
if request_id is None:
|
||||
request_id = str(uuid.uuid4())
|
||||
shard = self.get_current_shard(base_shard)
|
||||
|
||||
if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
|
||||
try:
|
||||
self.outstanding_requests[request_id] = "processing"
|
||||
result = await self.inference_engine.infer_tensor(request_id, shard, tensor)
|
||||
ret = await self.process_inference_result(shard, result, request_id)
|
||||
result, inference_state = await self.inference_engine.infer_tensor(request_id, shard, tensor, inference_state)
|
||||
ret = await self.process_inference_result(shard, result, request_id, inference_state)
|
||||
return ret
|
||||
except Exception as e:
|
||||
self.outstanding_requests.pop(request_id)
|
||||
print(f"Error processing tensor for shard {shard}: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
async def forward_example(
|
||||
self,
|
||||
@@ -427,19 +414,20 @@ class Node:
|
||||
prompt: str,
|
||||
request_id: str,
|
||||
target_index: int,
|
||||
inference_state: Optional[dict] = None,
|
||||
) -> None:
|
||||
if DEBUG >= 1: print(f"target partition index: {target_index}")
|
||||
target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
|
||||
next_shard = self.get_current_shard(base_shard, target_index)
|
||||
if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. next shard: {next_shard}")
|
||||
if target_id == self.id:
|
||||
await self.process_prompt(next_shard, prompt, request_id)
|
||||
await self.process_prompt(next_shard, prompt, request_id, inference_state)
|
||||
else:
|
||||
target_peer = next((p for p in self.peers if p.id() == target_id), None)
|
||||
if not target_peer:
|
||||
raise ValueError(f"Peer for {target_index} not found")
|
||||
if DEBUG >= 1: print(f"Sending prompt to {target_peer.id()}: {prompt}")
|
||||
await target_peer.send_prompt(next_shard, prompt, request_id=request_id)
|
||||
await target_peer.send_prompt(next_shard, prompt, request_id=request_id, inference_state=inference_state)
|
||||
|
||||
async def forward_tensor(
|
||||
self,
|
||||
@@ -447,19 +435,20 @@ class Node:
|
||||
tensor: np.ndarray,
|
||||
request_id: str,
|
||||
target_index: int,
|
||||
inference_state: Optional[dict] = None,
|
||||
) -> None:
|
||||
if DEBUG >= 1: print(f"target partition index: {target_index}")
|
||||
target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
|
||||
next_shard = self.get_current_shard(base_shard, target_index)
|
||||
if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. target shard: {next_shard}")
|
||||
if target_id == self.id:
|
||||
await self.process_tensor(next_shard, tensor, request_id)
|
||||
await self.process_tensor(next_shard, tensor, request_id, inference_state)
|
||||
else:
|
||||
target_peer = next((p for p in self.peers if p.id() == target_id), None)
|
||||
if not target_peer:
|
||||
raise ValueError(f"Peer for {target_index} not found")
|
||||
if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor}")
|
||||
await target_peer.send_tensor(next_shard, tensor, request_id=request_id)
|
||||
await target_peer.send_tensor(next_shard, tensor, request_id=request_id, inference_state=inference_state)
|
||||
|
||||
def get_partition_index(self, offset: int = 0):
|
||||
if not self.partitioning_strategy:
|
||||
@@ -542,18 +531,13 @@ class Node:
|
||||
try:
|
||||
did_peers_change = await self.update_peers()
|
||||
if DEBUG >= 2: print(f"{did_peers_change=}")
|
||||
await self.collect_topology(set())
|
||||
if did_peers_change:
|
||||
await self.collect_topology(set())
|
||||
await self.select_best_inference_engine()
|
||||
except Exception as e:
|
||||
print(f"Error collecting topology: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
|
||||
if request_id not in self.buffered_token_output:
|
||||
return None, False
|
||||
return np.array(self.buffered_token_output[request_id][0]), self.buffered_token_output[request_id][1]
|
||||
|
||||
async def collect_topology(self, visited: set[str], max_depth: int = 4) -> Topology:
|
||||
next_topology = Topology()
|
||||
next_topology.update_node(self.id, self.device_capabilities)
|
||||
@@ -598,10 +582,11 @@ class Node:
|
||||
return self._on_opaque_status
|
||||
|
||||
def trigger_on_token_callbacks(self, request_id: str, tokens: List[int], is_finished: bool) -> None:
|
||||
if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} num_tokens={len(tokens)} {is_finished=}")
|
||||
if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} {tokens=} {is_finished=}")
|
||||
self.on_token.trigger_all(request_id, tokens, is_finished)
|
||||
|
||||
async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
|
||||
if DEBUG >= 2: print(f"Broadcasting result: {request_id=} {result=} {is_finished=}")
|
||||
async def send_result_to_peer(peer):
|
||||
try:
|
||||
await asyncio.wait_for(peer.send_result(request_id, result, is_finished), timeout=15.0)
|
||||
@@ -632,3 +617,12 @@ class Node:
|
||||
@property
|
||||
def current_topology(self) -> Topology:
|
||||
return self.topology
|
||||
|
||||
def handle_stable_diffusion(self, inference_state, result):
|
||||
if inference_state['is_step_finished']:
|
||||
inference_state['step']+=1
|
||||
progress = [inference_state['step'],inference_state['total_steps']]
|
||||
intermediate_result = result
|
||||
if progress[0] == progress[1]:
|
||||
intermediate_result = result
|
||||
return intermediate_result, inference_state
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import unittest
|
||||
from unittest.mock import Mock, AsyncMock
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from .node import Node
|
||||
from exo.networking.peer_handle import PeerHandle
|
||||
|
||||
from exo.download.shard_download import NoopShardDownloader
|
||||
|
||||
class TestNode(unittest.IsolatedAsyncioTestCase):
|
||||
def setUp(self):
|
||||
@@ -21,7 +22,7 @@ class TestNode(unittest.IsolatedAsyncioTestCase):
|
||||
mock_peer2.id.return_value = "peer2"
|
||||
self.mock_discovery.discover_peers = AsyncMock(return_value=[mock_peer1, mock_peer2])
|
||||
|
||||
self.node = Node("test_node", self.mock_server, self.mock_inference_engine, "localhost", 50051, self.mock_discovery)
|
||||
self.node = Node("test_node", self.mock_server, self.mock_inference_engine, "localhost", 50051, self.mock_discovery, NoopShardDownloader())
|
||||
|
||||
async def asyncSetUp(self):
|
||||
await self.node.start()
|
||||
@@ -55,3 +56,11 @@ class TestNode(unittest.IsolatedAsyncioTestCase):
|
||||
await self.node.process_tensor(input_tensor, None)
|
||||
|
||||
self.node.inference_engine.process_shard.assert_called_once_with(input_tensor)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_node_capabilities():
|
||||
node = Node()
|
||||
await node.initialize()
|
||||
caps = await node.get_device_capabilities()
|
||||
assert caps is not None
|
||||
assert caps.model != ""
|
||||
|
||||
166
exo/orchestration/tracing.py
Normal file
166
exo/orchestration/tracing.py
Normal file
@@ -0,0 +1,166 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Any
|
||||
from opentelemetry import trace, context
|
||||
from opentelemetry.trace import Status, StatusCode, SpanContext
|
||||
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
|
||||
from contextlib import contextmanager
|
||||
import time
|
||||
from threading import Lock
|
||||
|
||||
@dataclass
|
||||
class TraceContext:
|
||||
request_id: str
|
||||
sequence_number: int
|
||||
current_span: Optional[trace.Span] = None
|
||||
trace_parent: Optional[str] = None
|
||||
token_group_span: Optional[trace.Span] = None
|
||||
token_count: int = 0
|
||||
token_group_size: int = 10 # Default group size
|
||||
request_span: Optional[trace.Span] = None # Track the main request span
|
||||
|
||||
class Tracer:
|
||||
def __init__(self):
|
||||
self.tracer = trace.get_tracer("exo")
|
||||
self.contexts: Dict[str, TraceContext] = {}
|
||||
self._lock = Lock()
|
||||
self.propagator = TraceContextTextMapPropagator()
|
||||
|
||||
def get_context(self, request_id: str) -> Optional[TraceContext]:
|
||||
with self._lock:
|
||||
return self.contexts.get(request_id)
|
||||
|
||||
def set_context(self, request_id: str, context: TraceContext):
|
||||
with self._lock:
|
||||
self.contexts[request_id] = context
|
||||
|
||||
def inject_context(self, span: trace.Span) -> str:
|
||||
"""Inject current span context into carrier for propagation"""
|
||||
carrier = {}
|
||||
ctx = trace.set_span_in_context(span)
|
||||
self.propagator.inject(carrier, context=ctx)
|
||||
return carrier.get("traceparent", "")
|
||||
|
||||
def extract_context(self, trace_parent: str) -> Optional[context.Context]:
|
||||
"""Extract span context from carrier"""
|
||||
if not trace_parent:
|
||||
return None
|
||||
carrier = {"traceparent": trace_parent}
|
||||
return self.propagator.extract(carrier)
|
||||
|
||||
def create_context_from_parent(self, request_id: str, trace_parent: str, sequence_number: int = 0) -> TraceContext:
|
||||
"""Create a new context with the given trace parent"""
|
||||
parent_ctx = self.extract_context(trace_parent)
|
||||
if parent_ctx:
|
||||
# Create a new request span that links to the parent context
|
||||
request_span = self.tracer.start_span(
|
||||
"request",
|
||||
context=parent_ctx,
|
||||
attributes={
|
||||
"request_id": request_id,
|
||||
"sequence_number": sequence_number
|
||||
}
|
||||
)
|
||||
return TraceContext(
|
||||
request_id=request_id,
|
||||
sequence_number=sequence_number,
|
||||
request_span=request_span,
|
||||
current_span=request_span,
|
||||
trace_parent=trace_parent
|
||||
)
|
||||
return TraceContext(request_id=request_id, sequence_number=sequence_number)
|
||||
|
||||
def handle_token(self, context: TraceContext, token: int, is_finished: bool = False):
|
||||
"""Handle token generation and manage token group spans"""
|
||||
context.token_count += 1
|
||||
|
||||
# Start a new token group span if needed
|
||||
if not context.token_group_span and context.request_span:
|
||||
group_number = (context.token_count - 1) // context.token_group_size + 1
|
||||
|
||||
# Create token group span as child of request span
|
||||
parent_ctx = trace.set_span_in_context(context.request_span)
|
||||
context.token_group_span = self.tracer.start_span(
|
||||
f"token_group_{group_number}",
|
||||
context=parent_ctx,
|
||||
attributes={
|
||||
"request_id": context.request_id,
|
||||
"group.number": group_number,
|
||||
"group.start_token": context.token_count,
|
||||
"group.max_tokens": context.token_group_size
|
||||
}
|
||||
)
|
||||
|
||||
# Add token to current group span
|
||||
if context.token_group_span:
|
||||
relative_pos = ((context.token_count - 1) % context.token_group_size) + 1
|
||||
context.token_group_span.set_attribute(f"token.{relative_pos}", token)
|
||||
context.token_group_span.set_attribute("token.count", relative_pos)
|
||||
|
||||
# End current group span if we've reached the group size or if generation is finished
|
||||
if context.token_count % context.token_group_size == 0 or is_finished:
|
||||
context.token_group_span.set_attribute("token.final_count", relative_pos)
|
||||
context.token_group_span.end()
|
||||
context.token_group_span = None
|
||||
|
||||
@contextmanager
|
||||
def start_span(self, name: str, context: TraceContext, extra_attributes: Optional[Dict[str, Any]] = None):
|
||||
"""Start a new span with proper parent context"""
|
||||
attributes = {
|
||||
"request_id": context.request_id,
|
||||
"sequence_number": context.sequence_number
|
||||
}
|
||||
if extra_attributes:
|
||||
attributes.update(extra_attributes)
|
||||
|
||||
# Use request span as parent if available
|
||||
parent_ctx = None
|
||||
if context.request_span:
|
||||
parent_ctx = trace.set_span_in_context(context.request_span)
|
||||
elif context.trace_parent:
|
||||
parent_ctx = self.extract_context(context.trace_parent)
|
||||
if parent_ctx and not context.request_span:
|
||||
# Create a new request span that links to the parent context
|
||||
context.request_span = self.tracer.start_span(
|
||||
"request",
|
||||
context=parent_ctx,
|
||||
attributes={
|
||||
"request_id": context.request_id,
|
||||
"sequence_number": context.sequence_number
|
||||
}
|
||||
)
|
||||
parent_ctx = trace.set_span_in_context(context.request_span)
|
||||
elif context.current_span:
|
||||
parent_ctx = trace.set_span_in_context(context.current_span)
|
||||
|
||||
# Create span with parent context if it exists
|
||||
if parent_ctx:
|
||||
span = self.tracer.start_span(
|
||||
name,
|
||||
context=parent_ctx,
|
||||
attributes=attributes
|
||||
)
|
||||
else:
|
||||
span = self.tracer.start_span(
|
||||
name,
|
||||
attributes=attributes
|
||||
)
|
||||
|
||||
# Update context with current span
|
||||
prev_span = context.current_span
|
||||
context.current_span = span
|
||||
|
||||
try:
|
||||
start_time = time.perf_counter()
|
||||
yield span
|
||||
duration = time.perf_counter() - start_time
|
||||
span.set_attribute("duration_s", duration)
|
||||
span.set_status(Status(StatusCode.OK))
|
||||
except Exception as e:
|
||||
span.set_status(Status(StatusCode.ERROR, str(e)))
|
||||
raise
|
||||
finally:
|
||||
span.end()
|
||||
context.current_span = prev_span
|
||||
|
||||
# Global tracer instance
|
||||
tracer = Tracer()
|
||||
@@ -1,27 +0,0 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
prometheus:
|
||||
image: prom/prometheus:latest
|
||||
container_name: prometheus
|
||||
volumes:
|
||||
- ./prometheus.yml:/etc/prometheus/prometheus.yml
|
||||
command:
|
||||
- '--config.file=/etc/prometheus/prometheus.yml'
|
||||
ports:
|
||||
- "9090:9090"
|
||||
networks:
|
||||
- monitoring
|
||||
|
||||
grafana:
|
||||
image: grafana/grafana:latest
|
||||
container_name: grafana
|
||||
ports:
|
||||
- "3000:3000"
|
||||
networks:
|
||||
- monitoring
|
||||
depends_on:
|
||||
- prometheus
|
||||
|
||||
networks:
|
||||
monitoring:
|
||||
@@ -1,29 +0,0 @@
|
||||
from exo.orchestration import Node
|
||||
from prometheus_client import start_http_server, Counter, Histogram
|
||||
import json
|
||||
|
||||
# Create metrics to track time spent and requests made.
|
||||
PROCESS_PROMPT_COUNTER = Counter("process_prompt_total", "Total number of prompts processed", ["node_id"])
|
||||
PROCESS_TENSOR_COUNTER = Counter("process_tensor_total", "Total number of tensors processed", ["node_id"])
|
||||
PROCESS_TENSOR_TIME = Histogram("process_tensor_seconds", "Time spent processing tensor", ["node_id"])
|
||||
|
||||
|
||||
def start_metrics_server(node: Node, port: int):
|
||||
start_http_server(port)
|
||||
|
||||
def _on_opaque_status(request_id, opaque_status: str):
|
||||
status_data = json.loads(opaque_status)
|
||||
_type = status_data.get("type", "")
|
||||
node_id = status_data.get("node_id", "")
|
||||
if _type != "node_status":
|
||||
return
|
||||
status = status_data.get("status", "")
|
||||
|
||||
if status == "end_process_prompt":
|
||||
PROCESS_PROMPT_COUNTER.labels(node_id=node_id).inc()
|
||||
elif status == "end_process_tensor":
|
||||
elapsed_time_ns = status_data.get("elapsed_time_ns", 0)
|
||||
PROCESS_TENSOR_COUNTER.labels(node_id=node_id).inc()
|
||||
PROCESS_TENSOR_TIME.labels(node_id=node_id).observe(elapsed_time_ns/1e9) # Convert ns to seconds
|
||||
|
||||
node.on_opaque_status.register("stats").on_next(_on_opaque_status)
|
||||
@@ -1,7 +0,0 @@
|
||||
global:
|
||||
scrape_interval: 15s
|
||||
|
||||
scrape_configs:
|
||||
- job_name: 'exo-node'
|
||||
static_configs:
|
||||
- targets: ['host.docker.internal:8005']
|
||||
@@ -139,6 +139,14 @@ main {
|
||||
padding: 0.5rem 1rem;
|
||||
border-radius: 20px;
|
||||
}
|
||||
|
||||
@media(max-width: 1482px) {
|
||||
.messages {
|
||||
padding-left: 16px;
|
||||
padding-right: 16px;
|
||||
}
|
||||
}
|
||||
|
||||
.message-role-assistant {
|
||||
background-color: var(--secondary-color);
|
||||
margin-right: auto;
|
||||
@@ -149,6 +157,12 @@ main {
|
||||
background-color: var(--primary-color);
|
||||
color: #000;
|
||||
}
|
||||
|
||||
.message-role-user p {
|
||||
white-space: pre-wrap;
|
||||
word-wrap: break-word;
|
||||
}
|
||||
|
||||
.download-progress {
|
||||
position: fixed;
|
||||
bottom: 11rem;
|
||||
@@ -654,4 +668,179 @@ main {
|
||||
|
||||
.model-download-button i {
|
||||
font-size: 0.9em;
|
||||
}
|
||||
}
|
||||
|
||||
.topology-section {
|
||||
margin-bottom: 30px;
|
||||
padding: 15px;
|
||||
background: rgba(255, 255, 255, 0.05);
|
||||
border-radius: 8px;
|
||||
}
|
||||
|
||||
.topology-visualization {
|
||||
min-height: 150px;
|
||||
position: relative;
|
||||
margin-top: 10px;
|
||||
}
|
||||
|
||||
.topology-loading {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 10px;
|
||||
color: #666;
|
||||
font-size: 0.9em;
|
||||
}
|
||||
|
||||
.topology-node {
|
||||
padding: 8px;
|
||||
background: rgba(255, 255, 255, 0.05);
|
||||
border-radius: 4px;
|
||||
margin: 4px 0;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 4px;
|
||||
}
|
||||
|
||||
.node-info {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
font-size: 0.9em;
|
||||
}
|
||||
|
||||
.topology-node .status {
|
||||
width: 6px;
|
||||
height: 6px;
|
||||
border-radius: 50%;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.topology-node .status.active {
|
||||
background: #4CAF50;
|
||||
}
|
||||
|
||||
.topology-node .status.inactive {
|
||||
background: #666;
|
||||
}
|
||||
|
||||
.node-details {
|
||||
padding-left: 12px;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 2px;
|
||||
font-size: 0.8em;
|
||||
opacity: 0.6;
|
||||
}
|
||||
|
||||
.node-details span {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.peer-connections {
|
||||
margin-top: 8px;
|
||||
padding-left: 12px;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 4px;
|
||||
}
|
||||
|
||||
.peer-connection {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
font-size: 0.85em;
|
||||
color: #a0a0a0;
|
||||
}
|
||||
|
||||
.peer-connection i {
|
||||
font-size: 0.8em;
|
||||
color: #666;
|
||||
}
|
||||
|
||||
.thinking-block {
|
||||
background-color: rgba(255, 255, 255, 0.05);
|
||||
border-radius: 8px;
|
||||
margin: 8px 0;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.thinking-header {
|
||||
background-color: rgba(255, 255, 255, 0.1);
|
||||
padding: 8px 12px;
|
||||
font-size: 0.9em;
|
||||
color: #a0a0a0;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.thinking-content {
|
||||
padding: 12px;
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
|
||||
@keyframes thinking-spin {
|
||||
to { transform: rotate(360deg); }
|
||||
}
|
||||
|
||||
.thinking-header.thinking::before {
|
||||
content: '';
|
||||
width: 12px;
|
||||
height: 12px;
|
||||
border: 2px solid #a0a0a0;
|
||||
border-top-color: transparent;
|
||||
border-radius: 50%;
|
||||
animation: thinking-spin 1s linear infinite;
|
||||
}
|
||||
|
||||
.model-group {
|
||||
margin-bottom: 12px;
|
||||
}
|
||||
|
||||
.model-group-header,
|
||||
.model-subgroup-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
padding: 8px 12px;
|
||||
background-color: var(--primary-bg-color);
|
||||
border-radius: 6px;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s ease;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
|
||||
.model-group-header:hover,
|
||||
.model-subgroup-header:hover {
|
||||
background-color: var(--secondary-color-transparent);
|
||||
}
|
||||
|
||||
.model-group-content {
|
||||
padding-left: 12px;
|
||||
}
|
||||
|
||||
.model-subgroup {
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
|
||||
.model-subgroup-header {
|
||||
font-size: 0.9em;
|
||||
background-color: rgba(255, 255, 255, 0.05);
|
||||
}
|
||||
|
||||
.model-subgroup-content {
|
||||
padding-left: 12px;
|
||||
}
|
||||
|
||||
.group-header-content {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.model-count {
|
||||
font-size: 0.8em;
|
||||
color: var(--secondary-color-transparent);
|
||||
font-family: monospace;
|
||||
}
|
||||
|
||||
@@ -22,75 +22,125 @@
|
||||
<link href="/static/unpkg.com/@highlightjs/cdn-assets@11.9.0/styles/vs2015.min.css" rel="stylesheet"/>
|
||||
<link href="/index.css" rel="stylesheet"/>
|
||||
<link href="/common.css" rel="stylesheet"/>
|
||||
<script src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
</head>
|
||||
<body>
|
||||
<main x-data="state" x-init="console.log(endpoint)">
|
||||
<div class="sidebar">
|
||||
<!-- Add topology section -->
|
||||
<div class="topology-section">
|
||||
<h2 class="megrim-regular">Network Topology</h2>
|
||||
<div class="topology-visualization"
|
||||
x-init="initTopology()"
|
||||
x-ref="topologyViz">
|
||||
<!-- Loading indicator for topology -->
|
||||
<div class="topology-loading" x-show="!topology">
|
||||
<i class="fas fa-spinner fa-spin"></i>
|
||||
<span>Loading topology...</span>
|
||||
</div>
|
||||
<!-- Topology visualization will be rendered here -->
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<h2 class="megrim-regular" style="margin-bottom: 20px;">Models</h2>
|
||||
|
||||
<div style="display: flex; align-items: center; margin-bottom: 10px;">
|
||||
<label style="margin-right: 5px;">
|
||||
<input type="checkbox" x-model="showDownloadedOnly" style="margin-right: 5px;">
|
||||
Downloaded only
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<!-- Loading indicator -->
|
||||
<div class="loading-container" x-show="Object.keys(models).length === 0">
|
||||
<i class="fas fa-spinner fa-spin"></i>
|
||||
<span>Loading models...</span>
|
||||
</div>
|
||||
|
||||
<template x-for="(model, key) in models" :key="key">
|
||||
<div class="model-option"
|
||||
:class="{ 'selected': cstate.selectedModel === key }"
|
||||
@click="cstate.selectedModel = key">
|
||||
<div class="model-header">
|
||||
<div class="model-name" x-text="model.name"></div>
|
||||
<button
|
||||
@click.stop="deleteModel(key, model)"
|
||||
class="model-delete-button"
|
||||
x-show="model.download_percentage > 0">
|
||||
<i class="fas fa-trash"></i>
|
||||
</button>
|
||||
</div>
|
||||
<div class="model-info">
|
||||
<div class="model-progress">
|
||||
<template x-if="model.loading">
|
||||
<span><i class="fas fa-spinner fa-spin"></i> Checking download status...</span>
|
||||
</template>
|
||||
<div class="model-progress-info">
|
||||
<template x-if="!model.loading && model.download_percentage != null">
|
||||
<span>
|
||||
<!-- Check if there's an active download for this model -->
|
||||
<template x-if="downloadProgress?.some(p =>
|
||||
p.repo_id && p.repo_id.toLowerCase().includes(key.toLowerCase()) && !p.isComplete
|
||||
)">
|
||||
<i class="fas fa-circle-notch fa-spin"></i>
|
||||
</template>
|
||||
<span x-text="model.downloaded ? 'Downloaded' : `${Math.round(model.download_percentage)}% downloaded`"></span>
|
||||
</span>
|
||||
</template>
|
||||
<template x-if="!model.loading && (model.download_percentage === null || model.download_percentage < 100) && !downloadProgress?.some(p => !p.isComplete)">
|
||||
<button
|
||||
@click.stop="handleDownload(key)"
|
||||
class="model-download-button">
|
||||
<i class="fas fa-download"></i>
|
||||
<span x-text="(model.download_percentage > 0 && model.download_percentage < 100) ? 'Continue Downloading' : 'Download'"></span>
|
||||
</button>
|
||||
</template>
|
||||
</div>
|
||||
|
||||
<!-- Group models by prefix -->
|
||||
<template x-for="[mainPrefix, subGroups] in Object.entries(groupModelsByPrefix(models))" :key="mainPrefix">
|
||||
<div class="model-group">
|
||||
<div class="model-group-header" @click="toggleGroup(mainPrefix)">
|
||||
<div class="group-header-content">
|
||||
<span x-text="mainPrefix"></span>
|
||||
<span class="model-count" x-text="getGroupCounts(Object.values(subGroups).flatMap(group => Object.values(group)))"></span>
|
||||
</div>
|
||||
<template x-if="model.total_size">
|
||||
<div class="model-size" x-text="model.total_downloaded ?
|
||||
`${formatBytes(model.total_downloaded)} / ${formatBytes(model.total_size)}` :
|
||||
formatBytes(model.total_size)">
|
||||
<i class="fas" :class="isGroupExpanded(mainPrefix) ? 'fa-chevron-down' : 'fa-chevron-right'"></i>
|
||||
</div>
|
||||
|
||||
<div class="model-group-content" x-show="isGroupExpanded(mainPrefix)" x-transition>
|
||||
<template x-for="[subPrefix, groupModels] in Object.entries(subGroups)" :key="subPrefix">
|
||||
<div class="model-subgroup">
|
||||
<div class="model-subgroup-header" @click.stop="toggleGroup(mainPrefix, subPrefix)">
|
||||
<div class="group-header-content">
|
||||
<span x-text="subPrefix"></span>
|
||||
<span class="model-count" x-text="getGroupCounts(groupModels)"></span>
|
||||
</div>
|
||||
<i class="fas" :class="isGroupExpanded(mainPrefix, subPrefix) ? 'fa-chevron-down' : 'fa-chevron-right'"></i>
|
||||
</div>
|
||||
|
||||
<div class="model-subgroup-content" x-show="isGroupExpanded(mainPrefix, subPrefix)" x-transition>
|
||||
<template x-for="(model, key) in groupModels" :key="key">
|
||||
<div class="model-option"
|
||||
:class="{ 'selected': cstate.selectedModel === key }"
|
||||
@click="cstate.selectedModel = key">
|
||||
<div class="model-header">
|
||||
<div class="model-name" x-text="model.name"></div>
|
||||
<button
|
||||
@click.stop="deleteModel(key, model)"
|
||||
class="model-delete-button"
|
||||
x-show="model.download_percentage > 0">
|
||||
<i class="fas fa-trash"></i>
|
||||
</button>
|
||||
</div>
|
||||
<div class="model-info">
|
||||
<div class="model-progress">
|
||||
<template x-if="model.loading">
|
||||
<span><i class="fas fa-spinner fa-spin"></i> Checking download status...</span>
|
||||
</template>
|
||||
<div class="model-progress-info">
|
||||
<template x-if="!model.loading && model.download_percentage != null">
|
||||
<span>
|
||||
<template x-if="downloadProgress?.some(p =>
|
||||
p.repo_id && p.repo_id.toLowerCase().includes(key.toLowerCase()) && !p.isComplete
|
||||
)">
|
||||
<i class="fas fa-circle-notch fa-spin"></i>
|
||||
</template>
|
||||
<span x-text="model.downloaded ? 'Downloaded' : `${Math.round(model.download_percentage)}% downloaded`"></span>
|
||||
</span>
|
||||
</template>
|
||||
<template x-if="!model.loading && (model.download_percentage === null || model.download_percentage < 100) && !downloadProgress?.some(p => !p.isComplete)">
|
||||
<button
|
||||
@click.stop="handleDownload(key)"
|
||||
class="model-download-button">
|
||||
<i class="fas fa-download"></i>
|
||||
<span x-text="(model.download_percentage > 0 && model.download_percentage < 100) ? 'Continue Downloading' : 'Download'"></span>
|
||||
</button>
|
||||
</template>
|
||||
</div>
|
||||
</div>
|
||||
<template x-if="model.total_size">
|
||||
<div class="model-size" x-text="model.total_downloaded ?
|
||||
`${formatBytes(model.total_downloaded)} / ${formatBytes(model.total_size)}` :
|
||||
formatBytes(model.total_size)">
|
||||
</div>
|
||||
</template>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
</div>
|
||||
</div>
|
||||
<!-- Error Toast -->
|
||||
<div x-show="errorMessage !== null" x-transition.opacity class="toast">
|
||||
<div class="toast-header">
|
||||
<span class="toast-error-message" x-text="errorMessage?.basic || ''"></span>
|
||||
<div class="toast-header-buttons">
|
||||
<button @click="errorExpanded = !errorExpanded; if (errorTimeout) { clearTimeout(errorTimeout); errorTimeout = null; }"
|
||||
class="toast-expand-button"
|
||||
<button @click="errorExpanded = !errorExpanded; if (errorTimeout) { clearTimeout(errorTimeout); errorTimeout = null; }"
|
||||
class="toast-expand-button"
|
||||
x-show="errorMessage?.stack">
|
||||
<span x-text="errorExpanded ? 'Hide Details' : 'Show Details'"></span>
|
||||
</button>
|
||||
@@ -119,8 +169,8 @@
|
||||
" x-show="home === 0" x-transition="">
|
||||
<h1 class="title megrim-regular">tinychat</h1>
|
||||
<template x-if="histories.length">
|
||||
<button
|
||||
@click="if(confirm('Are you sure you want to clear all history?')) clearAllHistory();"
|
||||
<button
|
||||
@click="if(confirm('Are you sure you want to clear all history?')) clearAllHistory();"
|
||||
class="clear-history-button">
|
||||
<i class="fas fa-trash"></i> Clear All History
|
||||
</button>
|
||||
@@ -149,7 +199,7 @@
|
||||
otx = $event.changedTouches[0].clientX;
|
||||
" class="history" x-data="{ otx: 0, trigger: 75 }">
|
||||
<h3 x-text="new Date(_state.time).toLocaleString()"></h3>
|
||||
<p x-text="$truncate(_state.messages[0].content, 80)"></p>
|
||||
<p x-text="$truncate(_state.messages[0].content, 76)"></p>
|
||||
<!-- delete button -->
|
||||
<button @click.stop="removeHistory(_state);" class="history-delete-button">
|
||||
<i class="fas fa-trash"></i>
|
||||
@@ -162,62 +212,101 @@
|
||||
</template>
|
||||
</div>
|
||||
</div>
|
||||
<button
|
||||
</div>
|
||||
<button
|
||||
@click="
|
||||
home = 0;
|
||||
cstate = { time: null, messages: [], selectedModel: cstate.selectedModel };
|
||||
time_till_first = 0;
|
||||
tokens_per_second = 0;
|
||||
total_tokens = 0;
|
||||
"
|
||||
"
|
||||
class="back-button"
|
||||
x-show="home === 2">
|
||||
<i class="fas fa-arrow-left"></i>
|
||||
Back to Chats
|
||||
</button>
|
||||
<div class="messages" x-init="
|
||||
$watch('cstate', value => {
|
||||
$el.innerHTML = '';
|
||||
value.messages.forEach(({ role, content }) => {
|
||||
const div = document.createElement('div');
|
||||
div.className = `message message-role-${role}`;
|
||||
try {
|
||||
div.innerHTML = DOMPurify.sanitize(marked.parse(content));
|
||||
} catch (e) {
|
||||
console.log(content);
|
||||
console.error(e);
|
||||
<div class="messages"
|
||||
x-init="
|
||||
$watch('cstate', (value) => {
|
||||
$el.innerHTML = '';
|
||||
|
||||
value.messages.forEach((msg) => {
|
||||
const div = document.createElement('div');
|
||||
div.className = `message message-role-${msg.role}`;
|
||||
|
||||
try {
|
||||
// If there's an embedded generated image
|
||||
if (msg.content.includes('![Generated Image]')) {
|
||||
const imageUrlMatch = msg.content.match(/\((.*?)\)/);
|
||||
if (imageUrlMatch) {
|
||||
const imageUrl = imageUrlMatch[1];
|
||||
const img = document.createElement('img');
|
||||
img.src = imageUrl;
|
||||
img.alt = 'Generated Image';
|
||||
|
||||
img.onclick = async () => {
|
||||
try {
|
||||
const response = await fetch(img.src);
|
||||
const blob = await response.blob();
|
||||
const file = new File([blob], 'image.png', { type: 'image/png' });
|
||||
handleImageUpload({ target: { files: [file] } });
|
||||
} catch (error) {
|
||||
console.error('Error fetching image:', error);
|
||||
}
|
||||
};
|
||||
div.appendChild(img);
|
||||
} else {
|
||||
// fallback if markdown is malformed
|
||||
div.textContent = msg.content;
|
||||
}
|
||||
} else {
|
||||
// Otherwise, transform message text (including streamed think blocks).
|
||||
div.innerHTML = transformMessageContent(msg);
|
||||
// Render math after content is inserted
|
||||
MathJax.typesetPromise([div]);
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('Error rendering message:', e);
|
||||
div.textContent = msg.content; // fallback
|
||||
}
|
||||
|
||||
// add a clipboard button to all code blocks
|
||||
const codeBlocks = div.querySelectorAll('.hljs');
|
||||
codeBlocks.forEach(codeBlock => {
|
||||
const button = document.createElement('button');
|
||||
button.className = 'clipboard-button';
|
||||
button.innerHTML = '<i class=\'fas fa-clipboard\'></i>';
|
||||
button.onclick = () => {
|
||||
// navigator.clipboard.writeText(codeBlock.textContent);
|
||||
const range = document.createRange();
|
||||
range.setStartBefore(codeBlock);
|
||||
range.setEndAfter(codeBlock);
|
||||
window.getSelection()?.removeAllRanges();
|
||||
window.getSelection()?.addRange(range);
|
||||
document.execCommand('copy');
|
||||
window.getSelection()?.removeAllRanges();
|
||||
// Add a clipboard button to code blocks
|
||||
const codeBlocks = div.querySelectorAll('.hljs');
|
||||
codeBlocks.forEach((codeBlock) => {
|
||||
const button = document.createElement('button');
|
||||
button.className = 'clipboard-button';
|
||||
button.innerHTML = '<i class=\'fas fa-clipboard\'></i>';
|
||||
|
||||
button.innerHTML = '<i class=\'fas fa-check\'></i>';
|
||||
setTimeout(() => button.innerHTML = '<i class=\'fas fa-clipboard\'></i>', 1000);
|
||||
};
|
||||
codeBlock.appendChild(button);
|
||||
});
|
||||
button.onclick = () => {
|
||||
const range = document.createRange();
|
||||
range.setStartBefore(codeBlock);
|
||||
range.setEndAfter(codeBlock);
|
||||
window.getSelection()?.removeAllRanges();
|
||||
window.getSelection()?.addRange(range);
|
||||
document.execCommand('copy');
|
||||
window.getSelection()?.removeAllRanges();
|
||||
|
||||
$el.appendChild(div);
|
||||
button.innerHTML = '<i class=\'fas fa-check\'></i>';
|
||||
setTimeout(() => {
|
||||
button.innerHTML = '<i class=\'fas fa-clipboard\'></i>';
|
||||
}, 1000);
|
||||
};
|
||||
|
||||
codeBlock.appendChild(button);
|
||||
});
|
||||
|
||||
$el.scrollTo({ top: $el.scrollHeight, behavior: 'smooth' });
|
||||
$el.appendChild(div);
|
||||
});
|
||||
" x-intersect="
|
||||
|
||||
// Scroll to bottom after rendering
|
||||
$el.scrollTo({ top: $el.scrollHeight, behavior: 'smooth' });
|
||||
" x-ref="messages" x-show="home === 2" x-transition="">
|
||||
});
|
||||
"
|
||||
x-ref="messages"
|
||||
x-show="home === 2"
|
||||
x-transition=""
|
||||
>
|
||||
</div>
|
||||
|
||||
<!-- Download Progress Section -->
|
||||
@@ -232,7 +321,7 @@
|
||||
<p><strong>Model:</strong> <span x-text="progress.repo_id + '@' + progress.repo_revision"></span></p>
|
||||
<p><strong>Status:</strong> <span x-text="progress.status"></span></p>
|
||||
<div class="progress-bar-container">
|
||||
<div class="progress-bar"
|
||||
<div class="progress-bar"
|
||||
:class="progress.isComplete ? 'complete' : 'in-progress'"
|
||||
:style="`width: ${progress.percentage}%;`">
|
||||
</div>
|
||||
@@ -253,6 +342,10 @@
|
||||
<div class="input-container">
|
||||
<div class="input-performance">
|
||||
<span class="input-performance-point">
|
||||
<p class="monospace" x-text="models[cstate.selectedModel]?.name || cstate.selectedModel"></p>
|
||||
<p class="megrim-regular">-</p>
|
||||
</span>
|
||||
<span class="input-performance-point">
|
||||
<p class="monospace" x-text="(time_till_first / 1000).toFixed(2)"></p>
|
||||
<p class="megrim-regular">SEC TO FIRST TOKEN</p>
|
||||
</span>
|
||||
@@ -266,7 +359,7 @@
|
||||
</span>
|
||||
</div>
|
||||
<div class="input">
|
||||
<button @click="$refs.imageUpload.click()" class="image-input-button" x-show="cstate.selectedModel === 'llava-1.5-7b-hf'">
|
||||
<button @click="$refs.imageUpload.click()" class="image-input-button" x-show="cstate.selectedModel === 'llava-1.5-7b-hf' || cstate.selectedModel === 'stable-diffusion-2-1-base'">
|
||||
<i class="fas fa-image"></i>
|
||||
</button>
|
||||
<input @change="$data.handleImageUpload($event)" accept="image/*" id="image-upload" style="display: none;" type="file" x-ref="imageUpload"/>
|
||||
@@ -276,10 +369,10 @@
|
||||
<i class="fas fa-times"></i>
|
||||
</button>
|
||||
</div>
|
||||
<textarea
|
||||
:disabled="generating || (downloadProgress?.length > 0 && downloadProgress.some(p => !p.isComplete))"
|
||||
<textarea
|
||||
:disabled="generating || (downloadProgress?.length > 0 && downloadProgress.some(p => !p.isComplete))"
|
||||
:placeholder="
|
||||
generating ? 'Generating...' :
|
||||
generating ? 'Generating...' :
|
||||
(downloadProgress?.length > 0 && downloadProgress.some(p => !p.isComplete)) ? 'Download in progress...' :
|
||||
'Say something'
|
||||
"
|
||||
@@ -311,13 +404,51 @@
|
||||
});
|
||||
"
|
||||
x-ref="inputForm"></textarea>
|
||||
<button
|
||||
:disabled="generating || (downloadProgress?.length > 0 && downloadProgress.some(p => !p.isComplete))"
|
||||
@click="await handleSend()"
|
||||
<button
|
||||
:disabled="generating || (downloadProgress?.length > 0 && downloadProgress.some(p => !p.isComplete))"
|
||||
@click="await handleSend()"
|
||||
class="input-button">
|
||||
<i :class="generating ? 'fa-spinner fa-spin' : 'fa-paper-plane'" class="fas"></i>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</main>
|
||||
|
||||
<script>
|
||||
/**
|
||||
* Transform a single message's content into HTML, preserving <think> blocks.
|
||||
* Ensure LaTeX expressions are properly delimited for MathJax.
|
||||
*/
|
||||
function transformMessageContent(message) {
|
||||
let text = message.content;
|
||||
console.log('Processing message content:', text);
|
||||
|
||||
// First replace think blocks
|
||||
text = text.replace(
|
||||
/<think>([\s\S]*?)(?:<\/think>|$)/g,
|
||||
(match, body) => {
|
||||
console.log('Found think block with content:', body);
|
||||
const isComplete = match.includes('</think>');
|
||||
const spinnerClass = isComplete ? '' : ' thinking';
|
||||
const parsedBody = DOMPurify.sanitize(marked.parse(body));
|
||||
return `
|
||||
<div class='thinking-block'>
|
||||
<div class='thinking-header${spinnerClass}'>Thinking...</div>
|
||||
<div class='thinking-content'>${parsedBody}</div>
|
||||
</div>`;
|
||||
}
|
||||
);
|
||||
|
||||
// Add backslashes to parentheses and brackets for LaTeX
|
||||
text = text
|
||||
.replace(/\((?=\s*[\d\\])/g, '\\(') // Add backslash before opening parentheses
|
||||
.replace(/\)(?!\w)/g, '\\)') // Add backslash before closing parentheses
|
||||
.replace(/\[(?=\s*[\d\\])/g, '\\[') // Add backslash before opening brackets
|
||||
.replace(/\](?!\w)/g, '\\]') // Add backslash before closing brackets
|
||||
.replace(/\[[\s\n]*\\boxed/g, '\\[\\boxed') // Ensure boxed expressions are properly delimited
|
||||
.replace(/\\!/g, '\\\\!'); // Preserve LaTeX spacing commands
|
||||
|
||||
return DOMPurify.sanitize(marked.parse(text));
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
|
||||
@@ -5,7 +5,7 @@ document.addEventListener("alpine:init", () => {
|
||||
time: null,
|
||||
messages: [],
|
||||
selectedModel: 'llama-3.2-1b',
|
||||
},
|
||||
},
|
||||
|
||||
// historical state
|
||||
histories: JSON.parse(localStorage.getItem("histories")) || [],
|
||||
@@ -13,7 +13,7 @@ document.addEventListener("alpine:init", () => {
|
||||
home: 0,
|
||||
generating: false,
|
||||
endpoint: `${window.location.origin}/v1`,
|
||||
|
||||
|
||||
// Initialize error message structure
|
||||
errorMessage: null,
|
||||
errorExpanded: false,
|
||||
@@ -39,6 +39,15 @@ document.addEventListener("alpine:init", () => {
|
||||
// Add models state alongside existing state
|
||||
models: {},
|
||||
|
||||
// Show only models available locally
|
||||
showDownloadedOnly: false,
|
||||
|
||||
topology: null,
|
||||
topologyInterval: null,
|
||||
|
||||
// Add these new properties
|
||||
expandedGroups: {},
|
||||
|
||||
init() {
|
||||
// Clean up any pending messages
|
||||
localStorage.removeItem("pendingMessage");
|
||||
@@ -48,7 +57,7 @@ document.addEventListener("alpine:init", () => {
|
||||
|
||||
// Start polling for download progress
|
||||
this.startDownloadProgressPolling();
|
||||
|
||||
|
||||
// Start model polling with the new pattern
|
||||
this.startModelPolling();
|
||||
},
|
||||
@@ -69,12 +78,12 @@ document.addEventListener("alpine:init", () => {
|
||||
while (true) {
|
||||
try {
|
||||
await this.populateSelector();
|
||||
// Wait 5 seconds before next poll
|
||||
await new Promise(resolve => setTimeout(resolve, 5000));
|
||||
// Wait 15 seconds before next poll
|
||||
await new Promise(resolve => setTimeout(resolve, 15000));
|
||||
} catch (error) {
|
||||
console.error('Model polling error:', error);
|
||||
// If there's an error, wait before retrying
|
||||
await new Promise(resolve => setTimeout(resolve, 5000));
|
||||
await new Promise(resolve => setTimeout(resolve, 15000));
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -82,14 +91,14 @@ document.addEventListener("alpine:init", () => {
|
||||
async populateSelector() {
|
||||
return new Promise((resolve, reject) => {
|
||||
const evtSource = new EventSource(`${window.location.origin}/modelpool`);
|
||||
|
||||
|
||||
evtSource.onmessage = (event) => {
|
||||
if (event.data === "[DONE]") {
|
||||
evtSource.close();
|
||||
resolve();
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
const modelData = JSON.parse(event.data);
|
||||
// Update existing model data while preserving other properties
|
||||
Object.entries(modelData).forEach(([modelName, data]) => {
|
||||
@@ -102,7 +111,7 @@ document.addEventListener("alpine:init", () => {
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
|
||||
evtSource.onerror = (error) => {
|
||||
console.error('EventSource failed:', error);
|
||||
evtSource.close();
|
||||
@@ -228,53 +237,110 @@ document.addEventListener("alpine:init", () => {
|
||||
};
|
||||
}
|
||||
});
|
||||
const containsImage = apiMessages.some(msg => Array.isArray(msg.content) && msg.content.some(item => item.type === 'image_url'));
|
||||
if (containsImage) {
|
||||
// Map all messages with string content to object with type text
|
||||
apiMessages = apiMessages.map(msg => {
|
||||
if (typeof msg.content === 'string') {
|
||||
return {
|
||||
...msg,
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: msg.content
|
||||
}
|
||||
]
|
||||
};
|
||||
}
|
||||
return msg;
|
||||
|
||||
if (this.cstate.selectedModel === "stable-diffusion-2-1-base") {
|
||||
// Send a request to the image generation endpoint
|
||||
console.log(apiMessages[apiMessages.length - 1].content)
|
||||
console.log(this.cstate.selectedModel)
|
||||
console.log(this.endpoint)
|
||||
const response = await fetch(`${this.endpoint}/image/generations`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
"model": 'stable-diffusion-2-1-base',
|
||||
"prompt": apiMessages[apiMessages.length - 1].content,
|
||||
"image_url": this.imageUrl
|
||||
}),
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
// start receiving server sent events
|
||||
let gottenFirstChunk = false;
|
||||
for await (
|
||||
const chunk of this.openaiChatCompletion(this.cstate.selectedModel, apiMessages)
|
||||
) {
|
||||
if (!gottenFirstChunk) {
|
||||
this.cstate.messages.push({ role: "assistant", content: "" });
|
||||
gottenFirstChunk = true;
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error("Failed to fetch");
|
||||
}
|
||||
|
||||
// add chunk to the last message
|
||||
this.cstate.messages[this.cstate.messages.length - 1].content += chunk;
|
||||
|
||||
// calculate performance tracking
|
||||
tokens += 1;
|
||||
this.total_tokens += 1;
|
||||
if (start_time === 0) {
|
||||
start_time = Date.now();
|
||||
this.time_till_first = start_time - prefill_start;
|
||||
} else {
|
||||
const diff = Date.now() - start_time;
|
||||
if (diff > 0) {
|
||||
this.tokens_per_second = tokens / (diff / 1000);
|
||||
const reader = response.body.getReader();
|
||||
let done = false;
|
||||
let gottenFirstChunk = false;
|
||||
|
||||
while (!done) {
|
||||
const { value, done: readerDone } = await reader.read();
|
||||
done = readerDone;
|
||||
const decoder = new TextDecoder();
|
||||
|
||||
if (value) {
|
||||
// Assume non-binary data (text) comes first
|
||||
const chunk = decoder.decode(value, { stream: true });
|
||||
const parsed = JSON.parse(chunk);
|
||||
console.log(parsed)
|
||||
|
||||
if (parsed.progress) {
|
||||
if (!gottenFirstChunk) {
|
||||
this.cstate.messages.push({ role: "assistant", content: "" });
|
||||
gottenFirstChunk = true;
|
||||
}
|
||||
this.cstate.messages[this.cstate.messages.length - 1].content = parsed.progress;
|
||||
}
|
||||
else if (parsed.images) {
|
||||
if (!gottenFirstChunk) {
|
||||
this.cstate.messages.push({ role: "assistant", content: "" });
|
||||
gottenFirstChunk = true;
|
||||
}
|
||||
const imageUrl = parsed.images[0].url;
|
||||
console.log(imageUrl)
|
||||
this.cstate.messages[this.cstate.messages.length - 1].content = `})`;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
else{
|
||||
const containsImage = apiMessages.some(msg => Array.isArray(msg.content) && msg.content.some(item => item.type === 'image_url'));
|
||||
if (containsImage) {
|
||||
// Map all messages with string content to object with type text
|
||||
apiMessages = apiMessages.map(msg => {
|
||||
if (typeof msg.content === 'string') {
|
||||
return {
|
||||
...msg,
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: msg.content
|
||||
}
|
||||
]
|
||||
};
|
||||
}
|
||||
return msg;
|
||||
});
|
||||
}
|
||||
|
||||
console.log(apiMessages)
|
||||
//start receiving server sent events
|
||||
let gottenFirstChunk = false;
|
||||
for await (
|
||||
const chunk of this.openaiChatCompletion(this.cstate.selectedModel, apiMessages)
|
||||
) {
|
||||
if (!gottenFirstChunk) {
|
||||
this.cstate.messages.push({ role: "assistant", content: "" });
|
||||
gottenFirstChunk = true;
|
||||
}
|
||||
|
||||
// add chunk to the last message
|
||||
this.cstate.messages[this.cstate.messages.length - 1].content += chunk;
|
||||
|
||||
// calculate performance tracking
|
||||
tokens += 1;
|
||||
this.total_tokens += 1;
|
||||
if (start_time === 0) {
|
||||
start_time = Date.now();
|
||||
this.time_till_first = start_time - prefill_start;
|
||||
} else {
|
||||
const diff = Date.now() - start_time;
|
||||
if (diff > 0) {
|
||||
this.tokens_per_second = tokens / (diff / 1000);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Clean the cstate before adding it to histories
|
||||
const cleanedCstate = JSON.parse(JSON.stringify(this.cstate));
|
||||
cleanedCstate.messages = cleanedCstate.messages.map(msg => {
|
||||
@@ -333,8 +399,6 @@ document.addEventListener("alpine:init", () => {
|
||||
},
|
||||
|
||||
async *openaiChatCompletion(model, messages) {
|
||||
// stream response
|
||||
console.log("model", model)
|
||||
const response = await fetch(`${this.endpoint}/chat/completions`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
@@ -357,19 +421,17 @@ document.addEventListener("alpine:init", () => {
|
||||
|
||||
const reader = response.body.pipeThrough(new TextDecoderStream())
|
||||
.pipeThrough(new EventSourceParserStream()).getReader();
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) {
|
||||
break;
|
||||
}
|
||||
if (done) break;
|
||||
|
||||
if (value.type === "event") {
|
||||
const json = JSON.parse(value.data);
|
||||
if (json.choices) {
|
||||
const choice = json.choices[0];
|
||||
if (choice.finish_reason === "stop") {
|
||||
break;
|
||||
}
|
||||
yield choice.delta.content;
|
||||
if (choice.finish_reason === "stop") break;
|
||||
if (choice.delta.content) yield choice.delta.content;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -452,7 +514,7 @@ document.addEventListener("alpine:init", () => {
|
||||
stack: error.stack || ""
|
||||
};
|
||||
this.errorExpanded = false;
|
||||
|
||||
|
||||
if (this.errorTimeout) {
|
||||
clearTimeout(this.errorTimeout);
|
||||
}
|
||||
@@ -467,10 +529,10 @@ document.addEventListener("alpine:init", () => {
|
||||
|
||||
async deleteModel(modelName, model) {
|
||||
const downloadedSize = model.total_downloaded || 0;
|
||||
const sizeMessage = downloadedSize > 0 ?
|
||||
const sizeMessage = downloadedSize > 0 ?
|
||||
`This will free up ${this.formatBytes(downloadedSize)} of space.` :
|
||||
'This will remove any partially downloaded files.';
|
||||
|
||||
|
||||
if (!confirm(`Are you sure you want to delete ${model.name}? ${sizeMessage}`)) {
|
||||
return;
|
||||
}
|
||||
@@ -484,7 +546,7 @@ document.addEventListener("alpine:init", () => {
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(data.detail || 'Failed to delete model');
|
||||
}
|
||||
@@ -543,7 +605,127 @@ document.addEventListener("alpine:init", () => {
|
||||
console.error('Error starting download:', error);
|
||||
this.setError(error);
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
async fetchTopology() {
|
||||
try {
|
||||
const response = await fetch(`${this.endpoint}/topology`);
|
||||
if (!response.ok) throw new Error('Failed to fetch topology');
|
||||
return await response.json();
|
||||
} catch (error) {
|
||||
console.error('Topology fetch error:', error);
|
||||
return null;
|
||||
}
|
||||
},
|
||||
|
||||
initTopology() {
|
||||
// Initial fetch
|
||||
this.updateTopology();
|
||||
|
||||
// Set up periodic updates
|
||||
this.topologyInterval = setInterval(() => this.updateTopology(), 5000);
|
||||
|
||||
// Cleanup on page unload
|
||||
window.addEventListener('beforeunload', () => {
|
||||
if (this.topologyInterval) {
|
||||
clearInterval(this.topologyInterval);
|
||||
}
|
||||
});
|
||||
},
|
||||
|
||||
async updateTopology() {
|
||||
const topologyData = await this.fetchTopology();
|
||||
if (!topologyData) return;
|
||||
|
||||
const vizElement = this.$refs.topologyViz;
|
||||
vizElement.innerHTML = ''; // Clear existing visualization
|
||||
|
||||
// Helper function to truncate node ID
|
||||
const truncateNodeId = (id) => id.substring(0, 8);
|
||||
|
||||
// Create nodes from object
|
||||
Object.entries(topologyData.nodes).forEach(([nodeId, node]) => {
|
||||
const nodeElement = document.createElement('div');
|
||||
nodeElement.className = 'topology-node';
|
||||
|
||||
// Get peer connections for this node
|
||||
const peerConnections = topologyData.peer_graph[nodeId] || [];
|
||||
const peerConnectionsHtml = peerConnections.map(peer => `
|
||||
<div class="peer-connection">
|
||||
<i class="fas fa-arrow-right"></i>
|
||||
<span>To ${truncateNodeId(peer.to_id)}: ${peer.description}</span>
|
||||
</div>
|
||||
`).join('');
|
||||
|
||||
nodeElement.innerHTML = `
|
||||
<div class="node-info">
|
||||
<span class="status ${nodeId === topologyData.active_node_id ? 'active' : 'inactive'}"></span>
|
||||
<span>${node.model} [${truncateNodeId(nodeId)}]</span>
|
||||
</div>
|
||||
<div class="node-details">
|
||||
<span>${node.chip}</span>
|
||||
<span>${(node.memory / 1024).toFixed(1)}GB RAM</span>
|
||||
<span>${node.flops.fp32.toFixed(1)} TF</span>
|
||||
</div>
|
||||
<div class="peer-connections">
|
||||
${peerConnectionsHtml}
|
||||
</div>
|
||||
`;
|
||||
vizElement.appendChild(nodeElement);
|
||||
});
|
||||
},
|
||||
|
||||
// Add these helper methods
|
||||
countDownloadedModels(models) {
|
||||
return Object.values(models).filter(model => model.downloaded).length;
|
||||
},
|
||||
|
||||
getGroupCounts(groupModels) {
|
||||
const total = Object.keys(groupModels).length;
|
||||
const downloaded = this.countDownloadedModels(groupModels);
|
||||
return `[${downloaded}/${total}]`;
|
||||
},
|
||||
|
||||
// Update the existing groupModelsByPrefix method to include counts
|
||||
groupModelsByPrefix(models) {
|
||||
const groups = {};
|
||||
const filteredModels = this.showDownloadedOnly ?
|
||||
Object.fromEntries(Object.entries(models).filter(([, model]) => model.downloaded)) :
|
||||
models;
|
||||
|
||||
Object.entries(filteredModels).forEach(([key, model]) => {
|
||||
const parts = key.split('-');
|
||||
const mainPrefix = parts[0].toUpperCase();
|
||||
|
||||
let subPrefix;
|
||||
if (parts.length === 2) {
|
||||
subPrefix = parts[1].toUpperCase();
|
||||
} else if (parts.length > 2) {
|
||||
subPrefix = parts[1].toUpperCase();
|
||||
} else {
|
||||
subPrefix = 'OTHER';
|
||||
}
|
||||
|
||||
if (!groups[mainPrefix]) {
|
||||
groups[mainPrefix] = {};
|
||||
}
|
||||
if (!groups[mainPrefix][subPrefix]) {
|
||||
groups[mainPrefix][subPrefix] = {};
|
||||
}
|
||||
groups[mainPrefix][subPrefix][key] = model;
|
||||
});
|
||||
return groups;
|
||||
},
|
||||
|
||||
toggleGroup(prefix, subPrefix = null) {
|
||||
const key = subPrefix ? `${prefix}-${subPrefix}` : prefix;
|
||||
this.expandedGroups[key] = !this.expandedGroups[key];
|
||||
},
|
||||
|
||||
isGroupExpanded(prefix, subPrefix = null) {
|
||||
const key = subPrefix ? `${prefix}-${subPrefix}` : prefix;
|
||||
return this.expandedGroups[key] || false;
|
||||
},
|
||||
}));
|
||||
});
|
||||
|
||||
|
||||
@@ -3,6 +3,8 @@ from pydantic import BaseModel
|
||||
from exo import DEBUG
|
||||
import subprocess
|
||||
import psutil
|
||||
import asyncio
|
||||
from exo.helpers import get_mac_system_info, subprocess_pool
|
||||
|
||||
TFLOPS = 1.00
|
||||
|
||||
@@ -144,11 +146,13 @@ CHIP_FLOPS.update({f"{key} LAPTOP GPU": value for key, value in CHIP_FLOPS.items
|
||||
CHIP_FLOPS.update({f"{key} Laptop GPU": value for key, value in CHIP_FLOPS.items()})
|
||||
|
||||
|
||||
def device_capabilities() -> DeviceCapabilities:
|
||||
async def device_capabilities() -> DeviceCapabilities:
|
||||
if psutil.MACOS:
|
||||
return mac_device_capabilities()
|
||||
return await mac_device_capabilities()
|
||||
elif psutil.LINUX:
|
||||
return linux_device_capabilities()
|
||||
return await linux_device_capabilities()
|
||||
elif psutil.WINDOWS:
|
||||
return await windows_device_capabilities()
|
||||
else:
|
||||
return DeviceCapabilities(
|
||||
model="Unknown Device",
|
||||
@@ -158,27 +162,18 @@ def device_capabilities() -> DeviceCapabilities:
|
||||
)
|
||||
|
||||
|
||||
def mac_device_capabilities() -> DeviceCapabilities:
|
||||
# Fetch the model of the Mac using system_profiler
|
||||
model = subprocess.check_output(["system_profiler", "SPHardwareDataType"]).decode("utf-8")
|
||||
model_line = next((line for line in model.split("\n") if "Model Name" in line), None)
|
||||
model_id = model_line.split(": ")[1] if model_line else "Unknown Model"
|
||||
chip_line = next((line for line in model.split("\n") if "Chip" in line), None)
|
||||
chip_id = chip_line.split(": ")[1] if chip_line else "Unknown Chip"
|
||||
memory_line = next((line for line in model.split("\n") if "Memory" in line), None)
|
||||
memory_str = memory_line.split(": ")[1] if memory_line else "Unknown Memory"
|
||||
memory_units = memory_str.split()
|
||||
memory_value = int(memory_units[0])
|
||||
if memory_units[1] == "GB":
|
||||
memory = memory_value*1024
|
||||
else:
|
||||
memory = memory_value
|
||||
|
||||
# Assuming static values for other attributes for demonstration
|
||||
return DeviceCapabilities(model=model_id, chip=chip_id, memory=memory, flops=CHIP_FLOPS.get(chip_id, DeviceFlops(fp32=0, fp16=0, int8=0)))
|
||||
async def mac_device_capabilities() -> DeviceCapabilities:
|
||||
model_id, chip_id, memory = await get_mac_system_info()
|
||||
|
||||
return DeviceCapabilities(
|
||||
model=model_id,
|
||||
chip=chip_id,
|
||||
memory=memory,
|
||||
flops=CHIP_FLOPS.get(chip_id, DeviceFlops(fp32=0, fp16=0, int8=0))
|
||||
)
|
||||
|
||||
|
||||
def linux_device_capabilities() -> DeviceCapabilities:
|
||||
async def linux_device_capabilities() -> DeviceCapabilities:
|
||||
import psutil
|
||||
from tinygrad import Device
|
||||
|
||||
@@ -194,6 +189,8 @@ def linux_device_capabilities() -> DeviceCapabilities:
|
||||
|
||||
if DEBUG >= 2: print(f"NVIDIA device {gpu_name=} {gpu_memory_info=}")
|
||||
|
||||
pynvml.nvmlShutdown()
|
||||
|
||||
return DeviceCapabilities(
|
||||
model=f"Linux Box ({gpu_name})",
|
||||
chip=gpu_name,
|
||||
@@ -201,13 +198,24 @@ def linux_device_capabilities() -> DeviceCapabilities:
|
||||
flops=CHIP_FLOPS.get(gpu_name, DeviceFlops(fp32=0, fp16=0, int8=0)),
|
||||
)
|
||||
elif Device.DEFAULT == "AMD":
|
||||
# TODO AMD support
|
||||
# For AMD GPUs, pyrsmi is the way (Official python package for rocm-smi)
|
||||
from pyrsmi import rocml
|
||||
|
||||
rocml.smi_initialize()
|
||||
gpu_name = rocml.smi_get_device_name(0).upper()
|
||||
gpu_memory_info = rocml.smi_get_device_memory_total(0)
|
||||
|
||||
if DEBUG >= 2: print(f"AMD device {gpu_name=} {gpu_memory_info=}")
|
||||
|
||||
rocml.smi_shutdown()
|
||||
|
||||
return DeviceCapabilities(
|
||||
model="Linux Box (AMD)",
|
||||
chip="Unknown AMD",
|
||||
memory=psutil.virtual_memory().total // 2**20,
|
||||
model="Linux Box ({gpu_name})",
|
||||
chip=gpu_name,
|
||||
memory=gpu_memory_info // 2**20,
|
||||
flops=DeviceFlops(fp32=0, fp16=0, int8=0),
|
||||
)
|
||||
|
||||
else:
|
||||
return DeviceCapabilities(
|
||||
model=f"Linux Box (Device: {Device.DEFAULT})",
|
||||
@@ -215,3 +223,74 @@ def linux_device_capabilities() -> DeviceCapabilities:
|
||||
memory=psutil.virtual_memory().total // 2**20,
|
||||
flops=DeviceFlops(fp32=0, fp16=0, int8=0),
|
||||
)
|
||||
|
||||
|
||||
def windows_device_capabilities() -> DeviceCapabilities:
|
||||
import psutil
|
||||
|
||||
def get_gpu_info():
|
||||
import win32com.client # install pywin32
|
||||
|
||||
wmiObj = win32com.client.GetObject("winmgmts:\\\\.\\root\\cimv2")
|
||||
gpus = wmiObj.ExecQuery("SELECT * FROM Win32_VideoController")
|
||||
|
||||
gpu_info = []
|
||||
for gpu in gpus:
|
||||
info = {
|
||||
"Name": gpu.Name,
|
||||
"AdapterRAM": gpu.AdapterRAM, # Bug in this property, returns -ve for VRAM > 4GB (uint32 overflow)
|
||||
"DriverVersion": gpu.DriverVersion,
|
||||
"VideoProcessor": gpu.VideoProcessor
|
||||
}
|
||||
gpu_info.append(info)
|
||||
|
||||
return gpu_info
|
||||
|
||||
gpus_info = get_gpu_info()
|
||||
gpu_names = [gpu['Name'] for gpu in gpus_info]
|
||||
|
||||
contains_nvidia = any('nvidia' in gpu_name.lower() for gpu_name in gpu_names)
|
||||
contains_amd = any('amd' in gpu_name.lower() for gpu_name in gpu_names)
|
||||
|
||||
if contains_nvidia:
|
||||
import pynvml
|
||||
|
||||
pynvml.nvmlInit()
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
|
||||
gpu_raw_name = pynvml.nvmlDeviceGetName(handle).upper()
|
||||
gpu_name = gpu_raw_name.rsplit(" ", 1)[0] if gpu_raw_name.endswith("GB") else gpu_raw_name
|
||||
gpu_memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||
|
||||
if DEBUG >= 2: print(f"NVIDIA device {gpu_name=} {gpu_memory_info=}")
|
||||
|
||||
return DeviceCapabilities(
|
||||
model=f"Windows Box ({gpu_name})",
|
||||
chip=gpu_name,
|
||||
memory=gpu_memory_info.total // 2**20,
|
||||
flops=CHIP_FLOPS.get(gpu_name, DeviceFlops(fp32=0, fp16=0, int8=0)),
|
||||
)
|
||||
elif contains_amd:
|
||||
# For AMD GPUs, pyrsmi is the way (Official python package for rocm-smi)
|
||||
from pyrsmi import rocml
|
||||
|
||||
rocml.smi_initialize()
|
||||
gpu_name = rocml.smi_get_device_name(0).upper()
|
||||
gpu_memory_info = rocml.smi_get_device_memory_total(0)
|
||||
|
||||
if DEBUG >= 2: print(f"AMD device {gpu_name=} {gpu_memory_info=}")
|
||||
|
||||
rocml.smi_shutdown()
|
||||
|
||||
return DeviceCapabilities(
|
||||
model="Windows Box ({gpu_name})",
|
||||
chip={gpu_name},
|
||||
memory=gpu_memory_info.total // 2**20,
|
||||
flops=DeviceFlops(fp32=0, fp16=0, int8=0),
|
||||
)
|
||||
else:
|
||||
return DeviceCapabilities(
|
||||
model=f"Windows Box (Device: Unknown)",
|
||||
chip=f"Unknown Chip (Device(s): {gpu_names})",
|
||||
memory=psutil.virtual_memory().total // 2**20,
|
||||
flops=DeviceFlops(fp32=0, fp16=0, int8=0),
|
||||
)
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
from typing import List, Dict
|
||||
from dataclasses import dataclass
|
||||
from .topology import Topology
|
||||
from exo.inference.shard import Shard
|
||||
from exo.topology.device_capabilities import device_capabilities
|
||||
import asyncio
|
||||
|
||||
|
||||
# Partitions shard-space into pieces of contiguous shards, represented by floating point range [start, end) between 0 and 1
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import unittest
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
from exo.topology.device_capabilities import mac_device_capabilities, DeviceCapabilities, DeviceFlops, TFLOPS
|
||||
from exo.topology.device_capabilities import mac_device_capabilities, DeviceCapabilities, DeviceFlops, TFLOPS, device_capabilities
|
||||
|
||||
|
||||
class TestMacDeviceCapabilities(unittest.TestCase):
|
||||
@patch("subprocess.check_output")
|
||||
def test_mac_device_capabilities_pro(self, mock_check_output):
|
||||
@pytest.mark.asyncio
|
||||
@patch("subprocess.check_output")
|
||||
async def test_mac_device_capabilities_pro(mock_check_output):
|
||||
# Mock the subprocess output
|
||||
mock_check_output.return_value = b"""
|
||||
Hardware:
|
||||
@@ -27,20 +27,19 @@ Activation Lock Status: Enabled
|
||||
"""
|
||||
|
||||
# Call the function
|
||||
result = mac_device_capabilities()
|
||||
result = await mac_device_capabilities()
|
||||
|
||||
# Check the results
|
||||
self.assertIsInstance(result, DeviceCapabilities)
|
||||
self.assertEqual(result.model, "MacBook Pro")
|
||||
self.assertEqual(result.chip, "Apple M3 Max")
|
||||
self.assertEqual(result.memory, 131072) # 16 GB in MB
|
||||
self.assertEqual(
|
||||
str(result),
|
||||
"Model: MacBook Pro. Chip: Apple M3 Max. Memory: 131072MB. Flops: 14.20 TFLOPS, fp16: 28.40 TFLOPS, int8: 56.80 TFLOPS",
|
||||
)
|
||||
assert isinstance(result, DeviceCapabilities)
|
||||
assert result.model == "MacBook Pro"
|
||||
assert result.chip == "Apple M3 Max"
|
||||
assert result.memory == 131072 # 128 GB in MB
|
||||
assert str(result) == "Model: MacBook Pro. Chip: Apple M3 Max. Memory: 131072MB. Flops: 14.20 TFLOPS, fp16: 28.40 TFLOPS, int8: 56.80 TFLOPS"
|
||||
|
||||
@patch("subprocess.check_output")
|
||||
def test_mac_device_capabilities_air(self, mock_check_output):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("subprocess.check_output")
|
||||
async def test_mac_device_capabilities_air(mock_check_output):
|
||||
# Mock the subprocess output
|
||||
mock_check_output.return_value = b"""
|
||||
Hardware:
|
||||
@@ -62,30 +61,34 @@ Activation Lock Status: Disabled
|
||||
"""
|
||||
|
||||
# Call the function
|
||||
result = mac_device_capabilities()
|
||||
result = await mac_device_capabilities()
|
||||
|
||||
# Check the results
|
||||
self.assertIsInstance(result, DeviceCapabilities)
|
||||
self.assertEqual(result.model, "MacBook Air")
|
||||
self.assertEqual(result.chip, "Apple M2")
|
||||
self.assertEqual(result.memory, 8192) # 8 GB in MB
|
||||
assert isinstance(result, DeviceCapabilities)
|
||||
assert result.model == "MacBook Air"
|
||||
assert result.chip == "Apple M2"
|
||||
assert result.memory == 8192 # 8 GB in MB
|
||||
|
||||
@unittest.skip("Unskip this test when running on a MacBook Pro, Apple M3 Max, 128GB")
|
||||
def test_mac_device_capabilities_real(self):
|
||||
|
||||
@pytest.mark.skip(reason="Unskip this test when running on a MacBook Pro, Apple M3 Max, 128GB")
|
||||
@pytest.mark.asyncio
|
||||
async def test_mac_device_capabilities_real():
|
||||
# Call the function without mocking
|
||||
result = mac_device_capabilities()
|
||||
result = await mac_device_capabilities()
|
||||
|
||||
# Check the results
|
||||
self.assertIsInstance(result, DeviceCapabilities)
|
||||
self.assertEqual(result.model, "MacBook Pro")
|
||||
self.assertEqual(result.chip, "Apple M3 Max")
|
||||
self.assertEqual(result.memory, 131072) # 128 GB in MB
|
||||
self.assertEqual(result.flops, DeviceFlops(fp32=14.20*TFLOPS, fp16=28.40*TFLOPS, int8=56.80*TFLOPS))
|
||||
self.assertEqual(
|
||||
str(result),
|
||||
"Model: MacBook Pro. Chip: Apple M3 Max. Memory: 131072MB. Flops: 14.20 TFLOPS, fp16: 28.40 TFLOPS, int8: 56.80 TFLOPS",
|
||||
)
|
||||
assert isinstance(result, DeviceCapabilities)
|
||||
assert result.model == "MacBook Pro"
|
||||
assert result.chip == "Apple M3 Max"
|
||||
assert result.memory == 131072 # 128 GB in MB
|
||||
assert result.flops == DeviceFlops(fp32=14.20*TFLOPS, fp16=28.40*TFLOPS, int8=56.80*TFLOPS)
|
||||
assert str(result) == "Model: MacBook Pro. Chip: Apple M3 Max. Memory: 131072MB. Flops: 14.20 TFLOPS, fp16: 28.40 TFLOPS, int8: 56.80 TFLOPS"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@pytest.mark.asyncio
|
||||
async def test_device_capabilities():
|
||||
caps = await device_capabilities()
|
||||
assert caps.model != ""
|
||||
assert caps.chip != ""
|
||||
assert caps.memory > 0
|
||||
assert caps.flops is not None
|
||||
|
||||
@@ -5,7 +5,7 @@ from exo.viz.topology_viz import TopologyViz
|
||||
from exo.topology.topology import Topology
|
||||
from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
|
||||
from exo.topology.partitioning_strategy import Partition
|
||||
from exo.download.hf.hf_helpers import RepoProgressEvent, RepoFileProgressEvent
|
||||
from exo.download.download_progress import RepoProgressEvent
|
||||
|
||||
|
||||
def create_hf_repo_progress_event(
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import List, Optional, Tuple, Dict
|
||||
from exo.helpers import exo_text, pretty_print_bytes, pretty_print_bytes_per_second
|
||||
from exo.topology.topology import Topology
|
||||
from exo.topology.partitioning_strategy import Partition
|
||||
from exo.download.hf.hf_helpers import RepoProgressEvent
|
||||
from exo.download.download_progress import RepoProgressEvent
|
||||
from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES
|
||||
from rich.console import Console, Group
|
||||
from rich.text import Text
|
||||
@@ -51,17 +51,11 @@ class TopologyViz:
|
||||
self.refresh()
|
||||
|
||||
def update_prompt(self, request_id: str, prompt: Optional[str] = None):
|
||||
if request_id in self.requests:
|
||||
self.requests[request_id] = [prompt, self.requests[request_id][1]]
|
||||
else:
|
||||
self.requests[request_id] = [prompt, ""]
|
||||
self.requests[request_id] = [prompt, self.requests.get(request_id, ["", ""])[1]]
|
||||
self.refresh()
|
||||
|
||||
def update_prompt_output(self, request_id: str, output: Optional[str] = None):
|
||||
if request_id in self.requests:
|
||||
self.requests[request_id] = [self.requests[request_id][0], output]
|
||||
else:
|
||||
self.requests[request_id] = ["", output]
|
||||
self.requests[request_id] = [self.requests.get(request_id, ["", ""])[0], output]
|
||||
self.refresh()
|
||||
|
||||
def refresh(self):
|
||||
@@ -91,36 +85,96 @@ class TopologyViz:
|
||||
content = []
|
||||
requests = list(self.requests.values())[-3:] # Get the 3 most recent requests
|
||||
max_width = self.console.width - 6 # Full width minus padding and icon
|
||||
max_lines = 13 # Maximum number of lines for the entire panel content
|
||||
|
||||
# Calculate available height for content
|
||||
panel_height = 15 # Fixed panel height
|
||||
available_lines = panel_height - 2 # Subtract 2 for panel borders
|
||||
lines_per_request = available_lines // len(requests) if requests else 0
|
||||
|
||||
for (prompt, output) in reversed(requests):
|
||||
prompt_icon, output_icon = "💬️", "🤖"
|
||||
|
||||
# Equal space allocation for prompt and output
|
||||
max_prompt_lines = lines_per_request // 2
|
||||
max_output_lines = lines_per_request - max_prompt_lines - 1 # -1 for spacing
|
||||
|
||||
# Process prompt
|
||||
prompt_lines = prompt.split('\n')
|
||||
if len(prompt_lines) > max_lines // 2:
|
||||
prompt_lines = prompt_lines[:max_lines//2 - 1] + ['...']
|
||||
prompt_lines = []
|
||||
for line in prompt.split('\n'):
|
||||
words = line.split()
|
||||
current_line = []
|
||||
current_length = 0
|
||||
|
||||
for word in words:
|
||||
if current_length + len(word) + 1 <= max_width:
|
||||
current_line.append(word)
|
||||
current_length += len(word) + 1
|
||||
else:
|
||||
if current_line:
|
||||
prompt_lines.append(' '.join(current_line))
|
||||
current_line = [word]
|
||||
current_length = len(word)
|
||||
|
||||
if current_line:
|
||||
prompt_lines.append(' '.join(current_line))
|
||||
|
||||
# Truncate prompt if needed
|
||||
if len(prompt_lines) > max_prompt_lines:
|
||||
prompt_lines = prompt_lines[:max_prompt_lines]
|
||||
if prompt_lines:
|
||||
last_line = prompt_lines[-1]
|
||||
if len(last_line) + 4 <= max_width:
|
||||
prompt_lines[-1] = last_line + " ..."
|
||||
else:
|
||||
prompt_lines[-1] = last_line[:max_width-4] + " ..."
|
||||
|
||||
prompt_text = Text(f"{prompt_icon} ", style="bold bright_blue")
|
||||
prompt_text.append('\n'.join(line[:max_width] for line in prompt_lines), style="white")
|
||||
|
||||
# Process output
|
||||
output_lines = output.split('\n')
|
||||
remaining_lines = max_lines - len(prompt_lines) - 2 # -2 for spacing
|
||||
if len(output_lines) > remaining_lines:
|
||||
output_lines = output_lines[:remaining_lines - 1] + ['...']
|
||||
output_text = Text(f"\n{output_icon} ", style="bold bright_magenta")
|
||||
output_text.append('\n'.join(line[:max_width] for line in output_lines), style="white")
|
||||
|
||||
prompt_text.append('\n'.join(prompt_lines), style="white")
|
||||
content.append(prompt_text)
|
||||
content.append(output_text)
|
||||
|
||||
# Process output with similar word wrapping
|
||||
if output: # Only process output if it exists
|
||||
output_lines = []
|
||||
for line in output.split('\n'):
|
||||
words = line.split()
|
||||
current_line = []
|
||||
current_length = 0
|
||||
|
||||
for word in words:
|
||||
if current_length + len(word) + 1 <= max_width:
|
||||
current_line.append(word)
|
||||
current_length += len(word) + 1
|
||||
else:
|
||||
if current_line:
|
||||
output_lines.append(' '.join(current_line))
|
||||
current_line = [word]
|
||||
current_length = len(word)
|
||||
|
||||
if current_line:
|
||||
output_lines.append(' '.join(current_line))
|
||||
|
||||
# Truncate output if needed
|
||||
if len(output_lines) > max_output_lines:
|
||||
output_lines = output_lines[:max_output_lines]
|
||||
if output_lines:
|
||||
last_line = output_lines[-1]
|
||||
if len(last_line) + 4 <= max_width:
|
||||
output_lines[-1] = last_line + " ..."
|
||||
else:
|
||||
output_lines[-1] = last_line[:max_width-4] + " ..."
|
||||
|
||||
output_text = Text(f"{output_icon} ", style="bold bright_magenta")
|
||||
output_text.append('\n'.join(output_lines), style="white")
|
||||
content.append(output_text)
|
||||
|
||||
content.append(Text()) # Empty line between entries
|
||||
|
||||
return Panel(
|
||||
Group(*content),
|
||||
title="",
|
||||
border_style="cyan",
|
||||
height=15, # Increased height to accommodate multiple lines
|
||||
expand=True # Allow the panel to expand to full width
|
||||
height=panel_height,
|
||||
expand=True
|
||||
)
|
||||
|
||||
def _generate_main_layout(self) -> str:
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
from exo.download.hf.hf_helpers import download_all_files, RepoProgressEvent
|
||||
|
||||
DEFAULT_ALLOW_PATTERNS = [
|
||||
"*.json",
|
||||
"*.py",
|
||||
"tokenizer.model",
|
||||
"*.tiktoken",
|
||||
"*.txt",
|
||||
"*.safetensors",
|
||||
]
|
||||
# Always ignore `.git` and `.cache/huggingface` folders in commits
|
||||
DEFAULT_IGNORE_PATTERNS = [
|
||||
".git",
|
||||
".git/*",
|
||||
"*/.git",
|
||||
"**/.git/**",
|
||||
".cache/huggingface",
|
||||
".cache/huggingface/*",
|
||||
"*/.cache/huggingface",
|
||||
"**/.cache/huggingface/**",
|
||||
]
|
||||
|
||||
|
||||
async def main(repo_id, revision="main", allow_patterns=None, ignore_patterns=None):
|
||||
async def progress_callback(event: RepoProgressEvent):
|
||||
print(f"Overall Progress: {event.completed_files}/{event.total_files} files, {event.downloaded_bytes}/{event.total_bytes} bytes")
|
||||
print(f"Estimated time remaining: {event.overall_eta}")
|
||||
print("File Progress:")
|
||||
for file_path, progress in event.file_progress.items():
|
||||
status_icon = {'not_started': '⚪', 'in_progress': '🔵', 'complete': '✅'}[progress.status]
|
||||
eta_str = str(progress.eta)
|
||||
print(f"{status_icon} {file_path}: {progress.downloaded}/{progress.total} bytes, "
|
||||
f"Speed: {progress.speed:.2f} B/s, ETA: {eta_str}")
|
||||
print("\n")
|
||||
|
||||
await download_all_files(repo_id, revision, progress_callback, allow_patterns, ignore_patterns)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Download files from a Hugging Face model repository.")
|
||||
parser.add_argument("--repo-id", required=True, help="The repository ID (e.g., 'meta-llama/Meta-Llama-3.1-8B-Instruct')")
|
||||
parser.add_argument("--revision", default="main", help="The revision to download (branch, tag, or commit hash)")
|
||||
parser.add_argument("--allow-patterns", nargs="*", default=None, help="Patterns of files to allow (e.g., '*.json' '*.safetensors')")
|
||||
parser.add_argument("--ignore-patterns", nargs="*", default=None, help="Patterns of files to ignore (e.g., '.*')")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(main(args.repo_id, args.revision, args.allow_patterns, args.ignore_patterns))
|
||||
@@ -74,9 +74,9 @@ def gen_diff(table_old, table_new):
|
||||
|
||||
def create_json_report(table, is_diff=False):
|
||||
timestamp = datetime.now(timezone.utc).isoformat()
|
||||
commit_sha = os.environ.get('CIRCLE_SHA1', 'unknown')
|
||||
branch = os.environ.get('CIRCLE_BRANCH', 'unknown')
|
||||
pr_number = os.environ.get('CIRCLE_PR_NUMBER', '')
|
||||
commit_sha = os.environ.get('GITHUB_SHA', 'unknown')
|
||||
branch = os.environ.get('GITHUB_REF_NAME', 'unknown')
|
||||
pr_number = os.environ.get('GITHUB_EVENT_NUMBER', '')
|
||||
|
||||
if is_diff:
|
||||
files = [{
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#!/bin/bash
|
||||
#!/usr/bin/env bash
|
||||
|
||||
if command -v python3.12 &>/dev/null; then
|
||||
echo "Python 3.12 is installed, proceeding with python3.12..."
|
||||
|
||||
@@ -6,6 +6,9 @@ import pkgutil
|
||||
|
||||
def run():
|
||||
site_packages = site.getsitepackages()[0]
|
||||
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
baseimages_dir = os.path.join(base_dir, "exo", "apputil", "baseimages")
|
||||
|
||||
command = [
|
||||
f"{sys.executable}", "-m", "nuitka", "exo/main.py",
|
||||
"--company-name=exolabs",
|
||||
@@ -15,7 +18,8 @@ def run():
|
||||
"--standalone",
|
||||
"--output-filename=exo",
|
||||
"--python-flag=no_site",
|
||||
"--onefile"
|
||||
"--onefile",
|
||||
f"--include-data-dir={baseimages_dir}=exo/apputil/baseimages"
|
||||
]
|
||||
|
||||
if sys.platform == "darwin":
|
||||
@@ -23,7 +27,7 @@ def run():
|
||||
"--macos-app-name=exo",
|
||||
"--macos-app-mode=gui",
|
||||
"--macos-app-version=0.0.1",
|
||||
"--macos-signed-app-name=com.exolabs.exo",
|
||||
"--macos-signed-app-name=net.exolabs.exo",
|
||||
"--include-distribution-meta=mlx",
|
||||
"--include-module=mlx._reprlib_fix",
|
||||
"--include-module=mlx._os_warning",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#!/bin/bash
|
||||
#!/usr/bin/env bash
|
||||
source ./install.sh
|
||||
pushd exo/networking/grpc
|
||||
python3 -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. node_service.proto
|
||||
|
||||
49
setup.py
49
setup.py
@@ -1,5 +1,6 @@
|
||||
import sys
|
||||
import platform
|
||||
import subprocess
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
@@ -11,7 +12,6 @@ install_requires = [
|
||||
"grpcio==1.68.0",
|
||||
"grpcio-tools==1.68.0",
|
||||
"Jinja2==3.1.4",
|
||||
"netifaces==0.11.0",
|
||||
"numpy==2.0.0",
|
||||
"nuitka==2.5.1",
|
||||
"nvidia-ml-py==12.560.30",
|
||||
@@ -23,27 +23,60 @@ install_requires = [
|
||||
"pydantic==2.9.2",
|
||||
"requests==2.32.3",
|
||||
"rich==13.7.1",
|
||||
"tenacity==9.0.0",
|
||||
"scapy==2.6.1",
|
||||
"tqdm==4.66.4",
|
||||
"transformers==4.46.3",
|
||||
"uuid==1.30",
|
||||
"tinygrad @ git+https://github.com/tinygrad/tinygrad.git@3b26e51fcebfc6576f4e0f99693e6f1406d61d79",
|
||||
"uvloop==0.21.0",
|
||||
"tinygrad @ git+https://github.com/tinygrad/tinygrad.git@ec120ce6b9ce8e4ff4b5692566a683ef240e8bc8",
|
||||
]
|
||||
|
||||
extras_require = {
|
||||
"formatting": [
|
||||
"yapf==0.40.2",
|
||||
],
|
||||
"formatting": ["yapf==0.40.2",],
|
||||
"apple_silicon": [
|
||||
"mlx==0.20.0",
|
||||
"mlx-lm==0.19.3",
|
||||
"mlx==0.22.0",
|
||||
"mlx-lm==0.21.1",
|
||||
],
|
||||
"windows": ["pywin32==308",],
|
||||
"nvidia-gpu": ["nvidia-ml-py==12.560.30",],
|
||||
"amd-gpu": ["pyrsmi==0.2.0"],
|
||||
}
|
||||
|
||||
# Check if running on macOS with Apple Silicon
|
||||
if sys.platform.startswith("darwin") and platform.machine() == "arm64":
|
||||
install_requires.extend(extras_require["apple_silicon"])
|
||||
|
||||
# Check if running Windows
|
||||
if sys.platform.startswith("win32"):
|
||||
install_requires.extend(extras_require["windows"])
|
||||
|
||||
|
||||
def _add_gpu_requires():
|
||||
global install_requires
|
||||
# Add Nvidia-GPU
|
||||
try:
|
||||
out = subprocess.run(['nvidia-smi', '--query-gpu=name', '--format=csv,noheader'], shell=True, text=True, capture_output=True, check=False)
|
||||
if out.returncode == 0:
|
||||
install_requires.extend(extras_require["nvidia-gpu"])
|
||||
except subprocess.CalledProcessError:
|
||||
pass
|
||||
|
||||
# Add AMD-GPU
|
||||
# This will mostly work only on Linux, amd/rocm-smi is not yet supported on Windows
|
||||
try:
|
||||
out = subprocess.run(['amd-smi', 'list', '--csv'], shell=True, text=True, capture_output=True, check=False)
|
||||
if out.returncode == 0:
|
||||
install_requires.extend(extras_require["amd-gpu"])
|
||||
except:
|
||||
out = subprocess.run(['rocm-smi', 'list', '--csv'], shell=True, text=True, capture_output=True, check=False)
|
||||
if out.returncode == 0:
|
||||
install_requires.extend(extras_require["amd-gpu"])
|
||||
finally:
|
||||
pass
|
||||
|
||||
|
||||
_add_gpu_requires()
|
||||
|
||||
setup(
|
||||
name="exo",
|
||||
version="0.0.1",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#!/bin/bash
|
||||
#!/usr/bin/env bash
|
||||
|
||||
echo "Starting node 1"
|
||||
DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 52415 --chatgpt-api-response-timeout 900 > output1.log 2>&1 &
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Add the project root to the Python path
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
import asyncio
|
||||
from exo.download.hf.hf_helpers import get_weight_map
|
||||
|
||||
async def test_get_weight_map():
|
||||
repo_ids = [
|
||||
"mlx-community/quantized-gemma-2b",
|
||||
"mlx-community/Meta-Llama-3.1-8B-4bit",
|
||||
"mlx-community/Meta-Llama-3.1-70B-4bit",
|
||||
"mlx-community/Meta-Llama-3.1-405B-4bit",
|
||||
]
|
||||
for repo_id in repo_ids:
|
||||
weight_map = await get_weight_map(repo_id)
|
||||
assert weight_map is not None, "Weight map should not be None"
|
||||
assert isinstance(weight_map, dict), "Weight map should be a dictionary"
|
||||
assert len(weight_map) > 0, "Weight map should not be empty"
|
||||
print(f"OK: {repo_id}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_get_weight_map())
|
||||
@@ -24,7 +24,7 @@ def test_tokenizer(name, tokenizer, verbose=False):
|
||||
strip_tokens = lambda s: s.lstrip(tokenizer.decode([tokenizer.bos_token_id])).rstrip(tokenizer.decode([tokenizer.eos_token_id]))
|
||||
assert text == strip_tokens(decoded) == strip_tokens(reconstructed)
|
||||
|
||||
ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", "llava-hf/llava-1.5-7b-hf", "mlx-community/Qwen*", "dummy", "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit"]
|
||||
ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", "llava-hf/llava-1.5-7b-hf", "mlx-community/Qwen*", "dummy", "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit", "mlx-community/Phi-3.5-mini-instruct-4bit", "mlx-community/phi-4-4bit", "stabilityai/stable-diffusion-2-1-base"]
|
||||
ignore_pattern = re.compile(r"^(" + "|".join(model.replace("*", ".*") for model in ignore) + r")")
|
||||
models = []
|
||||
for model_id in model_cards:
|
||||
@@ -37,5 +37,6 @@ verbose = os.environ.get("VERBOSE", "0").lower() == "1"
|
||||
for m in models:
|
||||
# TODO: figure out why use_fast=False is giving inconsistent behaviour (no spaces decoding invididual tokens) for Mistral-Large-Instruct-2407-4bit
|
||||
# test_tokenizer(m, AutoProcessor.from_pretrained(m, use_fast=False), verbose)
|
||||
test_tokenizer(m, AutoProcessor.from_pretrained(m, use_fast=True), verbose)
|
||||
test_tokenizer(m, AutoTokenizer.from_pretrained(m), verbose)
|
||||
if m not in ["mlx-community/DeepSeek-R1-4bit", "mlx-community/DeepSeek-R1-3bit", "mlx-community/DeepSeek-V3-4bit", "mlx-community/DeepSeek-V3-3bit"]:
|
||||
test_tokenizer(m, AutoProcessor.from_pretrained(m, use_fast=True, trust_remote_code=True), verbose)
|
||||
test_tokenizer(m, AutoTokenizer.from_pretrained(m, trust_remote_code=True), verbose)
|
||||
|
||||
Reference in New Issue
Block a user