Compare commits

...

442 Commits

Author SHA1 Message Date
Alex Cheema
af734f1bf6 Merge pull request #737 from exo-explore/handlegzipdownload
handle -gzip suffix in etag for integrity check fixes #633
2025-02-25 22:10:05 +00:00
Alex Cheema
ee095766d9 handle -gzip suffix in etag for integrity check fixes #633 2025-02-25 22:08:15 +00:00
Alex Cheema
a605e233ad Merge pull request #709 from exo-explore/notice
update notice in README
2025-02-18 11:43:14 +00:00
Alex Cheema
f9a1e5342b update notice in README 2025-02-18 11:41:09 +00:00
Alex Cheema
7a374a74cd Merge pull request #708 from exo-explore/notice
add notice to README
2025-02-17 22:55:44 +00:00
Alex Cheema
5a00899d73 Merge pull request #705 from cadenmackenzie/addingModelNameInputContainer
adding current model name to input container information
2025-02-17 22:55:29 +00:00
Alex Cheema
cb4bee2694 add notice to README 2025-02-17 22:54:56 +00:00
Caden MacKenzie
9078d094b9 adding current model name to input container information 2025-02-16 18:34:38 -08:00
Alex Cheema
ed70d47cfd Merge pull request #702 from exo-explore/alwayslogdownloaderror
make max_parallel_downloads configurable, increase download chunk size to 8MB
2025-02-14 21:27:12 +00:00
Alex Cheema
477e3a5e4c make max_parallel_downloads configurable, increase download chunk size to 8MB 2025-02-14 21:26:41 +00:00
Alex Cheema
be3b9ee973 Merge pull request #698 from exo-explore/alwayslogdownloaderror
always log download errors. some people e.g cant access huggingface
2025-02-13 22:56:33 +00:00
Alex Cheema
b4e6f8acad always log download errors. some people eg cant access huggingface which causes confusion 2025-02-13 22:55:09 +00:00
Alex Cheema
de99da7c75 Merge pull request #684 from divinity76/patch-1
workaround f16 cast ambiguity
2025-02-08 12:45:10 +00:00
Alex Cheema
76d1bd95f5 Merge pull request #688 from exo-explore/readmeupdate
apt-get debian noninteractive in circleci
2025-02-08 02:41:19 +00:00
Alex Cheema
928214d479 apt-get debian noninteractive in circleci 2025-02-08 02:40:51 +00:00
Alex Cheema
ce34a886c2 Merge pull request #687 from exo-explore/readmeupdate
README updates
2025-02-08 02:15:50 +00:00
Alex Cheema
d8c3aed0cc update discovery / peer networking modules 2025-02-08 02:15:13 +00:00
Alex Cheema
2c982d9295 update README to better reflect support for other devices like NVIDIA and Pi's 2025-02-08 02:13:04 +00:00
divinity76
5fe241ec61 code-breaking typo
oops
2025-02-06 19:02:02 +01:00
divinity76
05ff20fa89 workaround f16 cast ambiguity
for unknown reasons, without this, when trying to execute "Llama 3.2 1B", I get the error below. Fwiw I do not know the performance impact for this change. I can't even get exo running, but this change allows me to /get further/ (before running into a second issue with vram allocation? story for another day i suppose)


error: 
Failed to fetch completions: Error processing prompt (see logs with DEBUG>=2): Nvrtc Error 6, NVRTC_ERROR_COMPILATION
<null>(18): error: more than one user-defined conversion from "nv_bfloat16" to "half" applies:
            function "__half::__half(float)" (declared at line 214 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(short)" (declared at line 227 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(unsigned short)" (declared at line 228 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(int)" (declared at line 229 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(unsigned int)" (declared at line 230 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(long long)" (declared at line 231 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(unsigned long long)" (declared at line 232 of /usr/include/cuda_fp16.hpp)
    *((half4*)((data0+(alu0+(gidx1<<14)+(lidx0<<11)+alu1)))) = make_half4(((half)(val0)),((half)(val1)),((half)(val2)),((half)(val3)));
                                                                                 ^

<null>(18): error: more than one user-defined conversion from "nv_bfloat16" to "half" applies:
            function "__half::__half(float)" (declared at line 214 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(short)" (declared at line 227 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(unsigned short)" (declared at line 228 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(int)" (declared at line 229 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(unsigned int)" (declared at line 230 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(long long)" (declared at line 231 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(unsigned long long)" (declared at line 232 of /usr/include/cuda_fp16.hpp)
    *((half4*)((data0+(alu0+(gidx1<<14)+(lidx0<<11)+alu1)))) = make_half4(((half)(val0)),((half)(val1)),((half)(val2)),((half)(val3)));
                                                                                                ^

<null>(18): error: more than one user-defined conversion from "nv_bfloat16" to "half" applies:
            function "__half::__half(float)" (declared at line 214 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(short)" (declared at line 227 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(unsigned short)" (declared at line 228 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(int)" (declared at line 229 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(unsigned int)" (declared at line 230 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(long long)" (declared at line 231 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(unsigned long long)" (declared at line 232 of /usr/include/cuda_fp16.hpp)
    *((half4*)((data0+(alu0+(gidx1<<14)+(lidx0<<11)+alu1)))) = make_half4(((half)(val0)),((half)(val1)),((half)(val2)),((half)(val3)));
                                                                                                               ^

<null>(18): error: more than one user-defined conversion from "nv_bfloat16" to "half" applies:
            function "__half::__half(float)" (declared at line 214 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(short)" (declared at line 227 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(unsigned short)" (declared at line 228 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(int)" (declared at line 229 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(unsigned int)" (declared at line 230 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(long long)" (declared at line 231 of /usr/include/cuda_fp16.hpp)
            function "__half::__half(unsigned long long)" (declared at line 232 of /usr/include/cuda_fp16.hpp)
    *((half4*)((data0+(alu0+(gidx1<<14)+(lidx0<<11)+alu1)))) = make_half4(((half)(val0)),((half)(val1)),((half)(val2)),((half)(val3)));
                                                                                                                              ^

4 errors detected in the compilation of "<null>".
2025-02-06 18:54:15 +01:00
Alex Cheema
b5fc4bc288 Merge pull request #675 from exo-explore/rmtenacity
remove tenacity dependency, implement simple retry logic instead
2025-02-03 21:58:08 +00:00
Alex Cheema
5157d80a46 remove tenacity dependency, implement simple retry logic instead 2025-02-03 21:56:38 +00:00
Alex Cheema
75914b4de8 Merge pull request #669 from pavel-rodionov/feature-local-models
Add toggle to show only models downloaded locally
2025-02-03 21:45:27 +00:00
Rodionov Pavel
d084dbe574 Add toggle to show only models downloaded locally 2025-02-01 23:45:19 -08:00
Alex Cheema
1a77a52d71 Merge pull request #666 from exo-explore/patchmanualdiscovery
patch for manual discovery, set known_peers
2025-02-01 23:07:21 +00:00
Alex Cheema
72329ba984 patch for manual discovery, set known_peers 2025-02-01 23:06:57 +00:00
Alex Cheema
f663b0afa2 Merge pull request #665 from exo-explore/resumedownload
add model downloading section to README
2025-02-01 20:23:58 +00:00
Alex Cheema
51b5c2ca9b add model downloading section to README 2025-02-01 20:23:05 +00:00
Alex Cheema
9a1f0a85e6 Merge pull request #664 from exo-explore/resumedownload
resumable downloads with integrity checks
2025-02-01 18:34:36 +00:00
Alex Cheema
2c0d17c336 beautiful download 2025-02-01 17:29:19 +00:00
Alex Cheema
7034ee0fcb resumable downloads with integrity checks 2025-02-01 13:22:51 +00:00
Alex Cheema
7a75fb09b2 Merge pull request #660 from exo-explore/robustdownload
cleanup tmp files on failed download
2025-01-30 20:25:15 +00:00
Alex Cheema
0bebf8dfde fix indent 2025-01-30 20:21:28 +00:00
Alex Cheema
55c4385db5 cleanup tmp files on failed download 2025-01-30 20:11:06 +00:00
Alex Cheema
90690a7d10 Merge pull request #647 from deftdawg/patch-1
Add 4-bit to the end of DeepSeek V3/R1 model descriptions
2025-01-30 19:49:38 +00:00
Alex Cheema
130d998d36 Merge pull request #659 from exo-explore/robustdownload
ensure exo dir on start, retry with exp backoff on file downloads
2025-01-30 19:49:00 +00:00
Alex Cheema
788c49784c retry fetch_file_list also 2025-01-30 19:45:12 +00:00
Alex Cheema
6b1c8635fc ensure exo dir on start, retry with exp backoff on file downloads 2025-01-30 19:40:35 +00:00
Alex Cheema
24c410c19c Merge pull request #653 from exo-explore/tinyfixes
Tiny fixes
2025-01-29 19:08:05 +00:00
Alex Cheema
f6ed830ba6 Merge pull request #651 from exo-explore/parallelise_model_loadin
parallelise model loading
2025-01-29 19:07:25 +00:00
Alex Cheema
e6b4f2993c fix prompt output spacing in tui 2025-01-29 19:01:30 +00:00
DeftDawg
a25e02c913 Add 4-bit to the end of DeepSeek V3/R1 model descriptions 2025-01-29 14:00:13 -05:00
Alex Cheema
3675804f4d throttle repo progress events and only send them out if something changed 2025-01-29 18:55:54 +00:00
Alex Cheema
96f1aecb05 only in_progress if any given file is in_progress 2025-01-29 18:43:43 +00:00
Alex Cheema
23a5030604 even if part of a file is downloaded it may not be in_progress 2025-01-29 18:39:23 +00:00
Alex Cheema
31b56e862f make a singleton thread pool executor for tinygrad since we always want it to run on the same thread 2025-01-29 18:37:09 +00:00
Alex Cheema
9f6c688d62 update tinygrad 2025-01-29 18:06:38 +00:00
Alex Cheema
4887be5103 parallelise model loading 2025-01-29 02:32:59 +00:00
Alex Cheema
75091e206b Merge pull request #650 from exo-explore/chatgpttimeout
increase chatgpt api response timeout to 900 seconds
2025-01-29 02:03:52 +00:00
Alex Cheema
141de0d011 increase chatgpt api response timeout to 900 seconds 2025-01-29 02:03:00 +00:00
Alex Cheema
263b18a31e Merge pull request #649 from eclecticc/amd_fix
Fix AMD device capabilities fields
2025-01-29 02:01:06 +00:00
Nirav Patel
9cf6818f10 Fix AMD device capabilities fields 2025-01-28 16:58:58 -08:00
Alex Cheema
837ed5d980 Merge pull request #648 from exo-explore/modelasyncload
Fixes
2025-01-28 23:39:11 +00:00
Alex Cheema
9c1bea97e8 fix embed_tokens for last layer in qwen models 2025-01-28 23:09:45 +00:00
Alex Cheema
af171f06fa propagate prompts to other nodes so they can display them, cleaner prompt/output output 2025-01-28 21:50:49 +00:00
Alex Cheema
edfa53a4c2 Merge pull request #646 from exo-explore/modelasyncload
make sure mlx stuff is on separate thread non blocking
2025-01-28 18:56:19 +00:00
Alex Cheema
4a5b80a958 make sure mlx stuff is on separate thread non blocking 2025-01-28 18:56:00 +00:00
Alex Cheema
92d1bc01de Merge pull request #645 from exo-explore/modelasyncload
load mlx model shard on mlx thread so it doesnt block
2025-01-28 18:49:47 +00:00
Alex Cheema
6662d5668c load mlx model shard on mlx thread so it doesnt block 2025-01-28 18:49:19 +00:00
Alex Cheema
a0d673fa3a Merge pull request #640 from exo-explore/simpledownload
Simple download
2025-01-27 19:38:11 +00:00
Alex Cheema
7c649085a1 fix eta/speed for resuming an existing download, using the session downloaded bytes 2025-01-27 19:23:18 +00:00
Alex Cheema
90e0e2761f ignore not_started progress updates 2025-01-27 06:05:59 +00:00
Alex Cheema
265586f7b4 set timeout on get too 2025-01-27 06:05:40 +00:00
Alex Cheema
4748bb7dc7 increase file download timeout to 30min 2025-01-27 05:49:17 +00:00
Alex Cheema
ae770db4f3 increase download chunks to 1MB 2025-01-27 05:37:50 +00:00
Alex Cheema
82f75d0ccf increase hf download http timeout 15 mins for large downloads 2025-01-27 05:20:30 +00:00
Alex Cheema
295f41c5cc increase bench job timeout to give enough time to download 2025-01-27 05:03:35 +00:00
Alex Cheema
19a27c5bfd HF_HOME -> EXO_HOME 2025-01-27 02:59:23 +00:00
Alex Cheema
d7ca9b7732 show each node id in the tinychat topology viz 2025-01-27 02:20:22 +00:00
Alex Cheema
b349e48b0d fix visual bug where frontend would show the full hf repo size, but in some cases that includes redundant files so we should use the model index in those cases too 2025-01-27 02:13:05 +00:00
Alex Cheema
21586063f6 use llama-3.2-1b in tinygrad test 2025-01-27 01:35:33 +00:00
Alex Cheema
277d63d860 special case when a model doesnt have a model index file, then use wildcard for allow_patterns 2025-01-27 01:26:15 +00:00
Alex Cheema
74379ef671 log download logs with DEBUG>=6 very verbose 2025-01-27 01:11:54 +00:00
Alex Cheema
3c7bd48aa3 get rid of some more hf bloat 2025-01-27 01:08:46 +00:00
Alex Cheema
1df023023e remove a lot of hf bloat 2025-01-27 01:06:47 +00:00
Alex Cheema
b89495f444 rewrite ShardDownloader, simplify significantly 2025-01-27 00:37:57 +00:00
Alex Cheema
903950f64e Merge pull request #638 from exo-explore/deepseekv3fix
add exception for mlx-community/DeepSeek-R1-3bit and mlx-community/DeepSeek-V3-3bit in tokenizers test
2025-01-26 20:33:22 +00:00
Alex Cheema
a3766f538a add exception for mlx-community/DeepSeek-R1-3bit and mlx-community/DeepSeek-V3-3bit in tokenizers test 2025-01-26 20:32:48 +00:00
Alex Cheema
9711d632e0 Merge pull request #637 from exo-explore/deepseekv3fix
fix post_init deepseek v3
2025-01-26 20:31:53 +00:00
Alex Cheema
82ef086010 add deepseek-v3-3bit and deepseek-r1-3bit 2025-01-26 20:31:28 +00:00
Alex Cheema
55ea366932 fix post_init deepseek v3 2025-01-26 20:27:31 +00:00
Alex Cheema
63318983de Merge pull request #631 from sigseg5/main
Some adaptivity fixes in tinychat
2025-01-26 19:20:58 +00:00
sigseg5
fb841a1f50 Adjust truncate size in history list for text without any spaces 2025-01-26 00:38:58 +03:00
sigseg5
4512366580 Fix bubble behavior when user passes long text without any spaces 2025-01-26 00:02:17 +03:00
sigseg5
9525c0e7a7 Add adaptive padding for user and assistant messages on width <= 1480px 2025-01-26 00:01:54 +03:00
Alex Cheema
66f73768cc Merge pull request #627 from exo-explore/deepseek
Deepseek, tinychat group models, latex formatting, thinking boxes
2025-01-24 18:14:57 +00:00
Alex Cheema
fdd05baddb fix tokenizer tests 2025-01-24 18:13:36 +00:00
Alex Cheema
59174bdc62 we have a lot of models so group them nicely 2025-01-24 18:02:00 +00:00
Alex Cheema
cfdaaef8e6 handle thinking outputs nicely, format latex beautifully 2025-01-24 17:49:25 +00:00
Alex Cheema
d8ffa59dba add deepseek v1, v3 and all the distills 2025-01-24 16:39:38 +00:00
Alex Cheema
aa1ce21f82 Merge pull request #625 from eltociear/patch-1
chore: update manual_discovery.py
2025-01-23 16:51:32 +00:00
Ikko Eltociear Ashimine
4fb01f516d chore: update manual_discovery.py
occured -> occurred
2025-01-24 00:18:42 +09:00
Alex Cheema
a635b23044 Merge pull request #619 from exo-explore/runners2
fix readme images
2025-01-23 02:18:33 +00:00
Alex Cheema
ad0e0d02d8 fix readme images 2025-01-23 02:17:58 +00:00
Alex Cheema
2644fd02c8 Merge pull request #617 from exo-explore/runners2
Lots of fixes and QoL improvements.
2025-01-23 02:05:17 +00:00
Alex Cheema
88ac12df6c install clang test 2025-01-23 01:55:14 +00:00
Alex Cheema
dfd9d3eb48 linux install 2025-01-23 01:44:57 +00:00
Alex Cheema
200ff4d713 linux install 2025-01-23 01:43:00 +00:00
Alex Cheema
b2764f177f linux install 2025-01-23 01:40:59 +00:00
Alex Cheema
e57fa1dfa0 xlarge 2025-01-23 01:40:13 +00:00
Alex Cheema
209163c595 add linux tinygrad test 2025-01-23 01:38:10 +00:00
Alex Cheema
495987b50b beef up the instance 2025-01-23 01:37:38 +00:00
Alex Cheema
8484eb4165 fix config 2025-01-23 01:37:01 +00:00
Alex Cheema
790c08afd4 add linux tinygrad test 2025-01-23 01:31:44 +00:00
Alex Cheema
a8a9e3ffa1 explicitly enable TOKENIZERS_PARALLELISM=true 2025-01-23 01:26:27 +00:00
Alex Cheema
5c9bcb8620 set GRPC_VERBOSITY=error; TRANSFORMERS_VERBOSITY=error 2025-01-23 01:22:19 +00:00
Alex Cheema
d54e19c20a runners back 2025-01-23 00:55:52 +00:00
Alex Cheema
cc78738e24 remove kern scan intervals 2025-01-23 00:49:32 +00:00
Alex Cheema
2391051c11 remove kern.timer.scan_interval from bootstrap.sh 2025-01-23 00:41:40 +00:00
Alex Cheema
112dea1582 add back the benchmarks baby 2025-01-23 00:15:54 +00:00
Alex Cheema
dc5cdc4d78 add back opaque 2025-01-22 23:59:39 +00:00
Alex Cheema
f8db4e131e fix check for sd2.1 2025-01-22 23:53:42 +00:00
Alex Cheema
bbb6856988 fix check for sd2.1 2025-01-22 23:51:09 +00:00
Alex Cheema
9ba8bbbcf8 fix filter to include 169.254.* since thats what mac uses for ethernet 2025-01-22 23:47:43 +00:00
Alex Cheema
8ab9977f01 fix stable diffusion case for tui, make mlx run on its own thread again and non-blocking 2025-01-22 23:22:53 +00:00
Alex Cheema
3a4bae0dab fix issue with eos_token_id 2025-01-22 22:58:09 +00:00
Alex Cheema
87d1271d33 fix stream: false completion 2025-01-22 22:46:04 +00:00
Alex Cheema
55d1846f5e clean up DEBUG=2 logs, a few fixes for token 2025-01-22 22:27:02 +00:00
Alex Cheema
9954ce8e4d fix treating token as a list 2025-01-22 22:13:13 +00:00
Alex Cheema
09e12d8673 temporarily disable github runner benchmarks 2025-01-22 22:00:13 +00:00
Alex Cheema
98d6e986bd add back .circleci 2025-01-22 21:58:46 +00:00
Alex Cheema
d80324fe20 disable test-m3-single-node 2025-01-22 21:58:40 +00:00
Alex Cheema
97f3bad38f fix peer_handle 2025-01-22 21:07:49 +00:00
Alex Cheema
461e4f37cb Merge remote-tracking branch 'origin/main' into runners2 2025-01-22 21:06:12 +00:00
Alex Cheema
07ceb19f0a Merge pull request #614 from samiamjidkhan/main
animation fix
2025-01-22 14:59:54 +00:00
Sami Khan
27b4577f38 directory for images 2025-01-22 05:47:25 -05:00
Sami Khan
a70943f8d2 base images for animation 2025-01-22 05:46:38 -05:00
Alex Cheema
410d901505 Merge pull request #613 from samiamjidkhan/dmg-backend
image and text mode fix
2025-01-21 13:12:08 +00:00
Sami Khan
5c4ce5392c image and text mode fix 2025-01-21 04:33:54 -05:00
Alex Cheema
819ec7626e Merge pull request #611 from exo-explore/fixbuildname
fix scripts/build_exo.py: com.exolabs.exo -> net.exolabs.exo
2025-01-21 05:36:34 +00:00
Alex Cheema
ba5bb3e171 fix scripts/build_exo.py: com.exolabs.exo -> net.exolabs.exo 2025-01-21 05:36:02 +00:00
Alex Cheema
f4bbcf4c8f Merge pull request #607 from tensorsofthewall/smol_fix
Fixes for cross-platform operability
2025-01-21 02:21:18 +00:00
Alex Cheema
6b8cd0577e fix some issues with results 2025-01-20 16:30:16 +00:00
Alex Cheema
218c1e79d9 Merge branch 'main' into runners2 2025-01-20 16:12:55 +00:00
Sandesh Bharadwaj
b9eccedc3d Formatting 2025-01-17 05:40:42 -05:00
Sandesh Bharadwaj
5f06aa2759 Replace netifaces (unmaintained,outdated) with scapy + add dependencies for previous fixes 2025-01-17 05:37:01 -05:00
Sandesh Bharadwaj
349b5344eb Minor fix for Shard typing 2025-01-16 14:36:46 -05:00
Sandesh Bharadwaj
df3624d27a Add AMD GPU querying + Windows device capabilities 2025-01-14 20:37:02 -05:00
Sandesh Bharadwaj
6737e36e23 Fixed MLX import blocking native Windows execution of exo. (Not Final) 2025-01-14 20:35:21 -05:00
Alex Cheema
c260689a06 Merge pull request #602 from exo-explore/fixexodir
fix exo folder
2025-01-12 03:46:14 +00:00
Alex Cheema
fcc699a55f fix 2025-01-12 03:40:59 +00:00
Alex Cheema
e7b98f5ae5 fix unit tests 2025-01-12 03:35:24 +00:00
Alex Cheema
ffe78f6d0b fix dummy test 2025-01-12 03:30:06 +00:00
Alex Cheema
ce5041ee1b types 2025-01-12 03:24:42 +00:00
Alex Cheema
9b2c01c873 ensure dir exists 2025-01-12 03:15:49 +00:00
Alex Cheema
2aed3f3518 handle inference_state properly 2025-01-12 03:13:17 +00:00
Alex Cheema
2af5ee02e4 fix exo folder 2025-01-12 03:10:11 +00:00
Alex Cheema
b5cbcbc7a2 Merge pull request #474 from pranav4501/stable-stable-diffusion-mlx
Stable diffusion mlx
2025-01-12 02:57:21 +00:00
Alex Cheema
5f3d000a7b Merge branch 'main' into stable-stable-diffusion-mlx 2025-01-12 02:56:34 +00:00
Alex Cheema
bd2e8e7a5a Merge pull request #598 from exo-explore/fixphitest
typo in phi test
2025-01-08 22:09:38 +00:00
Alex Cheema
40696b21f7 typo in phi test 2025-01-08 22:09:04 +00:00
Alex Cheema
4937fb3df8 Merge pull request #597 from exo-explore/tuioverflow
Tui overflow
2025-01-08 16:40:16 +00:00
Alex Cheema
2d631ea53d Merge pull request #596 from exo-explore/phi4
add phi 3.5, phi 4
2025-01-08 16:39:32 +00:00
Alex Cheema
2846a9122f tok tests 2025-01-08 16:39:11 +00:00
Alex Cheema
553ccce728 fix prompt and output overflow in tui 2025-01-08 16:36:56 +00:00
Alex Cheema
c587593364 add phi 3.5, phi 4 2025-01-08 16:19:43 +00:00
Alex Cheema
3c9efe103d Merge pull request #590 from metaspartan/fix-models-api
Fix the /v1/models API to output proper OpenAI compatible endpoint
2025-01-07 02:32:06 +00:00
Carsen Klock
627bfcae7c Fix the /v1/models API to output proper OpenAI compatible endpoint
Modify the `/v1/models` API to output a proper OpenAI compatible endpoint with an object and a `data` object containing the models list.

* Change the `handle_get_models` method in `exo/api/chatgpt_api.py` to wrap the models list in an object with a `data` field.
* Add an `object` field with the value "list" to the response format.

---

For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/metaspartan/exo?shareId=XXXX-XXXX-XXXX-XXXX).
2025-01-06 01:20:30 -07:00
Alex Cheema
d9a836f152 Merge pull request #588 from exo-explore/betterdl
better download
2025-01-05 02:35:04 +00:00
Alex Cheema
29244c6369 fix args for ensure_shard 2025-01-05 02:33:25 +00:00
Alex Cheema
8c191050a2 download status in parallel, support async ensure shard with using shard_downloader instead 2025-01-05 02:31:59 +00:00
Alex Cheema
7b1656140e Merge pull request #585 from pepebruari/main
Add --system-prompt to exo cli
2025-01-03 23:49:50 +00:00
pepebruari
fe50d4d34d Add --system-prompt to exo cli 2025-01-03 16:16:22 -05:00
Alex Cheema
03aa6cecf1 Merge pull request #584 from exo-explore/AlexCheema-patch-1
add trending badge to README.md
2024-12-31 17:51:10 +00:00
Alex Cheema
178cc4d961 add trending badge to README.md 2024-12-31 17:50:29 +00:00
Pranav Veldurthi
b13e368368 fix inference engine 2024-12-30 19:41:19 -05:00
Pranav Veldurthi
9986fb86d4 remove prints and fix download progress for SD 2024-12-30 19:07:37 -05:00
Pranav Veldurthi
3475be9e9e Remove build 2024-12-30 18:39:17 -05:00
Pranav Veldurthi
fff8a1a690 fix inference engine for inference state 2024-12-30 18:36:53 -05:00
Pranav Veldurthi
54605299b8 Merge Latest 2024-12-30 18:36:23 -05:00
Alex Cheema
a174c78004 Merge pull request #383 from ianpaul10/feat/manual-disc-follow-up
Support changing manual configuration while running
2024-12-28 11:57:25 +00:00
Ian Paul
b003292b89 formatting and fixing tests after rebasing 2024-12-28 12:31:15 +07:00
Ian Paul
1dfd058c23 rm unecessary lock 2024-12-28 12:13:34 +07:00
Ian Paul
2eadaa2c0d rm redundant cleanup task 2024-12-28 12:13:34 +07:00
Ian Paul
637446ffa9 rm redundant typing 2024-12-28 12:13:34 +07:00
Ian Paul
a31f9e6c20 fix test warnings 2024-12-28 12:13:34 +07:00
Ian Paul
18acb97b42 make popping from dict threadsafe 2024-12-28 12:11:51 +07:00
Ian Paul
b066c944f3 make all I/O ops in manual_discovery.py run inside a ThreadPoolExecutor 2024-12-28 12:11:51 +07:00
Ian Paul
0e34ce2169 patch after rebasing to main 2024-12-28 12:11:51 +07:00
Ian Paul
90de7eada9 changes after rebase 2024-12-28 12:11:51 +07:00
Ian Paul
8d24df2b4b fix test runtime warning 2024-12-28 12:11:50 +07:00
Ian Paul
e5eb3259a5 handle when a peer is removed from config, so the known_peers dict gets updated accordingly 2024-12-28 12:11:21 +07:00
Ian Paul
2e8227fccb handle intermediate state for when config is being updated 2024-12-28 12:11:21 +07:00
Ian Paul
98118babae allow update to manual discovery file
re-load manual discovery file for each runthrough of the peer network, allowing incremental updates to the peer file even when exo is running
2024-12-28 12:11:21 +07:00
Alex Cheema
496a3b49f5 Merge pull request #561 from VerisimilitudeX/patch-1
Improved clarity, fixed typos, added macOS/Linux examples, and enhanc…
2024-12-27 17:06:00 +00:00
Alex Cheema
aba1bed5ed Merge pull request #575 from exo-explore/fixtok
Revert "Merge pull request #573 from damho1104/feature/add-exaone-3.5…
2024-12-27 16:36:34 +00:00
Alex Cheema
e08522ee97 Revert "Merge pull request #573 from damho1104/feature/add-exaone-3.5-model"
This reverts commit 4eb6a6a74a, reversing
changes made to fdc3b5ac02.
2024-12-27 16:35:54 +00:00
Alex Cheema
4eb6a6a74a Merge pull request #573 from damho1104/feature/add-exaone-3.5-model
Add exaone-3.5-2.4b, exaone-3.5-7.8b
2024-12-27 12:36:09 +00:00
damho.lee
94a5e908b0 add exaone-3.5 LLM Model 2024-12-24 20:57:11 +09:00
Alex Cheema
fdc3b5ac02 Merge pull request #571 from exo-explore/function_calling
add chatgpt-api-compatible tools for function calling
2024-12-24 02:08:48 +00:00
Alex Cheema
185b1e375c fix names in dummy tokenizer 2024-12-24 02:08:20 +00:00
Alex Cheema
078b807654 fix names of qwen models 2024-12-24 02:06:13 +00:00
Alex Cheema
188ac445c9 function calling example with weather tool 2024-12-24 01:57:17 +00:00
Alex Cheema
456fbdd2b0 add chatgpt-api-compatible tools for function calling 2024-12-24 01:51:55 +00:00
Alex Cheema
41df9ce1d7 Merge pull request #570 from exo-explore/moreqwen
add qwen-2.5-1.5b, qwen-2.5-3b, qwen-2.5-32b
2024-12-24 01:51:26 +00:00
Alex Cheema
c609c05e40 add qwen-2.5-1.5b, qwen-2.5-3b, qwen-2.5-32b 2024-12-24 01:50:12 +00:00
Alex Cheema
ba8c514974 Merge pull request #569 from deftdawg/env_bash
Use `#!/usr/bin/env bash` for better portability
2024-12-22 23:25:38 +00:00
DeftDawg
cde912deef - Use #!/usr/bin/env bash instead of #!/bin/bash for better portability 2024-12-22 01:14:54 -05:00
Piyush Acharya
154e0f58e4 Implement suggestiond 2024-12-21 19:40:53 -08:00
Piyush Acharya
6c82365ee2 Improved clarity, fixed typos, added macOS/Linux examples, and enhanced installation/debugging instructions 2024-12-17 18:02:34 -08:00
Alex Cheema
023ddc207e support different network interface tests 2024-12-17 21:03:00 +00:00
Alex Cheema
2f0b543a1e add peer connection info to tinychat 2024-12-17 17:37:40 +00:00
Alex Cheema
7ac4004392 change it back to collecting topology periodically even if peers dont change 2024-12-17 17:32:18 +00:00
Alex Cheema
198308b1eb more robust udp broadcast 2024-12-17 17:28:55 +00:00
Alex Cheema
1f108a06ff remove test sleep 2024-12-17 16:47:05 +00:00
Alex Cheema
3a58576f8c make sure this is actually doing something 2024-12-17 16:22:22 +00:00
Alex Cheema
0a07223074 switch to uvloop (faster asyncio event loop) and optimise grpc settings 2024-12-17 16:10:56 +00:00
Alex Cheema
58f0a0f547 optimise grpc parameters 2024-12-17 14:50:52 +00:00
Pranav Veldurthi
5c0cd1839b Update strength image to image gen 2024-12-16 18:40:36 -05:00
Alex Cheema
e2474c3f15 fail if we never get the desired node count 2024-12-16 21:59:02 +00:00
Alex Cheema
1b14be6013 make device_capabilities async running on a thread pool 2024-12-16 21:17:30 +00:00
Alex Cheema
036224f877 add topology to tinychat ui 2024-12-16 21:17:12 +00:00
Alex Cheema
b17faa8199 dont broadcast every single process_tensor 2024-12-16 20:54:38 +00:00
Alex Cheema
35d90d947c Merge remote-tracking branch 'origin/main' into runners 2024-12-16 20:04:03 +00:00
Alex Cheema
8d94b8ae12 trigger test 2024-12-16 20:03:22 +00:00
Alex Cheema
99a70f1045 Merge commit: trigger test 2024-12-16 20:01:23 +00:00
Alex Cheema
bd0febe35f Merge commit: trigger test 2024-12-16 20:01:09 +00:00
Alex Cheema
34ecbbe01c Merge commit: trigger test 2024-12-16 20:00:50 +00:00
Alex Cheema
427d0718b3 Merge commit: trigger test 2024-12-16 20:00:39 +00:00
Alex Cheema
b49c4ca0e5 Merge commit: trigger test 2024-12-16 20:00:21 +00:00
Alex Cheema
41eaaec5a9 Merge commit: trigger test 2024-12-16 20:00:10 +00:00
Alex Cheema
bf1aafdea7 Merge commit: trigger test 2024-12-16 19:59:51 +00:00
Alex Cheema
bfa06ee9f3 Merge commit: trigger test 2024-12-16 19:59:39 +00:00
Alex Cheema
c0534b67c3 Merge commit: trigger test 2024-12-16 19:59:08 +00:00
Alex Cheema
063964aab3 remove redundant sample_logits, put back opaque status for process_prompt so we have a way of preemptively starting downloads 2024-12-16 19:50:36 +00:00
Alex Cheema
804ad4705a upgrade mlx 2024-12-16 19:50:33 +00:00
Alex Cheema
c9ded9ba96 optimise networking, remove bloat 2024-12-16 19:50:29 +00:00
Alex Cheema
64365d684f one two and three m4 pro clusters 2024-12-16 19:50:24 +00:00
Alex Cheema
9397464fad add commit to results 2024-12-16 19:50:19 +00:00
Nel Nibcord
08912d1b64 Only collect topology if peers changed 2024-12-16 19:50:18 +00:00
Alex Cheema
06c2e236b8 rip out stats bloat 2024-12-16 19:50:17 +00:00
Alex Cheema
cb4615c95d fix SendNewToken 2024-12-16 19:50:14 +00:00
Alex Cheema
f55a53ae7e one token at a time 2024-12-16 19:49:52 +00:00
Gary
25b4af70e0 Merge branch 'main' into runners 2024-12-14 20:48:58 +00:00
Alex Cheema
a93092105c set max-generate-tokens to 250 2024-12-14 19:10:03 +00:00
Alex Cheema
0c6ab35333 increase timeout of http request in bench.py up to 10 mins 2024-12-14 18:33:41 +00:00
Alex Cheema
e5d54c77a9 add llama-3.3-70b to 3 M4 Pro cluster 2024-12-12 18:51:26 +00:00
Alex Cheema
2ff4638122 Merge remote-tracking branch 'origin/main' into runners 2024-12-12 17:14:40 +00:00
Alex Cheema
b6f2385c41 run llama-3.1-8b on 3 m4 pro cluster 2024-12-12 15:13:10 +00:00
Alex Cheema
9472ab0d2c t 2024-12-12 15:05:55 +00:00
Alex Cheema
dbb7ad3c08 run with three m4 pro 2024-12-12 14:36:18 +00:00
Alex Cheema
2abe57be21 grasping at straws 2024-12-12 12:03:20 +00:00
Alex Cheema
eeecdcb409 try a different taskpolicy 2024-12-12 11:45:01 +00:00
Alex Cheema
f9f76129a1 better bench system info 2024-12-12 11:34:37 +00:00
Alex Cheema
8c6d37d9b8 m4 cluster test 2024-12-12 11:13:13 +00:00
Alex Cheema
1194db6e65 m3 2024-12-12 00:02:20 +00:00
Alex Cheema
8cb7327da2 re-enable m4 cluster run 2024-12-12 00:01:14 +00:00
Alex Cheema
bba0aa0877 single node test 20 2024-12-11 22:58:44 +00:00
Alex Cheema
279354a1fd single node test 19 2024-12-11 22:58:38 +00:00
Alex Cheema
92e2b74902 single node test 18 2024-12-11 22:58:33 +00:00
Alex Cheema
76196b8c2f single node test 17 2024-12-11 22:58:27 +00:00
Alex Cheema
8408c8499f single node test 16 2024-12-11 22:58:21 +00:00
Alex Cheema
c65d1d9141 single node test 15 2024-12-11 22:58:16 +00:00
Alex Cheema
0bd44c0f78 single node test 14 2024-12-11 22:58:10 +00:00
Alex Cheema
f22bc99f2c single node test 13 2024-12-11 22:58:04 +00:00
Alex Cheema
3fda05aa39 single node test 12 2024-12-11 22:57:58 +00:00
Alex Cheema
6c322ac070 single node test 11 2024-12-11 22:57:53 +00:00
Alex Cheema
c5c27a32af single node test 10 2024-12-11 22:57:47 +00:00
Alex Cheema
9f1393dc7f single node test 9 2024-12-11 22:57:42 +00:00
Alex Cheema
32ff3ef9af single node test 8 2024-12-11 22:57:36 +00:00
Alex Cheema
b23c3fdaad single node test 7 2024-12-11 22:57:31 +00:00
Alex Cheema
8b47a9d017 single node test 6 2024-12-11 22:57:25 +00:00
Alex Cheema
f89b85b3f2 single node test 5 2024-12-11 22:57:19 +00:00
Alex Cheema
6f097c9321 single node test 4 2024-12-11 22:57:14 +00:00
Alex Cheema
fb7a0defe1 single node test 3 2024-12-11 22:57:08 +00:00
Alex Cheema
fe506a53d9 single node test 2 2024-12-11 22:57:02 +00:00
Alex Cheema
3f6ef1c763 single node test 1 2024-12-11 22:56:56 +00:00
Alex Cheema
e63c224c71 testtt 2024-12-11 22:53:02 +00:00
Alex Cheema
20e3065e57 les goh 2024-12-11 22:49:29 +00:00
Alex Cheema
83892d5b7e t 2024-12-11 22:45:59 +00:00
Alex Cheema
83470a98b4 t 2024-12-11 22:42:02 +00:00
Alex Cheema
92edfa5efc t 2024-12-11 22:40:47 +00:00
Alex Cheema
225dcba788 t 2024-12-11 22:37:11 +00:00
Alex Cheema
6249bee793 tes 2024-12-11 22:35:30 +00:00
Alex Cheema
741c31836e test 2024-12-11 22:27:10 +00:00
Alex Cheema
d0b7f1b4bb t 2024-12-11 22:11:01 +00:00
Alex Cheema
90677415c7 t 2024-12-11 22:01:29 +00:00
Alex Cheema
6cf2af39e8 t 2024-12-11 21:55:24 +00:00
Alex Cheema
5a1a0f5fd2 t 2024-12-11 21:45:53 +00:00
Alex Cheema
dd3fd279dc t 2024-12-11 21:42:01 +00:00
Alex Cheema
61c09631c0 t 2024-12-11 21:40:47 +00:00
Alex Cheema
e698ef6ab1 t 2024-12-11 21:39:27 +00:00
Alex Cheema
26351e719d t 2024-12-11 21:36:59 +00:00
Alex Cheema
5dee5e55fe t 2024-12-11 21:33:03 +00:00
Alex Cheema
6acfb81860 t 2024-12-11 20:28:07 +00:00
Alex Cheema
b1142d4ff4 t 2024-12-11 19:39:58 +00:00
Alex Cheema
a932afc01c oi 2024-12-11 19:30:28 +00:00
Alex Cheema
cdae702673 t 2024-12-11 19:24:43 +00:00
Alex Cheema
d95f40b6c8 a 2024-12-11 19:07:36 +00:00
Alex Cheema
97ffb83e86 t 2024-12-11 19:01:24 +00:00
Alex Cheema
9a11e27c93 ttt 2024-12-11 18:54:51 +00:00
Alex Cheema
d6c2146dd9 t 2024-12-11 18:34:35 +00:00
Alex Cheema
63da9fc194 a 2024-12-11 18:30:02 +00:00
Alex Cheema
7c0c5ef7fc ttttttt 2024-12-11 18:23:59 +00:00
Alex Cheema
739b7d178e tttttt 2024-12-11 18:02:22 +00:00
Alex Cheema
cacf50cd57 tttt 2024-12-11 18:00:28 +00:00
Alex Cheema
0904cda3ac ttt 2024-12-11 17:58:59 +00:00
Alex Cheema
6bb38939ec tt 2024-12-11 17:56:22 +00:00
Alex Cheema
1dbe11caf9 t 2024-12-11 17:54:41 +00:00
Alex Cheema
8d9e3b88d3 t 2024-12-11 17:52:07 +00:00
Alex Cheema
9dd33d37f2 t 2024-12-11 17:44:14 +00:00
Alex Cheema
a4bb4bb6ac update bootstrap 2024-12-11 17:37:38 +00:00
Alex Cheema
7b99cb4a12 t 2024-12-11 17:30:50 +00:00
Alex Cheema
9848a45da5 TT 2024-12-11 17:27:53 +00:00
Alex Cheema
378975813c t 2024-12-11 17:15:39 +00:00
Alex Cheema
e680e8a1ed fix name 2024-12-11 17:07:45 +00:00
Alex Cheema
7b2282d300 run without debug flag 2024-12-11 17:07:19 +00:00
Alex Cheema
3b1ea1933b use .venv exo 2024-12-11 17:02:58 +00:00
Alex Cheema
668766fc4b t 2024-12-11 16:55:57 +00:00
Alex Cheema
e501eeaf91 tweak install 2024-12-11 16:52:07 +00:00
Alex Cheema
41902f716f tweaks 2024-12-11 16:40:21 +00:00
Alex Cheema
b7bab80ec8 test2 2024-12-11 16:36:50 +00:00
Alex Cheema
6169996c70 test 2024-12-11 16:35:26 +00:00
Alex Cheema
bbb58460f8 Test on m4 2024-12-11 16:29:52 +00:00
Alex Cheema
cff03fc6c5 perf diag 2024-12-11 16:19:47 +00:00
Alex Cheema
f7122d400d add system_status check to bench 2024-12-11 16:13:53 +00:00
Alex Cheema
c938efb531 t 2024-12-11 16:06:14 +00:00
Alex Cheema
e2d3a90832 runner-token typo 2024-12-11 15:47:10 +00:00
Alex Cheema
ba96413a63 bootstrap script tweaks 2024-12-11 15:45:05 +00:00
Alex Cheema
cb40eb23ce more robust configure_mlx.sh 2024-12-11 15:38:45 +00:00
Alex Cheema
afe71c01da check gpu usage 2024-12-11 15:28:57 +00:00
Alex Cheema
a84cba4e3a Merge remote-tracking branch 'origin/main' into runners 2024-12-11 15:22:35 +00:00
Alex Cheema
23158a42ad add branch name to results 2024-12-11 12:59:55 +00:00
Alex Cheema
18e7919971 test 30 2024-12-11 12:55:05 +00:00
Alex Cheema
0e32a625d7 test 29 2024-12-11 12:54:59 +00:00
Alex Cheema
04bc163fea test 28 2024-12-11 12:54:52 +00:00
Alex Cheema
949055dec0 test 27 2024-12-11 12:54:45 +00:00
Alex Cheema
070b163cc7 test 26 2024-12-11 12:54:38 +00:00
Alex Cheema
fc26ad4006 test 25 2024-12-11 12:54:27 +00:00
Alex Cheema
5d3be3c6ed test 24 2024-12-11 12:54:20 +00:00
Alex Cheema
23dd5de3ae test 23 2024-12-11 12:54:14 +00:00
Alex Cheema
6030b39964 test 22 2024-12-11 12:54:08 +00:00
Alex Cheema
4f4ac0fa52 test 21 2024-12-11 12:54:01 +00:00
Alex Cheema
16d9839071 test {i} 2024-12-11 12:53:55 +00:00
Alex Cheema
8269b4b190 t 2024-12-11 12:38:51 +00:00
Alex Cheema
1e869a0f15 trigger test 2024-12-10 02:04:52 +00:00
Alex Cheema
5a4d128db6 trigger test 2024-12-09 08:02:29 +00:00
Alex Cheema
8a5d212cfc test 20 2024-12-08 23:38:30 +00:00
Alex Cheema
53edb8508b test 19 2024-12-08 23:38:24 +00:00
Alex Cheema
29d9df04bf test 18 2024-12-08 23:38:18 +00:00
Alex Cheema
4d6af6e6ca test 17 2024-12-08 23:38:13 +00:00
Alex Cheema
8c7c156f57 test 16 2024-12-08 23:38:07 +00:00
Alex Cheema
310843487f test 15 2024-12-08 23:38:01 +00:00
Alex Cheema
a4b221d0a0 test 14 2024-12-08 23:37:55 +00:00
Alex Cheema
286db875de test 13 2024-12-08 23:37:49 +00:00
Alex Cheema
d714e40f62 test 12 2024-12-08 23:37:43 +00:00
Alex Cheema
e78ef75531 test 11 2024-12-08 23:37:37 +00:00
Alex Cheema
38eaecf087 test 10 2024-12-08 23:37:31 +00:00
Alex Cheema
3cf28f8452 test 9 2024-12-08 23:37:26 +00:00
Alex Cheema
9ba8bbdd70 test 8 2024-12-08 23:37:20 +00:00
Alex Cheema
af6048e373 test 7 2024-12-08 23:37:14 +00:00
Alex Cheema
d93b8e8948 test 6 2024-12-08 23:37:08 +00:00
Alex Cheema
b69cb49a46 test 5 2024-12-08 23:37:02 +00:00
Alex Cheema
cc74b1f9b3 test 4 2024-12-08 23:36:57 +00:00
Alex Cheema
e78a52de5f test 3 2024-12-08 23:36:51 +00:00
Alex Cheema
f6c2c37c4b test 2 2024-12-08 23:36:45 +00:00
Alex Cheema
314a5d9781 test 1 2024-12-08 23:36:22 +00:00
Alex Cheema
b4e885bbd2 test range 2024-12-08 23:36:14 +00:00
Alex Cheema
bd9d11861b sleep before bench 2024-12-08 23:24:46 +00:00
Alex Cheema
571b26c50e allowed interface types 2024-12-08 23:20:08 +00:00
Glen
b21681931d remove 2024-12-08 23:13:10 +00:00
Alex Cheema
f584e86d8e get rid of lfs stuff 2024-12-08 22:55:19 +00:00
Alex Cheema
fd05bca1c8 lfs 2024-12-08 22:46:49 +00:00
Alex Cheema
cbac4d6a3e git version 2024-12-08 22:44:32 +00:00
Alex Cheema
b0977f97ab t 2024-12-08 22:43:23 +00:00
Glen
1716f637f7 test 2024-12-08 22:32:03 +00:00
Glen
903a5aabf7 fix 2024-12-08 22:26:44 +00:00
Glen
b4f86496ea bootstrap 2024-12-08 22:23:28 +00:00
Alex Cheema
8e57f3385c trigger test 2024-12-08 22:14:23 +00:00
Alex Cheema
3ccbdf19de add DEBUG_DISCOVERY 2024-12-08 22:07:48 +00:00
Alex Cheema
3687ba18df bench logs 2024-12-08 22:02:39 +00:00
Alex Cheema
6bb7c11bbb enable debug 2024-12-08 21:54:24 +00:00
Glen
c8f93721c5 model matrix 2024-12-08 21:14:36 +00:00
Alex Cheema
fb8d87025f t 2024-12-08 21:02:42 +00:00
Alex Cheema
87865f0cd9 list exo processes before test, warmup req in bench 2024-12-08 20:58:44 +00:00
Glen
755dd477dd jobname 2024-12-08 20:37:50 +00:00
Alex Cheema
fb44eb086c simplify bench 2024-12-08 20:30:07 +00:00
Alex Cheema
be8cbc0f56 trigger test 2024-12-08 19:28:55 +00:00
Glen
fe8074929f fix 2024-12-08 19:08:47 +00:00
Glen
c3c80c61c9 name 2024-12-08 19:02:53 +00:00
Glen
c138de0875 job_name 2024-12-08 18:56:37 +00:00
Glen
38bd00390c fix 2024-12-08 18:32:38 +00:00
Glen
732ba915aa new_conf 2024-12-08 18:32:06 +00:00
Glen
785710355f aws 2024-12-07 19:28:54 +00:00
Glen
320892dccc maxtok 2024-12-07 19:28:54 +00:00
Glen
6dae3a4719 conf 2024-12-07 19:28:54 +00:00
Glen
7b77ef000e flush 2024-12-07 19:28:54 +00:00
Glen
6c08b32350 nodebug 2024-12-07 19:28:54 +00:00
Glen
4dd617ad37 shorter 2024-12-07 19:28:54 +00:00
Glen
acdee16aee debug 2024-12-07 19:28:54 +00:00
Glen
9fc33587da path 2024-12-07 19:28:54 +00:00
Glen
f087c0ac99 fix 2024-12-07 19:28:54 +00:00
Glen
16b126d890 fix 2024-12-07 19:28:54 +00:00
Glen
faf0aaedba jq 2024-12-07 19:28:54 +00:00
Glen
4cac1bb151 quotes 2024-12-07 19:28:54 +00:00
Glen
cb3c1477bb fix 2024-12-07 19:28:54 +00:00
Glen
19a7d5a5cf fix 2024-12-07 19:28:54 +00:00
Glen
f7e0348f62 activate 2024-12-07 19:28:54 +00:00
Glen
c3dfac60a6 debug 2024-12-07 19:28:54 +00:00
Glen
64954aacfe fixed 2024-12-07 19:28:54 +00:00
Glen
ccc5415cc6 try 2024-12-07 19:28:54 +00:00
Glen
1dcc731b43 fix 2024-12-07 19:28:54 +00:00
Glen
3662ec402a fix 2024-12-07 19:28:54 +00:00
Glen
0739dc9564 fix 2024-12-07 19:28:54 +00:00
Glen
d16280ddfc debug 2024-12-07 19:28:54 +00:00
Glen
f9c23617a7 fix3 2024-12-07 19:28:54 +00:00
Glen
ce2ccddc93 fix2 2024-12-07 19:28:54 +00:00
Glen
1af28cb5a1 fix 2024-12-07 19:28:54 +00:00
Glen
6b61fc6660 tweak python install 2024-12-07 19:28:54 +00:00
Glen
bdf417f25e tweak 2024-12-07 19:28:54 +00:00
Glen
d154d37ac4 add exo run 2024-12-07 19:28:54 +00:00
Glen
90fd5c13a4 matrix 2024-12-07 19:28:54 +00:00
Glen
7d223a0095 matrix 2024-12-07 19:28:54 +00:00
Glen
cb3d89eb48 test runner 2024-12-07 19:28:54 +00:00
Glen
8302fd0aae test runner 2024-12-07 19:28:54 +00:00
Alex Cheema
deb80d2577 clang for tinygrad 2024-12-07 19:28:54 +00:00
Alex Cheema
976e5f2fdb disable mlx test for now..plan to run this on a self-hosted runner 2024-12-07 19:28:54 +00:00
Alex Cheema
9dc76ef03b tooonygrad 2024-12-07 19:28:54 +00:00
Alex Cheema
32cd1f1d72 give this a goh 2024-12-07 19:28:54 +00:00
Alex Cheema
6b54188140 cond 2024-12-07 19:28:54 +00:00
Alex Cheema
58bcf5b429 check discovery on integration tests too 2024-12-07 19:28:54 +00:00
Alex Cheema
3c0297c3e9 more robust discovery log check 2024-12-07 19:28:54 +00:00
Alex Cheema
8d433e6579 run tinygrad and discovery integratrion tests on linux 2024-12-07 19:28:54 +00:00
Alex Cheema
676125bfe6 job 2024-12-07 19:28:54 +00:00
Alex Cheema
902e0d35e1 github env vars 2024-12-07 19:28:54 +00:00
Alex Cheema
972aea446c macos 15 2024-12-07 19:28:53 +00:00
Alex Cheema
0d0338f871 migrate from circleci to github actions 2024-12-07 19:28:53 +00:00
Pranav Veldurthi
0f10244900 Merge latest 2024-12-04 22:52:48 -05:00
Pranav Veldurthi
686e139508 Merge Latest 2024-12-04 22:52:25 -05:00
Pranav Veldurthi
ca0caad0ae Image to image generation 2024-12-04 22:40:12 -05:00
Alex Cheema
f94c9067e2 trigger test 2024-12-04 03:09:12 +00:00
Alex Cheema
f0bb515d1d trigger test 2024-12-02 11:20:21 +00:00
Alex Cheema
71db641fe4 trigger test 2024-12-02 04:11:43 +00:00
Pranav Veldurthi
4b8c4a795f Images stored in system 2024-12-01 19:31:51 -05:00
Alex Cheema
f339f74fe3 trigger test 2024-12-01 17:39:53 +00:00
Alex Cheema
7dc0a7467b trigger test 2024-12-01 14:31:23 +00:00
Pranav Veldurthi
497756f7c8 merge latest main 2024-11-25 17:50:33 -05:00
Pranav Veldurthi
4874295b34 Image streaming while generation 2024-11-20 18:08:54 -05:00
Alex Cheema
fece3f0cef gitignore tinychat pngs 2024-11-20 10:01:06 +04:00
Alex Cheema
38ee815107 static images dir 2024-11-20 09:55:36 +04:00
Pranav Veldurthi
3d5746f16f Merge 2024-11-19 23:17:21 -05:00
Pranav Veldurthi
6b28ef0349 Stable stable diffusion mlx 2024-11-19 23:13:22 -05:00
79 changed files with 6059 additions and 1788 deletions

View File

@@ -27,7 +27,7 @@ commands:
fi
# Start first instance
HF_HOME="$(pwd)/.hf_cache_node1" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <<parameters.inference_engine>> \
EXO_HOME="$(pwd)/.exo_cache_node1" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <<parameters.inference_engine>> \
--node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 \
--chatgpt-api-response-timeout 900 --disable-tui > output1.log &
PID1=$!
@@ -35,7 +35,7 @@ commands:
TAIL1=$!
# Start second instance
HF_HOME="$(pwd)/.hf_cache_node2" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <<parameters.inference_engine>> \
EXO_HOME="$(pwd)/.exo_cache_node2" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <<parameters.inference_engine>> \
--node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 \
--chatgpt-api-response-timeout 900 --disable-tui > output2.log &
PID2=$!
@@ -254,6 +254,35 @@ jobs:
prompt: "Keep responses concise. Who was the king of pop?"
expected_output: "Michael Jackson"
chatgpt_api_integration_test_tinygrad_linux:
machine:
image: ubuntu-2204:current
resource_class: xlarge
steps:
- checkout
- run:
name: Set up Python
command: |
export DEBIAN_FRONTEND=noninteractive
export DEBCONF_NONINTERACTIVE_SEEN=true
sudo apt-get update
sudo add-apt-repository -y ppa:deadsnakes/ppa
sudo apt-get update
sudo apt-get install -y python3.12 python3.12-venv clang
python3.12 -m venv env
source env/bin/activate
- run:
name: Install dependencies
command: |
source env/bin/activate
pip install --upgrade pip
pip install .
- run_chatgpt_api_test:
inference_engine: tinygrad
model_id: llama-3.2-1b
prompt: "Keep responses concise. Who was the king of pop?"
expected_output: "Michael Jackson"
measure_pip_sizes:
macos:
xcode: "16.0.0"
@@ -342,5 +371,6 @@ workflows:
- discovery_integration_test
- chatgpt_api_integration_test_mlx
- chatgpt_api_integration_test_tinygrad
- chatgpt_api_integration_test_tinygrad_linux
- chatgpt_api_integration_test_dummy
- measure_pip_sizes

401
.github/bench.py vendored Normal file
View File

@@ -0,0 +1,401 @@
import aiohttp
import asyncio
import time
import json
import os
import boto3
from typing import Dict, Any
from datetime import datetime
import subprocess
import psutil
import platform
from pathlib import Path
def check_system_state():
print("\n=== System State Check ===", flush=True)
# Add macOS-specific checks
try:
# Check powermetrics with sudo
try:
power_metrics = subprocess.run(
['sudo', 'powermetrics', '-n', '1', '-i', '1000', '--samplers', 'cpu_power'],
capture_output=True, text=True
)
print("\nPower Metrics:", power_metrics.stdout, flush=True)
except Exception as e:
print(f"Error getting power metrics: {e}", flush=True)
# Check thermal state
thermal_state = subprocess.run(['pmset', '-g', 'therm'], capture_output=True, text=True)
print("\nThermal State:", thermal_state.stdout, flush=True)
# Check if running under Rosetta
arch = subprocess.run(['arch'], capture_output=True, text=True)
print("\nArchitecture:", arch.stdout, flush=True)
# Check MLX compilation mode - only if mlx is available
try:
import mlx.core as mx
if hasattr(mx, 'build_info'):
print("\nMLX Build Info:", mx.build_info(), flush=True)
else:
print("\nMLX Build Info: Not available in this version", flush=True)
except ImportError:
print("\nMLX: Not installed", flush=True)
except Exception as e:
print(f"\nError checking MLX: {e}", flush=True)
except Exception as e:
print(f"Error in macOS checks: {e}", flush=True)
# CPU Info
print("\nCPU Information:", flush=True)
try:
if platform.system() == 'Darwin' and platform.processor() == 'arm':
# Use sysctl for Apple Silicon Macs
cpu_info = subprocess.run(['sysctl', 'machdep.cpu'], capture_output=True, text=True)
if cpu_info.returncode == 0:
print(f"CPU Info (Apple Silicon):", cpu_info.stdout, flush=True)
# Parse powermetrics output for clearer CPU frequency display
try:
power_metrics = subprocess.run(
['sudo', 'powermetrics', '-n', '1', '-i', '100', '--samplers', 'cpu_power'],
capture_output=True, text=True
)
if power_metrics.returncode == 0:
output = power_metrics.stdout
print("\nDetailed CPU Frequency Information:")
# Extract cluster frequencies and max frequencies
current_cluster = None
max_freqs = {'E': 0, 'P0': 0, 'P1': 0}
for line in output.split('\n'):
# Track which cluster we're processing
if "E-Cluster" in line:
current_cluster = 'E'
elif "P0-Cluster" in line:
current_cluster = 'P0'
elif "P1-Cluster" in line:
current_cluster = 'P1'
# Get current frequencies
if "HW active frequency:" in line:
freq = line.split(':')[1].strip()
if freq != "0 MHz":
print(f"Current {current_cluster}-Cluster Frequency: {freq}")
# Get max frequencies from residency lines
if current_cluster and "active residency:" in line and "MHz:" in line:
try:
# Extract all frequency values
freqs = []
parts = line.split('MHz:')[:-1] # Skip last part as it's not a frequency
for part in parts:
freq_str = part.split()[-1]
try:
freq = float(freq_str)
freqs.append(freq)
except ValueError:
continue
if freqs:
max_freqs[current_cluster] = max(max_freqs[current_cluster], max(freqs))
except Exception:
continue
# Print max frequencies
print("\nMaximum Available Frequencies:")
for cluster, max_freq in max_freqs.items():
if max_freq > 0:
print(f"{cluster}-Cluster Max: {max_freq:.0f} MHz")
except Exception as e:
print(f"Error parsing powermetrics: {e}", flush=True)
else:
# Use psutil for other systems
cpu_freq = psutil.cpu_freq()
print(f"CPU Frequency - Current: {cpu_freq.current:.2f}MHz, Min: {cpu_freq.min:.2f}MHz, Max: {cpu_freq.max:.2f}MHz", flush=True)
print(f"\nCPU Usage per Core: {psutil.cpu_percent(percpu=True)}%", flush=True)
# Check if running in low power mode
power_mode = subprocess.run(['pmset', '-g'], capture_output=True, text=True)
print("\nPower Settings:", power_mode.stdout, flush=True)
except Exception as e:
print(f"Error getting CPU info: {e}", flush=True)
# Memory Info
print("\nMemory Information:", flush=True)
try:
mem = psutil.virtual_memory()
print(f"Total: {mem.total/1024/1024/1024:.2f}GB", flush=True)
print(f"Available: {mem.available/1024/1024/1024:.2f}GB", flush=True)
print(f"Used: {mem.used/1024/1024/1024:.2f}GB ({mem.percent}%)", flush=True)
# Check swap
swap = psutil.swap_memory()
print(f"Swap Used: {swap.used/1024/1024/1024:.2f}GB of {swap.total/1024/1024/1024:.2f}GB", flush=True)
except Exception as e:
print(f"Error getting memory info: {e}", flush=True)
# GPU Info
print("\nGPU Information:", flush=True)
try:
# Check MLX GPU settings
print("MLX Environment Variables:", flush=True)
mlx_vars = {k: v for k, v in os.environ.items() if k.startswith('MLX')}
print(json.dumps(mlx_vars, indent=2), flush=True)
# Check Metal GPU memory allocation
gpu_mem = subprocess.run(['sysctl', 'iogpu'], capture_output=True, text=True)
print("GPU Memory Settings:", gpu_mem.stdout, flush=True)
except Exception as e:
print(f"Error getting GPU info: {e}", flush=True)
# Process Priority
print("\nProcess Priority Information:", flush=True)
try:
current_process = psutil.Process()
print(f"Process Nice Value: {current_process.nice()}", flush=True)
# Only try to get ionice if the platform supports it
if hasattr(current_process, 'ionice'):
print(f"Process IO Nice Value: {current_process.ionice()}", flush=True)
except Exception as e:
print(f"Error getting process priority info: {e}", flush=True)
# System Load
print("\nSystem Load:", flush=True)
try:
load_avg = psutil.getloadavg()
print(f"Load Average: {load_avg}", flush=True)
# Get top processes by CPU and Memory
print("\nTop Processes by CPU Usage:", flush=True)
processes = []
for proc in psutil.process_iter(['pid', 'name', 'cpu_percent', 'memory_percent']):
try:
pinfo = proc.info
if pinfo['cpu_percent'] is not None and pinfo['memory_percent'] is not None:
processes.append(pinfo)
except (psutil.NoSuchProcess, psutil.AccessDenied):
continue
# Sort and display top 5 CPU-consuming processes
sorted_by_cpu = sorted(processes, key=lambda x: x['cpu_percent'] or 0, reverse=True)[:5]
for proc in sorted_by_cpu:
print(f"PID: {proc['pid']}, Name: {proc['name']}, CPU: {proc['cpu_percent']}%, Memory: {proc['memory_percent']:.1f}%")
except Exception as e:
print(f"Error getting system load info: {e}", flush=True)
print("\n=== End System State Check ===\n", flush=True)
def check_gpu_access():
try:
# Check if MLX can see the GPU
import mlx.core as mx
print("MLX device info:", mx.default_device())
# Check Metal device availability
result = subprocess.run(['system_profiler', 'SPDisplaysDataType'], capture_output=True, text=True)
print("GPU Info:", result.stdout)
except Exception as e:
print(f"Failed to check GPU access: {e}")
async def measure_performance(api_endpoint: str, prompt: str, model: str) -> Dict[str, Any]:
"""
Measures the performance of an API endpoint by sending a prompt and recording metrics.
Args:
api_endpoint (str): The API endpoint URL.
prompt (str): The prompt to send to the API.
Returns:
Dict[str, Any]: A dictionary containing performance metrics or error information.
"""
results = {
'model': model,
'run_id': os.environ.get('GITHUB_RUN_ID', 'unknown'),
'branch': os.environ.get('GITHUB_REF_NAME', 'unknown'),
'commit': os.environ.get('GITHUB_SHA', 'unknown'),
'configuration': json.loads(os.environ.get('HARDWARE_CONFIG', '{}'))
}
# Get token count
session = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=600, connect=10, sock_read=600, sock_connect=10))
try:
response = await session.post(
"http://localhost:52415/v1/chat/token/encode",
json={
"model": model,
"messages": [{"role": "user", "content": prompt}]
}
)
response.raise_for_status()
token_data = await response.json()
results['prompt_len'] = token_data['num_tokens']
except Exception as e:
await session.close()
raise RuntimeError(f"Failed to get token count: {str(e)}")
# Measure completion performance
try:
start_time = time.time()
response = await session.post(
api_endpoint,
json={
"model": model,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0,
"stream": True
}
)
response.raise_for_status()
first_token_time = None
total_tokens = 0
async for line in response.content.iter_chunks():
line = line[0].decode('utf-8').strip()
if not line.startswith('data: '):
continue
data = json.loads(line[6:]) # Skip 'data: ' prefix
if content := data.get('choices', [{}])[0].get('delta', {}).get('content'):
print(f"Received content: {content}", flush=True)
if first_token_time is None:
first_token_time = time.time()
ttft = first_token_time - start_time
results.update({
'ttft': ttft,
'prompt_tps': results['prompt_len'] / ttft
})
total_tokens += 1
total_time = time.time() - start_time
results.update({
'generation_tps': total_tokens / total_time,
'response_len': total_tokens,
'total_time': total_time
})
except Exception as e:
raise RuntimeError(f"Performance measurement failed: {str(e)}")
finally:
await session.close()
return results
async def main() -> None:
api_endpoint = "http://localhost:52415/v1/chat/completions"
# Define prompts
prompt_warmup = "what is the capital of France?"
prompt_essay = "write an essay about cats"
model = os.environ.get('model', 'llama-3.2-1b')
# Warmup request
print("\nPerforming warmup request...", flush=True)
try:
warmup_results = await measure_performance(api_endpoint, prompt_warmup, model)
print("Warmup completed successfully", flush=True)
except Exception as e:
print(f"Warmup request failed: {e}", flush=True)
# Measure performance for the essay prompt
print("\nMeasuring performance for the essay prompt...", flush=True)
results = await measure_performance(api_endpoint, prompt_essay, model)
try:
s3_client = boto3.client(
's3',
aws_access_key_id=os.environ.get('aws_access_key_id'),
aws_secret_access_key=os.environ.get('aws_secret_key')
)
job_name = os.environ.get('GITHUB_JOB')
# Create S3 key with timestamp and commit info
now = datetime.utcnow()
timestamp = now.strftime('%H-%M-%S')
commit_sha = os.environ.get('GITHUB_SHA', 'unknown')[:7]
s3_key = f"{job_name}/{model}/{now.year}/{now.month}/{now.day}/{timestamp}_{commit_sha}.json"
# Upload to S3
s3_client.put_object(
Bucket='exo-benchmarks',
Key=s3_key,
Body=json.dumps(results),
ContentType='application/json'
)
print(f"Performance metrics uploaded to S3: s3://exo-benchmarks/{s3_key}", flush=True)
except Exception as e:
print(f"Failed to upload metrics to S3: {e}", flush=True)
# Optionally print the metrics for visibility
print("Performance metrics:", flush=True)
print(json.dumps(results, indent=4), flush=True)
def optimize_system_performance():
"""Set optimal system performance settings before running benchmark."""
try:
# Try to set high performance power mode
subprocess.run(['sudo', 'pmset', '-a', 'powermode', '2'], check=False)
# Ensure MLX uses performance cores and GPU
os.environ['MLX_FORCE_P_CORES'] = '1'
os.environ['MLX_METAL_PREWARM'] = '1'
os.environ['MLX_USE_GPU'] = '1'
# Set process priority
current_process = psutil.Process()
try:
# Set highest priority
subprocess.run(['sudo', 'renice', '-n', '-20', '-p', str(current_process.pid)], check=False)
# Print current process state
print("\nProcess State Before Benchmark:", flush=True)
proc_info = subprocess.run(
['ps', '-o', 'pid,ppid,user,%cpu,%mem,nice,stat,pri,command', '-p', str(current_process.pid)],
capture_output=True, text=True
)
print(proc_info.stdout, flush=True)
# Verify power mode
power_info = subprocess.run(['pmset', '-g'], capture_output=True, text=True)
if 'powermode 0' in power_info.stdout:
print("\nWarning: System still in normal power mode. Trying to set high performance mode again...", flush=True)
subprocess.run(['sudo', 'pmset', '-a', 'powermode', '2'], check=False)
except Exception as e:
print(f"Warning: Could not set process priority: {e}", flush=True)
except Exception as e:
print(f"Warning: Could not optimize system performance: {e}", flush=True)
# Print optimization status
print("\nOptimization Settings:", flush=True)
print("MLX Environment Variables:", flush=True)
for var in ['MLX_FORCE_P_CORES', 'MLX_METAL_PREWARM', 'MLX_USE_GPU']:
print(f"{var}: {os.environ.get(var, 'Not set')}", flush=True)
try:
nice_value = psutil.Process().nice()
print(f"Process Nice Value: {nice_value}", flush=True)
if nice_value != -20:
print("Warning: Process not running at highest priority", flush=True)
except Exception:
pass
if __name__ == "__main__":
check_system_state()
check_gpu_access()
optimize_system_performance()
asyncio.run(main())

330
.github/bootstrap.sh vendored Executable file
View File

@@ -0,0 +1,330 @@
#!/bin/bash
set -e
command_exists() {
command -v "$1" >/dev/null 2>&1
}
log() {
echo "[$(date '+%Y-%m-%d %H:%M:%S')] $1"
}
if [ "$EUID" -eq 0 ]; then
log "Please do not run as root. Run as regular user with sudo access."
exit 1
fi
# Check for required arguments
if [ -z "$1" ]; then
log "Error: Runner token is required"
log "Usage: $0 <runner-token> [tailscale-auth-key]"
exit 1
fi
RUNNER_TOKEN=$1
TAILSCALE_AUTH_KEY=$2
REPO="exo-explore/exo"
# Add sudoers configuration
log "Configuring sudo access..."
SUDOERS_CONTENT="$(whoami) ALL=(ALL) NOPASSWD: ALL"
echo "$SUDOERS_CONTENT" | sudo tee /etc/sudoers.d/github-runner > /dev/null
sudo chmod 440 /etc/sudoers.d/github-runner
log "Configuring privacy permissions..."
sudo tccutil reset All
sudo tccutil reset SystemPolicyAllFiles
sudo tccutil reset SystemPolicyNetworkVolumes
# Configure power management for maximum performance
log "Configuring power management..."
sudo pmset -a powermode 2 # Force highest performance mode
sudo pmset -a gpuswitch 2 # Force discrete/high-performance GPU
sudo pmset -a lowpowermode 0
sudo pmset -a lessbright 0
sudo pmset -a disablesleep 1
sudo pmset -a sleep 0
sudo pmset -a hibernatemode 0
sudo pmset -a autopoweroff 0
sudo pmset -a standby 0
sudo pmset -a powernap 0
# For Python specifically
PYTHON_PATH="/opt/homebrew/bin/python3.12"
sudo chmod 755 "$PYTHON_PATH"
# Add to firewall
log "Configuring firewall access..."
sudo /usr/libexec/ApplicationFirewall/socketfilterfw --add "$PYTHON_PATH"
sudo /usr/libexec/ApplicationFirewall/socketfilterfw --unblock "$PYTHON_PATH"
# Set Homebrew paths based on architecture
if [ "$(uname -p)" = "arm" ]; then
BREW_PREFIX="/opt/homebrew"
else
BREW_PREFIX="/usr/local"
fi
# Install Homebrew if not present
if ! command_exists brew; then
log "Installing Homebrew..."
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
echo 'eval "$(/opt/homebrew/bin/brew shellenv)"' >> ~/.zshrc
eval "$(/opt/homebrew/bin/brew shellenv)"
fi
# Install required packages
log "Installing required packages..."
export HOMEBREW_NO_AUTO_UPDATE=1
brew install python@3.12 coreutils
# Optional Tailscale setup if auth key is provided
if [ -n "$TAILSCALE_AUTH_KEY" ]; then
log "Installing and configuring Tailscale..."
brew install --quiet tailscale
sudo brew services stop tailscale 2>/dev/null || true
sudo rm -f /var/db/tailscale/tailscaled.state 2>/dev/null || true
sudo brew services start tailscale
sleep 2
sudo tailscale up --authkey=$TAILSCALE_AUTH_KEY
# Enable SSH and Screen Sharing
log "Enabling remote access services..."
sudo launchctl load -w /System/Library/LaunchDaemons/ssh.plist
sudo /System/Library/CoreServices/RemoteManagement/ARDAgent.app/Contents/Resources/kickstart \
-activate \
-configure -access -on \
-configure -allowAccessFor -allUsers \
-configure -restart -agent -privs -all
# Create launch daemon for remote access
sudo bash -c 'cat > /Library/LaunchDaemons/com.remote.access.setup.plist' << 'EOL'
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>Label</key>
<string>com.remote.access.setup</string>
<key>ProgramArguments</key>
<array>
<string>/bin/bash</string>
<string>-c</string>
<string>
launchctl load -w /System/Library/LaunchDaemons/ssh.plist;
/System/Library/CoreServices/RemoteManagement/ARDAgent.app/Contents/Resources/kickstart -activate -configure -access -on
</string>
</array>
<key>RunAtLoad</key>
<true/>
</dict>
</plist>
EOL
sudo chmod 644 /Library/LaunchDaemons/com.remote.access.setup.plist
sudo launchctl load -w /Library/LaunchDaemons/com.remote.access.setup.plist
fi
# Configure GitHub Actions Runner
log "Gathering system metadata..."
MACHINE_NAME=$(scutil --get ComputerName)
MACHINE_NAME="runner-$(echo -n "$MACHINE_NAME" | tr '[:upper:]' '[:lower:]' | tr -cd '[:alnum:]-')"
# Enhanced Apple Silicon detection
MACHINE_INFO=$(system_profiler SPHardwareDataType)
CHIP_FULL=$(echo "$MACHINE_INFO" | grep "Chip" | cut -d: -f2 | xargs)
if [[ $CHIP_FULL =~ "Apple" ]]; then
CHIP_MODEL=$(echo "$CHIP_FULL" | sed 's/^Apple //' | tr -d ' ' | tr '[:lower:]' '[:upper:]')
GPU_CORES=$(ioreg -l | grep "gpu-core-count" | awk -F'= ' '{print $2}')
if [ -z "$GPU_CORES" ]; then
GPU_CORES="N/A"
fi
else
CHIP_MODEL="Intel"
GPU_CORES="N/A"
fi
MEMORY=$(($(sysctl -n hw.memsize) / 1024 / 1024 / 1024))
# Set up GitHub Runner
RUNNER_DIR="$HOME/actions-runner"
# Check if runner is already configured
if [ -f "$RUNNER_DIR/.runner" ]; then
log "Runner already configured. Stopping existing service..."
sudo launchctl unload /Library/LaunchDaemons/com.github.runner.plist 2>/dev/null || true
fi
# Create runner directory if it doesn't exist
mkdir -p "$RUNNER_DIR"
cd "$RUNNER_DIR"
CUSTOM_LABELS="self-hosted,macos,arm64,${CHIP_MODEL}_GPU${GPU_CORES}_${MEMORY}GB"
# Only download and extract if not already present or if forced
if [ ! -f "$RUNNER_DIR/run.sh" ] || [ "${FORCE_SETUP:-false}" = "true" ]; then
log "Downloading GitHub Actions runner..."
RUNNER_VERSION=$(curl -s https://api.github.com/repos/actions/runner/releases/latest | grep '"tag_name":' | cut -d'"' -f4)
curl -o actions-runner.tar.gz -L "https://github.com/actions/runner/releases/download/${RUNNER_VERSION}/actions-runner-osx-arm64-${RUNNER_VERSION#v}.tar.gz"
tar xzf actions-runner.tar.gz
rm actions-runner.tar.gz
else
log "Runner already downloaded, skipping download step"
fi
log "Configuring runner with labels: $CUSTOM_LABELS"
./config.sh --unattended \
--url "https://github.com/${REPO}" \
--token "${RUNNER_TOKEN}" \
--name "${MACHINE_NAME}" \
--labels "${CUSTOM_LABELS}" \
--work "_work"
# Set optimal performance settings
log "Configuring system for optimal performance..."
# Configure CPU performance
log "Setting CPU performance controls..."
# Disable timer coalescing
sudo sysctl -w kern.timer.coalescing_enabled=0
sudo sysctl -w kern.timer_coalesce_bg_scale=-5
sudo sysctl -w kern.timer_resort_threshold_ns=0
# Set minimum timer intervals
sudo sysctl -w kern.wq_max_timer_interval_usecs=1000
sudo sysctl -w kern.timer_coalesce_bg_ns_max=1000
# Set minimum timer coalescing for all tiers
sudo sysctl -w kern.timer_coalesce_tier0_scale=-5
sudo sysctl -w kern.timer_coalesce_tier0_ns_max=1000
sudo sysctl -w kern.timer_coalesce_tier1_scale=-5
sudo sysctl -w kern.timer_coalesce_tier1_ns_max=1000
sudo sysctl -w kern.timer_coalesce_tier2_scale=-5
sudo sysctl -w kern.timer_coalesce_tier2_ns_max=1000
sudo sysctl -w kern.timer_coalesce_tier3_scale=-5
sudo sysctl -w kern.timer_coalesce_tier3_ns_max=1000
sudo sysctl -w kern.timer_coalesce_tier4_scale=-5
sudo sysctl -w kern.timer_coalesce_tier4_ns_max=1000
# Disable QoS restrictions
sudo sysctl -w net.qos.policy.restricted=0
sudo sysctl -w net.qos.policy.restrict_avapps=0
sudo sysctl -w net.qos.policy.wifi_enabled=0
sudo sysctl -w net.qos.policy.capable_enabled=0
# Set scheduler parameters
sudo sysctl -w kern.sched_rt_avoid_cpu0=0
sudo sysctl -w debug.sched=2
sudo sysctl -w net.pktsched.netem.sched_output_ival_ms=1
# Clean up any existing runner services
log "Cleaning up existing runner services..."
for service in com.github.runner com.github.runner.monitor com.github.runner.cpuaffinity com.github.runner.affinity; do
sudo launchctl bootout system/$service 2>/dev/null || true
sudo rm -f /Library/LaunchDaemons/$service.plist
done
# Create a simple runner service configuration
sudo tee /Library/LaunchDaemons/com.github.runner.plist > /dev/null << EOF
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>Label</key>
<string>com.github.runner</string>
<key>UserName</key>
<string>$(whoami)</string>
<key>GroupName</key>
<string>staff</string>
<key>WorkingDirectory</key>
<string>$RUNNER_DIR</string>
<key>ProgramArguments</key>
<array>
<string>$RUNNER_DIR/run.sh</string>
</array>
<key>RunAtLoad</key>
<true/>
<key>KeepAlive</key>
<dict>
<key>SuccessfulExit</key>
<false/>
<key>Crashed</key>
<true/>
</dict>
<key>ProcessType</key>
<string>Interactive</string>
<key>LowPriorityIO</key>
<false/>
<key>AbandonProcessGroup</key>
<false/>
<key>EnableTransactions</key>
<true/>
<key>ThrottleInterval</key>
<integer>0</integer>
<key>HardResourceLimits</key>
<dict>
<key>NumberOfFiles</key>
<integer>524288</integer>
<key>MemoryLock</key>
<integer>-1</integer>
</dict>
<key>SoftResourceLimits</key>
<dict>
<key>NumberOfFiles</key>
<integer>524288</integer>
<key>MemoryLock</key>
<integer>-1</integer>
</dict>
<key>QOSClass</key>
<string>User-Interactive</string>
<key>StandardOutPath</key>
<string>$RUNNER_DIR/_diag/runner.log</string>
<key>StandardErrorPath</key>
<string>$RUNNER_DIR/_diag/runner.err</string>
<key>EnvironmentVariables</key>
<dict>
<key>PATH</key>
<string>/usr/local/bin:/opt/homebrew/bin:/usr/bin:/bin:/usr/sbin:/sbin</string>
</dict>
<key>Nice</key>
<integer>-20</integer>
</dict>
</plist>
EOF
# Set proper permissions for the LaunchDaemon
sudo chown root:wheel /Library/LaunchDaemons/com.github.runner.plist
sudo chmod 644 /Library/LaunchDaemons/com.github.runner.plist
# Remove any existing service
sudo launchctl bootout system/com.github.runner 2>/dev/null || true
# Load the new service using bootstrap
sudo launchctl bootstrap system /Library/LaunchDaemons/com.github.runner.plist
# Add Runner.Listener permissions (after runner installation)
RUNNER_PATH="$RUNNER_DIR/bin/Runner.Listener"
sudo chmod 755 "$RUNNER_PATH"
sudo /usr/libexec/ApplicationFirewall/socketfilterfw --add "$RUNNER_PATH"
sudo /usr/libexec/ApplicationFirewall/socketfilterfw --unblock "$RUNNER_PATH"
# Create connection info file if Tailscale is configured
if [ -n "$TAILSCALE_AUTH_KEY" ]; then
TAILSCALE_IP=$(tailscale ip)
cat > "$HOME/remote_access_info.txt" << EOL
Mac Remote Access Information
============================
Computer Name: $MACHINE_NAME
Username: $USER
Tailscale IP: $TAILSCALE_IP
SSH Command: ssh $USER@$TAILSCALE_IP
Screen Sharing: vnc://$TAILSCALE_IP
EOL
chmod 600 "$HOME/remote_access_info.txt"
fi
log "Verifying runner service status..."
if sudo launchctl list | grep com.github.runner > /dev/null; then
log "GitHub Actions runner service is running successfully!"
log "Runner labels: $CUSTOM_LABELS"
[ -n "$TAILSCALE_AUTH_KEY" ] && log "Remote access details saved to: $HOME/remote_access_info.txt"
else
log "Error: Failed to start GitHub Actions runner service"
exit 1
fi

95
.github/optimize_performance.sh vendored Executable file
View File

@@ -0,0 +1,95 @@
#!/bin/bash
set -e
# Function to log with timestamp
log() {
echo "[$(date '+%Y-%m-%d %H:%M:%S')] $1"
}
log "Applying comprehensive performance optimizations..."
# System-wide power management
log "Configuring power management..."
sudo pmset -a lessbright 0
sudo pmset -a disablesleep 1
sudo pmset -a sleep 0
sudo pmset -a hibernatemode 0
sudo pmset -a autopoweroff 0
sudo pmset -a standby 0
sudo pmset -a powernap 0
sudo pmset -a proximitywake 0
sudo pmset -a tcpkeepalive 1
sudo pmset -a powermode 2
sudo pmset -a gpuswitch 2
sudo pmset -a displaysleep 0
sudo pmset -a disksleep 0
# Memory and kernel optimizations
log "Configuring memory and kernel settings..."
sudo sysctl -w kern.memorystatus_purge_on_warning=0
sudo sysctl -w kern.memorystatus_purge_on_critical=0
sudo sysctl -w kern.timer.coalescing_enabled=0
# Metal and GPU optimizations
log "Configuring Metal and GPU settings..."
defaults write com.apple.CoreML MPSEnableGPUValidation -bool false
defaults write com.apple.CoreML MPSEnableMetalValidation -bool false
defaults write com.apple.CoreML MPSEnableGPUDebug -bool false
defaults write com.apple.Metal GPUDebug -bool false
defaults write com.apple.Metal GPUValidation -bool false
defaults write com.apple.Metal MetalValidation -bool false
defaults write com.apple.Metal MetalCaptureEnabled -bool false
defaults write com.apple.Metal MTLValidationBehavior -string "Disabled"
defaults write com.apple.Metal EnableMTLDebugLayer -bool false
defaults write com.apple.Metal MTLDebugLevel -int 0
defaults write com.apple.Metal PreferIntegratedGPU -bool false
defaults write com.apple.Metal ForceMaximumPerformance -bool true
defaults write com.apple.Metal MTLPreferredDeviceGPUFrame -bool true
# Create MPS cache directory with proper permissions
sudo mkdir -p /tmp/mps_cache
sudo chmod 777 /tmp/mps_cache
# Process and resource limits
log "Configuring process limits..."
sudo launchctl limit maxfiles 524288 524288
ulimit -n 524288 || log "Warning: Could not set file descriptor limit"
ulimit -c 0
ulimit -l unlimited || log "Warning: Could not set memory lock limit"
# Export performance-related environment variables
cat << 'EOF' > /tmp/performance_env.sh
# Metal optimizations
export MTL_DEBUG_LAYER=0
export METAL_DEVICE_WRAPPER_TYPE=1
export METAL_DEBUG_ERROR_MODE=0
export METAL_FORCE_PERFORMANCE_MODE=1
export METAL_DEVICE_PRIORITY=high
export METAL_MAX_COMMAND_QUEUES=1024
export METAL_LOAD_LIMIT=0
export METAL_VALIDATION_ENABLED=0
export METAL_ENABLE_VALIDATION_LAYER=0
export OBJC_DEBUG_MISSING_POOLS=NO
export MPS_CACHEDIR=/tmp/mps_cache
# MLX optimizations
export MLX_USE_GPU=1
export MLX_METAL_COMPILE_ASYNC=1
export MLX_METAL_PREALLOCATE=1
export MLX_METAL_MEMORY_GUARD=0
export MLX_METAL_CACHE_KERNELS=1
export MLX_PLACEMENT_POLICY=metal
export MLX_METAL_VALIDATION=0
export MLX_METAL_DEBUG=0
export MLX_FORCE_P_CORES=1
export MLX_METAL_MEMORY_BUDGET=0
export MLX_METAL_PREWARM=1
# Python optimizations
export PYTHONUNBUFFERED=1
export PYTHONOPTIMIZE=2
export PYTHONHASHSEED=0
export PYTHONDONTWRITEBYTECODE=1
EOF
log "Performance optimizations completed. Environment variables written to /tmp/performance_env.sh"

207
.github/workflows/bench_job.yml vendored Normal file
View File

@@ -0,0 +1,207 @@
# This is the reusable workflow file
name: Distributed Job Runner
on:
workflow_call:
inputs:
config:
required: true
type: string
model:
required: true
type: string
calling_job_name:
required: true
type: string
network_interface:
required: true
type: string
jobs:
generate-matrix:
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
steps:
- id: set-matrix
env:
CONFIG: ${{ inputs.config }}
run: |
MATRIX=$(echo $CONFIG | jq -c '{cpu: [to_entries | .[] | .key as $k | range(.value) | $k]}')
echo "matrix=$MATRIX" >> $GITHUB_OUTPUT
run-distributed-job:
needs: generate-matrix
strategy:
matrix: ${{fromJson(needs.generate-matrix.outputs.matrix)}}
runs-on: ['self-hosted', 'macOS', '${{ matrix.cpu }}']
env:
HARDWARE_CONFIG: ${{ inputs.config }}
model: ${{ inputs.model }}
# Add performance-related environment variables
MTL_DEBUG_LAYER: 0
METAL_VALIDATION_ENABLED: 0
MLX_METAL_VALIDATION: 0
MLX_METAL_DEBUG: 0
MLX_FORCE_P_CORES: 1
MLX_METAL_PREWARM: 1
PYTHONOPTIMIZE: 2
steps:
- name: Cleanup workspace
run: |
sudo rm -rf "$GITHUB_WORKSPACE"
sudo mkdir -p "$GITHUB_WORKSPACE"
sudo chown -R $(whoami):$(id -g) "$GITHUB_WORKSPACE"
- uses: actions/checkout@v4
- name: Install dependencies
run: |
export PATH="/usr/local/bin:/opt/homebrew/bin:$PATH"
python3.12 -m venv .venv || {
echo "Failed to find python3.12. Checking installation locations:"
ls -l /usr/local/bin/python* /opt/homebrew/bin/python* 2>/dev/null || true
exit 1
}
source .venv/bin/activate
pip install --upgrade pip
pip install -e .
pip install boto3==1.35.76
- name: Apply Performance Optimizations
run: |
# Export performance-related environment variables
cat << 'EOF' > /tmp/performance_env.sh
# MLX and Metal optimizations
export MTL_DEBUG_LAYER=0
export METAL_VALIDATION_ENABLED=0
export MLX_METAL_VALIDATION=0
export MLX_METAL_DEBUG=0
export MLX_FORCE_P_CORES=1
export MLX_METAL_PREWARM=1
export PYTHONOPTIMIZE=2
EOF
# Source the performance environment variables
source /tmp/performance_env.sh
# MLX Memory Settings
./configure_mlx.sh
# Verify optimizations
echo "Verifying performance settings..."
env | grep -E "MLX_|METAL_|MTL_"
- name: Run exo
env:
aws_access_key_id: ${{ secrets.S3_EXO_BENCHMARKS_AWS_ACCESS_KEY_ID }}
aws_secret_key: ${{ secrets.S3_EXO_BENCHMARKS_AWS_SECRET_ACCESS_KEY }}
run: |
# Source performance environment variables
source /tmp/performance_env.sh
# Debug information
echo "Current commit SHA: $GITHUB_SHA"
git rev-parse HEAD
git status
CALLING_JOB="${{ inputs.calling_job_name }}"
UNIQUE_JOB_ID="${CALLING_JOB}_${model}_${GITHUB_RUN_ID}"
ALL_NODE_IDS=$(for i in $(seq ${{ strategy.job-total }} -1 0); do echo -n "${UNIQUE_JOB_ID}_${i},"; done | sed 's/,$//')
MY_NODE_ID="${UNIQUE_JOB_ID}_${{ strategy.job-index }}"
source .venv/bin/activate
export PATH="/usr/local/bin:/opt/homebrew/bin:$PATH"
echo "=== Before starting exo ==="
ps -eo pid,ppid,user,%cpu,%mem,nice,state,pri,command | head -1
ps -eo pid,ppid,user,%cpu,%mem,nice,state,pri,command | grep -i python
echo "Starting exo daemon..."
echo "Power mode settings:"
sudo pmset -g
# Start exo with explicit process control
sudo taskpolicy -d default -g default -a -t 0 -l 0 .venv/bin/exo \
--node-id="${MY_NODE_ID}" \
--node-id-filter="${ALL_NODE_IDS}" \
--interface-type-filter="${{ inputs.network_interface }}" \
--disable-tui \
--max-generate-tokens 250 \
--chatgpt-api-response-timeout 900 \
--chatgpt-api-port 52415 > output1.log 2>&1 &
PID1=$!
echo "Exo process started with PID: $PID1"
tail -f output1.log &
TAIL1=$!
# Give process time to start
sleep 2
# Set additional process priorities
sudo renice -n -20 -p $PID1
sudo taskpolicy -t 4 -p $PID1
echo "=== After starting exo ==="
ps -eo pid,ppid,user,%cpu,%mem,nice,state,pri,command | head -1
ps -eo pid,ppid,user,%cpu,%mem,nice,state,pri,command | grep $PID1
echo "Additional process details:"
sudo powermetrics -n 1 -i 1000 --show-process-energy | grep -A 5 $PID1 || true
trap 'kill $TAIL1' EXIT
trap 'kill $PID1' EXIT
echo "Waiting for all nodes to connect..."
for i in {1..20}; do
echo "Attempt $i: Checking node count..."
nodes=$(curl -s http://localhost:52415/topology | jq ".nodes | length")
echo "Current node count: $nodes"
if [ "$nodes" -eq "${{ strategy.job-total }}" ]; then
echo "All nodes connected successfully!"
break
fi
if [ $i -eq 20 ]; then
echo "ERROR: Failed to connect all nodes after 20 attempts. Expected ${{ strategy.job-total }} nodes, but got $nodes"
exit 1
fi
sleep 5
done
if ! kill -0 $PID1 2>/dev/null; then
echo "ERROR: Instance (PID $PID1) died unexpectedly. Full log output:"
cat output1.log
exit 1
fi
if [ "${{ strategy.job-index }}" -eq "0" ]; then
sleep 10
echo "This is the primary node (index 0). Running benchmark..."
GITHUB_JOB=$CALLING_JOB python .github/bench.py
else
echo "This is a secondary node (index ${{ strategy.job-index }}). Waiting for completion..."
sleep 10
while true; do
echo "Checking if primary node is still running..."
nodes=$(curl -s http://localhost:52415/topology | jq ".nodes | length")
echo "Current node count: $nodes"
if [ "$nodes" -lt "${{ strategy.job-total }}" ]; then
echo "Primary node completed, exiting..."
break
fi
sleep 5
done
fi
- name: Check Final System State
if: always()
run: |
echo "=== Final System State ==="
sudo pmset -g
sudo powermetrics -n 1 -i 1000 --show-process-energy || true
system_profiler SPDisplaysDataType
sysctl iogpu
ps -eo pid,ppid,user,%cpu,%mem,nice,state,command | grep -i python
env | grep -E "MLX_|METAL_|MTL_"
echo "=== End Final System State ==="

71
.github/workflows/benchmarks.yml vendored Normal file
View File

@@ -0,0 +1,71 @@
name: Build and Test
on:
push:
branches: [ '*' ]
tags: [ '*' ]
pull_request:
branches: [ '*' ]
jobs:
single-m4-pro:
strategy:
matrix:
model: ['llama-3.2-1b', 'llama-3.2-3b', 'llama-3.1-8b']
uses: ./.github/workflows/bench_job.yml
with:
config: '{"M4PRO_GPU16_24GB": 1}'
model: ${{ matrix.model }}
calling_job_name: 'single-m4-pro'
network_interface: 'Ethernet'
secrets: inherit
two-m4-pro-cluster:
strategy:
matrix:
model: ['llama-3.2-1b', 'llama-3.2-3b', 'llama-3.1-8b']
uses: ./.github/workflows/bench_job.yml
with:
config: '{"M4PRO_GPU16_24GB": 2}'
model: ${{ matrix.model }}
calling_job_name: 'two-m4-pro-cluster'
network_interface: 'Ethernet'
secrets: inherit
# two-m4-pro-cluster-thunderbolt:
# strategy:
# matrix:
# model: ['llama-3.2-1b', 'llama-3.2-3b', 'llama-3.1-8b']
# uses: ./.github/workflows/bench_job.yml
# with:
# config: '{"M4PRO_GPU16_24GB": 2}'
# model: ${{ matrix.model }}
# calling_job_name: 'two-m4-pro-cluster-thunderbolt'
# network_interface: 'Thunderbolt'
# secrets: inherit
three-m4-pro-cluster:
strategy:
matrix:
model: ['llama-3.2-1b', 'llama-3.2-3b', 'llama-3.1-8b', 'llama-3.3-70b']
fail-fast: false
uses: ./.github/workflows/bench_job.yml
with:
config: '{"M4PRO_GPU16_24GB": 3}'
model: ${{ matrix.model }}
calling_job_name: 'three-m4-pro-cluster'
network_interface: 'Ethernet'
secrets: inherit
# test-m3-single-node:
# strategy:
# matrix:
# model: ['llama-3.2-1b']
# fail-fast: false
# uses: ./.github/workflows/bench_job.yml
# with:
# config: '{"M3MAX_GPU40_128GB": 1}'
# model: ${{ matrix.model }}
# calling_job_name: 'test-m3-cluster'
# network_interface: 'Ethernet'
# secrets: inherit

2
.gitignore vendored
View File

@@ -171,3 +171,5 @@ cython_debug/
**/*.xcodeproj/*
.aider*
exo/tinychat/images/*.png

View File

@@ -18,14 +18,17 @@ exo: Run your own AI cluster at home with everyday devices. Maintained by [exo l
[![Tests](https://dl.circleci.com/status-badge/img/circleci/TrkofJDoGzdQAeL6yVHKsg/4i5hJuafuwZYZQxbRAWS71/tree/main.svg?style=svg)](https://dl.circleci.com/status-badge/redirect/circleci/TrkofJDoGzdQAeL6yVHKsg/4i5hJuafuwZYZQxbRAWS71/tree/main)
[![License: GPL v3](https://img.shields.io/badge/License-GPLv3-blue.svg)](https://www.gnu.org/licenses/gpl-3.0)
<a href="https://trendshift.io/repositories/11849" target="_blank"><img src="https://trendshift.io/api/badge/repositories/11849" alt="exo-explore%2Fexo | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
</div>
---
Forget expensive NVIDIA GPUs, unify your existing devices into one powerful GPU: iPhone, iPad, Android, Mac, Linux, pretty much any device!
Unify your existing devices into one powerful GPU: iPhone, iPad, Android, Mac, NVIDIA, Raspberry Pi, pretty much any device!
<div align="center">
<h2>Update: exo is hiring. See <a href="https://exolabs.net">here</a> for more details.</h2>
<h2>Interested in running exo in your business? <a href="mailto:hello@exolabs.net">Contact us</a> to discuss.</h2>
</div>
## Get Involved
@@ -38,7 +41,7 @@ We also welcome contributions from the community. We have a list of bounties in
### Wide Model Support
exo supports different models including LLaMA ([MLX](exo/inference/mlx/models/llama.py) and [tinygrad](exo/inference/tinygrad/models/llama.py)), Mistral, LlaVA, Qwen and Deepseek.
exo supports different models including LLaMA ([MLX](exo/inference/mlx/models/llama.py) and [tinygrad](exo/inference/tinygrad/models/llama.py)), Mistral, LlaVA, Qwen, and Deepseek.
### Dynamic Model Partitioning
@@ -58,7 +61,7 @@ Unlike other distributed inference frameworks, exo does not use a master-worker
Exo supports different [partitioning strategies](exo/topology/partitioning_strategy.py) to split up a model across devices. The default partitioning strategy is [ring memory weighted partitioning](exo/topology/ring_memory_weighted_partitioning_strategy.py). This runs an inference in a ring where each device runs a number of model layers proportional to the memory of the device.
!["A screenshot of exo running 5 nodes](docs/exo-screenshot.png)
!["A screenshot of exo running 5 nodes](docs/exo-screenshot.jpg)
## Installation
@@ -100,13 +103,13 @@ source install.sh
- There are a number of things users have empirically found to improve performance on Apple Silicon Macs:
1. Upgrade to the latest version of MacOS 15.
1. Upgrade to the latest version of macOS Sequoia.
2. Run `./configure_mlx.sh`. This runs commands to optimize GPU memory allocation on Apple Silicon Macs.
## Documentation
### Example Usage on Multiple MacOS Devices
### Example Usage on Multiple macOS Devices
#### Device 1:
@@ -149,6 +152,18 @@ curl http://localhost:52415/v1/chat/completions \
}'
```
#### DeepSeek R1 (full 671B):
```sh
curl http://localhost:52415/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "deepseek-r1",
"messages": [{"role": "user", "content": "What is the meaning of exo?"}],
"temperature": 0.7
}'
```
#### Llava 1.5 7B (Vision Language Model):
```sh
@@ -177,9 +192,9 @@ curl http://localhost:52415/v1/chat/completions \
}'
```
### Example Usage on Multiple Heterogenous Devices (MacOS + Linux)
### Example Usage on Multiple Heterogenous Devices (macOS + Linux)
#### Device 1 (MacOS):
#### Device 1 (macOS):
```sh
exo
@@ -210,9 +225,19 @@ exo run llama-3.2-3b --prompt "What is the meaning of exo?"
### Model Storage
Models by default are stored in `~/.cache/huggingface/hub`.
Models by default are stored in `~/.cache/exo/downloads`.
You can set a different model storage location by setting the `HF_HOME` env var.
You can set a different model storage location by setting the `EXO_HOME` env var.
## Model Downloading
Models are downloaded from Hugging Face. If you are running exo in a country with strict internet censorship, you may need to download the models manually and put them in the `~/.cache/exo/downloads` directory.
To download models from a proxy endpoint, set the `HF_ENDPOINT` environment variable. For example, to run exo with the huggingface mirror endpoint:
```sh
HF_ENDPOINT=https://hf-mirror.com exo
```
## Debugging
@@ -244,7 +269,7 @@ python3 format.py ./exo
## Known Issues
- On some versions of MacOS/Python, certificates are not installed properly which can lead to SSL errors (e.g. SSL error with huggingface.co). To fix this, run the Install Certificates command, usually:
- On certain versions of Python on macOS, certificates may not installed correctly, potentially causing SSL errors (e.g., when accessing huggingface.co). To resolve this, run the `Install Certificates` command, typicall as follows:
```sh
/Applications/Python 3.x/Install Certificates.command
@@ -261,8 +286,15 @@ exo supports the following inference engines:
- 🚧 [PyTorch](https://github.com/exo-explore/exo/pull/139)
- 🚧 [llama.cpp](https://github.com/exo-explore/exo/issues/167)
## Networking Modules
## Discovery Modules
- ✅ [GRPC](exo/networking/grpc)
- ✅ [UDP](exo/networking/udp)
- ✅ [Manual](exo/networking/manual)
- ✅ [Tailscale](exo/networking/tailscale)
- 🚧 [Radio](TODO)
- 🚧 [Bluetooth](TODO)
# Peer Networking Modules
- ✅ [GRPC](exo/networking/grpc)
- 🚧 [NCCL](TODO)

View File

@@ -1,18 +1,43 @@
#!/bin/bash
#!/usr/bin/env bash
# Get the total memory in MB
TOTAL_MEM_MB=$(($(sysctl -n hw.memsize) / 1024 / 1024))
# Set WIRED_LIMIT_MB to 80%
WIRED_LIMIT_MB=$(($TOTAL_MEM_MB * 80 / 100))
# Set WIRED_LWM_MB to 70%
WIRED_LWM_MB=$(($TOTAL_MEM_MB * 70 / 100))
# Calculate 80% and TOTAL_MEM_GB-5GB in MB
EIGHTY_PERCENT=$(($TOTAL_MEM_MB * 80 / 100))
MINUS_5GB=$((($TOTAL_MEM_MB - 5120)))
# Calculate 70% and TOTAL_MEM_GB-8GB in MB
SEVENTY_PERCENT=$(($TOTAL_MEM_MB * 70 / 100))
MINUS_8GB=$((($TOTAL_MEM_MB - 8192)))
# Set WIRED_LIMIT_MB to higher value
if [ $EIGHTY_PERCENT -gt $MINUS_5GB ]; then
WIRED_LIMIT_MB=$EIGHTY_PERCENT
else
WIRED_LIMIT_MB=$MINUS_5GB
fi
# Set WIRED_LWM_MB to higher value
if [ $SEVENTY_PERCENT -gt $MINUS_8GB ]; then
WIRED_LWM_MB=$SEVENTY_PERCENT
else
WIRED_LWM_MB=$MINUS_8GB
fi
# Display the calculated values
echo "Total memory: $TOTAL_MEM_MB MB"
echo "Maximum limit (iogpu.wired_limit_mb): $WIRED_LIMIT_MB MB"
echo "Lower bound (iogpu.wired_lwm_mb): $WIRED_LWM_MB MB"
# Apply the values with sysctl
sudo sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB
sudo sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB
# Apply the values with sysctl, but check if we're already root
if [ "$EUID" -eq 0 ]; then
sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB
sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB
else
# Try without sudo first, fall back to sudo if needed
sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB 2>/dev/null || \
sudo sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB
sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB 2>/dev/null || \
sudo sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB
fi

BIN
docs/exo-screenshot.jpg Normal file
View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 295 KiB

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:561ec71a226a154503b1d70553c9d57c7cd45dbb3bb0e1244ed5b00edbf0523d
size 479724

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3f57b11f2d3aefb3887c5266994c4b4335501830c77a6a53fa6901c8725d0f6c
size 55857

View File

@@ -0,0 +1,111 @@
import json
import re
import requests
def get_current_weather(location: str, unit: str = "celsius"):
"""Mock weather data function"""
# Hardcoded response for demo purposes
return {
"location": location,
"temperature": 22 if unit == "celsius" else 72,
"unit": unit,
"forecast": "Sunny with light clouds"
}
def try_parse_tool_calls(content: str):
"""Try parse the tool calls."""
tool_calls = []
offset = 0
for i, m in enumerate(re.finditer(r"<tool_call>\n(.+)?\n</tool_call>", content)):
if i == 0:
offset = m.start()
try:
func = json.loads(m.group(1))
tool_calls.append({"type": "function", "function": func})
if isinstance(func["arguments"], str):
func["arguments"] = json.loads(func["arguments"])
except json.JSONDecodeError as e:
print(f"Failed to parse tool calls: the content is {m.group(1)} and {e}")
pass
if tool_calls:
if offset > 0 and content[:offset].strip():
c = content[:offset]
else:
c = ""
return {"role": "assistant", "content": c, "tool_calls": tool_calls}
return {"role": "assistant", "content": re.sub(r"<\|im_end\|>$", "", content)}
def chat_completion(messages):
"""Send chat completion request to local server"""
response = requests.post(
"http://localhost:52415/v1/chat/completions",
json={
"model": "qwen-2.5-1.5b",
"messages": messages,
"tools": [{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["location"]
}
}
}],
"tool_choice": "auto"
}
)
return response.json()
def main():
# Initial conversation
messages = [{
"role": "user",
"content": "Hi there, what's the weather in Boston?"
}]
# Get initial response
response = chat_completion(messages)
print(f"First response: {response}")
assistant_message = try_parse_tool_calls(response["choices"][0]["message"]["content"])
messages.append(assistant_message)
# If there are tool calls, execute them and continue conversation
if "tool_calls" in assistant_message:
for tool_call in assistant_message["tool_calls"]:
if tool_call["function"]["name"] == "get_current_weather":
args = tool_call["function"]["arguments"]
weather_data = get_current_weather(**args)
# Add tool response to messages
messages.append({
"role": "tool",
"content": json.dumps(weather_data),
"name": tool_call["function"]["name"]
})
# Get final response with weather data
response = chat_completion(messages)
print(f"Final response: {response}")
messages.append({
"role": "assistant",
"content": response["choices"][0]["message"]["content"]
})
# Print full conversation
for msg in messages:
print(f"\n{msg['role'].upper()}: {msg['content']}")
if __name__ == "__main__":
main()

View File

@@ -5,41 +5,56 @@ import json
import os
from pathlib import Path
from transformers import AutoTokenizer
from typing import List, Literal, Union, Dict
from typing import List, Literal, Union, Dict, Optional
from aiohttp import web
import aiohttp_cors
import traceback
import signal
from exo import DEBUG, VERSION
from exo.download.download_progress import RepoProgressEvent
from exo.helpers import PrefixDict, shutdown
from exo.helpers import PrefixDict, shutdown, get_exo_images_dir
from exo.inference.tokenizers import resolve_tokenizer
from exo.orchestration import Node
from exo.models import build_base_shard, model_cards, get_repo, pretty_name
from exo.models import build_base_shard, build_full_shard, model_cards, get_repo, get_supported_models, get_pretty_name
from typing import Callable, Optional
from exo.download.hf.hf_shard_download import HFShardDownloader
import shutil
from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
from PIL import Image
import numpy as np
import base64
from io import BytesIO
import platform
from exo.download.download_progress import RepoProgressEvent
from exo.download.new_shard_download import delete_model
import tempfile
from exo.apputil import create_animation_mp4
from collections import defaultdict
if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
import mlx.core as mx
else:
import numpy as mx
class Message:
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]], tools: Optional[List[Dict]] = None):
self.role = role
self.content = content
self.tools = tools
def to_dict(self):
return {"role": self.role, "content": self.content}
data = {"role": self.role, "content": self.content}
if self.tools:
data["tools"] = self.tools
return data
class ChatCompletionRequest:
def __init__(self, model: str, messages: List[Message], temperature: float):
def __init__(self, model: str, messages: List[Message], temperature: float, tools: Optional[List[Dict]] = None):
self.model = model
self.messages = messages
self.temperature = temperature
self.tools = tools
def to_dict(self):
return {"model": self.model, "messages": [message.to_dict() for message in self.messages], "temperature": self.temperature}
return {"model": self.model, "messages": [message.to_dict() for message in self.messages], "temperature": self.temperature, "tools": self.tools}
def generate_completion(
@@ -119,20 +134,32 @@ def remap_messages(messages: List[Message]) -> List[Message]:
return remapped_messages
def build_prompt(tokenizer, _messages: List[Message]):
def build_prompt(tokenizer, _messages: List[Message], tools: Optional[List[Dict]] = None):
messages = remap_messages(_messages)
prompt = tokenizer.apply_chat_template([m.to_dict() for m in messages], tokenize=False, add_generation_prompt=True)
for message in messages:
if not isinstance(message.content, list):
continue
chat_template_args = {"conversation": [m.to_dict() for m in messages], "tokenize": False, "add_generation_prompt": True}
if tools:
chat_template_args["tools"] = tools
return prompt
try:
prompt = tokenizer.apply_chat_template(**chat_template_args)
if DEBUG >= 3: print(f"!!! Prompt: {prompt}")
return prompt
except UnicodeEncodeError:
# Handle Unicode encoding by ensuring everything is UTF-8
chat_template_args["conversation"] = [
{k: v.encode('utf-8').decode('utf-8') if isinstance(v, str) else v
for k, v in m.to_dict().items()}
for m in messages
]
prompt = tokenizer.apply_chat_template(**chat_template_args)
if DEBUG >= 3: print(f"!!! Prompt (UTF-8 encoded): {prompt}")
return prompt
def parse_message(data: dict):
if "role" not in data or "content" not in data:
raise ValueError(f"Invalid message: {data}. Must have 'role' and 'content'")
return Message(data["role"], data["content"])
return Message(data["role"], data["content"], data.get("tools"))
def parse_chat_request(data: dict, default_model: str):
@@ -140,6 +167,7 @@ def parse_chat_request(data: dict, default_model: str):
data.get("model", default_model),
[parse_message(msg) for msg in data["messages"]],
data.get("temperature", 0.0),
data.get("tools", None),
)
@@ -149,8 +177,17 @@ class PromptSession:
self.timestamp = timestamp
self.prompt = prompt
class ChatGPTAPI:
def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None):
def __init__(
self,
node: Node,
inference_engine_classname: str,
response_timeout: int = 90,
on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None,
default_model: Optional[str] = None,
system_prompt: Optional[str] = None
):
self.node = node
self.inference_engine_classname = inference_engine_classname
self.response_timeout = response_timeout
@@ -160,6 +197,12 @@ class ChatGPTAPI:
self.prev_token_lens: Dict[str, int] = {}
self.stream_tasks: Dict[str, asyncio.Task] = {}
self.default_model = default_model or "llama-3.2-1b"
self.token_queues = defaultdict(asyncio.Queue)
# Get the callback system and register our handler
self.token_callback = node.on_token.register("chatgpt-api-token-handler")
self.token_callback.on_next(lambda _request_id, tokens, is_finished: asyncio.create_task(self.handle_tokens(_request_id, tokens, is_finished)))
self.system_prompt = system_prompt
cors = aiohttp_cors.setup(self.app)
cors_options = aiohttp_cors.ResourceOptions(
@@ -174,6 +217,7 @@ class ChatGPTAPI:
cors.add(self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
cors.add(self.app.router.add_post("/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
cors.add(self.app.router.add_post("/v1/image/generations", self.handle_post_image_generations), {"*": cors_options})
cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})
cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options})
@@ -182,18 +226,25 @@ class ChatGPTAPI:
cors.add(self.app.router.add_get("/initial_models", self.handle_get_initial_models), {"*": cors_options})
cors.add(self.app.router.add_post("/create_animation", self.handle_create_animation), {"*": cors_options})
cors.add(self.app.router.add_post("/download", self.handle_post_download), {"*": cors_options})
cors.add(self.app.router.add_get("/v1/topology", self.handle_get_topology), {"*": cors_options})
cors.add(self.app.router.add_get("/topology", self.handle_get_topology), {"*": cors_options})
# Add static routes
if "__compiled__" not in globals():
self.static_dir = Path(__file__).parent.parent/"tinychat"
self.app.router.add_get("/", self.handle_root)
self.app.router.add_static("/", self.static_dir, name="static")
# Always add images route, regardless of compilation status
self.images_dir = get_exo_images_dir()
self.images_dir.mkdir(parents=True, exist_ok=True)
self.app.router.add_static('/images/', self.images_dir, name='static_images')
self.app.middlewares.append(self.timeout_middleware)
self.app.middlewares.append(self.log_request)
async def handle_quit(self, request):
if DEBUG>=1: print("Received quit signal")
if DEBUG >= 1: print("Received quit signal")
response = web.json_response({"detail": "Quit signal received"}, status=200)
await response.prepare(request)
await response.write_eof()
@@ -223,58 +274,22 @@ class ChatGPTAPI:
async def handle_model_support(self, request):
try:
response = web.StreamResponse(
status=200,
reason='OK',
headers={
'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache',
'Connection': 'keep-alive',
}
)
await response.prepare(request)
for model_name, pretty in pretty_name.items():
if model_name in model_cards:
model_info = model_cards[model_name]
if self.inference_engine_classname in model_info.get("repo", {}):
shard = build_base_shard(model_name, self.inference_engine_classname)
if shard:
downloader = HFShardDownloader(quick_check=True)
downloader.current_shard = shard
downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
status = await downloader.get_shard_download_status()
download_percentage = status.get("overall") if status else None
total_size = status.get("total_size") if status else None
total_downloaded = status.get("total_downloaded") if status else False
model_data = {
model_name: {
"name": pretty,
"downloaded": download_percentage == 100 if download_percentage is not None else False,
"download_percentage": download_percentage,
"total_size": total_size,
"total_downloaded": total_downloaded
}
}
await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
await response.write(b"data: [DONE]\n\n")
return response
response = web.StreamResponse(status=200, reason='OK', headers={ 'Content-Type': 'text/event-stream', 'Cache-Control': 'no-cache', 'Connection': 'keep-alive' })
await response.prepare(request)
async for path, s in self.node.shard_downloader.get_shard_download_status(self.inference_engine_classname):
model_data = { s.shard.model_id: { "downloaded": s.downloaded_bytes == s.total_bytes, "download_percentage": 100 if s.downloaded_bytes == s.total_bytes else 100 * float(s.downloaded_bytes) / float(s.total_bytes), "total_size": s.total_bytes, "total_downloaded": s.downloaded_bytes } }
await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
await response.write(b"data: [DONE]\n\n")
return response
except Exception as e:
print(f"Error in handle_model_support: {str(e)}")
traceback.print_exc()
return web.json_response(
{"detail": f"Server error: {str(e)}"},
status=500
)
print(f"Error in handle_model_support: {str(e)}")
traceback.print_exc()
return web.json_response({"detail": f"Server error: {str(e)}"}, status=500)
async def handle_get_models(self, request):
return web.json_response([{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()])
models_list = [{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()]
return web.json_response({"object": "list", "data": models_list})
async def handle_post_chat_token_encode(self, request):
data = await request.json()
@@ -287,7 +302,7 @@ class ChatGPTAPI:
shard = build_base_shard(model, self.inference_engine_classname)
messages = [parse_message(msg) for msg in data.get("messages", [])]
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
prompt = build_prompt(tokenizer, messages)
prompt = build_prompt(tokenizer, messages, data.get("tools", None))
tokens = tokenizer.encode(prompt)
return web.json_response({
"length": len(prompt),
@@ -300,6 +315,7 @@ class ChatGPTAPI:
progress_data = {}
for node_id, progress_event in self.node.node_download_progress.items():
if isinstance(progress_event, RepoProgressEvent):
if progress_event.status != "in_progress": continue
progress_data[node_id] = progress_event.to_dict()
else:
print(f"Unknown progress event type: {type(progress_event)}. {progress_event}")
@@ -307,13 +323,13 @@ class ChatGPTAPI:
async def handle_post_chat_completions(self, request):
data = await request.json()
if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
if DEBUG >= 2: print(f"[ChatGPTAPI] Handling chat completions request from {request.remote}: {data}")
stream = data.get("stream", False)
chat_request = parse_chat_request(data, self.default_model)
if chat_request.model and chat_request.model.startswith("gpt-"): # to be compatible with ChatGPT tools, point all gpt- model requests to default model
chat_request.model = self.default_model
if not chat_request.model or chat_request.model not in model_cards:
if DEBUG >= 1: print(f"Invalid model: {chat_request.model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
if DEBUG >= 1: print(f"[ChatGPTAPI] Invalid model: {chat_request.model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
chat_request.model = self.default_model
shard = build_base_shard(chat_request.model, self.inference_engine_classname)
if not shard:
@@ -324,37 +340,26 @@ class ChatGPTAPI:
)
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")
if DEBUG >= 4: print(f"[ChatGPTAPI] Resolved tokenizer: {tokenizer}")
prompt = build_prompt(tokenizer, chat_request.messages)
# Add system prompt if set
if self.system_prompt and not any(msg.role == "system" for msg in chat_request.messages):
chat_request.messages.insert(0, Message("system", self.system_prompt))
prompt = build_prompt(tokenizer, chat_request.messages, chat_request.tools)
request_id = str(uuid.uuid4())
if self.on_chat_completion_request:
try:
self.on_chat_completion_request(request_id, chat_request, prompt)
except Exception as e:
if DEBUG >= 2: traceback.print_exc()
# request_id = None
# match = self.prompts.find_longest_prefix(prompt)
# if match and len(prompt) > len(match[1].prompt):
# if DEBUG >= 2:
# print(f"Prompt for request starts with previous prompt {len(match[1].prompt)} of {len(prompt)}: {match[1].prompt}")
# request_id = match[1].request_id
# self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt))
# # remove the matching prefix from the prompt
# prompt = prompt[len(match[1].prompt):]
# else:
# request_id = str(uuid.uuid4())
# self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt))
callback_id = f"chatgpt-api-wait-response-{request_id}"
callback = self.node.on_token.register(callback_id)
if DEBUG >= 2: print(f"Sending prompt from ChatGPT api {request_id=} {shard=} {prompt=}")
if DEBUG >= 2: print(f"[ChatGPTAPI] Processing prompt: {request_id=} {shard=} {prompt=}")
try:
await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id))), timeout=self.response_timeout)
if DEBUG >= 2: print(f"Waiting for response to finish. timeout={self.response_timeout}s")
if DEBUG >= 2: print(f"[ChatGPTAPI] Waiting for response to finish. timeout={self.response_timeout}s")
if stream:
response = web.StreamResponse(
@@ -367,62 +372,74 @@ class ChatGPTAPI:
)
await response.prepare(request)
async def stream_result(_request_id: str, tokens: List[int], is_finished: bool):
prev_last_tokens_len = self.prev_token_lens.get(_request_id, 0)
self.prev_token_lens[_request_id] = max(prev_last_tokens_len, len(tokens))
new_tokens = tokens[prev_last_tokens_len:]
finish_reason = None
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") and isinstance(tokenizer._tokenizer,
AutoTokenizer) else getattr(tokenizer, "eos_token_id", None)
if len(new_tokens) > 0 and new_tokens[-1] == eos_token_id:
new_tokens = new_tokens[:-1]
if is_finished:
finish_reason = "stop"
if is_finished and not finish_reason:
finish_reason = "length"
try:
# Stream tokens while waiting for inference to complete
while True:
if DEBUG >= 2: print(f"[ChatGPTAPI] Waiting for token from queue: {request_id=}")
tokens, is_finished = await asyncio.wait_for(
self.token_queues[request_id].get(),
timeout=self.response_timeout
)
if DEBUG >= 2: print(f"[ChatGPTAPI] Got token from queue: {request_id=} {tokens=} {is_finished=}")
eos_token_id = None
if not eos_token_id and hasattr(tokenizer, "eos_token_id"): eos_token_id = tokenizer.eos_token_id
if not eos_token_id and hasattr(tokenizer, "_tokenizer"): eos_token_id = tokenizer.special_tokens_map.get("eos_token_id")
finish_reason = None
if is_finished: finish_reason = "stop" if tokens[-1] == eos_token_id else "length"
if DEBUG >= 2: print(f"{eos_token_id=} {tokens[-1]=} {finish_reason=}")
completion = generate_completion(
chat_request,
tokenizer,
prompt,
request_id,
tokens,
stream,
finish_reason,
"chat.completion",
)
completion = generate_completion(
chat_request,
tokenizer,
prompt,
request_id,
new_tokens,
stream,
finish_reason,
"chat.completion",
)
if DEBUG >= 2: print(f"Streaming completion: {completion}")
try:
await response.write(f"data: {json.dumps(completion)}\n\n".encode())
except Exception as e:
if DEBUG >= 2: print(f"Error streaming completion: {e}")
if DEBUG >= 2: traceback.print_exc()
def on_result(_request_id: str, tokens: List[int], is_finished: bool):
if _request_id == request_id: self.stream_tasks[_request_id] = asyncio.create_task(stream_result(_request_id, tokens, is_finished))
if is_finished:
break
return _request_id == request_id and is_finished
await response.write_eof()
return response
_, tokens, _ = await callback.wait(on_result, timeout=self.response_timeout)
if request_id in self.stream_tasks: # in case there is still a stream task running, wait for it to complete
if DEBUG >= 2: print("Pending stream task. Waiting for stream task to complete.")
try:
await asyncio.wait_for(self.stream_tasks[request_id], timeout=30)
except asyncio.TimeoutError:
print("WARNING: Stream task timed out. This should not happen.")
await response.write_eof()
return response
except asyncio.TimeoutError:
if DEBUG >= 2: print(f"[ChatGPTAPI] Timeout waiting for token: {request_id=}")
return web.json_response({"detail": "Response generation timed out"}, status=408)
except Exception as e:
if DEBUG >= 2:
print(f"[ChatGPTAPI] Error processing prompt: {e}")
traceback.print_exc()
return web.json_response(
{"detail": f"Error processing prompt: {str(e)}"},
status=500
)
finally:
# Clean up the queue for this request
if request_id in self.token_queues:
if DEBUG >= 2: print(f"[ChatGPTAPI] Cleaning up token queue: {request_id=}")
del self.token_queues[request_id]
else:
_, tokens, _ = await callback.wait(
lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished,
timeout=self.response_timeout,
)
tokens = []
while True:
_tokens, is_finished = await asyncio.wait_for(self.token_queues[request_id].get(), timeout=self.response_timeout)
tokens.extend(_tokens)
if is_finished:
break
finish_reason = "length"
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(getattr(tokenizer, "_tokenizer", None), AutoTokenizer) else tokenizer.eos_token_id
eos_token_id = None
if not eos_token_id and hasattr(tokenizer, "eos_token_id"): eos_token_id = tokenizer.eos_token_id
if not eos_token_id and hasattr(tokenizer, "_tokenizer"): eos_token_id = tokenizer.special_tokens_map.get("eos_token_id")
if DEBUG >= 2: print(f"Checking if end of tokens result {tokens[-1]=} is {eos_token_id=}")
if tokens[-1] == eos_token_id:
tokens = tokens[:-1]
finish_reason = "stop"
return web.json_response(generate_completion(chat_request, tokenizer, prompt, request_id, tokens, stream, finish_reason, "chat.completion"))
@@ -431,73 +448,119 @@ class ChatGPTAPI:
except Exception as e:
if DEBUG >= 2: traceback.print_exc()
return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
finally:
deregistered_callback = self.node.on_token.deregister(callback_id)
if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")
async def handle_delete_model(self, request):
async def handle_post_image_generations(self, request):
data = await request.json()
if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
stream = data.get("stream", False)
model = data.get("model", "")
prompt = data.get("prompt", "")
image_url = data.get("image_url", "")
if DEBUG >= 2: print(f"model: {model}, prompt: {prompt}, stream: {stream}")
shard = build_base_shard(model, self.inference_engine_classname)
if DEBUG >= 2: print(f"shard: {shard}")
if not shard:
return web.json_response({"error": f"Unsupported model: {model} with inference engine {self.inference_engine_classname}"}, status=400)
request_id = str(uuid.uuid4())
callback_id = f"chatgpt-api-wait-response-{request_id}"
callback = self.node.on_token.register(callback_id)
try:
model_name = request.match_info.get('model_name')
if DEBUG >= 2: print(f"Attempting to delete model: {model_name}")
if not model_name or model_name not in model_cards:
return web.json_response(
{"detail": f"Invalid model name: {model_name}"},
status=400
)
shard = build_base_shard(model_name, self.inference_engine_classname)
if not shard:
return web.json_response(
{"detail": "Could not build shard for model"},
status=400
)
repo_id = get_repo(shard.model_id, self.inference_engine_classname)
if DEBUG >= 2: print(f"Repo ID for model: {repo_id}")
# Get the HF cache directory using the helper function
hf_home = get_hf_home()
cache_dir = get_repo_root(repo_id)
if DEBUG >= 2: print(f"Looking for model files in: {cache_dir}")
if os.path.exists(cache_dir):
if DEBUG >= 2: print(f"Found model files at {cache_dir}, deleting...")
try:
shutil.rmtree(cache_dir)
return web.json_response({
"status": "success",
"message": f"Model {model_name} deleted successfully",
"path": str(cache_dir)
})
except Exception as e:
return web.json_response({
"detail": f"Failed to delete model files: {str(e)}"
}, status=500)
if image_url != "" and image_url != None:
img = self.base64_decode(image_url)
else:
return web.json_response({
"detail": f"Model files not found at {cache_dir}"
}, status=404)
img = None
await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id, inference_state={"image": img}))), timeout=self.response_timeout)
response = web.StreamResponse(status=200, reason='OK', headers={
'Content-Type': 'application/octet-stream',
"Cache-Control": "no-cache",
})
await response.prepare(request)
def get_progress_bar(current_step, total_steps, bar_length=50):
# Calculate the percentage of completion
percent = float(current_step)/total_steps
# Calculate the number of hashes to display
arrow = '-'*int(round(percent*bar_length) - 1) + '>'
spaces = ' '*(bar_length - len(arrow))
# Create the progress bar string
progress_bar = f'Progress: [{arrow}{spaces}] {int(percent * 100)}% ({current_step}/{total_steps})'
return progress_bar
async def stream_image(_request_id: str, result, is_finished: bool):
if isinstance(result, list):
await response.write(json.dumps({'progress': get_progress_bar((result[0]), (result[1]))}).encode('utf-8') + b'\n')
elif isinstance(result, np.ndarray):
try:
im = Image.fromarray(np.array(result))
# Save the image to a file
image_filename = f"{_request_id}.png"
image_path = self.images_dir/image_filename
im.save(image_path)
# Get URL for the saved image
try:
image_url = request.app.router['static_images'].url_for(filename=image_filename)
base_url = f"{request.scheme}://{request.host}"
full_image_url = base_url + str(image_url)
await response.write(json.dumps({'images': [{'url': str(full_image_url), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
except KeyError as e:
if DEBUG >= 2: print(f"Error getting image URL: {e}")
# Fallback to direct file path if URL generation fails
await response.write(json.dumps({'images': [{'url': str(image_path), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
if is_finished:
await response.write_eof()
except Exception as e:
if DEBUG >= 2: print(f"Error processing image: {e}")
if DEBUG >= 2: traceback.print_exc()
await response.write(json.dumps({'error': str(e)}).encode('utf-8') + b'\n')
stream_task = None
def on_result(_request_id: str, result, is_finished: bool):
nonlocal stream_task
stream_task = asyncio.create_task(stream_image(_request_id, result, is_finished))
return _request_id == request_id and is_finished
await callback.wait(on_result, timeout=self.response_timeout*10)
if stream_task:
# Wait for the stream task to complete before returning
await stream_task
return response
except Exception as e:
print(f"Error in handle_delete_model: {str(e)}")
traceback.print_exc()
return web.json_response({
"detail": f"Server error: {str(e)}"
}, status=500)
if DEBUG >= 2: traceback.print_exc()
return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
async def handle_delete_model(self, request):
model_id = request.match_info.get('model_name')
try:
if await delete_model(model_id, self.inference_engine_classname): return web.json_response({"status": "success", "message": f"Model {model_id} deleted successfully"})
else: return web.json_response({"detail": f"Model {model_id} files not found"}, status=404)
except Exception as e:
if DEBUG >= 2: traceback.print_exc()
return web.json_response({"detail": f"Error deleting model: {str(e)}"}, status=500)
async def handle_get_initial_models(self, request):
model_data = {}
for model_name, pretty in pretty_name.items():
model_data[model_name] = {
"name": pretty,
"downloaded": None, # Initially unknown
"download_percentage": None, # Change from 0 to null
"total_size": None,
"total_downloaded": None,
"loading": True # Add loading state
}
for model_id in get_supported_models([[self.inference_engine_classname]]):
model_data[model_id] = {
"name": get_pretty_name(model_id),
"downloaded": None, # Initially unknown
"download_percentage": None, # Change from 0 to null
"total_size": None,
"total_downloaded": None,
"loading": True # Add loading state
}
return web.json_response(model_data)
async def handle_create_animation(self, request):
@@ -523,17 +586,9 @@ class ChatGPTAPI:
if DEBUG >= 2: print(f"Animation temp directory: {tmp_dir}, output file: {output_path}, directory exists: {tmp_dir.exists()}, directory permissions: {oct(tmp_dir.stat().st_mode)[-3:]}")
# Create the animation
create_animation_mp4(
replacement_image_path,
output_path,
device_name,
prompt_text
)
create_animation_mp4(replacement_image_path, output_path, device_name, prompt_text)
return web.json_response({
"status": "success",
"output_path": output_path
})
return web.json_response({"status": "success", "output_path": output_path})
except Exception as e:
if DEBUG >= 2: traceback.print_exc()
@@ -545,14 +600,11 @@ class ChatGPTAPI:
model_name = data.get("model")
if not model_name: return web.json_response({"error": "model parameter is required"}, status=400)
if model_name not in model_cards: return web.json_response({"error": f"Invalid model: {model_name}. Supported models: {list(model_cards.keys())}"}, status=400)
shard = build_base_shard(model_name, self.inference_engine_classname)
shard = build_full_shard(model_name, self.inference_engine_classname)
if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400)
asyncio.create_task(self.node.inference_engine.ensure_shard(shard))
asyncio.create_task(self.node.inference_engine.shard_downloader.ensure_shard(shard, self.inference_engine_classname))
return web.json_response({
"status": "success",
"message": f"Download started for model: {model_name}"
})
return web.json_response({"status": "success", "message": f"Download started for model: {model_name}"})
except Exception as e:
if DEBUG >= 2: traceback.print_exc()
return web.json_response({"error": str(e)}, status=500)
@@ -566,13 +618,28 @@ class ChatGPTAPI:
return web.json_response({})
except Exception as e:
if DEBUG >= 2: traceback.print_exc()
return web.json_response(
{"detail": f"Error getting topology: {str(e)}"},
status=500
)
return web.json_response({"detail": f"Error getting topology: {str(e)}"}, status=500)
async def handle_tokens(self, request_id: str, tokens: List[int], is_finished: bool):
await self.token_queues[request_id].put((tokens, is_finished))
async def run(self, host: str = "0.0.0.0", port: int = 52415):
runner = web.AppRunner(self.app)
await runner.setup()
site = web.TCPSite(runner, host, port)
await site.start()
def base64_decode(self, base64_string):
#decode and reshape image
if base64_string.startswith('data:image'):
base64_string = base64_string.split(',')[1]
image_data = base64.b64decode(base64_string)
img = Image.open(BytesIO(image_data))
W, H = (dim - dim%64 for dim in (img.width, img.height))
if W != img.width or H != img.height:
if DEBUG >= 2: print(f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}")
img = img.resize((W, H), Image.NEAREST) # use desired downsampling filter
img = mx.array(np.array(img))
img = (img[:, :, :3].astype(mx.float32)/255)*2 - 1
img = img[None]
return img

View File

@@ -2,6 +2,7 @@ from PIL import Image, ImageDraw, ImageFont, ImageFilter
import os
import numpy as np
import cv2
import sys
def draw_rounded_rectangle(draw, coords, radius, fill):
left, top, right, bottom = coords
@@ -80,14 +81,20 @@ def create_animation_mp4(
font = ImageFont.load_default()
promptfont = ImageFont.load_default()
# Get the base directory for images when running as a bundled app
if hasattr(sys, '_MEIPASS'):
base_dir = os.path.join(sys._MEIPASS, "exo", "apputil", "baseimages")
else:
base_dir = os.path.join(os.path.dirname(__file__), "baseimages")
# Process first frame
base_img = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image1.png"))
base_img = Image.open(os.path.join(base_dir, "image1.png"))
draw = ImageDraw.Draw(base_img)
draw_centered_text_rounded(draw, device_name, font, device_coords)
frames.extend([crop_image(base_img)] * 30) # 1 second at 30fps
# Process second frame with typing animation
base_img2 = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image2.png"))
base_img2 = Image.open(os.path.join(base_dir, "image2.png"))
for i in range(len(prompt_text) + 1):
current_frame = base_img2.copy()
draw = ImageDraw.Draw(current_frame)
@@ -101,7 +108,7 @@ def create_animation_mp4(
# Create blur sequence
replacement_img = Image.open(replacement_image_path)
base_img = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image3.png"))
base_img = Image.open(os.path.join(base_dir, "image3.png"))
blur_steps = [int(80 * (1 - i/8)) for i in range(9)]
for i, blur_amount in enumerate(blur_steps):
@@ -123,7 +130,7 @@ def create_animation_mp4(
frames.extend([crop_image(new_frame)] * 15) # 0.5 seconds at 30fps
# Create and add final frame (image4)
final_base = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image4.png"))
final_base = Image.open(os.path.join(base_dir, "image4.png"))
draw = ImageDraw.Draw(final_base)
draw_centered_text_rounded(draw, device_name, font, device_coords)
@@ -158,4 +165,4 @@ def create_animation_mp4(
out.write(frame_array)
out.release()
print(f"Video saved successfully to {output_path}")
print(f"Video saved successfully to {output_path}")

View File

@@ -1,4 +1,5 @@
from typing import Dict, Callable, Coroutine, Any, Literal
from exo.inference.shard import Shard
from dataclasses import dataclass
from datetime import timedelta
@@ -14,11 +15,12 @@ class RepoFileProgressEvent:
speed: int
eta: timedelta
status: Literal["not_started", "in_progress", "complete"]
start_time: float
def to_dict(self):
return {
"repo_id": self.repo_id, "repo_revision": self.repo_revision, "file_path": self.file_path, "downloaded": self.downloaded, "downloaded_this_session": self.downloaded_this_session,
"total": self.total, "speed": self.speed, "eta": self.eta.total_seconds(), "status": self.status
"total": self.total, "speed": self.speed, "eta": self.eta.total_seconds(), "status": self.status, "start_time": self.start_time
}
@classmethod
@@ -29,6 +31,7 @@ class RepoFileProgressEvent:
@dataclass
class RepoProgressEvent:
shard: Shard
repo_id: str
repo_revision: str
completed_files: int
@@ -43,7 +46,7 @@ class RepoProgressEvent:
def to_dict(self):
return {
"repo_id": self.repo_id, "repo_revision": self.repo_revision, "completed_files": self.completed_files, "total_files": self.total_files, "downloaded_bytes": self.downloaded_bytes,
"shard": self.shard.to_dict(), "repo_id": self.repo_id, "repo_revision": self.repo_revision, "completed_files": self.completed_files, "total_files": self.total_files, "downloaded_bytes": self.downloaded_bytes,
"downloaded_bytes_this_session": self.downloaded_bytes_this_session, "total_bytes": self.total_bytes, "overall_speed": self.overall_speed, "overall_eta": self.overall_eta.total_seconds(),
"file_progress": {k: v.to_dict()
for k, v in self.file_progress.items()}, "status": self.status
@@ -53,6 +56,7 @@ class RepoProgressEvent:
def from_dict(cls, data):
if 'overall_eta' in data: data['overall_eta'] = timedelta(seconds=data['overall_eta'])
if 'file_progress' in data: data['file_progress'] = {k: RepoFileProgressEvent.from_dict(v) for k, v in data['file_progress'].items()}
if 'shard' in data: data['shard'] = Shard.from_dict(data['shard'])
return cls(**data)

View File

@@ -1,36 +1,16 @@
import aiofiles.os as aios
from typing import Union
import asyncio
import aiohttp
import json
import os
import sys
import shutil
from urllib.parse import urljoin
from typing import Callable, Optional, Coroutine, Any, Dict, List, Union, Literal
from datetime import datetime, timedelta
from typing import Callable, Optional, Dict, List, Union
from fnmatch import fnmatch
from pathlib import Path
from typing import Generator, Iterable, TypeVar, TypedDict
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
from exo.helpers import DEBUG, is_frozen
from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent, RepoProgressCallback, RepoFileProgressCallback
from typing import Generator, Iterable, TypeVar
from exo.helpers import DEBUG
from exo.inference.shard import Shard
import aiofiles
T = TypeVar("T")
async def get_local_snapshot_dir(repo_id: str, revision: str = "main") -> Optional[Path]:
refs_dir = get_repo_root(repo_id)/"refs"
refs_file = refs_dir/revision
if await aios.path.exists(refs_file):
async with aiofiles.open(refs_file, 'r') as f:
commit_hash = (await f.read()).strip()
snapshot_dir = get_repo_root(repo_id)/"snapshots"/commit_hash
return snapshot_dir
return None
def filter_repo_objects(
items: Iterable[T],
*,
@@ -48,14 +28,12 @@ def filter_repo_objects(
ignore_patterns = [_add_wildcard_to_directories(p) for p in ignore_patterns]
if key is None:
def _identity(item: T) -> str:
if isinstance(item, str):
return item
if isinstance(item, Path):
return str(item)
raise ValueError(f"Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.")
key = _identity
for item in items:
@@ -66,22 +44,18 @@ def filter_repo_objects(
continue
yield item
def _add_wildcard_to_directories(pattern: str) -> str:
if pattern[-1] == "/":
return pattern + "*"
return pattern
def get_hf_endpoint() -> str:
return os.environ.get('HF_ENDPOINT', "https://huggingface.co")
def get_hf_home() -> Path:
"""Get the Hugging Face home directory."""
return Path(os.environ.get("HF_HOME", Path.home()/".cache"/"huggingface"))
async def get_hf_token():
"""Retrieve the Hugging Face token from the user's HF_HOME directory."""
token_path = get_hf_home()/"token"
@@ -90,7 +64,6 @@ async def get_hf_token():
return (await f.read()).strip()
return None
async def get_auth_headers():
"""Get authentication headers if a token is available."""
token = await get_hf_token()
@@ -98,321 +71,6 @@ async def get_auth_headers():
return {"Authorization": f"Bearer {token}"}
return {}
def get_repo_root(repo_id: str) -> Path:
"""Get the root directory for a given repo ID in the Hugging Face cache."""
sanitized_repo_id = str(repo_id).replace("/", "--")
return get_hf_home()/"hub"/f"models--{sanitized_repo_id}"
async def move_models_to_hf(seed_dir: Union[str, Path]):
"""Move model in resources folder of app to .cache/huggingface/hub"""
source_dir = Path(seed_dir)
dest_dir = get_hf_home()/"hub"
await aios.makedirs(dest_dir, exist_ok=True)
for path in source_dir.iterdir():
if path.is_dir() and path.name.startswith("models--"):
dest_path = dest_dir / path.name
if await aios.path.exists(dest_path):
print('Skipping moving model to .cache directory')
else:
try:
await aios.rename(str(path), str(dest_path))
except Exception as e:
print(f'Error moving model to .cache: {e}')
async def fetch_file_list(session, repo_id, revision, path=""):
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
url = f"{api_url}/{path}" if path else api_url
headers = await get_auth_headers()
async with session.get(url, headers=headers) as response:
if response.status == 200:
data = await response.json()
files = []
for item in data:
if item["type"] == "file":
files.append({"path": item["path"], "size": item["size"]})
elif item["type"] == "directory":
subfiles = await fetch_file_list(session, repo_id, revision, item["path"])
files.extend(subfiles)
return files
else:
raise Exception(f"Failed to fetch file list: {response.status}")
@retry(
stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=60), retry=retry_if_exception_type((aiohttp.ClientError, asyncio.TimeoutError, aiohttp.ClientResponseError)), reraise=True
)
async def download_file(
session: aiohttp.ClientSession, repo_id: str, revision: str, file_path: str, save_directory: str, progress_callback: Optional[RepoFileProgressCallback] = None, use_range_request: bool = True
):
base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
url = urljoin(base_url, file_path)
local_path = os.path.join(save_directory, file_path)
await aios.makedirs(os.path.dirname(local_path), exist_ok=True)
# Check if file already exists and get its size
local_file_size = await aios.path.getsize(local_path) if await aios.path.exists(local_path) else 0
headers = await get_auth_headers()
if use_range_request:
headers["Range"] = f"bytes={local_file_size}-"
async with session.get(url, headers=headers) as response:
total_size = int(response.headers.get('Content-Length', 0))
downloaded_size = local_file_size
downloaded_this_session = 0
mode = 'ab' if use_range_request else 'wb'
percentage = await get_file_download_percentage(
session,
repo_id,
revision,
file_path,
Path(save_directory)
)
if percentage == 100:
if DEBUG >= 2: print(f"File already downloaded: {file_path}")
if progress_callback:
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, total_size, 0, total_size, 0, timedelta(0), "complete"))
return
if response.status == 200:
# File doesn't support range requests or we're not using them, start from beginning
mode = 'wb'
downloaded_size = 0
elif response.status == 206:
# Partial content, resume download
content_range = response.headers.get('Content-Range', '')
try:
total_size = int(content_range.split('/')[-1])
except ValueError:
if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
elif response.status == 416:
# Range not satisfiable, get the actual file size
content_range = response.headers.get('Content-Range', '')
try:
total_size = int(content_range.split('/')[-1])
if downloaded_size == total_size:
if DEBUG >= 2: print(f"File fully downloaded on first pass: {file_path}")
if progress_callback:
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
return
except ValueError:
if DEBUG >= 1: print(f"Failed to parse Content-Range header: {content_range}. Starting download from scratch...")
return await download_file(session, repo_id, revision, file_path, save_directory, progress_callback, use_range_request=False)
else:
raise aiohttp.ClientResponseError(response.request_info, response.history, status=response.status, message=f"Failed to download {file_path}: {response.status}")
if downloaded_size == total_size:
print(f"File already downloaded: {file_path}")
if progress_callback:
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, 0, timedelta(0), "complete"))
return
DOWNLOAD_CHUNK_SIZE = 32768
start_time = datetime.now()
async with aiofiles.open(local_path, mode) as f:
async for chunk in response.content.iter_chunked(DOWNLOAD_CHUNK_SIZE):
await f.write(chunk)
downloaded_size += len(chunk)
downloaded_this_session += len(chunk)
if progress_callback and total_size:
elapsed_time = (datetime.now() - start_time).total_seconds()
speed = int(downloaded_this_session/elapsed_time) if elapsed_time > 0 else 0
remaining_size = total_size - downloaded_size
eta = timedelta(seconds=remaining_size/speed) if speed > 0 else timedelta(0)
status = "in_progress" if downloaded_size < total_size else "complete"
if DEBUG >= 8: print(f"HF repo file download progress: {file_path=} {elapsed_time=} {speed=} Downloaded={downloaded_size}/{total_size} {remaining_size=} {eta=} {status=}")
await progress_callback(RepoFileProgressEvent(repo_id, revision, file_path, downloaded_size, downloaded_this_session, total_size, speed, eta, status))
if DEBUG >= 2: print(f"Downloaded: {file_path}")
async def resolve_revision_to_commit_hash(repo_id: str, revision: str) -> str:
repo_root = get_repo_root(repo_id)
refs_dir = repo_root/"refs"
refs_file = refs_dir/revision
# Check if we have a cached commit hash
if await aios.path.exists(refs_file):
async with aiofiles.open(refs_file, 'r') as f:
commit_hash = (await f.read()).strip()
if DEBUG >= 2: print(f"Commit hash is already cached at {refs_file}: {commit_hash}")
return commit_hash
# Fetch the commit hash for the given revision
async with aiohttp.ClientSession() as session:
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/revision/{revision}"
headers = await get_auth_headers()
async with session.get(api_url, headers=headers) as response:
if response.status != 200:
raise Exception(f"Failed to fetch revision info from {api_url}: {response.status}")
revision_info = await response.json()
commit_hash = revision_info['sha']
# Cache the commit hash
await aios.makedirs(refs_dir, exist_ok=True)
async with aiofiles.open(refs_file, 'w') as f:
await f.write(commit_hash)
return commit_hash
async def download_repo_files(
repo_id: str,
revision: str = "main",
progress_callback: Optional[RepoProgressCallback] = None,
allow_patterns: Optional[Union[List[str], str]] = None,
ignore_patterns: Optional[Union[List[str], str]] = None,
max_parallel_downloads: int = 4
) -> Path:
repo_root = get_repo_root(repo_id)
snapshots_dir = repo_root/"snapshots"
cachedreqs_dir = repo_root/"cachedreqs"
# Ensure directories exist
await aios.makedirs(snapshots_dir, exist_ok=True)
await aios.makedirs(cachedreqs_dir, exist_ok=True)
# Resolve revision to commit hash
commit_hash = await resolve_revision_to_commit_hash(repo_id, revision)
# Set up the snapshot directory
snapshot_dir = snapshots_dir/commit_hash
await aios.makedirs(snapshot_dir, exist_ok=True)
# Set up the cached file list directory
cached_file_list_dir = cachedreqs_dir/commit_hash
await aios.makedirs(cached_file_list_dir, exist_ok=True)
cached_file_list_path = cached_file_list_dir/"fetch_file_list.json"
async with aiohttp.ClientSession() as session:
# Check if we have a cached file list
if await aios.path.exists(cached_file_list_path):
async with aiofiles.open(cached_file_list_path, 'r') as f:
file_list = json.loads(await f.read())
if DEBUG >= 2: print(f"Using cached file list from {cached_file_list_path}")
else:
file_list = await fetch_file_list(session, repo_id, revision)
# Cache the file list
async with aiofiles.open(cached_file_list_path, 'w') as f:
await f.write(json.dumps(file_list))
if DEBUG >= 2: print(f"Cached file list at {cached_file_list_path}")
filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, key=lambda x: x["path"]))
total_files = len(filtered_file_list)
total_bytes = sum(file["size"] for file in filtered_file_list)
file_progress: Dict[str, RepoFileProgressEvent] = {
file["path"]: RepoFileProgressEvent(repo_id, revision, file["path"], 0, 0, file["size"], 0, timedelta(0), "not_started")
for file in filtered_file_list
}
start_time = datetime.now()
async def download_with_progress(file_info, progress_state):
local_path = snapshot_dir/file_info["path"]
if await aios.path.exists(local_path) and (await aios.stat(local_path)).st_size == file_info["size"]:
if DEBUG >= 2: print(f"File already fully downloaded: {file_info['path']}")
progress_state['completed_files'] += 1
progress_state['downloaded_bytes'] += file_info["size"]
file_progress[file_info["path"]] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], 0, file_info["size"], 0, timedelta(0), "complete")
if progress_callback:
elapsed_time = (datetime.now() - start_time).total_seconds()
overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0
remaining_bytes = total_bytes - progress_state['downloaded_bytes']
overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0)
status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
await progress_callback(
RepoProgressEvent(
repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
overall_eta, file_progress, status
)
)
return
async def file_progress_callback(event: RepoFileProgressEvent):
progress_state['downloaded_bytes'] += event.downloaded - file_progress[event.file_path].downloaded
progress_state['downloaded_bytes_this_session'] += event.downloaded_this_session - file_progress[event.file_path].downloaded_this_session
file_progress[event.file_path] = event
if progress_callback:
elapsed_time = (datetime.now() - start_time).total_seconds()
overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0
remaining_bytes = total_bytes - progress_state['downloaded_bytes']
overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0)
status = "in_progress" if progress_state['downloaded_bytes'] < total_bytes else "complete"
await progress_callback(
RepoProgressEvent(
repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
overall_eta, file_progress, status
)
)
await download_file(session, repo_id, revision, file_info["path"], snapshot_dir, file_progress_callback)
progress_state['completed_files'] += 1
file_progress[
file_info["path"]
] = RepoFileProgressEvent(repo_id, revision, file_info["path"], file_info["size"], file_progress[file_info["path"]].downloaded_this_session, file_info["size"], 0, timedelta(0), "complete")
if progress_callback:
elapsed_time = (datetime.now() - start_time).total_seconds()
overall_speed = int(progress_state['downloaded_bytes_this_session']/elapsed_time) if elapsed_time > 0 else 0
remaining_bytes = total_bytes - progress_state['downloaded_bytes']
overall_eta = timedelta(seconds=remaining_bytes/overall_speed) if overall_speed > 0 else timedelta(seconds=0)
status = "in_progress" if progress_state['completed_files'] < total_files else "complete"
await progress_callback(
RepoProgressEvent(
repo_id, revision, progress_state['completed_files'], total_files, progress_state['downloaded_bytes'], progress_state['downloaded_bytes_this_session'], total_bytes, overall_speed,
overall_eta, file_progress, status
)
)
progress_state = {'completed_files': 0, 'downloaded_bytes': 0, 'downloaded_bytes_this_session': 0}
semaphore = asyncio.Semaphore(max_parallel_downloads)
async def download_with_semaphore(file_info):
async with semaphore:
await download_with_progress(file_info, progress_state)
tasks = [asyncio.create_task(download_with_semaphore(file_info)) for file_info in filtered_file_list]
await asyncio.gather(*tasks)
return snapshot_dir
async def get_weight_map(repo_id: str, revision: str = "main") -> Optional[Dict[str, str]]:
"""
Retrieve the weight map from the model.safetensors.index.json file.
Args:
repo_id (str): The Hugging Face repository ID.
revision (str): The revision of the repository to use.
Returns:
Optional[Dict[str, str]]: The weight map if it exists, otherwise None.
"""
# Download the index file
await download_repo_files(repo_id=repo_id, revision=revision, allow_patterns="model.safetensors.index.json")
# Check if the file exists
repo_root = get_repo_root(repo_id)
commit_hash = await resolve_revision_to_commit_hash(repo_id, revision)
snapshot_dir = repo_root/"snapshots"/commit_hash
index_file = next((f for f in await aios.listdir(snapshot_dir) if f.endswith("model.safetensors.index.json")), None)
if index_file:
index_file_path = snapshot_dir/index_file
if await aios.path.exists(index_file_path):
async with aiofiles.open(index_file_path, 'r') as f:
index_data = json.loads(await f.read())
return index_data.get("weight_map")
return None
def extract_layer_num(tensor_name: str) -> Optional[int]:
# This is a simple example and might need to be adjusted based on the actual naming convention
parts = tensor_name.split('.')
@@ -421,7 +79,6 @@ def extract_layer_num(tensor_name: str) -> Optional[int]:
return int(part)
return None
def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
default_patterns = set(["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"])
shard_specific_patterns = set()
@@ -437,67 +94,5 @@ def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
shard_specific_patterns.add(sorted_file_names[-1])
else:
shard_specific_patterns = set(["*.safetensors"])
if DEBUG >= 2: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
if DEBUG >= 3: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
return list(default_patterns | shard_specific_patterns)
async def get_file_download_percentage(
session: aiohttp.ClientSession,
repo_id: str,
revision: str,
file_path: str,
snapshot_dir: Path,
) -> float:
"""
Calculate the download percentage for a file by comparing local and remote sizes.
"""
try:
local_path = snapshot_dir / file_path
if not await aios.path.exists(local_path):
return 0
# Get local file size first
local_size = await aios.path.getsize(local_path)
if local_size == 0:
return 0
# Check remote size
base_url = f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/"
url = urljoin(base_url, file_path)
headers = await get_auth_headers()
# Use HEAD request with redirect following for all files
async with session.head(url, headers=headers, allow_redirects=True) as response:
if response.status != 200:
if DEBUG >= 2:
print(f"Failed to get remote file info for {file_path}: {response.status}")
return 0
remote_size = int(response.headers.get('Content-Length', 0))
if remote_size == 0:
if DEBUG >= 2:
print(f"Remote size is 0 for {file_path}")
return 0
# Only return 100% if sizes match exactly
if local_size == remote_size:
return 100.0
# Calculate percentage based on sizes
return (local_size / remote_size) * 100 if remote_size > 0 else 0
except Exception as e:
if DEBUG >= 2:
print(f"Error checking file download status for {file_path}: {e}")
return 0
async def has_hf_home_read_access() -> bool:
hf_home = get_hf_home()
try: return await aios.access(hf_home, os.R_OK)
except OSError: return False
async def has_hf_home_write_access() -> bool:
hf_home = get_hf_home()
try: return await aios.access(hf_home, os.W_OK)
except OSError: return False

View File

@@ -1,167 +0,0 @@
import asyncio
import traceback
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Union
from exo.inference.shard import Shard
from exo.download.shard_download import ShardDownloader
from exo.download.download_progress import RepoProgressEvent
from exo.download.hf.hf_helpers import (
download_repo_files, RepoProgressEvent, get_weight_map,
get_allow_patterns, get_repo_root, fetch_file_list,
get_local_snapshot_dir, get_file_download_percentage,
filter_repo_objects
)
from exo.helpers import AsyncCallbackSystem, DEBUG
from exo.models import model_cards, get_repo
import aiohttp
from aiofiles import os as aios
class HFShardDownloader(ShardDownloader):
def __init__(self, quick_check: bool = False, max_parallel_downloads: int = 4):
self.quick_check = quick_check
self.max_parallel_downloads = max_parallel_downloads
self.active_downloads: Dict[Shard, asyncio.Task] = {}
self.completed_downloads: Dict[Shard, Path] = {}
self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
self.current_shard: Optional[Shard] = None
self.current_repo_id: Optional[str] = None
self.revision: str = "main"
async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
self.current_shard = shard
self.current_repo_id = get_repo(shard.model_id, inference_engine_name)
repo_name = get_repo(shard.model_id, inference_engine_name)
if shard in self.completed_downloads:
return self.completed_downloads[shard]
if self.quick_check:
repo_root = get_repo_root(repo_name)
snapshots_dir = repo_root/"snapshots"
if snapshots_dir.exists():
visible_dirs = [d for d in snapshots_dir.iterdir() if not d.name.startswith('.')]
if visible_dirs:
most_recent_dir = max(visible_dirs, key=lambda x: x.stat().st_mtime)
return most_recent_dir
# If a download on this shard is already in progress, keep that one
for active_shard in self.active_downloads:
if active_shard == shard:
if DEBUG >= 2: print(f"Download already in progress for {shard}. Keeping that one.")
return await self.active_downloads[shard]
# Cancel any downloads for this model_id on a different shard
existing_active_shards = [active_shard for active_shard in self.active_downloads.keys() if active_shard.model_id == shard.model_id]
for active_shard in existing_active_shards:
if DEBUG >= 2: print(f"Cancelling download for {active_shard} (replacing with {shard})")
task = self.active_downloads[active_shard]
task.cancel()
try:
await task
except asyncio.CancelledError:
pass # This is expected when cancelling a task
except Exception as e:
if DEBUG >= 2: print(f"Error in cancelling download {active_shard}: {e}")
traceback.print_exc()
self.active_downloads = {active_shard: task for active_shard, task in self.active_downloads.items() if active_shard.model_id != shard.model_id}
# Start new download
download_task = asyncio.create_task(self._download_shard(shard, repo_name))
self.active_downloads[shard] = download_task
try:
path = await download_task
self.completed_downloads[shard] = path
return path
finally:
# Ensure the task is removed even if an exception occurs
print(f"Removing download task for {shard}: {shard in self.active_downloads}")
if shard in self.active_downloads:
self.active_downloads.pop(shard)
async def _download_shard(self, shard: Shard, repo_name: str) -> Path:
async def wrapped_progress_callback(event: RepoProgressEvent):
self._on_progress.trigger_all(shard, event)
weight_map = await get_weight_map(repo_name)
allow_patterns = get_allow_patterns(weight_map, shard)
return await download_repo_files(repo_name, progress_callback=wrapped_progress_callback, allow_patterns=allow_patterns, max_parallel_downloads=self.max_parallel_downloads)
@property
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
return self._on_progress
async def get_shard_download_status(self) -> Optional[Dict[str, Union[float, int]]]:
if not self.current_shard or not self.current_repo_id:
if DEBUG >= 2:
print(f"No current shard or repo_id set: {self.current_shard=} {self.current_repo_id=}")
return None
try:
# If no snapshot directory exists, return None - no need to check remote files
snapshot_dir = await get_local_snapshot_dir(self.current_repo_id, self.revision)
if not snapshot_dir:
if DEBUG >= 2:
print(f"No snapshot directory found for {self.current_repo_id}")
return None
# Get the weight map to know what files we need
weight_map = await get_weight_map(self.current_repo_id, self.revision)
if not weight_map:
if DEBUG >= 2:
print(f"No weight map found for {self.current_repo_id}")
return None
# Get all files needed for this shard
patterns = get_allow_patterns(weight_map, self.current_shard)
# Check download status for all relevant files
status = {}
total_bytes = 0
downloaded_bytes = 0
async with aiohttp.ClientSession() as session:
file_list = await fetch_file_list(session, self.current_repo_id, self.revision)
relevant_files = list(
filter_repo_objects(
file_list, allow_patterns=patterns, key=lambda x: x["path"]))
for file in relevant_files:
file_size = file["size"]
total_bytes += file_size
percentage = await get_file_download_percentage(
session,
self.current_repo_id,
self.revision,
file["path"],
snapshot_dir,
)
status[file["path"]] = percentage
downloaded_bytes += (file_size * (percentage / 100))
# Add overall progress weighted by file size
if total_bytes > 0:
status["overall"] = (downloaded_bytes / total_bytes) * 100
else:
status["overall"] = 0
# Add total size in bytes
status["total_size"] = total_bytes
if status["overall"] != 100:
status["total_downloaded"] = downloaded_bytes
if DEBUG >= 2:
print(f"Download calculation for {self.current_repo_id}:")
print(f"Total bytes: {total_bytes}")
print(f"Downloaded bytes: {downloaded_bytes}")
for file in relevant_files:
print(f"File {file['path']}: size={file['size']}, percentage={status[file['path']]}")
return status
except Exception as e:
if DEBUG >= 2:
print(f"Error getting shard download status: {e}")
traceback.print_exc()
return None

View File

@@ -0,0 +1,307 @@
from exo.inference.shard import Shard
from exo.models import get_repo
from pathlib import Path
from exo.download.hf.hf_helpers import get_hf_endpoint, get_auth_headers, filter_repo_objects, get_allow_patterns
from exo.download.shard_download import ShardDownloader
from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent
from exo.helpers import AsyncCallbackSystem, DEBUG
from exo.models import get_supported_models, build_full_shard
import os
import aiofiles.os as aios
import aiohttp
import aiofiles
from urllib.parse import urljoin
from typing import Callable, Union, Tuple, Dict, List, Optional, Literal, AsyncIterator
import time
from datetime import timedelta
import asyncio
import json
import traceback
import shutil
import tempfile
import hashlib
def exo_home() -> Path:
return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo"))
def exo_tmp() -> Path:
return Path(tempfile.gettempdir())/"exo"
async def ensure_exo_home() -> Path:
await aios.makedirs(exo_home(), exist_ok=True)
return exo_home()
async def ensure_exo_tmp() -> Path:
await aios.makedirs(exo_tmp(), exist_ok=True)
return exo_tmp()
async def has_exo_home_read_access() -> bool:
try: return await aios.access(exo_home(), os.R_OK)
except OSError: return False
async def has_exo_home_write_access() -> bool:
try: return await aios.access(exo_home(), os.W_OK)
except OSError: return False
async def ensure_downloads_dir() -> Path:
downloads_dir = exo_home()/"downloads"
await aios.makedirs(downloads_dir, exist_ok=True)
return downloads_dir
async def delete_model(model_id: str, inference_engine_name: str) -> bool:
repo_id = get_repo(model_id, inference_engine_name)
model_dir = await ensure_downloads_dir()/repo_id.replace("/", "--")
if not await aios.path.exists(model_dir): return False
await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)
return True
async def seed_models(seed_dir: Union[str, Path]):
"""Move model in resources folder of app to .cache/huggingface/hub"""
source_dir = Path(seed_dir)
dest_dir = await ensure_downloads_dir()
for path in source_dir.iterdir():
if path.is_dir() and path.name.startswith("models--"):
dest_path = dest_dir/path.name
if await aios.path.exists(dest_path): print('Skipping moving model to .cache directory')
else:
try: await aios.rename(str(path), str(dest_path))
except:
print(f"Error seeding model {path} to {dest_path}")
traceback.print_exc()
async def fetch_file_list_with_cache(repo_id: str, revision: str = "main") -> List[Dict[str, Union[str, int]]]:
cache_file = (await ensure_exo_tmp())/f"{repo_id.replace('/', '--')}--{revision}--file_list.json"
if await aios.path.exists(cache_file):
async with aiofiles.open(cache_file, 'r') as f: return json.loads(await f.read())
file_list = await fetch_file_list_with_retry(repo_id, revision)
await aios.makedirs(cache_file.parent, exist_ok=True)
async with aiofiles.open(cache_file, 'w') as f: await f.write(json.dumps(file_list))
return file_list
async def fetch_file_list_with_retry(repo_id: str, revision: str = "main", path: str = "") -> List[Dict[str, Union[str, int]]]:
n_attempts = 30
for attempt in range(n_attempts):
try: return await _fetch_file_list(repo_id, revision, path)
except Exception as e:
if attempt == n_attempts - 1: raise e
await asyncio.sleep(min(8, 0.1 * (2 ** attempt)))
async def _fetch_file_list(repo_id: str, revision: str = "main", path: str = "") -> List[Dict[str, Union[str, int]]]:
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
url = f"{api_url}/{path}" if path else api_url
headers = await get_auth_headers()
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=30, connect=10, sock_read=30, sock_connect=10)) as session:
async with session.get(url, headers=headers) as response:
if response.status == 200:
data = await response.json()
files = []
for item in data:
if item["type"] == "file":
files.append({"path": item["path"], "size": item["size"]})
elif item["type"] == "directory":
subfiles = await _fetch_file_list(repo_id, revision, item["path"])
files.extend(subfiles)
return files
else:
raise Exception(f"Failed to fetch file list: {response.status}")
async def calc_hash(path: Path, type: Literal["sha1", "sha256"] = "sha1") -> str:
hash = hashlib.sha1() if type == "sha1" else hashlib.sha256()
if type == "sha1":
header = f"blob {(await aios.stat(path)).st_size}\0".encode()
hash.update(header)
async with aiofiles.open(path, 'rb') as f:
while chunk := await f.read(8 * 1024 * 1024):
hash.update(chunk)
return hash.hexdigest()
async def file_meta(repo_id: str, revision: str, path: str) -> Tuple[int, str]:
url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
headers = await get_auth_headers()
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as session:
async with session.head(url, headers=headers) as r:
content_length = int(r.headers.get('x-linked-size') or r.headers.get('content-length') or 0)
etag = r.headers.get('X-Linked-ETag') or r.headers.get('ETag') or r.headers.get('Etag')
assert content_length > 0, f"No content length for {url}"
assert etag is not None, f"No remote hash for {url}"
if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"): etag = etag[1:-1]
return content_length, etag
async def download_file_with_retry(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
n_attempts = 30
for attempt in range(n_attempts):
try: return await _download_file(repo_id, revision, path, target_dir, on_progress)
except Exception as e:
if isinstance(e, FileNotFoundError) or attempt == n_attempts - 1: raise e
print(f"Download error on attempt {attempt}/{n_attempts} for {repo_id=} {revision=} {path=} {target_dir=}")
traceback.print_exc()
await asyncio.sleep(min(8, 0.1 * (2 ** attempt)))
async def _download_file(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
if await aios.path.exists(target_dir/path): return target_dir/path
await aios.makedirs((target_dir/path).parent, exist_ok=True)
length, etag = await file_meta(repo_id, revision, path)
remote_hash = etag[:-5] if etag.endswith("-gzip") else etag
partial_path = target_dir/f"{path}.partial"
resume_byte_pos = (await aios.stat(partial_path)).st_size if (await aios.path.exists(partial_path)) else None
if resume_byte_pos != length:
url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
headers = await get_auth_headers()
if resume_byte_pos: headers['Range'] = f'bytes={resume_byte_pos}-'
n_read = resume_byte_pos or 0
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as session:
async with session.get(url, headers=headers, timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as r:
if r.status == 404: raise FileNotFoundError(f"File not found: {url}")
assert r.status in [200, 206], f"Failed to download {path} from {url}: {r.status}"
async with aiofiles.open(partial_path, 'ab' if resume_byte_pos else 'wb') as f:
while chunk := await r.content.read(8 * 1024 * 1024): on_progress(n_read := n_read + await f.write(chunk), length)
final_hash = await calc_hash(partial_path, type="sha256" if len(remote_hash) == 64 else "sha1")
integrity = final_hash == remote_hash
if not integrity:
try: await aios.remove(partial_path)
except Exception as e: print(f"Error removing partial file {partial_path}: {e}")
raise Exception(f"Downloaded file {target_dir/path} has hash {final_hash} but remote hash is {remote_hash}")
await aios.rename(partial_path, target_dir/path)
return target_dir/path
def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_progress: Dict[str, RepoFileProgressEvent], all_start_time: float) -> RepoProgressEvent:
all_total_bytes = sum([p.total for p in file_progress.values()])
all_downloaded_bytes = sum([p.downloaded for p in file_progress.values()])
all_downloaded_bytes_this_session = sum([p.downloaded_this_session for p in file_progress.values()])
elapsed_time = time.time() - all_start_time
all_speed = all_downloaded_bytes_this_session / elapsed_time if elapsed_time > 0 else 0
all_eta = timedelta(seconds=(all_total_bytes - all_downloaded_bytes) / all_speed) if all_speed > 0 else timedelta(seconds=0)
status = "complete" if all(p.status == "complete" for p in file_progress.values()) else "in_progress" if any(p.status == "in_progress" for p in file_progress.values()) else "not_started"
return RepoProgressEvent(shard, repo_id, revision, len([p for p in file_progress.values() if p.downloaded == p.total]), len(file_progress), all_downloaded_bytes, all_downloaded_bytes_this_session, all_total_bytes, all_speed, all_eta, file_progress, status)
async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]:
target_dir = (await ensure_exo_tmp())/repo_id.replace("/", "--")
index_file = await download_file_with_retry(repo_id, revision, "model.safetensors.index.json", target_dir)
async with aiofiles.open(index_file, 'r') as f: index_data = json.loads(await f.read())
return index_data.get("weight_map")
async def resolve_allow_patterns(shard: Shard, inference_engine_classname: str) -> List[str]:
try:
weight_map = await get_weight_map(get_repo(shard.model_id, inference_engine_classname))
return get_allow_patterns(weight_map, shard)
except:
if DEBUG >= 1: print(f"Error getting weight map for {shard.model_id=} and inference engine {inference_engine_classname}")
if DEBUG >= 1: traceback.print_exc()
return ["*"]
async def get_downloaded_size(path: Path) -> int:
partial_path = path.with_suffix(path.suffix + ".partial")
if await aios.path.exists(path): return (await aios.stat(path)).st_size
if await aios.path.exists(partial_path): return (await aios.stat(partial_path)).st_size
return 0
async def download_shard(shard: Shard, inference_engine_classname: str, on_progress: AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]], max_parallel_downloads: int = 8, skip_download: bool = False) -> tuple[Path, RepoProgressEvent]:
if DEBUG >= 2 and not skip_download: print(f"Downloading {shard.model_id=} for {inference_engine_classname}")
repo_id = get_repo(shard.model_id, inference_engine_classname)
revision = "main"
target_dir = await ensure_downloads_dir()/repo_id.replace("/", "--")
if not skip_download: await aios.makedirs(target_dir, exist_ok=True)
if repo_id is None:
raise ValueError(f"No repo found for {shard.model_id=} and inference engine {inference_engine_classname}")
allow_patterns = await resolve_allow_patterns(shard, inference_engine_classname)
if DEBUG >= 2: print(f"Downloading {shard.model_id=} with {allow_patterns=}")
all_start_time = time.time()
file_list = await fetch_file_list_with_cache(repo_id, revision)
filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, key=lambda x: x["path"]))
file_progress: Dict[str, RepoFileProgressEvent] = {}
def on_progress_wrapper(file: dict, curr_bytes: int, total_bytes: int):
start_time = file_progress[file["path"]].start_time if file["path"] in file_progress else time.time()
downloaded_this_session = file_progress[file["path"]].downloaded_this_session + (curr_bytes - file_progress[file["path"]].downloaded) if file["path"] in file_progress else curr_bytes
speed = downloaded_this_session / (time.time() - start_time) if time.time() - start_time > 0 else 0
eta = timedelta(seconds=(total_bytes - curr_bytes) / speed) if speed > 0 else timedelta(seconds=0)
file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], curr_bytes, downloaded_this_session, total_bytes, speed, eta, "complete" if curr_bytes == total_bytes else "in_progress", start_time)
on_progress.trigger_all(shard, calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time))
if DEBUG >= 6: print(f"Downloading {file['path']} {curr_bytes}/{total_bytes} {speed} {eta}")
for file in filtered_file_list:
downloaded_bytes = await get_downloaded_size(target_dir/file["path"])
file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], downloaded_bytes, 0, file["size"], 0, timedelta(0), "complete" if downloaded_bytes == file["size"] else "not_started", time.time())
semaphore = asyncio.Semaphore(max_parallel_downloads)
async def download_with_semaphore(file):
async with semaphore:
await download_file_with_retry(repo_id, revision, file["path"], target_dir, lambda curr_bytes, total_bytes: on_progress_wrapper(file, curr_bytes, total_bytes))
if not skip_download: await asyncio.gather(*[download_with_semaphore(file) for file in filtered_file_list])
final_repo_progress = calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time)
on_progress.trigger_all(shard, final_repo_progress)
if gguf := next((f for f in filtered_file_list if f["path"].endswith(".gguf")), None):
return target_dir/gguf["path"], final_repo_progress
else:
return target_dir, final_repo_progress
def new_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
return SingletonShardDownloader(CachedShardDownloader(NewShardDownloader(max_parallel_downloads)))
class SingletonShardDownloader(ShardDownloader):
def __init__(self, shard_downloader: ShardDownloader):
self.shard_downloader = shard_downloader
self.active_downloads: Dict[Shard, asyncio.Task] = {}
@property
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
return self.shard_downloader.on_progress
async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
if shard not in self.active_downloads: self.active_downloads[shard] = asyncio.create_task(self.shard_downloader.ensure_shard(shard, inference_engine_name))
try: return await self.active_downloads[shard]
finally:
if shard in self.active_downloads and self.active_downloads[shard].done(): del self.active_downloads[shard]
async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
async for path, status in self.shard_downloader.get_shard_download_status(inference_engine_name):
yield path, status
class CachedShardDownloader(ShardDownloader):
def __init__(self, shard_downloader: ShardDownloader):
self.shard_downloader = shard_downloader
self.cache: Dict[tuple[str, Shard], Path] = {}
@property
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
return self.shard_downloader.on_progress
async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
if (inference_engine_name, shard) in self.cache:
if DEBUG >= 2: print(f"ensure_shard cache hit {shard=} for {inference_engine_name}")
return self.cache[(inference_engine_name, shard)]
if DEBUG >= 2: print(f"ensure_shard cache miss {shard=} for {inference_engine_name}")
target_dir = await self.shard_downloader.ensure_shard(shard, inference_engine_name)
self.cache[(inference_engine_name, shard)] = target_dir
return target_dir
async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
async for path, status in self.shard_downloader.get_shard_download_status(inference_engine_name):
yield path, status
class NewShardDownloader(ShardDownloader):
def __init__(self, max_parallel_downloads: int = 8):
self.max_parallel_downloads = max_parallel_downloads
self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
@property
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
return self._on_progress
async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
target_dir, _ = await download_shard(shard, inference_engine_name, self.on_progress, max_parallel_downloads=self.max_parallel_downloads)
return target_dir
async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
if DEBUG >= 2: print("Getting shard download status for", inference_engine_name)
tasks = [download_shard(build_full_shard(model_id, inference_engine_name), inference_engine_name, self.on_progress, skip_download=True) for model_id in get_supported_models([[inference_engine_name]])]
for task in asyncio.as_completed(tasks):
try:
path, progress = await task
yield (path, progress)
except Exception as e:
print("Error downloading shard:", e)

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Optional, Tuple, Dict
from typing import Optional, Tuple, Dict, AsyncIterator
from pathlib import Path
from exo.inference.shard import Shard
from exo.download.download_progress import RepoProgressEvent
@@ -27,7 +27,7 @@ class ShardDownloader(ABC):
pass
@abstractmethod
async def get_shard_download_status(self) -> Optional[Dict[str, float]]:
async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
"""Get the download status of shards.
Returns:
@@ -45,5 +45,5 @@ class NoopShardDownloader(ShardDownloader):
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
return AsyncCallbackSystem()
async def get_shard_download_status(self) -> Optional[Dict[str, float]]:
return None
async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
if False: yield

View File

@@ -0,0 +1,14 @@
from exo.download.new_shard_download import NewShardDownloader
from exo.inference.shard import Shard
import asyncio
async def test_new_shard_download():
shard_downloader = NewShardDownloader()
shard_downloader.on_progress.register("test").on_next(lambda shard, event: print(shard, event))
await shard_downloader.ensure_shard(Shard(model_id="llama-3.2-1b", start_layer=0, end_layer=0, n_layers=16), "MLXDynamicShardInferenceEngine")
async for path, shard_status in shard_downloader.get_shard_download_status("MLXDynamicShardInferenceEngine"):
print("Shard download status:", path, shard_status)
if __name__ == "__main__":
asyncio.run(test_new_shard_download())

View File

@@ -7,12 +7,14 @@ import random
import platform
import psutil
import uuid
import netifaces
from scapy.all import get_if_addr, get_if_list
import re
import subprocess
from pathlib import Path
import tempfile
import json
from concurrent.futures import ThreadPoolExecutor
import traceback
DEBUG = int(os.getenv("DEBUG", default="0"))
DEBUG_DISCOVERY = int(os.getenv("DEBUG_DISCOVERY", default="0"))
@@ -229,28 +231,29 @@ def pretty_print_bytes_per_second(bytes_per_second: int) -> str:
def get_all_ip_addresses_and_interfaces():
try:
ip_addresses = []
for interface in netifaces.interfaces():
ifaddresses = netifaces.ifaddresses(interface)
if netifaces.AF_INET in ifaddresses:
for link in ifaddresses[netifaces.AF_INET]:
ip = link['addr']
ip_addresses.append((ip, interface))
for interface in get_if_list():
try:
ip = get_if_addr(interface)
if ip.startswith("0.0."): continue
simplified_interface = re.sub(r'^\\Device\\NPF_', '', interface)
ip_addresses.append((ip, simplified_interface))
except:
if DEBUG >= 1: print(f"Failed to get IP address for interface {interface}")
if DEBUG >= 1: traceback.print_exc()
if not ip_addresses:
if DEBUG >= 1: print("Failed to get any IP addresses. Defaulting to localhost.")
return [("localhost", "lo")]
return list(set(ip_addresses))
except:
if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
return [("localhost", "lo")]
async def get_macos_interface_type(ifname: str) -> Optional[Tuple[int, str]]:
try:
# Use the shared subprocess_pool
output = await asyncio.get_running_loop().run_in_executor(subprocess_pool, lambda: subprocess.run(
['system_profiler', 'SPNetworkDataType', '-json'],
capture_output=True,
text=True,
close_fds=True
).stdout)
output = await asyncio.get_running_loop().run_in_executor(
subprocess_pool, lambda: subprocess.run(['system_profiler', 'SPNetworkDataType', '-json'], capture_output=True, text=True, close_fds=True).stdout
)
data = json.loads(output)
@@ -276,6 +279,7 @@ async def get_macos_interface_type(ifname: str) -> Optional[Tuple[int, str]]:
return None
async def get_interface_priority_and_type(ifname: str) -> Tuple[int, str]:
# On macOS, try to get interface type using networksetup
if psutil.MACOS:
@@ -283,8 +287,7 @@ async def get_interface_priority_and_type(ifname: str) -> Tuple[int, str]:
if macos_type is not None: return macos_type
# Local container/virtual interfaces
if (ifname.startswith(('docker', 'br-', 'veth', 'cni', 'flannel', 'calico', 'weave')) or
'bridge' in ifname):
if (ifname.startswith(('docker', 'br-', 'veth', 'cni', 'flannel', 'calico', 'weave')) or 'bridge' in ifname):
return (7, "Container Virtual")
# Loopback interface
@@ -310,6 +313,7 @@ async def get_interface_priority_and_type(ifname: str) -> Tuple[int, str]:
# Other physical interfaces
return (2, "Other")
async def shutdown(signal, loop, server):
"""Gracefully shutdown the server and close the asyncio loop."""
print(f"Received exit signal {signal.name}...")
@@ -325,4 +329,44 @@ async def shutdown(signal, loop, server):
def is_frozen():
return getattr(sys, 'frozen', False) or os.path.basename(sys.executable) == "exo" \
or ('Contents/MacOS' in str(os.path.dirname(sys.executable))) \
or '__nuitka__' in globals() or getattr(sys, '__compiled__', False)
or '__nuitka__' in globals() or getattr(sys, '__compiled__', False)
async def get_mac_system_info() -> Tuple[str, str, int]:
"""Get Mac system information using system_profiler."""
try:
output = await asyncio.get_running_loop().run_in_executor(
subprocess_pool,
lambda: subprocess.check_output(["system_profiler", "SPHardwareDataType"]).decode("utf-8")
)
model_line = next((line for line in output.split("\n") if "Model Name" in line), None)
model_id = model_line.split(": ")[1] if model_line else "Unknown Model"
chip_line = next((line for line in output.split("\n") if "Chip" in line), None)
chip_id = chip_line.split(": ")[1] if chip_line else "Unknown Chip"
memory_line = next((line for line in output.split("\n") if "Memory" in line), None)
memory_str = memory_line.split(": ")[1] if memory_line else "Unknown Memory"
memory_units = memory_str.split()
memory_value = int(memory_units[0])
memory = memory_value * 1024 if memory_units[1] == "GB" else memory_value
return model_id, chip_id, memory
except Exception as e:
if DEBUG >= 2: print(f"Error getting Mac system info: {e}")
return "Unknown Model", "Unknown Chip", 0
def get_exo_home() -> Path:
if psutil.WINDOWS: docs_folder = Path(os.environ["USERPROFILE"])/"Documents"
else: docs_folder = Path.home()/"Documents"
if not docs_folder.exists(): docs_folder.mkdir(exist_ok=True)
exo_folder = docs_folder/"Exo"
if not exo_folder.exists(): exo_folder.mkdir(exist_ok=True)
return exo_folder
def get_exo_images_dir() -> Path:
exo_home = get_exo_home()
images_dir = exo_home/"Images"
if not images_dir.exists(): images_dir.mkdir(exist_ok=True)
return images_dir

View File

@@ -16,25 +16,25 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
resp_full = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
token_full = await inference_engine_1.sample(resp_full)
next_resp_full = await inference_engine_1.infer_tensor(
next_resp_full, _ = await inference_engine_1.infer_tensor(
"A",
shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
input_data=token_full,
)
resp1 = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
resp2 = await inference_engine_2.infer_tensor(
resp1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
resp2, _ = await inference_engine_2.infer_tensor(
"B",
shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
input_data=resp1,
)
token2 = await inference_engine_2.sample(resp2)
resp3 = await inference_engine_1.infer_tensor(
resp3, _ = await inference_engine_1.infer_tensor(
"B",
shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32),
input_data=token2,
)
resp4 = await inference_engine_2.infer_tensor(
resp4, _ = await inference_engine_2.infer_tensor(
"B",
shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
input_data=resp3,

View File

@@ -25,9 +25,9 @@ class DummyInferenceEngine(InferenceEngine):
async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
return self.tokenizer.decode(tokens)
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
await self.ensure_shard(shard)
return input_data + 1 if self.shard.is_last_layer() else input_data
return input_data + 1 if self.shard.is_last_layer() else input_data, None
async def ensure_shard(self, shard: Shard):
if self.shard == shard: return

View File

@@ -5,6 +5,7 @@ from exo.helpers import DEBUG # Make sure to import DEBUG
from typing import Tuple, Optional
from abc import ABC, abstractmethod
from .shard import Shard
from exo.download.shard_download import ShardDownloader
class InferenceEngine(ABC):
@@ -13,7 +14,7 @@ class InferenceEngine(ABC):
@abstractmethod
async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
pass
@abstractmethod
async def sample(self, x: np.ndarray) -> np.ndarray:
pass
@@ -23,7 +24,7 @@ class InferenceEngine(ABC):
pass
@abstractmethod
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
pass
@abstractmethod
@@ -32,18 +33,23 @@ class InferenceEngine(ABC):
async def save_checkpoint(self, shard: Shard, path: str):
pass
async def save_session(self, key, value):
self.session[key] = value
async def clear_session(self):
self.session.empty()
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str) -> np.ndarray:
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
tokens = await self.encode(shard, prompt)
x = tokens.reshape(1, -1)
output_data = await self.infer_tensor(request_id, shard, x)
return output_data
if shard.model_id != 'stable-diffusion-2-1-base':
x = tokens.reshape(1, -1)
else:
x = tokens
output_data, inference_state = await self.infer_tensor(request_id, shard, x, inference_state)
return output_data, inference_state
inference_engine_classes = {
"mlx": "MLXDynamicShardInferenceEngine",
@@ -51,7 +57,8 @@ inference_engine_classes = {
"dummy": "DummyInferenceEngine",
}
def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):
def get_inference_engine(inference_engine_name: str, shard_downloader: ShardDownloader):
if DEBUG >= 2:
print(f"get_inference_engine called with: {inference_engine_name}")
if inference_engine_name == "mlx":

View File

@@ -0,0 +1,307 @@
# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/__init__.py
import time
from typing import Optional, Tuple
import inspect
import mlx.core as mx
import mlx.nn as nn
from pathlib import Path
from tqdm import tqdm
from .sd_models.vae import ModelArgs as VAEArgs
from .sd_models.vae import Autoencoder
from .sd_models.tokenizer import load_tokenizer
from .sd_models.clip import CLIPTextModel
from .sd_models.clip import ModelArgs as CLIPArgs
from .sd_models.unet import UNetConfig, UNetModel
from dataclasses import dataclass, field
from exo.inference.shard import Shard
@dataclass
class DiffusionConfig:
beta_schedule: str = "scaled_linear"
beta_start: float = 0.00085
beta_end: float = 0.012
num_train_steps: int = 1000
@classmethod
def from_dict(cls, params):
return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
#Sampler
def _linspace(a, b, num):
x = mx.arange(0, num) / (num - 1)
return (b - a) * x + a
def _interp(y, x_new):
"""Interpolate the function defined by (arange(0, len(y)), y) at positions x_new."""
x_low = x_new.astype(mx.int32)
x_high = mx.minimum(x_low + 1, len(y) - 1)
y_low = y[x_low]
y_high = y[x_high]
delta_x = x_new - x_low
y_new = y_low * (1 - delta_x) + delta_x * y_high
return y_new
class SimpleEulerSampler:
"""A simple Euler integrator that can be used to sample from our diffusion models.
The method ``step()`` performs one Euler step from x_t to x_t_prev.
"""
def __init__(self, config: DiffusionConfig):
# Compute the noise schedule
if config.beta_schedule == "linear":
betas = _linspace(
config.beta_start, config.beta_end, config.num_train_steps
)
elif config.beta_schedule == "scaled_linear":
betas = _linspace(
config.beta_start**0.5, config.beta_end**0.5, config.num_train_steps
).square()
else:
raise NotImplementedError(f"{config.beta_schedule} is not implemented.")
alphas = 1 - betas
alphas_cumprod = mx.cumprod(alphas)
self._sigmas = mx.concatenate(
[mx.zeros(1), ((1 - alphas_cumprod) / alphas_cumprod).sqrt()]
)
@property
def max_time(self):
return len(self._sigmas) - 1
def sample_prior(self, shape, dtype=mx.float32, key=None):
noise = mx.random.normal(shape, key=key)
return (
noise * self._sigmas[-1] * (self._sigmas[-1].square() + 1).rsqrt()
).astype(dtype)
def add_noise(self, x, t, key=None):
noise = mx.random.normal(x.shape, key=key)
s = self.sigmas(t)
return (x + noise * s) * (s.square() + 1).rsqrt()
def sigmas(self, t):
return _interp(self._sigmas, t)
def timesteps(self, num_steps: int, start_time=None, dtype=mx.float32):
start_time = start_time or (len(self._sigmas) - 1)
assert 0 < start_time <= (len(self._sigmas) - 1)
steps = _linspace(start_time, 0, num_steps + 1).astype(dtype)
return list(zip(steps, steps[1:]))
def current_timestep(self, step, total_steps, start_time=None):
if step < total_steps:
steps = self.timesteps(total_steps, start_time)
return steps[step]
else:
return mx.array(0),mx.array(0)
def step(self, eps_pred, x_t, t, t_prev):
sigma = self.sigmas(t).astype(eps_pred.dtype)
sigma_prev = self.sigmas(t_prev).astype(eps_pred.dtype)
dt = sigma_prev - sigma
x_t_prev = (sigma.square() + 1).sqrt() * x_t + eps_pred * dt
x_t_prev = x_t_prev * (sigma_prev.square() + 1).rsqrt()
return x_t_prev
@dataclass
class ShardConfig:
model_id:str
start_layer:int
end_layer:int
n_layers:int
@dataclass
class StableDiffusionConfig:
model_type:str
vae:VAEArgs
text_encoder:CLIPArgs
scheduler:DiffusionConfig
unet:UNetConfig
shard:ShardConfig
@classmethod
def from_dict(cls, params):
return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
@dataclass
class ModelArgs(StableDiffusionConfig):
shard:Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
def __post_init__(self):
if isinstance(self.shard, dict):
self.shard = Shard(**self.shard)
if not isinstance(self.shard, Shard):
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
class Model(nn.Module):
def __init__(self, config):
super().__init__()
self.model_type = config.model_type
self.config = config
self.model_path = config.vae['path'].split('/vae')[0]
self.shard = config.shard
self.shard_clip, self.shard_encoder, self.shard_unet, self.shard_decoder = model_shards(config.shard)
self.config_clip=CLIPArgs.from_dict(config.text_encoder['config'])
if self.shard_clip.start_layer != -1:
self.text_encoder = CLIPTextModel(self.config_clip, shard=self.shard_clip)
else:
self.text_encoder = nn.Identity()
self.tokenizer = load_tokenizer(Path(self.model_path), "vocab.json", "merges.txt")
self.diffusion_config = DiffusionConfig.from_dict(config.scheduler['config'])
self.sampler = SimpleEulerSampler(self.diffusion_config)
if self.shard_unet.start_layer!=-1:
self.config_unet = UNetConfig.from_dict(config.unet['config'])
self.unet = UNetModel(self.config_unet, self.shard_unet)
else:
self.unet = nn.Identity()
self.config_vae=VAEArgs.from_dict(config.vae['config'])
if self.shard_encoder.start_layer != -1:
self.encoder=Autoencoder(self.config_vae, self.shard_encoder, "vae_encoder")
else:
self.encoder = nn.Identity()
if self.shard_decoder.start_layer != -1:
self.decoder=Autoencoder(self.config_vae, self.shard_decoder, "vae_decoder")
else:
self.decoder = nn.Identity()
def __call__(self,x, step= 0, cfg_weight: float = 7.5,total_steps=50,conditioning=None,mask=None,residual=None,x_t_prev=None,is_finished=False,is_step_finished=False, image=None, strength=0.65, start_step=None):
t, t_prev = self.sampler.current_timestep(step=step, total_steps=total_steps, start_time=start_step)
is_finished = False
is_step_finished = False
if t.item()==1000:
if self.shard_clip.start_layer == 0:
conditioning = x
if self.shard_clip.start_layer != -1:
conditioning, mask= self.text_encoder(conditioning,mask)
seed = int(time.time())
mx.random.seed(seed)
if image is None:
if self.shard_encoder.is_last_layer():
x = self.sampler.sample_prior((1, *(64, 64), self.config_vae.latent_channels_in), dtype=mx.float32)
x_t_prev=x
start_step = self.sampler.max_time
else:
if self.shard_encoder.start_layer != -1:
image= self.encoder.encode(image)
if self.shard_encoder.is_last_layer():
start_step = self.sampler.max_time*strength
total_steps = int(total_steps*strength)
image = mx.broadcast_to(image, (1,) + image.shape[1:])
x_t_prev=self.sampler.add_noise(image, mx.array(start_step))
image = None
t, t_prev = self.sampler.current_timestep(step=step, total_steps=total_steps, start_time=start_step)
# Perform the denoising loop
if self.shard_unet.start_layer != -1:
with tqdm(total=total_steps,initial=step+1) as pbar:
if step<total_steps:
x = x_t_prev
if self.shard_unet.is_first_layer():
x_t_unet = mx.concatenate([x] * 2, axis=0) if cfg_weight> 1 else x
else:
x_t_unet = x
t_unet = mx.broadcast_to(t, [len(x_t_unet)])
x, residual= self.unet(x_t_unet, t_unet, encoder_x=conditioning, residuals=residual)
if self.shard_unet.is_last_layer():
if cfg_weight > 1:
eps_text, eps_neg = x.split(2)
eps_pred = eps_neg + cfg_weight * (eps_text - eps_neg)
x = self.sampler.step(eps_pred, x_t_prev, t, t_prev)
x_t_prev=x
mx.eval(x)
if self.shard_decoder.is_last_layer():
is_step_finished=True
if self.shard_decoder.start_layer != -1:
x=self.decoder.decode(x)
if self.shard_decoder.is_last_layer():
x = mx.clip(x / 2 + 0.5, 0, 1)
B, H, W, C = x.shape
x = x.reshape(1, B // 1, H, W, C).transpose(0, 2, 1, 3, 4)
x = x.reshape(1 * H, B // 1 * W, C)
x = (x * 255).astype(mx.uint8)
if t_prev.item() ==0:
is_finished=True
mx.eval(x)
return x, {'conditioning':conditioning, 'mask':mask,'residual':residual,'x_t_prev':x_t_prev,'is_finished':is_finished,'is_step_finished':is_step_finished, 'step':step, 'total_steps':total_steps, 'start_step':start_step, 'image':image}
def load(self):
if self.shard_encoder.start_layer != -1:
vae_weights = mx.load(self.config_vae.weight_files[0])
vae_weights = self.encoder.sanitize(vae_weights)
self.encoder.load_weights(list(vae_weights.items()), strict=True)
if self.shard_decoder.start_layer != -1:
vae_weights = mx.load(self.config_vae.weight_files[0])
vae_weights = self.decoder.sanitize(vae_weights)
self.decoder.load_weights(list(vae_weights.items()), strict=True)
if self.shard_clip.start_layer != -1:
clip_weights = mx.load(self.config_clip.weight_files[0])
clip_weights = self.text_encoder.sanitize(clip_weights)
self.text_encoder.load_weights(list(clip_weights.items()), strict=True)
if self.shard_unet.start_layer !=-1:
unet_weights = mx.load(self.config_unet.weight_files[0])
unet_weights = self.unet.sanitize(unet_weights)
self.unet.load_weights(list(unet_weights.items()), strict=True)
def model_shards(shard:ShardConfig):
def create_shard(shard, model_ranges):
start_layer = shard.start_layer
end_layer = shard.end_layer
shards = {}
for model_name, (range_start, range_end) in model_ranges.items():
if start_layer < range_end and end_layer >= range_start:
# Calculate the overlap with the model range
overlap_start = max(start_layer, range_start)
overlap_end = min(end_layer, range_end - 1)
# Adjust the layers relative to the model's range
relative_start = overlap_start - range_start
relative_end = overlap_end - range_start
shards[model_name] = Shard(model_name, relative_start, relative_end, range_end - range_start)
else:
# If no overlap, create a zero-layer shard
shards[model_name] = Shard(model_name, -1, -1, range_end - range_start)
return shards
# Define the ranges for different models
model_ranges = {
'clip': (0, 12),
'vae_encoder':(12,17),
'unet':(17,26),
'vae_decoder': (26, 31) # Example range for unet
}
# Call the function and get the shards for all models
shards = create_shard(shard, model_ranges)
# Access individual shards
shard_clip = shards['clip']
shard_encoder = shards['vae_encoder']
shard_unet = shards['unet']
shard_decoder = shards['vae_decoder']
return shard_clip, shard_encoder, shard_unet, shard_decoder

View File

@@ -0,0 +1,134 @@
from dataclasses import dataclass, field
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.models.cache import KVCache
from mlx_lm.models.deepseek_v3 import (
ModelArgs as V3ModelArgs,
DeepseekV3DecoderLayer,
)
from .base import IdentityBlock
from exo.inference.shard import Shard
@dataclass
class ModelArgs(V3ModelArgs):
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
def __post_init__(self):
if isinstance(self.shard, Shard):
return
if not isinstance(self.shard, dict):
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
self.shard = Shard(**self.shard)
class DeepseekV3Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.args = config
self.num_hidden_layers = config.num_hidden_layers
self.vocab_size = config.vocab_size
if self.args.shard.is_first_layer():
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = []
for i in range(self.num_hidden_layers):
if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
self.layers.append(DeepseekV3DecoderLayer(config, i))
else:
self.layers.append(IdentityBlock())
if self.args.shard.is_last_layer():
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def __call__(
self,
x: mx.array,
cache: Optional[KVCache] = None,
) -> mx.array:
if self.args.shard.is_first_layer():
h = self.embed_tokens(x)
else:
h = x
mask = None
T = h.shape[1]
if T > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
mask = mask.astype(h.dtype)
if cache is None:
cache = [None]*len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
if self.args.shard.is_last_layer():
h = self.norm(h)
return h
class Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.args = config
self.model_type = config.model_type
self.model = DeepseekV3Model(config)
if self.args.shard.is_last_layer():
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache: Optional[KVCache] = None,
):
out = self.model(inputs, cache)
if self.args.shard.is_last_layer():
return self.lm_head(out)
return out
def sanitize(self, weights):
shard_state_dict = {}
for key, value in weights.items():
if key.startswith('model.layers.'):
layer_num = int(key.split('.')[2])
if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer:
shard_state_dict[key] = value
elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'):
shard_state_dict[key] = value
elif self.args.shard.is_last_layer() and (key.startswith('model.norm') or key.startswith('lm_head')):
shard_state_dict[key] = value
for l in range(self.args.num_hidden_layers):
prefix = f"model.layers.{l}"
for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
for k in ["weight", "scales", "biases"]:
expert_key = f"{prefix}.mlp.experts.0.{m}.{k}"
if expert_key in shard_state_dict:
to_join = [
shard_state_dict.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}")
for e in range(self.args.n_routed_experts)
]
shard_state_dict[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
return shard_state_dict
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return (
self.args.qk_nope_head_dim + self.args.qk_rope_head_dim,
self.args.v_head_dim,
)
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -0,0 +1,117 @@
from dataclasses import dataclass, field
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.models.base import create_attention_mask
from mlx_lm.models.phi3 import TransformerBlock, ModelArgs
from ...shard import Shard
from .base import IdentityBlock
@dataclass
class ModelArgs(ModelArgs):
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
def __post_init__(self):
super().__post_init__()
if isinstance(self.shard, Shard):
return
if not isinstance(self.shard, dict):
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
self.shard = Shard(**self.shard)
class Phi3Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
if self.args.shard.is_first_layer():
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = []
for i in range(self.num_hidden_layers):
if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
self.layers.append(TransformerBlock(args=args))
else:
self.layers.append(IdentityBlock())
if self.args.shard.is_last_layer():
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
):
if self.args.shard.is_first_layer():
h = self.embed_tokens(inputs)
else:
h = inputs
mask = None
if h.shape[1] > 1:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
if self.args.shard.is_last_layer():
h = self.norm(h)
return h
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = Phi3Model(args)
if self.args.shard.is_last_layer():
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.model(inputs, cache)
if self.args.shard.is_last_layer():
out = self.lm_head(out)
return out
def sanitize(self, weights):
shard_state_dict = {}
for key, value in weights.items():
if "self_attn.rope.inv_freq" in key:
continue
if key.startswith('model.layers.'):
layer_num = int(key.split('.')[2])
if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer:
shard_state_dict[key] = value
elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'):
shard_state_dict[key] = value
elif self.args.shard.is_last_layer() and (key.startswith('lm_head') or key.startswith('model.norm')):
shard_state_dict[key] = value
return shard_state_dict
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -9,13 +9,12 @@ from mlx_lm.models.qwen2 import TransformerBlock, ModelArgs
from ...shard import Shard
from .base import IdentityBlock
@dataclass
class ModelArgs(ModelArgs):
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
def __post_init__(self):
super().__post_init__() # Ensure parent initializations are respected
super().__post_init__()
if isinstance(self.shard, Shard):
return
@@ -24,7 +23,6 @@ class ModelArgs(ModelArgs):
self.shard = Shard(**self.shard)
class Qwen2Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
@@ -32,14 +30,17 @@ class Qwen2Model(nn.Module):
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
if self.args.shard.is_first_layer():
if self.args.shard.is_first_layer() or (self.args.shard.is_last_layer() and args.tie_word_embeddings):
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = []
for i in range(self.num_hidden_layers):
if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
self.layers.append(TransformerBlock(args=args))
else:
self.layers.append(IdentityBlock())
if self.args.shard.is_last_layer():
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)

View File

@@ -0,0 +1,191 @@
# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/clip.py
import math
from dataclasses import dataclass
from typing import List, Optional
import mlx.core as mx
import mlx.nn as nn
from dataclasses import field, dataclass
from exo.inference.shard import Shard
from exo.inference.mlx.models.base import IdentityBlock
_ACTIVATIONS = {"quick_gelu": nn.gelu_fast_approx, "gelu": nn.gelu}
@dataclass
class CLIPTextModelConfig:
num_layers: int = 23
model_dims: int = 1024
num_heads: int = 16
max_length: int = 77
vocab_size: int = 49408
projection_dim: Optional[int] = None
hidden_act: str = "quick_gelu"
@classmethod
def from_dict(cls, config):
return ModelArgs(
num_layers=config["num_hidden_layers"],
model_dims=config["hidden_size"],
num_heads=config["num_attention_heads"],
max_length=config["max_position_embeddings"],
vocab_size=config["vocab_size"],
projection_dim=config["projection_dim"] if "WithProjection" in config['architectures'][0] else None,
hidden_act=config.get("hidden_act", "quick_gelu"),
weight_files=config.get("weight_files", [])
)
@dataclass
class ModelArgs(CLIPTextModelConfig):
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
weight_files: List[str] = field(default_factory=lambda: [])
def __post_init__(self):
if isinstance(self.shard, dict):
self.shard = Shard(**self.shard)
if not isinstance(self.shard, Shard):
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
if not self.shard.is_first_layer():
self.vision_config = None
@dataclass
class CLIPOutput:
pooled_output: Optional[mx.array] = None
last_hidden_state: Optional[mx.array] = None
hidden_states: Optional[List[mx.array]] = None
class CLIPEncoderLayer(nn.Module):
"""The transformer encoder layer from CLIP."""
def __init__(self, model_dims: int, num_heads: int, activation: str):
super().__init__()
self.layer_norm1 = nn.LayerNorm(model_dims)
self.layer_norm2 = nn.LayerNorm(model_dims)
self.attention = nn.MultiHeadAttention(model_dims, num_heads)
self.attention.query_proj.bias = mx.zeros(model_dims)
self.attention.key_proj.bias = mx.zeros(model_dims)
self.attention.value_proj.bias = mx.zeros(model_dims)
self.attention.out_proj.bias = mx.zeros(model_dims)
self.linear1 = nn.Linear(model_dims, 4 * model_dims)
self.linear2 = nn.Linear(4 * model_dims, model_dims)
self.act = _ACTIVATIONS[activation]
def __call__(self, x, attn_mask=None):
y = self.layer_norm1(x)
y = self.attention(y, y, y, attn_mask)
x = y + x
y = self.layer_norm2(x)
y = self.linear1(y)
y = self.act(y)
y = self.linear2(y)
x = y + x
return x
class CLIPTextModel(nn.Module):
"""Implements the text encoder transformer from CLIP."""
def __init__(self, config: CLIPTextModelConfig, shard: Shard):
super().__init__()
self.shard = shard
self.layers_range = range(self.shard.start_layer*2, self.shard.end_layer*2+2)
if self.shard.is_first_layer():
self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims)
self.position_embedding = nn.Embedding(config.max_length, config.model_dims)
self.layers = []
for i in range(math.ceil(config.num_layers/2)):
if 2*i in self.layers_range:
self.layers.append(CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act))
if 2*i+1 in self.layers_range and 2*i+1 < config.num_layers:
self.layers.append(CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act))
else:
self.layers.append(IdentityBlock())
if self.shard.is_last_layer():
self.final_layer_norm = nn.LayerNorm(config.model_dims)
if config.projection_dim is not None:
self.text_projection = nn.Linear(
config.model_dims, config.projection_dim, bias=False
)
def _get_mask(self, N, dtype):
indices = mx.arange(N)
mask = indices[:, None] < indices[None]
mask = mask.astype(dtype) * (-6e4 if dtype == mx.float16 else -1e9)
return mask
def __call__(self, x, mask=None):
# Extract some shapes
if self.shard.is_first_layer():
B, N = x.shape
eos_tokens = x.argmax(-1)
# Compute the embeddings
x = self.token_embedding(x)
x = x + self.position_embedding.weight[:N]
# Compute the features from the transformer
mask = self._get_mask(N, x.dtype)
for l in self.layers:
x = l(x, mask)
# Apply the final layernorm and return
if self.shard.is_last_layer():
x = self.final_layer_norm(x)
return x, mask
def sanitize(self, weights):
sanitized_weights = {}
for key, value in weights.items():
if "position_ids" in key:
continue
if key.startswith("text_model."):
key = key[11:]
if key.startswith("embeddings."):
key = key[11:]
if key.startswith("encoder."):
key = key[8:]
# Map attention layers
if "self_attn." in key:
key = key.replace("self_attn.", "attention.")
if "q_proj." in key:
key = key.replace("q_proj.", "query_proj.")
if "k_proj." in key:
key = key.replace("k_proj.", "key_proj.")
if "v_proj." in key:
key = key.replace("v_proj.", "value_proj.")
# Map ffn layers
if "mlp.fc1" in key:
key = key.replace("mlp.fc1", "linear1")
if "mlp.fc2" in key:
key = key.replace("mlp.fc2", "linear2")
if key.startswith("layers."):
layer_num = int(key.split(".")[1])
if layer_num not in self.layers_range:
continue
if not self.shard.is_first_layer() and "embedding" in key:
continue
if not self.shard.is_last_layer() and key.startswith("final_layer_norm"):
continue
if not self.shard.is_last_layer() and key.startswith("text_projection"):
continue
sanitized_weights[key] = value
return sanitized_weights

View File

@@ -0,0 +1,131 @@
# adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/tokenizer.py
import regex
import json
import glob
class Tokenizer:
"""A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ ."""
def __init__(self, bpe_ranks, vocab):
self.bpe_ranks = bpe_ranks
self.vocab = vocab
self.pat = regex.compile(
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
regex.IGNORECASE,
)
self._cache = {self.bos: self.bos, self.eos: self.eos}
@property
def bos(self):
return "<|startoftext|>"
@property
def bos_token(self):
return self.vocab[self.bos]
@property
def eos(self):
return "<|endoftext|>"
@property
def eos_token(self):
return self.vocab[self.eos]
def bpe(self, text):
if text in self._cache:
return self._cache[text]
unigrams = list(text[:-1]) + [text[-1] + "</w>"]
unique_bigrams = set(zip(unigrams, unigrams[1:]))
if not unique_bigrams:
return unigrams
# In every iteration try to merge the two most likely bigrams. If none
# was merged we are done.
#
# Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py
while unique_bigrams:
bigram = min(
unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))
)
if bigram not in self.bpe_ranks:
break
new_unigrams = []
skip = False
for a, b in zip(unigrams, unigrams[1:]):
if skip:
skip = False
continue
if (a, b) == bigram:
new_unigrams.append(a + b)
skip = True
else:
new_unigrams.append(a)
if not skip:
new_unigrams.append(b)
unigrams = new_unigrams
unique_bigrams = set(zip(unigrams, unigrams[1:]))
self._cache[text] = unigrams
return unigrams
def tokenize(self, text, prepend_bos=True, append_eos=True):
if isinstance(text, list):
return [self.tokenize(t, prepend_bos, append_eos) for t in text]
# Lower case cleanup and split according to self.pat. Hugging Face does
# a much more thorough job here but this should suffice for 95% of
# cases.
clean_text = regex.sub(r"\s+", " ", text.lower())
tokens = regex.findall(self.pat, clean_text)
# Split the tokens according to the byte-pair merge file
bpe_tokens = [ti for t in tokens for ti in self.bpe(t)]
# Map to token ids and return
tokens = [self.vocab[t] for t in bpe_tokens]
if prepend_bos:
tokens = [self.bos_token] + tokens
if append_eos:
tokens.append(self.eos_token)
return tokens
def encode(self, prompt):
tokens = [self.tokenize(prompt)]
negative_text = ""
if negative_text is not None:
tokens += [self.tokenize(negative_text)]
lengths = [len(t) for t in tokens]
N = max(lengths)
tokens = [t + [0] * (N - len(t)) for t in tokens]
return tokens
def load_tokenizer(
model_path: str,
vocab_key: str = "tokenizer_vocab",
merges_key: str = "tokenizer_merges",
):
vocab_file = glob.glob(str(model_path/"tokenizer"/vocab_key))[0]
with open(vocab_file, encoding="utf-8") as f:
vocab = json.load(f)
merges_file = glob.glob(str(model_path/"tokenizer"/merges_key))[0]
with open(merges_file, encoding="utf-8") as f:
bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
bpe_merges = [tuple(m.split()) for m in bpe_merges]
bpe_ranks = dict(map(reversed, enumerate(bpe_merges)))
return Tokenizer(bpe_ranks, vocab)

View File

@@ -0,0 +1,629 @@
# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/unet.py
import math
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
from dataclasses import dataclass, field
from typing import Tuple, Optional, List
from exo.inference.shard import Shard
@dataclass
class UNetConfig:
in_channels: int = 4
out_channels: int = 4
conv_in_kernel: int = 3
conv_out_kernel: int = 3
block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
layers_per_block: Tuple[int] = (2, 2, 2, 2)
mid_block_layers: int = 2
transformer_layers_per_block: Tuple[int] = (1, 1, 1, 1)
num_attention_heads: Tuple[int] = (5, 10, 20, 20)
cross_attention_dim: Tuple[int] = (1024,) * 4
norm_num_groups: int = 32
down_block_types: Tuple[str] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
)
up_block_types: Tuple[str] = (
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
)
addition_embed_type: Optional[str] = None
addition_time_embed_dim: Optional[int] = None
projection_class_embeddings_input_dim: Optional[int] = None
weight_files: List[str] = field(default_factory=lambda: [])
@classmethod
def from_dict(cls,config):
n_blocks = len(config['block_out_channels'])
return UNetConfig(
in_channels=config["in_channels"],
out_channels=config["out_channels"],
block_out_channels=config["block_out_channels"],
layers_per_block=[config["layers_per_block"]] * n_blocks,
transformer_layers_per_block=config.get(
"transformer_layers_per_block", (1,) * 4
),
num_attention_heads=(
[config["attention_head_dim"]] * n_blocks
if isinstance(config["attention_head_dim"], int)
else config["attention_head_dim"]
),
cross_attention_dim=[config["cross_attention_dim"]] * n_blocks,
norm_num_groups=config["norm_num_groups"],
down_block_types=config["down_block_types"],
up_block_types=config["up_block_types"][::-1],
addition_embed_type=config.get("addition_embed_type", None),
addition_time_embed_dim=config.get("addition_time_embed_dim", None),
projection_class_embeddings_input_dim=config.get(
"projection_class_embeddings_input_dim", None
),
weight_files=config.get("weight_files", [])
)
def upsample_nearest(x, scale: int = 2):
B, H, W, C = x.shape
x = mx.broadcast_to(x[:, :, None, :, None, :], (B, H, scale, W, scale, C))
x = x.reshape(B, H * scale, W * scale, C)
return x
class TimestepEmbedding(nn.Module):
def __init__(self, in_channels: int, time_embed_dim: int):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
def __call__(self, x):
x = self.linear_1(x)
x = nn.silu(x)
x = self.linear_2(x)
return x
class TransformerBlock(nn.Module):
def __init__(
self,
model_dims: int,
num_heads: int,
hidden_dims: Optional[int] = None,
memory_dims: Optional[int] = None,
):
super().__init__()
self.norm1 = nn.LayerNorm(model_dims)
self.attn1 = nn.MultiHeadAttention(model_dims, num_heads)
self.attn1.out_proj.bias = mx.zeros(model_dims)
memory_dims = memory_dims or model_dims
self.norm2 = nn.LayerNorm(model_dims)
self.attn2 = nn.MultiHeadAttention(
model_dims, num_heads, key_input_dims=memory_dims
)
self.attn2.out_proj.bias = mx.zeros(model_dims)
hidden_dims = hidden_dims or 4 * model_dims
self.norm3 = nn.LayerNorm(model_dims)
self.linear1 = nn.Linear(model_dims, hidden_dims)
self.linear2 = nn.Linear(model_dims, hidden_dims)
self.linear3 = nn.Linear(hidden_dims, model_dims)
def __call__(self, x, memory, attn_mask, memory_mask):
# Self attention
y = self.norm1(x)
y = self.attn1(y, y, y, attn_mask)
x = x + y
# Cross attention
y = self.norm2(x)
y = self.attn2(y, memory, memory, memory_mask)
x = x + y
# FFN
y = self.norm3(x)
y_a = self.linear1(y)
y_b = self.linear2(y)
y = y_a * nn.gelu(y_b)
y = self.linear3(y)
x = x + y
return x
class Transformer2D(nn.Module):
"""A transformer model for inputs with 2 spatial dimensions."""
def __init__(
self,
in_channels: int,
model_dims: int,
encoder_dims: int,
num_heads: int,
num_layers: int = 1,
norm_num_groups: int = 32,
):
super().__init__()
self.norm = nn.GroupNorm(norm_num_groups, in_channels, pytorch_compatible=True)
self.proj_in = nn.Linear(in_channels, model_dims)
self.transformer_blocks = [
TransformerBlock(model_dims, num_heads, memory_dims=encoder_dims)
for i in range(num_layers)
]
self.proj_out = nn.Linear(model_dims, in_channels)
def __call__(self, x, encoder_x, attn_mask, encoder_attn_mask):
# Save the input to add to the output
input_x = x
dtype = x.dtype
# Perform the input norm and projection
B, H, W, C = x.shape
x = self.norm(x.astype(mx.float32)).astype(dtype).reshape(B, -1, C)
x = self.proj_in(x)
# Apply the transformer
for block in self.transformer_blocks:
x = block(x, encoder_x, attn_mask, encoder_attn_mask)
# Apply the output projection and reshape
x = self.proj_out(x)
x = x.reshape(B, H, W, C)
return x + input_x
class ResnetBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
groups: int = 32,
temb_channels: Optional[int] = None,
):
super().__init__()
out_channels = out_channels or in_channels
self.norm1 = nn.GroupNorm(groups, in_channels, pytorch_compatible=True)
self.conv1 = nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if temb_channels is not None:
self.time_emb_proj = nn.Linear(temb_channels, out_channels)
self.norm2 = nn.GroupNorm(groups, out_channels, pytorch_compatible=True)
self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if in_channels != out_channels:
self.conv_shortcut = nn.Linear(in_channels, out_channels)
def __call__(self, x, temb=None):
dtype = x.dtype
if temb is not None:
temb = self.time_emb_proj(nn.silu(temb))
y = self.norm1(x.astype(mx.float32)).astype(dtype)
y = nn.silu(y)
y = self.conv1(y)
if temb is not None:
y = y + temb[:, None, None, :]
y = self.norm2(y.astype(mx.float32)).astype(dtype)
y = nn.silu(y)
y = self.conv2(y)
x = y + (x if "conv_shortcut" not in self else self.conv_shortcut(x))
return x
class UNetBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
prev_out_channels: Optional[int] = None,
num_layers: int = 1,
transformer_layers_per_block: int = 1,
num_attention_heads: int = 8,
cross_attention_dim=1280,
resnet_groups: int = 32,
add_downsample=True,
add_upsample=True,
add_cross_attention=True,
):
super().__init__()
# Prepare the in channels list for the resnets
if prev_out_channels is None:
in_channels_list = [in_channels] + [out_channels] * (num_layers - 1)
else:
in_channels_list = [prev_out_channels] + [out_channels] * (num_layers - 1)
res_channels_list = [out_channels] * (num_layers - 1) + [in_channels]
in_channels_list = [
a + b for a, b in zip(in_channels_list, res_channels_list)
]
# Add resnet blocks that also process the time embedding
self.resnets = [
ResnetBlock2D(
in_channels=ic,
out_channels=out_channels,
temb_channels=temb_channels,
groups=resnet_groups,
)
for ic in in_channels_list
]
# Add optional cross attention layers
if add_cross_attention:
self.attentions = [
Transformer2D(
in_channels=out_channels,
model_dims=out_channels,
num_heads=num_attention_heads,
num_layers=transformer_layers_per_block,
encoder_dims=cross_attention_dim,
)
for i in range(num_layers)
]
# Add an optional downsampling layer
if add_downsample:
self.downsample = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=2, padding=1
)
# or upsampling layer
if add_upsample:
self.upsample = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
def __call__(
self,
x,
encoder_x=None,
temb=None,
attn_mask=None,
encoder_attn_mask=None,
residual_hidden_states=None,
):
output_states = []
for i in range(len(self.resnets)):
if residual_hidden_states is not None:
x = mx.concatenate([x, residual_hidden_states.pop()], axis=-1)
x = self.resnets[i](x, temb)
if "attentions" in self:
x = self.attentions[i](x, encoder_x, attn_mask, encoder_attn_mask)
output_states.append(x)
if "downsample" in self:
x = self.downsample(x)
output_states.append(x)
if "upsample" in self:
x = self.upsample(upsample_nearest(x))
output_states.append(x)
return x, output_states
class UNetModel(nn.Module):
"""The conditional 2D UNet model that actually performs the denoising."""
def __init__(self, config: UNetConfig, shard: Shard):
super().__init__()
self.shard = shard
self.start_layer = shard.start_layer
self.end_layer = shard.end_layer
self.layers_range = list(range(self.start_layer, self.end_layer+1))
if shard.is_first_layer():
self.conv_in = nn.Conv2d(
config.in_channels,
config.block_out_channels[0],
config.conv_in_kernel,
padding=(config.conv_in_kernel - 1) // 2,
)
self.timesteps = nn.SinusoidalPositionalEncoding(
config.block_out_channels[0],
max_freq=1,
min_freq=math.exp(
-math.log(10000) + 2 * math.log(10000) / config.block_out_channels[0]
),
scale=1.0,
cos_first=True,
full_turns=False,
)
self.time_embedding = TimestepEmbedding(
config.block_out_channels[0],
config.block_out_channels[0] * 4,
)
if config.addition_embed_type == "text_time":
self.add_time_proj = nn.SinusoidalPositionalEncoding(
config.addition_time_embed_dim,
max_freq=1,
min_freq=math.exp(
-math.log(10000)
+ 2 * math.log(10000) / config.addition_time_embed_dim
),
scale=1.0,
cos_first=True,
full_turns=False,
)
self.add_embedding = TimestepEmbedding(
config.projection_class_embeddings_input_dim,
config.block_out_channels[0] * 4,
)
# Make the downsampling blocks
block_channels = [config.block_out_channels[0]] + list(
config.block_out_channels
)
self.down_blocks = []
for i, (in_channels, out_channels) in enumerate(zip(block_channels, block_channels[1:])):
if i in self.layers_range:
self.down_blocks.append(
UNetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=config.block_out_channels[0] * 4,
num_layers=config.layers_per_block[i],
transformer_layers_per_block=config.transformer_layers_per_block[i],
num_attention_heads=config.num_attention_heads[i],
cross_attention_dim=config.cross_attention_dim[i],
resnet_groups=config.norm_num_groups,
add_downsample=(i < len(config.block_out_channels) - 1),
add_upsample=False,
add_cross_attention="CrossAttn" in config.down_block_types[i],
)
)
else:
self.down_blocks.append(nn.Identity())
# Make the middle block
if 4 in self.layers_range:
self.mid_blocks = [
ResnetBlock2D(
in_channels=config.block_out_channels[-1],
out_channels=config.block_out_channels[-1],
temb_channels=config.block_out_channels[0] * 4,
groups=config.norm_num_groups,
),
Transformer2D(
in_channels=config.block_out_channels[-1],
model_dims=config.block_out_channels[-1],
num_heads=config.num_attention_heads[-1],
num_layers=config.transformer_layers_per_block[-1],
encoder_dims=config.cross_attention_dim[-1],
),
ResnetBlock2D(
in_channels=config.block_out_channels[-1],
out_channels=config.block_out_channels[-1],
temb_channels=config.block_out_channels[0] * 4,
groups=config.norm_num_groups,
),
]
# Make the upsampling blocks
block_channels = (
[config.block_out_channels[0]]
+ list(config.block_out_channels)
+ [config.block_out_channels[-1]]
)
total_items = len(block_channels) - 3
reversed_channels = list(reversed(list(zip(block_channels, block_channels[1:], block_channels[2:]))))
self.up_blocks = []
for rev_i, (in_channels, out_channels, prev_out_channels) in enumerate(reversed_channels):
i = total_items - rev_i
if rev_i+5 in self.layers_range:
self.up_blocks.append(
UNetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=config.block_out_channels[0] * 4,
prev_out_channels=prev_out_channels,
num_layers=config.layers_per_block[i] + 1,
transformer_layers_per_block=config.transformer_layers_per_block[i],
num_attention_heads=config.num_attention_heads[i],
cross_attention_dim=config.cross_attention_dim[i],
resnet_groups=config.norm_num_groups,
add_downsample=False,
add_upsample=(i > 0),
add_cross_attention="CrossAttn" in config.up_block_types[i],
)
)
else:
self.up_blocks.append(nn.Identity())
if shard.is_last_layer():
self.conv_norm_out = nn.GroupNorm(
config.norm_num_groups,
config.block_out_channels[0],
pytorch_compatible=True,
)
self.conv_out = nn.Conv2d(
config.block_out_channels[0],
config.out_channels,
config.conv_out_kernel,
padding=(config.conv_out_kernel - 1) // 2,
)
def __call__(
self,
x,
timestep,
encoder_x,
attn_mask=None,
encoder_attn_mask=None,
text_time=None,
residuals=None,
):
# Compute the time embeddings
temb = self.timesteps(timestep).astype(x.dtype)
temb = self.time_embedding(temb)
# Add the extra text_time conditioning
if text_time is not None:
text_emb, time_ids = text_time
emb = self.add_time_proj(time_ids).flatten(1).astype(x.dtype)
emb = mx.concatenate([text_emb, emb], axis=-1)
emb = self.add_embedding(emb)
temb = temb + emb
if self.shard.is_first_layer():
# Preprocess the input
x = self.conv_in(x)
residuals = [x]
# Run the downsampling part of the unet
for i in range(len(self.down_blocks)):
if i in self.layers_range:
x, res = self.down_blocks[i](
x,
encoder_x=encoder_x,
temb=temb,
attn_mask=attn_mask,
encoder_attn_mask=encoder_attn_mask,
)
residuals.extend(res)
else:
x= self.down_blocks[i](x)
if 4 in self.layers_range:
# Run the middle part of the unet
x = self.mid_blocks[0](x, temb)
x = self.mid_blocks[1](x, encoder_x, attn_mask, encoder_attn_mask)
x = self.mid_blocks[2](x, temb)
# Run the upsampling part of the unet
for i in range(len(self.up_blocks)):
if i+5 in self.layers_range:
x, _ = self.up_blocks[i](
x,
encoder_x=encoder_x,
temb=temb,
attn_mask=attn_mask,
encoder_attn_mask=encoder_attn_mask,
residual_hidden_states=residuals,
)
else:
x= self.up_blocks[i](x)
# Postprocess the output
if self.shard.is_last_layer():
dtype = x.dtype
x = self.conv_norm_out(x.astype(mx.float32)).astype(dtype)
x = nn.silu(x)
x = self.conv_out(x)
return x, residuals
def sanitize(self, weights):
sanitized_weights = {}
for key, value in weights.items():
k1=""
k2=""
if "downsamplers" in key:
key = key.replace("downsamplers.0.conv", "downsample")
if "upsamplers" in key:
key = key.replace("upsamplers.0.conv", "upsample")
# Map the mid block
if "mid_block.resnets.0" in key:
key = key.replace("mid_block.resnets.0", "mid_blocks.0")
if "mid_block.attentions.0" in key:
key = key.replace("mid_block.attentions.0", "mid_blocks.1")
if "mid_block.resnets.1" in key:
key = key.replace("mid_block.resnets.1", "mid_blocks.2")
# Map attention layers
if "to_k" in key:
key = key.replace("to_k", "key_proj")
if "to_out.0" in key:
key = key.replace("to_out.0", "out_proj")
if "to_q" in key:
key = key.replace("to_q", "query_proj")
if "to_v" in key:
key = key.replace("to_v", "value_proj")
# Map transformer ffn
if "ff.net.2" in key:
key = key.replace("ff.net.2", "linear3")
if "ff.net.0" in key:
k1 = key.replace("ff.net.0.proj", "linear1")
k2 = key.replace("ff.net.0.proj", "linear2")
v1, v2 = mx.split(value, 2)
if "conv_shortcut.weight" in key:
value = value.squeeze()
# Transform the weights from 1x1 convs to linear
if len(value.shape) == 4 and ("proj_in" in key or "proj_out" in key):
value = value.squeeze()
if len(value.shape) == 4:
value = value.transpose(0, 2, 3, 1)
value = value.reshape(-1).reshape(value.shape)
if key.startswith("conv_in") :
if 0 not in self.layers_range:
continue
if key.startswith("down_blocks"):
layer_num = int(key.split(".")[1])
if layer_num not in self.layers_range:
continue
if key.startswith("mid_block"):
if 4 not in self.layers_range:
continue
if key.startswith("up_blocks"):
layer_num = int(key.split(".")[1])
if (layer_num+5) not in self.layers_range:
continue
if key.startswith("conv_out") or key.startswith("conv_norm_out"):
if 8 not in self.layers_range:
continue
if len(k1)>0:
sanitized_weights[k1] = v1
sanitized_weights[k2] = v2
else:
sanitized_weights[key] = value
return sanitized_weights

View File

@@ -0,0 +1,429 @@
# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/vae.py
import math
from typing import List
import mlx.core as mx
import mlx.nn as nn
from .unet import ResnetBlock2D, upsample_nearest
from dataclasses import dataclass, field
from exo.inference.shard import Shard
from typing import Tuple
import inspect
from ..base import IdentityBlock
@dataclass
class AutoencoderConfig:
in_channels: int = 3
out_channels: int = 3
latent_channels_out: int = 8
latent_channels_in: int = 4
block_out_channels: Tuple[int] = (128, 256, 512, 512)
layers_per_block: int = 2
norm_num_groups: int = 32
scaling_factor: float = 0.18215
weight_files: List[str] = field(default_factory=lambda: [])
@classmethod
def from_dict(cls, params):
return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
@dataclass
class ModelArgs(AutoencoderConfig):
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
def __post_init__(self):
if isinstance(self.shard, dict):
self.shard = Shard(**self.shard)
if not isinstance(self.shard, Shard):
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
if not self.shard.is_first_layer():
self.vision_config = None
class Attention(nn.Module):
"""A single head unmasked attention for use with the VAE."""
def __init__(self, dims: int, norm_groups: int = 32):
super().__init__()
self.group_norm = nn.GroupNorm(norm_groups, dims, pytorch_compatible=True)
self.query_proj = nn.Linear(dims, dims)
self.key_proj = nn.Linear(dims, dims)
self.value_proj = nn.Linear(dims, dims)
self.out_proj = nn.Linear(dims, dims)
def __call__(self, x):
B, H, W, C = x.shape
y = self.group_norm(x)
queries = self.query_proj(y).reshape(B, H * W, C)
keys = self.key_proj(y).reshape(B, H * W, C)
values = self.value_proj(y).reshape(B, H * W, C)
scale = 1 / math.sqrt(queries.shape[-1])
scores = (queries * scale) @ keys.transpose(0, 2, 1)
attn = mx.softmax(scores, axis=-1)
y = (attn @ values).reshape(B, H, W, C)
y = self.out_proj(y)
x = x + y
return x
class EncoderDecoderBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
resnet_groups: int = 32,
add_downsample=True,
add_upsample=True,
):
super().__init__()
# Add the resnet blocks
self.resnets = [
ResnetBlock2D(
in_channels=in_channels if i == 0 else out_channels,
out_channels=out_channels,
groups=resnet_groups,
)
for i in range(num_layers)
]
# Add an optional downsampling layer
if add_downsample:
self.downsample = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=2, padding=0
)
# or upsampling layer
if add_upsample:
self.upsample = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
def __call__(self, x):
for resnet in self.resnets:
x = resnet(x)
if "downsample" in self:
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
x = self.downsample(x)
if "upsample" in self:
x = self.upsample(upsample_nearest(x))
return x
class Encoder(nn.Module):
"""Implements the encoder side of the Autoencoder."""
def __init__(
self,
in_channels: int,
latent_channels_out: int,
block_out_channels: List[int] = [64],
layers_per_block: int = 2,
resnet_groups: int = 32,
layers_range: List[int] = [],
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
):
super().__init__()
self.layers_range = layers_range
self.shard = shard
if self.shard.is_first_layer():
self.conv_in = nn.Conv2d(
in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1
)
channels = [block_out_channels[0]] + list(block_out_channels)
self.down_blocks = []
current_layer = 1
for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:])):
if current_layer in self.layers_range:
self.down_blocks.append(
EncoderDecoderBlock2D(
in_channels,
out_channels,
num_layers=layers_per_block,
resnet_groups=resnet_groups,
add_downsample=i < len(block_out_channels) - 1,
add_upsample=False,
)
)
else:
self.down_blocks.append(IdentityBlock())
current_layer += 1
if self.shard.is_last_layer():
self.mid_blocks = [
ResnetBlock2D(
in_channels=block_out_channels[-1],
out_channels=block_out_channels[-1],
groups=resnet_groups,
),
Attention(block_out_channels[-1], resnet_groups),
ResnetBlock2D(
in_channels=block_out_channels[-1],
out_channels=block_out_channels[-1],
groups=resnet_groups,
),
]
self.conv_norm_out = nn.GroupNorm(
resnet_groups, block_out_channels[-1], pytorch_compatible=True
)
self.conv_out = nn.Conv2d(block_out_channels[-1], latent_channels_out, 3, padding=1)
def __call__(self, x):
if self.shard.is_first_layer():
x = self.conv_in(x)
for l in self.down_blocks:
x = l(x)
if self.shard.is_last_layer():
x = self.mid_blocks[0](x)
x = self.mid_blocks[1](x)
x = self.mid_blocks[2](x)
x = self.conv_norm_out(x)
x = nn.silu(x)
x = self.conv_out(x)
return x
class Decoder(nn.Module):
"""Implements the decoder side of the Autoencoder."""
def __init__(
self,
in_channels: int,
out_channels: int,
shard: Shard,
layer_range: List[int],
block_out_channels: List[int] = [64],
layers_per_block: int = 2,
resnet_groups: int = 32,
):
super().__init__()
self.out_channels = out_channels
self.layers_range = layer_range
if 0 in layer_range:
self.conv_in = nn.Conv2d(
in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1
)
if 0 in layer_range:
self.mid_blocks = [
ResnetBlock2D(
in_channels=block_out_channels[-1],
out_channels=block_out_channels[-1],
groups=resnet_groups,
),
Attention(block_out_channels[-1], resnet_groups),
ResnetBlock2D(
in_channels=block_out_channels[-1],
out_channels=block_out_channels[-1],
groups=resnet_groups,
),
]
channels = list(reversed(block_out_channels))
channels = [channels[0]] + channels
self.up_blocks = []
current_layer = 1
for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:])):
if current_layer in layer_range:
self.up_blocks.append(
EncoderDecoderBlock2D(
in_channels,
out_channels,
num_layers=layers_per_block,
resnet_groups=resnet_groups,
add_downsample=False,
add_upsample=i < len(block_out_channels) - 1,
)
)
else:
self.up_blocks.append(IdentityBlock())
current_layer += 1
if 4 in layer_range:
self.conv_norm_out = nn.GroupNorm(
resnet_groups, block_out_channels[0], pytorch_compatible=True
)
self.conv_out = nn.Conv2d(block_out_channels[0], self.out_channels, 3, padding=1)
def __call__(self, x):
if 0 in self.layers_range:
x = self.conv_in(x)
x = self.mid_blocks[0](x)
x = self.mid_blocks[1](x)
x = self.mid_blocks[2](x)
for l in self.up_blocks:
x = l(x)
if 4 in self.layers_range:
x = self.conv_norm_out(x)
x = nn.silu(x)
x = self.conv_out(x)
return x
class Autoencoder(nn.Module):
"""The autoencoder that allows us to perform diffusion in the latent space."""
def __init__(self, config: AutoencoderConfig, shard: Shard, model_shard: str):
super().__init__()
self.shard = shard
self.start_layer = shard.start_layer
self.end_layer = shard.end_layer
self.layers_range = list(range(self.start_layer, self.end_layer+1))
self.latent_channels = config.latent_channels_in
self.scaling_factor = config.scaling_factor
self.model_shard = model_shard
if self.model_shard == "vae_encoder":
self.encoder = Encoder(
config.in_channels,
config.latent_channels_out,
config.block_out_channels,
config.layers_per_block,
resnet_groups=config.norm_num_groups,
layers_range=self.layers_range,
shard=shard
)
if self.shard.is_last_layer():
self.quant_proj = nn.Linear(
config.latent_channels_out, config.latent_channels_out
)
if self.model_shard == "vae_decoder":
self.decoder = Decoder(
config.latent_channels_in,
config.out_channels,
shard,
self.layers_range,
config.block_out_channels,
config.layers_per_block + 1,
resnet_groups=config.norm_num_groups,
)
if self.shard.is_first_layer():
self.post_quant_proj = nn.Linear(
config.latent_channels_in, config.latent_channels_in
)
def decode(self, z):
if self.shard.is_first_layer():
z = z / self.scaling_factor
z=self.post_quant_proj(z)
return self.decoder(z)
def encode(self, x):
x = self.encoder(x)
if self.shard.is_last_layer():
x = self.quant_proj(x)
mean, logvar = x.split(2, axis=-1)
mean = mean * self.scaling_factor
logvar = logvar + 2 * math.log(self.scaling_factor)
x = mean
return x
def __call__(self, x, key=None):
mean, logvar = self.encode(x)
z = mx.random.normal(mean.shape, key=key) * mx.exp(0.5 * logvar) + mean
x_hat = self.decode(z)
return dict(x_hat=x_hat, z=z, mean=mean, logvar=logvar)
def sanitize(self, weights):
shard = self.shard
layers = self.layers_range
sanitized_weights = {}
for key, value in weights.items():
if "downsamplers" in key:
key = key.replace("downsamplers.0.conv", "downsample")
if "upsamplers" in key:
key = key.replace("upsamplers.0.conv", "upsample")
# Map attention layers
if "key" in key:
key = key.replace("key", "key_proj")
if "proj_attn" in key:
key = key.replace("proj_attn", "out_proj")
if "query" in key:
key = key.replace("query", "query_proj")
if "value" in key:
key = key.replace("value", "value_proj")
# Map the mid block
if "mid_block.resnets.0" in key:
key = key.replace("mid_block.resnets.0", "mid_blocks.0")
if "mid_block.attentions.0" in key:
key = key.replace("mid_block.attentions.0", "mid_blocks.1")
if "mid_block.resnets.1" in key:
key = key.replace("mid_block.resnets.1", "mid_blocks.2")
# Map the quant/post_quant layers
if "quant_conv" in key:
key = key.replace("quant_conv", "quant_proj")
value = value.squeeze()
# Map the conv_shortcut to linear
if "conv_shortcut.weight" in key:
value = value.squeeze()
if len(value.shape) == 4:
value = value.transpose(0, 2, 3, 1)
value = value.reshape(-1).reshape(value.shape)
if "post_quant_conv" in key :
key = key.replace("quant_conv", "quant_proj")
value = value.squeeze()
if 'decoder' in key and self.model_shard == "vae_decoder":
if key.startswith("decoder.mid_blocks."):
if 0 in layers:
sanitized_weights[key] = value
if "conv_in" in key and 0 in layers:
sanitized_weights[key] = value
if key.startswith("decoder.up_blocks."):
layer_num = int(key.split(".")[2])+1
if layer_num in layers:
sanitized_weights[key] = value
if key.startswith("decoder.conv_norm_out") and 4 in layers:
sanitized_weights[key] = value
if key.startswith("decoder.conv_out") and 4 in layers:
sanitized_weights[key] = value
if self.model_shard == "vae_decoder":
if key.startswith("post_quant_proj") and 0 in layers:
sanitized_weights[key] = value
if self.model_shard == "vae_encoder":
if key.startswith("encoder."):
if "conv_in" in key and shard.is_first_layer():
sanitized_weights[key] = value
if key.startswith("encoder.down_blocks."):
layer_num = int(key.split(".")[2])+1
if layer_num in layers:
sanitized_weights[key] = value
if key.startswith("encoder.mid_blocks.") and shard.is_last_layer():
sanitized_weights[key] = value
if "conv_norm_out" in key and shard.is_last_layer():
sanitized_weights[key] = value
if "conv_out" in key and shard.is_last_layer():
sanitized_weights[key] = value
if key.startswith("quant_proj") and shard.is_last_layer():
sanitized_weights[key] = value
return sanitized_weights

View File

@@ -0,0 +1,7 @@
# Perf improvements
Target: 460 tok/sec
- removing sample goes from 369 -> 402
- performance degrades as we generate more tokens
- make mlx inference engien synchronous, removing thread pool executor: 402 -> 413
- remove self.on_opaque_status.trigger_all: 413 -> 418

View File

@@ -1,151 +1,179 @@
import numpy as np
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.sample_utils import top_p_sampling
from mlx_lm.sample_utils import make_sampler
import mlx.optimizers as optim
from ..inference_engine import InferenceEngine
from .sharded_utils import load_shard, get_image_from_str
from .losses import loss_fns
from .sharded_utils import load_model_shard, resolve_tokenizer
from .losses import loss_fns
from ..shard import Shard
from typing import Dict, Optional, Tuple
from exo.download.shard_download import ShardDownloader
import asyncio
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from collections import OrderedDict
from mlx_lm.models.cache import make_prompt_cache
def sample_logits(
logits: mx.array,
temp: float = 0.0,
top_p: float = 1.0,
logit_bias: Optional[Dict[int, float]] = None
) -> Tuple[mx.array, float]:
if logit_bias:
indices = mx.array(list(logit_bias.keys()))
values = mx.array(list(logit_bias.values()))
logits[:, indices] += values
if temp == 0:
token = mx.argmax(logits, axis=-1)
else:
if top_p > 0 and top_p < 1.0:
token = top_p_sampling(logits, top_p, temp)
else:
token = mx.random.categorical(logits*(1/temp))
return token
from concurrent.futures import ThreadPoolExecutor
class MLXDynamicShardInferenceEngine(InferenceEngine):
def __init__(self, shard_downloader: ShardDownloader):
self.shard = None
self.shard_downloader = shard_downloader
self.executor = ThreadPoolExecutor(max_workers=1)
self.caches = OrderedDict()
self.sampler_params: tuple[float, float] = (0.0, 0.0, 0.0, 1)
self.sampler = make_sampler(*self.sampler_params)
self._mlx_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix="mlx")
self._tokenizer_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix="tokenizer")
self.session = {}
self._shard_lock = asyncio.Lock()
async def _eval_mlx(self, *args):
await asyncio.get_running_loop().run_in_executor(self._mlx_thread, mx.eval, *args)
async def poll_state(self, request_id: str, max_caches=2):
if request_id in self.caches:
self.caches.move_to_end(request_id)
else:
newcache = await asyncio.get_running_loop().run_in_executor(self.executor, make_prompt_cache, self.model)
newcache = make_prompt_cache(self.model)
if len(self.caches) > max_caches:
self.caches.popitem(last=False)
self.caches[request_id] = newcache
return {"cache": self.caches[request_id]}
async def sample(self, x, temp: float = 0.0, top_p: float = 1.0) -> np.ndarray:
y = mx.array(x)
logits = y[:, -1, :]
out = np.array(sample_logits(logits, temp=temp, top_p=top_p), dtype=int)
return out
async def sample(self, x: np.ndarray, temp: float = 0.0, top_p: float = 1.0) -> np.ndarray:
if (temp, top_p, 0.0, 1) != self.sampler_params:
self.sampler_params = (temp, top_p, 0.0, 1)
self.sampler = make_sampler(*self.sampler_params)
logits = mx.array(x)
logits = logits[:, -1, :]
logprobs = logits - mx.logsumexp(logits, keepdims=True)
result = self.sampler(logprobs)
await self._eval_mlx(result)
return np.asarray(result, dtype=int)
async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
await self.ensure_shard(shard)
tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
return np.array(tokens)
return np.asarray(
await asyncio.get_running_loop().run_in_executor(
self._tokenizer_thread,
self.tokenizer.encode,
prompt
)
)
async def decode(self, shard: Shard, tokens) -> str:
await self.ensure_shard(shard)
tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
return tokens
return await asyncio.get_running_loop().run_in_executor(
self._tokenizer_thread,
self.tokenizer.decode,
tokens
)
async def save_checkpoint(self, shard: Shard, path: str):
await self.ensure_shard(shard)
await asyncio.get_running_loop().run_in_executor(self.executor, self.model.save_weights, path)
await asyncio.get_running_loop().run_in_executor(self._mlx_thread, lambda: self.model.save_weights(path))
async def load_checkpoint(self, shard: Shard, path: str):
await self.ensure_shard(shard)
await asyncio.get_running_loop().run_in_executor(self.executor, self.model.load_weights, path)
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
await asyncio.get_running_loop().run_in_executor(self._mlx_thread, lambda: self.model.load_weights(path))
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
await self.ensure_shard(shard)
loop = asyncio.get_running_loop()
state = await self.poll_state(request_id)
state = await self.poll_state(request_id) if self.model.model_type != 'StableDiffusionPipeline' else {}
x = mx.array(input_data)
output_data: np.ndarray = np.array(await loop.run_in_executor(self.executor, lambda: self.model(x, **state)))
return output_data
if self.model.model_type != 'StableDiffusionPipeline':
output_data = await asyncio.get_running_loop().run_in_executor(
self._mlx_thread,
lambda: self.model(x, **state, **(inference_state or {}))
)
inference_state = None
else:
result = await asyncio.get_running_loop().run_in_executor(
self._mlx_thread,
lambda: self.model(x, **state, **(inference_state or {}))
)
output_data, inference_state = result
await self._eval_mlx(output_data)
output_data = await asyncio.get_running_loop().run_in_executor(
self._mlx_thread,
lambda: np.array(output_data, copy=False)
)
return output_data, inference_state
async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce"):
await self.ensure_shard(shard)
await self.save_session('loss', loss_fns[loss])
loop = asyncio.get_running_loop()
#print(f"evaluate in <- {inputs}")
x = mx.array(inputs)
y = mx.array(targets)
l = mx.array(lengths)
score = await loop.run_in_executor(self.executor, self.session['loss'], self.model, x, y, l)
#print(f"evaluate out -> {score}")
score = await asyncio.get_running_loop().run_in_executor(
self._mlx_thread,
lambda: self.session['loss'](self.model, x, y, l)
)
return score
async def ensure_train(self, shard: Shard, loss: str, opt=optim.SGD, lr=1e-5, trainable_layers=['input_layernorm', 'gate_proj']):
await self.ensure_shard(shard)
if 'train_layers' not in self.session or self.session['train_layers'] != trainable_layers:
await self.save_session('train_layers', trainable_layers)
self.model.freeze()
self.model.apply_to_modules(lambda k, v: v.unfreeze() if any(lambda: k.endswith(i) for i in trainable_layers) else None)
def freeze_unfreeze():
self.model.freeze()
self.model.apply_to_modules(
lambda k, v: v.unfreeze() if any(k.endswith(layer_name) for layer_name in trainable_layers) else None
)
await asyncio.get_running_loop().run_in_executor(self._mlx_thread, freeze_unfreeze)
if 'lossname' not in self.session or 'LVaG' not in self.session or self.session['lossname'] != loss:
await self.save_session('lossname', loss)
await self.save_session('LVaG', nn.value_and_grad(self.model, loss_fns[loss]))
if 'opt' not in self.session:
await self.save_session('opt', opt(lr))
return True
async def train(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce", opt=optim.SGD, lr=1e-5):
loop = asyncio.get_running_loop()
nothin = await self.ensure_train(shard, loss, opt, lr)
await self.ensure_train(shard, loss, opt, lr)
def train_step(inp, tar, lng):
lval, grad = self.session['LVaG'](self.model, inp, tar, lng)
gradlayers = grad['model']['layers']
self.session['opt'].update(self.model, grad)
mx.eval(self.model.parameters(), self.session['opt'].state, lval)
return lval, gradlayers
return lval, gradlayers, (self.model.parameters(), self.session['opt'].state, lval)
x = mx.array(inputs)
y = mx.array(targets)
l = mx.array(lengths)
score, gradients, eval_args = await asyncio.get_running_loop().run_in_executor(
self._mlx_thread,
lambda: train_step(x, y, l)
)
await self._eval_mlx(*eval_args)
score, gradients = await loop.run_in_executor(self.executor, train_step, x, y, l)
#print(f"{score=}")
layers = [{k: v["weight"] for k,v in l.items() if 'weight' in v} for l in gradients if l]
#print(layers[0])
return score, np.array(layers[0]['input_layernorm'])
layers = [{k: v["weight"] for k, v in layer.items() if 'weight' in v} for layer in gradients if layer]
first_layer = np.array(layers[0]['input_layernorm'], copy=False)
await self._eval_mlx(first_layer)
return score, first_layer
async def ensure_shard(self, shard: Shard):
if self.shard == shard:
return
model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
if self.shard != shard:
def load_shard_wrapper():
return asyncio.run(load_shard(model_path, shard))
model_shard, self.tokenizer = await asyncio.get_running_loop().run_in_executor(self.executor, load_shard_wrapper)
self.shard = shard
self.model = model_shard
self.caches = OrderedDict()
self.session = {}
async with self._shard_lock:
if self.shard == shard: return
model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
if self.shard != shard:
model_shard = await asyncio.get_running_loop().run_in_executor(
self._mlx_thread,
lambda: load_model_shard(model_path, shard, lazy=False)
)
if hasattr(model_shard, "tokenizer"):
self.tokenizer = model_shard.tokenizer
else:
self.tokenizer = await resolve_tokenizer(model_path)
self.shard = shard
self.model = model_shard
self.caches = OrderedDict()
self.session = {}
async def cleanup(self):
self._mlx_thread.shutdown(wait=True)

View File

@@ -62,8 +62,16 @@ def _get_classes(config: dict):
def load_config(model_path: Path) -> dict:
try:
with open(model_path/"config.json", "r") as f:
config = json.load(f)
config_path = model_path / "config.json"
if config_path.exists():
with open(config_path, "r") as f:
config = json.load(f)
return config
model_index_path = model_path / "model_index.json"
if model_index_path.exists():
config = load_model_index(model_path, model_index_path)
return config
except FileNotFoundError:
logging.error(f"Config file not found in {model_path}")
raise
@@ -110,6 +118,24 @@ def load_model_shard(
# Try weight for back-compat
weight_files = glob.glob(str(model_path/"weight*.safetensors"))
model_class, model_args_class = _get_classes(config=config)
class ShardedModel(model_class):
def __init__(self, args):
super().__init__(args)
self.shard = Shard(args.shard.model_id, args.shard.start_layer, args.shard.end_layer, args.shard.n_layers)
def __call__(self, x, *args, **kwargs):
y = super().__call__(x, *args, **kwargs)
return y
model_args = model_args_class.from_dict(config)
model = ShardedModel(model_args)
if config.get("model_index", False):
model.load()
return model
if not weight_files:
logging.error(f"No safetensors found in {model_path}")
raise FileNotFoundError(f"No safetensors found in {model_path}")
@@ -129,19 +155,7 @@ def load_model_shard(
weights.update(mx.load(wf))
model_class, model_args_class = _get_classes(config=config)
class ShardedModel(model_class):
def __init__(self, args):
super().__init__(args)
self.shard = Shard(args.shard.model_id, args.shard.start_layer, args.shard.end_layer, args.shard.n_layers)
def __call__(self, x, *args, **kwargs):
y = super().__call__(x, *args, **kwargs)
return y
model_args = model_args_class.from_dict(config)
model = ShardedModel(model_args)
if hasattr(model, "sanitize"):
weights = model.sanitize(weights)
@@ -186,6 +200,9 @@ async def load_shard(
processor.eos_token_id = processor.tokenizer.eos_token_id
processor.encode = processor.tokenizer.encode
return model, processor
elif hasattr(model, "tokenizer"):
tokenizer = model.tokenizer
return model, tokenizer
else:
tokenizer = await resolve_tokenizer(model_path)
return model, tokenizer
@@ -214,3 +231,27 @@ async def get_image_from_str(_image_str: str):
return img
else:
raise ValueError("Invalid image_str format. Must be a URL or a base64 encoded image.")
# loading a combined config for all models in the index
def load_model_index(model_path: Path, model_index_path: Path):
models_config = {}
with open(model_index_path, "r") as f:
model_index = json.load(f)
models_config["model_index"] = True
models_config["model_type"] = model_index["_class_name"]
models_config["models"] = {}
for model in model_index.keys():
model_config_path = glob.glob(str(model_path / model / "*config.json"))
if len(model_config_path)>0:
with open(model_config_path[0], "r") as f:
model_config = { }
model_config["model_type"] = model
model_config["config"] = json.load(f)
model_config["path"] = model_path / model
if model_config["path"]/"*model.safetensors":
model_config["config"].update({"weight_files": list(glob.glob(str(model_config["path"]/"*model.safetensors")))})
model_config["path"] = str(model_path / model)
m = {}
m[model] = model_config
models_config.update(m)
return models_config

View File

@@ -0,0 +1,81 @@
import asyncio
import time
import numpy as np
from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
from exo.download.new_shard_download import NewShardDownloader
from exo.inference.shard import Shard
from exo.models import build_base_shard
from collections import deque
from statistics import mean, median
async def test_non_blocking():
# Setup
shard_downloader = NewShardDownloader()
engine = MLXDynamicShardInferenceEngine(shard_downloader)
_shard = build_base_shard("llama-3.1-8b", "MLXDynamicShardInferenceEngine")
shard = Shard(_shard.model_id, _shard.start_layer, _shard.n_layers - 1, _shard.n_layers)
await engine.ensure_shard(shard)
queue = asyncio.Queue()
measurements = deque(maxlen=1000000)
running = True
async def mlx_worker():
try:
start_time = time.time()
count = 0
while running and (time.time() - start_time) < 5: # Hard time limit
start = time.perf_counter_ns()
await engine.infer_prompt("req1", shard, "test prompt")
duration = (time.perf_counter_ns() - start) / 1_000_000 # Convert to ms
count += 1
print(f"MLX operation {count} took: {duration:.3f}ms")
except asyncio.CancelledError:
pass
finally:
print(f"\nTotal MLX operations completed: {count}")
print(f"Average rate: {count/5:.1f} ops/second")
async def latency_producer():
try:
start_time = time.perf_counter_ns()
count = 0
while running:
await queue.put(time.perf_counter_ns())
count += 1
await asyncio.sleep(0) # Yield to event loop without delay
duration = (time.perf_counter_ns() - start_time) / 1e9 # Convert to seconds
print(f"\nProducer iterations: {count}")
print(f"Producer rate: {count/duration:.1f} iterations/second")
except asyncio.CancelledError:
pass
async def latency_consumer():
try:
while running:
timestamp = await queue.get()
latency = (time.perf_counter_ns() - timestamp) / 1_000_000 # Convert to ms
measurements.append(latency)
queue.task_done()
except asyncio.CancelledError:
pass
tasks = [
asyncio.create_task(mlx_worker()),
asyncio.create_task(latency_producer()),
asyncio.create_task(latency_consumer())
]
try:
await asyncio.wait_for(asyncio.gather(*tasks), timeout=6)
except asyncio.TimeoutError:
print("\nTest timed out")
finally:
running = False
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
print(f"\nFinal measurement count: {len(measurements)}")
if __name__ == "__main__":
asyncio.run(test_non_blocking())

View File

@@ -1,22 +1,16 @@
import pytest
import json
import numpy as np
from exo.inference.dummy_inference_engine import DummyInferenceEngine
from exo.inference.shard import Shard
class MockShardDownloader:
async def ensure_shard(self, shard):
pass
@pytest.mark.asyncio
async def test_dummy_inference_specific():
engine = DummyInferenceEngine(MockShardDownloader())
engine = DummyInferenceEngine()
test_shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
test_prompt = "This is a test prompt"
result = await engine.infer_prompt("test_request", test_shard, test_prompt)
result, _ = await engine.infer_prompt("test_request", test_shard, test_prompt)
print(f"Inference result shape: {result.shape}")
@@ -26,20 +20,20 @@ async def test_dummy_inference_specific():
@pytest.mark.asyncio
async def test_dummy_inference_engine():
# Initialize the DummyInferenceEngine
engine = DummyInferenceEngine(MockShardDownloader())
engine = DummyInferenceEngine()
# Create a test shard
shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
# Test infer_prompt
output = await engine.infer_prompt("test_id", shard, "Test prompt")
output, _ = await engine.infer_prompt("test_id", shard, "Test prompt")
assert isinstance(output, np.ndarray), "Output should be a numpy array"
assert output.ndim == 2, "Output should be 2-dimensional"
# Test infer_tensor
input_tensor = np.array([[1, 2, 3]])
output = await engine.infer_tensor("test_id", shard, input_tensor)
output, _ = await engine.infer_tensor("test_id", shard, input_tensor)
assert isinstance(output, np.ndarray), "Output should be a numpy array"
assert output.ndim == 2, "Output should be 2-dimensional"

View File

@@ -1,6 +1,6 @@
from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
from exo.download.hf.hf_shard_download import HFShardDownloader
from exo.inference.inference_engine import InferenceEngine
from exo.download.new_shard_download import NewShardDownloader
from exo.inference.shard import Shard
from exo.helpers import DEBUG
import os
@@ -11,30 +11,30 @@ import numpy as np
# An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str, n_layers: int):
prompt = "In a single word only, what is the last name of the current president of the USA?"
resp_full = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers), prompt=prompt)
resp_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers), prompt=prompt)
token_full = await inference_engine_1.sample(resp_full)
token_full = token_full.reshape(1, -1)
next_resp_full = await inference_engine_1.infer_tensor(
next_resp_full, _ = await inference_engine_1.infer_tensor(
"A",
shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers),
input_data=token_full,
)
pp = n_layers // 2
resp1 = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers), prompt=prompt)
resp2 = await inference_engine_2.infer_tensor(
resp1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers), prompt=prompt)
resp2, _ = await inference_engine_2.infer_tensor(
"B",
shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
input_data=resp1,
)
tokens2 = await inference_engine_1.sample(resp2)
tokens2 = tokens2.reshape(1, -1)
resp3 = await inference_engine_1.infer_tensor(
resp3, _ = await inference_engine_1.infer_tensor(
"B",
shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers),
input_data=tokens2,
)
resp4 = await inference_engine_2.infer_tensor(
resp4, _ = await inference_engine_2.infer_tensor(
"B",
shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
input_data=resp3,
@@ -44,13 +44,11 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
assert np.array_equal(next_resp_full, resp4)
asyncio.run(test_inference_engine(MLXDynamicShardInferenceEngine(HFShardDownloader()), MLXDynamicShardInferenceEngine(HFShardDownloader()), "llama-3.2-1b", 16))
asyncio.run(test_inference_engine(MLXDynamicShardInferenceEngine(NewShardDownloader()), MLXDynamicShardInferenceEngine(NewShardDownloader()), "llama-3.2-1b", 16))
if os.getenv("RUN_TINYGRAD", default="0") == "1":
import tinygrad
import os
from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
asyncio.run(
test_inference_engine(TinygradDynamicShardInferenceEngine(HFShardDownloader()), TinygradDynamicShardInferenceEngine(HFShardDownloader()), "llama-3-8b", 32)
)
asyncio.run(test_inference_engine(TinygradDynamicShardInferenceEngine(NewShardDownloader()), TinygradDynamicShardInferenceEngine(NewShardDownloader()), "llama-3.2-1b", 32))

View File

@@ -15,7 +15,7 @@ from .stateful_model import make_prompt_state
from .losses import length_masked_ce_loss
from collections import OrderedDict
import asyncio
from typing import Optional
Tensor.no_grad = True
# default settings
TEMPERATURE = int(os.getenv("TEMPERATURE", 0.85))
@@ -61,12 +61,13 @@ def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=No
return model
_executor = ThreadPoolExecutor(max_workers=1) # singleton so tinygrad always runs on the same thread
class TinygradDynamicShardInferenceEngine(InferenceEngine):
def __init__(self, shard_downloader: ShardDownloader):
self.shard = None
self.shard_downloader = shard_downloader
self.executor = ThreadPoolExecutor(max_workers=1)
self.states = OrderedDict()
self.executor = _executor
def poll_state(self, x, request_id: str, max_states=2):
if request_id not in self.states:
@@ -79,8 +80,8 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
return {"start_pos": state.start, "cache": state.cache}
async def sample(self, x: np.ndarray, temp=TEMPERATURE, top_p: float = 0.0) -> np.ndarray:
logits = x[:, -1, :]
def sample_wrapper():
logits = x[:, -1, :]
return sample_logits(Tensor(logits).flatten(), temp, 0, 0.8, top_p, 0.0).realize().numpy().astype(int)
return await asyncio.get_running_loop().run_in_executor(self.executor, sample_wrapper)
@@ -104,7 +105,7 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
state_dict = await asyncio.get_running_loop().run_in_executor(self.executor, get_state_dict, self.model)
safe_save(state_dict, path)
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray) -> np.ndarray:
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
await self.ensure_shard(shard)
def wrap_infer():
x = Tensor(input_data)
@@ -112,9 +113,9 @@ class TinygradDynamicShardInferenceEngine(InferenceEngine):
state = self.poll_state(h, request_id)
out = self.model.forward(h, **state)
self.states[request_id].start += x.shape[1]
return out.realize()
return out.numpy()
output_data = await asyncio.get_running_loop().run_in_executor(self.executor, wrap_infer)
return output_data.numpy()
return output_data, inference_state
async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss=length_masked_ce_loss):
def step(x, y, l):

View File

@@ -322,6 +322,6 @@ def fix_bf16(weights: Dict[Any, Tensor]):
}
if getenv("SUPPORT_BF16", 1):
# TODO: without casting to float16, 70B llama OOM on tinybox.
return {k: v.cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}
return {k: v.cast(dtypes.float32).cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}
# TODO: check if device supports bf16
return {k: v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}

View File

@@ -1,12 +1,11 @@
import traceback
from aiofiles import os as aios
from os import PathLike
from pathlib import Path
from aiofiles import os as aios
from typing import Union
from transformers import AutoTokenizer, AutoProcessor
import numpy as np
from exo.download.hf.hf_helpers import get_local_snapshot_dir
from exo.helpers import DEBUG
from exo.download.new_shard_download import ensure_downloads_dir
class DummyTokenizer:
@@ -14,7 +13,7 @@ class DummyTokenizer:
self.eos_token_id = 69
self.vocab_size = 1000
def apply_chat_template(self, messages, tokenize=True, add_generation_prompt=True):
def apply_chat_template(self, conversation, tokenize=True, add_generation_prompt=True, tools=None, **kwargs):
return "dummy_tokenized_prompt"
def encode(self, text):
@@ -24,25 +23,25 @@ class DummyTokenizer:
return "dummy" * len(tokens)
async def resolve_tokenizer(model_id: str):
if model_id == "dummy":
async def resolve_tokenizer(repo_id: Union[str, PathLike]):
if repo_id == "dummy":
return DummyTokenizer()
local_path = await get_local_snapshot_dir(model_id)
local_path = await ensure_downloads_dir()/str(repo_id).replace("/", "--")
if DEBUG >= 2: print(f"Checking if local path exists to load tokenizer from local {local_path=}")
try:
if local_path and await aios.path.exists(local_path):
if DEBUG >= 2: print(f"Resolving tokenizer for {model_id=} from {local_path=}")
if DEBUG >= 2: print(f"Resolving tokenizer for {repo_id=} from {local_path=}")
return await _resolve_tokenizer(local_path)
except:
if DEBUG >= 5: print(f"Local check for {local_path=} failed. Resolving tokenizer for {model_id=} normally...")
if DEBUG >= 5: print(f"Local check for {local_path=} failed. Resolving tokenizer for {repo_id=} normally...")
if DEBUG >= 5: traceback.print_exc()
return await _resolve_tokenizer(model_id)
return await _resolve_tokenizer(repo_id)
async def _resolve_tokenizer(model_id_or_local_path: Union[str, PathLike]):
async def _resolve_tokenizer(repo_id_or_local_path: Union[str, PathLike]):
try:
if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id_or_local_path}")
processor = AutoProcessor.from_pretrained(model_id_or_local_path, use_fast=True if "Mistral-Large" in f"{model_id_or_local_path}" else False, trust_remote_code=True)
if DEBUG >= 4: print(f"Trying AutoProcessor for {repo_id_or_local_path}")
processor = AutoProcessor.from_pretrained(repo_id_or_local_path, use_fast=True if "Mistral-Large" in f"{repo_id_or_local_path}" else False, trust_remote_code=True)
if not hasattr(processor, 'eos_token_id'):
processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id
if not hasattr(processor, 'encode'):
@@ -51,14 +50,14 @@ async def _resolve_tokenizer(model_id_or_local_path: Union[str, PathLike]):
processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode
return processor
except Exception as e:
if DEBUG >= 4: print(f"Failed to load processor for {model_id_or_local_path}. Error: {e}")
if DEBUG >= 4: print(f"Failed to load processor for {repo_id_or_local_path}. Error: {e}")
if DEBUG >= 4: print(traceback.format_exc())
try:
if DEBUG >= 4: print(f"Trying AutoTokenizer for {model_id_or_local_path}")
return AutoTokenizer.from_pretrained(model_id_or_local_path, trust_remote_code=True)
if DEBUG >= 4: print(f"Trying AutoTokenizer for {repo_id_or_local_path}")
return AutoTokenizer.from_pretrained(repo_id_or_local_path, trust_remote_code=True)
except Exception as e:
if DEBUG >= 4: print(f"Failed to load tokenizer for {model_id_or_local_path}. Falling back to tinygrad tokenizer. Error: {e}")
if DEBUG >= 4: print(f"Failed to load tokenizer for {repo_id_or_local_path}. Falling back to tinygrad tokenizer. Error: {e}")
if DEBUG >= 4: print(traceback.format_exc())
raise ValueError(f"[TODO] Unsupported model: {model_id_or_local_path}")
raise ValueError(f"[TODO] Unsupported model: {repo_id_or_local_path}")

View File

@@ -3,20 +3,15 @@ import asyncio
import atexit
import signal
import json
import logging
import platform
import os
import sys
import time
import traceback
import uuid
import numpy as np
from functools import partial
from tqdm import tqdm
from tqdm.asyncio import tqdm_asyncio
from exo.train.dataset import load_dataset, iterate_batches, compose
from exo.train.dataset import load_dataset, iterate_batches
from exo.networking.manual.manual_discovery import ManualDiscovery
from exo.networking.manual.network_topology_config import NetworkTopology
from exo.orchestration.node import Node
from exo.networking.grpc.grpc_server import GRPCServer
from exo.networking.udp.udp_discovery import UDPDiscovery
@@ -24,15 +19,41 @@ from exo.networking.tailscale.tailscale_discovery import TailscaleDiscovery
from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
from exo.api import ChatGPTAPI
from exo.download.shard_download import ShardDownloader, RepoProgressEvent, NoopShardDownloader
from exo.download.hf.hf_shard_download import HFShardDownloader
from exo.download.shard_download import ShardDownloader, NoopShardDownloader
from exo.download.download_progress import RepoProgressEvent
from exo.download.new_shard_download import new_shard_downloader, has_exo_home_read_access, has_exo_home_write_access, ensure_exo_home, seed_models
from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses_and_interfaces, terminal_link, shutdown
from exo.inference.shard import Shard
from exo.inference.inference_engine import get_inference_engine, InferenceEngine
from exo.inference.inference_engine import get_inference_engine
from exo.inference.tokenizers import resolve_tokenizer
from exo.models import build_base_shard, get_repo
from exo.viz.topology_viz import TopologyViz
from exo.download.hf.hf_helpers import has_hf_home_read_access, has_hf_home_write_access, get_hf_home, move_models_to_hf
import uvloop
import concurrent.futures
import resource
import psutil
# TODO: figure out why this is happening
os.environ["GRPC_VERBOSITY"] = "error"
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
os.environ["TOKENIZERS_PARALLELISM"] = "true"
# Configure uvloop for maximum performance
def configure_uvloop():
uvloop.install()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Increase file descriptor limits on Unix systems
if not psutil.WINDOWS:
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
try: resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
except ValueError:
try: resource.setrlimit(resource.RLIMIT_NOFILE, (8192, hard))
except ValueError: pass
loop.set_default_executor(concurrent.futures.ThreadPoolExecutor(max_workers=min(32, (os.cpu_count() or 1) * 4)))
return loop
# parse args
parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
@@ -51,15 +72,14 @@ parser.add_argument("--node-port", type=int, default=None, help="Node port")
parser.add_argument("--models-seed-dir", type=str, default=None, help="Model seed directory")
parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
parser.add_argument("--download-quick-check", action="store_true", help="Quick check local path for model shards download")
parser.add_argument("--max-parallel-downloads", type=int, default=4, help="Max parallel downloads for model shards download")
parser.add_argument("--prometheus-client-port", type=int, default=None, help="Prometheus client port")
parser.add_argument("--max-parallel-downloads", type=int, default=8, help="Max parallel downloads for model shards download")
parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
parser.add_argument("--discovery-module", type=str, choices=["udp", "tailscale", "manual"], default="udp", help="Discovery module to use")
parser.add_argument("--discovery-timeout", type=int, default=30, help="Discovery timeout in seconds")
parser.add_argument("--discovery-config-path", type=str, default=None, help="Path to discovery config json file")
parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
parser.add_argument("--chatgpt-api-port", type=int, default=52415, help="ChatGPT API port")
parser.add_argument("--chatgpt-api-response-timeout", type=int, default=90, help="ChatGPT API response timeout in seconds")
parser.add_argument("--chatgpt-api-response-timeout", type=int, default=900, help="ChatGPT API response timeout in seconds")
parser.add_argument("--max-generate-tokens", type=int, default=10000, help="Max tokens to generate in each request")
parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use (mlx, tinygrad, or dummy)")
parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
@@ -69,6 +89,8 @@ parser.add_argument("--default-temp", type=float, help="Default token sampling t
parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
parser.add_argument("--node-id-filter", type=str, default=None, help="Comma separated list of allowed node IDs (only for UDP and Tailscale discovery)")
parser.add_argument("--interface-type-filter", type=str, default=None, help="Comma separated list of allowed interface types (only for UDP discovery)")
parser.add_argument("--system-prompt", type=str, default=None, help="System prompt for the ChatGPT API")
args = parser.parse_args()
print(f"Selected inference engine: {args.inference_engine}")
@@ -77,8 +99,7 @@ print_yellow_exo()
system_info = get_system_info()
print(f"Detected system: {system_info}")
shard_downloader: ShardDownloader = HFShardDownloader(quick_check=args.download_quick_check,
max_parallel_downloads=args.max_parallel_downloads) if args.inference_engine != "dummy" else NoopShardDownloader()
shard_downloader: ShardDownloader = new_shard_downloader(args.max_parallel_downloads) if args.inference_engine != "dummy" else NoopShardDownloader()
inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
print(f"Inference engine name after selection: {inference_engine_name}")
@@ -100,8 +121,9 @@ if DEBUG >= 0:
for chatgpt_api_endpoint in chatgpt_api_endpoints:
print(f" - {terminal_link(chatgpt_api_endpoint)}")
# Convert node-id-filter to list if provided
# Convert node-id-filter and interface-type-filter to lists if provided
allowed_node_ids = args.node_id_filter.split(',') if args.node_id_filter else None
allowed_interface_types = args.interface_type_filter.split(',') if args.interface_type_filter else None
if args.discovery_module == "udp":
discovery = UDPDiscovery(
@@ -111,7 +133,8 @@ if args.discovery_module == "udp":
args.broadcast_port,
lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
discovery_timeout=args.discovery_timeout,
allowed_node_ids=allowed_node_ids
allowed_node_ids=allowed_node_ids,
allowed_interface_types=allowed_interface_types
)
elif args.discovery_module == "tailscale":
discovery = TailscaleDiscovery(
@@ -133,62 +156,73 @@ node = Node(
None,
inference_engine,
discovery,
shard_downloader,
partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
max_generate_tokens=args.max_generate_tokens,
topology_viz=topology_viz,
shard_downloader=shard_downloader,
default_sample_temperature=args.default_temp
)
server = GRPCServer(node, args.node_host, args.node_port)
node.server = server
api = ChatGPTAPI(
node,
inference_engine.__class__.__name__,
node.inference_engine.__class__.__name__,
response_timeout=args.chatgpt_api_response_timeout,
on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None,
default_model=args.default_model
default_model=args.default_model,
system_prompt=args.system_prompt
)
node.on_token.register("update_topology_viz").on_next(
lambda req_id, tokens, __: topology_viz.update_prompt_output(req_id, inference_engine.tokenizer.decode(tokens)) if topology_viz and hasattr(inference_engine, "tokenizer") else None
)
def preemptively_start_download(request_id: str, opaque_status: str):
buffered_token_output = {}
def update_topology_viz(req_id, tokens, __):
if not topology_viz: return
if not node.inference_engine.shard: return
if node.inference_engine.shard.model_id == 'stable-diffusion-2-1-base': return
if req_id in buffered_token_output: buffered_token_output[req_id].extend(tokens)
else: buffered_token_output[req_id] = tokens
topology_viz.update_prompt_output(req_id, node.inference_engine.tokenizer.decode(buffered_token_output[req_id]))
node.on_token.register("update_topology_viz").on_next(update_topology_viz)
def update_prompt_viz(request_id, opaque_status: str):
if not topology_viz: return
try:
status = json.loads(opaque_status)
if status.get("type") == "node_status" and status.get("status") == "start_process_prompt":
current_shard = node.get_current_shard(Shard.from_dict(status.get("shard")))
if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}")
asyncio.create_task(shard_downloader.ensure_shard(current_shard, inference_engine.__class__.__name__))
if status.get("type") != "node_status" or status.get("status") != "start_process_prompt": return
topology_viz.update_prompt(request_id, status.get("prompt", "corrupted prompt (this should never happen)"))
except Exception as e:
if DEBUG >= 2:
print(f"Failed to update prompt viz: {e}")
traceback.print_exc()
node.on_opaque_status.register("update_prompt_viz").on_next(update_prompt_viz)
def preemptively_load_shard(request_id: str, opaque_status: str):
try:
status = json.loads(opaque_status)
if status.get("type") != "node_status" or status.get("status") != "start_process_prompt": return
current_shard = node.get_current_shard(Shard.from_dict(status.get("shard")))
if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}")
asyncio.create_task(node.inference_engine.ensure_shard(current_shard))
except Exception as e:
if DEBUG >= 2:
print(f"Failed to preemptively start download: {e}")
traceback.print_exc()
node.on_opaque_status.register("preemptively_load_shard").on_next(preemptively_load_shard)
node.on_opaque_status.register("start_download").on_next(preemptively_start_download)
if args.prometheus_client_port:
from exo.stats.metrics import start_metrics_server
start_metrics_server(node, args.prometheus_client_port)
last_broadcast_time = 0
last_events: dict[str, tuple[float, RepoProgressEvent]] = {}
def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
global last_broadcast_time
global last_events
current_time = time.time()
if event.status == "complete" or current_time - last_broadcast_time >= 0.1:
last_broadcast_time = current_time
asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
if event.status == "not_started": return
last_event = last_events.get(shard.model_id)
if last_event and last_event[1].status == "complete" and event.status == "complete": return
if last_event and last_event[0] == event.status and current_time - last_event[0] < 0.2: return
last_events[shard.model_id] = (current_time, event)
asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_name: str, prompt: str):
inference_class = inference_engine.__class__.__name__
async def run_model_cli(node: Node, model_name: str, prompt: str):
inference_class = node.inference_engine.__class__.__name__
shard = build_base_shard(model_name, inference_class)
if not shard:
print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
print(f"Error: Unsupported model '{model_name}' for inference engine {inference_class}")
return
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class))
request_id = str(uuid.uuid4())
@@ -202,7 +236,11 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam
print(f"Processing prompt: {prompt}")
await node.process_prompt(shard, prompt, request_id=request_id)
_, tokens, _ = await callback.wait(lambda _request_id, tokens, is_finished: _request_id == request_id and is_finished, timeout=300)
tokens = []
def on_token(_request_id, _tokens, _is_finished):
tokens.extend(_tokens)
return _request_id == request_id and _is_finished
await callback.wait(on_token, timeout=300)
print("\nGenerated response:")
print(tokenizer.decode(tokens))
@@ -221,7 +259,7 @@ def clean_path(path):
async def hold_outstanding(node: Node):
while node.outstanding_requests:
await asyncio.sleep(.5)
return
return
async def run_iter(node: Node, shard: Shard, train: bool, data, batch_size=1):
losses = []
@@ -232,14 +270,14 @@ async def run_iter(node: Node, shard: Shard, train: bool, data, batch_size=1):
tokens.append(np.sum(lengths))
total_tokens = np.sum(tokens)
total_loss = np.sum(losses) / total_tokens
return total_loss, total_tokens
async def eval_model_cli(node: Node, inference_engine: InferenceEngine, model_name, dataloader, batch_size, num_batches=-1):
inference_class = inference_engine.__class__.__name__
async def eval_model_cli(node: Node, model_name, dataloader, batch_size, num_batches=-1):
inference_class = node.inference_engine.__class__.__name__
shard = build_base_shard(model_name, inference_class)
if not shard:
print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
print(f"Error: Unsupported model '{model_name}' for inference engine {inference_class}")
return
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class))
train, val, test = dataloader(tokenizer.encode)
@@ -249,11 +287,11 @@ async def eval_model_cli(node: Node, inference_engine: InferenceEngine, model_na
print("Waiting for outstanding tasks")
await hold_outstanding(node)
async def train_model_cli(node: Node, inference_engine: InferenceEngine, model_name, dataloader, batch_size, iters, save_interval=0, checkpoint_dir=None):
inference_class = inference_engine.__class__.__name__
async def train_model_cli(node: Node, model_name, dataloader, batch_size, iters, save_interval=0, checkpoint_dir=None):
inference_class = node.inference_engine.__class__.__name__
shard = build_base_shard(model_name, inference_class)
if not shard:
print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
print(f"Error: Unsupported model '{model_name}' for inference engine {inference_class}")
return
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class))
train, val, test = dataloader(tokenizer.encode)
@@ -268,28 +306,30 @@ async def train_model_cli(node: Node, inference_engine: InferenceEngine, model_n
await hold_outstanding(node)
await hold_outstanding(node)
async def main():
loop = asyncio.get_running_loop()
# Check HuggingFace directory permissions
hf_home, has_read, has_write = get_hf_home(), await has_hf_home_read_access(), await has_hf_home_write_access()
if DEBUG >= 1: print(f"Model storage directory: {hf_home}")
async def check_exo_home():
home, has_read, has_write = await ensure_exo_home(), await has_exo_home_read_access(), await has_exo_home_write_access()
if DEBUG >= 1: print(f"exo home directory: {home}")
print(f"{has_read=}, {has_write=}")
if not has_read or not has_write:
print(f"""
WARNING: Limited permissions for model storage directory: {hf_home}.
WARNING: Limited permissions for exo home directory: {home}.
This may prevent model downloads from working correctly.
{"❌ No read access" if not has_read else ""}
{"❌ No write access" if not has_write else ""}
""")
async def main():
loop = asyncio.get_running_loop()
try: await check_exo_home()
except Exception as e: print(f"Error checking exo home directory: {e}")
if not args.models_seed_dir is None:
try:
models_seed_dir = clean_path(args.models_seed_dir)
await move_models_to_hf(models_seed_dir)
await seed_models(models_seed_dir)
except Exception as e:
print(f"Error moving models to .cache/huggingface: {e}")
print(f"Error seeding models: {e}")
def restore_cursor():
if platform.system() != "Windows":
@@ -313,7 +353,7 @@ async def main():
if not model_name:
print("Error: Model name is required when using 'run' command or --run-model")
return
await run_model_cli(node, inference_engine, model_name, args.prompt)
await run_model_cli(node, model_name, args.prompt)
elif args.command == "eval" or args.command == 'train':
model_name = args.model_name
dataloader = lambda tok: load_dataset(args.data, preprocess=lambda item: tok(item)
@@ -322,35 +362,31 @@ async def main():
if not model_name:
print("Error: Much like a human, I can't evaluate anything without a model")
return
await eval_model_cli(node, inference_engine, model_name, dataloader, args.batch_size)
await eval_model_cli(node, model_name, dataloader, args.batch_size)
else:
if not model_name:
print("Error: This train ain't leaving the station without a model")
return
await train_model_cli(node, inference_engine, model_name, dataloader, args.batch_size, args.iters, save_interval=args.save_every, checkpoint_dir=args.save_checkpoint_dir)
await train_model_cli(node, model_name, dataloader, args.batch_size, args.iters, save_interval=args.save_every, checkpoint_dir=args.save_checkpoint_dir)
else:
asyncio.create_task(api.run(port=args.chatgpt_api_port)) # Start the API server as a non-blocking task
await asyncio.Event().wait()
if args.wait_for_peers > 0:
print("Cooldown to allow peers to exit gracefully")
for i in tqdm(range(50)):
await asyncio.sleep(.1)
def run():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(main())
except KeyboardInterrupt:
print("Received keyboard interrupt. Shutting down...")
finally:
loop.run_until_complete(shutdown(signal.SIGTERM, loop, node.server))
loop.close()
loop = None
try:
loop = configure_uvloop()
loop.run_until_complete(main())
except KeyboardInterrupt:
print("\nShutdown requested... exiting")
finally:
if loop: loop.close()
if __name__ == "__main__":
run()

View File

@@ -88,18 +88,55 @@ model_cards = {
### deepseek
"deepseek-coder-v2-lite": { "layers": 27, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", }, },
"deepseek-coder-v2.5": { "layers": 60, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", }, },
"deepseek-v3": { "layers": 61, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-V3-4bit", }, },
"deepseek-v3-3bit": { "layers": 61, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-V3-3bit", }, },
"deepseek-r1": { "layers": 61, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-4bit", }, },
"deepseek-r1-3bit": { "layers": 61, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-3bit", }, },
### deepseek distills
"deepseek-r1-distill-qwen-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/deepseek-r1-distill-qwen-1.5b", }, },
"deepseek-r1-distill-qwen-1.5b-3bit": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-3bit", }, },
"deepseek-r1-distill-qwen-1.5b-6bit": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-6bit", }, },
"deepseek-r1-distill-qwen-1.5b-8bit": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-8bit", }, },
"deepseek-r1-distill-qwen-1.5b-bf16": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-bf16", }, },
"deepseek-r1-distill-qwen-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-7B-4bit", }, },
"deepseek-r1-distill-qwen-7b-3bit": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-7B-3bit", }, },
"deepseek-r1-distill-qwen-7b-6bit": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-7B-6bit", }, },
"deepseek-r1-distill-qwen-7b-8bit": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-7B-8bit", }, },
"deepseek-r1-distill-qwen-7b-bf16": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-7B-bf16", }, },
"deepseek-r1-distill-qwen-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-14B-4bit", }, },
"deepseek-r1-distill-qwen-14b-3bit": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-14B-3bit", }, },
"deepseek-r1-distill-qwen-14b-6bit": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-14B-6bit", }, },
"deepseek-r1-distill-qwen-14b-8bit": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-14B-8bit", }, },
"deepseek-r1-distill-qwen-14b-bf16": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-14B-bf16", }, },
"deepseek-r1-distill-qwen-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-32B-4bit", }, },
"deepseek-r1-distill-qwen-32b-3bit": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-32B-3bit", }, },
"deepseek-r1-distill-qwen-32b-6bit": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-32B-6bit", }, },
"deepseek-r1-distill-qwen-32b-8bit": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-32B-MLX-8Bit", }, },
"deepseek-r1-distill-qwen-32b-bf16": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-32B-bf16", }, },
"deepseek-r1-distill-llama-8b": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-8B-4bit", }, },
"deepseek-r1-distill-llama-8b-3bit": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-8B-3bit", }, },
"deepseek-r1-distill-llama-8b-6bit": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-8B-6bit", }, },
"deepseek-r1-distill-llama-8b-8bit": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-8B-8bit", }, },
"deepseek-r1-distill-llama-8b-bf16": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-8B-bf16", }, },
"deepseek-r1-distill-llama-70b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-70B-4bit", }, },
"deepseek-r1-distill-llama-70b-3bit": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-70B-3bit", }, },
"deepseek-r1-distill-llama-70b-6bit": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-70B-6bit", }, },
"deepseek-r1-distill-llama-70b-8bit": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-70B-8bit", }, },
### llava
"llava-1.5-7b-hf": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "llava-hf/llava-1.5-7b-hf", }, },
### qwen
"qwen-2.5-0.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-0.5B-Instruct-4bit", }, },
"qwen-2.5-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-1.5B-Instruct-4bit", }, },
"qwen-2.5-coder-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-1.5B-Instruct-4bit", }, },
"qwen-2.5-3b": { "layers": 36, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-3B-Instruct-4bit", }, },
"qwen-2.5-coder-3b": { "layers": 36, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-3B-Instruct-4bit", }, },
"qwen-2.5-coder-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", }, },
"qwen-2.5-coder-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-14B-Instruct-4bit", }, },
"qwen-2.5-coder-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-32B-Instruct-4bit", }, },
"qwen-2.5-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-7B-Instruct-4bit", }, },
"qwen-2.5-coder-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", }, },
"qwen-2.5-math-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-7B-Instruct-4bit", }, },
"qwen-2.5-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-14B-Instruct-4bit", }, },
"qwen-2.5-coder-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-14B-Instruct-4bit", }, },
"qwen-2.5-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-32B-Instruct-4bit", }, },
"qwen-2.5-coder-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-32B-Instruct-4bit", }, },
"qwen-2.5-72b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-72B-Instruct-4bit", }, },
"qwen-2.5-math-72b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-72B-Instruct-4bit", }, },
### nemotron
@@ -108,6 +145,11 @@ model_cards = {
# gemma
"gemma2-9b": { "layers": 42, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-9b-it-4bit", }, },
"gemma2-27b": { "layers": 46, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-27b-it-4bit", }, },
# stable diffusion
"stable-diffusion-2-1-base": { "layers": 31, "repo": { "MLXDynamicShardInferenceEngine": "stabilityai/stable-diffusion-2-1-base" } },
# phi
"phi-3.5-mini": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Phi-3.5-mini-instruct-4bit", }, },
"phi-4": { "layers": 40, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/phi-4-4bit", }, },
# dummy
"dummy": { "layers": 8, "repo": { "DummyInferenceEngine": "dummy", }, },
}
@@ -132,24 +174,70 @@ pretty_name = {
"mistral-large": "Mistral Large",
"deepseek-coder-v2-lite": "Deepseek Coder V2 Lite",
"deepseek-coder-v2.5": "Deepseek Coder V2.5",
"deepseek-v3": "Deepseek V3 (4-bit)",
"deepseek-v3-3bit": "Deepseek V3 (3-bit)",
"deepseek-r1": "Deepseek R1 (4-bit)",
"deepseek-r1-3bit": "Deepseek R1 (3-bit)",
"llava-1.5-7b-hf": "LLaVa 1.5 7B (Vision Model)",
"qwen-2.5-0.5b": "Qwen 2.5 0.5B",
"qwen-2.5-1.5b": "Qwen 2.5 1.5B",
"qwen-2.5-coder-1.5b": "Qwen 2.5 Coder 1.5B",
"qwen-2.5-3b": "Qwen 2.5 3B",
"qwen-2.5-coder-3b": "Qwen 2.5 Coder 3B",
"qwen-2.5-coder-7b": "Qwen 2.5 Coder 7B",
"qwen-2.5-coder-14b": "Qwen 2.5 Coder 14B",
"qwen-2.5-coder-32b": "Qwen 2.5 Coder 32B",
"qwen-2.5-7b": "Qwen 2.5 7B",
"qwen-2.5-coder-7b": "Qwen 2.5 Coder 7B",
"qwen-2.5-math-7b": "Qwen 2.5 7B (Math)",
"qwen-2.5-14b": "Qwen 2.5 14B",
"qwen-2.5-coder-14b": "Qwen 2.5 Coder 14B",
"qwen-2.5-32b": "Qwen 2.5 32B",
"qwen-2.5-coder-32b": "Qwen 2.5 Coder 32B",
"qwen-2.5-72b": "Qwen 2.5 72B",
"qwen-2.5-math-72b": "Qwen 2.5 72B (Math)",
"phi-3.5-mini": "Phi-3.5 Mini",
"phi-4": "Phi-4",
"llama-3-8b": "Llama 3 8B",
"llama-3-70b": "Llama 3 70B",
"stable-diffusion-2-1-base": "Stable Diffusion 2.1",
"deepseek-r1-distill-qwen-1.5b": "DeepSeek R1 Distill Qwen 1.5B",
"deepseek-r1-distill-qwen-1.5b-3bit": "DeepSeek R1 Distill Qwen 1.5B (3-bit)",
"deepseek-r1-distill-qwen-1.5b-6bit": "DeepSeek R1 Distill Qwen 1.5B (6-bit)",
"deepseek-r1-distill-qwen-1.5b-8bit": "DeepSeek R1 Distill Qwen 1.5B (8-bit)",
"deepseek-r1-distill-qwen-1.5b-bf16": "DeepSeek R1 Distill Qwen 1.5B (BF16)",
"deepseek-r1-distill-qwen-7b": "DeepSeek R1 Distill Qwen 7B",
"deepseek-r1-distill-qwen-7b-3bit": "DeepSeek R1 Distill Qwen 7B (3-bit)",
"deepseek-r1-distill-qwen-7b-6bit": "DeepSeek R1 Distill Qwen 7B (6-bit)",
"deepseek-r1-distill-qwen-7b-8bit": "DeepSeek R1 Distill Qwen 7B (8-bit)",
"deepseek-r1-distill-qwen-7b-bf16": "DeepSeek R1 Distill Qwen 7B (BF16)",
"deepseek-r1-distill-qwen-14b": "DeepSeek R1 Distill Qwen 14B",
"deepseek-r1-distill-qwen-14b-3bit": "DeepSeek R1 Distill Qwen 14B (3-bit)",
"deepseek-r1-distill-qwen-14b-6bit": "DeepSeek R1 Distill Qwen 14B (6-bit)",
"deepseek-r1-distill-qwen-14b-8bit": "DeepSeek R1 Distill Qwen 14B (8-bit)",
"deepseek-r1-distill-qwen-14b-bf16": "DeepSeek R1 Distill Qwen 14B (BF16)",
"deepseek-r1-distill-qwen-32b": "DeepSeek R1 Distill Qwen 32B",
"deepseek-r1-distill-qwen-32b-3bit": "DeepSeek R1 Distill Qwen 32B (3-bit)",
"deepseek-r1-distill-qwen-32b-8bit": "DeepSeek R1 Distill Qwen 32B (8-bit)",
"deepseek-r1-distill-qwen-32b-bf16": "DeepSeek R1 Distill Qwen 32B (BF16)",
"deepseek-r1-distill-llama-8b-8bit": "DeepSeek R1 Distill Llama 8B (8-bit)",
"deepseek-r1-distill-llama-70b-6bit": "DeepSeek R1 Distill Llama 70B (6-bit)",
"deepseek-r1-distill-llama-70b-8bit": "DeepSeek R1 Distill Llama 70B (8-bit)",
"deepseek-r1-distill-llama-8b": "DeepSeek R1 Distill Llama 8B",
"deepseek-r1-distill-llama-8b-3bit": "DeepSeek R1 Distill Llama 8B (3-bit)",
"deepseek-r1-distill-llama-8b-6bit": "DeepSeek R1 Distill Llama 8B (6-bit)",
"deepseek-r1-distill-llama-8b-8bit": "DeepSeek R1 Distill Llama 8B (8-bit)",
"deepseek-r1-distill-llama-8b-bf16": "DeepSeek R1 Distill Llama 8B (BF16)",
"deepseek-r1-distill-llama-70b": "DeepSeek R1 Distill Llama 70B",
"deepseek-r1-distill-llama-70b-3bit": "DeepSeek R1 Distill Llama 70B (3-bit)",
"deepseek-r1-distill-llama-70b-6bit": "DeepSeek R1 Distill Llama 70B (6-bit)",
"deepseek-r1-distill-llama-70b-8bit": "DeepSeek R1 Distill Llama 70B (8-bit)",
"deepseek-r1-distill-qwen-32b-6bit": "DeepSeek R1 Distill Qwen 32B (6-bit)",
}
def get_repo(model_id: str, inference_engine_classname: str) -> Optional[str]:
return model_cards.get(model_id, {}).get("repo", {}).get(inference_engine_classname, None)
def get_pretty_name(model_id: str) -> Optional[str]:
return pretty_name.get(model_id, None)
def build_base_shard(model_id: str, inference_engine_classname: str) -> Optional[Shard]:
repo = get_repo(model_id, inference_engine_classname)
n_layers = model_cards.get(model_id, {}).get("layers", 0)
@@ -157,7 +245,12 @@ def build_base_shard(model_id: str, inference_engine_classname: str) -> Optional
return None
return Shard(model_id, 0, 0, n_layers)
def get_supported_models(supported_inference_engine_lists: List[List[str]]) -> List[str]:
def build_full_shard(model_id: str, inference_engine_classname: str) -> Optional[Shard]:
base_shard = build_base_shard(model_id, inference_engine_classname)
if base_shard is None: return None
return Shard(base_shard.model_id, 0, base_shard.n_layers - 1, base_shard.n_layers)
def get_supported_models(supported_inference_engine_lists: Optional[List[List[str]]] = None) -> List[str]:
if not supported_inference_engine_lists:
return list(model_cards.keys())

View File

@@ -11,6 +11,13 @@ from exo.inference.shard import Shard
from exo.topology.topology import Topology
from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
from exo.helpers import DEBUG
import json
import platform
if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
import mlx.core as mx
else:
import numpy as mx
class GRPCPeerHandle(PeerHandle):
@@ -21,6 +28,19 @@ class GRPCPeerHandle(PeerHandle):
self._device_capabilities = device_capabilities
self.channel = None
self.stub = None
self.channel_options = [
("grpc.max_metadata_size", 64 * 1024 * 1024),
("grpc.max_receive_message_length", 256 * 1024 * 1024),
("grpc.max_send_message_length", 256 * 1024 * 1024),
("grpc.max_concurrent_streams", 100),
("grpc.http2.min_time_between_pings_ms", 10000),
("grpc.keepalive_time_ms", 20000),
("grpc.keepalive_timeout_ms", 10000),
("grpc.keepalive_permit_without_calls", 1),
("grpc.http2.max_pings_without_data", 0),
("grpc.tcp_nodelay", 1),
("grpc.optimization_target", "throughput"),
]
def id(self) -> str:
return self._id
@@ -36,11 +56,11 @@ class GRPCPeerHandle(PeerHandle):
async def connect(self):
if self.channel is None:
self.channel = grpc.aio.insecure_channel(self.address, options=[
("grpc.max_metadata_size", 32*1024*1024),
('grpc.max_receive_message_length', 32*1024*1024),
('grpc.max_send_message_length', 32*1024*1024)
])
self.channel = grpc.aio.insecure_channel(
self.address,
options=self.channel_options,
compression=grpc.Compression.Gzip
)
self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
await self.channel.channel_ready()
@@ -54,7 +74,13 @@ class GRPCPeerHandle(PeerHandle):
self.stub = None
async def _ensure_connected(self):
if not await self.is_connected(): await asyncio.wait_for(self.connect(), timeout=5)
if not await self.is_connected():
try:
await asyncio.wait_for(self.connect(), timeout=10.0)
except asyncio.TimeoutError:
if DEBUG >= 2: print(f"Connection timeout for {self._id}@{self.address}")
await self.disconnect()
raise
async def health_check(self) -> bool:
try:
@@ -71,7 +97,7 @@ class GRPCPeerHandle(PeerHandle):
traceback.print_exc()
return False
async def send_prompt(self, shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.array]:
async def send_prompt(self, shard: Shard, prompt: str, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]:
request = node_service_pb2.PromptRequest(
prompt=prompt,
shard=node_service_pb2.Shard(
@@ -81,15 +107,11 @@ class GRPCPeerHandle(PeerHandle):
n_layers=shard.n_layers,
),
request_id=request_id,
inference_state=None if inference_state is None else self.serialize_inference_state(inference_state)
)
response = await self.stub.SendPrompt(request)
await self.stub.SendPrompt(request)
if not response.tensor_data or not response.shape or not response.dtype:
return None
return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
async def send_tensor(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.array]:
async def send_tensor(self, shard: Shard, tensor: np.ndarray, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]:
request = node_service_pb2.TensorRequest(
shard=node_service_pb2.Shard(
model_id=shard.model_id,
@@ -99,6 +121,7 @@ class GRPCPeerHandle(PeerHandle):
),
tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)),
request_id=request_id,
inference_state=None if inference_state is None else self.serialize_inference_state(inference_state)
)
response = await self.stub.SendTensor(request)
@@ -106,7 +129,7 @@ class GRPCPeerHandle(PeerHandle):
return None
return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
async def send_example(self, shard: Shard, example: np.ndarray, target: np.ndarray, length: np.ndarray, train: bool, request_id: Optional[str] = None) -> Optional[np.array]:
request = node_service_pb2.ExampleRequest(
shard=node_service_pb2.Shard(
@@ -128,7 +151,7 @@ class GRPCPeerHandle(PeerHandle):
return loss, grads
else:
return loss
async def send_loss(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.array]:
request = node_service_pb2.TensorRequest(
shard=node_service_pb2.Shard(
@@ -147,26 +170,13 @@ class GRPCPeerHandle(PeerHandle):
return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
request = node_service_pb2.GetInferenceResultRequest(request_id=request_id)
response = await self.stub.GetInferenceResult(request)
if response.tensor is None:
return None, response.is_finished
return (
np.frombuffer(response.tensor.tensor_data, dtype=np.dtype(response.tensor.dtype)).reshape(response.tensor.shape),
response.is_finished,
)
async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
request = node_service_pb2.CollectTopologyRequest(visited=visited, max_depth=max_depth)
response = await self.stub.CollectTopology(request)
topology = Topology()
for node_id, capabilities in response.nodes.items():
device_capabilities = DeviceCapabilities(
model=capabilities.model,
chip=capabilities.chip,
memory=capabilities.memory,
flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8)
model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory, flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8)
)
topology.update_node(node_id, device_capabilities)
for node_id, peer_connections in response.peer_graph.items():
@@ -175,9 +185,35 @@ class GRPCPeerHandle(PeerHandle):
return topology
async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
request = node_service_pb2.SendResultRequest(request_id=request_id, result=result, is_finished=is_finished)
tensor = None
if isinstance(result, np.ndarray):
tensor = node_service_pb2.Tensor(tensor_data=result.tobytes(), shape=result.shape, dtype=str(result.dtype))
result = []
request = node_service_pb2.SendResultRequest(request_id=request_id, result=result, tensor=tensor, is_finished=is_finished)
await self.stub.SendResult(request)
async def send_opaque_status(self, request_id: str, status: str) -> None:
request = node_service_pb2.SendOpaqueStatusRequest(request_id=request_id, status=status)
await self.stub.SendOpaqueStatus(request)
def serialize_inference_state(self, inference_state: dict) -> node_service_pb2.InferenceState:
proto_inference_state = node_service_pb2.InferenceState()
other_data = {}
for k, v in inference_state.items():
if isinstance(v, mx.array):
np_array = np.array(v)
tensor_data = node_service_pb2.Tensor(tensor_data=np_array.tobytes(), shape=list(np_array.shape), dtype=str(np_array.dtype))
proto_inference_state.tensor_data[k].CopyFrom(tensor_data)
elif isinstance(v, list) and all(isinstance(item, mx.array) for item in v):
tensor_list = node_service_pb2.TensorList()
for tensor in v:
np_array = np.array(tensor)
tensor_data = node_service_pb2.Tensor(tensor_data=np_array.tobytes(), shape=list(np_array.shape), dtype=str(np_array.dtype))
tensor_list.tensors.append(tensor_data)
proto_inference_state.tensor_list_data[k].CopyFrom(tensor_list)
else:
# For non-tensor data, we'll still use JSON
other_data[k] = v
if other_data:
proto_inference_state.other_data_json = json.dumps(other_data)
return proto_inference_state

View File

@@ -3,11 +3,19 @@ from concurrent import futures
import numpy as np
from asyncio import CancelledError
import platform
from . import node_service_pb2
from . import node_service_pb2_grpc
from exo import DEBUG
from exo.inference.shard import Shard
from exo.orchestration import Node
import json
if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
import mlx.core as mx
else:
import numpy as mx
class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
@@ -19,11 +27,19 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
async def start(self) -> None:
self.server = grpc.aio.server(
futures.ThreadPoolExecutor(max_workers=10),
futures.ThreadPoolExecutor(max_workers=32),
options=[
("grpc.max_metadata_size", 32*1024*1024),
("grpc.max_send_message_length", 128*1024*1024),
("grpc.max_receive_message_length", 128*1024*1024),
("grpc.max_send_message_length", 256*1024*1024),
("grpc.max_receive_message_length", 256*1024*1024),
("grpc.keepalive_time_ms", 10000),
("grpc.keepalive_timeout_ms", 5000),
("grpc.http2.max_pings_without_data", 0),
("grpc.http2.min_time_between_pings_ms", 10000),
("grpc.http2.min_ping_interval_without_data_ms", 5000),
("grpc.max_concurrent_streams", 100),
("grpc.tcp_nodelay", 1),
("grpc.optimization_target", "throughput"),
],
)
node_service_pb2_grpc.add_NodeServiceServicer_to_server(self, self.server)
@@ -50,7 +66,8 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
)
prompt = request.prompt
request_id = request.request_id
result = await self.node.process_prompt(shard, prompt, request_id)
inference_state = None if request.inference_state is None else self.deserialize_inference_state(request.inference_state)
result = await self.node.process_prompt(shard, prompt, request_id, inference_state)
if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}")
tensor_data = result.tobytes() if result is not None else None
return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
@@ -65,11 +82,13 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
request_id = request.request_id
result = await self.node.process_tensor(shard, tensor, request_id)
inference_state = None if request.inference_state is None else self.deserialize_inference_state(request.inference_state)
result = await self.node.process_tensor(shard, tensor, request_id, inference_state)
if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
tensor_data = result.tobytes() if result is not None else None
return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
async def SendExample(self, request, context):
shard = Shard(
model_id=request.shard.model_id,
@@ -91,7 +110,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
else:
loss = await self.node.process_example(shard, example, target, length, train, request_id)
return node_service_pb2.Loss(loss=loss, grads=None)
async def CollectTopology(self, request, context):
max_depth = request.max_depth
visited = set(request.visited)
@@ -107,12 +126,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
for node_id, cap in topology.nodes.items()
}
peer_graph = {
node_id: node_service_pb2.PeerConnections(
connections=[
node_service_pb2.PeerConnection(to_id=conn.to_id, description=conn.description)
for conn in connections
]
)
node_id: node_service_pb2.PeerConnections(connections=[node_service_pb2.PeerConnection(to_id=conn.to_id, description=conn.description) for conn in connections])
for node_id, connections in topology.peer_graph.items()
}
if DEBUG >= 5: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}")
@@ -122,7 +136,11 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
request_id = request.request_id
result = request.result
is_finished = request.is_finished
img = request.tensor
if DEBUG >= 5: print(f"Received SendResult request: {request_id=} {result=} {is_finished=}")
result = list(result)
if len(img.tensor_data) > 0:
result = np.frombuffer(img.tensor_data, dtype=np.dtype(img.dtype)).reshape(img.shape)
self.node.on_token.trigger_all(request_id, result, is_finished)
return node_service_pb2.Empty()
@@ -135,3 +153,19 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
async def HealthCheck(self, request, context):
return node_service_pb2.HealthCheckResponse(is_healthy=True)
def deserialize_inference_state(self, inference_state_proto: node_service_pb2.InferenceState) -> dict:
inference_state = {}
for k, tensor_data in inference_state_proto.tensor_data.items():
np_array = np.frombuffer(tensor_data.tensor_data, dtype=tensor_data.dtype).reshape(tensor_data.shape)
inference_state[k] = mx.array(np_array)
for k, tensor_list in inference_state_proto.tensor_list_data.items():
inference_state[k] = [mx.array(np.frombuffer(tensor.tensor_data, dtype=tensor.dtype).reshape(tensor.shape)) for tensor in tensor_list.tensors]
if inference_state_proto.other_data_json:
other_data = json.loads(inference_state_proto.other_data_json)
inference_state.update(other_data)
return inference_state

View File

@@ -6,7 +6,6 @@ service NodeService {
rpc SendPrompt (PromptRequest) returns (Tensor) {}
rpc SendTensor (TensorRequest) returns (Tensor) {}
rpc SendExample (ExampleRequest) returns (Loss) {}
rpc GetInferenceResult (GetInferenceResultRequest) returns (InferenceResult) {}
rpc CollectTopology (CollectTopologyRequest) returns (Topology) {}
rpc SendResult (SendResultRequest) returns (Empty) {}
rpc SendOpaqueStatus (SendOpaqueStatusRequest) returns (Empty) {}
@@ -24,12 +23,14 @@ message PromptRequest {
Shard shard = 1;
string prompt = 2;
optional string request_id = 3;
optional InferenceState inference_state = 4;
}
message TensorRequest {
Shard shard = 1;
Tensor tensor = 2;
optional string request_id = 3;
optional InferenceState inference_state = 4;
}
message ExampleRequest {
@@ -45,15 +46,6 @@ message Loss {
float loss = 1;
optional Tensor grads = 2;
}
message GetInferenceResultRequest {
string request_id = 1;
}
message InferenceResult {
optional Tensor tensor = 1;
bool is_finished = 2;
}
message Tensor {
bytes tensor_data = 1;
@@ -61,6 +53,16 @@ message Tensor {
string dtype = 3;
}
message TensorList {
repeated Tensor tensors = 1;
}
message InferenceState {
map<string, Tensor> tensor_data = 1;
map<string, TensorList> tensor_list_data = 2;
string other_data_json = 3;
}
message CollectTopologyRequest {
repeated string visited = 1;
int32 max_depth = 2;
@@ -96,7 +98,8 @@ message DeviceCapabilities {
message SendResultRequest {
string request_id = 1;
repeated int32 result = 2;
bool is_finished = 3;
optional Tensor tensor = 3;
bool is_finished = 4;
}
message SendOpaqueStatusRequest {

View File

@@ -24,59 +24,67 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"k\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\r\n\x0b_request_id\"\x81\x01\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\r\n\x0b_request_id\"\xde\x01\n\x0e\x45xampleRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12%\n\x07\x65xample\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12$\n\x06target\x18\x03 \x01(\x0b\x32\x14.node_service.Tensor\x12$\n\x06length\x18\x04 \x01(\x0b\x32\x14.node_service.Tensor\x12\r\n\x05train\x18\x05 \x01(\x08\x12\x17\n\nrequest_id\x18\x06 \x01(\tH\x00\x88\x01\x01\x42\r\n\x0b_request_id\"H\n\x04Loss\x12\x0c\n\x04loss\x18\x01 \x01(\x02\x12(\n\x05grads\x18\x02 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x42\x08\n\x06_grads\"/\n\x19GetInferenceResultRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\"\\\n\x0fInferenceResult\x12)\n\x06tensor\x18\x01 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x12\x13\n\x0bis_finished\x18\x02 \x01(\x08\x42\t\n\x07_tensor\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"<\n\x16\x43ollectTopologyRequest\x12\x0f\n\x07visited\x18\x01 \x03(\t\x12\x11\n\tmax_depth\x18\x02 \x01(\x05\"\x98\x02\n\x08Topology\x12\x30\n\x05nodes\x18\x01 \x03(\x0b\x32!.node_service.Topology.NodesEntry\x12\x39\n\npeer_graph\x18\x02 \x03(\x0b\x32%.node_service.Topology.PeerGraphEntry\x1aN\n\nNodesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .node_service.DeviceCapabilities:\x02\x38\x01\x1aO\n\x0ePeerGraphEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12,\n\x05value\x18\x02 \x01(\x0b\x32\x1d.node_service.PeerConnections:\x02\x38\x01\"I\n\x0ePeerConnection\x12\r\n\x05to_id\x18\x01 \x01(\t\x12\x18\n\x0b\x64\x65scription\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x0e\n\x0c_description\"D\n\x0fPeerConnections\x12\x31\n\x0b\x63onnections\x18\x01 \x03(\x0b\x32\x1c.node_service.PeerConnection\"7\n\x0b\x44\x65viceFlops\x12\x0c\n\x04\x66p32\x18\x01 \x01(\x01\x12\x0c\n\x04\x66p16\x18\x02 \x01(\x01\x12\x0c\n\x04int8\x18\x03 \x01(\x01\"k\n\x12\x44\x65viceCapabilities\x12\r\n\x05model\x18\x01 \x01(\t\x12\x0c\n\x04\x63hip\x18\x02 \x01(\t\x12\x0e\n\x06memory\x18\x03 \x01(\x05\x12(\n\x05\x66lops\x18\x04 \x01(\x0b\x32\x19.node_service.DeviceFlops\"L\n\x11SendResultRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06result\x18\x02 \x03(\x05\x12\x13\n\x0bis_finished\x18\x03 \x01(\x08\"=\n\x17SendOpaqueStatusRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\t\"\x14\n\x12HealthCheckRequest\")\n\x13HealthCheckResponse\x12\x12\n\nis_healthy\x18\x01 \x01(\x08\"\x07\n\x05\x45mpty2\xf7\x04\n\x0bNodeService\x12\x41\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\x0bSendExample\x12\x1c.node_service.ExampleRequest\x1a\x12.node_service.Loss\"\x00\x12^\n\x12GetInferenceResult\x12\'.node_service.GetInferenceResultRequest\x1a\x1d.node_service.InferenceResult\"\x00\x12Q\n\x0f\x43ollectTopology\x12$.node_service.CollectTopologyRequest\x1a\x16.node_service.Topology\"\x00\x12\x44\n\nSendResult\x12\x1f.node_service.SendResultRequest\x1a\x13.node_service.Empty\"\x00\x12P\n\x10SendOpaqueStatus\x12%.node_service.SendOpaqueStatusRequest\x1a\x13.node_service.Empty\"\x00\x12T\n\x0bHealthCheck\x12 .node_service.HealthCheckRequest\x1a!.node_service.HealthCheckResponse\"\x00\x62\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"\xbb\x01\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x12:\n\x0finference_state\x18\x04 \x01(\x0b\x32\x1c.node_service.InferenceStateH\x01\x88\x01\x01\x42\r\n\x0b_request_idB\x12\n\x10_inference_state\"\xd1\x01\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x12:\n\x0finference_state\x18\x04 \x01(\x0b\x32\x1c.node_service.InferenceStateH\x01\x88\x01\x01\x42\r\n\x0b_request_idB\x12\n\x10_inference_state\"\xde\x01\n\x0e\x45xampleRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12%\n\x07\x65xample\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12$\n\x06target\x18\x03 \x01(\x0b\x32\x14.node_service.Tensor\x12$\n\x06length\x18\x04 \x01(\x0b\x32\x14.node_service.Tensor\x12\r\n\x05train\x18\x05 \x01(\x08\x12\x17\n\nrequest_id\x18\x06 \x01(\tH\x00\x88\x01\x01\x42\r\n\x0b_request_id\"H\n\x04Loss\x12\x0c\n\x04loss\x18\x01 \x01(\x02\x12(\n\x05grads\x18\x02 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x42\x08\n\x06_grads\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"3\n\nTensorList\x12%\n\x07tensors\x18\x01 \x03(\x0b\x32\x14.node_service.Tensor\"\xd2\x02\n\x0eInferenceState\x12\x41\n\x0btensor_data\x18\x01 \x03(\x0b\x32,.node_service.InferenceState.TensorDataEntry\x12J\n\x10tensor_list_data\x18\x02 \x03(\x0b\x32\x30.node_service.InferenceState.TensorListDataEntry\x12\x17\n\x0fother_data_json\x18\x03 \x01(\t\x1aG\n\x0fTensorDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor:\x02\x38\x01\x1aO\n\x13TensorListDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\'\n\x05value\x18\x02 \x01(\x0b\x32\x18.node_service.TensorList:\x02\x38\x01\"<\n\x16\x43ollectTopologyRequest\x12\x0f\n\x07visited\x18\x01 \x03(\t\x12\x11\n\tmax_depth\x18\x02 \x01(\x05\"\x98\x02\n\x08Topology\x12\x30\n\x05nodes\x18\x01 \x03(\x0b\x32!.node_service.Topology.NodesEntry\x12\x39\n\npeer_graph\x18\x02 \x03(\x0b\x32%.node_service.Topology.PeerGraphEntry\x1aN\n\nNodesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .node_service.DeviceCapabilities:\x02\x38\x01\x1aO\n\x0ePeerGraphEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12,\n\x05value\x18\x02 \x01(\x0b\x32\x1d.node_service.PeerConnections:\x02\x38\x01\"I\n\x0ePeerConnection\x12\r\n\x05to_id\x18\x01 \x01(\t\x12\x18\n\x0b\x64\x65scription\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x0e\n\x0c_description\"D\n\x0fPeerConnections\x12\x31\n\x0b\x63onnections\x18\x01 \x03(\x0b\x32\x1c.node_service.PeerConnection\"7\n\x0b\x44\x65viceFlops\x12\x0c\n\x04\x66p32\x18\x01 \x01(\x01\x12\x0c\n\x04\x66p16\x18\x02 \x01(\x01\x12\x0c\n\x04int8\x18\x03 \x01(\x01\"k\n\x12\x44\x65viceCapabilities\x12\r\n\x05model\x18\x01 \x01(\t\x12\x0c\n\x04\x63hip\x18\x02 \x01(\t\x12\x0e\n\x06memory\x18\x03 \x01(\x05\x12(\n\x05\x66lops\x18\x04 \x01(\x0b\x32\x19.node_service.DeviceFlops\"\x82\x01\n\x11SendResultRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06result\x18\x02 \x03(\x05\x12)\n\x06tensor\x18\x03 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x12\x13\n\x0bis_finished\x18\x04 \x01(\x08\x42\t\n\x07_tensor\"=\n\x17SendOpaqueStatusRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\t\"\x14\n\x12HealthCheckRequest\")\n\x13HealthCheckResponse\x12\x12\n\nis_healthy\x18\x01 \x01(\x08\"\x07\n\x05\x45mpty2\x97\x04\n\x0bNodeService\x12\x41\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\x0bSendExample\x12\x1c.node_service.ExampleRequest\x1a\x12.node_service.Loss\"\x00\x12Q\n\x0f\x43ollectTopology\x12$.node_service.CollectTopologyRequest\x1a\x16.node_service.Topology\"\x00\x12\x44\n\nSendResult\x12\x1f.node_service.SendResultRequest\x1a\x13.node_service.Empty\"\x00\x12P\n\x10SendOpaqueStatus\x12%.node_service.SendOpaqueStatusRequest\x1a\x13.node_service.Empty\"\x00\x12T\n\x0bHealthCheck\x12 .node_service.HealthCheckRequest\x1a!.node_service.HealthCheckResponse\"\x00\x62\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'node_service_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_INFERENCESTATE_TENSORDATAENTRY']._loaded_options = None
_globals['_INFERENCESTATE_TENSORDATAENTRY']._serialized_options = b'8\001'
_globals['_INFERENCESTATE_TENSORLISTDATAENTRY']._loaded_options = None
_globals['_INFERENCESTATE_TENSORLISTDATAENTRY']._serialized_options = b'8\001'
_globals['_TOPOLOGY_NODESENTRY']._loaded_options = None
_globals['_TOPOLOGY_NODESENTRY']._serialized_options = b'8\001'
_globals['_TOPOLOGY_PEERGRAPHENTRY']._loaded_options = None
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_options = b'8\001'
_globals['_SHARD']._serialized_start=36
_globals['_SHARD']._serialized_end=119
_globals['_PROMPTREQUEST']._serialized_start=121
_globals['_PROMPTREQUEST']._serialized_end=228
_globals['_TENSORREQUEST']._serialized_start=231
_globals['_TENSORREQUEST']._serialized_end=360
_globals['_EXAMPLEREQUEST']._serialized_start=363
_globals['_EXAMPLEREQUEST']._serialized_end=585
_globals['_LOSS']._serialized_start=587
_globals['_LOSS']._serialized_end=659
_globals['_GETINFERENCERESULTREQUEST']._serialized_start=661
_globals['_GETINFERENCERESULTREQUEST']._serialized_end=708
_globals['_INFERENCERESULT']._serialized_start=710
_globals['_INFERENCERESULT']._serialized_end=802
_globals['_TENSOR']._serialized_start=804
_globals['_TENSOR']._serialized_end=863
_globals['_COLLECTTOPOLOGYREQUEST']._serialized_start=865
_globals['_COLLECTTOPOLOGYREQUEST']._serialized_end=925
_globals['_TOPOLOGY']._serialized_start=928
_globals['_TOPOLOGY']._serialized_end=1208
_globals['_TOPOLOGY_NODESENTRY']._serialized_start=1049
_globals['_TOPOLOGY_NODESENTRY']._serialized_end=1127
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_start=1129
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_end=1208
_globals['_PEERCONNECTION']._serialized_start=1210
_globals['_PEERCONNECTION']._serialized_end=1283
_globals['_PEERCONNECTIONS']._serialized_start=1285
_globals['_PEERCONNECTIONS']._serialized_end=1353
_globals['_DEVICEFLOPS']._serialized_start=1355
_globals['_DEVICEFLOPS']._serialized_end=1410
_globals['_DEVICECAPABILITIES']._serialized_start=1412
_globals['_DEVICECAPABILITIES']._serialized_end=1519
_globals['_SENDRESULTREQUEST']._serialized_start=1521
_globals['_SENDRESULTREQUEST']._serialized_end=1597
_globals['_SENDOPAQUESTATUSREQUEST']._serialized_start=1599
_globals['_SENDOPAQUESTATUSREQUEST']._serialized_end=1660
_globals['_HEALTHCHECKREQUEST']._serialized_start=1662
_globals['_HEALTHCHECKREQUEST']._serialized_end=1682
_globals['_HEALTHCHECKRESPONSE']._serialized_start=1684
_globals['_HEALTHCHECKRESPONSE']._serialized_end=1725
_globals['_EMPTY']._serialized_start=1727
_globals['_EMPTY']._serialized_end=1734
_globals['_NODESERVICE']._serialized_start=1737
_globals['_NODESERVICE']._serialized_end=2368
_globals['_PROMPTREQUEST']._serialized_start=122
_globals['_PROMPTREQUEST']._serialized_end=309
_globals['_TENSORREQUEST']._serialized_start=312
_globals['_TENSORREQUEST']._serialized_end=521
_globals['_EXAMPLEREQUEST']._serialized_start=524
_globals['_EXAMPLEREQUEST']._serialized_end=746
_globals['_LOSS']._serialized_start=748
_globals['_LOSS']._serialized_end=820
_globals['_TENSOR']._serialized_start=822
_globals['_TENSOR']._serialized_end=881
_globals['_TENSORLIST']._serialized_start=883
_globals['_TENSORLIST']._serialized_end=934
_globals['_INFERENCESTATE']._serialized_start=937
_globals['_INFERENCESTATE']._serialized_end=1275
_globals['_INFERENCESTATE_TENSORDATAENTRY']._serialized_start=1123
_globals['_INFERENCESTATE_TENSORDATAENTRY']._serialized_end=1194
_globals['_INFERENCESTATE_TENSORLISTDATAENTRY']._serialized_start=1196
_globals['_INFERENCESTATE_TENSORLISTDATAENTRY']._serialized_end=1275
_globals['_COLLECTTOPOLOGYREQUEST']._serialized_start=1277
_globals['_COLLECTTOPOLOGYREQUEST']._serialized_end=1337
_globals['_TOPOLOGY']._serialized_start=1340
_globals['_TOPOLOGY']._serialized_end=1620
_globals['_TOPOLOGY_NODESENTRY']._serialized_start=1461
_globals['_TOPOLOGY_NODESENTRY']._serialized_end=1539
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_start=1541
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_end=1620
_globals['_PEERCONNECTION']._serialized_start=1622
_globals['_PEERCONNECTION']._serialized_end=1695
_globals['_PEERCONNECTIONS']._serialized_start=1697
_globals['_PEERCONNECTIONS']._serialized_end=1765
_globals['_DEVICEFLOPS']._serialized_start=1767
_globals['_DEVICEFLOPS']._serialized_end=1822
_globals['_DEVICECAPABILITIES']._serialized_start=1824
_globals['_DEVICECAPABILITIES']._serialized_end=1931
_globals['_SENDRESULTREQUEST']._serialized_start=1934
_globals['_SENDRESULTREQUEST']._serialized_end=2064
_globals['_SENDOPAQUESTATUSREQUEST']._serialized_start=2066
_globals['_SENDOPAQUESTATUSREQUEST']._serialized_end=2127
_globals['_HEALTHCHECKREQUEST']._serialized_start=2129
_globals['_HEALTHCHECKREQUEST']._serialized_end=2149
_globals['_HEALTHCHECKRESPONSE']._serialized_start=2151
_globals['_HEALTHCHECKRESPONSE']._serialized_end=2192
_globals['_EMPTY']._serialized_start=2194
_globals['_EMPTY']._serialized_end=2201
_globals['_NODESERVICE']._serialized_start=2204
_globals['_NODESERVICE']._serialized_end=2739
# @@protoc_insertion_point(module_scope)

View File

@@ -49,11 +49,6 @@ class NodeServiceStub(object):
request_serializer=node__service__pb2.ExampleRequest.SerializeToString,
response_deserializer=node__service__pb2.Loss.FromString,
_registered_method=True)
self.GetInferenceResult = channel.unary_unary(
'/node_service.NodeService/GetInferenceResult',
request_serializer=node__service__pb2.GetInferenceResultRequest.SerializeToString,
response_deserializer=node__service__pb2.InferenceResult.FromString,
_registered_method=True)
self.CollectTopology = channel.unary_unary(
'/node_service.NodeService/CollectTopology',
request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
@@ -97,12 +92,6 @@ class NodeServiceServicer(object):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def GetInferenceResult(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def CollectTopology(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
@@ -145,11 +134,6 @@ def add_NodeServiceServicer_to_server(servicer, server):
request_deserializer=node__service__pb2.ExampleRequest.FromString,
response_serializer=node__service__pb2.Loss.SerializeToString,
),
'GetInferenceResult': grpc.unary_unary_rpc_method_handler(
servicer.GetInferenceResult,
request_deserializer=node__service__pb2.GetInferenceResultRequest.FromString,
response_serializer=node__service__pb2.InferenceResult.SerializeToString,
),
'CollectTopology': grpc.unary_unary_rpc_method_handler(
servicer.CollectTopology,
request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
@@ -262,33 +246,6 @@ class NodeService(object):
metadata,
_registered_method=True)
@staticmethod
def GetInferenceResult(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/node_service.NodeService/GetInferenceResult',
node__service__pb2.GetInferenceResultRequest.SerializeToString,
node__service__pb2.InferenceResult.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
@staticmethod
def CollectTopology(request,
target,

View File

@@ -1,7 +1,9 @@
import os
import asyncio
from exo.networking.discovery import Discovery
from typing import Dict, List, Callable
from typing import Dict, List, Callable, Optional
from concurrent.futures import ThreadPoolExecutor
from exo.networking.discovery import Discovery
from exo.topology.device_capabilities import DeviceCapabilities
from exo.networking.manual.network_topology_config import NetworkTopology, PeerConfig
from exo.helpers import DEBUG_DISCOVERY
@@ -13,28 +15,25 @@ class ManualDiscovery(Discovery):
self,
network_config_path: str,
node_id: str,
create_peer_handle: Callable[[str, str, DeviceCapabilities], PeerHandle],
create_peer_handle: Callable[[str, str, str, DeviceCapabilities], PeerHandle],
):
self.topology = NetworkTopology.from_path(network_config_path)
self.network_config_path = network_config_path
self.node_id = node_id
self.create_peer_handle = create_peer_handle
if node_id not in self.topology.peers:
raise ValueError(
f"Node ID {node_id} not found in network config file {network_config_path}. Please run with `node_id` set to one of the keys in the config file: {[k for k, _ in self.topology.peers]}"
)
self.listen_task = None
self.known_peers: Dict[str, PeerHandle] = {}
self.peers_in_network: Dict[str, PeerConfig] = self.topology.peers
self.peers_in_network.pop(node_id)
self._cached_peers: Dict[str, PeerConfig] = {}
self._last_modified_time: Optional[float] = None
self._file_executor = ThreadPoolExecutor(max_workers=1)
async def start(self) -> None:
self.listen_task = asyncio.create_task(self.task_find_peers_from_config())
async def stop(self) -> None:
if self.listen_task:
self.listen_task.cancel()
if self.listen_task: self.listen_task.cancel()
self._file_executor.shutdown(wait=True)
async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
if wait_for_peers > 0:
@@ -47,7 +46,9 @@ class ManualDiscovery(Discovery):
async def task_find_peers_from_config(self):
if DEBUG_DISCOVERY >= 2: print("Starting task to find peers from config...")
while True:
for peer_id, peer_config in self.peers_in_network.items():
peers_from_config = await self._get_peers()
new_known_peers = {}
for peer_id, peer_config in peers_from_config.items():
try:
if DEBUG_DISCOVERY >= 2: print(f"Checking peer {peer_id=} at {peer_config.address}:{peer_config.port}")
peer = self.known_peers.get(peer_id)
@@ -57,15 +58,44 @@ class ManualDiscovery(Discovery):
is_healthy = await peer.health_check()
if is_healthy:
if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.")
self.known_peers[peer_id] = peer
else:
if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy.")
try:
del self.known_peers[peer_id]
except KeyError:
pass
new_known_peers[peer_id] = peer
elif DEBUG_DISCOVERY >= 2:
print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy. Removing.")
except Exception as e:
if DEBUG_DISCOVERY >= 2: print(f"Exception occured when attempting to add {peer_id=}: {e}")
await asyncio.sleep(1.0)
if DEBUG_DISCOVERY >= 2: print(f"Exception occurred when attempting to add {peer_id=}: {e}")
self.known_peers = new_known_peers
await asyncio.sleep(5.0)
if DEBUG_DISCOVERY >= 2: print(f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}")
async def _get_peers(self):
try:
loop = asyncio.get_running_loop()
current_mtime = await loop.run_in_executor(self._file_executor, os.path.getmtime, self.network_config_path)
if (self._cached_peers is not None and self._last_modified_time is not None and current_mtime <= self._last_modified_time):
return self._cached_peers
topology = await loop.run_in_executor(self._file_executor, NetworkTopology.from_path, self.network_config_path)
if self.node_id not in topology.peers:
raise ValueError(
f"Node ID {self.node_id} not found in network config file "
f"{self.network_config_path}. Please run with `node_id` set to "
f"one of the keys in the config file: {[k for k, _ in topology.peers]}"
)
peers_in_network = topology.peers
peers_in_network.pop(self.node_id)
self._cached_peers = peers_in_network
self._last_modified_time = current_mtime
return peers_in_network
except Exception as e:
if DEBUG_DISCOVERY >= 2:
print(f"Error when loading network config file from {self.network_config_path}. "
f"Please update the config file in order to successfully discover peers. "
f"Exception: {e}")
return self._cached_peers

View File

@@ -29,4 +29,4 @@
}
}
}
}
}

View File

@@ -1,3 +1,4 @@
import json
import asyncio
import unittest
from unittest import mock
@@ -14,8 +15,12 @@ class TestSingleNodeManualDiscovery(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
self.peer1 = mock.AsyncMock()
self.peer1.connect = mock.AsyncMock()
self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1)
_ = self.discovery1.start()
self.discovery1 = ManualDiscovery(
root_path,
"node1",
create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1,
)
await self.discovery1.start()
async def asyncTearDown(self):
await self.discovery1.stop()
@@ -33,8 +38,16 @@ class TestManualDiscovery(unittest.IsolatedAsyncioTestCase):
self.peer2 = mock.AsyncMock()
self.peer1.connect = mock.AsyncMock()
self.peer2.connect = mock.AsyncMock()
self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1)
self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2)
self.discovery1 = ManualDiscovery(
root_path,
"node1",
create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer1,
)
self.discovery2 = ManualDiscovery(
root_path,
"node2",
create_peer_handle=lambda peer_id, address, description, device_capabilities: self.peer2,
)
await self.discovery1.start()
await self.discovery2.start()
@@ -63,8 +76,16 @@ class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
self.server2 = GRPCServer(self.node2, config.peers["node2"].address, config.peers["node2"].port)
await self.server1.start()
await self.server2.start()
self.discovery1 = ManualDiscovery(root_path, "node1", create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
self.discovery2 = ManualDiscovery(root_path, "node2", create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
self.discovery1 = ManualDiscovery(
root_path,
"node1",
create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
)
self.discovery2 = ManualDiscovery(
root_path,
"node2",
create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
)
await self.discovery1.start()
await self.discovery2.start()
@@ -98,6 +119,63 @@ class TestManualDiscoveryWithGRPCPeerHandle(unittest.IsolatedAsyncioTestCase):
self.assertFalse(await peers1[0].is_connected())
self.assertFalse(await peers2[0].is_connected())
async def test_dynamic_config_update(self):
initial_peers = await self.discovery1.discover_peers(wait_for_peers=1)
self.assertEqual(len(initial_peers), 1)
# Save original config for cleanup
with open(root_path, "r") as f:
original_config = json.load(f)
try:
updated_config = {
"peers": {
**original_config["peers"],
"node3": {
"address": "localhost",
"port": 50053,
"device_capabilities": {
"model": "Unknown Model",
"chip": "Unknown Chip",
"memory": 0,
"flops": {"fp32": 0, "fp16": 0, "int8": 0},
},
},
}
}
with open(root_path, "w") as f:
json.dump(updated_config, f, indent=2)
node3 = mock.AsyncMock(spec=Node)
server3 = GRPCServer(node3, "localhost", 50053)
await server3.start()
try:
# Wait for the config to be reloaded
await asyncio.sleep(1.5)
updated_peers = await self.discovery1.discover_peers(wait_for_peers=2)
self.assertEqual(len(updated_peers), 2)
for peer in updated_peers:
await peer.connect()
self.assertTrue(await peer.is_connected())
finally:
await server3.stop()
finally:
# Restore the original config file
with open(root_path, "w") as f:
json.dump(original_config, f, indent=2)
# Wait for the config to be reloaded again
await asyncio.sleep(1.5)
updated_peers = await self.discovery1.discover_peers(wait_for_peers=1)
self.assertEqual(len(updated_peers), 1)
if __name__ == "__main__":
asyncio.run(unittest.main())

View File

@@ -51,10 +51,6 @@ class PeerHandle(ABC):
async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
pass
@abstractmethod
async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
pass
@abstractmethod
async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
pass

View File

@@ -40,7 +40,7 @@ class TailscaleDiscovery(Discovery):
self.update_task = None
async def start(self):
self.device_capabilities = device_capabilities()
self.device_capabilities = await device_capabilities()
self.discovery_task = asyncio.create_task(self.task_discover_peers())
self.cleanup_task = asyncio.create_task(self.task_cleanup_peers())
self.update_task = asyncio.create_task(self.task_update_device_posture_attributes())

View File

@@ -3,7 +3,7 @@ import json
import socket
import time
import traceback
from typing import List, Dict, Callable, Tuple, Coroutine
from typing import List, Dict, Callable, Tuple, Coroutine, Optional
from exo.networking.discovery import Discovery
from exo.networking.peer_handle import PeerHandle
from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
@@ -23,15 +23,29 @@ class ListenProtocol(asyncio.DatagramProtocol):
asyncio.create_task(self.on_message(data, addr))
def get_broadcast_address(ip_addr: str) -> str:
try:
# Split IP into octets and create broadcast address for the subnet
ip_parts = ip_addr.split('.')
return f"{ip_parts[0]}.{ip_parts[1]}.{ip_parts[2]}.255"
except:
return "255.255.255.255"
class BroadcastProtocol(asyncio.DatagramProtocol):
def __init__(self, message: str, broadcast_port: int):
def __init__(self, message: str, broadcast_port: int, source_ip: str):
self.message = message
self.broadcast_port = broadcast_port
self.source_ip = source_ip
def connection_made(self, transport):
sock = transport.get_extra_info("socket")
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
transport.sendto(self.message.encode("utf-8"), ("<broadcast>", self.broadcast_port))
# Try both subnet-specific and global broadcast
broadcast_addr = get_broadcast_address(self.source_ip)
transport.sendto(self.message.encode("utf-8"), (broadcast_addr, self.broadcast_port))
if broadcast_addr != "255.255.255.255":
transport.sendto(self.message.encode("utf-8"), ("255.255.255.255", self.broadcast_port))
class UDPDiscovery(Discovery):
@@ -45,7 +59,8 @@ class UDPDiscovery(Discovery):
broadcast_interval: int = 2.5,
discovery_timeout: int = 30,
device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES,
allowed_node_ids: List[str] = None,
allowed_node_ids: Optional[List[str]] = None,
allowed_interface_types: Optional[List[str]] = None,
):
self.node_id = node_id
self.node_port = node_port
@@ -56,13 +71,14 @@ class UDPDiscovery(Discovery):
self.discovery_timeout = discovery_timeout
self.device_capabilities = device_capabilities
self.allowed_node_ids = allowed_node_ids
self.allowed_interface_types = allowed_interface_types
self.known_peers: Dict[str, Tuple[PeerHandle, float, float, int]] = {}
self.broadcast_task = None
self.listen_task = None
self.cleanup_task = None
async def start(self):
self.device_capabilities = device_capabilities()
self.device_capabilities = await device_capabilities()
self.broadcast_task = asyncio.create_task(self.task_broadcast_presence())
self.listen_task = asyncio.create_task(self.task_listen_for_peers())
self.cleanup_task = asyncio.create_task(self.task_cleanup_peers())
@@ -82,11 +98,7 @@ class UDPDiscovery(Discovery):
return [peer_handle for peer_handle, _, _, _ in self.known_peers.values()]
async def task_broadcast_presence(self):
if DEBUG_DISCOVERY >= 2: print("Starting task_broadcast_presence...")
while True:
# Explicitly broadcasting on all assigned ips since broadcasting on `0.0.0.0` on MacOS does not broadcast over
# the Thunderbolt bridge when other connection modalities exist such as WiFi or Ethernet
for addr, interface_name in get_all_ip_addresses_and_interfaces():
interface_priority, interface_type = await get_interface_priority_and_type(interface_name)
message = json.dumps({
@@ -94,16 +106,26 @@ class UDPDiscovery(Discovery):
"node_id": self.node_id,
"grpc_port": self.node_port,
"device_capabilities": self.device_capabilities.to_dict(),
"priority": interface_priority, # TODO: Prioritise interfaces based on bandwidth, latency, and jitter e.g. prioritise Thunderbolt over WiFi.
"priority": interface_priority,
"interface_name": interface_name,
"interface_type": interface_type,
})
if DEBUG_DISCOVERY >= 3: print(f"Broadcasting presence at ({addr} - {interface_name} - {interface_priority}): {message}")
transport = None
try:
transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(lambda: BroadcastProtocol(message, self.broadcast_port), local_addr=(addr, 0), family=socket.AF_INET)
if DEBUG_DISCOVERY >= 3: print(f"Broadcasting presence at ({addr} - {interface_name} - {interface_priority})")
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
try:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
except AttributeError:
pass
sock.bind((addr, 0))
transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
lambda: BroadcastProtocol(message, self.broadcast_port, addr),
sock=sock
)
except Exception as e:
print(f"Error in broadcast presence ({addr} - {interface_name} - {interface_priority}): {e}")
finally:
@@ -111,7 +133,7 @@ class UDPDiscovery(Discovery):
try: transport.close()
except Exception as e:
if DEBUG_DISCOVERY >= 2: print(f"Error closing transport: {e}")
if DEBUG_DISCOVERY >= 2: traceback.print_exc()
await asyncio.sleep(self.broadcast_interval)
async def on_listen_message(self, data, addr):
@@ -147,6 +169,12 @@ class UDPDiscovery(Discovery):
peer_prio = message["priority"]
peer_interface_name = message["interface_name"]
peer_interface_type = message["interface_type"]
# Skip if interface type is not in allowed list
if self.allowed_interface_types and peer_interface_type not in self.allowed_interface_types:
if DEBUG_DISCOVERY >= 2: print(f"Ignoring peer {peer_id} as its interface type {peer_interface_type} is not in the allowed interface types list")
return
device_capabilities = DeviceCapabilities(**message["device_capabilities"])
if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":

View File

@@ -8,14 +8,14 @@ from typing import List, Dict, Optional, Tuple, Union, Set
from exo.networking import Discovery, PeerHandle, Server
from exo.inference.inference_engine import InferenceEngine, Shard
from exo.topology.topology import Topology
from exo.topology.device_capabilities import device_capabilities
from exo.topology.device_capabilities import device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
from exo.topology.partitioning_strategy import Partition, PartitioningStrategy, map_partitions_to_shards
from exo import DEBUG
from exo.helpers import AsyncCallbackSystem
from exo.viz.topology_viz import TopologyViz
from exo.download.hf.hf_helpers import RepoProgressEvent
from exo.download.download_progress import RepoProgressEvent
from exo.inference.inference_engine import get_inference_engine, InferenceEngine
from exo.download.hf.hf_shard_download import HFShardDownloader
from exo.download.shard_download import ShardDownloader
class Node:
def __init__(
@@ -24,20 +24,21 @@ class Node:
server: Server,
inference_engine: InferenceEngine,
discovery: Discovery,
shard_downloader: ShardDownloader,
partitioning_strategy: PartitioningStrategy = None,
max_generate_tokens: int = 1024,
default_sample_temperature: float = 0.0,
topology_viz: Optional[TopologyViz] = None,
shard_downloader: Optional[HFShardDownloader] = None,
):
self.id = _id
self.inference_engine = inference_engine
self.server = server
self.discovery = discovery
self.shard_downloader = shard_downloader
self.partitioning_strategy = partitioning_strategy
self.peers: List[PeerHandle] = {}
self.topology: Topology = Topology()
self.device_capabilities = device_capabilities()
self.device_capabilities = UNKNOWN_DEVICE_CAPABILITIES
self.buffered_token_output: Dict[str, Tuple[List[int], bool]] = {}
self.buffered_logits: Dict[str, List[np.ndarray]] = {}
self.buffered_inputs: Dict[str, List[np.ndarray]] = {}
@@ -52,10 +53,10 @@ class Node:
self._on_opaque_status.register("node_status").on_next(self.on_node_status)
self.node_download_progress: Dict[str, RepoProgressEvent] = {}
self.topology_inference_engines_pool: List[List[str]] = []
self.shard_downloader = shard_downloader
self.outstanding_requests = {}
async def start(self, wait_for_peers: int = 0) -> None:
self.device_capabilities = await device_capabilities()
await self.server.start()
await self.discovery.start()
await self.update_peers(wait_for_peers)
@@ -70,25 +71,28 @@ class Node:
def on_node_status(self, request_id, opaque_status):
try:
status_data = json.loads(opaque_status)
if status_data.get("type", "") == "supported_inference_engines":
status_type = status_data.get("type", "")
if status_type == "supported_inference_engines":
node_id = status_data.get("node_id")
engines = status_data.get("engines", [])
self.topology_inference_engines_pool.append(engines)
if status_data.get("type", "") == "node_status":
elif status_type == "node_status":
if status_data.get("status", "").startswith("start_"):
self.current_topology.active_node_id = status_data.get("node_id")
elif status_data.get("status", "").startswith("end_"):
if status_data.get("node_id") == self.current_topology.active_node_id:
self.current_topology.active_node_id = None
download_progress = None
if status_data.get("type", "") == "download_progress":
if status_type == "download_progress":
if DEBUG >= 8: print(f"Download progress from {status_data.get('node_id')}: {status_data.get('progress')}")
download_progress = RepoProgressEvent.from_dict(status_data.get('progress'))
self.node_download_progress[status_data.get('node_id')] = download_progress
if self.topology_viz:
self.topology_viz.update_visualization(self.topology, self.partitioning_strategy.partition(self.topology), self.id, self.node_download_progress)
except Exception as e:
if DEBUG >= 1: print(f"Error updating visualization: {e}")
if DEBUG >= 1: print(f"Error on_node_status: {e}")
if DEBUG >= 1: traceback.print_exc()
def get_supported_inference_engines(self):
@@ -107,44 +111,58 @@ class Node:
def get_topology_inference_engines(self) -> List[List[str]]:
return self.topology_inference_engines_pool
token_count = 0
first_token_time = 0
async def process_inference_result(
self,
shard,
result: np.ndarray,
request_id: Optional[str] = None,
inference_state: Optional[dict] = None,
):
if request_id not in self.buffered_token_output:
self.buffered_token_output[request_id] = ([], False)
is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
if shard.is_last_layer() and not is_finished:
token = await self.inference_engine.sample(result, temp=self.default_sample_temperature)
await self.inference_engine.ensure_shard(shard)
self.buffered_token_output[request_id][0].append(token.item())
is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id or is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
asyncio.create_task(self.broadcast_result(request_id, *self.buffered_token_output[request_id]))
forward = token.reshape(1, -1)
self.trigger_on_token_callbacks(request_id, self.buffered_token_output[request_id][0], is_finished)
asyncio.create_task(self.broadcast_result(request_id, self.buffered_token_output[request_id][0], is_finished))
if shard.model_id != 'stable-diffusion-2-1-base':
if request_id not in self.buffered_token_output:
self.buffered_token_output[request_id] = ([], False)
is_finished = len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
if shard.is_last_layer() and not is_finished:
token = await self.inference_engine.sample(result, temp=self.default_sample_temperature)
await self.inference_engine.ensure_shard(shard)
self.buffered_token_output[request_id][0].append(token.item())
is_finished = token.item() == self.inference_engine.tokenizer.eos_token_id or is_finished or len(self.buffered_token_output[request_id][0]) >= self.max_generate_tokens
if DEBUG >= 2: print(f"[{request_id}] result size: {result.size}, is finished: {is_finished}, buffered tokens: {len(self.buffered_token_output[request_id][0])}")
forward = token.reshape(1, -1)
intermediate_result = [self.buffered_token_output[request_id][0][-1]]
else:
forward = result
else:
await self.inference_engine.ensure_shard(shard)
is_finished = inference_state.get("is_finished", False)
intermediate_result, inference_state = self.handle_stable_diffusion(inference_state, result)
forward = result
if shard.is_last_layer():
self.trigger_on_token_callbacks(request_id, intermediate_result, is_finished)
asyncio.create_task(self.broadcast_result(request_id, intermediate_result, is_finished))
if is_finished:
self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
if shard.model_id != 'stable-diffusion-2-1-base':
self.buffered_token_output[request_id] = (self.buffered_token_output[request_id][0], True)
self.outstanding_requests.pop(request_id)
else:
self.outstanding_requests[request_id] = "waiting"
asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1)))
asyncio.create_task(self.forward_tensor(shard, forward, request_id, self.get_partition_index(offset = 1), inference_state))
return np.array(self.buffered_token_output[request_id][0]) if shard.model_id != 'stable-diffusion-2-1-base' else intermediate_result
return np.array(self.buffered_token_output[request_id][0])
async def process_prompt(
self,
base_shard: Shard,
prompt: str,
request_id: Optional[str] = None,
inference_state: Optional[dict] = {},
) -> Optional[np.ndarray]:
shard = self.get_current_shard(base_shard)
start_time = time.perf_counter_ns()
asyncio.create_task(
self.broadcast_opaque_status(
request_id,
@@ -160,7 +178,7 @@ class Node:
)
)
start_time = time.perf_counter_ns()
resp = await self._process_prompt(base_shard, prompt, request_id)
resp = await self._process_prompt(base_shard, prompt, request_id, inference_state)
end_time = time.perf_counter_ns()
elapsed_time_ns = end_time - start_time
asyncio.create_task(
@@ -175,27 +193,26 @@ class Node:
"prompt": prompt,
"request_id": request_id,
"elapsed_time_ns": elapsed_time_ns,
"result_size": resp.size if resp is not None else 0,
}),
)
)
return resp
if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=} {elapsed_time_ns=}")
async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None) -> Optional[np.ndarray]:
async def _process_prompt(self, base_shard: Shard, prompt: str, request_id: Optional[str] = None, inference_state: Optional[dict] = None) -> Optional[np.ndarray]:
if request_id is None:
request_id = str(uuid.uuid4())
shard = self.get_current_shard(base_shard)
if DEBUG >= 2: print(f"[{request_id}] process prompt: {base_shard=} {shard=} {prompt=}")
if not shard.is_first_layer():
if DEBUG >= 2: print(f"[{request_id}] forwarding to next shard: {base_shard=} {shard=} {prompt=}")
self.outstanding_requests[request_id] = "waiting"
resp = await self.forward_prompt(shard, prompt, request_id, 0)
resp = await self.forward_prompt(shard, prompt, request_id, 0, inference_state)
return None
else:
self.outstanding_requests[request_id] = "processing"
result = await self.inference_engine.infer_prompt(request_id, shard, prompt)
ret = await self.process_inference_result(shard, result, request_id)
result, inference_state = await self.inference_engine.infer_prompt(request_id, shard, prompt, inference_state)
ret = await self.process_inference_result(shard, result, request_id, inference_state)
return result
async def enqueue_example(
@@ -308,7 +325,7 @@ class Node:
loss, grad = await self.inference_engine.train(request_id, shard, example, target, length)
else:
self.outstanding_requests[request_id] = "preprocessing"
step = await self.inference_engine.infer_tensor(request_id, shard, example)
step, _ = await self.inference_engine.infer_tensor(request_id, shard, example)
self.outstanding_requests[request_id] = "waiting"
loss, backgrad = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset = 1))
self.outstanding_requests[request_id] = "training"
@@ -324,7 +341,7 @@ class Node:
loss = await self.inference_engine.evaluate(request_id, shard, example, target, length)
else:
self.outstanding_requests[request_id] = "preprocessing"
step = await self.inference_engine.infer_tensor(request_id, shard, example)
step, _ = await self.inference_engine.infer_tensor(request_id, shard, example)
self.outstanding_requests[request_id] = "waiting"
loss = await self.forward_example(shard, step, target, length, train, request_id, self.get_partition_index(offset = 1))
self.outstanding_requests.pop(request_id)
@@ -340,65 +357,35 @@ class Node:
base_shard: Shard,
tensor: np.ndarray,
request_id: Optional[str] = None,
inference_state: Optional[dict] = None,
) -> Optional[np.ndarray]:
shard = self.get_current_shard(base_shard)
asyncio.create_task(
self.broadcast_opaque_status(
request_id,
json.dumps({
"type": "node_status",
"node_id": self.id,
"status": "start_process_tensor",
"base_shard": base_shard.to_dict(),
"shard": shard.to_dict(),
"tensor_size": tensor.size,
"tensor_shape": tensor.shape,
"request_id": request_id,
}),
)
)
start_time = time.perf_counter_ns()
resp = await self._process_tensor(shard, tensor, request_id)
resp = await self._process_tensor(shard, tensor, request_id, inference_state)
end_time = time.perf_counter_ns()
elapsed_time_ns = end_time - start_time
asyncio.create_task(
self.broadcast_opaque_status(
request_id,
json.dumps({
"type": "node_status",
"node_id": self.id,
"status": "end_process_tensor",
"base_shard": base_shard.to_dict(),
"shard": shard.to_dict(),
"request_id": request_id,
"elapsed_time_ns": elapsed_time_ns,
"result_size": resp.size if resp is not None else 0,
}),
)
)
return resp
if DEBUG >= 2: print(f"[{request_id}] process_tensor: {base_shard=} {shard=} {tensor.size=} {tensor.shape=} {elapsed_time_ns=}")
async def _process_tensor(
self,
base_shard: Shard,
tensor: np.ndarray,
request_id: Optional[str] = None,
inference_state: Optional[dict] = None,
) -> Optional[np.ndarray]:
if request_id is None:
request_id = str(uuid.uuid4())
shard = self.get_current_shard(base_shard)
if DEBUG >= 1: print(f"[{request_id}] process_tensor: {tensor.size=} {tensor.shape=}")
try:
self.outstanding_requests[request_id] = "processing"
result = await self.inference_engine.infer_tensor(request_id, shard, tensor)
ret = await self.process_inference_result(shard, result, request_id)
result, inference_state = await self.inference_engine.infer_tensor(request_id, shard, tensor, inference_state)
ret = await self.process_inference_result(shard, result, request_id, inference_state)
return ret
except Exception as e:
self.outstanding_requests.pop(request_id)
print(f"Error processing tensor for shard {shard}: {e}")
traceback.print_exc()
return None
async def forward_example(
self,
@@ -427,19 +414,20 @@ class Node:
prompt: str,
request_id: str,
target_index: int,
inference_state: Optional[dict] = None,
) -> None:
if DEBUG >= 1: print(f"target partition index: {target_index}")
target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
next_shard = self.get_current_shard(base_shard, target_index)
if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. next shard: {next_shard}")
if target_id == self.id:
await self.process_prompt(next_shard, prompt, request_id)
await self.process_prompt(next_shard, prompt, request_id, inference_state)
else:
target_peer = next((p for p in self.peers if p.id() == target_id), None)
if not target_peer:
raise ValueError(f"Peer for {target_index} not found")
if DEBUG >= 1: print(f"Sending prompt to {target_peer.id()}: {prompt}")
await target_peer.send_prompt(next_shard, prompt, request_id=request_id)
await target_peer.send_prompt(next_shard, prompt, request_id=request_id, inference_state=inference_state)
async def forward_tensor(
self,
@@ -447,19 +435,20 @@ class Node:
tensor: np.ndarray,
request_id: str,
target_index: int,
inference_state: Optional[dict] = None,
) -> None:
if DEBUG >= 1: print(f"target partition index: {target_index}")
target_id = self.partitioning_strategy.partition(self.topology)[target_index].node_id
next_shard = self.get_current_shard(base_shard, target_index)
if DEBUG >= 2: print(f"Computed target from: {base_shard} {target_index}, {self.topology}. target shard: {next_shard}")
if target_id == self.id:
await self.process_tensor(next_shard, tensor, request_id)
await self.process_tensor(next_shard, tensor, request_id, inference_state)
else:
target_peer = next((p for p in self.peers if p.id() == target_id), None)
if not target_peer:
raise ValueError(f"Peer for {target_index} not found")
if DEBUG >= 1: print(f"Sending tensor to {target_peer.id()}: {tensor}")
await target_peer.send_tensor(next_shard, tensor, request_id=request_id)
await target_peer.send_tensor(next_shard, tensor, request_id=request_id, inference_state=inference_state)
def get_partition_index(self, offset: int = 0):
if not self.partitioning_strategy:
@@ -542,18 +531,13 @@ class Node:
try:
did_peers_change = await self.update_peers()
if DEBUG >= 2: print(f"{did_peers_change=}")
await self.collect_topology(set())
if did_peers_change:
await self.collect_topology(set())
await self.select_best_inference_engine()
except Exception as e:
print(f"Error collecting topology: {e}")
traceback.print_exc()
async def get_inference_result(self, request_id: str) -> Tuple[Optional[np.ndarray], bool]:
if request_id not in self.buffered_token_output:
return None, False
return np.array(self.buffered_token_output[request_id][0]), self.buffered_token_output[request_id][1]
async def collect_topology(self, visited: set[str], max_depth: int = 4) -> Topology:
next_topology = Topology()
next_topology.update_node(self.id, self.device_capabilities)
@@ -598,10 +582,11 @@ class Node:
return self._on_opaque_status
def trigger_on_token_callbacks(self, request_id: str, tokens: List[int], is_finished: bool) -> None:
if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} num_tokens={len(tokens)} {is_finished=}")
if DEBUG >= 2: print(f"Triggering all on_token callbacks with {request_id=} {tokens=} {is_finished=}")
self.on_token.trigger_all(request_id, tokens, is_finished)
async def broadcast_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
if DEBUG >= 2: print(f"Broadcasting result: {request_id=} {result=} {is_finished=}")
async def send_result_to_peer(peer):
try:
await asyncio.wait_for(peer.send_result(request_id, result, is_finished), timeout=15.0)
@@ -632,3 +617,12 @@ class Node:
@property
def current_topology(self) -> Topology:
return self.topology
def handle_stable_diffusion(self, inference_state, result):
if inference_state['is_step_finished']:
inference_state['step']+=1
progress = [inference_state['step'],inference_state['total_steps']]
intermediate_result = result
if progress[0] == progress[1]:
intermediate_result = result
return intermediate_result, inference_state

View File

@@ -1,10 +1,11 @@
import unittest
from unittest.mock import Mock, AsyncMock
import numpy as np
import pytest
from .node import Node
from exo.networking.peer_handle import PeerHandle
from exo.download.shard_download import NoopShardDownloader
class TestNode(unittest.IsolatedAsyncioTestCase):
def setUp(self):
@@ -21,7 +22,7 @@ class TestNode(unittest.IsolatedAsyncioTestCase):
mock_peer2.id.return_value = "peer2"
self.mock_discovery.discover_peers = AsyncMock(return_value=[mock_peer1, mock_peer2])
self.node = Node("test_node", self.mock_server, self.mock_inference_engine, "localhost", 50051, self.mock_discovery)
self.node = Node("test_node", self.mock_server, self.mock_inference_engine, "localhost", 50051, self.mock_discovery, NoopShardDownloader())
async def asyncSetUp(self):
await self.node.start()
@@ -55,3 +56,11 @@ class TestNode(unittest.IsolatedAsyncioTestCase):
await self.node.process_tensor(input_tensor, None)
self.node.inference_engine.process_shard.assert_called_once_with(input_tensor)
@pytest.mark.asyncio
async def test_node_capabilities():
node = Node()
await node.initialize()
caps = await node.get_device_capabilities()
assert caps is not None
assert caps.model != ""

View File

@@ -0,0 +1,166 @@
from dataclasses import dataclass
from typing import Dict, Optional, Any
from opentelemetry import trace, context
from opentelemetry.trace import Status, StatusCode, SpanContext
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from contextlib import contextmanager
import time
from threading import Lock
@dataclass
class TraceContext:
request_id: str
sequence_number: int
current_span: Optional[trace.Span] = None
trace_parent: Optional[str] = None
token_group_span: Optional[trace.Span] = None
token_count: int = 0
token_group_size: int = 10 # Default group size
request_span: Optional[trace.Span] = None # Track the main request span
class Tracer:
def __init__(self):
self.tracer = trace.get_tracer("exo")
self.contexts: Dict[str, TraceContext] = {}
self._lock = Lock()
self.propagator = TraceContextTextMapPropagator()
def get_context(self, request_id: str) -> Optional[TraceContext]:
with self._lock:
return self.contexts.get(request_id)
def set_context(self, request_id: str, context: TraceContext):
with self._lock:
self.contexts[request_id] = context
def inject_context(self, span: trace.Span) -> str:
"""Inject current span context into carrier for propagation"""
carrier = {}
ctx = trace.set_span_in_context(span)
self.propagator.inject(carrier, context=ctx)
return carrier.get("traceparent", "")
def extract_context(self, trace_parent: str) -> Optional[context.Context]:
"""Extract span context from carrier"""
if not trace_parent:
return None
carrier = {"traceparent": trace_parent}
return self.propagator.extract(carrier)
def create_context_from_parent(self, request_id: str, trace_parent: str, sequence_number: int = 0) -> TraceContext:
"""Create a new context with the given trace parent"""
parent_ctx = self.extract_context(trace_parent)
if parent_ctx:
# Create a new request span that links to the parent context
request_span = self.tracer.start_span(
"request",
context=parent_ctx,
attributes={
"request_id": request_id,
"sequence_number": sequence_number
}
)
return TraceContext(
request_id=request_id,
sequence_number=sequence_number,
request_span=request_span,
current_span=request_span,
trace_parent=trace_parent
)
return TraceContext(request_id=request_id, sequence_number=sequence_number)
def handle_token(self, context: TraceContext, token: int, is_finished: bool = False):
"""Handle token generation and manage token group spans"""
context.token_count += 1
# Start a new token group span if needed
if not context.token_group_span and context.request_span:
group_number = (context.token_count - 1) // context.token_group_size + 1
# Create token group span as child of request span
parent_ctx = trace.set_span_in_context(context.request_span)
context.token_group_span = self.tracer.start_span(
f"token_group_{group_number}",
context=parent_ctx,
attributes={
"request_id": context.request_id,
"group.number": group_number,
"group.start_token": context.token_count,
"group.max_tokens": context.token_group_size
}
)
# Add token to current group span
if context.token_group_span:
relative_pos = ((context.token_count - 1) % context.token_group_size) + 1
context.token_group_span.set_attribute(f"token.{relative_pos}", token)
context.token_group_span.set_attribute("token.count", relative_pos)
# End current group span if we've reached the group size or if generation is finished
if context.token_count % context.token_group_size == 0 or is_finished:
context.token_group_span.set_attribute("token.final_count", relative_pos)
context.token_group_span.end()
context.token_group_span = None
@contextmanager
def start_span(self, name: str, context: TraceContext, extra_attributes: Optional[Dict[str, Any]] = None):
"""Start a new span with proper parent context"""
attributes = {
"request_id": context.request_id,
"sequence_number": context.sequence_number
}
if extra_attributes:
attributes.update(extra_attributes)
# Use request span as parent if available
parent_ctx = None
if context.request_span:
parent_ctx = trace.set_span_in_context(context.request_span)
elif context.trace_parent:
parent_ctx = self.extract_context(context.trace_parent)
if parent_ctx and not context.request_span:
# Create a new request span that links to the parent context
context.request_span = self.tracer.start_span(
"request",
context=parent_ctx,
attributes={
"request_id": context.request_id,
"sequence_number": context.sequence_number
}
)
parent_ctx = trace.set_span_in_context(context.request_span)
elif context.current_span:
parent_ctx = trace.set_span_in_context(context.current_span)
# Create span with parent context if it exists
if parent_ctx:
span = self.tracer.start_span(
name,
context=parent_ctx,
attributes=attributes
)
else:
span = self.tracer.start_span(
name,
attributes=attributes
)
# Update context with current span
prev_span = context.current_span
context.current_span = span
try:
start_time = time.perf_counter()
yield span
duration = time.perf_counter() - start_time
span.set_attribute("duration_s", duration)
span.set_status(Status(StatusCode.OK))
except Exception as e:
span.set_status(Status(StatusCode.ERROR, str(e)))
raise
finally:
span.end()
context.current_span = prev_span
# Global tracer instance
tracer = Tracer()

View File

View File

@@ -1,27 +0,0 @@
version: '3.8'
services:
prometheus:
image: prom/prometheus:latest
container_name: prometheus
volumes:
- ./prometheus.yml:/etc/prometheus/prometheus.yml
command:
- '--config.file=/etc/prometheus/prometheus.yml'
ports:
- "9090:9090"
networks:
- monitoring
grafana:
image: grafana/grafana:latest
container_name: grafana
ports:
- "3000:3000"
networks:
- monitoring
depends_on:
- prometheus
networks:
monitoring:

View File

@@ -1,29 +0,0 @@
from exo.orchestration import Node
from prometheus_client import start_http_server, Counter, Histogram
import json
# Create metrics to track time spent and requests made.
PROCESS_PROMPT_COUNTER = Counter("process_prompt_total", "Total number of prompts processed", ["node_id"])
PROCESS_TENSOR_COUNTER = Counter("process_tensor_total", "Total number of tensors processed", ["node_id"])
PROCESS_TENSOR_TIME = Histogram("process_tensor_seconds", "Time spent processing tensor", ["node_id"])
def start_metrics_server(node: Node, port: int):
start_http_server(port)
def _on_opaque_status(request_id, opaque_status: str):
status_data = json.loads(opaque_status)
_type = status_data.get("type", "")
node_id = status_data.get("node_id", "")
if _type != "node_status":
return
status = status_data.get("status", "")
if status == "end_process_prompt":
PROCESS_PROMPT_COUNTER.labels(node_id=node_id).inc()
elif status == "end_process_tensor":
elapsed_time_ns = status_data.get("elapsed_time_ns", 0)
PROCESS_TENSOR_COUNTER.labels(node_id=node_id).inc()
PROCESS_TENSOR_TIME.labels(node_id=node_id).observe(elapsed_time_ns/1e9) # Convert ns to seconds
node.on_opaque_status.register("stats").on_next(_on_opaque_status)

View File

@@ -1,7 +0,0 @@
global:
scrape_interval: 15s
scrape_configs:
- job_name: 'exo-node'
static_configs:
- targets: ['host.docker.internal:8005']

View File

@@ -139,6 +139,14 @@ main {
padding: 0.5rem 1rem;
border-radius: 20px;
}
@media(max-width: 1482px) {
.messages {
padding-left: 16px;
padding-right: 16px;
}
}
.message-role-assistant {
background-color: var(--secondary-color);
margin-right: auto;
@@ -149,6 +157,12 @@ main {
background-color: var(--primary-color);
color: #000;
}
.message-role-user p {
white-space: pre-wrap;
word-wrap: break-word;
}
.download-progress {
position: fixed;
bottom: 11rem;
@@ -654,4 +668,179 @@ main {
.model-download-button i {
font-size: 0.9em;
}
}
.topology-section {
margin-bottom: 30px;
padding: 15px;
background: rgba(255, 255, 255, 0.05);
border-radius: 8px;
}
.topology-visualization {
min-height: 150px;
position: relative;
margin-top: 10px;
}
.topology-loading {
display: flex;
align-items: center;
gap: 10px;
color: #666;
font-size: 0.9em;
}
.topology-node {
padding: 8px;
background: rgba(255, 255, 255, 0.05);
border-radius: 4px;
margin: 4px 0;
display: flex;
flex-direction: column;
gap: 4px;
}
.node-info {
display: flex;
align-items: center;
gap: 6px;
font-size: 0.9em;
}
.topology-node .status {
width: 6px;
height: 6px;
border-radius: 50%;
flex-shrink: 0;
}
.topology-node .status.active {
background: #4CAF50;
}
.topology-node .status.inactive {
background: #666;
}
.node-details {
padding-left: 12px;
display: flex;
flex-direction: column;
gap: 2px;
font-size: 0.8em;
opacity: 0.6;
}
.node-details span {
display: flex;
align-items: center;
}
.peer-connections {
margin-top: 8px;
padding-left: 12px;
display: flex;
flex-direction: column;
gap: 4px;
}
.peer-connection {
display: flex;
align-items: center;
gap: 8px;
font-size: 0.85em;
color: #a0a0a0;
}
.peer-connection i {
font-size: 0.8em;
color: #666;
}
.thinking-block {
background-color: rgba(255, 255, 255, 0.05);
border-radius: 8px;
margin: 8px 0;
overflow: hidden;
}
.thinking-header {
background-color: rgba(255, 255, 255, 0.1);
padding: 8px 12px;
font-size: 0.9em;
color: #a0a0a0;
display: flex;
align-items: center;
gap: 8px;
}
.thinking-content {
padding: 12px;
white-space: pre-wrap;
}
@keyframes thinking-spin {
to { transform: rotate(360deg); }
}
.thinking-header.thinking::before {
content: '';
width: 12px;
height: 12px;
border: 2px solid #a0a0a0;
border-top-color: transparent;
border-radius: 50%;
animation: thinking-spin 1s linear infinite;
}
.model-group {
margin-bottom: 12px;
}
.model-group-header,
.model-subgroup-header {
display: flex;
justify-content: space-between;
align-items: center;
padding: 8px 12px;
background-color: var(--primary-bg-color);
border-radius: 6px;
cursor: pointer;
transition: all 0.2s ease;
margin-bottom: 8px;
}
.model-group-header:hover,
.model-subgroup-header:hover {
background-color: var(--secondary-color-transparent);
}
.model-group-content {
padding-left: 12px;
}
.model-subgroup {
margin-bottom: 8px;
}
.model-subgroup-header {
font-size: 0.9em;
background-color: rgba(255, 255, 255, 0.05);
}
.model-subgroup-content {
padding-left: 12px;
}
.group-header-content {
display: flex;
align-items: center;
gap: 8px;
}
.model-count {
font-size: 0.8em;
color: var(--secondary-color-transparent);
font-family: monospace;
}

View File

@@ -22,75 +22,125 @@
<link href="/static/unpkg.com/@highlightjs/cdn-assets@11.9.0/styles/vs2015.min.css" rel="stylesheet"/>
<link href="/index.css" rel="stylesheet"/>
<link href="/common.css" rel="stylesheet"/>
<script src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
</head>
<body>
<main x-data="state" x-init="console.log(endpoint)">
<div class="sidebar">
<!-- Add topology section -->
<div class="topology-section">
<h2 class="megrim-regular">Network Topology</h2>
<div class="topology-visualization"
x-init="initTopology()"
x-ref="topologyViz">
<!-- Loading indicator for topology -->
<div class="topology-loading" x-show="!topology">
<i class="fas fa-spinner fa-spin"></i>
<span>Loading topology...</span>
</div>
<!-- Topology visualization will be rendered here -->
</div>
</div>
<h2 class="megrim-regular" style="margin-bottom: 20px;">Models</h2>
<div style="display: flex; align-items: center; margin-bottom: 10px;">
<label style="margin-right: 5px;">
<input type="checkbox" x-model="showDownloadedOnly" style="margin-right: 5px;">
Downloaded only
</label>
</div>
<!-- Loading indicator -->
<div class="loading-container" x-show="Object.keys(models).length === 0">
<i class="fas fa-spinner fa-spin"></i>
<span>Loading models...</span>
</div>
<template x-for="(model, key) in models" :key="key">
<div class="model-option"
:class="{ 'selected': cstate.selectedModel === key }"
@click="cstate.selectedModel = key">
<div class="model-header">
<div class="model-name" x-text="model.name"></div>
<button
@click.stop="deleteModel(key, model)"
class="model-delete-button"
x-show="model.download_percentage > 0">
<i class="fas fa-trash"></i>
</button>
</div>
<div class="model-info">
<div class="model-progress">
<template x-if="model.loading">
<span><i class="fas fa-spinner fa-spin"></i> Checking download status...</span>
</template>
<div class="model-progress-info">
<template x-if="!model.loading && model.download_percentage != null">
<span>
<!-- Check if there's an active download for this model -->
<template x-if="downloadProgress?.some(p =>
p.repo_id && p.repo_id.toLowerCase().includes(key.toLowerCase()) && !p.isComplete
)">
<i class="fas fa-circle-notch fa-spin"></i>
</template>
<span x-text="model.downloaded ? 'Downloaded' : `${Math.round(model.download_percentage)}% downloaded`"></span>
</span>
</template>
<template x-if="!model.loading && (model.download_percentage === null || model.download_percentage < 100) && !downloadProgress?.some(p => !p.isComplete)">
<button
@click.stop="handleDownload(key)"
class="model-download-button">
<i class="fas fa-download"></i>
<span x-text="(model.download_percentage > 0 && model.download_percentage < 100) ? 'Continue Downloading' : 'Download'"></span>
</button>
</template>
</div>
<!-- Group models by prefix -->
<template x-for="[mainPrefix, subGroups] in Object.entries(groupModelsByPrefix(models))" :key="mainPrefix">
<div class="model-group">
<div class="model-group-header" @click="toggleGroup(mainPrefix)">
<div class="group-header-content">
<span x-text="mainPrefix"></span>
<span class="model-count" x-text="getGroupCounts(Object.values(subGroups).flatMap(group => Object.values(group)))"></span>
</div>
<template x-if="model.total_size">
<div class="model-size" x-text="model.total_downloaded ?
`${formatBytes(model.total_downloaded)} / ${formatBytes(model.total_size)}` :
formatBytes(model.total_size)">
<i class="fas" :class="isGroupExpanded(mainPrefix) ? 'fa-chevron-down' : 'fa-chevron-right'"></i>
</div>
<div class="model-group-content" x-show="isGroupExpanded(mainPrefix)" x-transition>
<template x-for="[subPrefix, groupModels] in Object.entries(subGroups)" :key="subPrefix">
<div class="model-subgroup">
<div class="model-subgroup-header" @click.stop="toggleGroup(mainPrefix, subPrefix)">
<div class="group-header-content">
<span x-text="subPrefix"></span>
<span class="model-count" x-text="getGroupCounts(groupModels)"></span>
</div>
<i class="fas" :class="isGroupExpanded(mainPrefix, subPrefix) ? 'fa-chevron-down' : 'fa-chevron-right'"></i>
</div>
<div class="model-subgroup-content" x-show="isGroupExpanded(mainPrefix, subPrefix)" x-transition>
<template x-for="(model, key) in groupModels" :key="key">
<div class="model-option"
:class="{ 'selected': cstate.selectedModel === key }"
@click="cstate.selectedModel = key">
<div class="model-header">
<div class="model-name" x-text="model.name"></div>
<button
@click.stop="deleteModel(key, model)"
class="model-delete-button"
x-show="model.download_percentage > 0">
<i class="fas fa-trash"></i>
</button>
</div>
<div class="model-info">
<div class="model-progress">
<template x-if="model.loading">
<span><i class="fas fa-spinner fa-spin"></i> Checking download status...</span>
</template>
<div class="model-progress-info">
<template x-if="!model.loading && model.download_percentage != null">
<span>
<template x-if="downloadProgress?.some(p =>
p.repo_id && p.repo_id.toLowerCase().includes(key.toLowerCase()) && !p.isComplete
)">
<i class="fas fa-circle-notch fa-spin"></i>
</template>
<span x-text="model.downloaded ? 'Downloaded' : `${Math.round(model.download_percentage)}% downloaded`"></span>
</span>
</template>
<template x-if="!model.loading && (model.download_percentage === null || model.download_percentage < 100) && !downloadProgress?.some(p => !p.isComplete)">
<button
@click.stop="handleDownload(key)"
class="model-download-button">
<i class="fas fa-download"></i>
<span x-text="(model.download_percentage > 0 && model.download_percentage < 100) ? 'Continue Downloading' : 'Download'"></span>
</button>
</template>
</div>
</div>
<template x-if="model.total_size">
<div class="model-size" x-text="model.total_downloaded ?
`${formatBytes(model.total_downloaded)} / ${formatBytes(model.total_size)}` :
formatBytes(model.total_size)">
</div>
</template>
</div>
</div>
</template>
</div>
</div>
</template>
</div>
</div>
</template>
</div>
</div>
<!-- Error Toast -->
<div x-show="errorMessage !== null" x-transition.opacity class="toast">
<div class="toast-header">
<span class="toast-error-message" x-text="errorMessage?.basic || ''"></span>
<div class="toast-header-buttons">
<button @click="errorExpanded = !errorExpanded; if (errorTimeout) { clearTimeout(errorTimeout); errorTimeout = null; }"
class="toast-expand-button"
<button @click="errorExpanded = !errorExpanded; if (errorTimeout) { clearTimeout(errorTimeout); errorTimeout = null; }"
class="toast-expand-button"
x-show="errorMessage?.stack">
<span x-text="errorExpanded ? 'Hide Details' : 'Show Details'"></span>
</button>
@@ -119,8 +169,8 @@
" x-show="home === 0" x-transition="">
<h1 class="title megrim-regular">tinychat</h1>
<template x-if="histories.length">
<button
@click="if(confirm('Are you sure you want to clear all history?')) clearAllHistory();"
<button
@click="if(confirm('Are you sure you want to clear all history?')) clearAllHistory();"
class="clear-history-button">
<i class="fas fa-trash"></i> Clear All History
</button>
@@ -149,7 +199,7 @@
otx = $event.changedTouches[0].clientX;
" class="history" x-data="{ otx: 0, trigger: 75 }">
<h3 x-text="new Date(_state.time).toLocaleString()"></h3>
<p x-text="$truncate(_state.messages[0].content, 80)"></p>
<p x-text="$truncate(_state.messages[0].content, 76)"></p>
<!-- delete button -->
<button @click.stop="removeHistory(_state);" class="history-delete-button">
<i class="fas fa-trash"></i>
@@ -162,62 +212,101 @@
</template>
</div>
</div>
<button
</div>
<button
@click="
home = 0;
cstate = { time: null, messages: [], selectedModel: cstate.selectedModel };
time_till_first = 0;
tokens_per_second = 0;
total_tokens = 0;
"
"
class="back-button"
x-show="home === 2">
<i class="fas fa-arrow-left"></i>
Back to Chats
</button>
<div class="messages" x-init="
$watch('cstate', value =&gt; {
$el.innerHTML = '';
value.messages.forEach(({ role, content }) =&gt; {
const div = document.createElement('div');
div.className = `message message-role-${role}`;
try {
div.innerHTML = DOMPurify.sanitize(marked.parse(content));
} catch (e) {
console.log(content);
console.error(e);
<div class="messages"
x-init="
$watch('cstate', (value) => {
$el.innerHTML = '';
value.messages.forEach((msg) => {
const div = document.createElement('div');
div.className = `message message-role-${msg.role}`;
try {
// If there's an embedded generated image
if (msg.content.includes('![Generated Image]')) {
const imageUrlMatch = msg.content.match(/\((.*?)\)/);
if (imageUrlMatch) {
const imageUrl = imageUrlMatch[1];
const img = document.createElement('img');
img.src = imageUrl;
img.alt = 'Generated Image';
img.onclick = async () => {
try {
const response = await fetch(img.src);
const blob = await response.blob();
const file = new File([blob], 'image.png', { type: 'image/png' });
handleImageUpload({ target: { files: [file] } });
} catch (error) {
console.error('Error fetching image:', error);
}
};
div.appendChild(img);
} else {
// fallback if markdown is malformed
div.textContent = msg.content;
}
} else {
// Otherwise, transform message text (including streamed think blocks).
div.innerHTML = transformMessageContent(msg);
// Render math after content is inserted
MathJax.typesetPromise([div]);
}
} catch (e) {
console.error('Error rendering message:', e);
div.textContent = msg.content; // fallback
}
// add a clipboard button to all code blocks
const codeBlocks = div.querySelectorAll('.hljs');
codeBlocks.forEach(codeBlock =&gt; {
const button = document.createElement('button');
button.className = 'clipboard-button';
button.innerHTML = '&lt;i class=\'fas fa-clipboard\'&gt;&lt;/i&gt;';
button.onclick = () =&gt; {
// navigator.clipboard.writeText(codeBlock.textContent);
const range = document.createRange();
range.setStartBefore(codeBlock);
range.setEndAfter(codeBlock);
window.getSelection()?.removeAllRanges();
window.getSelection()?.addRange(range);
document.execCommand('copy');
window.getSelection()?.removeAllRanges();
// Add a clipboard button to code blocks
const codeBlocks = div.querySelectorAll('.hljs');
codeBlocks.forEach((codeBlock) => {
const button = document.createElement('button');
button.className = 'clipboard-button';
button.innerHTML = '<i class=\'fas fa-clipboard\'></i>';
button.innerHTML = '&lt;i class=\'fas fa-check\'&gt;&lt;/i&gt;';
setTimeout(() =&gt; button.innerHTML = '&lt;i class=\'fas fa-clipboard\'&gt;&lt;/i&gt;', 1000);
};
codeBlock.appendChild(button);
});
button.onclick = () => {
const range = document.createRange();
range.setStartBefore(codeBlock);
range.setEndAfter(codeBlock);
window.getSelection()?.removeAllRanges();
window.getSelection()?.addRange(range);
document.execCommand('copy');
window.getSelection()?.removeAllRanges();
$el.appendChild(div);
button.innerHTML = '<i class=\'fas fa-check\'></i>';
setTimeout(() => {
button.innerHTML = '<i class=\'fas fa-clipboard\'></i>';
}, 1000);
};
codeBlock.appendChild(button);
});
$el.scrollTo({ top: $el.scrollHeight, behavior: 'smooth' });
$el.appendChild(div);
});
" x-intersect="
// Scroll to bottom after rendering
$el.scrollTo({ top: $el.scrollHeight, behavior: 'smooth' });
" x-ref="messages" x-show="home === 2" x-transition="">
});
"
x-ref="messages"
x-show="home === 2"
x-transition=""
>
</div>
<!-- Download Progress Section -->
@@ -232,7 +321,7 @@
<p><strong>Model:</strong> <span x-text="progress.repo_id + '@' + progress.repo_revision"></span></p>
<p><strong>Status:</strong> <span x-text="progress.status"></span></p>
<div class="progress-bar-container">
<div class="progress-bar"
<div class="progress-bar"
:class="progress.isComplete ? 'complete' : 'in-progress'"
:style="`width: ${progress.percentage}%;`">
</div>
@@ -253,6 +342,10 @@
<div class="input-container">
<div class="input-performance">
<span class="input-performance-point">
<p class="monospace" x-text="models[cstate.selectedModel]?.name || cstate.selectedModel"></p>
<p class="megrim-regular">-</p>
</span>
<span class="input-performance-point">
<p class="monospace" x-text="(time_till_first / 1000).toFixed(2)"></p>
<p class="megrim-regular">SEC TO FIRST TOKEN</p>
</span>
@@ -266,7 +359,7 @@
</span>
</div>
<div class="input">
<button @click="$refs.imageUpload.click()" class="image-input-button" x-show="cstate.selectedModel === 'llava-1.5-7b-hf'">
<button @click="$refs.imageUpload.click()" class="image-input-button" x-show="cstate.selectedModel === 'llava-1.5-7b-hf' || cstate.selectedModel === 'stable-diffusion-2-1-base'">
<i class="fas fa-image"></i>
</button>
<input @change="$data.handleImageUpload($event)" accept="image/*" id="image-upload" style="display: none;" type="file" x-ref="imageUpload"/>
@@ -276,10 +369,10 @@
<i class="fas fa-times"></i>
</button>
</div>
<textarea
:disabled="generating || (downloadProgress?.length > 0 && downloadProgress.some(p => !p.isComplete))"
<textarea
:disabled="generating || (downloadProgress?.length > 0 && downloadProgress.some(p => !p.isComplete))"
:placeholder="
generating ? 'Generating...' :
generating ? 'Generating...' :
(downloadProgress?.length > 0 && downloadProgress.some(p => !p.isComplete)) ? 'Download in progress...' :
'Say something'
"
@@ -311,13 +404,51 @@
});
"
x-ref="inputForm"></textarea>
<button
:disabled="generating || (downloadProgress?.length > 0 && downloadProgress.some(p => !p.isComplete))"
@click="await handleSend()"
<button
:disabled="generating || (downloadProgress?.length > 0 && downloadProgress.some(p => !p.isComplete))"
@click="await handleSend()"
class="input-button">
<i :class="generating ? 'fa-spinner fa-spin' : 'fa-paper-plane'" class="fas"></i>
</button>
</div>
</div>
</main>
<script>
/**
* Transform a single message's content into HTML, preserving <think> blocks.
* Ensure LaTeX expressions are properly delimited for MathJax.
*/
function transformMessageContent(message) {
let text = message.content;
console.log('Processing message content:', text);
// First replace think blocks
text = text.replace(
/<think>([\s\S]*?)(?:<\/think>|$)/g,
(match, body) => {
console.log('Found think block with content:', body);
const isComplete = match.includes('</think>');
const spinnerClass = isComplete ? '' : ' thinking';
const parsedBody = DOMPurify.sanitize(marked.parse(body));
return `
<div class='thinking-block'>
<div class='thinking-header${spinnerClass}'>Thinking...</div>
<div class='thinking-content'>${parsedBody}</div>
</div>`;
}
);
// Add backslashes to parentheses and brackets for LaTeX
text = text
.replace(/\((?=\s*[\d\\])/g, '\\(') // Add backslash before opening parentheses
.replace(/\)(?!\w)/g, '\\)') // Add backslash before closing parentheses
.replace(/\[(?=\s*[\d\\])/g, '\\[') // Add backslash before opening brackets
.replace(/\](?!\w)/g, '\\]') // Add backslash before closing brackets
.replace(/\[[\s\n]*\\boxed/g, '\\[\\boxed') // Ensure boxed expressions are properly delimited
.replace(/\\!/g, '\\\\!'); // Preserve LaTeX spacing commands
return DOMPurify.sanitize(marked.parse(text));
}
</script>
</body>

View File

@@ -5,7 +5,7 @@ document.addEventListener("alpine:init", () => {
time: null,
messages: [],
selectedModel: 'llama-3.2-1b',
},
},
// historical state
histories: JSON.parse(localStorage.getItem("histories")) || [],
@@ -13,7 +13,7 @@ document.addEventListener("alpine:init", () => {
home: 0,
generating: false,
endpoint: `${window.location.origin}/v1`,
// Initialize error message structure
errorMessage: null,
errorExpanded: false,
@@ -39,6 +39,15 @@ document.addEventListener("alpine:init", () => {
// Add models state alongside existing state
models: {},
// Show only models available locally
showDownloadedOnly: false,
topology: null,
topologyInterval: null,
// Add these new properties
expandedGroups: {},
init() {
// Clean up any pending messages
localStorage.removeItem("pendingMessage");
@@ -48,7 +57,7 @@ document.addEventListener("alpine:init", () => {
// Start polling for download progress
this.startDownloadProgressPolling();
// Start model polling with the new pattern
this.startModelPolling();
},
@@ -69,12 +78,12 @@ document.addEventListener("alpine:init", () => {
while (true) {
try {
await this.populateSelector();
// Wait 5 seconds before next poll
await new Promise(resolve => setTimeout(resolve, 5000));
// Wait 15 seconds before next poll
await new Promise(resolve => setTimeout(resolve, 15000));
} catch (error) {
console.error('Model polling error:', error);
// If there's an error, wait before retrying
await new Promise(resolve => setTimeout(resolve, 5000));
await new Promise(resolve => setTimeout(resolve, 15000));
}
}
},
@@ -82,14 +91,14 @@ document.addEventListener("alpine:init", () => {
async populateSelector() {
return new Promise((resolve, reject) => {
const evtSource = new EventSource(`${window.location.origin}/modelpool`);
evtSource.onmessage = (event) => {
if (event.data === "[DONE]") {
evtSource.close();
resolve();
return;
}
const modelData = JSON.parse(event.data);
// Update existing model data while preserving other properties
Object.entries(modelData).forEach(([modelName, data]) => {
@@ -102,7 +111,7 @@ document.addEventListener("alpine:init", () => {
}
});
};
evtSource.onerror = (error) => {
console.error('EventSource failed:', error);
evtSource.close();
@@ -228,53 +237,110 @@ document.addEventListener("alpine:init", () => {
};
}
});
const containsImage = apiMessages.some(msg => Array.isArray(msg.content) && msg.content.some(item => item.type === 'image_url'));
if (containsImage) {
// Map all messages with string content to object with type text
apiMessages = apiMessages.map(msg => {
if (typeof msg.content === 'string') {
return {
...msg,
content: [
{
type: "text",
text: msg.content
}
]
};
}
return msg;
if (this.cstate.selectedModel === "stable-diffusion-2-1-base") {
// Send a request to the image generation endpoint
console.log(apiMessages[apiMessages.length - 1].content)
console.log(this.cstate.selectedModel)
console.log(this.endpoint)
const response = await fetch(`${this.endpoint}/image/generations`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
"model": 'stable-diffusion-2-1-base',
"prompt": apiMessages[apiMessages.length - 1].content,
"image_url": this.imageUrl
}),
});
}
// start receiving server sent events
let gottenFirstChunk = false;
for await (
const chunk of this.openaiChatCompletion(this.cstate.selectedModel, apiMessages)
) {
if (!gottenFirstChunk) {
this.cstate.messages.push({ role: "assistant", content: "" });
gottenFirstChunk = true;
if (!response.ok) {
throw new Error("Failed to fetch");
}
// add chunk to the last message
this.cstate.messages[this.cstate.messages.length - 1].content += chunk;
// calculate performance tracking
tokens += 1;
this.total_tokens += 1;
if (start_time === 0) {
start_time = Date.now();
this.time_till_first = start_time - prefill_start;
} else {
const diff = Date.now() - start_time;
if (diff > 0) {
this.tokens_per_second = tokens / (diff / 1000);
const reader = response.body.getReader();
let done = false;
let gottenFirstChunk = false;
while (!done) {
const { value, done: readerDone } = await reader.read();
done = readerDone;
const decoder = new TextDecoder();
if (value) {
// Assume non-binary data (text) comes first
const chunk = decoder.decode(value, { stream: true });
const parsed = JSON.parse(chunk);
console.log(parsed)
if (parsed.progress) {
if (!gottenFirstChunk) {
this.cstate.messages.push({ role: "assistant", content: "" });
gottenFirstChunk = true;
}
this.cstate.messages[this.cstate.messages.length - 1].content = parsed.progress;
}
else if (parsed.images) {
if (!gottenFirstChunk) {
this.cstate.messages.push({ role: "assistant", content: "" });
gottenFirstChunk = true;
}
const imageUrl = parsed.images[0].url;
console.log(imageUrl)
this.cstate.messages[this.cstate.messages.length - 1].content = `![Generated Image](${imageUrl}?t=${Date.now()})`;
}
}
}
}
else{
const containsImage = apiMessages.some(msg => Array.isArray(msg.content) && msg.content.some(item => item.type === 'image_url'));
if (containsImage) {
// Map all messages with string content to object with type text
apiMessages = apiMessages.map(msg => {
if (typeof msg.content === 'string') {
return {
...msg,
content: [
{
type: "text",
text: msg.content
}
]
};
}
return msg;
});
}
console.log(apiMessages)
//start receiving server sent events
let gottenFirstChunk = false;
for await (
const chunk of this.openaiChatCompletion(this.cstate.selectedModel, apiMessages)
) {
if (!gottenFirstChunk) {
this.cstate.messages.push({ role: "assistant", content: "" });
gottenFirstChunk = true;
}
// add chunk to the last message
this.cstate.messages[this.cstate.messages.length - 1].content += chunk;
// calculate performance tracking
tokens += 1;
this.total_tokens += 1;
if (start_time === 0) {
start_time = Date.now();
this.time_till_first = start_time - prefill_start;
} else {
const diff = Date.now() - start_time;
if (diff > 0) {
this.tokens_per_second = tokens / (diff / 1000);
}
}
}
}
// Clean the cstate before adding it to histories
const cleanedCstate = JSON.parse(JSON.stringify(this.cstate));
cleanedCstate.messages = cleanedCstate.messages.map(msg => {
@@ -333,8 +399,6 @@ document.addEventListener("alpine:init", () => {
},
async *openaiChatCompletion(model, messages) {
// stream response
console.log("model", model)
const response = await fetch(`${this.endpoint}/chat/completions`, {
method: "POST",
headers: {
@@ -357,19 +421,17 @@ document.addEventListener("alpine:init", () => {
const reader = response.body.pipeThrough(new TextDecoderStream())
.pipeThrough(new EventSourceParserStream()).getReader();
while (true) {
const { done, value } = await reader.read();
if (done) {
break;
}
if (done) break;
if (value.type === "event") {
const json = JSON.parse(value.data);
if (json.choices) {
const choice = json.choices[0];
if (choice.finish_reason === "stop") {
break;
}
yield choice.delta.content;
if (choice.finish_reason === "stop") break;
if (choice.delta.content) yield choice.delta.content;
}
}
}
@@ -452,7 +514,7 @@ document.addEventListener("alpine:init", () => {
stack: error.stack || ""
};
this.errorExpanded = false;
if (this.errorTimeout) {
clearTimeout(this.errorTimeout);
}
@@ -467,10 +529,10 @@ document.addEventListener("alpine:init", () => {
async deleteModel(modelName, model) {
const downloadedSize = model.total_downloaded || 0;
const sizeMessage = downloadedSize > 0 ?
const sizeMessage = downloadedSize > 0 ?
`This will free up ${this.formatBytes(downloadedSize)} of space.` :
'This will remove any partially downloaded files.';
if (!confirm(`Are you sure you want to delete ${model.name}? ${sizeMessage}`)) {
return;
}
@@ -484,7 +546,7 @@ document.addEventListener("alpine:init", () => {
});
const data = await response.json();
if (!response.ok) {
throw new Error(data.detail || 'Failed to delete model');
}
@@ -543,7 +605,127 @@ document.addEventListener("alpine:init", () => {
console.error('Error starting download:', error);
this.setError(error);
}
}
},
async fetchTopology() {
try {
const response = await fetch(`${this.endpoint}/topology`);
if (!response.ok) throw new Error('Failed to fetch topology');
return await response.json();
} catch (error) {
console.error('Topology fetch error:', error);
return null;
}
},
initTopology() {
// Initial fetch
this.updateTopology();
// Set up periodic updates
this.topologyInterval = setInterval(() => this.updateTopology(), 5000);
// Cleanup on page unload
window.addEventListener('beforeunload', () => {
if (this.topologyInterval) {
clearInterval(this.topologyInterval);
}
});
},
async updateTopology() {
const topologyData = await this.fetchTopology();
if (!topologyData) return;
const vizElement = this.$refs.topologyViz;
vizElement.innerHTML = ''; // Clear existing visualization
// Helper function to truncate node ID
const truncateNodeId = (id) => id.substring(0, 8);
// Create nodes from object
Object.entries(topologyData.nodes).forEach(([nodeId, node]) => {
const nodeElement = document.createElement('div');
nodeElement.className = 'topology-node';
// Get peer connections for this node
const peerConnections = topologyData.peer_graph[nodeId] || [];
const peerConnectionsHtml = peerConnections.map(peer => `
<div class="peer-connection">
<i class="fas fa-arrow-right"></i>
<span>To ${truncateNodeId(peer.to_id)}: ${peer.description}</span>
</div>
`).join('');
nodeElement.innerHTML = `
<div class="node-info">
<span class="status ${nodeId === topologyData.active_node_id ? 'active' : 'inactive'}"></span>
<span>${node.model} [${truncateNodeId(nodeId)}]</span>
</div>
<div class="node-details">
<span>${node.chip}</span>
<span>${(node.memory / 1024).toFixed(1)}GB RAM</span>
<span>${node.flops.fp32.toFixed(1)} TF</span>
</div>
<div class="peer-connections">
${peerConnectionsHtml}
</div>
`;
vizElement.appendChild(nodeElement);
});
},
// Add these helper methods
countDownloadedModels(models) {
return Object.values(models).filter(model => model.downloaded).length;
},
getGroupCounts(groupModels) {
const total = Object.keys(groupModels).length;
const downloaded = this.countDownloadedModels(groupModels);
return `[${downloaded}/${total}]`;
},
// Update the existing groupModelsByPrefix method to include counts
groupModelsByPrefix(models) {
const groups = {};
const filteredModels = this.showDownloadedOnly ?
Object.fromEntries(Object.entries(models).filter(([, model]) => model.downloaded)) :
models;
Object.entries(filteredModels).forEach(([key, model]) => {
const parts = key.split('-');
const mainPrefix = parts[0].toUpperCase();
let subPrefix;
if (parts.length === 2) {
subPrefix = parts[1].toUpperCase();
} else if (parts.length > 2) {
subPrefix = parts[1].toUpperCase();
} else {
subPrefix = 'OTHER';
}
if (!groups[mainPrefix]) {
groups[mainPrefix] = {};
}
if (!groups[mainPrefix][subPrefix]) {
groups[mainPrefix][subPrefix] = {};
}
groups[mainPrefix][subPrefix][key] = model;
});
return groups;
},
toggleGroup(prefix, subPrefix = null) {
const key = subPrefix ? `${prefix}-${subPrefix}` : prefix;
this.expandedGroups[key] = !this.expandedGroups[key];
},
isGroupExpanded(prefix, subPrefix = null) {
const key = subPrefix ? `${prefix}-${subPrefix}` : prefix;
return this.expandedGroups[key] || false;
},
}));
});

View File

@@ -3,6 +3,8 @@ from pydantic import BaseModel
from exo import DEBUG
import subprocess
import psutil
import asyncio
from exo.helpers import get_mac_system_info, subprocess_pool
TFLOPS = 1.00
@@ -144,11 +146,13 @@ CHIP_FLOPS.update({f"{key} LAPTOP GPU": value for key, value in CHIP_FLOPS.items
CHIP_FLOPS.update({f"{key} Laptop GPU": value for key, value in CHIP_FLOPS.items()})
def device_capabilities() -> DeviceCapabilities:
async def device_capabilities() -> DeviceCapabilities:
if psutil.MACOS:
return mac_device_capabilities()
return await mac_device_capabilities()
elif psutil.LINUX:
return linux_device_capabilities()
return await linux_device_capabilities()
elif psutil.WINDOWS:
return await windows_device_capabilities()
else:
return DeviceCapabilities(
model="Unknown Device",
@@ -158,27 +162,18 @@ def device_capabilities() -> DeviceCapabilities:
)
def mac_device_capabilities() -> DeviceCapabilities:
# Fetch the model of the Mac using system_profiler
model = subprocess.check_output(["system_profiler", "SPHardwareDataType"]).decode("utf-8")
model_line = next((line for line in model.split("\n") if "Model Name" in line), None)
model_id = model_line.split(": ")[1] if model_line else "Unknown Model"
chip_line = next((line for line in model.split("\n") if "Chip" in line), None)
chip_id = chip_line.split(": ")[1] if chip_line else "Unknown Chip"
memory_line = next((line for line in model.split("\n") if "Memory" in line), None)
memory_str = memory_line.split(": ")[1] if memory_line else "Unknown Memory"
memory_units = memory_str.split()
memory_value = int(memory_units[0])
if memory_units[1] == "GB":
memory = memory_value*1024
else:
memory = memory_value
# Assuming static values for other attributes for demonstration
return DeviceCapabilities(model=model_id, chip=chip_id, memory=memory, flops=CHIP_FLOPS.get(chip_id, DeviceFlops(fp32=0, fp16=0, int8=0)))
async def mac_device_capabilities() -> DeviceCapabilities:
model_id, chip_id, memory = await get_mac_system_info()
return DeviceCapabilities(
model=model_id,
chip=chip_id,
memory=memory,
flops=CHIP_FLOPS.get(chip_id, DeviceFlops(fp32=0, fp16=0, int8=0))
)
def linux_device_capabilities() -> DeviceCapabilities:
async def linux_device_capabilities() -> DeviceCapabilities:
import psutil
from tinygrad import Device
@@ -194,6 +189,8 @@ def linux_device_capabilities() -> DeviceCapabilities:
if DEBUG >= 2: print(f"NVIDIA device {gpu_name=} {gpu_memory_info=}")
pynvml.nvmlShutdown()
return DeviceCapabilities(
model=f"Linux Box ({gpu_name})",
chip=gpu_name,
@@ -201,13 +198,24 @@ def linux_device_capabilities() -> DeviceCapabilities:
flops=CHIP_FLOPS.get(gpu_name, DeviceFlops(fp32=0, fp16=0, int8=0)),
)
elif Device.DEFAULT == "AMD":
# TODO AMD support
# For AMD GPUs, pyrsmi is the way (Official python package for rocm-smi)
from pyrsmi import rocml
rocml.smi_initialize()
gpu_name = rocml.smi_get_device_name(0).upper()
gpu_memory_info = rocml.smi_get_device_memory_total(0)
if DEBUG >= 2: print(f"AMD device {gpu_name=} {gpu_memory_info=}")
rocml.smi_shutdown()
return DeviceCapabilities(
model="Linux Box (AMD)",
chip="Unknown AMD",
memory=psutil.virtual_memory().total // 2**20,
model="Linux Box ({gpu_name})",
chip=gpu_name,
memory=gpu_memory_info // 2**20,
flops=DeviceFlops(fp32=0, fp16=0, int8=0),
)
else:
return DeviceCapabilities(
model=f"Linux Box (Device: {Device.DEFAULT})",
@@ -215,3 +223,74 @@ def linux_device_capabilities() -> DeviceCapabilities:
memory=psutil.virtual_memory().total // 2**20,
flops=DeviceFlops(fp32=0, fp16=0, int8=0),
)
def windows_device_capabilities() -> DeviceCapabilities:
import psutil
def get_gpu_info():
import win32com.client # install pywin32
wmiObj = win32com.client.GetObject("winmgmts:\\\\.\\root\\cimv2")
gpus = wmiObj.ExecQuery("SELECT * FROM Win32_VideoController")
gpu_info = []
for gpu in gpus:
info = {
"Name": gpu.Name,
"AdapterRAM": gpu.AdapterRAM, # Bug in this property, returns -ve for VRAM > 4GB (uint32 overflow)
"DriverVersion": gpu.DriverVersion,
"VideoProcessor": gpu.VideoProcessor
}
gpu_info.append(info)
return gpu_info
gpus_info = get_gpu_info()
gpu_names = [gpu['Name'] for gpu in gpus_info]
contains_nvidia = any('nvidia' in gpu_name.lower() for gpu_name in gpu_names)
contains_amd = any('amd' in gpu_name.lower() for gpu_name in gpu_names)
if contains_nvidia:
import pynvml
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
gpu_raw_name = pynvml.nvmlDeviceGetName(handle).upper()
gpu_name = gpu_raw_name.rsplit(" ", 1)[0] if gpu_raw_name.endswith("GB") else gpu_raw_name
gpu_memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
if DEBUG >= 2: print(f"NVIDIA device {gpu_name=} {gpu_memory_info=}")
return DeviceCapabilities(
model=f"Windows Box ({gpu_name})",
chip=gpu_name,
memory=gpu_memory_info.total // 2**20,
flops=CHIP_FLOPS.get(gpu_name, DeviceFlops(fp32=0, fp16=0, int8=0)),
)
elif contains_amd:
# For AMD GPUs, pyrsmi is the way (Official python package for rocm-smi)
from pyrsmi import rocml
rocml.smi_initialize()
gpu_name = rocml.smi_get_device_name(0).upper()
gpu_memory_info = rocml.smi_get_device_memory_total(0)
if DEBUG >= 2: print(f"AMD device {gpu_name=} {gpu_memory_info=}")
rocml.smi_shutdown()
return DeviceCapabilities(
model="Windows Box ({gpu_name})",
chip={gpu_name},
memory=gpu_memory_info.total // 2**20,
flops=DeviceFlops(fp32=0, fp16=0, int8=0),
)
else:
return DeviceCapabilities(
model=f"Windows Box (Device: Unknown)",
chip=f"Unknown Chip (Device(s): {gpu_names})",
memory=psutil.virtual_memory().total // 2**20,
flops=DeviceFlops(fp32=0, fp16=0, int8=0),
)

View File

@@ -1,8 +1,10 @@
from abc import ABC, abstractmethod
from typing import List
from typing import List, Dict
from dataclasses import dataclass
from .topology import Topology
from exo.inference.shard import Shard
from exo.topology.device_capabilities import device_capabilities
import asyncio
# Partitions shard-space into pieces of contiguous shards, represented by floating point range [start, end) between 0 and 1

View File

@@ -1,11 +1,11 @@
import unittest
import pytest
from unittest.mock import patch
from exo.topology.device_capabilities import mac_device_capabilities, DeviceCapabilities, DeviceFlops, TFLOPS
from exo.topology.device_capabilities import mac_device_capabilities, DeviceCapabilities, DeviceFlops, TFLOPS, device_capabilities
class TestMacDeviceCapabilities(unittest.TestCase):
@patch("subprocess.check_output")
def test_mac_device_capabilities_pro(self, mock_check_output):
@pytest.mark.asyncio
@patch("subprocess.check_output")
async def test_mac_device_capabilities_pro(mock_check_output):
# Mock the subprocess output
mock_check_output.return_value = b"""
Hardware:
@@ -27,20 +27,19 @@ Activation Lock Status: Enabled
"""
# Call the function
result = mac_device_capabilities()
result = await mac_device_capabilities()
# Check the results
self.assertIsInstance(result, DeviceCapabilities)
self.assertEqual(result.model, "MacBook Pro")
self.assertEqual(result.chip, "Apple M3 Max")
self.assertEqual(result.memory, 131072) # 16 GB in MB
self.assertEqual(
str(result),
"Model: MacBook Pro. Chip: Apple M3 Max. Memory: 131072MB. Flops: 14.20 TFLOPS, fp16: 28.40 TFLOPS, int8: 56.80 TFLOPS",
)
assert isinstance(result, DeviceCapabilities)
assert result.model == "MacBook Pro"
assert result.chip == "Apple M3 Max"
assert result.memory == 131072 # 128 GB in MB
assert str(result) == "Model: MacBook Pro. Chip: Apple M3 Max. Memory: 131072MB. Flops: 14.20 TFLOPS, fp16: 28.40 TFLOPS, int8: 56.80 TFLOPS"
@patch("subprocess.check_output")
def test_mac_device_capabilities_air(self, mock_check_output):
@pytest.mark.asyncio
@patch("subprocess.check_output")
async def test_mac_device_capabilities_air(mock_check_output):
# Mock the subprocess output
mock_check_output.return_value = b"""
Hardware:
@@ -62,30 +61,34 @@ Activation Lock Status: Disabled
"""
# Call the function
result = mac_device_capabilities()
result = await mac_device_capabilities()
# Check the results
self.assertIsInstance(result, DeviceCapabilities)
self.assertEqual(result.model, "MacBook Air")
self.assertEqual(result.chip, "Apple M2")
self.assertEqual(result.memory, 8192) # 8 GB in MB
assert isinstance(result, DeviceCapabilities)
assert result.model == "MacBook Air"
assert result.chip == "Apple M2"
assert result.memory == 8192 # 8 GB in MB
@unittest.skip("Unskip this test when running on a MacBook Pro, Apple M3 Max, 128GB")
def test_mac_device_capabilities_real(self):
@pytest.mark.skip(reason="Unskip this test when running on a MacBook Pro, Apple M3 Max, 128GB")
@pytest.mark.asyncio
async def test_mac_device_capabilities_real():
# Call the function without mocking
result = mac_device_capabilities()
result = await mac_device_capabilities()
# Check the results
self.assertIsInstance(result, DeviceCapabilities)
self.assertEqual(result.model, "MacBook Pro")
self.assertEqual(result.chip, "Apple M3 Max")
self.assertEqual(result.memory, 131072) # 128 GB in MB
self.assertEqual(result.flops, DeviceFlops(fp32=14.20*TFLOPS, fp16=28.40*TFLOPS, int8=56.80*TFLOPS))
self.assertEqual(
str(result),
"Model: MacBook Pro. Chip: Apple M3 Max. Memory: 131072MB. Flops: 14.20 TFLOPS, fp16: 28.40 TFLOPS, int8: 56.80 TFLOPS",
)
assert isinstance(result, DeviceCapabilities)
assert result.model == "MacBook Pro"
assert result.chip == "Apple M3 Max"
assert result.memory == 131072 # 128 GB in MB
assert result.flops == DeviceFlops(fp32=14.20*TFLOPS, fp16=28.40*TFLOPS, int8=56.80*TFLOPS)
assert str(result) == "Model: MacBook Pro. Chip: Apple M3 Max. Memory: 131072MB. Flops: 14.20 TFLOPS, fp16: 28.40 TFLOPS, int8: 56.80 TFLOPS"
if __name__ == "__main__":
unittest.main()
@pytest.mark.asyncio
async def test_device_capabilities():
caps = await device_capabilities()
assert caps.model != ""
assert caps.chip != ""
assert caps.memory > 0
assert caps.flops is not None

View File

@@ -5,7 +5,7 @@ from exo.viz.topology_viz import TopologyViz
from exo.topology.topology import Topology
from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
from exo.topology.partitioning_strategy import Partition
from exo.download.hf.hf_helpers import RepoProgressEvent, RepoFileProgressEvent
from exo.download.download_progress import RepoProgressEvent
def create_hf_repo_progress_event(

View File

@@ -4,7 +4,7 @@ from typing import List, Optional, Tuple, Dict
from exo.helpers import exo_text, pretty_print_bytes, pretty_print_bytes_per_second
from exo.topology.topology import Topology
from exo.topology.partitioning_strategy import Partition
from exo.download.hf.hf_helpers import RepoProgressEvent
from exo.download.download_progress import RepoProgressEvent
from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES
from rich.console import Console, Group
from rich.text import Text
@@ -51,17 +51,11 @@ class TopologyViz:
self.refresh()
def update_prompt(self, request_id: str, prompt: Optional[str] = None):
if request_id in self.requests:
self.requests[request_id] = [prompt, self.requests[request_id][1]]
else:
self.requests[request_id] = [prompt, ""]
self.requests[request_id] = [prompt, self.requests.get(request_id, ["", ""])[1]]
self.refresh()
def update_prompt_output(self, request_id: str, output: Optional[str] = None):
if request_id in self.requests:
self.requests[request_id] = [self.requests[request_id][0], output]
else:
self.requests[request_id] = ["", output]
self.requests[request_id] = [self.requests.get(request_id, ["", ""])[0], output]
self.refresh()
def refresh(self):
@@ -91,36 +85,96 @@ class TopologyViz:
content = []
requests = list(self.requests.values())[-3:] # Get the 3 most recent requests
max_width = self.console.width - 6 # Full width minus padding and icon
max_lines = 13 # Maximum number of lines for the entire panel content
# Calculate available height for content
panel_height = 15 # Fixed panel height
available_lines = panel_height - 2 # Subtract 2 for panel borders
lines_per_request = available_lines // len(requests) if requests else 0
for (prompt, output) in reversed(requests):
prompt_icon, output_icon = "💬️", "🤖"
# Equal space allocation for prompt and output
max_prompt_lines = lines_per_request // 2
max_output_lines = lines_per_request - max_prompt_lines - 1 # -1 for spacing
# Process prompt
prompt_lines = prompt.split('\n')
if len(prompt_lines) > max_lines // 2:
prompt_lines = prompt_lines[:max_lines//2 - 1] + ['...']
prompt_lines = []
for line in prompt.split('\n'):
words = line.split()
current_line = []
current_length = 0
for word in words:
if current_length + len(word) + 1 <= max_width:
current_line.append(word)
current_length += len(word) + 1
else:
if current_line:
prompt_lines.append(' '.join(current_line))
current_line = [word]
current_length = len(word)
if current_line:
prompt_lines.append(' '.join(current_line))
# Truncate prompt if needed
if len(prompt_lines) > max_prompt_lines:
prompt_lines = prompt_lines[:max_prompt_lines]
if prompt_lines:
last_line = prompt_lines[-1]
if len(last_line) + 4 <= max_width:
prompt_lines[-1] = last_line + " ..."
else:
prompt_lines[-1] = last_line[:max_width-4] + " ..."
prompt_text = Text(f"{prompt_icon} ", style="bold bright_blue")
prompt_text.append('\n'.join(line[:max_width] for line in prompt_lines), style="white")
# Process output
output_lines = output.split('\n')
remaining_lines = max_lines - len(prompt_lines) - 2 # -2 for spacing
if len(output_lines) > remaining_lines:
output_lines = output_lines[:remaining_lines - 1] + ['...']
output_text = Text(f"\n{output_icon} ", style="bold bright_magenta")
output_text.append('\n'.join(line[:max_width] for line in output_lines), style="white")
prompt_text.append('\n'.join(prompt_lines), style="white")
content.append(prompt_text)
content.append(output_text)
# Process output with similar word wrapping
if output: # Only process output if it exists
output_lines = []
for line in output.split('\n'):
words = line.split()
current_line = []
current_length = 0
for word in words:
if current_length + len(word) + 1 <= max_width:
current_line.append(word)
current_length += len(word) + 1
else:
if current_line:
output_lines.append(' '.join(current_line))
current_line = [word]
current_length = len(word)
if current_line:
output_lines.append(' '.join(current_line))
# Truncate output if needed
if len(output_lines) > max_output_lines:
output_lines = output_lines[:max_output_lines]
if output_lines:
last_line = output_lines[-1]
if len(last_line) + 4 <= max_width:
output_lines[-1] = last_line + " ..."
else:
output_lines[-1] = last_line[:max_width-4] + " ..."
output_text = Text(f"{output_icon} ", style="bold bright_magenta")
output_text.append('\n'.join(output_lines), style="white")
content.append(output_text)
content.append(Text()) # Empty line between entries
return Panel(
Group(*content),
title="",
border_style="cyan",
height=15, # Increased height to accommodate multiple lines
expand=True # Allow the panel to expand to full width
height=panel_height,
expand=True
)
def _generate_main_layout(self) -> str:

View File

@@ -1,50 +0,0 @@
import argparse
import asyncio
from exo.download.hf.hf_helpers import download_all_files, RepoProgressEvent
DEFAULT_ALLOW_PATTERNS = [
"*.json",
"*.py",
"tokenizer.model",
"*.tiktoken",
"*.txt",
"*.safetensors",
]
# Always ignore `.git` and `.cache/huggingface` folders in commits
DEFAULT_IGNORE_PATTERNS = [
".git",
".git/*",
"*/.git",
"**/.git/**",
".cache/huggingface",
".cache/huggingface/*",
"*/.cache/huggingface",
"**/.cache/huggingface/**",
]
async def main(repo_id, revision="main", allow_patterns=None, ignore_patterns=None):
async def progress_callback(event: RepoProgressEvent):
print(f"Overall Progress: {event.completed_files}/{event.total_files} files, {event.downloaded_bytes}/{event.total_bytes} bytes")
print(f"Estimated time remaining: {event.overall_eta}")
print("File Progress:")
for file_path, progress in event.file_progress.items():
status_icon = {'not_started': '', 'in_progress': '🔵', 'complete': ''}[progress.status]
eta_str = str(progress.eta)
print(f"{status_icon} {file_path}: {progress.downloaded}/{progress.total} bytes, "
f"Speed: {progress.speed:.2f} B/s, ETA: {eta_str}")
print("\n")
await download_all_files(repo_id, revision, progress_callback, allow_patterns, ignore_patterns)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Download files from a Hugging Face model repository.")
parser.add_argument("--repo-id", required=True, help="The repository ID (e.g., 'meta-llama/Meta-Llama-3.1-8B-Instruct')")
parser.add_argument("--revision", default="main", help="The revision to download (branch, tag, or commit hash)")
parser.add_argument("--allow-patterns", nargs="*", default=None, help="Patterns of files to allow (e.g., '*.json' '*.safetensors')")
parser.add_argument("--ignore-patterns", nargs="*", default=None, help="Patterns of files to ignore (e.g., '.*')")
args = parser.parse_args()
asyncio.run(main(args.repo_id, args.revision, args.allow_patterns, args.ignore_patterns))

View File

@@ -74,9 +74,9 @@ def gen_diff(table_old, table_new):
def create_json_report(table, is_diff=False):
timestamp = datetime.now(timezone.utc).isoformat()
commit_sha = os.environ.get('CIRCLE_SHA1', 'unknown')
branch = os.environ.get('CIRCLE_BRANCH', 'unknown')
pr_number = os.environ.get('CIRCLE_PR_NUMBER', '')
commit_sha = os.environ.get('GITHUB_SHA', 'unknown')
branch = os.environ.get('GITHUB_REF_NAME', 'unknown')
pr_number = os.environ.get('GITHUB_EVENT_NUMBER', '')
if is_diff:
files = [{

View File

@@ -1,4 +1,4 @@
#!/bin/bash
#!/usr/bin/env bash
if command -v python3.12 &>/dev/null; then
echo "Python 3.12 is installed, proceeding with python3.12..."

View File

@@ -6,6 +6,9 @@ import pkgutil
def run():
site_packages = site.getsitepackages()[0]
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
baseimages_dir = os.path.join(base_dir, "exo", "apputil", "baseimages")
command = [
f"{sys.executable}", "-m", "nuitka", "exo/main.py",
"--company-name=exolabs",
@@ -15,7 +18,8 @@ def run():
"--standalone",
"--output-filename=exo",
"--python-flag=no_site",
"--onefile"
"--onefile",
f"--include-data-dir={baseimages_dir}=exo/apputil/baseimages"
]
if sys.platform == "darwin":
@@ -23,7 +27,7 @@ def run():
"--macos-app-name=exo",
"--macos-app-mode=gui",
"--macos-app-version=0.0.1",
"--macos-signed-app-name=com.exolabs.exo",
"--macos-signed-app-name=net.exolabs.exo",
"--include-distribution-meta=mlx",
"--include-module=mlx._reprlib_fix",
"--include-module=mlx._os_warning",

View File

@@ -1,4 +1,4 @@
#!/bin/bash
#!/usr/bin/env bash
source ./install.sh
pushd exo/networking/grpc
python3 -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. node_service.proto

View File

@@ -1,5 +1,6 @@
import sys
import platform
import subprocess
from setuptools import find_packages, setup
@@ -11,7 +12,6 @@ install_requires = [
"grpcio==1.68.0",
"grpcio-tools==1.68.0",
"Jinja2==3.1.4",
"netifaces==0.11.0",
"numpy==2.0.0",
"nuitka==2.5.1",
"nvidia-ml-py==12.560.30",
@@ -23,27 +23,60 @@ install_requires = [
"pydantic==2.9.2",
"requests==2.32.3",
"rich==13.7.1",
"tenacity==9.0.0",
"scapy==2.6.1",
"tqdm==4.66.4",
"transformers==4.46.3",
"uuid==1.30",
"tinygrad @ git+https://github.com/tinygrad/tinygrad.git@3b26e51fcebfc6576f4e0f99693e6f1406d61d79",
"uvloop==0.21.0",
"tinygrad @ git+https://github.com/tinygrad/tinygrad.git@ec120ce6b9ce8e4ff4b5692566a683ef240e8bc8",
]
extras_require = {
"formatting": [
"yapf==0.40.2",
],
"formatting": ["yapf==0.40.2",],
"apple_silicon": [
"mlx==0.20.0",
"mlx-lm==0.19.3",
"mlx==0.22.0",
"mlx-lm==0.21.1",
],
"windows": ["pywin32==308",],
"nvidia-gpu": ["nvidia-ml-py==12.560.30",],
"amd-gpu": ["pyrsmi==0.2.0"],
}
# Check if running on macOS with Apple Silicon
if sys.platform.startswith("darwin") and platform.machine() == "arm64":
install_requires.extend(extras_require["apple_silicon"])
# Check if running Windows
if sys.platform.startswith("win32"):
install_requires.extend(extras_require["windows"])
def _add_gpu_requires():
global install_requires
# Add Nvidia-GPU
try:
out = subprocess.run(['nvidia-smi', '--query-gpu=name', '--format=csv,noheader'], shell=True, text=True, capture_output=True, check=False)
if out.returncode == 0:
install_requires.extend(extras_require["nvidia-gpu"])
except subprocess.CalledProcessError:
pass
# Add AMD-GPU
# This will mostly work only on Linux, amd/rocm-smi is not yet supported on Windows
try:
out = subprocess.run(['amd-smi', 'list', '--csv'], shell=True, text=True, capture_output=True, check=False)
if out.returncode == 0:
install_requires.extend(extras_require["amd-gpu"])
except:
out = subprocess.run(['rocm-smi', 'list', '--csv'], shell=True, text=True, capture_output=True, check=False)
if out.returncode == 0:
install_requires.extend(extras_require["amd-gpu"])
finally:
pass
_add_gpu_requires()
setup(
name="exo",
version="0.0.1",

View File

@@ -1,4 +1,4 @@
#!/bin/bash
#!/usr/bin/env bash
echo "Starting node 1"
DEBUG_DISCOVERY=7 DEBUG=7 python3 main.py --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 52415 --chatgpt-api-response-timeout 900 > output1.log 2>&1 &

View File

@@ -1,26 +0,0 @@
import os
import sys
# Add the project root to the Python path
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)
import asyncio
from exo.download.hf.hf_helpers import get_weight_map
async def test_get_weight_map():
repo_ids = [
"mlx-community/quantized-gemma-2b",
"mlx-community/Meta-Llama-3.1-8B-4bit",
"mlx-community/Meta-Llama-3.1-70B-4bit",
"mlx-community/Meta-Llama-3.1-405B-4bit",
]
for repo_id in repo_ids:
weight_map = await get_weight_map(repo_id)
assert weight_map is not None, "Weight map should not be None"
assert isinstance(weight_map, dict), "Weight map should be a dictionary"
assert len(weight_map) > 0, "Weight map should not be empty"
print(f"OK: {repo_id}")
if __name__ == "__main__":
asyncio.run(test_get_weight_map())

View File

@@ -24,7 +24,7 @@ def test_tokenizer(name, tokenizer, verbose=False):
strip_tokens = lambda s: s.lstrip(tokenizer.decode([tokenizer.bos_token_id])).rstrip(tokenizer.decode([tokenizer.eos_token_id]))
assert text == strip_tokens(decoded) == strip_tokens(reconstructed)
ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", "llava-hf/llava-1.5-7b-hf", "mlx-community/Qwen*", "dummy", "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit"]
ignore = ["TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R", "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", "llava-hf/llava-1.5-7b-hf", "mlx-community/Qwen*", "dummy", "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit", "mlx-community/Phi-3.5-mini-instruct-4bit", "mlx-community/phi-4-4bit", "stabilityai/stable-diffusion-2-1-base"]
ignore_pattern = re.compile(r"^(" + "|".join(model.replace("*", ".*") for model in ignore) + r")")
models = []
for model_id in model_cards:
@@ -37,5 +37,6 @@ verbose = os.environ.get("VERBOSE", "0").lower() == "1"
for m in models:
# TODO: figure out why use_fast=False is giving inconsistent behaviour (no spaces decoding invididual tokens) for Mistral-Large-Instruct-2407-4bit
# test_tokenizer(m, AutoProcessor.from_pretrained(m, use_fast=False), verbose)
test_tokenizer(m, AutoProcessor.from_pretrained(m, use_fast=True), verbose)
test_tokenizer(m, AutoTokenizer.from_pretrained(m), verbose)
if m not in ["mlx-community/DeepSeek-R1-4bit", "mlx-community/DeepSeek-R1-3bit", "mlx-community/DeepSeek-V3-4bit", "mlx-community/DeepSeek-V3-3bit"]:
test_tokenizer(m, AutoProcessor.from_pretrained(m, use_fast=True, trust_remote_code=True), verbose)
test_tokenizer(m, AutoTokenizer.from_pretrained(m, trust_remote_code=True), verbose)