mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-07 04:32:28 -05:00
Compare commits
13 Commits
v0.0.8-alp
...
v0.0.12-al
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
07ceb19f0a | ||
|
|
27b4577f38 | ||
|
|
a70943f8d2 | ||
|
|
410d901505 | ||
|
|
5c4ce5392c | ||
|
|
819ec7626e | ||
|
|
ba5bb3e171 | ||
|
|
f4bbcf4c8f | ||
|
|
b9eccedc3d | ||
|
|
5f06aa2759 | ||
|
|
349b5344eb | ||
|
|
df3624d27a | ||
|
|
6737e36e23 |
@@ -21,13 +21,20 @@ from PIL import Image
|
||||
import numpy as np
|
||||
import base64
|
||||
from io import BytesIO
|
||||
import mlx.core as mx
|
||||
import platform
|
||||
|
||||
if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
|
||||
import mlx.core as mx
|
||||
else:
|
||||
import numpy as mx
|
||||
|
||||
import tempfile
|
||||
from exo.download.hf.hf_shard_download import HFShardDownloader
|
||||
import shutil
|
||||
from exo.download.hf.hf_helpers import get_hf_home, get_repo_root
|
||||
from exo.apputil import create_animation_mp4
|
||||
|
||||
|
||||
class Message:
|
||||
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]], tools: Optional[List[Dict]] = None):
|
||||
self.role = role
|
||||
@@ -41,7 +48,6 @@ class Message:
|
||||
return data
|
||||
|
||||
|
||||
|
||||
class ChatCompletionRequest:
|
||||
def __init__(self, model: str, messages: List[Message], temperature: float, tools: Optional[List[Dict]] = None):
|
||||
self.model = model
|
||||
@@ -132,16 +138,24 @@ def remap_messages(messages: List[Message]) -> List[Message]:
|
||||
|
||||
def build_prompt(tokenizer, _messages: List[Message], tools: Optional[List[Dict]] = None):
|
||||
messages = remap_messages(_messages)
|
||||
chat_template_args = {
|
||||
"conversation": [m.to_dict() for m in messages],
|
||||
"tokenize": False,
|
||||
"add_generation_prompt": True
|
||||
}
|
||||
if tools: chat_template_args["tools"] = tools
|
||||
chat_template_args = {"conversation": [m.to_dict() for m in messages], "tokenize": False, "add_generation_prompt": True}
|
||||
if tools:
|
||||
chat_template_args["tools"] = tools
|
||||
|
||||
prompt = tokenizer.apply_chat_template(**chat_template_args)
|
||||
print(f"!!! Prompt: {prompt}")
|
||||
return prompt
|
||||
try:
|
||||
prompt = tokenizer.apply_chat_template(**chat_template_args)
|
||||
if DEBUG >= 3: print(f"!!! Prompt: {prompt}")
|
||||
return prompt
|
||||
except UnicodeEncodeError:
|
||||
# Handle Unicode encoding by ensuring everything is UTF-8
|
||||
chat_template_args["conversation"] = [
|
||||
{k: v.encode('utf-8').decode('utf-8') if isinstance(v, str) else v
|
||||
for k, v in m.to_dict().items()}
|
||||
for m in messages
|
||||
]
|
||||
prompt = tokenizer.apply_chat_template(**chat_template_args)
|
||||
if DEBUG >= 3: print(f"!!! Prompt (UTF-8 encoded): {prompt}")
|
||||
return prompt
|
||||
|
||||
|
||||
def parse_message(data: dict):
|
||||
@@ -165,8 +179,17 @@ class PromptSession:
|
||||
self.timestamp = timestamp
|
||||
self.prompt = prompt
|
||||
|
||||
|
||||
class ChatGPTAPI:
|
||||
def __init__(self, node: Node, inference_engine_classname: str, response_timeout: int = 90, on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None, default_model: Optional[str] = None, system_prompt: Optional[str] = None):
|
||||
def __init__(
|
||||
self,
|
||||
node: Node,
|
||||
inference_engine_classname: str,
|
||||
response_timeout: int = 90,
|
||||
on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None,
|
||||
default_model: Optional[str] = None,
|
||||
system_prompt: Optional[str] = None
|
||||
):
|
||||
self.node = node
|
||||
self.inference_engine_classname = inference_engine_classname
|
||||
self.response_timeout = response_timeout
|
||||
@@ -202,18 +225,22 @@ class ChatGPTAPI:
|
||||
cors.add(self.app.router.add_post("/download", self.handle_post_download), {"*": cors_options})
|
||||
cors.add(self.app.router.add_get("/topology", self.handle_get_topology), {"*": cors_options})
|
||||
|
||||
|
||||
# Add static routes
|
||||
if "__compiled__" not in globals():
|
||||
self.static_dir = Path(__file__).parent.parent/"tinychat"
|
||||
self.app.router.add_get("/", self.handle_root)
|
||||
self.app.router.add_static("/", self.static_dir, name="static")
|
||||
self.app.router.add_static('/images/', get_exo_images_dir(), name='static_images')
|
||||
|
||||
# Always add images route, regardless of compilation status
|
||||
self.images_dir = get_exo_images_dir()
|
||||
self.images_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.app.router.add_static('/images/', self.images_dir, name='static_images')
|
||||
|
||||
self.app.middlewares.append(self.timeout_middleware)
|
||||
self.app.middlewares.append(self.log_request)
|
||||
|
||||
async def handle_quit(self, request):
|
||||
if DEBUG>=1: print("Received quit signal")
|
||||
if DEBUG >= 1: print("Received quit signal")
|
||||
response = web.json_response({"detail": "Quit signal received"}, status=200)
|
||||
await response.prepare(request)
|
||||
await response.write_eof()
|
||||
@@ -243,61 +270,48 @@ class ChatGPTAPI:
|
||||
|
||||
async def handle_model_support(self, request):
|
||||
try:
|
||||
response = web.StreamResponse(
|
||||
status=200,
|
||||
reason='OK',
|
||||
headers={
|
||||
'Content-Type': 'text/event-stream',
|
||||
'Cache-Control': 'no-cache',
|
||||
'Connection': 'keep-alive',
|
||||
}
|
||||
)
|
||||
await response.prepare(request)
|
||||
response = web.StreamResponse(status=200, reason='OK', headers={
|
||||
'Content-Type': 'text/event-stream',
|
||||
'Cache-Control': 'no-cache',
|
||||
'Connection': 'keep-alive',
|
||||
})
|
||||
await response.prepare(request)
|
||||
|
||||
async def process_model(model_name, pretty):
|
||||
if model_name in model_cards:
|
||||
model_info = model_cards[model_name]
|
||||
async def process_model(model_name, pretty):
|
||||
if model_name in model_cards:
|
||||
model_info = model_cards[model_name]
|
||||
|
||||
if self.inference_engine_classname in model_info.get("repo", {}):
|
||||
shard = build_base_shard(model_name, self.inference_engine_classname)
|
||||
if shard:
|
||||
downloader = HFShardDownloader(quick_check=True)
|
||||
downloader.current_shard = shard
|
||||
downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
|
||||
status = await downloader.get_shard_download_status()
|
||||
if self.inference_engine_classname in model_info.get("repo", {}):
|
||||
shard = build_base_shard(model_name, self.inference_engine_classname)
|
||||
if shard:
|
||||
downloader = HFShardDownloader(quick_check=True)
|
||||
downloader.current_shard = shard
|
||||
downloader.current_repo_id = get_repo(shard.model_id, self.inference_engine_classname)
|
||||
status = await downloader.get_shard_download_status()
|
||||
|
||||
download_percentage = status.get("overall") if status else None
|
||||
total_size = status.get("total_size") if status else None
|
||||
total_downloaded = status.get("total_downloaded") if status else False
|
||||
download_percentage = status.get("overall") if status else None
|
||||
total_size = status.get("total_size") if status else None
|
||||
total_downloaded = status.get("total_downloaded") if status else False
|
||||
|
||||
model_data = {
|
||||
model_name: {
|
||||
"name": pretty,
|
||||
"downloaded": download_percentage == 100 if download_percentage is not None else False,
|
||||
"download_percentage": download_percentage,
|
||||
"total_size": total_size,
|
||||
"total_downloaded": total_downloaded
|
||||
}
|
||||
}
|
||||
model_data = {
|
||||
model_name: {
|
||||
"name": pretty, "downloaded": download_percentage == 100 if download_percentage is not None else False, "download_percentage": download_percentage, "total_size": total_size,
|
||||
"total_downloaded": total_downloaded
|
||||
}
|
||||
}
|
||||
|
||||
await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
|
||||
await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
|
||||
|
||||
# Process all models in parallel
|
||||
await asyncio.gather(*[
|
||||
process_model(model_name, pretty)
|
||||
for model_name, pretty in pretty_name.items()
|
||||
])
|
||||
# Process all models in parallel
|
||||
await asyncio.gather(*[process_model(model_name, pretty) for model_name, pretty in pretty_name.items()])
|
||||
|
||||
await response.write(b"data: [DONE]\n\n")
|
||||
return response
|
||||
await response.write(b"data: [DONE]\n\n")
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in handle_model_support: {str(e)}")
|
||||
traceback.print_exc()
|
||||
return web.json_response(
|
||||
{"detail": f"Server error: {str(e)}"},
|
||||
status=500
|
||||
)
|
||||
print(f"Error in handle_model_support: {str(e)}")
|
||||
traceback.print_exc()
|
||||
return web.json_response({"detail": f"Server error: {str(e)}"}, status=500)
|
||||
|
||||
async def handle_get_models(self, request):
|
||||
models_list = [{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()]
|
||||
@@ -466,7 +480,6 @@ class ChatGPTAPI:
|
||||
deregistered_callback = self.node.on_token.deregister(callback_id)
|
||||
if DEBUG >= 2: print(f"Deregister {callback_id=} {deregistered_callback=}")
|
||||
|
||||
|
||||
async def handle_post_image_generations(self, request):
|
||||
data = await request.json()
|
||||
|
||||
@@ -479,7 +492,7 @@ class ChatGPTAPI:
|
||||
shard = build_base_shard(model, self.inference_engine_classname)
|
||||
if DEBUG >= 2: print(f"shard: {shard}")
|
||||
if not shard:
|
||||
return web.json_response({"error": f"Unsupported model: {model} with inference engine {self.inference_engine_classname}"}, status=400)
|
||||
return web.json_response({"error": f"Unsupported model: {model} with inference engine {self.inference_engine_classname}"}, status=400)
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
callback_id = f"chatgpt-api-wait-response-{request_id}"
|
||||
@@ -491,77 +504,85 @@ class ChatGPTAPI:
|
||||
img = None
|
||||
await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id, inference_state={"image": img}))), timeout=self.response_timeout)
|
||||
|
||||
|
||||
response = web.StreamResponse(status=200, reason='OK', headers={'Content-Type': 'application/octet-stream',"Cache-Control": "no-cache",})
|
||||
response = web.StreamResponse(status=200, reason='OK', headers={
|
||||
'Content-Type': 'application/octet-stream',
|
||||
"Cache-Control": "no-cache",
|
||||
})
|
||||
await response.prepare(request)
|
||||
|
||||
def get_progress_bar(current_step, total_steps, bar_length=50):
|
||||
# Calculate the percentage of completion
|
||||
percent = float(current_step) / total_steps
|
||||
percent = float(current_step)/total_steps
|
||||
# Calculate the number of hashes to display
|
||||
arrow = '-' * int(round(percent * bar_length) - 1) + '>'
|
||||
spaces = ' ' * (bar_length - len(arrow))
|
||||
|
||||
arrow = '-'*int(round(percent*bar_length) - 1) + '>'
|
||||
spaces = ' '*(bar_length - len(arrow))
|
||||
|
||||
# Create the progress bar string
|
||||
progress_bar = f'Progress: [{arrow}{spaces}] {int(percent * 100)}% ({current_step}/{total_steps})'
|
||||
return progress_bar
|
||||
|
||||
async def stream_image(_request_id: str, result, is_finished: bool):
|
||||
if isinstance(result, list):
|
||||
await response.write(json.dumps({'progress': get_progress_bar((result[0]), (result[1]))}).encode('utf-8') + b'\n')
|
||||
if isinstance(result, list):
|
||||
await response.write(json.dumps({'progress': get_progress_bar((result[0]), (result[1]))}).encode('utf-8') + b'\n')
|
||||
|
||||
elif isinstance(result, np.ndarray):
|
||||
elif isinstance(result, np.ndarray):
|
||||
try:
|
||||
im = Image.fromarray(np.array(result))
|
||||
images_folder = get_exo_images_dir()
|
||||
# Save the image to a file
|
||||
image_filename = f"{_request_id}.png"
|
||||
image_path = images_folder / image_filename
|
||||
image_path = self.images_dir/image_filename
|
||||
im.save(image_path)
|
||||
image_url = request.app.router['static_images'].url_for(filename=image_filename)
|
||||
base_url = f"{request.scheme}://{request.host}"
|
||||
# Construct the full URL correctly
|
||||
full_image_url = base_url + str(image_url)
|
||||
|
||||
await response.write(json.dumps({'images': [{'url': str(full_image_url), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
|
||||
# Get URL for the saved image
|
||||
try:
|
||||
image_url = request.app.router['static_images'].url_for(filename=image_filename)
|
||||
base_url = f"{request.scheme}://{request.host}"
|
||||
full_image_url = base_url + str(image_url)
|
||||
|
||||
await response.write(json.dumps({'images': [{'url': str(full_image_url), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
|
||||
except KeyError as e:
|
||||
if DEBUG >= 2: print(f"Error getting image URL: {e}")
|
||||
# Fallback to direct file path if URL generation fails
|
||||
await response.write(json.dumps({'images': [{'url': str(image_path), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
|
||||
|
||||
if is_finished:
|
||||
await response.write_eof()
|
||||
|
||||
|
||||
except Exception as e:
|
||||
if DEBUG >= 2: print(f"Error processing image: {e}")
|
||||
if DEBUG >= 2: traceback.print_exc()
|
||||
await response.write(json.dumps({'error': str(e)}).encode('utf-8') + b'\n')
|
||||
|
||||
stream_task = None
|
||||
|
||||
def on_result(_request_id: str, result, is_finished: bool):
|
||||
nonlocal stream_task
|
||||
stream_task = asyncio.create_task(stream_image(_request_id, result, is_finished))
|
||||
return _request_id == request_id and is_finished
|
||||
nonlocal stream_task
|
||||
stream_task = asyncio.create_task(stream_image(_request_id, result, is_finished))
|
||||
return _request_id == request_id and is_finished
|
||||
|
||||
await callback.wait(on_result, timeout=self.response_timeout*10)
|
||||
|
||||
|
||||
if stream_task:
|
||||
# Wait for the stream task to complete before returning
|
||||
await stream_task
|
||||
# Wait for the stream task to complete before returning
|
||||
await stream_task
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
if DEBUG >= 2: traceback.print_exc()
|
||||
return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
|
||||
|
||||
if DEBUG >= 2: traceback.print_exc()
|
||||
return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
|
||||
|
||||
async def handle_delete_model(self, request):
|
||||
try:
|
||||
model_name = request.match_info.get('model_name')
|
||||
if DEBUG >= 2: print(f"Attempting to delete model: {model_name}")
|
||||
|
||||
if not model_name or model_name not in model_cards:
|
||||
return web.json_response(
|
||||
{"detail": f"Invalid model name: {model_name}"},
|
||||
status=400
|
||||
)
|
||||
return web.json_response({"detail": f"Invalid model name: {model_name}"}, status=400)
|
||||
|
||||
shard = build_base_shard(model_name, self.inference_engine_classname)
|
||||
if not shard:
|
||||
return web.json_response(
|
||||
{"detail": "Could not build shard for model"},
|
||||
status=400
|
||||
)
|
||||
return web.json_response({"detail": "Could not build shard for model"}, status=400)
|
||||
|
||||
repo_id = get_repo(shard.model_id, self.inference_engine_classname)
|
||||
if DEBUG >= 2: print(f"Repo ID for model: {repo_id}")
|
||||
@@ -576,38 +597,28 @@ class ChatGPTAPI:
|
||||
if DEBUG >= 2: print(f"Found model files at {cache_dir}, deleting...")
|
||||
try:
|
||||
shutil.rmtree(cache_dir)
|
||||
return web.json_response({
|
||||
"status": "success",
|
||||
"message": f"Model {model_name} deleted successfully",
|
||||
"path": str(cache_dir)
|
||||
})
|
||||
return web.json_response({"status": "success", "message": f"Model {model_name} deleted successfully", "path": str(cache_dir)})
|
||||
except Exception as e:
|
||||
return web.json_response({
|
||||
"detail": f"Failed to delete model files: {str(e)}"
|
||||
}, status=500)
|
||||
return web.json_response({"detail": f"Failed to delete model files: {str(e)}"}, status=500)
|
||||
else:
|
||||
return web.json_response({
|
||||
"detail": f"Model files not found at {cache_dir}"
|
||||
}, status=404)
|
||||
return web.json_response({"detail": f"Model files not found at {cache_dir}"}, status=404)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in handle_delete_model: {str(e)}")
|
||||
traceback.print_exc()
|
||||
return web.json_response({
|
||||
"detail": f"Server error: {str(e)}"
|
||||
}, status=500)
|
||||
print(f"Error in handle_delete_model: {str(e)}")
|
||||
traceback.print_exc()
|
||||
return web.json_response({"detail": f"Server error: {str(e)}"}, status=500)
|
||||
|
||||
async def handle_get_initial_models(self, request):
|
||||
model_data = {}
|
||||
for model_name, pretty in pretty_name.items():
|
||||
model_data[model_name] = {
|
||||
"name": pretty,
|
||||
"downloaded": None, # Initially unknown
|
||||
"download_percentage": None, # Change from 0 to null
|
||||
"total_size": None,
|
||||
"total_downloaded": None,
|
||||
"loading": True # Add loading state
|
||||
}
|
||||
model_data[model_name] = {
|
||||
"name": pretty,
|
||||
"downloaded": None, # Initially unknown
|
||||
"download_percentage": None, # Change from 0 to null
|
||||
"total_size": None,
|
||||
"total_downloaded": None,
|
||||
"loading": True # Add loading state
|
||||
}
|
||||
return web.json_response(model_data)
|
||||
|
||||
async def handle_create_animation(self, request):
|
||||
@@ -633,17 +644,9 @@ class ChatGPTAPI:
|
||||
if DEBUG >= 2: print(f"Animation temp directory: {tmp_dir}, output file: {output_path}, directory exists: {tmp_dir.exists()}, directory permissions: {oct(tmp_dir.stat().st_mode)[-3:]}")
|
||||
|
||||
# Create the animation
|
||||
create_animation_mp4(
|
||||
replacement_image_path,
|
||||
output_path,
|
||||
device_name,
|
||||
prompt_text
|
||||
)
|
||||
create_animation_mp4(replacement_image_path, output_path, device_name, prompt_text)
|
||||
|
||||
return web.json_response({
|
||||
"status": "success",
|
||||
"output_path": output_path
|
||||
})
|
||||
return web.json_response({"status": "success", "output_path": output_path})
|
||||
|
||||
except Exception as e:
|
||||
if DEBUG >= 2: traceback.print_exc()
|
||||
@@ -659,10 +662,7 @@ class ChatGPTAPI:
|
||||
if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400)
|
||||
asyncio.create_task(self.node.inference_engine.shard_downloader.ensure_shard(shard, self.inference_engine_classname))
|
||||
|
||||
return web.json_response({
|
||||
"status": "success",
|
||||
"message": f"Download started for model: {model_name}"
|
||||
})
|
||||
return web.json_response({"status": "success", "message": f"Download started for model: {model_name}"})
|
||||
except Exception as e:
|
||||
if DEBUG >= 2: traceback.print_exc()
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
@@ -676,10 +676,7 @@ class ChatGPTAPI:
|
||||
return web.json_response({})
|
||||
except Exception as e:
|
||||
if DEBUG >= 2: traceback.print_exc()
|
||||
return web.json_response(
|
||||
{"detail": f"Error getting topology: {str(e)}"},
|
||||
status=500
|
||||
)
|
||||
return web.json_response({"detail": f"Error getting topology: {str(e)}"}, status=500)
|
||||
|
||||
async def run(self, host: str = "0.0.0.0", port: int = 52415):
|
||||
runner = web.AppRunner(self.app)
|
||||
@@ -690,15 +687,14 @@ class ChatGPTAPI:
|
||||
def base64_decode(self, base64_string):
|
||||
#decode and reshape image
|
||||
if base64_string.startswith('data:image'):
|
||||
base64_string = base64_string.split(',')[1]
|
||||
base64_string = base64_string.split(',')[1]
|
||||
image_data = base64.b64decode(base64_string)
|
||||
img = Image.open(BytesIO(image_data))
|
||||
W, H = (dim - dim % 64 for dim in (img.width, img.height))
|
||||
W, H = (dim - dim%64 for dim in (img.width, img.height))
|
||||
if W != img.width or H != img.height:
|
||||
if DEBUG >= 2: print(f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}")
|
||||
img = img.resize((W, H), Image.NEAREST) # use desired downsampling filter
|
||||
if DEBUG >= 2: print(f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}")
|
||||
img = img.resize((W, H), Image.NEAREST) # use desired downsampling filter
|
||||
img = mx.array(np.array(img))
|
||||
img = (img[:, :, :3].astype(mx.float32) / 255) * 2 - 1
|
||||
img = (img[:, :, :3].astype(mx.float32)/255)*2 - 1
|
||||
img = img[None]
|
||||
return img
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ from PIL import Image, ImageDraw, ImageFont, ImageFilter
|
||||
import os
|
||||
import numpy as np
|
||||
import cv2
|
||||
import sys
|
||||
|
||||
def draw_rounded_rectangle(draw, coords, radius, fill):
|
||||
left, top, right, bottom = coords
|
||||
@@ -80,14 +81,20 @@ def create_animation_mp4(
|
||||
font = ImageFont.load_default()
|
||||
promptfont = ImageFont.load_default()
|
||||
|
||||
# Get the base directory for images when running as a bundled app
|
||||
if hasattr(sys, '_MEIPASS'):
|
||||
base_dir = os.path.join(sys._MEIPASS, "exo", "apputil", "baseimages")
|
||||
else:
|
||||
base_dir = os.path.join(os.path.dirname(__file__), "baseimages")
|
||||
|
||||
# Process first frame
|
||||
base_img = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image1.png"))
|
||||
base_img = Image.open(os.path.join(base_dir, "image1.png"))
|
||||
draw = ImageDraw.Draw(base_img)
|
||||
draw_centered_text_rounded(draw, device_name, font, device_coords)
|
||||
frames.extend([crop_image(base_img)] * 30) # 1 second at 30fps
|
||||
|
||||
# Process second frame with typing animation
|
||||
base_img2 = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image2.png"))
|
||||
base_img2 = Image.open(os.path.join(base_dir, "image2.png"))
|
||||
for i in range(len(prompt_text) + 1):
|
||||
current_frame = base_img2.copy()
|
||||
draw = ImageDraw.Draw(current_frame)
|
||||
@@ -101,7 +108,7 @@ def create_animation_mp4(
|
||||
|
||||
# Create blur sequence
|
||||
replacement_img = Image.open(replacement_image_path)
|
||||
base_img = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image3.png"))
|
||||
base_img = Image.open(os.path.join(base_dir, "image3.png"))
|
||||
blur_steps = [int(80 * (1 - i/8)) for i in range(9)]
|
||||
|
||||
for i, blur_amount in enumerate(blur_steps):
|
||||
@@ -123,7 +130,7 @@ def create_animation_mp4(
|
||||
frames.extend([crop_image(new_frame)] * 15) # 0.5 seconds at 30fps
|
||||
|
||||
# Create and add final frame (image4)
|
||||
final_base = Image.open(os.path.join(os.path.dirname(__file__), "baseimages", "image4.png"))
|
||||
final_base = Image.open(os.path.join(base_dir, "image4.png"))
|
||||
draw = ImageDraw.Draw(final_base)
|
||||
|
||||
draw_centered_text_rounded(draw, device_name, font, device_coords)
|
||||
@@ -158,4 +165,4 @@ def create_animation_mp4(
|
||||
out.write(frame_array)
|
||||
|
||||
out.release()
|
||||
print(f"Video saved successfully to {output_path}")
|
||||
print(f"Video saved successfully to {output_path}")
|
||||
|
||||
@@ -7,7 +7,8 @@ import random
|
||||
import platform
|
||||
import psutil
|
||||
import uuid
|
||||
import netifaces
|
||||
from scapy.all import get_if_addr, get_if_list
|
||||
import re
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
@@ -231,26 +232,26 @@ def pretty_print_bytes_per_second(bytes_per_second: int) -> str:
|
||||
def get_all_ip_addresses_and_interfaces():
|
||||
try:
|
||||
ip_addresses = []
|
||||
for interface in netifaces.interfaces():
|
||||
ifaddresses = netifaces.ifaddresses(interface)
|
||||
if netifaces.AF_INET in ifaddresses:
|
||||
for link in ifaddresses[netifaces.AF_INET]:
|
||||
ip = link['addr']
|
||||
ip_addresses.append((ip, interface))
|
||||
for interface in get_if_list():
|
||||
ip = get_if_addr(interface)
|
||||
# Include all addresses, including loopback
|
||||
# Filter out link-local addresses
|
||||
if not ip.startswith('169.254.') and not ip.startswith('0.0.'):
|
||||
# Remove "\\Device\\NPF_" prefix from interface name
|
||||
simplified_interface = re.sub(r'^\\Device\\NPF_', '', interface)
|
||||
ip_addresses.append((ip, simplified_interface))
|
||||
return list(set(ip_addresses))
|
||||
except:
|
||||
if DEBUG >= 1: print("Failed to get all IP addresses. Defaulting to localhost.")
|
||||
return [("localhost", "lo")]
|
||||
|
||||
|
||||
async def get_macos_interface_type(ifname: str) -> Optional[Tuple[int, str]]:
|
||||
try:
|
||||
# Use the shared subprocess_pool
|
||||
output = await asyncio.get_running_loop().run_in_executor(subprocess_pool, lambda: subprocess.run(
|
||||
['system_profiler', 'SPNetworkDataType', '-json'],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
close_fds=True
|
||||
).stdout)
|
||||
output = await asyncio.get_running_loop().run_in_executor(
|
||||
subprocess_pool, lambda: subprocess.run(['system_profiler', 'SPNetworkDataType', '-json'], capture_output=True, text=True, close_fds=True).stdout
|
||||
)
|
||||
|
||||
data = json.loads(output)
|
||||
|
||||
@@ -276,6 +277,7 @@ async def get_macos_interface_type(ifname: str) -> Optional[Tuple[int, str]]:
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def get_interface_priority_and_type(ifname: str) -> Tuple[int, str]:
|
||||
# On macOS, try to get interface type using networksetup
|
||||
if psutil.MACOS:
|
||||
@@ -283,8 +285,7 @@ async def get_interface_priority_and_type(ifname: str) -> Tuple[int, str]:
|
||||
if macos_type is not None: return macos_type
|
||||
|
||||
# Local container/virtual interfaces
|
||||
if (ifname.startswith(('docker', 'br-', 'veth', 'cni', 'flannel', 'calico', 'weave')) or
|
||||
'bridge' in ifname):
|
||||
if (ifname.startswith(('docker', 'br-', 'veth', 'cni', 'flannel', 'calico', 'weave')) or 'bridge' in ifname):
|
||||
return (7, "Container Virtual")
|
||||
|
||||
# Loopback interface
|
||||
@@ -310,6 +311,7 @@ async def get_interface_priority_and_type(ifname: str) -> Tuple[int, str]:
|
||||
# Other physical interfaces
|
||||
return (2, "Other")
|
||||
|
||||
|
||||
async def shutdown(signal, loop, server):
|
||||
"""Gracefully shutdown the server and close the asyncio loop."""
|
||||
print(f"Received exit signal {signal.name}...")
|
||||
@@ -329,16 +331,16 @@ def is_frozen():
|
||||
|
||||
|
||||
def get_exo_home() -> Path:
|
||||
if psutil.WINDOWS: docs_folder = Path(os.environ["USERPROFILE"]) / "Documents"
|
||||
else: docs_folder = Path.home() / "Documents"
|
||||
if psutil.WINDOWS: docs_folder = Path(os.environ["USERPROFILE"])/"Documents"
|
||||
else: docs_folder = Path.home()/"Documents"
|
||||
if not docs_folder.exists(): docs_folder.mkdir(exist_ok=True)
|
||||
exo_folder = docs_folder / "Exo"
|
||||
exo_folder = docs_folder/"Exo"
|
||||
if not exo_folder.exists(): exo_folder.mkdir(exist_ok=True)
|
||||
return exo_folder
|
||||
|
||||
|
||||
def get_exo_images_dir() -> Path:
|
||||
exo_home = get_exo_home()
|
||||
images_dir = exo_home / "Images"
|
||||
images_dir = exo_home/"Images"
|
||||
if not images_dir.exists(): images_dir.mkdir(exist_ok=True)
|
||||
return images_dir
|
||||
|
||||
@@ -5,6 +5,7 @@ from exo.helpers import DEBUG # Make sure to import DEBUG
|
||||
from typing import Tuple, Optional
|
||||
from abc import ABC, abstractmethod
|
||||
from .shard import Shard
|
||||
from exo.download.shard_download import ShardDownloader
|
||||
|
||||
|
||||
class InferenceEngine(ABC):
|
||||
@@ -13,7 +14,7 @@ class InferenceEngine(ABC):
|
||||
@abstractmethod
|
||||
async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def sample(self, x: np.ndarray) -> np.ndarray:
|
||||
pass
|
||||
@@ -32,13 +33,13 @@ class InferenceEngine(ABC):
|
||||
|
||||
async def save_checkpoint(self, shard: Shard, path: str):
|
||||
pass
|
||||
|
||||
|
||||
async def save_session(self, key, value):
|
||||
self.session[key] = value
|
||||
|
||||
|
||||
async def clear_session(self):
|
||||
self.session.empty()
|
||||
|
||||
|
||||
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
|
||||
tokens = await self.encode(shard, prompt)
|
||||
if shard.model_id != 'stable-diffusion-2-1-base':
|
||||
@@ -49,13 +50,15 @@ class InferenceEngine(ABC):
|
||||
|
||||
return output_data, inference_state
|
||||
|
||||
|
||||
inference_engine_classes = {
|
||||
"mlx": "MLXDynamicShardInferenceEngine",
|
||||
"tinygrad": "TinygradDynamicShardInferenceEngine",
|
||||
"dummy": "DummyInferenceEngine",
|
||||
}
|
||||
|
||||
def get_inference_engine(inference_engine_name: str, shard_downloader: 'ShardDownloader'):
|
||||
|
||||
def get_inference_engine(inference_engine_name: str, shard_downloader: ShardDownloader):
|
||||
if DEBUG >= 2:
|
||||
print(f"get_inference_engine called with: {inference_engine_name}")
|
||||
if inference_engine_name == "mlx":
|
||||
|
||||
@@ -12,7 +12,13 @@ from exo.topology.topology import Topology
|
||||
from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
|
||||
from exo.helpers import DEBUG
|
||||
import json
|
||||
import mlx.core as mx
|
||||
import platform
|
||||
|
||||
if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
|
||||
import mlx.core as mx
|
||||
else:
|
||||
import numpy as mx
|
||||
|
||||
|
||||
class GRPCPeerHandle(PeerHandle):
|
||||
def __init__(self, _id: str, address: str, desc: str, device_capabilities: DeviceCapabilities):
|
||||
@@ -37,11 +43,9 @@ class GRPCPeerHandle(PeerHandle):
|
||||
|
||||
async def connect(self):
|
||||
if self.channel is None:
|
||||
self.channel = grpc.aio.insecure_channel(self.address, options=[
|
||||
("grpc.max_metadata_size", 32*1024*1024),
|
||||
('grpc.max_receive_message_length', 32*1024*1024),
|
||||
('grpc.max_send_message_length', 32*1024*1024)
|
||||
])
|
||||
self.channel = grpc.aio.insecure_channel(
|
||||
self.address, options=[("grpc.max_metadata_size", 32*1024*1024), ('grpc.max_receive_message_length', 32*1024*1024), ('grpc.max_send_message_length', 32*1024*1024)]
|
||||
)
|
||||
self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
|
||||
await self.channel.channel_ready()
|
||||
|
||||
@@ -109,7 +113,7 @@ class GRPCPeerHandle(PeerHandle):
|
||||
return None
|
||||
|
||||
return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
|
||||
|
||||
|
||||
async def send_example(self, shard: Shard, example: np.ndarray, target: np.ndarray, length: np.ndarray, train: bool, request_id: Optional[str] = None) -> Optional[np.array]:
|
||||
request = node_service_pb2.ExampleRequest(
|
||||
shard=node_service_pb2.Shard(
|
||||
@@ -131,7 +135,7 @@ class GRPCPeerHandle(PeerHandle):
|
||||
return loss, grads
|
||||
else:
|
||||
return loss
|
||||
|
||||
|
||||
async def send_loss(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.array]:
|
||||
request = node_service_pb2.TensorRequest(
|
||||
shard=node_service_pb2.Shard(
|
||||
@@ -166,10 +170,7 @@ class GRPCPeerHandle(PeerHandle):
|
||||
topology = Topology()
|
||||
for node_id, capabilities in response.nodes.items():
|
||||
device_capabilities = DeviceCapabilities(
|
||||
model=capabilities.model,
|
||||
chip=capabilities.chip,
|
||||
memory=capabilities.memory,
|
||||
flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8)
|
||||
model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory, flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8)
|
||||
)
|
||||
topology.update_node(node_id, device_capabilities)
|
||||
for node_id, peer_connections in response.peer_graph.items():
|
||||
@@ -193,28 +194,20 @@ class GRPCPeerHandle(PeerHandle):
|
||||
proto_inference_state = node_service_pb2.InferenceState()
|
||||
other_data = {}
|
||||
for k, v in inference_state.items():
|
||||
if isinstance(v, mx.array):
|
||||
np_array = np.array(v)
|
||||
tensor_data = node_service_pb2.Tensor(
|
||||
tensor_data=np_array.tobytes(),
|
||||
shape=list(np_array.shape),
|
||||
dtype=str(np_array.dtype)
|
||||
)
|
||||
proto_inference_state.tensor_data[k].CopyFrom(tensor_data)
|
||||
elif isinstance(v, list) and all(isinstance(item, mx.array) for item in v):
|
||||
tensor_list = node_service_pb2.TensorList()
|
||||
for tensor in v:
|
||||
np_array = np.array(tensor)
|
||||
tensor_data = node_service_pb2.Tensor(
|
||||
tensor_data=np_array.tobytes(),
|
||||
shape=list(np_array.shape),
|
||||
dtype=str(np_array.dtype)
|
||||
)
|
||||
tensor_list.tensors.append(tensor_data)
|
||||
proto_inference_state.tensor_list_data[k].CopyFrom(tensor_list)
|
||||
else:
|
||||
# For non-tensor data, we'll still use JSON
|
||||
other_data[k] = v
|
||||
if isinstance(v, mx.array):
|
||||
np_array = np.array(v)
|
||||
tensor_data = node_service_pb2.Tensor(tensor_data=np_array.tobytes(), shape=list(np_array.shape), dtype=str(np_array.dtype))
|
||||
proto_inference_state.tensor_data[k].CopyFrom(tensor_data)
|
||||
elif isinstance(v, list) and all(isinstance(item, mx.array) for item in v):
|
||||
tensor_list = node_service_pb2.TensorList()
|
||||
for tensor in v:
|
||||
np_array = np.array(tensor)
|
||||
tensor_data = node_service_pb2.Tensor(tensor_data=np_array.tobytes(), shape=list(np_array.shape), dtype=str(np_array.dtype))
|
||||
tensor_list.tensors.append(tensor_data)
|
||||
proto_inference_state.tensor_list_data[k].CopyFrom(tensor_list)
|
||||
else:
|
||||
# For non-tensor data, we'll still use JSON
|
||||
other_data[k] = v
|
||||
if other_data:
|
||||
proto_inference_state.other_data_json = json.dumps(other_data)
|
||||
return proto_inference_state
|
||||
|
||||
@@ -3,13 +3,19 @@ from concurrent import futures
|
||||
import numpy as np
|
||||
from asyncio import CancelledError
|
||||
|
||||
import platform
|
||||
|
||||
from . import node_service_pb2
|
||||
from . import node_service_pb2_grpc
|
||||
from exo import DEBUG
|
||||
from exo.inference.shard import Shard
|
||||
from exo.orchestration import Node
|
||||
import json
|
||||
import mlx.core as mx
|
||||
|
||||
if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
|
||||
import mlx.core as mx
|
||||
else:
|
||||
import numpy as mx
|
||||
|
||||
|
||||
class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
||||
@@ -74,7 +80,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
||||
if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
|
||||
tensor_data = result.tobytes() if result is not None else None
|
||||
return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
|
||||
|
||||
|
||||
async def SendExample(self, request, context):
|
||||
shard = Shard(
|
||||
model_id=request.shard.model_id,
|
||||
@@ -96,7 +102,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
||||
else:
|
||||
loss = await self.node.process_example(shard, example, target, length, train, request_id)
|
||||
return node_service_pb2.Loss(loss=loss, grads=None)
|
||||
|
||||
|
||||
async def CollectTopology(self, request, context):
|
||||
max_depth = request.max_depth
|
||||
visited = set(request.visited)
|
||||
@@ -112,12 +118,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
||||
for node_id, cap in topology.nodes.items()
|
||||
}
|
||||
peer_graph = {
|
||||
node_id: node_service_pb2.PeerConnections(
|
||||
connections=[
|
||||
node_service_pb2.PeerConnection(to_id=conn.to_id, description=conn.description)
|
||||
for conn in connections
|
||||
]
|
||||
)
|
||||
node_id: node_service_pb2.PeerConnections(connections=[node_service_pb2.PeerConnection(to_id=conn.to_id, description=conn.description) for conn in connections])
|
||||
for node_id, connections in topology.peer_graph.items()
|
||||
}
|
||||
if DEBUG >= 5: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}")
|
||||
@@ -131,7 +132,7 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
||||
if DEBUG >= 5: print(f"Received SendResult request: {request_id=} {result=} {is_finished=}")
|
||||
result = list(result)
|
||||
if len(img.tensor_data) > 0:
|
||||
result=np.frombuffer(img.tensor_data, dtype=np.dtype(img.dtype)).reshape(img.shape)
|
||||
result = np.frombuffer(img.tensor_data, dtype=np.dtype(img.dtype)).reshape(img.shape)
|
||||
self.node.on_token.trigger_all(request_id, result, is_finished)
|
||||
return node_service_pb2.Empty()
|
||||
|
||||
@@ -145,21 +146,18 @@ class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
||||
async def HealthCheck(self, request, context):
|
||||
return node_service_pb2.HealthCheckResponse(is_healthy=True)
|
||||
|
||||
def deserialize_inference_state(self,inference_state_proto: node_service_pb2.InferenceState) -> dict:
|
||||
def deserialize_inference_state(self, inference_state_proto: node_service_pb2.InferenceState) -> dict:
|
||||
inference_state = {}
|
||||
|
||||
|
||||
for k, tensor_data in inference_state_proto.tensor_data.items():
|
||||
np_array = np.frombuffer(tensor_data.tensor_data, dtype=tensor_data.dtype).reshape(tensor_data.shape)
|
||||
inference_state[k] = mx.array(np_array)
|
||||
|
||||
np_array = np.frombuffer(tensor_data.tensor_data, dtype=tensor_data.dtype).reshape(tensor_data.shape)
|
||||
inference_state[k] = mx.array(np_array)
|
||||
|
||||
for k, tensor_list in inference_state_proto.tensor_list_data.items():
|
||||
inference_state[k] = [
|
||||
mx.array(np.frombuffer(tensor.tensor_data, dtype=tensor.dtype).reshape(tensor.shape))
|
||||
for tensor in tensor_list.tensors
|
||||
]
|
||||
|
||||
inference_state[k] = [mx.array(np.frombuffer(tensor.tensor_data, dtype=tensor.dtype).reshape(tensor.shape)) for tensor in tensor_list.tensors]
|
||||
|
||||
if inference_state_proto.other_data_json:
|
||||
other_data = json.loads(inference_state_proto.other_data_json)
|
||||
inference_state.update(other_data)
|
||||
|
||||
other_data = json.loads(inference_state_proto.other_data_json)
|
||||
inference_state.update(other_data)
|
||||
|
||||
return inference_state
|
||||
|
||||
@@ -149,6 +149,8 @@ def device_capabilities() -> DeviceCapabilities:
|
||||
return mac_device_capabilities()
|
||||
elif psutil.LINUX:
|
||||
return linux_device_capabilities()
|
||||
elif psutil.WINDOWS:
|
||||
return windows_device_capabilities()
|
||||
else:
|
||||
return DeviceCapabilities(
|
||||
model="Unknown Device",
|
||||
@@ -194,6 +196,8 @@ def linux_device_capabilities() -> DeviceCapabilities:
|
||||
|
||||
if DEBUG >= 2: print(f"NVIDIA device {gpu_name=} {gpu_memory_info=}")
|
||||
|
||||
pynvml.nvmlShutdown()
|
||||
|
||||
return DeviceCapabilities(
|
||||
model=f"Linux Box ({gpu_name})",
|
||||
chip=gpu_name,
|
||||
@@ -201,13 +205,24 @@ def linux_device_capabilities() -> DeviceCapabilities:
|
||||
flops=CHIP_FLOPS.get(gpu_name, DeviceFlops(fp32=0, fp16=0, int8=0)),
|
||||
)
|
||||
elif Device.DEFAULT == "AMD":
|
||||
# TODO AMD support
|
||||
# For AMD GPUs, pyrsmi is the way (Official python package for rocm-smi)
|
||||
from pyrsmi import rocml
|
||||
|
||||
rocml.smi_initialize()
|
||||
gpu_name = rocml.smi_get_device_name(0).upper()
|
||||
gpu_memory_info = rocml.smi_get_device_memory_total(0)
|
||||
|
||||
if DEBUG >= 2: print(f"AMD device {gpu_name=} {gpu_memory_info=}")
|
||||
|
||||
rocml.smi_shutdown()
|
||||
|
||||
return DeviceCapabilities(
|
||||
model="Linux Box (AMD)",
|
||||
chip="Unknown AMD",
|
||||
memory=psutil.virtual_memory().total // 2**20,
|
||||
model="Linux Box ({gpu_name})",
|
||||
chip={gpu_name},
|
||||
memory=gpu_memory_info.total // 2**20,
|
||||
flops=DeviceFlops(fp32=0, fp16=0, int8=0),
|
||||
)
|
||||
|
||||
else:
|
||||
return DeviceCapabilities(
|
||||
model=f"Linux Box (Device: {Device.DEFAULT})",
|
||||
@@ -215,3 +230,74 @@ def linux_device_capabilities() -> DeviceCapabilities:
|
||||
memory=psutil.virtual_memory().total // 2**20,
|
||||
flops=DeviceFlops(fp32=0, fp16=0, int8=0),
|
||||
)
|
||||
|
||||
|
||||
def windows_device_capabilities() -> DeviceCapabilities:
|
||||
import psutil
|
||||
|
||||
def get_gpu_info():
|
||||
import win32com.client # install pywin32
|
||||
|
||||
wmiObj = win32com.client.GetObject("winmgmts:\\\\.\\root\\cimv2")
|
||||
gpus = wmiObj.ExecQuery("SELECT * FROM Win32_VideoController")
|
||||
|
||||
gpu_info = []
|
||||
for gpu in gpus:
|
||||
info = {
|
||||
"Name": gpu.Name,
|
||||
"AdapterRAM": gpu.AdapterRAM, # Bug in this property, returns -ve for VRAM > 4GB (uint32 overflow)
|
||||
"DriverVersion": gpu.DriverVersion,
|
||||
"VideoProcessor": gpu.VideoProcessor
|
||||
}
|
||||
gpu_info.append(info)
|
||||
|
||||
return gpu_info
|
||||
|
||||
gpus_info = get_gpu_info()
|
||||
gpu_names = [gpu['Name'] for gpu in gpus_info]
|
||||
|
||||
contains_nvidia = any('nvidia' in gpu_name.lower() for gpu_name in gpu_names)
|
||||
contains_amd = any('amd' in gpu_name.lower() for gpu_name in gpu_names)
|
||||
|
||||
if contains_nvidia:
|
||||
import pynvml
|
||||
|
||||
pynvml.nvmlInit()
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
|
||||
gpu_raw_name = pynvml.nvmlDeviceGetName(handle).upper()
|
||||
gpu_name = gpu_raw_name.rsplit(" ", 1)[0] if gpu_raw_name.endswith("GB") else gpu_raw_name
|
||||
gpu_memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||
|
||||
if DEBUG >= 2: print(f"NVIDIA device {gpu_name=} {gpu_memory_info=}")
|
||||
|
||||
return DeviceCapabilities(
|
||||
model=f"Windows Box ({gpu_name})",
|
||||
chip=gpu_name,
|
||||
memory=gpu_memory_info.total // 2**20,
|
||||
flops=CHIP_FLOPS.get(gpu_name, DeviceFlops(fp32=0, fp16=0, int8=0)),
|
||||
)
|
||||
elif contains_amd:
|
||||
# For AMD GPUs, pyrsmi is the way (Official python package for rocm-smi)
|
||||
from pyrsmi import rocml
|
||||
|
||||
rocml.smi_initialize()
|
||||
gpu_name = rocml.smi_get_device_name(0).upper()
|
||||
gpu_memory_info = rocml.smi_get_device_memory_total(0)
|
||||
|
||||
if DEBUG >= 2: print(f"AMD device {gpu_name=} {gpu_memory_info=}")
|
||||
|
||||
rocml.smi_shutdown()
|
||||
|
||||
return DeviceCapabilities(
|
||||
model="Windows Box ({gpu_name})",
|
||||
chip={gpu_name},
|
||||
memory=gpu_memory_info.total // 2**20,
|
||||
flops=DeviceFlops(fp32=0, fp16=0, int8=0),
|
||||
)
|
||||
else:
|
||||
return DeviceCapabilities(
|
||||
model=f"Windows Box (Device: Unknown)",
|
||||
chip=f"Unknown Chip (Device(s): {gpu_names})",
|
||||
memory=psutil.virtual_memory().total // 2**20,
|
||||
flops=DeviceFlops(fp32=0, fp16=0, int8=0),
|
||||
)
|
||||
|
||||
@@ -6,6 +6,9 @@ import pkgutil
|
||||
|
||||
def run():
|
||||
site_packages = site.getsitepackages()[0]
|
||||
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
baseimages_dir = os.path.join(base_dir, "exo", "apputil", "baseimages")
|
||||
|
||||
command = [
|
||||
f"{sys.executable}", "-m", "nuitka", "exo/main.py",
|
||||
"--company-name=exolabs",
|
||||
@@ -15,7 +18,8 @@ def run():
|
||||
"--standalone",
|
||||
"--output-filename=exo",
|
||||
"--python-flag=no_site",
|
||||
"--onefile"
|
||||
"--onefile",
|
||||
f"--include-data-dir={baseimages_dir}=exo/apputil/baseimages"
|
||||
]
|
||||
|
||||
if sys.platform == "darwin":
|
||||
@@ -23,7 +27,7 @@ def run():
|
||||
"--macos-app-name=exo",
|
||||
"--macos-app-mode=gui",
|
||||
"--macos-app-version=0.0.1",
|
||||
"--macos-signed-app-name=com.exolabs.exo",
|
||||
"--macos-signed-app-name=net.exolabs.exo",
|
||||
"--include-distribution-meta=mlx",
|
||||
"--include-module=mlx._reprlib_fix",
|
||||
"--include-module=mlx._os_warning",
|
||||
|
||||
41
setup.py
41
setup.py
@@ -1,5 +1,6 @@
|
||||
import sys
|
||||
import platform
|
||||
import subprocess
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
@@ -11,7 +12,6 @@ install_requires = [
|
||||
"grpcio==1.68.0",
|
||||
"grpcio-tools==1.68.0",
|
||||
"Jinja2==3.1.4",
|
||||
"netifaces==0.11.0",
|
||||
"numpy==2.0.0",
|
||||
"nuitka==2.5.1",
|
||||
"nvidia-ml-py==12.560.30",
|
||||
@@ -23,6 +23,7 @@ install_requires = [
|
||||
"pydantic==2.9.2",
|
||||
"requests==2.32.3",
|
||||
"rich==13.7.1",
|
||||
"scapy==2.6.1",
|
||||
"tenacity==9.0.0",
|
||||
"tqdm==4.66.4",
|
||||
"transformers==4.46.3",
|
||||
@@ -31,19 +32,47 @@ install_requires = [
|
||||
]
|
||||
|
||||
extras_require = {
|
||||
"formatting": [
|
||||
"yapf==0.40.2",
|
||||
],
|
||||
"apple_silicon": [
|
||||
"formatting": ["yapf==0.40.2",], "apple_silicon": [
|
||||
"mlx==0.20.0",
|
||||
"mlx-lm==0.19.3",
|
||||
],
|
||||
], "windows": ["pywin32==308",], "nvidia-gpu": ["nvidia-ml-py==12.560.30",], "amd-gpu": ["pyrsmi==0.2.0"]
|
||||
}
|
||||
|
||||
# Check if running on macOS with Apple Silicon
|
||||
if sys.platform.startswith("darwin") and platform.machine() == "arm64":
|
||||
install_requires.extend(extras_require["apple_silicon"])
|
||||
|
||||
# Check if running Windows
|
||||
if sys.platform.startswith("win32"):
|
||||
install_requires.extend(extras_require["windows"])
|
||||
|
||||
|
||||
def _add_gpu_requires():
|
||||
global install_requires
|
||||
# Add Nvidia-GPU
|
||||
try:
|
||||
out = subprocess.run(['nvidia-smi', '--query-gpu=name', '--format=csv,noheader'], shell=True, text=True, capture_output=True, check=False)
|
||||
if out.returncode == 0:
|
||||
install_requires.extend(extras_require["nvidia-gpu"])
|
||||
except subprocess.CalledProcessError:
|
||||
pass
|
||||
|
||||
# Add AMD-GPU
|
||||
# This will mostly work only on Linux, amd/rocm-smi is not yet supported on Windows
|
||||
try:
|
||||
out = subprocess.run(['amd-smi', 'list', '--csv'], shell=True, text=True, capture_output=True, check=False)
|
||||
if out.returncode == 0:
|
||||
install_requires.extend(extras_require["amd-gpu"])
|
||||
except:
|
||||
out = subprocess.run(['rocm-smi', 'list', '--csv'], shell=True, text=True, capture_output=True, check=False)
|
||||
if out.returncode == 0:
|
||||
install_requires.extend(extras_require["amd-gpu"])
|
||||
finally:
|
||||
pass
|
||||
|
||||
|
||||
_add_gpu_requires()
|
||||
|
||||
setup(
|
||||
name="exo",
|
||||
version="0.0.1",
|
||||
|
||||
Reference in New Issue
Block a user