Compare commits

...

10 Commits

Author SHA1 Message Date
Alex Cheema
410d901505 Merge pull request #613 from samiamjidkhan/dmg-backend
image and text mode fix
2025-01-21 13:12:08 +00:00
Sami Khan
5c4ce5392c image and text mode fix 2025-01-21 04:33:54 -05:00
Alex Cheema
819ec7626e Merge pull request #611 from exo-explore/fixbuildname
fix scripts/build_exo.py: com.exolabs.exo -> net.exolabs.exo
2025-01-21 05:36:34 +00:00
Alex Cheema
ba5bb3e171 fix scripts/build_exo.py: com.exolabs.exo -> net.exolabs.exo 2025-01-21 05:36:02 +00:00
Alex Cheema
f4bbcf4c8f Merge pull request #607 from tensorsofthewall/smol_fix
Fixes for cross-platform operability
2025-01-21 02:21:18 +00:00
Sandesh Bharadwaj
b9eccedc3d Formatting 2025-01-17 05:40:42 -05:00
Sandesh Bharadwaj
5f06aa2759 Replace netifaces (unmaintained,outdated) with scapy + add dependencies for previous fixes 2025-01-17 05:37:01 -05:00
Sandesh Bharadwaj
349b5344eb Minor fix for Shard typing 2025-01-16 14:36:46 -05:00
Sandesh Bharadwaj
df3624d27a Add AMD GPU querying + Windows device capabilities 2025-01-14 20:37:02 -05:00
Sandesh Bharadwaj
6737e36e23 Fixed MLX import blocking native Windows execution of exo. (Not Final) 2025-01-14 20:35:21 -05:00
8 changed files with 345 additions and 238 deletions

View File

@@ -21,13 +21,20 @@ from PIL import Image
import numpy as np
import base64
from io import BytesIO
import mlx.core as mx
import platform
if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
import mlx.core as mx
else:
import numpy as mx
import tempfile
from exo.download.hf.hf_shard_download import HFShardDownloader
import shutil
from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
from exo.apputil import create_animation_mp4
class Message:
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]], tools: Optional[List[Dict]] = None):
self.role = role
@@ -41,7 +48,6 @@ class Message:
return data
class ChatCompletionRequest:
def __init__(self, model: str, messages: List[Message], temperature: float, tools: Optional[List[Dict]] = None):
self.model = model
@@ -132,16 +138,24 @@ def remap_messages(messages: List[Message]) -> List[Message]:
def build_prompt(tokenizer, _messages: List[Message], tools: Optional[List[Dict]] = None):
messages = remap_messages(_messages)
chat_template_args = {
"conversation": [m.to_dict() for m in messages],
"tokenize": False,
"add_generation_prompt": True
}
if tools: chat_template_args["tools"] = tools
chat_template_args = {"conversation": [m.to_dict() for m in messages], "tokenize": False, "add_generation_prompt": True}
if tools:
chat_template_args["tools"] = tools
prompt = tokenizer.apply_chat_template(**chat_template_args)
print(f"!!! Prompt: {prompt}")
return prompt
try:
prompt = tokenizer.apply_chat_template(**chat_template_args)
if DEBUG >= 3: print(f"!!! Prompt: {prompt}")
return prompt
except UnicodeEncodeError:
# Handle Unicode encoding by ensuring everything is UTF-8
chat_template_args["conversation"] = [
{k: v.encode('utf-8').decode('utf-8') if isinstance(v, str) else v
for k, v in m.to_dict().items()}
for m in messages
]
prompt = tokenizer.apply_chat_template(**chat_template_args)
if DEBUG >= 3: print(f"!!! Prompt (UTF-8 encoded): {prompt}")
return prompt
def parse_message(data: dict):
@@ -165,8 +179,17 @@ class PromptSession:
self.timestamp = timestamp
self.prompt = prompt
class ChatGPTAPI:
def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None, system_prompt: Optional[str] = None):
def __init__(
self,
node: Node,
inference_engine_classname: str,
response_timeout: int = 90,
on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None,
default_model: Optional[str] = None,
system_prompt: Optional[str] = None
):
self.node = node
self.inference_engine_classname = inference_engine_classname
self.response_timeout = response_timeout
@@ -202,18 +225,22 @@ class ChatGPTAPI:
cors.add(self.app.router.add_post("/download", self.handle_post_download), {"*": cors_options})
cors.add(self.app.router.add_get("/topology", self.handle_get_topology), {"*": cors_options})
# Add static routes
if "__compiled__" not in globals():
self.static_dir = Path(__file__).parent.parent/"tinychat"
self.app.router.add_get("/", self.handle_root)
self.app.router.add_static("/", self.static_dir, name="static")
self.app.router.add_static('/images/', get_exo_images_dir(), name='static_images')
# Always add images route, regardless of compilation status
self.images_dir = get_exo_images_dir()
self.images_dir.mkdir(parents=True, exist_ok=True)
self.app.router.add_static('/images/', self.images_dir, name='static_images')
self.app.middlewares.append(self.timeout_middleware)
self.app.middlewares.append(self.log_request)
async def handle_quit(self, request):
if DEBUG>=1: print("Received quit signal")
if DEBUG >= 1: print("Received quit signal")
response = web.json_response({"detail": "Quit signal received"}, status=200)
await response.prepare(request)
await response.write_eof()
@@ -243,61 +270,48 @@ class ChatGPTAPI:
async def handle_model_support(self, request):
try:
response = web.StreamResponse(
status=200,
reason='OK',
headers={
'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache',
'Connection': 'keep-alive',
}
)
await response.prepare(request)
response = web.StreamResponse(status=200, reason='OK', headers={
'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache',
'Connection': 'keep-alive',
})
await response.prepare(request)
async def process_model(model_name, pretty):
if model_name in model_cards:
model_info = model_cards[model_name]
async def process_model(model_name, pretty):
if model_name in model_cards:
model_info = model_cards[model_name]
if self.inference_engine_classname in model_info.get("repo", {}):
shard = build_base_shard(model_name, self.inference_engine_classname)
if shard:
downloader = HFShardDownloader(quick_check=True)
downloader.current_shard = shard
downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
status = await downloader.get_shard_download_status()
if self.inference_engine_classname in model_info.get("repo", {}):
shard = build_base_shard(model_name, self.inference_engine_classname)
if shard:
downloader = HFShardDownloader(quick_check=True)
downloader.current_shard = shard
downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
status = await downloader.get_shard_download_status()
download_percentage = status.get("overall") if status else None
total_size = status.get("total_size") if status else None
total_downloaded = status.get("total_downloaded") if status else False
download_percentage = status.get("overall") if status else None
total_size = status.get("total_size") if status else None
total_downloaded = status.get("total_downloaded") if status else False
model_data = {
model_name: {
"name": pretty,
"downloaded": download_percentage == 100 if download_percentage is not None else False,
"download_percentage": download_percentage,
"total_size": total_size,
"total_downloaded": total_downloaded
}
}
model_data = {
model_name: {
"name": pretty, "downloaded": download_percentage == 100 if download_percentage is not None else False, "download_percentage": download_percentage, "total_size": total_size,
"total_downloaded": total_downloaded
}
}
await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
# Process all models in parallel
await asyncio.gather(*[
process_model(model_name, pretty)
for model_name, pretty in pretty_name.items()
])
# Process all models in parallel
await asyncio.gather(*[process_model(model_name, pretty) for model_name, pretty in pretty_name.items()])
await response.write(b"data: [DONE]\n\n")
return response
await response.write(b"data: [DONE]\n\n")
return response
except Exception as e:
print(f"Error in handle_model_support: {str(e)}")
traceback.print_exc()
return web.json_response(
{"detail": f"Server error: {str(e)}"},
status=500
)
print(f"Error in handle_model_support: {str(e)}")
traceback.print_exc()
return web.json_response({"detail": f"Server error: {str(e)}"}, status=500)
async def handle_get_models(self, request):
models_list = [{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()]
@@ -466,7 +480,6 @@ class ChatGPTAPI:
deregistered_callback = self.node.on_token.deregister(callback_id)
if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")
async def handle_post_image_generations(self, request):
data = await request.json()
@@ -479,7 +492,7 @@ class ChatGPTAPI:
shard = build_base_shard(model, self.inference_engine_classname)
if DEBUG >= 2: print(f"shard: {shard}")
if not shard:
return web.json_response({"error": f"Unsupported model: {model} with inference engine {self.inference_engine_classname}"}, status=400)
return web.json_response({"error": f"Unsupported model: {model} with inference engine {self.inference_engine_classname}"}, status=400)
request_id = str(uuid.uuid4())
callback_id = f"chatgpt-api-wait-response-{request_id}"
@@ -491,77 +504,85 @@ class ChatGPTAPI:
img = None
await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id, inference_state={"image": img}))), timeout=self.response_timeout)
response = web.StreamResponse(status=200, reason='OK', headers={'Content-Type': 'application/octet-stream',"Cache-Control": "no-cache",})
response = web.StreamResponse(status=200, reason='OK', headers={
'Content-Type': 'application/octet-stream',
"Cache-Control": "no-cache",
})
await response.prepare(request)
def get_progress_bar(current_step, total_steps, bar_length=50):
# Calculate the percentage of completion
percent = float(current_step) / total_steps
percent = float(current_step)/total_steps
# Calculate the number of hashes to display
arrow = '-' * int(round(percent * bar_length) - 1) + '>'
spaces = ' ' * (bar_length - len(arrow))
arrow = '-'*int(round(percent*bar_length) - 1) + '>'
spaces = ' '*(bar_length - len(arrow))
# Create the progress bar string
progress_bar = f'Progress: [{arrow}{spaces}] {int(percent * 100)}% ({current_step}/{total_steps})'
return progress_bar
async def stream_image(_request_id: str, result, is_finished: bool):
if isinstance(result, list):
await response.write(json.dumps({'progress': get_progress_bar((result[0]), (result[1]))}).encode('utf-8') + b'\n')
if isinstance(result, list):
await response.write(json.dumps({'progress': get_progress_bar((result[0]), (result[1]))}).encode('utf-8') + b'\n')
elif isinstance(result, np.ndarray):
elif isinstance(result, np.ndarray):
try:
im = Image.fromarray(np.array(result))
images_folder = get_exo_images_dir()
# Save the image to a file
image_filename = f"{_request_id}.png"
image_path = images_folder / image_filename
image_path = self.images_dir/image_filename
im.save(image_path)
image_url = request.app.router['static_images'].url_for(filename=image_filename)
base_url = f"{request.scheme}://{request.host}"
# Construct the full URL correctly
full_image_url = base_url + str(image_url)
await response.write(json.dumps({'images': [{'url': str(full_image_url), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
# Get URL for the saved image
try:
image_url = request.app.router['static_images'].url_for(filename=image_filename)
base_url = f"{request.scheme}://{request.host}"
full_image_url = base_url + str(image_url)
await response.write(json.dumps({'images': [{'url': str(full_image_url), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
except KeyError as e:
if DEBUG >= 2: print(f"Error getting image URL: {e}")
# Fallback to direct file path if URL generation fails
await response.write(json.dumps({'images': [{'url': str(image_path), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
if is_finished:
await response.write_eof()
except Exception as e:
if DEBUG >= 2: print(f"Error processing image: {e}")
if DEBUG >= 2: traceback.print_exc()
await response.write(json.dumps({'error': str(e)}).encode('utf-8') + b'\n')
stream_task = None
def on_result(_request_id: str, result, is_finished: bool):
nonlocal stream_task
stream_task = asyncio.create_task(stream_image(_request_id, result, is_finished))
return _request_id == request_id and is_finished
nonlocal stream_task
stream_task = asyncio.create_task(stream_image(_request_id, result, is_finished))
return _request_id == request_id and is_finished
await callback.wait(on_result, timeout=self.response_timeout*10)
if stream_task:
# Wait for the stream task to complete before returning
await stream_task
# Wait for the stream task to complete before returning
await stream_task
return response
except Exception as e:
if DEBUG >= 2: traceback.print_exc()
return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
if DEBUG >= 2: traceback.print_exc()
return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
async def handle_delete_model(self, request):
try:
model_name = request.match_info.get('model_name')
if DEBUG >= 2: print(f"Attempting to delete model: {model_name}")
if not model_name or model_name not in model_cards:
return web.json_response(
{"detail": f"Invalid model name: {model_name}"},
status=400
)
return web.json_response({"detail": f"Invalid model name: {model_name}"}, status=400)
shard = build_base_shard(model_name, self.inference_engine_classname)
if not shard:
return web.json_response(
{"detail": "Could not build shard for model"},
status=400
)
return web.json_response({"detail": "Could not build shard for model"}, status=400)
repo_id = get_repo(shard.model_id, self.inference_engine_classname)
if DEBUG >= 2: print(f"Repo ID for model: {repo_id}")
@@ -576,38 +597,28 @@ class ChatGPTAPI:
if DEBUG >= 2: print(f"Found model files at {cache_dir}, deleting...")
try:
shutil.rmtree(cache_dir)
return web.json_response({
"status": "success",
"message": f"Model {model_name} deleted successfully",
"path": str(cache_dir)
})
return web.json_response({"status": "success", "message": f"Model {model_name} deleted successfully", "path": str(cache_dir)})
except Exception as e:
return web.json_response({
"detail": f"Failed to delete model files: {str(e)}"
}, status=500)
return web.json_response({"detail": f"Failed to delete model files: {str(e)}"}, status=500)
else:
return web.json_response({
"detail": f"Model files not found at {cache_dir}"
}, status=404)
return web.json_response({"detail": f"Model files not found at {cache_dir}"}, status=404)
except Exception as e:
print(f"Error in handle_delete_model: {str(e)}")
traceback.print_exc()
return web.json_response({
"detail": f"Server error: {str(e)}"
}, status=500)
print(f"Error in handle_delete_model: {str(e)}")
traceback.print_exc()
return web.json_response({"detail": f"Server error: {str(e)}"}, status=500)
async def handle_get_initial_models(self, request):
model_data = {}
for model_name, pretty in pretty_name.items():
model_data[model_name] = {
"name": pretty,
"downloaded": None, # Initially unknown
"download_percentage": None, # Change from 0 to null
"total_size": None,
"total_downloaded": None,
"loading": True # Add loading state
}
model_data[model_name] = {
"name": pretty,
"downloaded": None, # Initially unknown
"download_percentage": None, # Change from 0 to null
"total_size": None,
"total_downloaded": None,
"loading": True # Add loading state
}
return web.json_response(model_data)
async def handle_create_animation(self, request):
@@ -633,17 +644,9 @@ class ChatGPTAPI:
if DEBUG >= 2: print(f"Animation temp directory: {tmp_dir}, output file: {output_path}, directory exists: {tmp_dir.exists()}, directory permissions: {oct(tmp_dir.stat().st_mode)[-3:]}")
# Create the animation
create_animation_mp4(
replacement_image_path,
output_path,
device_name,
prompt_text
)
create_animation_mp4(replacement_image_path, output_path, device_name, prompt_text)
return web.json_response({
"status": "success",
"output_path": output_path
})
return web.json_response({"status": "success", "output_path": output_path})
except Exception as e:
if DEBUG >= 2: traceback.print_exc()
@@ -659,10 +662,7 @@ class ChatGPTAPI:
if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400)
asyncio.create_task(self.node.inference_engine.shard_downloader.ensure_shard(shard, self.inference_engine_classname))
return web.json_response({
"status": "success",
"message": f"Download started for model: {model_name}"
})
return web.json_response({"status": "success", "message": f"Download started for model: {model_name}"})
except Exception as e:
if DEBUG >= 2: traceback.print_exc()
return web.json_response({"error": str(e)}, status=500)
@@ -676,10 +676,7 @@ class ChatGPTAPI:
return web.json_response({})
except Exception as e:
if DEBUG >= 2: traceback.print_exc()
return web.json_response(
{"detail": f"Error getting topology: {str(e)}"},
status=500
)
return web.json_response({"detail": f"Error getting topology: {str(e)}"}, status=500)
async def run(self, host: str = "0.0.0.0", port: int = 52415):
runner = web.AppRunner(self.app)
@@ -690,15 +687,14 @@ class ChatGPTAPI:
def base64_decode(self, base64_string):
#decode and reshape image
if base64_string.startswith('data:image'):
base64_string = base64_string.split(',')[1]
base64_string = base64_string.split(',')[1]
image_data = base64.b64decode(base64_string)
img = Image.open(BytesIO(image_data))
W, H = (dim - dim % 64 for dim in (img.width, img.height))
W, H = (dim - dim%64 for dim in (img.width, img.height))
if W != img.width or H != img.height:
if DEBUG >= 2: print(f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}")
img = img.resize((W, H), Image.NEAREST) # use desired downsampling filter
if DEBUG >= 2: print(f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}")
img = img.resize((W, H), Image.NEAREST) # use desired downsampling filter
img = mx.array(np.array(img))
img = (img[:, :, :3].astype(mx.float32) / 255) * 2 - 1
img = (img[:, :, :3].astype(mx.float32)/255)*2 - 1
img = img[None]
return img

View File

@@ -7,7 +7,8 @@ import random
import platform
import psutil
import uuid
import netifaces
from scapy.all import get_if_addr, get_if_list
import re
import subprocess
from pathlib import Path
import tempfile
@@ -231,26 +232,26 @@ def pretty_print_bytes_per_second(bytes_per_second: int) -> str:
def get_all_ip_addresses_and_interfaces():
try:
ip_addresses = []
for interface in netifaces.interfaces():
ifaddresses = netifaces.ifaddresses(interface)
if netifaces.AF_INET in ifaddresses:
for link in ifaddresses[netifaces.AF_INET]:
ip = link['addr']
ip_addresses.append((ip, interface))
for interface in get_if_list():
ip = get_if_addr(interface)
# Include all addresses, including loopback
# Filter out link-local addresses
if not ip.startswith('169.254.') and not ip.startswith('0.0.'):
# Remove "\\Device\\NPF_" prefix from interface name
simplified_interface = re.sub(r'^\\Device\\NPF_', '', interface)
ip_addresses.append((ip, simplified_interface))
return list(set(ip_addresses))
except:
if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
return [("localhost", "lo")]
async def get_macos_interface_type(ifname: str) -> Optional[Tuple[int, str]]:
try:
# Use the shared subprocess_pool
output = await asyncio.get_running_loop().run_in_executor(subprocess_pool, lambda: subprocess.run(
['system_profiler', 'SPNetworkDataType', '-json'],
capture_output=True,
text=True,
close_fds=True
).stdout)
output = await asyncio.get_running_loop().run_in_executor(
subprocess_pool, lambda: subprocess.run(['system_profiler', 'SPNetworkDataType', '-json'], capture_output=True, text=True, close_fds=True).stdout
)
data = json.loads(output)
@@ -276,6 +277,7 @@ async def get_macos_interface_type(ifname: str) -> Optional[Tuple[int, str]]:
return None
async def get_interface_priority_and_type(ifname: str) -> Tuple[int, str]:
# On macOS, try to get interface type using networksetup
if psutil.MACOS:
@@ -283,8 +285,7 @@ async def get_interface_priority_and_type(ifname: str) -> Tuple[int, str]:
if macos_type is not None: return macos_type
# Local container/virtual interfaces
if (ifname.startswith(('docker', 'br-', 'veth', 'cni', 'flannel', 'calico', 'weave')) or
'bridge' in ifname):
if (ifname.startswith(('docker', 'br-', 'veth', 'cni', 'flannel', 'calico', 'weave')) or 'bridge' in ifname):
return (7, "Container Virtual")
# Loopback interface
@@ -310,6 +311,7 @@ async def get_interface_priority_and_type(ifname: str) -> Tuple[int, str]:
# Other physical interfaces
return (2, "Other")
async def shutdown(signal, loop, server):
"""Gracefully shutdown the server and close the asyncio loop."""
print(f"Received exit signal {signal.name}...")
@@ -329,16 +331,16 @@ def is_frozen():
def get_exo_home() -> Path:
if psutil.WINDOWS: docs_folder = Path(os.environ["USERPROFILE"]) / "Documents"
else: docs_folder = Path.home() / "Documents"
if psutil.WINDOWS: docs_folder = Path(os.environ["USERPROFILE"])/"Documents"
else: docs_folder = Path.home()/"Documents"
if not docs_folder.exists(): docs_folder.mkdir(exist_ok=True)
exo_folder = docs_folder / "Exo"
exo_folder = docs_folder/"Exo"
if not exo_folder.exists(): exo_folder.mkdir(exist_ok=True)
return exo_folder
def get_exo_images_dir() -> Path:
exo_home = get_exo_home()
images_dir = exo_home / "Images"
images_dir = exo_home/"Images"
if not images_dir.exists(): images_dir.mkdir(exist_ok=True)
return images_dir

View File

@@ -5,6 +5,7 @@ from exo.helpers import DEBUG # Make sure to import DEBUG
from typing import Tuple, Optional
from abc import ABC, abstractmethod
from .shard import Shard
from exo.download.shard_download import ShardDownloader
class InferenceEngine(ABC):
@@ -13,7 +14,7 @@ class InferenceEngine(ABC):
@abstractmethod
async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
pass
@abstractmethod
async def sample(self, x: np.ndarray) -> np.ndarray:
pass
@@ -32,13 +33,13 @@ class InferenceEngine(ABC):
async def save_checkpoint(self, shard: Shard, path: str):
pass
async def save_session(self, key, value):
self.session[key] = value
async def clear_session(self):
self.session.empty()
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
tokens = await self.encode(shard, prompt)
if shard.model_id != 'stable-diffusion-2-1-base':
@@ -49,13 +50,15 @@ class InferenceEngine(ABC):
return output_data, inference_state
inference_engine_classes = {
"mlx": "MLXDynamicShardInferenceEngine",
"tinygrad": "TinygradDynamicShardInferenceEngine",
"dummy": "DummyInferenceEngine",
}
def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):
def get_inference_engine(inference_engine_name: str, shard_downloader: ShardDownloader):
if DEBUG >= 2:
print(f"get_inference_engine called with: {inference_engine_name}")
if inference_engine_name == "mlx":

View File

@@ -12,7 +12,13 @@ from exo.topology.topology import Topology
from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
from exo.helpers import DEBUG
import json
import mlx.core as mx
import platform
if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
import mlx.core as mx
else:
import numpy as mx
class GRPCPeerHandle(PeerHandle):
def __init__(self, _id: str, address: str, desc: str, device_capabilities: DeviceCapabilities):
@@ -37,11 +43,9 @@ class GRPCPeerHandle(PeerHandle):
async def connect(self):
if self.channel is None:
self.channel = grpc.aio.insecure_channel(self.address, options=[
("grpc.max_metadata_size", 32*1024*1024),
('grpc.max_receive_message_length', 32*1024*1024),
('grpc.max_send_message_length', 32*1024*1024)
])
self.channel = grpc.aio.insecure_channel(
self.address, options=[("grpc.max_metadata_size", 32*1024*1024), ('grpc.max_receive_message_length', 32*1024*1024), ('grpc.max_send_message_length', 32*1024*1024)]
)
self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
await self.channel.channel_ready()
@@ -109,7 +113,7 @@ class GRPCPeerHandle(PeerHandle):
return None
return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
async def send_example(self, shard: Shard, example: np.ndarray, target: np.ndarray, length: np.ndarray, train: bool, request_id: Optional[str] = None) -> Optional[np.array]:
request = node_service_pb2.ExampleRequest(
shard=node_service_pb2.Shard(
@@ -131,7 +135,7 @@ class GRPCPeerHandle(PeerHandle):
return loss, grads
else:
return loss
async def send_loss(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.array]:
request = node_service_pb2.TensorRequest(
shard=node_service_pb2.Shard(
@@ -166,10 +170,7 @@ class GRPCPeerHandle(PeerHandle):
topology = Topology()
for node_id, capabilities in response.nodes.items():
device_capabilities = DeviceCapabilities(
model=capabilities.model,
chip=capabilities.chip,
memory=capabilities.memory,
flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8)
model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory, flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8)
)
topology.update_node(node_id, device_capabilities)
for node_id, peer_connections in response.peer_graph.items():
@@ -193,28 +194,20 @@ class GRPCPeerHandle(PeerHandle):
proto_inference_state = node_service_pb2.InferenceState()
other_data = {}
for k, v in inference_state.items():
if isinstance(v, mx.array):
np_array = np.array(v)
tensor_data = node_service_pb2.Tensor(
tensor_data=np_array.tobytes(),
shape=list(np_array.shape),
dtype=str(np_array.dtype)
)
proto_inference_state.tensor_data[k].CopyFrom(tensor_data)
elif isinstance(v, list) and all(isinstance(item, mx.array) for item in v):
tensor_list = node_service_pb2.TensorList()
for tensor in v:
np_array = np.array(tensor)
tensor_data = node_service_pb2.Tensor(
tensor_data=np_array.tobytes(),
shape=list(np_array.shape),
dtype=str(np_array.dtype)
)
tensor_list.tensors.append(tensor_data)
proto_inference_state.tensor_list_data[k].CopyFrom(tensor_list)
else:
# For non-tensor data, we'll still use JSON
other_data[k] = v
if isinstance(v, mx.array):
np_array = np.array(v)
tensor_data = node_service_pb2.Tensor(tensor_data=np_array.tobytes(), shape=list(np_array.shape), dtype=str(np_array.dtype))
proto_inference_state.tensor_data[k].CopyFrom(tensor_data)
elif isinstance(v, list) and all(isinstance(item, mx.array) for item in v):
tensor_list = node_service_pb2.TensorList()
for tensor in v:
np_array = np.array(tensor)
tensor_data = node_service_pb2.Tensor(tensor_data=np_array.tobytes(), shape=list(np_array.shape), dtype=str(np_array.dtype))
tensor_list.tensors.append(tensor_data)
proto_inference_state.tensor_list_data[k].CopyFrom(tensor_list)
else:
# For non-tensor data, we'll still use JSON
other_data[k] = v
if other_data:
proto_inference_state.other_data_json = json.dumps(other_data)
return proto_inference_state

View File

@@ -3,13 +3,19 @@ from concurrent import futures
import numpy as np
from asyncio import CancelledError
import platform
from . import node_service_pb2
from . import node_service_pb2_grpc
from exo import DEBUG
from exo.inference.shard import Shard
from exo.orchestration import Node
import json
import mlx.core as mx
if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
import mlx.core as mx
else:
import numpy as mx
class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
@@ -74,7 +80,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
tensor_data = result.tobytes() if result is not None else None
return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
async def SendExample(self, request, context):
shard = Shard(
model_id=request.shard.model_id,
@@ -96,7 +102,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
else:
loss = await self.node.process_example(shard, example, target, length, train, request_id)
return node_service_pb2.Loss(loss=loss, grads=None)
async def CollectTopology(self, request, context):
max_depth = request.max_depth
visited = set(request.visited)
@@ -112,12 +118,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
for node_id, cap in topology.nodes.items()
}
peer_graph = {
node_id: node_service_pb2.PeerConnections(
connections=[
node_service_pb2.PeerConnection(to_id=conn.to_id, description=conn.description)
for conn in connections
]
)
node_id: node_service_pb2.PeerConnections(connections=[node_service_pb2.PeerConnection(to_id=conn.to_id, description=conn.description) for conn in connections])
for node_id, connections in topology.peer_graph.items()
}
if DEBUG >= 5: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}")
@@ -131,7 +132,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
if DEBUG >= 5: print(f"Received SendResult request: {request_id=} {result=} {is_finished=}")
result = list(result)
if len(img.tensor_data) > 0:
result=np.frombuffer(img.tensor_data, dtype=np.dtype(img.dtype)).reshape(img.shape)
result = np.frombuffer(img.tensor_data, dtype=np.dtype(img.dtype)).reshape(img.shape)
self.node.on_token.trigger_all(request_id, result, is_finished)
return node_service_pb2.Empty()
@@ -145,21 +146,18 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
async def HealthCheck(self, request, context):
return node_service_pb2.HealthCheckResponse(is_healthy=True)
def deserialize_inference_state(self,inference_state_proto: node_service_pb2.InferenceState) -> dict:
def deserialize_inference_state(self, inference_state_proto: node_service_pb2.InferenceState) -> dict:
inference_state = {}
for k, tensor_data in inference_state_proto.tensor_data.items():
np_array = np.frombuffer(tensor_data.tensor_data, dtype=tensor_data.dtype).reshape(tensor_data.shape)
inference_state[k] = mx.array(np_array)
np_array = np.frombuffer(tensor_data.tensor_data, dtype=tensor_data.dtype).reshape(tensor_data.shape)
inference_state[k] = mx.array(np_array)
for k, tensor_list in inference_state_proto.tensor_list_data.items():
inference_state[k] = [
mx.array(np.frombuffer(tensor.tensor_data, dtype=tensor.dtype).reshape(tensor.shape))
for tensor in tensor_list.tensors
]
inference_state[k] = [mx.array(np.frombuffer(tensor.tensor_data, dtype=tensor.dtype).reshape(tensor.shape)) for tensor in tensor_list.tensors]
if inference_state_proto.other_data_json:
other_data = json.loads(inference_state_proto.other_data_json)
inference_state.update(other_data)
other_data = json.loads(inference_state_proto.other_data_json)
inference_state.update(other_data)
return inference_state

View File

@@ -149,6 +149,8 @@ def device_capabilities() -> DeviceCapabilities:
return mac_device_capabilities()
elif psutil.LINUX:
return linux_device_capabilities()
elif psutil.WINDOWS:
return windows_device_capabilities()
else:
return DeviceCapabilities(
model="Unknown Device",
@@ -194,6 +196,8 @@ def linux_device_capabilities() -> DeviceCapabilities:
if DEBUG >= 2: print(f"NVIDIA device {gpu_name=} {gpu_memory_info=}")
pynvml.nvmlShutdown()
return DeviceCapabilities(
model=f"Linux Box ({gpu_name})",
chip=gpu_name,
@@ -201,13 +205,24 @@ def linux_device_capabilities() -> DeviceCapabilities:
flops=CHIP_FLOPS.get(gpu_name, DeviceFlops(fp32=0, fp16=0, int8=0)),
)
elif Device.DEFAULT == "AMD":
# TODO AMD support
# For AMD GPUs, pyrsmi is the way (Official python package for rocm-smi)
from pyrsmi import rocml
rocml.smi_initialize()
gpu_name = rocml.smi_get_device_name(0).upper()
gpu_memory_info = rocml.smi_get_device_memory_total(0)
if DEBUG >= 2: print(f"AMD device {gpu_name=} {gpu_memory_info=}")
rocml.smi_shutdown()
return DeviceCapabilities(
model="Linux Box (AMD)",
chip="Unknown AMD",
memory=psutil.virtual_memory().total // 2**20,
model="Linux Box ({gpu_name})",
chip={gpu_name},
memory=gpu_memory_info.total // 2**20,
flops=DeviceFlops(fp32=0, fp16=0, int8=0),
)
else:
return DeviceCapabilities(
model=f"Linux Box (Device: {Device.DEFAULT})",
@@ -215,3 +230,74 @@ def linux_device_capabilities() -> DeviceCapabilities:
memory=psutil.virtual_memory().total // 2**20,
flops=DeviceFlops(fp32=0, fp16=0, int8=0),
)
def windows_device_capabilities() -> DeviceCapabilities:
import psutil
def get_gpu_info():
import win32com.client # install pywin32
wmiObj = win32com.client.GetObject("winmgmts:\\\\.\\root\\cimv2")
gpus = wmiObj.ExecQuery("SELECT * FROM Win32_VideoController")
gpu_info = []
for gpu in gpus:
info = {
"Name": gpu.Name,
"AdapterRAM": gpu.AdapterRAM, # Bug in this property, returns -ve for VRAM > 4GB (uint32 overflow)
"DriverVersion": gpu.DriverVersion,
"VideoProcessor": gpu.VideoProcessor
}
gpu_info.append(info)
return gpu_info
gpus_info = get_gpu_info()
gpu_names = [gpu['Name'] for gpu in gpus_info]
contains_nvidia = any('nvidia' in gpu_name.lower() for gpu_name in gpu_names)
contains_amd = any('amd' in gpu_name.lower() for gpu_name in gpu_names)
if contains_nvidia:
import pynvml
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
gpu_raw_name = pynvml.nvmlDeviceGetName(handle).upper()
gpu_name = gpu_raw_name.rsplit(" ", 1)[0] if gpu_raw_name.endswith("GB") else gpu_raw_name
gpu_memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
if DEBUG >= 2: print(f"NVIDIA device {gpu_name=} {gpu_memory_info=}")
return DeviceCapabilities(
model=f"Windows Box ({gpu_name})",
chip=gpu_name,
memory=gpu_memory_info.total // 2**20,
flops=CHIP_FLOPS.get(gpu_name, DeviceFlops(fp32=0, fp16=0, int8=0)),
)
elif contains_amd:
# For AMD GPUs, pyrsmi is the way (Official python package for rocm-smi)
from pyrsmi import rocml
rocml.smi_initialize()
gpu_name = rocml.smi_get_device_name(0).upper()
gpu_memory_info = rocml.smi_get_device_memory_total(0)
if DEBUG >= 2: print(f"AMD device {gpu_name=} {gpu_memory_info=}")
rocml.smi_shutdown()
return DeviceCapabilities(
model="Windows Box ({gpu_name})",
chip={gpu_name},
memory=gpu_memory_info.total // 2**20,
flops=DeviceFlops(fp32=0, fp16=0, int8=0),
)
else:
return DeviceCapabilities(
model=f"Windows Box (Device: Unknown)",
chip=f"Unknown Chip (Device(s): {gpu_names})",
memory=psutil.virtual_memory().total // 2**20,
flops=DeviceFlops(fp32=0, fp16=0, int8=0),
)

View File

@@ -23,7 +23,7 @@ def run():
"--macos-app-name=exo",
"--macos-app-mode=gui",
"--macos-app-version=0.0.1",
"--macos-signed-app-name=com.exolabs.exo",
"--macos-signed-app-name=net.exolabs.exo",
"--include-distribution-meta=mlx",
"--include-module=mlx._reprlib_fix",
"--include-module=mlx._os_warning",

View File

@@ -1,5 +1,6 @@
import sys
import platform
import subprocess
from setuptools import find_packages, setup
@@ -11,7 +12,6 @@ install_requires = [
"grpcio==1.68.0",
"grpcio-tools==1.68.0",
"Jinja2==3.1.4",
"netifaces==0.11.0",
"numpy==2.0.0",
"nuitka==2.5.1",
"nvidia-ml-py==12.560.30",
@@ -23,6 +23,7 @@ install_requires = [
"pydantic==2.9.2",
"requests==2.32.3",
"rich==13.7.1",
"scapy==2.6.1",
"tenacity==9.0.0",
"tqdm==4.66.4",
"transformers==4.46.3",
@@ -31,19 +32,47 @@ install_requires = [
]
extras_require = {
"formatting": [
"yapf==0.40.2",
],
"apple_silicon": [
"formatting": ["yapf==0.40.2",], "apple_silicon": [
"mlx==0.20.0",
"mlx-lm==0.19.3",
],
], "windows": ["pywin32==308",], "nvidia-gpu": ["nvidia-ml-py==12.560.30",], "amd-gpu": ["pyrsmi==0.2.0"]
}
# Check if running on macOS with Apple Silicon
if sys.platform.startswith("darwin") and platform.machine() == "arm64":
install_requires.extend(extras_require["apple_silicon"])
# Check if running Windows
if sys.platform.startswith("win32"):
install_requires.extend(extras_require["windows"])
def _add_gpu_requires():
global install_requires
# Add Nvidia-GPU
try:
out = subprocess.run(['nvidia-smi', '--query-gpu=name', '--format=csv,noheader'], shell=True, text=True, capture_output=True, check=False)
if out.returncode == 0:
install_requires.extend(extras_require["nvidia-gpu"])
except subprocess.CalledProcessError:
pass
# Add AMD-GPU
# This will mostly work only on Linux, amd/rocm-smi is not yet supported on Windows
try:
out = subprocess.run(['amd-smi', 'list', '--csv'], shell=True, text=True, capture_output=True, check=False)
if out.returncode == 0:
install_requires.extend(extras_require["amd-gpu"])
except:
out = subprocess.run(['rocm-smi', 'list', '--csv'], shell=True, text=True, capture_output=True, check=False)
if out.returncode == 0:
install_requires.extend(extras_require["amd-gpu"])
finally:
pass
_add_gpu_requires()
setup(
name="exo",
version="0.0.1",