mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
prep repo for v1
This commit is contained in:
@@ -1,376 +0,0 @@
|
||||
version: 2.1
|
||||
|
||||
orbs:
|
||||
python: circleci/python@2
|
||||
|
||||
commands:
|
||||
run_chatgpt_api_test:
|
||||
parameters:
|
||||
inference_engine:
|
||||
type: string
|
||||
model_id:
|
||||
type: string
|
||||
expected_output:
|
||||
type: string
|
||||
prompt:
|
||||
type: string
|
||||
steps:
|
||||
- run:
|
||||
name: Run chatgpt api integration test (<<parameters.inference_engine>>, <<parameters.model_id>>)
|
||||
command: |
|
||||
source env/bin/activate
|
||||
|
||||
# Set CLANG=1 for tinygrad only
|
||||
if [ "<<parameters.inference_engine>>" = "tinygrad" ]; then
|
||||
pip install llvmlite
|
||||
export TOKENIZERS_PARALLELISM=true SUPPORT_BF16=0 CLANG=1
|
||||
fi
|
||||
|
||||
# Start first instance
|
||||
EXO_HOME="$(pwd)/.exo_cache_node1" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <<parameters.inference_engine>> \
|
||||
--node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 \
|
||||
--chatgpt-api-response-timeout 900 --disable-tui > output1.log &
|
||||
PID1=$!
|
||||
tail -f output1.log &
|
||||
TAIL1=$!
|
||||
|
||||
# Start second instance
|
||||
EXO_HOME="$(pwd)/.exo_cache_node2" DEBUG_DISCOVERY=7 DEBUG=7 exo --inference-engine <<parameters.inference_engine>> \
|
||||
--node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 \
|
||||
--chatgpt-api-response-timeout 900 --disable-tui > output2.log &
|
||||
PID2=$!
|
||||
tail -f output2.log &
|
||||
TAIL2=$!
|
||||
|
||||
# Remember to kill the tail processes at the end
|
||||
trap 'kill $TAIL1 $TAIL2' EXIT
|
||||
|
||||
# Wait for discovery
|
||||
sleep 10
|
||||
|
||||
# Function to check if processes are still running
|
||||
check_processes() {
|
||||
if ! kill -0 $PID1 2>/dev/null; then
|
||||
echo "First instance (PID $PID1) died unexpectedly. Log output:"
|
||||
cat output1.log
|
||||
exit 1
|
||||
fi
|
||||
if ! kill -0 $PID2 2>/dev/null; then
|
||||
echo "Second instance (PID $PID2) died unexpectedly. Log output:"
|
||||
cat output2.log
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
# Check processes before proceeding
|
||||
check_processes
|
||||
|
||||
echo "Sending request to first instance..."
|
||||
response_1=$(curl -s http://localhost:8000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "<<parameters.model_id>>",
|
||||
"messages": [{"role": "user", "content": "<<parameters.prompt>>"}],
|
||||
"temperature": 0.7
|
||||
}')
|
||||
echo "Response 1: $response_1"
|
||||
|
||||
# Check processes after first response
|
||||
check_processes
|
||||
|
||||
echo "Sending request to second instance..."
|
||||
response_2=$(curl -s http://localhost:8001/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "<<parameters.model_id>>",
|
||||
"messages": [{"role": "user", "content": "<<parameters.prompt>>"}],
|
||||
"temperature": 0.7
|
||||
}')
|
||||
echo "Response 2: $response_2"
|
||||
|
||||
# Check processes after second response
|
||||
check_processes
|
||||
|
||||
# Stop both instances
|
||||
kill $PID1 $PID2
|
||||
|
||||
echo ""
|
||||
# Extract content using jq and check if it contains expected output
|
||||
content1=$(echo "$response_1" | jq -r '.choices[0].message.content')
|
||||
content2=$(echo "$response_2" | jq -r '.choices[0].message.content')
|
||||
|
||||
if [[ "$content1" != *"<<parameters.expected_output>>"* ]] || [[ "$content2" != *"<<parameters.expected_output>>"* ]]; then
|
||||
echo "Test failed: Response does not match '<<parameters.expected_output>>'"
|
||||
echo "Response 1 content: $content1"
|
||||
echo ""
|
||||
echo "Response 2 content: $content2"
|
||||
echo "Output of first instance:"
|
||||
cat output1.log
|
||||
echo "Output of second instance:"
|
||||
cat output2.log
|
||||
exit 1
|
||||
else
|
||||
echo "Test passed: Response from both nodes matches '<<parameters.expected_output>>'"
|
||||
fi
|
||||
|
||||
jobs:
|
||||
unit_test:
|
||||
macos:
|
||||
xcode: "16.0.0"
|
||||
resource_class: m2pro.large
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Set up Python
|
||||
command: |
|
||||
brew install python@3.12
|
||||
python3.12 -m venv env
|
||||
source env/bin/activate
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install .
|
||||
- run:
|
||||
name: Run tests
|
||||
command: |
|
||||
source env/bin/activate
|
||||
# set TEMPERATURE to 0 for deterministic sampling
|
||||
echo "Running inference engine tests..."
|
||||
METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 METAL_XCODE=1 TEMPERATURE=0 python3 -m exo.inference.test_inference_engine
|
||||
echo "Running tokenizer tests..."
|
||||
python3 ./test/test_tokenizers.py
|
||||
python3 ./test/test_model_helpers.py
|
||||
|
||||
discovery_integration_test:
|
||||
macos:
|
||||
xcode: "16.0.0"
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Set up Python
|
||||
command: |
|
||||
brew install python@3.12
|
||||
python3.12 -m venv env
|
||||
source env/bin/activate
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install .
|
||||
- run:
|
||||
name: Run discovery integration test
|
||||
command: |
|
||||
source env/bin/activate
|
||||
DEBUG_DISCOVERY=7 DEBUG=7 exo --node-id "node1" --listen-port 5678 --broadcast-port 5679 --chatgpt-api-port 8000 --disable-tui > output1.log 2>&1 &
|
||||
PID1=$!
|
||||
DEBUG_DISCOVERY=7 DEBUG=7 exo --node-id "node2" --listen-port 5679 --broadcast-port 5678 --chatgpt-api-port 8001 --disable-tui > output2.log 2>&1 &
|
||||
PID2=$!
|
||||
sleep 10
|
||||
kill $PID1 $PID2
|
||||
if grep -q "Peer statuses: {\\'node2\\': \\'is_connected=True, health_check=True" output1.log && ! grep -q "Failed to connect peers:" output1.log && grep -q "Peer statuses: {\\'node1\\': \\'is_connected=True, health_check=True" output2.log && ! grep -q "Failed to connect peers:" output2.log; then
|
||||
echo "Test passed: Both instances discovered each other"
|
||||
exit 0
|
||||
else
|
||||
echo "Test failed: Devices did not discover each other"
|
||||
echo "Output of first instance:"
|
||||
cat output1.log
|
||||
echo "Output of second instance:"
|
||||
cat output2.log
|
||||
exit 1
|
||||
fi
|
||||
|
||||
chatgpt_api_integration_test_mlx:
|
||||
macos:
|
||||
xcode: "16.0.0"
|
||||
resource_class: m2pro.large
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Set up Python
|
||||
command: |
|
||||
brew install python@3.12
|
||||
python3.12 -m venv env
|
||||
source env/bin/activate
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install .
|
||||
- run_chatgpt_api_test:
|
||||
inference_engine: mlx
|
||||
model_id: llama-3.2-1b
|
||||
prompt: "Keep responses concise. Who was the king of pop?"
|
||||
expected_output: "Michael Jackson"
|
||||
|
||||
chatgpt_api_integration_test_dummy:
|
||||
macos:
|
||||
xcode: "16.0.0"
|
||||
resource_class: m2pro.large
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Set up Python
|
||||
command: |
|
||||
brew install python@3.12
|
||||
python3.12 -m venv env
|
||||
source env/bin/activate
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install .
|
||||
- run_chatgpt_api_test:
|
||||
inference_engine: dummy
|
||||
model_id: dummy
|
||||
prompt: "Dummy prompt."
|
||||
expected_output: "dummy"
|
||||
|
||||
chatgpt_api_integration_test_tinygrad:
|
||||
macos:
|
||||
xcode: "16.0.0"
|
||||
resource_class: m2pro.large
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Set up Python
|
||||
command: |
|
||||
brew install python@3.12
|
||||
python3.12 -m venv env
|
||||
source env/bin/activate
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install .
|
||||
- run_chatgpt_api_test:
|
||||
inference_engine: tinygrad
|
||||
model_id: llama-3.2-1b
|
||||
prompt: "Keep responses concise. Who was the king of pop?"
|
||||
expected_output: "Michael Jackson"
|
||||
|
||||
chatgpt_api_integration_test_tinygrad_linux:
|
||||
machine:
|
||||
image: ubuntu-2204:current
|
||||
resource_class: xlarge
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Set up Python
|
||||
command: |
|
||||
export DEBIAN_FRONTEND=noninteractive
|
||||
export DEBCONF_NONINTERACTIVE_SEEN=true
|
||||
sudo apt-get update
|
||||
sudo add-apt-repository -y ppa:deadsnakes/ppa
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y python3.12 python3.12-venv clang
|
||||
python3.12 -m venv env
|
||||
source env/bin/activate
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install .
|
||||
- run_chatgpt_api_test:
|
||||
inference_engine: tinygrad
|
||||
model_id: llama-3.2-1b
|
||||
prompt: "Keep responses concise. Who was the king of pop?"
|
||||
expected_output: "Michael Jackson"
|
||||
|
||||
measure_pip_sizes:
|
||||
macos:
|
||||
xcode: "16.0.0"
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Set up Python
|
||||
command: |
|
||||
brew install python@3.12
|
||||
python3.12 -m venv env
|
||||
source env/bin/activate
|
||||
- run:
|
||||
name: Install dependencies and measure sizes
|
||||
command: |
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install .
|
||||
python ./extra/pipsize.py --json ./pipsize.json
|
||||
- store_artifacts:
|
||||
path: ./pipsize.json
|
||||
destination: pip-sizes.json
|
||||
|
||||
check_line_count:
|
||||
docker:
|
||||
- image: cimg/python:3.10
|
||||
steps:
|
||||
- checkout
|
||||
|
||||
- run:
|
||||
name: Setup git for PR comparison
|
||||
command: |
|
||||
if [[ -n "$CIRCLE_PULL_REQUEST" ]]; then
|
||||
PR_NUMBER=$(echo $CIRCLE_PULL_REQUEST | rev | cut -d'/' -f1 | rev)
|
||||
BASE_BRANCH=$(curl -s -H "Circle-Token: $CIRCLE_TOKEN" \
|
||||
"https://circleci.com/api/v2/project/github/$CIRCLE_PROJECT_USERNAME/$CIRCLE_PROJECT_REPONAME/pipeline/$CIRCLE_WORKFLOW_ID" \
|
||||
| jq -r '.target_branch')
|
||||
|
||||
git clone -b $BASE_BRANCH --single-branch \
|
||||
https://github.com/$CIRCLE_PROJECT_USERNAME/$CIRCLE_PROJECT_REPONAME.git \
|
||||
base_branch
|
||||
fi
|
||||
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install tabulate
|
||||
|
||||
- run:
|
||||
name: Run line count check
|
||||
command: |
|
||||
if [[ -n "$CIRCLE_PULL_REQUEST" ]]; then
|
||||
python extra/line_counter.py base_branch .
|
||||
else
|
||||
python extra/line_counter.py .
|
||||
fi
|
||||
|
||||
- store_artifacts:
|
||||
path: line-count-snapshot.json
|
||||
destination: line-count-snapshot.json
|
||||
|
||||
- store_artifacts:
|
||||
path: line-count-diff.json
|
||||
destination: line-count-diff.json
|
||||
|
||||
- run:
|
||||
name: Create test results directory
|
||||
command: |
|
||||
mkdir -p test-results/line-count
|
||||
cp line-count-*.json test-results/line-count/
|
||||
|
||||
- store_test_results:
|
||||
path: test-results
|
||||
|
||||
workflows:
|
||||
version: 2
|
||||
build_and_test:
|
||||
jobs:
|
||||
- check_line_count:
|
||||
filters:
|
||||
branches:
|
||||
only: /.*/
|
||||
tags:
|
||||
only: /.*/
|
||||
- unit_test
|
||||
- discovery_integration_test
|
||||
- chatgpt_api_integration_test_mlx
|
||||
- chatgpt_api_integration_test_tinygrad
|
||||
- chatgpt_api_integration_test_tinygrad_linux
|
||||
- chatgpt_api_integration_test_dummy
|
||||
- measure_pip_sizes
|
||||
2
.gitattributes
vendored
2
.gitattributes
vendored
@@ -1,2 +0,0 @@
|
||||
*.mp3 filter=lfs diff=lfs merge=lfs -text
|
||||
*.png filter=lfs diff=lfs merge=lfs -text
|
||||
401
.github/bench.py
vendored
401
.github/bench.py
vendored
@@ -1,401 +0,0 @@
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import time
|
||||
import json
|
||||
import os
|
||||
import boto3
|
||||
from typing import Dict, Any
|
||||
from datetime import datetime
|
||||
import subprocess
|
||||
import psutil
|
||||
import platform
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def check_system_state():
|
||||
print("\n=== System State Check ===", flush=True)
|
||||
|
||||
# Add macOS-specific checks
|
||||
try:
|
||||
# Check powermetrics with sudo
|
||||
try:
|
||||
power_metrics = subprocess.run(
|
||||
['sudo', 'powermetrics', '-n', '1', '-i', '1000', '--samplers', 'cpu_power'],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
print("\nPower Metrics:", power_metrics.stdout, flush=True)
|
||||
except Exception as e:
|
||||
print(f"Error getting power metrics: {e}", flush=True)
|
||||
|
||||
# Check thermal state
|
||||
thermal_state = subprocess.run(['pmset', '-g', 'therm'], capture_output=True, text=True)
|
||||
print("\nThermal State:", thermal_state.stdout, flush=True)
|
||||
|
||||
# Check if running under Rosetta
|
||||
arch = subprocess.run(['arch'], capture_output=True, text=True)
|
||||
print("\nArchitecture:", arch.stdout, flush=True)
|
||||
|
||||
# Check MLX compilation mode - only if mlx is available
|
||||
try:
|
||||
import mlx.core as mx
|
||||
if hasattr(mx, 'build_info'):
|
||||
print("\nMLX Build Info:", mx.build_info(), flush=True)
|
||||
else:
|
||||
print("\nMLX Build Info: Not available in this version", flush=True)
|
||||
except ImportError:
|
||||
print("\nMLX: Not installed", flush=True)
|
||||
except Exception as e:
|
||||
print(f"\nError checking MLX: {e}", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in macOS checks: {e}", flush=True)
|
||||
|
||||
# CPU Info
|
||||
print("\nCPU Information:", flush=True)
|
||||
try:
|
||||
if platform.system() == 'Darwin' and platform.processor() == 'arm':
|
||||
# Use sysctl for Apple Silicon Macs
|
||||
cpu_info = subprocess.run(['sysctl', 'machdep.cpu'], capture_output=True, text=True)
|
||||
if cpu_info.returncode == 0:
|
||||
print(f"CPU Info (Apple Silicon):", cpu_info.stdout, flush=True)
|
||||
|
||||
# Parse powermetrics output for clearer CPU frequency display
|
||||
try:
|
||||
power_metrics = subprocess.run(
|
||||
['sudo', 'powermetrics', '-n', '1', '-i', '100', '--samplers', 'cpu_power'],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
if power_metrics.returncode == 0:
|
||||
output = power_metrics.stdout
|
||||
print("\nDetailed CPU Frequency Information:")
|
||||
|
||||
# Extract cluster frequencies and max frequencies
|
||||
current_cluster = None
|
||||
max_freqs = {'E': 0, 'P0': 0, 'P1': 0}
|
||||
|
||||
for line in output.split('\n'):
|
||||
# Track which cluster we're processing
|
||||
if "E-Cluster" in line:
|
||||
current_cluster = 'E'
|
||||
elif "P0-Cluster" in line:
|
||||
current_cluster = 'P0'
|
||||
elif "P1-Cluster" in line:
|
||||
current_cluster = 'P1'
|
||||
|
||||
# Get current frequencies
|
||||
if "HW active frequency:" in line:
|
||||
freq = line.split(':')[1].strip()
|
||||
if freq != "0 MHz":
|
||||
print(f"Current {current_cluster}-Cluster Frequency: {freq}")
|
||||
|
||||
# Get max frequencies from residency lines
|
||||
if current_cluster and "active residency:" in line and "MHz:" in line:
|
||||
try:
|
||||
# Extract all frequency values
|
||||
freqs = []
|
||||
parts = line.split('MHz:')[:-1] # Skip last part as it's not a frequency
|
||||
for part in parts:
|
||||
freq_str = part.split()[-1]
|
||||
try:
|
||||
freq = float(freq_str)
|
||||
freqs.append(freq)
|
||||
except ValueError:
|
||||
continue
|
||||
if freqs:
|
||||
max_freqs[current_cluster] = max(max_freqs[current_cluster], max(freqs))
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# Print max frequencies
|
||||
print("\nMaximum Available Frequencies:")
|
||||
for cluster, max_freq in max_freqs.items():
|
||||
if max_freq > 0:
|
||||
print(f"{cluster}-Cluster Max: {max_freq:.0f} MHz")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error parsing powermetrics: {e}", flush=True)
|
||||
else:
|
||||
# Use psutil for other systems
|
||||
cpu_freq = psutil.cpu_freq()
|
||||
print(f"CPU Frequency - Current: {cpu_freq.current:.2f}MHz, Min: {cpu_freq.min:.2f}MHz, Max: {cpu_freq.max:.2f}MHz", flush=True)
|
||||
|
||||
print(f"\nCPU Usage per Core: {psutil.cpu_percent(percpu=True)}%", flush=True)
|
||||
|
||||
# Check if running in low power mode
|
||||
power_mode = subprocess.run(['pmset', '-g'], capture_output=True, text=True)
|
||||
print("\nPower Settings:", power_mode.stdout, flush=True)
|
||||
except Exception as e:
|
||||
print(f"Error getting CPU info: {e}", flush=True)
|
||||
|
||||
# Memory Info
|
||||
print("\nMemory Information:", flush=True)
|
||||
try:
|
||||
mem = psutil.virtual_memory()
|
||||
print(f"Total: {mem.total/1024/1024/1024:.2f}GB", flush=True)
|
||||
print(f"Available: {mem.available/1024/1024/1024:.2f}GB", flush=True)
|
||||
print(f"Used: {mem.used/1024/1024/1024:.2f}GB ({mem.percent}%)", flush=True)
|
||||
|
||||
# Check swap
|
||||
swap = psutil.swap_memory()
|
||||
print(f"Swap Used: {swap.used/1024/1024/1024:.2f}GB of {swap.total/1024/1024/1024:.2f}GB", flush=True)
|
||||
except Exception as e:
|
||||
print(f"Error getting memory info: {e}", flush=True)
|
||||
|
||||
# GPU Info
|
||||
print("\nGPU Information:", flush=True)
|
||||
try:
|
||||
# Check MLX GPU settings
|
||||
print("MLX Environment Variables:", flush=True)
|
||||
mlx_vars = {k: v for k, v in os.environ.items() if k.startswith('MLX')}
|
||||
print(json.dumps(mlx_vars, indent=2), flush=True)
|
||||
|
||||
# Check Metal GPU memory allocation
|
||||
gpu_mem = subprocess.run(['sysctl', 'iogpu'], capture_output=True, text=True)
|
||||
print("GPU Memory Settings:", gpu_mem.stdout, flush=True)
|
||||
except Exception as e:
|
||||
print(f"Error getting GPU info: {e}", flush=True)
|
||||
|
||||
# Process Priority
|
||||
print("\nProcess Priority Information:", flush=True)
|
||||
try:
|
||||
current_process = psutil.Process()
|
||||
print(f"Process Nice Value: {current_process.nice()}", flush=True)
|
||||
# Only try to get ionice if the platform supports it
|
||||
if hasattr(current_process, 'ionice'):
|
||||
print(f"Process IO Nice Value: {current_process.ionice()}", flush=True)
|
||||
except Exception as e:
|
||||
print(f"Error getting process priority info: {e}", flush=True)
|
||||
|
||||
# System Load
|
||||
print("\nSystem Load:", flush=True)
|
||||
try:
|
||||
load_avg = psutil.getloadavg()
|
||||
print(f"Load Average: {load_avg}", flush=True)
|
||||
|
||||
# Get top processes by CPU and Memory
|
||||
print("\nTop Processes by CPU Usage:", flush=True)
|
||||
processes = []
|
||||
for proc in psutil.process_iter(['pid', 'name', 'cpu_percent', 'memory_percent']):
|
||||
try:
|
||||
pinfo = proc.info
|
||||
if pinfo['cpu_percent'] is not None and pinfo['memory_percent'] is not None:
|
||||
processes.append(pinfo)
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
||||
continue
|
||||
|
||||
# Sort and display top 5 CPU-consuming processes
|
||||
sorted_by_cpu = sorted(processes, key=lambda x: x['cpu_percent'] or 0, reverse=True)[:5]
|
||||
for proc in sorted_by_cpu:
|
||||
print(f"PID: {proc['pid']}, Name: {proc['name']}, CPU: {proc['cpu_percent']}%, Memory: {proc['memory_percent']:.1f}%")
|
||||
except Exception as e:
|
||||
print(f"Error getting system load info: {e}", flush=True)
|
||||
|
||||
print("\n=== End System State Check ===\n", flush=True)
|
||||
|
||||
|
||||
def check_gpu_access():
|
||||
try:
|
||||
# Check if MLX can see the GPU
|
||||
import mlx.core as mx
|
||||
print("MLX device info:", mx.default_device())
|
||||
|
||||
# Check Metal device availability
|
||||
result = subprocess.run(['system_profiler', 'SPDisplaysDataType'], capture_output=True, text=True)
|
||||
print("GPU Info:", result.stdout)
|
||||
except Exception as e:
|
||||
print(f"Failed to check GPU access: {e}")
|
||||
|
||||
|
||||
async def measure_performance(api_endpoint: str, prompt: str, model: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Measures the performance of an API endpoint by sending a prompt and recording metrics.
|
||||
|
||||
Args:
|
||||
api_endpoint (str): The API endpoint URL.
|
||||
prompt (str): The prompt to send to the API.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary containing performance metrics or error information.
|
||||
"""
|
||||
|
||||
results = {
|
||||
'model': model,
|
||||
'run_id': os.environ.get('GITHUB_RUN_ID', 'unknown'),
|
||||
'branch': os.environ.get('GITHUB_REF_NAME', 'unknown'),
|
||||
'commit': os.environ.get('GITHUB_SHA', 'unknown'),
|
||||
'configuration': json.loads(os.environ.get('HARDWARE_CONFIG', '{}'))
|
||||
}
|
||||
|
||||
# Get token count
|
||||
session = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=600, connect=10, sock_read=600, sock_connect=10))
|
||||
try:
|
||||
response = await session.post(
|
||||
"http://localhost:52415/v1/chat/token/encode",
|
||||
json={
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": prompt}]
|
||||
}
|
||||
)
|
||||
response.raise_for_status()
|
||||
token_data = await response.json()
|
||||
results['prompt_len'] = token_data['num_tokens']
|
||||
except Exception as e:
|
||||
await session.close()
|
||||
raise RuntimeError(f"Failed to get token count: {str(e)}")
|
||||
|
||||
# Measure completion performance
|
||||
try:
|
||||
start_time = time.time()
|
||||
response = await session.post(
|
||||
api_endpoint,
|
||||
json={
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"temperature": 0,
|
||||
"stream": True
|
||||
}
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
first_token_time = None
|
||||
total_tokens = 0
|
||||
|
||||
async for line in response.content.iter_chunks():
|
||||
line = line[0].decode('utf-8').strip()
|
||||
if not line.startswith('data: '):
|
||||
continue
|
||||
|
||||
data = json.loads(line[6:]) # Skip 'data: ' prefix
|
||||
if content := data.get('choices', [{}])[0].get('delta', {}).get('content'):
|
||||
print(f"Received content: {content}", flush=True)
|
||||
if first_token_time is None:
|
||||
first_token_time = time.time()
|
||||
ttft = first_token_time - start_time
|
||||
results.update({
|
||||
'ttft': ttft,
|
||||
'prompt_tps': results['prompt_len'] / ttft
|
||||
})
|
||||
total_tokens += 1
|
||||
|
||||
total_time = time.time() - start_time
|
||||
results.update({
|
||||
'generation_tps': total_tokens / total_time,
|
||||
'response_len': total_tokens,
|
||||
'total_time': total_time
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Performance measurement failed: {str(e)}")
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
api_endpoint = "http://localhost:52415/v1/chat/completions"
|
||||
|
||||
# Define prompts
|
||||
prompt_warmup = "what is the capital of France?"
|
||||
prompt_essay = "write an essay about cats"
|
||||
|
||||
model = os.environ.get('model', 'llama-3.2-1b')
|
||||
# Warmup request
|
||||
print("\nPerforming warmup request...", flush=True)
|
||||
try:
|
||||
warmup_results = await measure_performance(api_endpoint, prompt_warmup, model)
|
||||
print("Warmup completed successfully", flush=True)
|
||||
except Exception as e:
|
||||
print(f"Warmup request failed: {e}", flush=True)
|
||||
|
||||
# Measure performance for the essay prompt
|
||||
print("\nMeasuring performance for the essay prompt...", flush=True)
|
||||
results = await measure_performance(api_endpoint, prompt_essay, model)
|
||||
|
||||
try:
|
||||
s3_client = boto3.client(
|
||||
's3',
|
||||
aws_access_key_id=os.environ.get('aws_access_key_id'),
|
||||
aws_secret_access_key=os.environ.get('aws_secret_key')
|
||||
)
|
||||
job_name = os.environ.get('GITHUB_JOB')
|
||||
|
||||
# Create S3 key with timestamp and commit info
|
||||
now = datetime.utcnow()
|
||||
timestamp = now.strftime('%H-%M-%S')
|
||||
commit_sha = os.environ.get('GITHUB_SHA', 'unknown')[:7]
|
||||
s3_key = f"{job_name}/{model}/{now.year}/{now.month}/{now.day}/{timestamp}_{commit_sha}.json"
|
||||
|
||||
# Upload to S3
|
||||
s3_client.put_object(
|
||||
Bucket='exo-benchmarks',
|
||||
Key=s3_key,
|
||||
Body=json.dumps(results),
|
||||
ContentType='application/json'
|
||||
)
|
||||
print(f"Performance metrics uploaded to S3: s3://exo-benchmarks/{s3_key}", flush=True)
|
||||
except Exception as e:
|
||||
print(f"Failed to upload metrics to S3: {e}", flush=True)
|
||||
|
||||
# Optionally print the metrics for visibility
|
||||
print("Performance metrics:", flush=True)
|
||||
print(json.dumps(results, indent=4), flush=True)
|
||||
|
||||
|
||||
def optimize_system_performance():
|
||||
"""Set optimal system performance settings before running benchmark."""
|
||||
try:
|
||||
# Try to set high performance power mode
|
||||
subprocess.run(['sudo', 'pmset', '-a', 'powermode', '2'], check=False)
|
||||
|
||||
# Ensure MLX uses performance cores and GPU
|
||||
os.environ['MLX_FORCE_P_CORES'] = '1'
|
||||
os.environ['MLX_METAL_PREWARM'] = '1'
|
||||
os.environ['MLX_USE_GPU'] = '1'
|
||||
|
||||
# Set process priority
|
||||
current_process = psutil.Process()
|
||||
try:
|
||||
# Set highest priority
|
||||
subprocess.run(['sudo', 'renice', '-n', '-20', '-p', str(current_process.pid)], check=False)
|
||||
|
||||
# Print current process state
|
||||
print("\nProcess State Before Benchmark:", flush=True)
|
||||
proc_info = subprocess.run(
|
||||
['ps', '-o', 'pid,ppid,user,%cpu,%mem,nice,stat,pri,command', '-p', str(current_process.pid)],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
print(proc_info.stdout, flush=True)
|
||||
|
||||
# Verify power mode
|
||||
power_info = subprocess.run(['pmset', '-g'], capture_output=True, text=True)
|
||||
if 'powermode 0' in power_info.stdout:
|
||||
print("\nWarning: System still in normal power mode. Trying to set high performance mode again...", flush=True)
|
||||
subprocess.run(['sudo', 'pmset', '-a', 'powermode', '2'], check=False)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not set process priority: {e}", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not optimize system performance: {e}", flush=True)
|
||||
|
||||
# Print optimization status
|
||||
print("\nOptimization Settings:", flush=True)
|
||||
print("MLX Environment Variables:", flush=True)
|
||||
for var in ['MLX_FORCE_P_CORES', 'MLX_METAL_PREWARM', 'MLX_USE_GPU']:
|
||||
print(f"{var}: {os.environ.get(var, 'Not set')}", flush=True)
|
||||
|
||||
try:
|
||||
nice_value = psutil.Process().nice()
|
||||
print(f"Process Nice Value: {nice_value}", flush=True)
|
||||
if nice_value != -20:
|
||||
print("Warning: Process not running at highest priority", flush=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
check_system_state()
|
||||
check_gpu_access()
|
||||
optimize_system_performance()
|
||||
asyncio.run(main())
|
||||
330
.github/bootstrap.sh
vendored
330
.github/bootstrap.sh
vendored
@@ -1,330 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
command_exists() {
|
||||
command -v "$1" >/dev/null 2>&1
|
||||
}
|
||||
|
||||
log() {
|
||||
echo "[$(date '+%Y-%m-%d %H:%M:%S')] $1"
|
||||
}
|
||||
|
||||
if [ "$EUID" -eq 0 ]; then
|
||||
log "Please do not run as root. Run as regular user with sudo access."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check for required arguments
|
||||
if [ -z "$1" ]; then
|
||||
log "Error: Runner token is required"
|
||||
log "Usage: $0 <runner-token> [tailscale-auth-key]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
RUNNER_TOKEN=$1
|
||||
TAILSCALE_AUTH_KEY=$2
|
||||
REPO="exo-explore/exo"
|
||||
|
||||
# Add sudoers configuration
|
||||
log "Configuring sudo access..."
|
||||
SUDOERS_CONTENT="$(whoami) ALL=(ALL) NOPASSWD: ALL"
|
||||
echo "$SUDOERS_CONTENT" | sudo tee /etc/sudoers.d/github-runner > /dev/null
|
||||
sudo chmod 440 /etc/sudoers.d/github-runner
|
||||
|
||||
log "Configuring privacy permissions..."
|
||||
sudo tccutil reset All
|
||||
sudo tccutil reset SystemPolicyAllFiles
|
||||
sudo tccutil reset SystemPolicyNetworkVolumes
|
||||
|
||||
# Configure power management for maximum performance
|
||||
log "Configuring power management..."
|
||||
sudo pmset -a powermode 2 # Force highest performance mode
|
||||
sudo pmset -a gpuswitch 2 # Force discrete/high-performance GPU
|
||||
sudo pmset -a lowpowermode 0
|
||||
sudo pmset -a lessbright 0
|
||||
sudo pmset -a disablesleep 1
|
||||
sudo pmset -a sleep 0
|
||||
sudo pmset -a hibernatemode 0
|
||||
sudo pmset -a autopoweroff 0
|
||||
sudo pmset -a standby 0
|
||||
sudo pmset -a powernap 0
|
||||
|
||||
# For Python specifically
|
||||
PYTHON_PATH="/opt/homebrew/bin/python3.12"
|
||||
sudo chmod 755 "$PYTHON_PATH"
|
||||
|
||||
# Add to firewall
|
||||
log "Configuring firewall access..."
|
||||
sudo /usr/libexec/ApplicationFirewall/socketfilterfw --add "$PYTHON_PATH"
|
||||
sudo /usr/libexec/ApplicationFirewall/socketfilterfw --unblock "$PYTHON_PATH"
|
||||
|
||||
# Set Homebrew paths based on architecture
|
||||
if [ "$(uname -p)" = "arm" ]; then
|
||||
BREW_PREFIX="/opt/homebrew"
|
||||
else
|
||||
BREW_PREFIX="/usr/local"
|
||||
fi
|
||||
|
||||
# Install Homebrew if not present
|
||||
if ! command_exists brew; then
|
||||
log "Installing Homebrew..."
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
|
||||
echo 'eval "$(/opt/homebrew/bin/brew shellenv)"' >> ~/.zshrc
|
||||
eval "$(/opt/homebrew/bin/brew shellenv)"
|
||||
fi
|
||||
|
||||
# Install required packages
|
||||
log "Installing required packages..."
|
||||
export HOMEBREW_NO_AUTO_UPDATE=1
|
||||
brew install python@3.12 coreutils
|
||||
|
||||
# Optional Tailscale setup if auth key is provided
|
||||
if [ -n "$TAILSCALE_AUTH_KEY" ]; then
|
||||
log "Installing and configuring Tailscale..."
|
||||
brew install --quiet tailscale
|
||||
sudo brew services stop tailscale 2>/dev/null || true
|
||||
sudo rm -f /var/db/tailscale/tailscaled.state 2>/dev/null || true
|
||||
sudo brew services start tailscale
|
||||
sleep 2
|
||||
sudo tailscale up --authkey=$TAILSCALE_AUTH_KEY
|
||||
|
||||
# Enable SSH and Screen Sharing
|
||||
log "Enabling remote access services..."
|
||||
sudo launchctl load -w /System/Library/LaunchDaemons/ssh.plist
|
||||
sudo /System/Library/CoreServices/RemoteManagement/ARDAgent.app/Contents/Resources/kickstart \
|
||||
-activate \
|
||||
-configure -access -on \
|
||||
-configure -allowAccessFor -allUsers \
|
||||
-configure -restart -agent -privs -all
|
||||
|
||||
# Create launch daemon for remote access
|
||||
sudo bash -c 'cat > /Library/LaunchDaemons/com.remote.access.setup.plist' << 'EOL'
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>Label</key>
|
||||
<string>com.remote.access.setup</string>
|
||||
<key>ProgramArguments</key>
|
||||
<array>
|
||||
<string>/bin/bash</string>
|
||||
<string>-c</string>
|
||||
<string>
|
||||
launchctl load -w /System/Library/LaunchDaemons/ssh.plist;
|
||||
/System/Library/CoreServices/RemoteManagement/ARDAgent.app/Contents/Resources/kickstart -activate -configure -access -on
|
||||
</string>
|
||||
</array>
|
||||
<key>RunAtLoad</key>
|
||||
<true/>
|
||||
</dict>
|
||||
</plist>
|
||||
EOL
|
||||
|
||||
sudo chmod 644 /Library/LaunchDaemons/com.remote.access.setup.plist
|
||||
sudo launchctl load -w /Library/LaunchDaemons/com.remote.access.setup.plist
|
||||
fi
|
||||
|
||||
# Configure GitHub Actions Runner
|
||||
log "Gathering system metadata..."
|
||||
MACHINE_NAME=$(scutil --get ComputerName)
|
||||
MACHINE_NAME="runner-$(echo -n "$MACHINE_NAME" | tr '[:upper:]' '[:lower:]' | tr -cd '[:alnum:]-')"
|
||||
|
||||
# Enhanced Apple Silicon detection
|
||||
MACHINE_INFO=$(system_profiler SPHardwareDataType)
|
||||
CHIP_FULL=$(echo "$MACHINE_INFO" | grep "Chip" | cut -d: -f2 | xargs)
|
||||
if [[ $CHIP_FULL =~ "Apple" ]]; then
|
||||
CHIP_MODEL=$(echo "$CHIP_FULL" | sed 's/^Apple //' | tr -d ' ' | tr '[:lower:]' '[:upper:]')
|
||||
GPU_CORES=$(ioreg -l | grep "gpu-core-count" | awk -F'= ' '{print $2}')
|
||||
if [ -z "$GPU_CORES" ]; then
|
||||
GPU_CORES="N/A"
|
||||
fi
|
||||
else
|
||||
CHIP_MODEL="Intel"
|
||||
GPU_CORES="N/A"
|
||||
fi
|
||||
|
||||
MEMORY=$(($(sysctl -n hw.memsize) / 1024 / 1024 / 1024))
|
||||
|
||||
# Set up GitHub Runner
|
||||
RUNNER_DIR="$HOME/actions-runner"
|
||||
|
||||
# Check if runner is already configured
|
||||
if [ -f "$RUNNER_DIR/.runner" ]; then
|
||||
log "Runner already configured. Stopping existing service..."
|
||||
sudo launchctl unload /Library/LaunchDaemons/com.github.runner.plist 2>/dev/null || true
|
||||
fi
|
||||
|
||||
# Create runner directory if it doesn't exist
|
||||
mkdir -p "$RUNNER_DIR"
|
||||
cd "$RUNNER_DIR"
|
||||
|
||||
CUSTOM_LABELS="self-hosted,macos,arm64,${CHIP_MODEL}_GPU${GPU_CORES}_${MEMORY}GB"
|
||||
|
||||
# Only download and extract if not already present or if forced
|
||||
if [ ! -f "$RUNNER_DIR/run.sh" ] || [ "${FORCE_SETUP:-false}" = "true" ]; then
|
||||
log "Downloading GitHub Actions runner..."
|
||||
RUNNER_VERSION=$(curl -s https://api.github.com/repos/actions/runner/releases/latest | grep '"tag_name":' | cut -d'"' -f4)
|
||||
curl -o actions-runner.tar.gz -L "https://github.com/actions/runner/releases/download/${RUNNER_VERSION}/actions-runner-osx-arm64-${RUNNER_VERSION#v}.tar.gz"
|
||||
tar xzf actions-runner.tar.gz
|
||||
rm actions-runner.tar.gz
|
||||
else
|
||||
log "Runner already downloaded, skipping download step"
|
||||
fi
|
||||
|
||||
log "Configuring runner with labels: $CUSTOM_LABELS"
|
||||
./config.sh --unattended \
|
||||
--url "https://github.com/${REPO}" \
|
||||
--token "${RUNNER_TOKEN}" \
|
||||
--name "${MACHINE_NAME}" \
|
||||
--labels "${CUSTOM_LABELS}" \
|
||||
--work "_work"
|
||||
|
||||
# Set optimal performance settings
|
||||
log "Configuring system for optimal performance..."
|
||||
|
||||
# Configure CPU performance
|
||||
log "Setting CPU performance controls..."
|
||||
# Disable timer coalescing
|
||||
sudo sysctl -w kern.timer.coalescing_enabled=0
|
||||
sudo sysctl -w kern.timer_coalesce_bg_scale=-5
|
||||
sudo sysctl -w kern.timer_resort_threshold_ns=0
|
||||
# Set minimum timer intervals
|
||||
sudo sysctl -w kern.wq_max_timer_interval_usecs=1000
|
||||
sudo sysctl -w kern.timer_coalesce_bg_ns_max=1000
|
||||
# Set minimum timer coalescing for all tiers
|
||||
sudo sysctl -w kern.timer_coalesce_tier0_scale=-5
|
||||
sudo sysctl -w kern.timer_coalesce_tier0_ns_max=1000
|
||||
sudo sysctl -w kern.timer_coalesce_tier1_scale=-5
|
||||
sudo sysctl -w kern.timer_coalesce_tier1_ns_max=1000
|
||||
sudo sysctl -w kern.timer_coalesce_tier2_scale=-5
|
||||
sudo sysctl -w kern.timer_coalesce_tier2_ns_max=1000
|
||||
sudo sysctl -w kern.timer_coalesce_tier3_scale=-5
|
||||
sudo sysctl -w kern.timer_coalesce_tier3_ns_max=1000
|
||||
sudo sysctl -w kern.timer_coalesce_tier4_scale=-5
|
||||
sudo sysctl -w kern.timer_coalesce_tier4_ns_max=1000
|
||||
# Disable QoS restrictions
|
||||
sudo sysctl -w net.qos.policy.restricted=0
|
||||
sudo sysctl -w net.qos.policy.restrict_avapps=0
|
||||
sudo sysctl -w net.qos.policy.wifi_enabled=0
|
||||
sudo sysctl -w net.qos.policy.capable_enabled=0
|
||||
# Set scheduler parameters
|
||||
sudo sysctl -w kern.sched_rt_avoid_cpu0=0
|
||||
sudo sysctl -w debug.sched=2
|
||||
sudo sysctl -w net.pktsched.netem.sched_output_ival_ms=1
|
||||
|
||||
# Clean up any existing runner services
|
||||
log "Cleaning up existing runner services..."
|
||||
for service in com.github.runner com.github.runner.monitor com.github.runner.cpuaffinity com.github.runner.affinity; do
|
||||
sudo launchctl bootout system/$service 2>/dev/null || true
|
||||
sudo rm -f /Library/LaunchDaemons/$service.plist
|
||||
done
|
||||
|
||||
# Create a simple runner service configuration
|
||||
sudo tee /Library/LaunchDaemons/com.github.runner.plist > /dev/null << EOF
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>Label</key>
|
||||
<string>com.github.runner</string>
|
||||
<key>UserName</key>
|
||||
<string>$(whoami)</string>
|
||||
<key>GroupName</key>
|
||||
<string>staff</string>
|
||||
<key>WorkingDirectory</key>
|
||||
<string>$RUNNER_DIR</string>
|
||||
<key>ProgramArguments</key>
|
||||
<array>
|
||||
<string>$RUNNER_DIR/run.sh</string>
|
||||
</array>
|
||||
<key>RunAtLoad</key>
|
||||
<true/>
|
||||
<key>KeepAlive</key>
|
||||
<dict>
|
||||
<key>SuccessfulExit</key>
|
||||
<false/>
|
||||
<key>Crashed</key>
|
||||
<true/>
|
||||
</dict>
|
||||
<key>ProcessType</key>
|
||||
<string>Interactive</string>
|
||||
<key>LowPriorityIO</key>
|
||||
<false/>
|
||||
<key>AbandonProcessGroup</key>
|
||||
<false/>
|
||||
<key>EnableTransactions</key>
|
||||
<true/>
|
||||
<key>ThrottleInterval</key>
|
||||
<integer>0</integer>
|
||||
<key>HardResourceLimits</key>
|
||||
<dict>
|
||||
<key>NumberOfFiles</key>
|
||||
<integer>524288</integer>
|
||||
<key>MemoryLock</key>
|
||||
<integer>-1</integer>
|
||||
</dict>
|
||||
<key>SoftResourceLimits</key>
|
||||
<dict>
|
||||
<key>NumberOfFiles</key>
|
||||
<integer>524288</integer>
|
||||
<key>MemoryLock</key>
|
||||
<integer>-1</integer>
|
||||
</dict>
|
||||
<key>QOSClass</key>
|
||||
<string>User-Interactive</string>
|
||||
<key>StandardOutPath</key>
|
||||
<string>$RUNNER_DIR/_diag/runner.log</string>
|
||||
<key>StandardErrorPath</key>
|
||||
<string>$RUNNER_DIR/_diag/runner.err</string>
|
||||
<key>EnvironmentVariables</key>
|
||||
<dict>
|
||||
<key>PATH</key>
|
||||
<string>/usr/local/bin:/opt/homebrew/bin:/usr/bin:/bin:/usr/sbin:/sbin</string>
|
||||
</dict>
|
||||
<key>Nice</key>
|
||||
<integer>-20</integer>
|
||||
</dict>
|
||||
</plist>
|
||||
EOF
|
||||
|
||||
# Set proper permissions for the LaunchDaemon
|
||||
sudo chown root:wheel /Library/LaunchDaemons/com.github.runner.plist
|
||||
sudo chmod 644 /Library/LaunchDaemons/com.github.runner.plist
|
||||
|
||||
# Remove any existing service
|
||||
sudo launchctl bootout system/com.github.runner 2>/dev/null || true
|
||||
|
||||
# Load the new service using bootstrap
|
||||
sudo launchctl bootstrap system /Library/LaunchDaemons/com.github.runner.plist
|
||||
|
||||
# Add Runner.Listener permissions (after runner installation)
|
||||
RUNNER_PATH="$RUNNER_DIR/bin/Runner.Listener"
|
||||
sudo chmod 755 "$RUNNER_PATH"
|
||||
sudo /usr/libexec/ApplicationFirewall/socketfilterfw --add "$RUNNER_PATH"
|
||||
sudo /usr/libexec/ApplicationFirewall/socketfilterfw --unblock "$RUNNER_PATH"
|
||||
|
||||
# Create connection info file if Tailscale is configured
|
||||
if [ -n "$TAILSCALE_AUTH_KEY" ]; then
|
||||
TAILSCALE_IP=$(tailscale ip)
|
||||
cat > "$HOME/remote_access_info.txt" << EOL
|
||||
Mac Remote Access Information
|
||||
============================
|
||||
Computer Name: $MACHINE_NAME
|
||||
Username: $USER
|
||||
Tailscale IP: $TAILSCALE_IP
|
||||
|
||||
SSH Command: ssh $USER@$TAILSCALE_IP
|
||||
Screen Sharing: vnc://$TAILSCALE_IP
|
||||
EOL
|
||||
chmod 600 "$HOME/remote_access_info.txt"
|
||||
fi
|
||||
|
||||
log "Verifying runner service status..."
|
||||
if sudo launchctl list | grep com.github.runner > /dev/null; then
|
||||
log "GitHub Actions runner service is running successfully!"
|
||||
log "Runner labels: $CUSTOM_LABELS"
|
||||
[ -n "$TAILSCALE_AUTH_KEY" ] && log "Remote access details saved to: $HOME/remote_access_info.txt"
|
||||
else
|
||||
log "Error: Failed to start GitHub Actions runner service"
|
||||
exit 1
|
||||
fi
|
||||
95
.github/optimize_performance.sh
vendored
95
.github/optimize_performance.sh
vendored
@@ -1,95 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# Function to log with timestamp
|
||||
log() {
|
||||
echo "[$(date '+%Y-%m-%d %H:%M:%S')] $1"
|
||||
}
|
||||
|
||||
log "Applying comprehensive performance optimizations..."
|
||||
|
||||
# System-wide power management
|
||||
log "Configuring power management..."
|
||||
sudo pmset -a lessbright 0
|
||||
sudo pmset -a disablesleep 1
|
||||
sudo pmset -a sleep 0
|
||||
sudo pmset -a hibernatemode 0
|
||||
sudo pmset -a autopoweroff 0
|
||||
sudo pmset -a standby 0
|
||||
sudo pmset -a powernap 0
|
||||
sudo pmset -a proximitywake 0
|
||||
sudo pmset -a tcpkeepalive 1
|
||||
sudo pmset -a powermode 2
|
||||
sudo pmset -a gpuswitch 2
|
||||
sudo pmset -a displaysleep 0
|
||||
sudo pmset -a disksleep 0
|
||||
|
||||
# Memory and kernel optimizations
|
||||
log "Configuring memory and kernel settings..."
|
||||
sudo sysctl -w kern.memorystatus_purge_on_warning=0
|
||||
sudo sysctl -w kern.memorystatus_purge_on_critical=0
|
||||
sudo sysctl -w kern.timer.coalescing_enabled=0
|
||||
|
||||
# Metal and GPU optimizations
|
||||
log "Configuring Metal and GPU settings..."
|
||||
defaults write com.apple.CoreML MPSEnableGPUValidation -bool false
|
||||
defaults write com.apple.CoreML MPSEnableMetalValidation -bool false
|
||||
defaults write com.apple.CoreML MPSEnableGPUDebug -bool false
|
||||
defaults write com.apple.Metal GPUDebug -bool false
|
||||
defaults write com.apple.Metal GPUValidation -bool false
|
||||
defaults write com.apple.Metal MetalValidation -bool false
|
||||
defaults write com.apple.Metal MetalCaptureEnabled -bool false
|
||||
defaults write com.apple.Metal MTLValidationBehavior -string "Disabled"
|
||||
defaults write com.apple.Metal EnableMTLDebugLayer -bool false
|
||||
defaults write com.apple.Metal MTLDebugLevel -int 0
|
||||
defaults write com.apple.Metal PreferIntegratedGPU -bool false
|
||||
defaults write com.apple.Metal ForceMaximumPerformance -bool true
|
||||
defaults write com.apple.Metal MTLPreferredDeviceGPUFrame -bool true
|
||||
|
||||
# Create MPS cache directory with proper permissions
|
||||
sudo mkdir -p /tmp/mps_cache
|
||||
sudo chmod 777 /tmp/mps_cache
|
||||
|
||||
# Process and resource limits
|
||||
log "Configuring process limits..."
|
||||
sudo launchctl limit maxfiles 524288 524288
|
||||
ulimit -n 524288 || log "Warning: Could not set file descriptor limit"
|
||||
ulimit -c 0
|
||||
ulimit -l unlimited || log "Warning: Could not set memory lock limit"
|
||||
|
||||
# Export performance-related environment variables
|
||||
cat << 'EOF' > /tmp/performance_env.sh
|
||||
# Metal optimizations
|
||||
export MTL_DEBUG_LAYER=0
|
||||
export METAL_DEVICE_WRAPPER_TYPE=1
|
||||
export METAL_DEBUG_ERROR_MODE=0
|
||||
export METAL_FORCE_PERFORMANCE_MODE=1
|
||||
export METAL_DEVICE_PRIORITY=high
|
||||
export METAL_MAX_COMMAND_QUEUES=1024
|
||||
export METAL_LOAD_LIMIT=0
|
||||
export METAL_VALIDATION_ENABLED=0
|
||||
export METAL_ENABLE_VALIDATION_LAYER=0
|
||||
export OBJC_DEBUG_MISSING_POOLS=NO
|
||||
export MPS_CACHEDIR=/tmp/mps_cache
|
||||
|
||||
# MLX optimizations
|
||||
export MLX_USE_GPU=1
|
||||
export MLX_METAL_COMPILE_ASYNC=1
|
||||
export MLX_METAL_PREALLOCATE=1
|
||||
export MLX_METAL_MEMORY_GUARD=0
|
||||
export MLX_METAL_CACHE_KERNELS=1
|
||||
export MLX_PLACEMENT_POLICY=metal
|
||||
export MLX_METAL_VALIDATION=0
|
||||
export MLX_METAL_DEBUG=0
|
||||
export MLX_FORCE_P_CORES=1
|
||||
export MLX_METAL_MEMORY_BUDGET=0
|
||||
export MLX_METAL_PREWARM=1
|
||||
|
||||
# Python optimizations
|
||||
export PYTHONUNBUFFERED=1
|
||||
export PYTHONOPTIMIZE=2
|
||||
export PYTHONHASHSEED=0
|
||||
export PYTHONDONTWRITEBYTECODE=1
|
||||
EOF
|
||||
|
||||
log "Performance optimizations completed. Environment variables written to /tmp/performance_env.sh"
|
||||
207
.github/workflows/bench_job.yml
vendored
207
.github/workflows/bench_job.yml
vendored
@@ -1,207 +0,0 @@
|
||||
# This is the reusable workflow file
|
||||
name: Distributed Job Runner
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
config:
|
||||
required: true
|
||||
type: string
|
||||
model:
|
||||
required: true
|
||||
type: string
|
||||
calling_job_name:
|
||||
required: true
|
||||
type: string
|
||||
network_interface:
|
||||
required: true
|
||||
type: string
|
||||
jobs:
|
||||
generate-matrix:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
matrix: ${{ steps.set-matrix.outputs.matrix }}
|
||||
steps:
|
||||
- id: set-matrix
|
||||
env:
|
||||
CONFIG: ${{ inputs.config }}
|
||||
run: |
|
||||
MATRIX=$(echo $CONFIG | jq -c '{cpu: [to_entries | .[] | .key as $k | range(.value) | $k]}')
|
||||
echo "matrix=$MATRIX" >> $GITHUB_OUTPUT
|
||||
|
||||
run-distributed-job:
|
||||
needs: generate-matrix
|
||||
strategy:
|
||||
matrix: ${{fromJson(needs.generate-matrix.outputs.matrix)}}
|
||||
runs-on: ['self-hosted', 'macOS', '${{ matrix.cpu }}']
|
||||
env:
|
||||
HARDWARE_CONFIG: ${{ inputs.config }}
|
||||
model: ${{ inputs.model }}
|
||||
# Add performance-related environment variables
|
||||
MTL_DEBUG_LAYER: 0
|
||||
METAL_VALIDATION_ENABLED: 0
|
||||
MLX_METAL_VALIDATION: 0
|
||||
MLX_METAL_DEBUG: 0
|
||||
MLX_FORCE_P_CORES: 1
|
||||
MLX_METAL_PREWARM: 1
|
||||
PYTHONOPTIMIZE: 2
|
||||
steps:
|
||||
- name: Cleanup workspace
|
||||
run: |
|
||||
sudo rm -rf "$GITHUB_WORKSPACE"
|
||||
sudo mkdir -p "$GITHUB_WORKSPACE"
|
||||
sudo chown -R $(whoami):$(id -g) "$GITHUB_WORKSPACE"
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
export PATH="/usr/local/bin:/opt/homebrew/bin:$PATH"
|
||||
python3.12 -m venv .venv || {
|
||||
echo "Failed to find python3.12. Checking installation locations:"
|
||||
ls -l /usr/local/bin/python* /opt/homebrew/bin/python* 2>/dev/null || true
|
||||
exit 1
|
||||
}
|
||||
source .venv/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install -e .
|
||||
pip install boto3==1.35.76
|
||||
|
||||
- name: Apply Performance Optimizations
|
||||
run: |
|
||||
# Export performance-related environment variables
|
||||
cat << 'EOF' > /tmp/performance_env.sh
|
||||
# MLX and Metal optimizations
|
||||
export MTL_DEBUG_LAYER=0
|
||||
export METAL_VALIDATION_ENABLED=0
|
||||
export MLX_METAL_VALIDATION=0
|
||||
export MLX_METAL_DEBUG=0
|
||||
export MLX_FORCE_P_CORES=1
|
||||
export MLX_METAL_PREWARM=1
|
||||
export PYTHONOPTIMIZE=2
|
||||
EOF
|
||||
|
||||
# Source the performance environment variables
|
||||
source /tmp/performance_env.sh
|
||||
|
||||
# MLX Memory Settings
|
||||
./configure_mlx.sh
|
||||
|
||||
# Verify optimizations
|
||||
echo "Verifying performance settings..."
|
||||
env | grep -E "MLX_|METAL_|MTL_"
|
||||
|
||||
- name: Run exo
|
||||
env:
|
||||
aws_access_key_id: ${{ secrets.S3_EXO_BENCHMARKS_AWS_ACCESS_KEY_ID }}
|
||||
aws_secret_key: ${{ secrets.S3_EXO_BENCHMARKS_AWS_SECRET_ACCESS_KEY }}
|
||||
run: |
|
||||
# Source performance environment variables
|
||||
source /tmp/performance_env.sh
|
||||
|
||||
# Debug information
|
||||
echo "Current commit SHA: $GITHUB_SHA"
|
||||
git rev-parse HEAD
|
||||
git status
|
||||
|
||||
CALLING_JOB="${{ inputs.calling_job_name }}"
|
||||
UNIQUE_JOB_ID="${CALLING_JOB}_${model}_${GITHUB_RUN_ID}"
|
||||
ALL_NODE_IDS=$(for i in $(seq ${{ strategy.job-total }} -1 0); do echo -n "${UNIQUE_JOB_ID}_${i},"; done | sed 's/,$//')
|
||||
MY_NODE_ID="${UNIQUE_JOB_ID}_${{ strategy.job-index }}"
|
||||
|
||||
source .venv/bin/activate
|
||||
export PATH="/usr/local/bin:/opt/homebrew/bin:$PATH"
|
||||
|
||||
echo "=== Before starting exo ==="
|
||||
ps -eo pid,ppid,user,%cpu,%mem,nice,state,pri,command | head -1
|
||||
ps -eo pid,ppid,user,%cpu,%mem,nice,state,pri,command | grep -i python
|
||||
|
||||
echo "Starting exo daemon..."
|
||||
|
||||
echo "Power mode settings:"
|
||||
sudo pmset -g
|
||||
|
||||
# Start exo with explicit process control
|
||||
sudo taskpolicy -d default -g default -a -t 0 -l 0 .venv/bin/exo \
|
||||
--node-id="${MY_NODE_ID}" \
|
||||
--node-id-filter="${ALL_NODE_IDS}" \
|
||||
--interface-type-filter="${{ inputs.network_interface }}" \
|
||||
--disable-tui \
|
||||
--max-generate-tokens 250 \
|
||||
--chatgpt-api-response-timeout 900 \
|
||||
--chatgpt-api-port 52415 > output1.log 2>&1 &
|
||||
PID1=$!
|
||||
|
||||
echo "Exo process started with PID: $PID1"
|
||||
tail -f output1.log &
|
||||
TAIL1=$!
|
||||
|
||||
# Give process time to start
|
||||
sleep 2
|
||||
|
||||
# Set additional process priorities
|
||||
sudo renice -n -20 -p $PID1
|
||||
sudo taskpolicy -t 4 -p $PID1
|
||||
|
||||
echo "=== After starting exo ==="
|
||||
ps -eo pid,ppid,user,%cpu,%mem,nice,state,pri,command | head -1
|
||||
ps -eo pid,ppid,user,%cpu,%mem,nice,state,pri,command | grep $PID1
|
||||
|
||||
echo "Additional process details:"
|
||||
sudo powermetrics -n 1 -i 1000 --show-process-energy | grep -A 5 $PID1 || true
|
||||
|
||||
trap 'kill $TAIL1' EXIT
|
||||
trap 'kill $PID1' EXIT
|
||||
|
||||
echo "Waiting for all nodes to connect..."
|
||||
for i in {1..20}; do
|
||||
echo "Attempt $i: Checking node count..."
|
||||
nodes=$(curl -s http://localhost:52415/topology | jq ".nodes | length")
|
||||
echo "Current node count: $nodes"
|
||||
if [ "$nodes" -eq "${{ strategy.job-total }}" ]; then
|
||||
echo "All nodes connected successfully!"
|
||||
break
|
||||
fi
|
||||
if [ $i -eq 20 ]; then
|
||||
echo "ERROR: Failed to connect all nodes after 20 attempts. Expected ${{ strategy.job-total }} nodes, but got $nodes"
|
||||
exit 1
|
||||
fi
|
||||
sleep 5
|
||||
done
|
||||
|
||||
if ! kill -0 $PID1 2>/dev/null; then
|
||||
echo "ERROR: Instance (PID $PID1) died unexpectedly. Full log output:"
|
||||
cat output1.log
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ "${{ strategy.job-index }}" -eq "0" ]; then
|
||||
sleep 10
|
||||
echo "This is the primary node (index 0). Running benchmark..."
|
||||
GITHUB_JOB=$CALLING_JOB python .github/bench.py
|
||||
else
|
||||
echo "This is a secondary node (index ${{ strategy.job-index }}). Waiting for completion..."
|
||||
sleep 10
|
||||
while true; do
|
||||
echo "Checking if primary node is still running..."
|
||||
nodes=$(curl -s http://localhost:52415/topology | jq ".nodes | length")
|
||||
echo "Current node count: $nodes"
|
||||
if [ "$nodes" -lt "${{ strategy.job-total }}" ]; then
|
||||
echo "Primary node completed, exiting..."
|
||||
break
|
||||
fi
|
||||
sleep 5
|
||||
done
|
||||
fi
|
||||
|
||||
- name: Check Final System State
|
||||
if: always()
|
||||
run: |
|
||||
echo "=== Final System State ==="
|
||||
sudo pmset -g
|
||||
sudo powermetrics -n 1 -i 1000 --show-process-energy || true
|
||||
system_profiler SPDisplaysDataType
|
||||
sysctl iogpu
|
||||
ps -eo pid,ppid,user,%cpu,%mem,nice,state,command | grep -i python
|
||||
env | grep -E "MLX_|METAL_|MTL_"
|
||||
echo "=== End Final System State ==="
|
||||
71
.github/workflows/benchmarks.yml
vendored
71
.github/workflows/benchmarks.yml
vendored
@@ -1,71 +0,0 @@
|
||||
name: Build and Test
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ '*' ]
|
||||
tags: [ '*' ]
|
||||
pull_request:
|
||||
branches: [ '*' ]
|
||||
|
||||
jobs:
|
||||
single-m4-pro:
|
||||
strategy:
|
||||
matrix:
|
||||
model: ['llama-3.2-1b', 'llama-3.2-3b', 'llama-3.1-8b']
|
||||
uses: ./.github/workflows/bench_job.yml
|
||||
with:
|
||||
config: '{"M4PRO_GPU16_24GB": 1}'
|
||||
model: ${{ matrix.model }}
|
||||
calling_job_name: 'single-m4-pro'
|
||||
network_interface: 'Ethernet'
|
||||
secrets: inherit
|
||||
|
||||
two-m4-pro-cluster:
|
||||
strategy:
|
||||
matrix:
|
||||
model: ['llama-3.2-1b', 'llama-3.2-3b', 'llama-3.1-8b']
|
||||
uses: ./.github/workflows/bench_job.yml
|
||||
with:
|
||||
config: '{"M4PRO_GPU16_24GB": 2}'
|
||||
model: ${{ matrix.model }}
|
||||
calling_job_name: 'two-m4-pro-cluster'
|
||||
network_interface: 'Ethernet'
|
||||
secrets: inherit
|
||||
|
||||
# two-m4-pro-cluster-thunderbolt:
|
||||
# strategy:
|
||||
# matrix:
|
||||
# model: ['llama-3.2-1b', 'llama-3.2-3b', 'llama-3.1-8b']
|
||||
# uses: ./.github/workflows/bench_job.yml
|
||||
# with:
|
||||
# config: '{"M4PRO_GPU16_24GB": 2}'
|
||||
# model: ${{ matrix.model }}
|
||||
# calling_job_name: 'two-m4-pro-cluster-thunderbolt'
|
||||
# network_interface: 'Thunderbolt'
|
||||
# secrets: inherit
|
||||
|
||||
three-m4-pro-cluster:
|
||||
strategy:
|
||||
matrix:
|
||||
model: ['llama-3.2-1b', 'llama-3.2-3b', 'llama-3.1-8b', 'llama-3.3-70b']
|
||||
fail-fast: false
|
||||
uses: ./.github/workflows/bench_job.yml
|
||||
with:
|
||||
config: '{"M4PRO_GPU16_24GB": 3}'
|
||||
model: ${{ matrix.model }}
|
||||
calling_job_name: 'three-m4-pro-cluster'
|
||||
network_interface: 'Ethernet'
|
||||
secrets: inherit
|
||||
|
||||
# test-m3-single-node:
|
||||
# strategy:
|
||||
# matrix:
|
||||
# model: ['llama-3.2-1b']
|
||||
# fail-fast: false
|
||||
# uses: ./.github/workflows/bench_job.yml
|
||||
# with:
|
||||
# config: '{"M3MAX_GPU40_128GB": 1}'
|
||||
# model: ${{ matrix.model }}
|
||||
# calling_job_name: 'test-m3-cluster'
|
||||
# network_interface: 'Ethernet'
|
||||
# secrets: inherit
|
||||
175
.gitignore
vendored
175
.gitignore
vendored
@@ -1,175 +0,0 @@
|
||||
__pycache__/
|
||||
.venv*
|
||||
test_weights.npz
|
||||
.exo_used_ports
|
||||
.exo_node_id
|
||||
.idea
|
||||
.DS_Store
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
/.Python
|
||||
/develop-eggs/
|
||||
/dist/
|
||||
/downloads/
|
||||
/eggs/
|
||||
/.eggs/
|
||||
/lib/
|
||||
/lib64/
|
||||
/parts/
|
||||
/sdist/
|
||||
/var/
|
||||
/wheels/
|
||||
/share/python-wheels/
|
||||
/*.egg-info/
|
||||
/.installed.cfg
|
||||
/*.egg
|
||||
/MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
Untitled.ipynb
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
||||
.pdm.toml
|
||||
.pdm-python
|
||||
.pdm-build/
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
**/*.xcodeproj/*
|
||||
.aider*
|
||||
|
||||
exo/tinychat/images/*.png
|
||||
19
.style.yapf
19
.style.yapf
@@ -1,19 +0,0 @@
|
||||
[style]
|
||||
based_on_style = pep8
|
||||
indent_width = 2
|
||||
column_limit = 200
|
||||
allow_split_before_dict_value = False
|
||||
dedent_closing_brackets = True
|
||||
split_before_first_argument = False
|
||||
split_complex_comprehension = False
|
||||
continuation_indent_width = 2
|
||||
indent_dictionary_value = True
|
||||
allow_multiline_dictionary_keys = True
|
||||
each_dict_entry_on_separate_line = False
|
||||
allow_multiline_lambdas = True
|
||||
blank_line_before_nested_class_or_def = False
|
||||
arithmetic_precedence_indication = True
|
||||
no_spaces_around_selected_binary_operators = "*,/"
|
||||
coalesce_brackets = True
|
||||
space_between_ending_comma_and_closing_bracket = False
|
||||
split_before_expression_after_opening_paren = False
|
||||
675
LICENSE
675
LICENSE
@@ -1,675 +0,0 @@
|
||||
GNU GENERAL PUBLIC LICENSE
|
||||
Version 3, 29 June 2007
|
||||
|
||||
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
||||
Everyone is permitted to copy and distribute verbatim copies
|
||||
of this license document, but changing it is not allowed.
|
||||
|
||||
Preamble
|
||||
|
||||
The GNU General Public License is a free, copyleft license for
|
||||
software and other kinds of works.
|
||||
|
||||
The licenses for most software and other practical works are designed
|
||||
to take away your freedom to share and change the works. By contrast,
|
||||
the GNU General Public License is intended to guarantee your freedom to
|
||||
share and change all versions of a program--to make sure it remains free
|
||||
software for all its users. We, the Free Software Foundation, use the
|
||||
GNU General Public License for most of our software; it applies also to
|
||||
any other work released this way by its authors. You can apply it to
|
||||
your programs, too.
|
||||
|
||||
When we speak of free software, we are referring to freedom, not
|
||||
price. Our General Public Licenses are designed to make sure that you
|
||||
have the freedom to distribute copies of free software (and charge for
|
||||
them if you wish), that you receive source code or can get it if you
|
||||
want it, that you can change the software or use pieces of it in new
|
||||
free programs, and that you know you can do these things.
|
||||
|
||||
To protect your rights, we need to prevent others from denying you
|
||||
these rights or asking you to surrender the rights. Therefore, you have
|
||||
certain responsibilities if you distribute copies of the software, or if
|
||||
you modify it: responsibilities to respect the freedom of others.
|
||||
|
||||
For example, if you distribute copies of such a program, whether
|
||||
gratis or for a fee, you must pass on to the recipients the same
|
||||
freedoms that you received. You must make sure that they, too, receive
|
||||
or can get the source code. And you must show them these terms so they
|
||||
know their rights.
|
||||
|
||||
Developers that use the GNU GPL protect your rights with two steps:
|
||||
(1) assert copyright on the software, and (2) offer you this License
|
||||
giving you legal permission to copy, distribute and/or modify it.
|
||||
|
||||
For the developers' and authors' protection, the GPL clearly explains
|
||||
that there is no warranty for this free software. For both users' and
|
||||
authors' sake, the GPL requires that modified versions be marked as
|
||||
changed, so that their problems will not be attributed erroneously to
|
||||
authors of previous versions.
|
||||
|
||||
Some devices are designed to deny users access to install or run
|
||||
modified versions of the software inside them, although the manufacturer
|
||||
can do so. This is fundamentally incompatible with the aim of
|
||||
protecting users' freedom to change the software. The systematic
|
||||
pattern of such abuse occurs in the area of products for individuals to
|
||||
use, which is precisely where it is most unacceptable. Therefore, we
|
||||
have designed this version of the GPL to prohibit the practice for those
|
||||
products. If such problems arise substantially in other domains, we
|
||||
stand ready to extend this provision to those domains in future versions
|
||||
of the GPL, as needed to protect the freedom of users.
|
||||
|
||||
Finally, every program is threatened constantly by software patents.
|
||||
States should not allow patents to restrict development and use of
|
||||
software on general-purpose computers, but in those that do, we wish to
|
||||
avoid the special danger that patents applied to a free program could
|
||||
make it effectively proprietary. To prevent this, the GPL assures that
|
||||
patents cannot be used to render the program non-free.
|
||||
|
||||
The precise terms and conditions for copying, distribution and
|
||||
modification follow.
|
||||
|
||||
TERMS AND CONDITIONS
|
||||
|
||||
0. Definitions.
|
||||
|
||||
"This License" refers to version 3 of the GNU General Public License.
|
||||
|
||||
"Copyright" also means copyright-like laws that apply to other kinds of
|
||||
works, such as semiconductor masks.
|
||||
|
||||
"The Program" refers to any copyrightable work licensed under this
|
||||
License. Each licensee is addressed as "you". "Licensees" and
|
||||
"recipients" may be individuals or organizations.
|
||||
|
||||
To "modify" a work means to copy from or adapt all or part of the work
|
||||
in a fashion requiring copyright permission, other than the making of an
|
||||
exact copy. The resulting work is called a "modified version" of the
|
||||
earlier work or a work "based on" the earlier work.
|
||||
|
||||
A "covered work" means either the unmodified Program or a work based
|
||||
on the Program.
|
||||
|
||||
To "propagate" a work means to do anything with it that, without
|
||||
permission, would make you directly or secondarily liable for
|
||||
infringement under applicable copyright law, except executing it on a
|
||||
computer or modifying a private copy. Propagation includes copying,
|
||||
distribution (with or without modification), making available to the
|
||||
public, and in some countries other activities as well.
|
||||
|
||||
To "convey" a work means any kind of propagation that enables other
|
||||
parties to make or receive copies. Mere interaction with a user through
|
||||
a computer network, with no transfer of a copy, is not conveying.
|
||||
|
||||
An interactive user interface displays "Appropriate Legal Notices"
|
||||
to the extent that it includes a convenient and prominently visible
|
||||
feature that (1) displays an appropriate copyright notice, and (2)
|
||||
tells the user that there is no warranty for the work (except to the
|
||||
extent that warranties are provided), that licensees may convey the
|
||||
work under this License, and how to view a copy of this License. If
|
||||
the interface presents a list of user commands or options, such as a
|
||||
menu, a prominent item in the list meets this criterion.
|
||||
|
||||
1. Source Code.
|
||||
|
||||
The "source code" for a work means the preferred form of the work
|
||||
for making modifications to it. "Object code" means any non-source
|
||||
form of a work.
|
||||
|
||||
A "Standard Interface" means an interface that either is an official
|
||||
standard defined by a recognized standards body, or, in the case of
|
||||
interfaces specified for a particular programming language, one that
|
||||
is widely used among developers working in that language.
|
||||
|
||||
The "System Libraries" of an executable work include anything, other
|
||||
than the work as a whole, that (a) is included in the normal form of
|
||||
packaging a Major Component, but which is not part of that Major
|
||||
Component, and (b) serves only to enable use of the work with that
|
||||
Major Component, or to implement a Standard Interface for which an
|
||||
implementation is available to the public in source code form. A
|
||||
"Major Component", in this context, means a major essential component
|
||||
(kernel, window system, and so on) of the specific operating system
|
||||
(if any) on which the executable work runs, or a compiler used to
|
||||
produce the work, or an object code interpreter used to run it.
|
||||
|
||||
The "Corresponding Source" for a work in object code form means all
|
||||
the source code needed to generate, install, and (for an executable
|
||||
work) run the object code and to modify the work, including scripts to
|
||||
control those activities. However, it does not include the work's
|
||||
System Libraries, or general-purpose tools or generally available free
|
||||
programs which are used unmodified in performing those activities but
|
||||
which are not part of the work. For example, Corresponding Source
|
||||
includes interface definition files associated with source files for
|
||||
the work, and the source code for shared libraries and dynamically
|
||||
linked subprograms that the work is specifically designed to require,
|
||||
such as by intimate data communication or control flow between those
|
||||
subprograms and other parts of the work.
|
||||
|
||||
The Corresponding Source need not include anything that users
|
||||
can regenerate automatically from other parts of the Corresponding
|
||||
Source.
|
||||
|
||||
The Corresponding Source for a work in source code form is that
|
||||
same work.
|
||||
|
||||
2. Basic Permissions.
|
||||
|
||||
All rights granted under this License are granted for the term of
|
||||
copyright on the Program, and are irrevocable provided the stated
|
||||
conditions are met. This License explicitly affirms your unlimited
|
||||
permission to run the unmodified Program. The output from running a
|
||||
covered work is covered by this License only if the output, given its
|
||||
content, constitutes a covered work. This License acknowledges your
|
||||
rights of fair use or other equivalent, as provided by copyright law.
|
||||
|
||||
You may make, run and propagate covered works that you do not
|
||||
convey, without conditions so long as your license otherwise remains
|
||||
in force. You may convey covered works to others for the sole purpose
|
||||
of having them make modifications exclusively for you, or provide you
|
||||
with facilities for running those works, provided that you comply with
|
||||
the terms of this License in conveying all material for which you do
|
||||
not control copyright. Those thus making or running the covered works
|
||||
for you must do so exclusively on your behalf, under your direction
|
||||
and control, on terms that prohibit them from making any copies of
|
||||
your copyrighted material outside their relationship with you.
|
||||
|
||||
Conveying under any other circumstances is permitted solely under
|
||||
the conditions stated below. Sublicensing is not allowed; section 10
|
||||
makes it unnecessary.
|
||||
|
||||
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
||||
|
||||
No covered work shall be deemed part of an effective technological
|
||||
measure under any applicable law fulfilling obligations under article
|
||||
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
||||
similar laws prohibiting or restricting circumvention of such
|
||||
measures.
|
||||
|
||||
When you convey a covered work, you waive any legal power to forbid
|
||||
circumvention of technological measures to the extent such circumvention
|
||||
is effected by exercising rights under this License with respect to
|
||||
the covered work, and you disclaim any intention to limit operation or
|
||||
modification of the work as a means of enforcing, against the work's
|
||||
users, your or third parties' legal rights to forbid circumvention of
|
||||
technological measures.
|
||||
|
||||
4. Conveying Verbatim Copies.
|
||||
|
||||
You may convey verbatim copies of the Program's source code as you
|
||||
receive it, in any medium, provided that you conspicuously and
|
||||
appropriately publish on each copy an appropriate copyright notice;
|
||||
keep intact all notices stating that this License and any
|
||||
non-permissive terms added in accord with section 7 apply to the code;
|
||||
keep intact all notices of the absence of any warranty; and give all
|
||||
recipients a copy of this License along with the Program.
|
||||
|
||||
You may charge any price or no price for each copy that you convey,
|
||||
and you may offer support or warranty protection for a fee.
|
||||
|
||||
5. Conveying Modified Source Versions.
|
||||
|
||||
You may convey a work based on the Program, or the modifications to
|
||||
produce it from the Program, in the form of source code under the
|
||||
terms of section 4, provided that you also meet all of these conditions:
|
||||
|
||||
a) The work must carry prominent notices stating that you modified
|
||||
it, and giving a relevant date.
|
||||
|
||||
b) The work must carry prominent notices stating that it is
|
||||
released under this License and any conditions added under section
|
||||
7. This requirement modifies the requirement in section 4 to
|
||||
"keep intact all notices".
|
||||
|
||||
c) You must license the entire work, as a whole, under this
|
||||
License to anyone who comes into possession of a copy. This
|
||||
License will therefore apply, along with any applicable section 7
|
||||
additional terms, to the whole of the work, and all its parts,
|
||||
regardless of how they are packaged. This License gives no
|
||||
permission to license the work in any other way, but it does not
|
||||
invalidate such permission if you have separately received it.
|
||||
|
||||
d) If the work has interactive user interfaces, each must display
|
||||
Appropriate Legal Notices; however, if the Program has interactive
|
||||
interfaces that do not display Appropriate Legal Notices, your
|
||||
work need not make them do so.
|
||||
|
||||
A compilation of a covered work with other separate and independent
|
||||
works, which are not by their nature extensions of the covered work,
|
||||
and which are not combined with it such as to form a larger program,
|
||||
in or on a volume of a storage or distribution medium, is called an
|
||||
"aggregate" if the compilation and its resulting copyright are not
|
||||
used to limit the access or legal rights of the compilation's users
|
||||
beyond what the individual works permit. Inclusion of a covered work
|
||||
in an aggregate does not cause this License to apply to the other
|
||||
parts of the aggregate.
|
||||
|
||||
6. Conveying Non-Source Forms.
|
||||
|
||||
You may convey a covered work in object code form under the terms
|
||||
of sections 4 and 5, provided that you also convey the
|
||||
machine-readable Corresponding Source under the terms of this License,
|
||||
in one of these ways:
|
||||
|
||||
a) Convey the object code in, or embodied in, a physical product
|
||||
(including a physical distribution medium), accompanied by the
|
||||
Corresponding Source fixed on a durable physical medium
|
||||
customarily used for software interchange.
|
||||
|
||||
b) Convey the object code in, or embodied in, a physical product
|
||||
(including a physical distribution medium), accompanied by a
|
||||
written offer, valid for at least three years and valid for as
|
||||
long as you offer spare parts or customer support for that product
|
||||
model, to give anyone who possesses the object code either (1) a
|
||||
copy of the Corresponding Source for all the software in the
|
||||
product that is covered by this License, on a durable physical
|
||||
medium customarily used for software interchange, for a price no
|
||||
more than your reasonable cost of physically performing this
|
||||
conveying of source, or (2) access to copy the
|
||||
Corresponding Source from a network server at no charge.
|
||||
|
||||
c) Convey individual copies of the object code with a copy of the
|
||||
written offer to provide the Corresponding Source. This
|
||||
alternative is allowed only occasionally and noncommercially, and
|
||||
only if you received the object code with such an offer, in accord
|
||||
with subsection 6b.
|
||||
|
||||
d) Convey the object code by offering access from a designated
|
||||
place (gratis or for a charge), and offer equivalent access to the
|
||||
Corresponding Source in the same way through the same place at no
|
||||
further charge. You need not require recipients to copy the
|
||||
Corresponding Source along with the object code. If the place to
|
||||
copy the object code is a network server, the Corresponding Source
|
||||
may be on a different server (operated by you or a third party)
|
||||
that supports equivalent copying facilities, provided you maintain
|
||||
clear directions next to the object code saying where to find the
|
||||
Corresponding Source. Regardless of what server hosts the
|
||||
Corresponding Source, you remain obligated to ensure that it is
|
||||
available for as long as needed to satisfy these requirements.
|
||||
|
||||
e) Convey the object code using peer-to-peer transmission, provided
|
||||
you inform other peers where the object code and Corresponding
|
||||
Source of the work are being offered to the general public at no
|
||||
charge under subsection 6d.
|
||||
|
||||
A separable portion of the object code, whose source code is excluded
|
||||
from the Corresponding Source as a System Library, need not be
|
||||
included in conveying the object code work.
|
||||
|
||||
A "User Product" is either (1) a "consumer product", which means any
|
||||
tangible personal property which is normally used for personal, family,
|
||||
or household purposes, or (2) anything designed or sold for incorporation
|
||||
into a dwelling. In determining whether a product is a consumer product,
|
||||
doubtful cases shall be resolved in favor of coverage. For a particular
|
||||
product received by a particular user, "normally used" refers to a
|
||||
typical or common use of that class of product, regardless of the status
|
||||
of the particular user or of the way in which the particular user
|
||||
actually uses, or expects or is expected to use, the product. A product
|
||||
is a consumer product regardless of whether the product has substantial
|
||||
commercial, industrial or non-consumer uses, unless such uses represent
|
||||
the only significant mode of use of the product.
|
||||
|
||||
"Installation Information" for a User Product means any methods,
|
||||
procedures, authorization keys, or other information required to install
|
||||
and execute modified versions of a covered work in that User Product from
|
||||
a modified version of its Corresponding Source. The information must
|
||||
suffice to ensure that the continued functioning of the modified object
|
||||
code is in no case prevented or interfered with solely because
|
||||
modification has been made.
|
||||
|
||||
If you convey an object code work under this section in, or with, or
|
||||
specifically for use in, a User Product, and the conveying occurs as
|
||||
part of a transaction in which the right of possession and use of the
|
||||
User Product is transferred to the recipient in perpetuity or for a
|
||||
fixed term (regardless of how the transaction is characterized), the
|
||||
Corresponding Source conveyed under this section must be accompanied
|
||||
by the Installation Information. But this requirement does not apply
|
||||
if neither you nor any third party retains the ability to install
|
||||
modified object code on the User Product (for example, the work has
|
||||
been installed in ROM).
|
||||
|
||||
The requirement to provide Installation Information does not include a
|
||||
requirement to continue to provide support service, warranty, or updates
|
||||
for a work that has been modified or installed by the recipient, or for
|
||||
the User Product in which it has been modified or installed. Access to a
|
||||
network may be denied when the modification itself materially and
|
||||
adversely affects the operation of the network or violates the rules and
|
||||
protocols for communication across the network.
|
||||
|
||||
Corresponding Source conveyed, and Installation Information provided,
|
||||
in accord with this section must be in a format that is publicly
|
||||
documented (and with an implementation available to the public in
|
||||
source code form), and must require no special password or key for
|
||||
unpacking, reading or copying.
|
||||
|
||||
7. Additional Terms.
|
||||
|
||||
"Additional permissions" are terms that supplement the terms of this
|
||||
License by making exceptions from one or more of its conditions.
|
||||
Additional permissions that are applicable to the entire Program shall
|
||||
be treated as though they were included in this License, to the extent
|
||||
that they are valid under applicable law. If additional permissions
|
||||
apply only to part of the Program, that part may be used separately
|
||||
under those permissions, but the entire Program remains governed by
|
||||
this License without regard to the additional permissions.
|
||||
|
||||
When you convey a copy of a covered work, you may at your option
|
||||
remove any additional permissions from that copy, or from any part of
|
||||
it. (Additional permissions may be written to require their own
|
||||
removal in certain cases when you modify the work.) You may place
|
||||
additional permissions on material, added by you to a covered work,
|
||||
for which you have or can give appropriate copyright permission.
|
||||
|
||||
Notwithstanding any other provision of this License, for material you
|
||||
add to a covered work, you may (if authorized by the copyright holders of
|
||||
that material) supplement the terms of this License with terms:
|
||||
|
||||
a) Disclaiming warranty or limiting liability differently from the
|
||||
terms of sections 15 and 16 of this License; or
|
||||
|
||||
b) Requiring preservation of specified reasonable legal notices or
|
||||
author attributions in that material or in the Appropriate Legal
|
||||
Notices displayed by works containing it; or
|
||||
|
||||
c) Prohibiting misrepresentation of the origin of that material, or
|
||||
requiring that modified versions of such material be marked in
|
||||
reasonable ways as different from the original version; or
|
||||
|
||||
d) Limiting the use for publicity purposes of names of licensors or
|
||||
authors of the material; or
|
||||
|
||||
e) Declining to grant rights under trademark law for use of some
|
||||
trade names, trademarks, or service marks; or
|
||||
|
||||
f) Requiring indemnification of licensors and authors of that
|
||||
material by anyone who conveys the material (or modified versions of
|
||||
it) with contractual assumptions of liability to the recipient, for
|
||||
any liability that these contractual assumptions directly impose on
|
||||
those licensors and authors.
|
||||
|
||||
All other non-permissive additional terms are considered "further
|
||||
restrictions" within the meaning of section 10. If the Program as you
|
||||
received it, or any part of it, contains a notice stating that it is
|
||||
governed by this License along with a term that is a further
|
||||
restriction, you may remove that term. If a license document contains
|
||||
a further restriction but permits relicensing or conveying under this
|
||||
License, you may add to a covered work material governed by the terms
|
||||
of that license document, provided that the further restriction does
|
||||
not survive such relicensing or conveying.
|
||||
|
||||
If you add terms to a covered work in accord with this section, you
|
||||
must place, in the relevant source files, a statement of the
|
||||
additional terms that apply to those files, or a notice indicating
|
||||
where to find the applicable terms.
|
||||
|
||||
Additional terms, permissive or non-permissive, may be stated in the
|
||||
form of a separately written license, or stated as exceptions;
|
||||
the above requirements apply either way.
|
||||
|
||||
8. Termination.
|
||||
|
||||
You may not propagate or modify a covered work except as expressly
|
||||
provided under this License. Any attempt otherwise to propagate or
|
||||
modify it is void, and will automatically terminate your rights under
|
||||
this License (including any patent licenses granted under the third
|
||||
paragraph of section 11).
|
||||
|
||||
However, if you cease all violation of this License, then your
|
||||
license from a particular copyright holder is reinstated (a)
|
||||
provisionally, unless and until the copyright holder explicitly and
|
||||
finally terminates your license, and (b) permanently, if the copyright
|
||||
holder fails to notify you of the violation by some reasonable means
|
||||
prior to 60 days after the cessation.
|
||||
|
||||
Moreover, your license from a particular copyright holder is
|
||||
reinstated permanently if the copyright holder notifies you of the
|
||||
violation by some reasonable means, this is the first time you have
|
||||
received notice of violation of this License (for any work) from that
|
||||
copyright holder, and you cure the violation prior to 30 days after
|
||||
your receipt of the notice.
|
||||
|
||||
Termination of your rights under this section does not terminate the
|
||||
licenses of parties who have received copies or rights from you under
|
||||
this License. If your rights have been terminated and not permanently
|
||||
reinstated, you do not qualify to receive new licenses for the same
|
||||
material under section 10.
|
||||
|
||||
9. Acceptance Not Required for Having Copies.
|
||||
|
||||
You are not required to accept this License in order to receive or
|
||||
run a copy of the Program. Ancillary propagation of a covered work
|
||||
occurring solely as a consequence of using peer-to-peer transmission
|
||||
to receive a copy likewise does not require acceptance. However,
|
||||
nothing other than this License grants you permission to propagate or
|
||||
modify any covered work. These actions infringe copyright if you do
|
||||
not accept this License. Therefore, by modifying or propagating a
|
||||
covered work, you indicate your acceptance of this License to do so.
|
||||
|
||||
10. Automatic Licensing of Downstream Recipients.
|
||||
|
||||
Each time you convey a covered work, the recipient automatically
|
||||
receives a license from the original licensors, to run, modify and
|
||||
propagate that work, subject to this License. You are not responsible
|
||||
for enforcing compliance by third parties with this License.
|
||||
|
||||
An "entity transaction" is a transaction transferring control of an
|
||||
organization, or substantially all assets of one, or subdividing an
|
||||
organization, or merging organizations. If propagation of a covered
|
||||
work results from an entity transaction, each party to that
|
||||
transaction who receives a copy of the work also receives whatever
|
||||
licenses to the work the party's predecessor in interest had or could
|
||||
give under the previous paragraph, plus a right to possession of the
|
||||
Corresponding Source of the work from the predecessor in interest, if
|
||||
the predecessor has it or can get it with reasonable efforts.
|
||||
|
||||
You may not impose any further restrictions on the exercise of the
|
||||
rights granted or affirmed under this License. For example, you may
|
||||
not impose a license fee, royalty, or other charge for exercise of
|
||||
rights granted under this License, and you may not initiate litigation
|
||||
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
||||
any patent claim is infringed by making, using, selling, offering for
|
||||
sale, or importing the Program or any portion of it.
|
||||
|
||||
11. Patents.
|
||||
|
||||
A "contributor" is a copyright holder who authorizes use under this
|
||||
License of the Program or a work on which the Program is based. The
|
||||
work thus licensed is called the contributor's "contributor version".
|
||||
|
||||
A contributor's "essential patent claims" are all patent claims
|
||||
owned or controlled by the contributor, whether already acquired or
|
||||
hereafter acquired, that would be infringed by some manner, permitted
|
||||
by this License, of making, using, or selling its contributor version,
|
||||
but do not include claims that would be infringed only as a
|
||||
consequence of further modification of the contributor version. For
|
||||
purposes of this definition, "control" includes the right to grant
|
||||
patent sublicenses in a manner consistent with the requirements of
|
||||
this License.
|
||||
|
||||
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
||||
patent license under the contributor's essential patent claims, to
|
||||
make, use, sell, offer for sale, import and otherwise run, modify and
|
||||
propagate the contents of its contributor version.
|
||||
|
||||
In the following three paragraphs, a "patent license" is any express
|
||||
agreement or commitment, however denominated, not to enforce a patent
|
||||
(such as an express permission to practice a patent or covenant not to
|
||||
sue for patent infringement). To "grant" such a patent license to a
|
||||
party means to make such an agreement or commitment not to enforce a
|
||||
patent against the party.
|
||||
|
||||
If you convey a covered work, knowingly relying on a patent license,
|
||||
and the Corresponding Source of the work is not available for anyone
|
||||
to copy, free of charge and under the terms of this License, through a
|
||||
publicly available network server or other readily accessible means,
|
||||
then you must either (1) cause the Corresponding Source to be so
|
||||
available, or (2) arrange to deprive yourself of the benefit of the
|
||||
patent license for this particular work, or (3) arrange, in a manner
|
||||
consistent with the requirements of this License, to extend the patent
|
||||
license to downstream recipients. "Knowingly relying" means you have
|
||||
actual knowledge that, but for the patent license, your conveying the
|
||||
covered work in a country, or your recipient's use of the covered work
|
||||
in a country, would infringe one or more identifiable patents in that
|
||||
country that you have reason to believe are valid.
|
||||
|
||||
If, pursuant to or in connection with a single transaction or
|
||||
arrangement, you convey, or propagate by procuring conveyance of, a
|
||||
covered work, and grant a patent license to some of the parties
|
||||
receiving the covered work authorizing them to use, propagate, modify
|
||||
or convey a specific copy of the covered work, then the patent license
|
||||
you grant is automatically extended to all recipients of the covered
|
||||
work and works based on it.
|
||||
|
||||
A patent license is "discriminatory" if it does not include within
|
||||
the scope of its coverage, prohibits the exercise of, or is
|
||||
conditioned on the non-exercise of one or more of the rights that are
|
||||
specifically granted under this License. You may not convey a covered
|
||||
work if you are a party to an arrangement with a third party that is
|
||||
in the business of distributing software, under which you make payment
|
||||
to the third party based on the extent of your activity of conveying
|
||||
the work, and under which the third party grants, to any of the
|
||||
parties who would receive the covered work from you, a discriminatory
|
||||
patent license (a) in connection with copies of the covered work
|
||||
conveyed by you (or copies made from those copies), or (b) primarily
|
||||
for and in connection with specific products or compilations that
|
||||
contain the covered work, unless you entered into that arrangement,
|
||||
or that patent license was granted, prior to 28 March 2007.
|
||||
|
||||
Nothing in this License shall be construed as excluding or limiting
|
||||
any implied license or other defenses to infringement that may
|
||||
otherwise be available to you under applicable patent law.
|
||||
|
||||
12. No Surrender of Others' Freedom.
|
||||
|
||||
If conditions are imposed on you (whether by court order, agreement or
|
||||
otherwise) that contradict the conditions of this License, they do not
|
||||
excuse you from the conditions of this License. If you cannot convey a
|
||||
covered work so as to satisfy simultaneously your obligations under this
|
||||
License and any other pertinent obligations, then as a consequence you may
|
||||
not convey it at all. For example, if you agree to terms that obligate you
|
||||
to collect a royalty for further conveying from those to whom you convey
|
||||
the Program, the only way you could satisfy both those terms and this
|
||||
License would be to refrain entirely from conveying the Program.
|
||||
|
||||
13. Use with the GNU Affero General Public License.
|
||||
|
||||
Notwithstanding any other provision of this License, you have
|
||||
permission to link or combine any covered work with a work licensed
|
||||
under version 3 of the GNU Affero General Public License into a single
|
||||
combined work, and to convey the resulting work. The terms of this
|
||||
License will continue to apply to the part which is the covered work,
|
||||
but the special requirements of the GNU Affero General Public License,
|
||||
section 13, concerning interaction through a network will apply to the
|
||||
combination as such.
|
||||
|
||||
14. Revised Versions of this License.
|
||||
|
||||
The Free Software Foundation may publish revised and/or new versions of
|
||||
the GNU General Public License from time to time. Such new versions will
|
||||
be similar in spirit to the present version, but may differ in detail to
|
||||
address new problems or concerns.
|
||||
|
||||
Each version is given a distinguishing version number. If the
|
||||
Program specifies that a certain numbered version of the GNU General
|
||||
Public License "or any later version" applies to it, you have the
|
||||
option of following the terms and conditions either of that numbered
|
||||
version or of any later version published by the Free Software
|
||||
Foundation. If the Program does not specify a version number of the
|
||||
GNU General Public License, you may choose any version ever published
|
||||
by the Free Software Foundation.
|
||||
|
||||
If the Program specifies that a proxy can decide which future
|
||||
versions of the GNU General Public License can be used, that proxy's
|
||||
public statement of acceptance of a version permanently authorizes you
|
||||
to choose that version for the Program.
|
||||
|
||||
Later license versions may give you additional or different
|
||||
permissions. However, no additional obligations are imposed on any
|
||||
author or copyright holder as a result of your choosing to follow a
|
||||
later version.
|
||||
|
||||
15. Disclaimer of Warranty.
|
||||
|
||||
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
||||
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
||||
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
||||
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
||||
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
||||
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
||||
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
||||
|
||||
16. Limitation of Liability.
|
||||
|
||||
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
||||
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
||||
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
||||
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
||||
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
||||
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
||||
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
||||
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
||||
SUCH DAMAGES.
|
||||
|
||||
17. Interpretation of Sections 15 and 16.
|
||||
|
||||
If the disclaimer of warranty and limitation of liability provided
|
||||
above cannot be given local legal effect according to their terms,
|
||||
reviewing courts shall apply local law that most closely approximates
|
||||
an absolute waiver of all civil liability in connection with the
|
||||
Program, unless a warranty or assumption of liability accompanies a
|
||||
copy of the Program in return for a fee.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
How to Apply These Terms to Your New Programs
|
||||
|
||||
If you develop a new program, and you want it to be of the greatest
|
||||
possible use to the public, the best way to achieve this is to make it
|
||||
free software which everyone can redistribute and change under these terms.
|
||||
|
||||
To do so, attach the following notices to the program. It is safest
|
||||
to attach them to the start of each source file to most effectively
|
||||
state the exclusion of warranty; and each file should have at least
|
||||
the "copyright" line and a pointer to where the full notice is found.
|
||||
|
||||
<one line to give the program's name and a brief idea of what it does.>
|
||||
Copyright (C) <year> <name of author>
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
Also add information on how to contact you by electronic and paper mail.
|
||||
|
||||
If the program does terminal interaction, make it output a short
|
||||
notice like this when it starts in an interactive mode:
|
||||
|
||||
<program> Copyright (C) <year> <name of author>
|
||||
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
|
||||
This is free software, and you are welcome to redistribute it
|
||||
under certain conditions; type `show c' for details.
|
||||
|
||||
The hypothetical commands `show w' and `show c' should show the appropriate
|
||||
parts of the General Public License. Of course, your program's commands
|
||||
might be different; for a GUI interface, you would use an "about box".
|
||||
|
||||
You should also get your employer (if you work as a programmer) or school,
|
||||
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
||||
For more information on this, and how to apply and follow the GNU GPL, see
|
||||
<https://www.gnu.org/licenses/>.
|
||||
|
||||
The GNU General Public License does not permit incorporating your program
|
||||
into proprietary programs. If your program is a subroutine library, you
|
||||
may consider it more useful to permit linking proprietary applications with
|
||||
the library. If this is what you want to do, use the GNU Lesser General
|
||||
Public License instead of this License. But first, please read
|
||||
<https://www.gnu.org/licenses/why-not-lgpl.html>.
|
||||
|
||||
285
README.md
285
README.md
@@ -15,8 +15,7 @@ exo: Run your own AI cluster at home with everyday devices. Maintained by [exo l
|
||||
</h3>
|
||||
|
||||
[](https://github.com/exo-explore/exo/stargazers)
|
||||
[](https://dl.circleci.com/status-badge/redirect/circleci/TrkofJDoGzdQAeL6yVHKsg/4i5hJuafuwZYZQxbRAWS71/tree/main)
|
||||
[](https://www.gnu.org/licenses/gpl-3.0)
|
||||
[](https://www.apache.org/licenses/LICENSE-2.0.html)
|
||||
|
||||
<a href="https://trendshift.io/repositories/11849" target="_blank"><img src="https://trendshift.io/api/badge/repositories/11849" alt="exo-explore%2Fexo | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
|
||||
@@ -26,286 +25,6 @@ exo: Run your own AI cluster at home with everyday devices. Maintained by [exo l
|
||||
|
||||
> **EXO**
|
||||
>
|
||||
> EXO started out of a desire to run research experiments on large language models using the hardware we already owned.
|
||||
>
|
||||
> What began here is becoming part of something much larger.
|
||||
>
|
||||
> soon™
|
||||
>
|
||||
> \- The EXO Team
|
||||
> Coming soon. For legacy exo, see this repo's history or [exo-explore/ex-exo](https://github.com/exo-explore/ex-exo) for a snapshot.
|
||||
|
||||
---
|
||||
|
||||
Unify your existing devices into one powerful GPU: iPhone, iPad, Android, Mac, NVIDIA, Raspberry Pi, pretty much any device!
|
||||
|
||||
<div align="center">
|
||||
<h2>Update: exo is hiring. See <a href="https://exolabs.net">here</a> for more details.</h2>
|
||||
</div>
|
||||
|
||||
## Get Involved
|
||||
|
||||
exo is **experimental** software. Expect bugs early on. Create issues so they can be fixed. The [exo labs](https://x.com/exolabs) team will strive to resolve issues quickly.
|
||||
|
||||
We also welcome contributions from the community. We have a list of bounties in [this sheet](https://docs.google.com/spreadsheets/d/1cTCpTIp48UnnIvHeLEUNg1iMy_Q6lRybgECSFCoVJpE/edit?usp=sharing).
|
||||
|
||||
## Features
|
||||
|
||||
### Wide Model Support
|
||||
|
||||
exo supports different models including LLaMA ([MLX](exo/inference/mlx/models/llama.py) and [tinygrad](exo/inference/tinygrad/models/llama.py)), Mistral, LlaVA, Qwen, and Deepseek.
|
||||
|
||||
### Dynamic Model Partitioning
|
||||
|
||||
exo [optimally splits up models](exo/topology/ring_memory_weighted_partitioning_strategy.py) based on the current network topology and device resources available. This enables you to run larger models than you would be able to on any single device.
|
||||
|
||||
### Automatic Device Discovery
|
||||
|
||||
exo will [automatically discover](https://github.com/exo-explore/exo/blob/945f90f676182a751d2ad7bcf20987ab7fe0181e/exo/orchestration/node.py#L154) other devices using the best method available. Zero manual configuration.
|
||||
|
||||
### ChatGPT-compatible API
|
||||
|
||||
exo provides a [ChatGPT-compatible API](exo/api/chatgpt_api.py) for running models. It's a [one-line change](examples/chatgpt_api.sh) in your application to run models on your own hardware using exo.
|
||||
|
||||
### Device Equality
|
||||
|
||||
Unlike other distributed inference frameworks, exo does not use a master-worker architecture. Instead, exo devices [connect p2p](https://github.com/exo-explore/exo/blob/945f90f676182a751d2ad7bcf20987ab7fe0181e/exo/orchestration/node.py#L161). As long as a device is connected somewhere in the network, it can be used to run models.
|
||||
|
||||
Exo supports different [partitioning strategies](exo/topology/partitioning_strategy.py) to split up a model across devices. The default partitioning strategy is [ring memory weighted partitioning](exo/topology/ring_memory_weighted_partitioning_strategy.py). This runs an inference in a ring where each device runs a number of model layers proportional to the memory of the device.
|
||||
|
||||

|
||||
|
||||
## Installation
|
||||
|
||||
The current recommended way to install exo is from source.
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Python>=3.12.0 is required because of [issues with asyncio](https://github.com/exo-explore/exo/issues/5) in previous versions.
|
||||
- For Linux with NVIDIA GPU support (Linux-only, skip if not using Linux or NVIDIA):
|
||||
- NVIDIA driver - verify with `nvidia-smi`
|
||||
- CUDA toolkit - install from [NVIDIA CUDA guide](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#cuda-cross-platform-installation), verify with `nvcc --version`
|
||||
- cuDNN library - download from [NVIDIA cuDNN page](https://developer.nvidia.com/cudnn-downloads), verify installation by following [these steps](https://docs.nvidia.com/deeplearning/cudnn/latest/installation/linux.html#verifying-the-install-on-linux:~:text=at%20a%20time.-,Verifying%20the%20Install%20on%20Linux,Test%20passed!,-Upgrading%20From%20Older)
|
||||
|
||||
### Hardware Requirements
|
||||
|
||||
- The only requirement to run exo is to have enough memory across all your devices to fit the entire model into memory. For example, if you are running llama 3.1 8B (fp16), you need 16GB of memory across all devices. Any of the following configurations would work since they each have more than 16GB of memory in total:
|
||||
- 2 x 8GB M3 MacBook Airs
|
||||
- 1 x 16GB NVIDIA RTX 4070 Ti Laptop
|
||||
- 2 x Raspberry Pi 400 with 4GB of RAM each (running on CPU) + 1 x 8GB Mac Mini
|
||||
- exo is designed to run on devices with heterogeneous capabilities. For example, you can have some devices with powerful GPUs and others with integrated GPUs or even CPUs. Adding less capable devices will slow down individual inference latency but will increase the overall throughput of the cluster.
|
||||
|
||||
### From source
|
||||
|
||||
|
||||
```sh
|
||||
git clone https://github.com/exo-explore/exo.git
|
||||
cd exo
|
||||
pip install -e .
|
||||
# alternatively, with venv
|
||||
source install.sh
|
||||
```
|
||||
|
||||
|
||||
### Troubleshooting
|
||||
|
||||
- If running on Mac, MLX has an [install guide](https://ml-explore.github.io/mlx/build/html/install.html) with troubleshooting steps.
|
||||
|
||||
### Performance
|
||||
|
||||
- There are a number of things users have empirically found to improve performance on Apple Silicon Macs:
|
||||
|
||||
1. Upgrade to the latest version of macOS Sequoia.
|
||||
2. Run `./configure_mlx.sh`. This runs commands to optimize GPU memory allocation on Apple Silicon Macs.
|
||||
|
||||
|
||||
## Documentation
|
||||
|
||||
### Example Usage on Multiple macOS Devices
|
||||
|
||||
#### Device 1:
|
||||
|
||||
```sh
|
||||
exo
|
||||
```
|
||||
|
||||
#### Device 2:
|
||||
```sh
|
||||
exo
|
||||
```
|
||||
|
||||
That's it! No configuration required - exo will automatically discover the other device(s).
|
||||
|
||||
exo starts a ChatGPT-like WebUI (powered by [tinygrad tinychat](https://github.com/tinygrad/tinygrad/tree/master/examples/tinychat)) on http://localhost:52415
|
||||
|
||||
For developers, exo also starts a ChatGPT-compatible API endpoint on http://localhost:52415/v1/chat/completions. Examples with curl:
|
||||
|
||||
#### Llama 3.2 3B:
|
||||
|
||||
```sh
|
||||
curl http://localhost:52415/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "llama-3.2-3b",
|
||||
"messages": [{"role": "user", "content": "What is the meaning of exo?"}],
|
||||
"temperature": 0.7
|
||||
}'
|
||||
```
|
||||
|
||||
#### Llama 3.1 405B:
|
||||
|
||||
```sh
|
||||
curl http://localhost:52415/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "llama-3.1-405b",
|
||||
"messages": [{"role": "user", "content": "What is the meaning of exo?"}],
|
||||
"temperature": 0.7
|
||||
}'
|
||||
```
|
||||
|
||||
#### DeepSeek R1 (full 671B):
|
||||
|
||||
```sh
|
||||
curl http://localhost:52415/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "deepseek-r1",
|
||||
"messages": [{"role": "user", "content": "What is the meaning of exo?"}],
|
||||
"temperature": 0.7
|
||||
}'
|
||||
```
|
||||
|
||||
#### Llava 1.5 7B (Vision Language Model):
|
||||
|
||||
```sh
|
||||
curl http://localhost:52415/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "llava-1.5-7b-hf",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What are these?"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"temperature": 0.0
|
||||
}'
|
||||
```
|
||||
|
||||
### Example Usage on Multiple Heterogenous Devices (macOS + Linux)
|
||||
|
||||
#### Device 1 (macOS):
|
||||
|
||||
```sh
|
||||
exo
|
||||
```
|
||||
|
||||
Note: We don't need to explicitly tell exo to use the **tinygrad** inference engine. **MLX** and **tinygrad** are interoperable!
|
||||
|
||||
#### Device 2 (Linux):
|
||||
```sh
|
||||
exo
|
||||
```
|
||||
|
||||
Linux devices will automatically default to using the **tinygrad** inference engine.
|
||||
|
||||
You can read about tinygrad-specific env vars [here](https://docs.tinygrad.org/env_vars/). For example, you can configure tinygrad to use the cpu by specifying `CLANG=1`.
|
||||
|
||||
### Example Usage on a single device with "exo run" command
|
||||
|
||||
```sh
|
||||
exo run llama-3.2-3b
|
||||
```
|
||||
|
||||
With a custom prompt:
|
||||
|
||||
```sh
|
||||
exo run llama-3.2-3b --prompt "What is the meaning of exo?"
|
||||
```
|
||||
|
||||
### Model Storage
|
||||
|
||||
Models by default are stored in `~/.cache/exo/downloads`.
|
||||
|
||||
You can set a different model storage location by setting the `EXO_HOME` env var.
|
||||
|
||||
## Model Downloading
|
||||
|
||||
Models are downloaded from Hugging Face. If you are running exo in a country with strict internet censorship, you may need to download the models manually and put them in the `~/.cache/exo/downloads` directory.
|
||||
|
||||
To download models from a proxy endpoint, set the `HF_ENDPOINT` environment variable. For example, to run exo with the huggingface mirror endpoint:
|
||||
|
||||
```sh
|
||||
HF_ENDPOINT=https://hf-mirror.com exo
|
||||
```
|
||||
|
||||
## Debugging
|
||||
|
||||
Enable debug logs with the DEBUG environment variable (0-9).
|
||||
|
||||
```sh
|
||||
DEBUG=9 exo
|
||||
```
|
||||
|
||||
For the **tinygrad** inference engine specifically, there is a separate DEBUG flag `TINYGRAD_DEBUG` that can be used to enable debug logs (1-6).
|
||||
|
||||
```sh
|
||||
TINYGRAD_DEBUG=2 exo
|
||||
```
|
||||
|
||||
## Formatting
|
||||
|
||||
We use [yapf](https://github.com/google/yapf) to format the code. To format the code, first install the formatting requirements:
|
||||
|
||||
```sh
|
||||
pip3 install -e '.[formatting]'
|
||||
```
|
||||
|
||||
Then run the formatting script:
|
||||
|
||||
```sh
|
||||
python3 format.py ./exo
|
||||
```
|
||||
|
||||
## Known Issues
|
||||
|
||||
- On certain versions of Python on macOS, certificates may not installed correctly, potentially causing SSL errors (e.g., when accessing huggingface.co). To resolve this, run the `Install Certificates` command, typicall as follows:
|
||||
|
||||
```sh
|
||||
/Applications/Python 3.x/Install Certificates.command
|
||||
```
|
||||
|
||||
- 🚧 As the library is evolving so quickly, the iOS implementation has fallen behind Python. We have decided for now not to put out the buggy iOS version and receive a bunch of GitHub issues for outdated code. We are working on solving this properly and will make an announcement when it's ready. If you would like access to the iOS implementation now, please email alex@exolabs.net with your GitHub username explaining your use-case and you will be granted access on GitHub.
|
||||
|
||||
## Inference Engines
|
||||
|
||||
exo supports the following inference engines:
|
||||
|
||||
- ✅ [MLX](exo/inference/mlx/sharded_inference_engine.py)
|
||||
- ✅ [tinygrad](exo/inference/tinygrad/inference.py)
|
||||
- 🚧 [PyTorch](https://github.com/exo-explore/exo/pull/139)
|
||||
- 🚧 [llama.cpp](https://github.com/exo-explore/exo/issues/167)
|
||||
|
||||
## Discovery Modules
|
||||
|
||||
- ✅ [UDP](exo/networking/udp)
|
||||
- ✅ [Manual](exo/networking/manual)
|
||||
- ✅ [Tailscale](exo/networking/tailscale)
|
||||
- 🚧 Radio
|
||||
- 🚧 Bluetooth
|
||||
|
||||
# Peer Networking Modules
|
||||
|
||||
- ✅ [GRPC](exo/networking/grpc)
|
||||
- 🚧 NCCL
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Get the total memory in MB
|
||||
TOTAL_MEM_MB=$(($(sysctl -n hw.memsize) / 1024 / 1024))
|
||||
|
||||
# Calculate 80% and TOTAL_MEM_GB-5GB in MB
|
||||
EIGHTY_PERCENT=$(($TOTAL_MEM_MB * 80 / 100))
|
||||
MINUS_5GB=$((($TOTAL_MEM_MB - 5120)))
|
||||
|
||||
# Calculate 70% and TOTAL_MEM_GB-8GB in MB
|
||||
SEVENTY_PERCENT=$(($TOTAL_MEM_MB * 70 / 100))
|
||||
MINUS_8GB=$((($TOTAL_MEM_MB - 8192)))
|
||||
|
||||
# Set WIRED_LIMIT_MB to higher value
|
||||
if [ $EIGHTY_PERCENT -gt $MINUS_5GB ]; then
|
||||
WIRED_LIMIT_MB=$EIGHTY_PERCENT
|
||||
else
|
||||
WIRED_LIMIT_MB=$MINUS_5GB
|
||||
fi
|
||||
|
||||
# Set WIRED_LWM_MB to higher value
|
||||
if [ $SEVENTY_PERCENT -gt $MINUS_8GB ]; then
|
||||
WIRED_LWM_MB=$SEVENTY_PERCENT
|
||||
else
|
||||
WIRED_LWM_MB=$MINUS_8GB
|
||||
fi
|
||||
|
||||
# Display the calculated values
|
||||
echo "Total memory: $TOTAL_MEM_MB MB"
|
||||
echo "Maximum limit (iogpu.wired_limit_mb): $WIRED_LIMIT_MB MB"
|
||||
echo "Lower bound (iogpu.wired_lwm_mb): $WIRED_LWM_MB MB"
|
||||
|
||||
# Apply the values with sysctl, but check if we're already root
|
||||
if [ "$EUID" -eq 0 ]; then
|
||||
sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB
|
||||
sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB
|
||||
else
|
||||
# Try without sudo first, fall back to sudo if needed
|
||||
sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB 2>/dev/null || \
|
||||
sudo sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB
|
||||
sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB 2>/dev/null || \
|
||||
sudo sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB
|
||||
fi
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 7.9 KiB |
@@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:1c6f0b66b68ffc11a42cf25fbd43a6fbea99869ed4ba82e5f480d8213e9b7061
|
||||
size 1296
|
||||
@@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c7aeca6a876a195df706f3221f1bfd4792884e6042c2b355026f94cba0f7576d
|
||||
size 1296
|
||||
@@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:1316a53899f32ba6c33b083fca232b638aea4efbcf36bc99e640369169e6a1c9
|
||||
size 28651
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 295 KiB |
@@ -1,3 +0,0 @@
|
||||
## Overview
|
||||
|
||||
This example app is an open-source alternative to [Google's Project Astra](https://deepmind.google/technologies/gemini/project-astra/). It leverages the exo library to run on your own devices, providing a fully transparent and customizable experience compared to Google's closed-source API.
|
||||
@@ -1,653 +0,0 @@
|
||||
// !$*UTF8*$!
|
||||
{
|
||||
archiveVersion = 1;
|
||||
classes = {
|
||||
};
|
||||
objectVersion = 56;
|
||||
objects = {
|
||||
|
||||
/* Begin PBXBuildFile section */
|
||||
FA3E988F2C725A0200E4E795 /* astraApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = FA3E988E2C725A0200E4E795 /* astraApp.swift */; };
|
||||
FA3E98912C725A0200E4E795 /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = FA3E98902C725A0200E4E795 /* ContentView.swift */; };
|
||||
FA3E98932C725A0300E4E795 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = FA3E98922C725A0300E4E795 /* Assets.xcassets */; };
|
||||
FA3E98972C725A0300E4E795 /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = FA3E98962C725A0300E4E795 /* Preview Assets.xcassets */; };
|
||||
FA3E98A12C725A0300E4E795 /* astraTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = FA3E98A02C725A0300E4E795 /* astraTests.swift */; };
|
||||
FA3E98AB2C725A0300E4E795 /* astraUITests.swift in Sources */ = {isa = PBXBuildFile; fileRef = FA3E98AA2C725A0300E4E795 /* astraUITests.swift */; };
|
||||
FA3E98AD2C725A0300E4E795 /* astraUITestsLaunchTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = FA3E98AC2C725A0300E4E795 /* astraUITestsLaunchTests.swift */; };
|
||||
FA3E98BB2C725BF800E4E795 /* WhisperKit in Frameworks */ = {isa = PBXBuildFile; productRef = FA3E98BA2C725BF800E4E795 /* WhisperKit */; };
|
||||
/* End PBXBuildFile section */
|
||||
|
||||
/* Begin PBXContainerItemProxy section */
|
||||
FA3E989D2C725A0300E4E795 /* PBXContainerItemProxy */ = {
|
||||
isa = PBXContainerItemProxy;
|
||||
containerPortal = FA3E98832C725A0200E4E795 /* Project object */;
|
||||
proxyType = 1;
|
||||
remoteGlobalIDString = FA3E988A2C725A0200E4E795;
|
||||
remoteInfo = astra;
|
||||
};
|
||||
FA3E98A72C725A0300E4E795 /* PBXContainerItemProxy */ = {
|
||||
isa = PBXContainerItemProxy;
|
||||
containerPortal = FA3E98832C725A0200E4E795 /* Project object */;
|
||||
proxyType = 1;
|
||||
remoteGlobalIDString = FA3E988A2C725A0200E4E795;
|
||||
remoteInfo = astra;
|
||||
};
|
||||
/* End PBXContainerItemProxy section */
|
||||
|
||||
/* Begin PBXFileReference section */
|
||||
FA3E988B2C725A0200E4E795 /* astra.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = astra.app; sourceTree = BUILT_PRODUCTS_DIR; };
|
||||
FA3E988E2C725A0200E4E795 /* astraApp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = astraApp.swift; sourceTree = "<group>"; };
|
||||
FA3E98902C725A0200E4E795 /* ContentView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ContentView.swift; sourceTree = "<group>"; };
|
||||
FA3E98922C725A0300E4E795 /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = "<group>"; };
|
||||
FA3E98942C725A0300E4E795 /* astra.entitlements */ = {isa = PBXFileReference; lastKnownFileType = text.plist.entitlements; path = astra.entitlements; sourceTree = "<group>"; };
|
||||
FA3E98962C725A0300E4E795 /* Preview Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = "Preview Assets.xcassets"; sourceTree = "<group>"; };
|
||||
FA3E989C2C725A0300E4E795 /* astraTests.xctest */ = {isa = PBXFileReference; explicitFileType = wrapper.cfbundle; includeInIndex = 0; path = astraTests.xctest; sourceTree = BUILT_PRODUCTS_DIR; };
|
||||
FA3E98A02C725A0300E4E795 /* astraTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = astraTests.swift; sourceTree = "<group>"; };
|
||||
FA3E98A62C725A0300E4E795 /* astraUITests.xctest */ = {isa = PBXFileReference; explicitFileType = wrapper.cfbundle; includeInIndex = 0; path = astraUITests.xctest; sourceTree = BUILT_PRODUCTS_DIR; };
|
||||
FA3E98AA2C725A0300E4E795 /* astraUITests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = astraUITests.swift; sourceTree = "<group>"; };
|
||||
FA3E98AC2C725A0300E4E795 /* astraUITestsLaunchTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = astraUITestsLaunchTests.swift; sourceTree = "<group>"; };
|
||||
/* End PBXFileReference section */
|
||||
|
||||
/* Begin PBXFrameworksBuildPhase section */
|
||||
FA3E98882C725A0200E4E795 /* Frameworks */ = {
|
||||
isa = PBXFrameworksBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
FA3E98BB2C725BF800E4E795 /* WhisperKit in Frameworks */,
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
FA3E98992C725A0300E4E795 /* Frameworks */ = {
|
||||
isa = PBXFrameworksBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
FA3E98A32C725A0300E4E795 /* Frameworks */ = {
|
||||
isa = PBXFrameworksBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
/* End PBXFrameworksBuildPhase section */
|
||||
|
||||
/* Begin PBXGroup section */
|
||||
FA3E98822C725A0200E4E795 = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
FA3E988D2C725A0200E4E795 /* astra */,
|
||||
FA3E989F2C725A0300E4E795 /* astraTests */,
|
||||
FA3E98A92C725A0300E4E795 /* astraUITests */,
|
||||
FA3E988C2C725A0200E4E795 /* Products */,
|
||||
);
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
FA3E988C2C725A0200E4E795 /* Products */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
FA3E988B2C725A0200E4E795 /* astra.app */,
|
||||
FA3E989C2C725A0300E4E795 /* astraTests.xctest */,
|
||||
FA3E98A62C725A0300E4E795 /* astraUITests.xctest */,
|
||||
);
|
||||
name = Products;
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
FA3E988D2C725A0200E4E795 /* astra */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
FA3E988E2C725A0200E4E795 /* astraApp.swift */,
|
||||
FA3E98902C725A0200E4E795 /* ContentView.swift */,
|
||||
FA3E98922C725A0300E4E795 /* Assets.xcassets */,
|
||||
FA3E98942C725A0300E4E795 /* astra.entitlements */,
|
||||
FA3E98952C725A0300E4E795 /* Preview Content */,
|
||||
);
|
||||
path = astra;
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
FA3E98952C725A0300E4E795 /* Preview Content */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
FA3E98962C725A0300E4E795 /* Preview Assets.xcassets */,
|
||||
);
|
||||
path = "Preview Content";
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
FA3E989F2C725A0300E4E795 /* astraTests */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
FA3E98A02C725A0300E4E795 /* astraTests.swift */,
|
||||
);
|
||||
path = astraTests;
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
FA3E98A92C725A0300E4E795 /* astraUITests */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
FA3E98AA2C725A0300E4E795 /* astraUITests.swift */,
|
||||
FA3E98AC2C725A0300E4E795 /* astraUITestsLaunchTests.swift */,
|
||||
);
|
||||
path = astraUITests;
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
/* End PBXGroup section */
|
||||
|
||||
/* Begin PBXNativeTarget section */
|
||||
FA3E988A2C725A0200E4E795 /* astra */ = {
|
||||
isa = PBXNativeTarget;
|
||||
buildConfigurationList = FA3E98B02C725A0300E4E795 /* Build configuration list for PBXNativeTarget "astra" */;
|
||||
buildPhases = (
|
||||
FA3E98872C725A0200E4E795 /* Sources */,
|
||||
FA3E98882C725A0200E4E795 /* Frameworks */,
|
||||
FA3E98892C725A0200E4E795 /* Resources */,
|
||||
);
|
||||
buildRules = (
|
||||
);
|
||||
dependencies = (
|
||||
);
|
||||
name = astra;
|
||||
packageProductDependencies = (
|
||||
FA3E98BA2C725BF800E4E795 /* WhisperKit */,
|
||||
);
|
||||
productName = astra;
|
||||
productReference = FA3E988B2C725A0200E4E795 /* astra.app */;
|
||||
productType = "com.apple.product-type.application";
|
||||
};
|
||||
FA3E989B2C725A0300E4E795 /* astraTests */ = {
|
||||
isa = PBXNativeTarget;
|
||||
buildConfigurationList = FA3E98B32C725A0300E4E795 /* Build configuration list for PBXNativeTarget "astraTests" */;
|
||||
buildPhases = (
|
||||
FA3E98982C725A0300E4E795 /* Sources */,
|
||||
FA3E98992C725A0300E4E795 /* Frameworks */,
|
||||
FA3E989A2C725A0300E4E795 /* Resources */,
|
||||
);
|
||||
buildRules = (
|
||||
);
|
||||
dependencies = (
|
||||
FA3E989E2C725A0300E4E795 /* PBXTargetDependency */,
|
||||
);
|
||||
name = astraTests;
|
||||
productName = astraTests;
|
||||
productReference = FA3E989C2C725A0300E4E795 /* astraTests.xctest */;
|
||||
productType = "com.apple.product-type.bundle.unit-test";
|
||||
};
|
||||
FA3E98A52C725A0300E4E795 /* astraUITests */ = {
|
||||
isa = PBXNativeTarget;
|
||||
buildConfigurationList = FA3E98B62C725A0300E4E795 /* Build configuration list for PBXNativeTarget "astraUITests" */;
|
||||
buildPhases = (
|
||||
FA3E98A22C725A0300E4E795 /* Sources */,
|
||||
FA3E98A32C725A0300E4E795 /* Frameworks */,
|
||||
FA3E98A42C725A0300E4E795 /* Resources */,
|
||||
);
|
||||
buildRules = (
|
||||
);
|
||||
dependencies = (
|
||||
FA3E98A82C725A0300E4E795 /* PBXTargetDependency */,
|
||||
);
|
||||
name = astraUITests;
|
||||
productName = astraUITests;
|
||||
productReference = FA3E98A62C725A0300E4E795 /* astraUITests.xctest */;
|
||||
productType = "com.apple.product-type.bundle.ui-testing";
|
||||
};
|
||||
/* End PBXNativeTarget section */
|
||||
|
||||
/* Begin PBXProject section */
|
||||
FA3E98832C725A0200E4E795 /* Project object */ = {
|
||||
isa = PBXProject;
|
||||
attributes = {
|
||||
BuildIndependentTargetsInParallel = 1;
|
||||
LastSwiftUpdateCheck = 1540;
|
||||
LastUpgradeCheck = 1540;
|
||||
TargetAttributes = {
|
||||
FA3E988A2C725A0200E4E795 = {
|
||||
CreatedOnToolsVersion = 15.4;
|
||||
};
|
||||
FA3E989B2C725A0300E4E795 = {
|
||||
CreatedOnToolsVersion = 15.4;
|
||||
TestTargetID = FA3E988A2C725A0200E4E795;
|
||||
};
|
||||
FA3E98A52C725A0300E4E795 = {
|
||||
CreatedOnToolsVersion = 15.4;
|
||||
TestTargetID = FA3E988A2C725A0200E4E795;
|
||||
};
|
||||
};
|
||||
};
|
||||
buildConfigurationList = FA3E98862C725A0200E4E795 /* Build configuration list for PBXProject "astra" */;
|
||||
compatibilityVersion = "Xcode 14.0";
|
||||
developmentRegion = en;
|
||||
hasScannedForEncodings = 0;
|
||||
knownRegions = (
|
||||
en,
|
||||
Base,
|
||||
);
|
||||
mainGroup = FA3E98822C725A0200E4E795;
|
||||
packageReferences = (
|
||||
FA3E98B92C725BF800E4E795 /* XCRemoteSwiftPackageReference "whisperkit" */,
|
||||
);
|
||||
productRefGroup = FA3E988C2C725A0200E4E795 /* Products */;
|
||||
projectDirPath = "";
|
||||
projectRoot = "";
|
||||
targets = (
|
||||
FA3E988A2C725A0200E4E795 /* astra */,
|
||||
FA3E989B2C725A0300E4E795 /* astraTests */,
|
||||
FA3E98A52C725A0300E4E795 /* astraUITests */,
|
||||
);
|
||||
};
|
||||
/* End PBXProject section */
|
||||
|
||||
/* Begin PBXResourcesBuildPhase section */
|
||||
FA3E98892C725A0200E4E795 /* Resources */ = {
|
||||
isa = PBXResourcesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
FA3E98972C725A0300E4E795 /* Preview Assets.xcassets in Resources */,
|
||||
FA3E98932C725A0300E4E795 /* Assets.xcassets in Resources */,
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
FA3E989A2C725A0300E4E795 /* Resources */ = {
|
||||
isa = PBXResourcesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
FA3E98A42C725A0300E4E795 /* Resources */ = {
|
||||
isa = PBXResourcesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
/* End PBXResourcesBuildPhase section */
|
||||
|
||||
/* Begin PBXSourcesBuildPhase section */
|
||||
FA3E98872C725A0200E4E795 /* Sources */ = {
|
||||
isa = PBXSourcesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
FA3E98912C725A0200E4E795 /* ContentView.swift in Sources */,
|
||||
FA3E988F2C725A0200E4E795 /* astraApp.swift in Sources */,
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
FA3E98982C725A0300E4E795 /* Sources */ = {
|
||||
isa = PBXSourcesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
FA3E98A12C725A0300E4E795 /* astraTests.swift in Sources */,
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
FA3E98A22C725A0300E4E795 /* Sources */ = {
|
||||
isa = PBXSourcesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
FA3E98AD2C725A0300E4E795 /* astraUITestsLaunchTests.swift in Sources */,
|
||||
FA3E98AB2C725A0300E4E795 /* astraUITests.swift in Sources */,
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
/* End PBXSourcesBuildPhase section */
|
||||
|
||||
/* Begin PBXTargetDependency section */
|
||||
FA3E989E2C725A0300E4E795 /* PBXTargetDependency */ = {
|
||||
isa = PBXTargetDependency;
|
||||
target = FA3E988A2C725A0200E4E795 /* astra */;
|
||||
targetProxy = FA3E989D2C725A0300E4E795 /* PBXContainerItemProxy */;
|
||||
};
|
||||
FA3E98A82C725A0300E4E795 /* PBXTargetDependency */ = {
|
||||
isa = PBXTargetDependency;
|
||||
target = FA3E988A2C725A0200E4E795 /* astra */;
|
||||
targetProxy = FA3E98A72C725A0300E4E795 /* PBXContainerItemProxy */;
|
||||
};
|
||||
/* End PBXTargetDependency section */
|
||||
|
||||
/* Begin XCBuildConfiguration section */
|
||||
FA3E98AE2C725A0300E4E795 /* Debug */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
ALWAYS_SEARCH_USER_PATHS = NO;
|
||||
ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;
|
||||
CLANG_ANALYZER_NONNULL = YES;
|
||||
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
|
||||
CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
|
||||
CLANG_ENABLE_MODULES = YES;
|
||||
CLANG_ENABLE_OBJC_ARC = YES;
|
||||
CLANG_ENABLE_OBJC_WEAK = YES;
|
||||
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
|
||||
CLANG_WARN_BOOL_CONVERSION = YES;
|
||||
CLANG_WARN_COMMA = YES;
|
||||
CLANG_WARN_CONSTANT_CONVERSION = YES;
|
||||
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
|
||||
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
|
||||
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
|
||||
CLANG_WARN_EMPTY_BODY = YES;
|
||||
CLANG_WARN_ENUM_CONVERSION = YES;
|
||||
CLANG_WARN_INFINITE_RECURSION = YES;
|
||||
CLANG_WARN_INT_CONVERSION = YES;
|
||||
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
|
||||
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
|
||||
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
|
||||
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
|
||||
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
|
||||
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
|
||||
CLANG_WARN_STRICT_PROTOTYPES = YES;
|
||||
CLANG_WARN_SUSPICIOUS_MOVE = YES;
|
||||
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
|
||||
CLANG_WARN_UNREACHABLE_CODE = YES;
|
||||
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
|
||||
COPY_PHASE_STRIP = NO;
|
||||
DEBUG_INFORMATION_FORMAT = dwarf;
|
||||
ENABLE_STRICT_OBJC_MSGSEND = YES;
|
||||
ENABLE_TESTABILITY = YES;
|
||||
ENABLE_USER_SCRIPT_SANDBOXING = YES;
|
||||
GCC_C_LANGUAGE_STANDARD = gnu17;
|
||||
GCC_DYNAMIC_NO_PIC = NO;
|
||||
GCC_NO_COMMON_BLOCKS = YES;
|
||||
GCC_OPTIMIZATION_LEVEL = 0;
|
||||
GCC_PREPROCESSOR_DEFINITIONS = (
|
||||
"DEBUG=1",
|
||||
"$(inherited)",
|
||||
);
|
||||
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
|
||||
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
|
||||
GCC_WARN_UNDECLARED_SELECTOR = YES;
|
||||
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
|
||||
GCC_WARN_UNUSED_FUNCTION = YES;
|
||||
GCC_WARN_UNUSED_VARIABLE = YES;
|
||||
LOCALIZATION_PREFERS_STRING_CATALOGS = YES;
|
||||
MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE;
|
||||
MTL_FAST_MATH = YES;
|
||||
ONLY_ACTIVE_ARCH = YES;
|
||||
SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)";
|
||||
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
|
||||
};
|
||||
name = Debug;
|
||||
};
|
||||
FA3E98AF2C725A0300E4E795 /* Release */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
ALWAYS_SEARCH_USER_PATHS = NO;
|
||||
ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;
|
||||
CLANG_ANALYZER_NONNULL = YES;
|
||||
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
|
||||
CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
|
||||
CLANG_ENABLE_MODULES = YES;
|
||||
CLANG_ENABLE_OBJC_ARC = YES;
|
||||
CLANG_ENABLE_OBJC_WEAK = YES;
|
||||
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
|
||||
CLANG_WARN_BOOL_CONVERSION = YES;
|
||||
CLANG_WARN_COMMA = YES;
|
||||
CLANG_WARN_CONSTANT_CONVERSION = YES;
|
||||
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
|
||||
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
|
||||
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
|
||||
CLANG_WARN_EMPTY_BODY = YES;
|
||||
CLANG_WARN_ENUM_CONVERSION = YES;
|
||||
CLANG_WARN_INFINITE_RECURSION = YES;
|
||||
CLANG_WARN_INT_CONVERSION = YES;
|
||||
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
|
||||
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
|
||||
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
|
||||
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
|
||||
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
|
||||
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
|
||||
CLANG_WARN_STRICT_PROTOTYPES = YES;
|
||||
CLANG_WARN_SUSPICIOUS_MOVE = YES;
|
||||
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
|
||||
CLANG_WARN_UNREACHABLE_CODE = YES;
|
||||
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
|
||||
COPY_PHASE_STRIP = NO;
|
||||
DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym";
|
||||
ENABLE_NS_ASSERTIONS = NO;
|
||||
ENABLE_STRICT_OBJC_MSGSEND = YES;
|
||||
ENABLE_USER_SCRIPT_SANDBOXING = YES;
|
||||
GCC_C_LANGUAGE_STANDARD = gnu17;
|
||||
GCC_NO_COMMON_BLOCKS = YES;
|
||||
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
|
||||
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
|
||||
GCC_WARN_UNDECLARED_SELECTOR = YES;
|
||||
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
|
||||
GCC_WARN_UNUSED_FUNCTION = YES;
|
||||
GCC_WARN_UNUSED_VARIABLE = YES;
|
||||
LOCALIZATION_PREFERS_STRING_CATALOGS = YES;
|
||||
MTL_ENABLE_DEBUG_INFO = NO;
|
||||
MTL_FAST_MATH = YES;
|
||||
SWIFT_COMPILATION_MODE = wholemodule;
|
||||
};
|
||||
name = Release;
|
||||
};
|
||||
FA3E98B12C725A0300E4E795 /* Debug */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
|
||||
ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor;
|
||||
CODE_SIGN_ENTITLEMENTS = astra/astra.entitlements;
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
CURRENT_PROJECT_VERSION = 1;
|
||||
DEVELOPMENT_ASSET_PATHS = "\"astra/Preview Content\"";
|
||||
DEVELOPMENT_TEAM = 8NFAS2P4ND;
|
||||
ENABLE_HARDENED_RUNTIME = YES;
|
||||
ENABLE_PREVIEWS = YES;
|
||||
GENERATE_INFOPLIST_FILE = YES;
|
||||
INFOPLIST_KEY_NSCameraUsageDescription = "Capture from camera to send to vision model";
|
||||
INFOPLIST_KEY_NSMicrophoneUsageDescription = "Uses your microphone for transcribing audio";
|
||||
"INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphoneos*]" = YES;
|
||||
"INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphonesimulator*]" = YES;
|
||||
"INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents[sdk=iphoneos*]" = YES;
|
||||
"INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents[sdk=iphonesimulator*]" = YES;
|
||||
"INFOPLIST_KEY_UILaunchScreen_Generation[sdk=iphoneos*]" = YES;
|
||||
"INFOPLIST_KEY_UILaunchScreen_Generation[sdk=iphonesimulator*]" = YES;
|
||||
"INFOPLIST_KEY_UIStatusBarStyle[sdk=iphoneos*]" = UIStatusBarStyleDefault;
|
||||
"INFOPLIST_KEY_UIStatusBarStyle[sdk=iphonesimulator*]" = UIStatusBarStyleDefault;
|
||||
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
|
||||
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
|
||||
IPHONEOS_DEPLOYMENT_TARGET = 17.5;
|
||||
LD_RUNPATH_SEARCH_PATHS = "@executable_path/Frameworks";
|
||||
"LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = "@executable_path/../Frameworks";
|
||||
MACOSX_DEPLOYMENT_TARGET = 14.4;
|
||||
MARKETING_VERSION = 1.0;
|
||||
PRODUCT_BUNDLE_IDENTIFIER = net.exolabs.astra;
|
||||
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||
SDKROOT = auto;
|
||||
SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx";
|
||||
SWIFT_EMIT_LOC_STRINGS = YES;
|
||||
SWIFT_VERSION = 5.0;
|
||||
TARGETED_DEVICE_FAMILY = "1,2";
|
||||
};
|
||||
name = Debug;
|
||||
};
|
||||
FA3E98B22C725A0300E4E795 /* Release */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
|
||||
ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor;
|
||||
CODE_SIGN_ENTITLEMENTS = astra/astra.entitlements;
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
CURRENT_PROJECT_VERSION = 1;
|
||||
DEVELOPMENT_ASSET_PATHS = "\"astra/Preview Content\"";
|
||||
DEVELOPMENT_TEAM = 8NFAS2P4ND;
|
||||
ENABLE_HARDENED_RUNTIME = YES;
|
||||
ENABLE_PREVIEWS = YES;
|
||||
GENERATE_INFOPLIST_FILE = YES;
|
||||
INFOPLIST_KEY_NSCameraUsageDescription = "Capture from camera to send to vision model";
|
||||
INFOPLIST_KEY_NSMicrophoneUsageDescription = "Uses your microphone for transcribing audio";
|
||||
"INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphoneos*]" = YES;
|
||||
"INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphonesimulator*]" = YES;
|
||||
"INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents[sdk=iphoneos*]" = YES;
|
||||
"INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents[sdk=iphonesimulator*]" = YES;
|
||||
"INFOPLIST_KEY_UILaunchScreen_Generation[sdk=iphoneos*]" = YES;
|
||||
"INFOPLIST_KEY_UILaunchScreen_Generation[sdk=iphonesimulator*]" = YES;
|
||||
"INFOPLIST_KEY_UIStatusBarStyle[sdk=iphoneos*]" = UIStatusBarStyleDefault;
|
||||
"INFOPLIST_KEY_UIStatusBarStyle[sdk=iphonesimulator*]" = UIStatusBarStyleDefault;
|
||||
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
|
||||
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
|
||||
IPHONEOS_DEPLOYMENT_TARGET = 17.5;
|
||||
LD_RUNPATH_SEARCH_PATHS = "@executable_path/Frameworks";
|
||||
"LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = "@executable_path/../Frameworks";
|
||||
MACOSX_DEPLOYMENT_TARGET = 14.4;
|
||||
MARKETING_VERSION = 1.0;
|
||||
PRODUCT_BUNDLE_IDENTIFIER = net.exolabs.astra;
|
||||
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||
SDKROOT = auto;
|
||||
SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx";
|
||||
SWIFT_EMIT_LOC_STRINGS = YES;
|
||||
SWIFT_VERSION = 5.0;
|
||||
TARGETED_DEVICE_FAMILY = "1,2";
|
||||
};
|
||||
name = Release;
|
||||
};
|
||||
FA3E98B42C725A0300E4E795 /* Debug */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
ALWAYS_EMBED_SWIFT_STANDARD_LIBRARIES = YES;
|
||||
BUNDLE_LOADER = "$(TEST_HOST)";
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
CURRENT_PROJECT_VERSION = 1;
|
||||
DEVELOPMENT_TEAM = 8NFAS2P4ND;
|
||||
GENERATE_INFOPLIST_FILE = YES;
|
||||
IPHONEOS_DEPLOYMENT_TARGET = 17.5;
|
||||
MACOSX_DEPLOYMENT_TARGET = 14.4;
|
||||
MARKETING_VERSION = 1.0;
|
||||
PRODUCT_BUNDLE_IDENTIFIER = net.exolabs.astraTests;
|
||||
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||
SDKROOT = auto;
|
||||
SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx";
|
||||
SWIFT_EMIT_LOC_STRINGS = NO;
|
||||
SWIFT_VERSION = 5.0;
|
||||
TARGETED_DEVICE_FAMILY = "1,2";
|
||||
TEST_HOST = "$(BUILT_PRODUCTS_DIR)/astra.app/$(BUNDLE_EXECUTABLE_FOLDER_PATH)/astra";
|
||||
};
|
||||
name = Debug;
|
||||
};
|
||||
FA3E98B52C725A0300E4E795 /* Release */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
ALWAYS_EMBED_SWIFT_STANDARD_LIBRARIES = YES;
|
||||
BUNDLE_LOADER = "$(TEST_HOST)";
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
CURRENT_PROJECT_VERSION = 1;
|
||||
DEVELOPMENT_TEAM = 8NFAS2P4ND;
|
||||
GENERATE_INFOPLIST_FILE = YES;
|
||||
IPHONEOS_DEPLOYMENT_TARGET = 17.5;
|
||||
MACOSX_DEPLOYMENT_TARGET = 14.4;
|
||||
MARKETING_VERSION = 1.0;
|
||||
PRODUCT_BUNDLE_IDENTIFIER = net.exolabs.astraTests;
|
||||
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||
SDKROOT = auto;
|
||||
SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx";
|
||||
SWIFT_EMIT_LOC_STRINGS = NO;
|
||||
SWIFT_VERSION = 5.0;
|
||||
TARGETED_DEVICE_FAMILY = "1,2";
|
||||
TEST_HOST = "$(BUILT_PRODUCTS_DIR)/astra.app/$(BUNDLE_EXECUTABLE_FOLDER_PATH)/astra";
|
||||
};
|
||||
name = Release;
|
||||
};
|
||||
FA3E98B72C725A0300E4E795 /* Debug */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
ALWAYS_EMBED_SWIFT_STANDARD_LIBRARIES = YES;
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
CURRENT_PROJECT_VERSION = 1;
|
||||
DEVELOPMENT_TEAM = 8NFAS2P4ND;
|
||||
GENERATE_INFOPLIST_FILE = YES;
|
||||
IPHONEOS_DEPLOYMENT_TARGET = 17.5;
|
||||
MACOSX_DEPLOYMENT_TARGET = 14.4;
|
||||
MARKETING_VERSION = 1.0;
|
||||
PRODUCT_BUNDLE_IDENTIFIER = net.exolabs.astraUITests;
|
||||
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||
SDKROOT = auto;
|
||||
SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx";
|
||||
SWIFT_EMIT_LOC_STRINGS = NO;
|
||||
SWIFT_VERSION = 5.0;
|
||||
TARGETED_DEVICE_FAMILY = "1,2";
|
||||
TEST_TARGET_NAME = astra;
|
||||
};
|
||||
name = Debug;
|
||||
};
|
||||
FA3E98B82C725A0300E4E795 /* Release */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
ALWAYS_EMBED_SWIFT_STANDARD_LIBRARIES = YES;
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
CURRENT_PROJECT_VERSION = 1;
|
||||
DEVELOPMENT_TEAM = 8NFAS2P4ND;
|
||||
GENERATE_INFOPLIST_FILE = YES;
|
||||
IPHONEOS_DEPLOYMENT_TARGET = 17.5;
|
||||
MACOSX_DEPLOYMENT_TARGET = 14.4;
|
||||
MARKETING_VERSION = 1.0;
|
||||
PRODUCT_BUNDLE_IDENTIFIER = net.exolabs.astraUITests;
|
||||
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||
SDKROOT = auto;
|
||||
SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx";
|
||||
SWIFT_EMIT_LOC_STRINGS = NO;
|
||||
SWIFT_VERSION = 5.0;
|
||||
TARGETED_DEVICE_FAMILY = "1,2";
|
||||
TEST_TARGET_NAME = astra;
|
||||
};
|
||||
name = Release;
|
||||
};
|
||||
/* End XCBuildConfiguration section */
|
||||
|
||||
/* Begin XCConfigurationList section */
|
||||
FA3E98862C725A0200E4E795 /* Build configuration list for PBXProject "astra" */ = {
|
||||
isa = XCConfigurationList;
|
||||
buildConfigurations = (
|
||||
FA3E98AE2C725A0300E4E795 /* Debug */,
|
||||
FA3E98AF2C725A0300E4E795 /* Release */,
|
||||
);
|
||||
defaultConfigurationIsVisible = 0;
|
||||
defaultConfigurationName = Release;
|
||||
};
|
||||
FA3E98B02C725A0300E4E795 /* Build configuration list for PBXNativeTarget "astra" */ = {
|
||||
isa = XCConfigurationList;
|
||||
buildConfigurations = (
|
||||
FA3E98B12C725A0300E4E795 /* Debug */,
|
||||
FA3E98B22C725A0300E4E795 /* Release */,
|
||||
);
|
||||
defaultConfigurationIsVisible = 0;
|
||||
defaultConfigurationName = Release;
|
||||
};
|
||||
FA3E98B32C725A0300E4E795 /* Build configuration list for PBXNativeTarget "astraTests" */ = {
|
||||
isa = XCConfigurationList;
|
||||
buildConfigurations = (
|
||||
FA3E98B42C725A0300E4E795 /* Debug */,
|
||||
FA3E98B52C725A0300E4E795 /* Release */,
|
||||
);
|
||||
defaultConfigurationIsVisible = 0;
|
||||
defaultConfigurationName = Release;
|
||||
};
|
||||
FA3E98B62C725A0300E4E795 /* Build configuration list for PBXNativeTarget "astraUITests" */ = {
|
||||
isa = XCConfigurationList;
|
||||
buildConfigurations = (
|
||||
FA3E98B72C725A0300E4E795 /* Debug */,
|
||||
FA3E98B82C725A0300E4E795 /* Release */,
|
||||
);
|
||||
defaultConfigurationIsVisible = 0;
|
||||
defaultConfigurationName = Release;
|
||||
};
|
||||
/* End XCConfigurationList section */
|
||||
|
||||
/* Begin XCRemoteSwiftPackageReference section */
|
||||
FA3E98B92C725BF800E4E795 /* XCRemoteSwiftPackageReference "whisperkit" */ = {
|
||||
isa = XCRemoteSwiftPackageReference;
|
||||
repositoryURL = "https://github.com/argmaxinc/whisperkit";
|
||||
requirement = {
|
||||
branch = main;
|
||||
kind = branch;
|
||||
};
|
||||
};
|
||||
/* End XCRemoteSwiftPackageReference section */
|
||||
|
||||
/* Begin XCSwiftPackageProductDependency section */
|
||||
FA3E98BA2C725BF800E4E795 /* WhisperKit */ = {
|
||||
isa = XCSwiftPackageProductDependency;
|
||||
package = FA3E98B92C725BF800E4E795 /* XCRemoteSwiftPackageReference "whisperkit" */;
|
||||
productName = WhisperKit;
|
||||
};
|
||||
/* End XCSwiftPackageProductDependency section */
|
||||
};
|
||||
rootObject = FA3E98832C725A0200E4E795 /* Project object */;
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<Workspace
|
||||
version = "1.0">
|
||||
<FileRef
|
||||
location = "self:">
|
||||
</FileRef>
|
||||
</Workspace>
|
||||
@@ -1,8 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>IDEDidComputeMac32BitWarning</key>
|
||||
<true/>
|
||||
</dict>
|
||||
</plist>
|
||||
@@ -1,33 +0,0 @@
|
||||
{
|
||||
"originHash" : "8f61689e55c5551e76f2c686d145061dc1fa621a58cbca576565ebfabc15c894",
|
||||
"pins" : [
|
||||
{
|
||||
"identity" : "swift-argument-parser",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/apple/swift-argument-parser.git",
|
||||
"state" : {
|
||||
"revision" : "c8ed701b513cf5177118a175d85fbbbcd707ab41",
|
||||
"version" : "1.3.0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-transformers",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/huggingface/swift-transformers.git",
|
||||
"state" : {
|
||||
"revision" : "74b94211bdc741694ed7e700a1104c72e5ba68fe",
|
||||
"version" : "0.1.7"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "whisperkit",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/argmaxinc/whisperkit",
|
||||
"state" : {
|
||||
"branch" : "main",
|
||||
"revision" : "59aaa4e5f211622f9a5e133440220d9974641d3b"
|
||||
}
|
||||
}
|
||||
],
|
||||
"version" : 3
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
{
|
||||
"colors" : [
|
||||
{
|
||||
"idiom" : "universal"
|
||||
}
|
||||
],
|
||||
"info" : {
|
||||
"author" : "xcode",
|
||||
"version" : 1
|
||||
}
|
||||
}
|
||||
@@ -1,63 +0,0 @@
|
||||
{
|
||||
"images" : [
|
||||
{
|
||||
"idiom" : "universal",
|
||||
"platform" : "ios",
|
||||
"size" : "1024x1024"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "1x",
|
||||
"size" : "16x16"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "2x",
|
||||
"size" : "16x16"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "1x",
|
||||
"size" : "32x32"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "2x",
|
||||
"size" : "32x32"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "1x",
|
||||
"size" : "128x128"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "2x",
|
||||
"size" : "128x128"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "1x",
|
||||
"size" : "256x256"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "2x",
|
||||
"size" : "256x256"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "1x",
|
||||
"size" : "512x512"
|
||||
},
|
||||
{
|
||||
"idiom" : "mac",
|
||||
"scale" : "2x",
|
||||
"size" : "512x512"
|
||||
}
|
||||
],
|
||||
"info" : {
|
||||
"author" : "xcode",
|
||||
"version" : 1
|
||||
}
|
||||
}
|
||||
@@ -1,6 +0,0 @@
|
||||
{
|
||||
"info" : {
|
||||
"author" : "xcode",
|
||||
"version" : 1
|
||||
}
|
||||
}
|
||||
@@ -1,729 +0,0 @@
|
||||
import SwiftUI
|
||||
import WhisperKit
|
||||
import AVFoundation
|
||||
import Foundation
|
||||
import Combine
|
||||
import Vision
|
||||
import AVFAudio
|
||||
|
||||
actor CameraActor {
|
||||
let captureSession = AVCaptureSession()
|
||||
private let photoOutput = AVCapturePhotoOutput()
|
||||
private var isConfigured = false
|
||||
private var currentPhotoCaptureDelegate: PhotoCaptureDelegate?
|
||||
|
||||
func configure() throws {
|
||||
guard !isConfigured else {
|
||||
print("Camera already configured")
|
||||
return
|
||||
}
|
||||
|
||||
print("Starting camera configuration")
|
||||
|
||||
guard let camera = AVCaptureDevice.default(for: .video) else {
|
||||
print("No camera device available")
|
||||
throw CameraError.cameraUnavailable
|
||||
}
|
||||
|
||||
do {
|
||||
let input = try AVCaptureDeviceInput(device: camera)
|
||||
print("Camera input created successfully")
|
||||
|
||||
guard captureSession.canAddInput(input) else {
|
||||
print("Cannot add camera input to session")
|
||||
throw CameraError.cannotAddInputOutput
|
||||
}
|
||||
|
||||
guard captureSession.canAddOutput(photoOutput) else {
|
||||
print("Cannot add photo output to session")
|
||||
throw CameraError.cannotAddInputOutput
|
||||
}
|
||||
|
||||
captureSession.beginConfiguration()
|
||||
captureSession.addInput(input)
|
||||
captureSession.addOutput(photoOutput)
|
||||
captureSession.commitConfiguration()
|
||||
|
||||
print("Camera session configured successfully")
|
||||
|
||||
Task.detached { [weak self] in
|
||||
self?.captureSession.startRunning()
|
||||
print("Camera session started running")
|
||||
}
|
||||
|
||||
isConfigured = true
|
||||
print("Camera fully configured and ready")
|
||||
} catch {
|
||||
print("Error during camera configuration: \(error)")
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
func capturePhoto() async throws -> String {
|
||||
guard isConfigured else {
|
||||
throw CameraError.notConfigured
|
||||
}
|
||||
|
||||
return try await withCheckedThrowingContinuation { continuation in
|
||||
let photoSettings = AVCapturePhotoSettings()
|
||||
|
||||
let delegate = PhotoCaptureDelegate { result in
|
||||
self.currentPhotoCaptureDelegate = nil
|
||||
continuation.resume(with: result)
|
||||
}
|
||||
|
||||
self.currentPhotoCaptureDelegate = delegate
|
||||
|
||||
Task { @MainActor in
|
||||
self.photoOutput.capturePhoto(with: photoSettings, delegate: delegate)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class PhotoCaptureDelegate: NSObject, AVCapturePhotoCaptureDelegate {
|
||||
private let completionHandler: (Result<String, Error>) -> Void
|
||||
|
||||
init(completionHandler: @escaping (Result<String, Error>) -> Void) {
|
||||
self.completionHandler = completionHandler
|
||||
}
|
||||
|
||||
func photoOutput(_ output: AVCapturePhotoOutput, didFinishProcessingPhoto photo: AVCapturePhoto, error: Error?) {
|
||||
if let error = error {
|
||||
completionHandler(.failure(error))
|
||||
return
|
||||
}
|
||||
|
||||
guard let imageData = photo.fileDataRepresentation() else {
|
||||
completionHandler(.failure(CameraError.imageProcessingFailed))
|
||||
return
|
||||
}
|
||||
|
||||
let base64String = imageData.base64EncodedString()
|
||||
completionHandler(.success(base64String))
|
||||
}
|
||||
}
|
||||
|
||||
enum CameraError: Error {
|
||||
case cameraUnavailable
|
||||
case cannotAddInputOutput
|
||||
case notConfigured
|
||||
case imageProcessingFailed
|
||||
}
|
||||
|
||||
struct CameraPreview: UIViewControllerRepresentable {
|
||||
let cameraActor: CameraActor
|
||||
|
||||
func makeUIViewController(context: Context) -> UIViewController {
|
||||
let viewController = UIViewController()
|
||||
let previewLayer = AVCaptureVideoPreviewLayer(session: cameraActor.captureSession)
|
||||
previewLayer.videoGravity = .resizeAspectFill
|
||||
viewController.view.layer.addSublayer(previewLayer)
|
||||
previewLayer.frame = viewController.view.bounds
|
||||
return viewController
|
||||
}
|
||||
|
||||
func updateUIViewController(_ uiViewController: UIViewController, context: Context) {
|
||||
if let previewLayer = uiViewController.view.layer.sublayers?.first as? AVCaptureVideoPreviewLayer {
|
||||
previewLayer.frame = uiViewController.view.bounds
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct ContentView: View {
|
||||
@State private var whisperKit: WhisperKit?
|
||||
@State private var isListening = false
|
||||
@State private var currentText = ""
|
||||
@State private var bufferSeconds: Double = 0.5 // or whatever the actual buffer size is
|
||||
@State private var modelState: ModelState = .unloaded
|
||||
|
||||
@AppStorage("selectedModel") private var selectedModel: String = "large-v3"
|
||||
@AppStorage("selectedLanguage") private var selectedLanguage: String = "english"
|
||||
@AppStorage("selectedTask") private var selectedTask: String = "transcribe"
|
||||
|
||||
@State private var isRecordingMemo = false
|
||||
@State private var currentMemo = ""
|
||||
@State private var lastVoiceActivityTime = Date()
|
||||
@State private var silenceTimer: Timer?
|
||||
@State private var voiceActivityThreshold: Float = 0.40
|
||||
@State private var silenceTimeThreshold = 1.0
|
||||
@State private var debugText = ""
|
||||
@State private var apiEndpoint = "http://192.168.212.74:52415/v1/chat/completions"
|
||||
@State private var audioBuffer: [Float] = []
|
||||
@State private var bufferDuration: Double = 0.5 // 0.5 seconds buffer
|
||||
@State private var isInitialTranscription = true
|
||||
@State private var streamingResponse = ""
|
||||
@State private var cancellables = Set<AnyCancellable>()
|
||||
|
||||
@State private var cameraActor: CameraActor?
|
||||
@State private var showLiveCamera = false
|
||||
@State private var capturedImageBase64: String?
|
||||
@State private var errorMessage: String?
|
||||
@State private var isCameraReady = false
|
||||
|
||||
@State private var speechSynthesizer = AVSpeechSynthesizer()
|
||||
@State private var speechBuffer = ""
|
||||
@State private var wordCount = 0
|
||||
let maxWords = 12
|
||||
@State private var originalSilenceThreshold: Float = 0.40
|
||||
@State private var isTTSActive: Bool = false
|
||||
@State private var canRecordAudio: Bool = true
|
||||
@State private var ttsFinishTime: Date?
|
||||
|
||||
@State private var isRequestInProgress = false
|
||||
@State private var isFirst3WordsOfResponse = true
|
||||
|
||||
var body: some View {
|
||||
ZStack {
|
||||
if showLiveCamera, isCameraReady, let actor = cameraActor {
|
||||
CameraPreview(cameraActor: actor)
|
||||
.edgesIgnoringSafeArea(.all)
|
||||
}
|
||||
|
||||
ScrollView {
|
||||
VStack {
|
||||
Text(currentText)
|
||||
.padding()
|
||||
|
||||
Text(isListening ? "Listening..." : "Not listening")
|
||||
.foregroundColor(isListening ? .green : .red)
|
||||
|
||||
if isRecordingMemo {
|
||||
Text("Recording...")
|
||||
.foregroundColor(.blue)
|
||||
}
|
||||
|
||||
Picker("Model", selection: $selectedModel) {
|
||||
Text("large-v3").tag("large-v3")
|
||||
Text("base").tag("base")
|
||||
Text("small").tag("small")
|
||||
}
|
||||
.pickerStyle(SegmentedPickerStyle())
|
||||
.padding()
|
||||
|
||||
Button("Load Model") {
|
||||
loadModel(selectedModel)
|
||||
}
|
||||
.disabled(modelState == .loaded)
|
||||
.padding()
|
||||
|
||||
Text("Model State: \(modelState.description)")
|
||||
|
||||
Text(debugText)
|
||||
.font(.caption)
|
||||
.foregroundColor(.gray)
|
||||
|
||||
Text("TTS Active: \(isTTSActive ? "Yes" : "No")")
|
||||
.font(.caption)
|
||||
.foregroundColor(isTTSActive ? .green : .red)
|
||||
|
||||
Text("Current Silence Threshold: \(voiceActivityThreshold, specifier: "%.2f")")
|
||||
.font(.caption)
|
||||
.foregroundColor(.blue)
|
||||
|
||||
Text("Original Silence Threshold: \(originalSilenceThreshold, specifier: "%.2f")")
|
||||
.font(.caption)
|
||||
.foregroundColor(.orange)
|
||||
|
||||
Slider(value: $voiceActivityThreshold, in: 0.01...1.0) {
|
||||
Text("Voice Activity Threshold: \(voiceActivityThreshold, specifier: "%.2f")")
|
||||
}
|
||||
|
||||
Text("API Response:")
|
||||
.font(.headline)
|
||||
.padding(.top)
|
||||
|
||||
ScrollView {
|
||||
Text(streamingResponse)
|
||||
.padding()
|
||||
}
|
||||
.frame(height: 200)
|
||||
.border(Color.gray, width: 1)
|
||||
|
||||
Toggle("Show Live Camera", isOn: $showLiveCamera)
|
||||
.padding()
|
||||
.onChange(of: showLiveCamera) { newValue in
|
||||
if newValue {
|
||||
Task {
|
||||
await setupCamera()
|
||||
}
|
||||
} else {
|
||||
cameraActor = nil
|
||||
isCameraReady = false
|
||||
print("Camera disabled")
|
||||
}
|
||||
}
|
||||
|
||||
if !showLiveCamera {
|
||||
Text("Camera Ready: \(isCameraReady ? "Yes" : "No")")
|
||||
.padding()
|
||||
|
||||
if let errorMessage = errorMessage {
|
||||
Text("Error: \(errorMessage)")
|
||||
.foregroundColor(.red)
|
||||
.padding()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
.opacity(showLiveCamera ? 0.7 : 1)
|
||||
}
|
||||
.onAppear {
|
||||
setupWhisperKit()
|
||||
startTTSMonitoring()
|
||||
}
|
||||
}
|
||||
|
||||
private func setupWhisperKit() {
|
||||
Task {
|
||||
do {
|
||||
whisperKit = try await WhisperKit(verbose: true)
|
||||
print("WhisperKit initialized successfully")
|
||||
startListening()
|
||||
startAudioBuffering()
|
||||
} catch {
|
||||
print("Error initializing WhisperKit: \(error)")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func startTTSMonitoring() {
|
||||
Timer.scheduledTimer(withTimeInterval: 0.1, repeats: true) { _ in
|
||||
let newTTSActive = speechSynthesizer.isSpeaking
|
||||
if newTTSActive != isTTSActive {
|
||||
isTTSActive = newTTSActive
|
||||
canRecordAudio = !newTTSActive
|
||||
if isTTSActive {
|
||||
voiceActivityThreshold = 1.0 // Set to max to prevent recording
|
||||
whisperKit?.audioProcessor.purgeAudioSamples(keepingLast: 0) // Flush audio buffer
|
||||
print("TTS Started - Audio recording paused")
|
||||
} else {
|
||||
ttsFinishTime = Date()
|
||||
print("TTS Finished - Waiting 0.5 seconds before resuming audio recording")
|
||||
}
|
||||
updateDebugText()
|
||||
}
|
||||
|
||||
if !isTTSActive, let finishTime = ttsFinishTime, Date().timeIntervalSince(finishTime) >= 0.5 {
|
||||
whisperKit?.audioProcessor.purgeAudioSamples(keepingLast: 0) // Flush audio buffer
|
||||
voiceActivityThreshold = originalSilenceThreshold
|
||||
canRecordAudio = true
|
||||
ttsFinishTime = nil
|
||||
print("Audio recording resumed after TTS delay")
|
||||
updateDebugText()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func updateDebugText() {
|
||||
debugText += "\nTTS Active: \(isTTSActive)"
|
||||
debugText += "\nCurrent Silence Threshold: \(voiceActivityThreshold)"
|
||||
debugText += "\nOriginal Silence Threshold: \(originalSilenceThreshold)"
|
||||
debugText += "\n---"
|
||||
}
|
||||
|
||||
private func startAudioBuffering() {
|
||||
Task {
|
||||
while true {
|
||||
if let samples = whisperKit?.audioProcessor.audioSamples {
|
||||
let bufferSize = Int(Double(WhisperKit.sampleRate) * bufferDuration)
|
||||
audioBuffer = Array(samples.suffix(bufferSize))
|
||||
}
|
||||
try await Task.sleep(nanoseconds: 100_000_000) // Update every 0.1 seconds
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func loadModel(_ model: String) {
|
||||
Task {
|
||||
let success = try await loadModel(selectedModel)
|
||||
if success {
|
||||
startListening()
|
||||
} else {
|
||||
print("Model failed to load, cannot start listening")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func startListening() {
|
||||
guard let audioProcessor = whisperKit?.audioProcessor else {
|
||||
print("AudioProcessor not available")
|
||||
return
|
||||
}
|
||||
|
||||
do {
|
||||
try audioProcessor.startRecordingLive { buffer in
|
||||
DispatchQueue.main.async {
|
||||
checkVoiceActivity()
|
||||
}
|
||||
}
|
||||
isListening = true
|
||||
} catch {
|
||||
print("Error starting listening: \(error)")
|
||||
}
|
||||
}
|
||||
|
||||
private func checkVoiceActivity() {
|
||||
guard canRecordAudio, let audioProcessor = whisperKit?.audioProcessor else { return }
|
||||
|
||||
let voiceDetected = AudioProcessor.isVoiceDetected(
|
||||
in: audioProcessor.relativeEnergy,
|
||||
nextBufferInSeconds: Float(bufferSeconds),
|
||||
silenceThreshold: Float(voiceActivityThreshold)
|
||||
)
|
||||
|
||||
let energyValuesToConsider = Int(Float(bufferSeconds) / 0.1)
|
||||
let nextBufferEnergies = audioProcessor.relativeEnergy.suffix(energyValuesToConsider)
|
||||
let numberOfValuesToCheck = max(10, nextBufferEnergies.count - 10)
|
||||
let relevantEnergies = Array(nextBufferEnergies.prefix(numberOfValuesToCheck))
|
||||
|
||||
debugText = """
|
||||
Buffer seconds: \(bufferSeconds)
|
||||
Energy values to consider: \(energyValuesToConsider)
|
||||
Number of values to check: \(numberOfValuesToCheck)
|
||||
Silence threshold: \(voiceActivityThreshold)
|
||||
Relevant energies: \(relevantEnergies)
|
||||
Max energy: \(relevantEnergies.max() ?? 0)
|
||||
Voice detected: \(voiceDetected)
|
||||
"""
|
||||
|
||||
if voiceDetected {
|
||||
lastVoiceActivityTime = Date()
|
||||
if !isRecordingMemo {
|
||||
startNewMemo()
|
||||
}
|
||||
} else {
|
||||
checkSilence()
|
||||
}
|
||||
}
|
||||
|
||||
private func checkSilence() {
|
||||
let silenceDuration = Date().timeIntervalSince(lastVoiceActivityTime)
|
||||
debugText += "\nSilence duration: \(silenceDuration)"
|
||||
|
||||
if silenceDuration > silenceTimeThreshold {
|
||||
endCurrentMemo()
|
||||
}
|
||||
}
|
||||
|
||||
private func endCurrentMemo() {
|
||||
if isRecordingMemo {
|
||||
isRecordingMemo = false
|
||||
silenceTimer?.invalidate()
|
||||
silenceTimer = nil
|
||||
if !currentMemo.isEmpty {
|
||||
saveMemoToFile(currentMemo)
|
||||
currentMemo = ""
|
||||
}
|
||||
currentText = ""
|
||||
whisperKit?.audioProcessor.purgeAudioSamples(keepingLast: 0)
|
||||
print("Ended memo")
|
||||
debugText += "\nMemo ended"
|
||||
}
|
||||
}
|
||||
|
||||
private func startNewMemo() {
|
||||
isRecordingMemo = true
|
||||
currentMemo = ""
|
||||
isInitialTranscription = true
|
||||
silenceTimer?.invalidate()
|
||||
silenceTimer = Timer.scheduledTimer(withTimeInterval: 0.5, repeats: true) { _ in
|
||||
checkSilence()
|
||||
}
|
||||
transcribeInRealTime()
|
||||
print("Started new memo")
|
||||
}
|
||||
|
||||
private func transcribeInRealTime() {
|
||||
Task {
|
||||
while isRecordingMemo {
|
||||
if canRecordAudio, let samples = whisperKit?.audioProcessor.audioSamples, samples.count > WhisperKit.sampleRate {
|
||||
do {
|
||||
let samplesToTranscribe: [Float]
|
||||
if isInitialTranscription {
|
||||
samplesToTranscribe = audioBuffer + samples
|
||||
isInitialTranscription = false
|
||||
} else {
|
||||
samplesToTranscribe = Array(samples)
|
||||
}
|
||||
|
||||
let result = try await whisperKit?.transcribe(audioArray: samplesToTranscribe)
|
||||
await MainActor.run {
|
||||
let newText = result?.first?.text ?? ""
|
||||
if !newText.isEmpty {
|
||||
currentMemo = newText
|
||||
currentText = newText
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
print("Transcription error: \(error)")
|
||||
}
|
||||
}
|
||||
try await Task.sleep(nanoseconds: 500_000_000) // Sleep for 0.5 seconds
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func saveMemoToFile(_ memo: String) {
|
||||
let dateFormatter = DateFormatter()
|
||||
dateFormatter.dateFormat = "yyyy-MM-dd_HH-mm-ss"
|
||||
let fileName = "memo_\(dateFormatter.string(from: Date())).txt"
|
||||
|
||||
guard let documentsDirectory = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask).first else {
|
||||
print("Unable to access documents directory")
|
||||
return
|
||||
}
|
||||
|
||||
let fileURL = documentsDirectory.appendingPathComponent(fileName)
|
||||
|
||||
do {
|
||||
try memo.write(to: fileURL, atomically: true, encoding: .utf8)
|
||||
print("Memo saved to: \(fileURL.path)")
|
||||
} catch {
|
||||
print("Error saving memo: \(error)")
|
||||
}
|
||||
|
||||
Task {
|
||||
if !isCameraReady {
|
||||
print("Camera not ready, initializing...")
|
||||
await setupCamera()
|
||||
}
|
||||
|
||||
if let imageBase64 = await capturePhotoBase64() {
|
||||
sendMemoToAPI(memo, imageBase64: imageBase64)
|
||||
} else {
|
||||
sendMemoToAPI(memo)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func setupCamera() async {
|
||||
print("Setting up camera...")
|
||||
do {
|
||||
let actor = CameraActor()
|
||||
print("CameraActor instance created")
|
||||
try await actor.configure()
|
||||
print("Camera configured successfully")
|
||||
await MainActor.run {
|
||||
self.cameraActor = actor
|
||||
self.errorMessage = nil
|
||||
self.isCameraReady = true
|
||||
print("Camera setup complete, UI updated")
|
||||
}
|
||||
} catch {
|
||||
print("Camera setup failed: \(error)")
|
||||
await MainActor.run {
|
||||
self.errorMessage = "Failed to initialize camera: \(error.localizedDescription)"
|
||||
self.isCameraReady = false
|
||||
print("Camera setup failure reflected in UI")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func capturePhotoBase64() async -> String? {
|
||||
print("Attempting to capture photo...")
|
||||
if !isCameraReady {
|
||||
print("Camera not ready, attempting to initialize...")
|
||||
await setupCamera()
|
||||
}
|
||||
|
||||
guard let actor = cameraActor, isCameraReady else {
|
||||
print("Camera not initialized or not ready, cannot capture photo")
|
||||
await MainActor.run {
|
||||
self.errorMessage = "Camera not initialized or not ready"
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
do {
|
||||
let base64String = try await actor.capturePhoto()
|
||||
print("Photo captured successfully")
|
||||
await MainActor.run {
|
||||
self.errorMessage = nil
|
||||
}
|
||||
return base64String
|
||||
} catch {
|
||||
print("Error capturing photo: \(error)")
|
||||
await MainActor.run {
|
||||
self.errorMessage = "Failed to capture photo: \(error.localizedDescription)"
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
private func sendMemoToAPI(_ memo: String, imageBase64: String? = nil) {
|
||||
Task {
|
||||
guard !isRequestInProgress else {
|
||||
print("A request is already in progress. Skipping this one.")
|
||||
return
|
||||
}
|
||||
|
||||
isRequestInProgress = true
|
||||
isFirst3WordsOfResponse = true // Reset for new request
|
||||
defer { isRequestInProgress = false }
|
||||
|
||||
do {
|
||||
print("Starting API request for memo: \(memo.prefix(50))...")
|
||||
|
||||
guard let url = URL(string: apiEndpoint) else {
|
||||
print("Invalid API endpoint URL: \(apiEndpoint)")
|
||||
return
|
||||
}
|
||||
|
||||
var payload: [String: Any] = [
|
||||
"model": "llava-1.5-7b-hf",
|
||||
"messages": [
|
||||
["role": "user", "content": [
|
||||
["type": "text", "text": "You are a helpful conversational assistant chatting with a Gen Z user using their iPhone for voice transcription and sending images to you with their iPhone camera. Be conversational and concise, with a laid back attitude and be cheerful with humour. User said: " + memo],
|
||||
]]
|
||||
],
|
||||
"temperature": 0.7,
|
||||
"stream": true
|
||||
]
|
||||
|
||||
if let imageBase64 = imageBase64 {
|
||||
if var userMessage = (payload["messages"] as? [[String: Any]])?.last,
|
||||
var content = userMessage["content"] as? [[String: Any]] {
|
||||
content.append(["type": "image_url", "image_url": ["url": "data:image/jpeg;base64,\(imageBase64)"]])
|
||||
userMessage["content"] = content
|
||||
payload["messages"] = [userMessage]
|
||||
}
|
||||
}
|
||||
|
||||
guard let jsonData = try? JSONSerialization.data(withJSONObject: payload) else {
|
||||
print("Failed to serialize JSON payload")
|
||||
return
|
||||
}
|
||||
|
||||
var request = URLRequest(url: url)
|
||||
request.httpMethod = "POST"
|
||||
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
|
||||
request.httpBody = jsonData
|
||||
|
||||
print("Sending request to \(url.absoluteString)")
|
||||
|
||||
await MainActor.run {
|
||||
self.streamingResponse = ""
|
||||
}
|
||||
|
||||
let (bytes, response) = try await URLSession.shared.bytes(for: request)
|
||||
|
||||
guard let httpResponse = response as? HTTPURLResponse else {
|
||||
print("Invalid response")
|
||||
return
|
||||
}
|
||||
|
||||
print("Response status code: \(httpResponse.statusCode)")
|
||||
|
||||
for try await line in bytes.lines {
|
||||
print("Received line: \(line)")
|
||||
await processStreamLine(line)
|
||||
}
|
||||
|
||||
print("Stream completed")
|
||||
} catch {
|
||||
print("Error: \(error.localizedDescription)")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func processStreamLine(_ line: String) async {
|
||||
let jsonString: String
|
||||
if line.hasPrefix("data: ") {
|
||||
jsonString = String(line.dropFirst(6))
|
||||
} else {
|
||||
jsonString = line
|
||||
}
|
||||
|
||||
if jsonString.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty {
|
||||
return
|
||||
}
|
||||
|
||||
if let jsonData = jsonString.data(using: .utf8),
|
||||
let json = try? JSONSerialization.jsonObject(with: jsonData, options: []) as? [String: Any],
|
||||
let choices = json["choices"] as? [[String: Any]],
|
||||
let firstChoice = choices.first,
|
||||
let delta = firstChoice["delta"] as? [String: String],
|
||||
let content = delta["content"] {
|
||||
print("Extracted content: \(content)")
|
||||
await MainActor.run {
|
||||
self.streamingResponse += content
|
||||
bufferContent(content)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func bufferContent(_ content: String) {
|
||||
speechBuffer += content
|
||||
let words = speechBuffer.split(separator: " ")
|
||||
wordCount = words.count
|
||||
|
||||
if isFirst3WordsOfResponse && wordCount >= 3 {
|
||||
isFirst3WordsOfResponse = false
|
||||
speakBufferedContent()
|
||||
} else if content.contains(".") || content.contains("!") || content.contains("?") || wordCount >= maxWords {
|
||||
speakBufferedContent()
|
||||
}
|
||||
}
|
||||
|
||||
private func speakBufferedContent() {
|
||||
guard !speechBuffer.isEmpty else { return }
|
||||
speakContent(speechBuffer)
|
||||
speechBuffer = ""
|
||||
wordCount = 0
|
||||
}
|
||||
|
||||
private func speakContent(_ content: String) {
|
||||
let utterance = AVSpeechUtterance(string: content)
|
||||
utterance.voice = AVSpeechSynthesisVoice(language: "en-US")
|
||||
utterance.rate = 0.5
|
||||
speechSynthesizer.speak(utterance)
|
||||
}
|
||||
|
||||
private func loadModel(_ model: String) async throws -> Bool {
|
||||
guard let whisperKit = whisperKit else {
|
||||
print("WhisperKit instance not initialized")
|
||||
return false
|
||||
}
|
||||
modelState = .loading
|
||||
do {
|
||||
print("Starting to load model: \(model)")
|
||||
try await whisperKit.loadModels()
|
||||
await MainActor.run {
|
||||
modelState = .loaded
|
||||
print("Model loaded successfully: \(model)")
|
||||
}
|
||||
return true
|
||||
} catch {
|
||||
print("Error loading model: \(error)")
|
||||
await MainActor.run { modelState = .unloaded }
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
private func capturePhoto() async {
|
||||
print("Attempting to capture photo...")
|
||||
print("Camera ready: \(isCameraReady), CameraActor exists: \(cameraActor != nil)")
|
||||
guard let actor = cameraActor, isCameraReady else {
|
||||
print("Camera not initialized or not ready, cannot capture photo")
|
||||
await MainActor.run {
|
||||
self.errorMessage = "Camera not initialized or not ready"
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
do {
|
||||
let base64String = try await actor.capturePhoto()
|
||||
print("Photo captured successfully")
|
||||
await MainActor.run {
|
||||
self.capturedImageBase64 = base64String
|
||||
self.errorMessage = nil
|
||||
}
|
||||
} catch {
|
||||
print("Error capturing photo: \(error)")
|
||||
await MainActor.run {
|
||||
self.errorMessage = "Failed to capture photo: \(error.localizedDescription)"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,6 +0,0 @@
|
||||
{
|
||||
"info" : {
|
||||
"author" : "xcode",
|
||||
"version" : 1
|
||||
}
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>com.apple.developer.kernel.increased-memory-limit</key>
|
||||
<true/>
|
||||
<key>com.apple.security.app-sandbox</key>
|
||||
<true/>
|
||||
<key>com.apple.security.device.audio-input</key>
|
||||
<true/>
|
||||
<key>com.apple.security.files.downloads.read-only</key>
|
||||
<true/>
|
||||
<key>com.apple.security.files.user-selected.read-write</key>
|
||||
<true/>
|
||||
<key>com.apple.security.network.client</key>
|
||||
<true/>
|
||||
<key>com.apple.security.network.server</key>
|
||||
<true/>
|
||||
<key>com.apple.security.device.camera</key>
|
||||
<true/>
|
||||
</dict>
|
||||
</plist>
|
||||
@@ -1,17 +0,0 @@
|
||||
//
|
||||
// astraApp.swift
|
||||
// astra
|
||||
//
|
||||
// Created by Alex on 18/08/2024.
|
||||
//
|
||||
|
||||
import SwiftUI
|
||||
|
||||
@main
|
||||
struct astraApp: App {
|
||||
var body: some Scene {
|
||||
WindowGroup {
|
||||
ContentView()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,35 +0,0 @@
|
||||
//
|
||||
// astraTests.swift
|
||||
// astraTests
|
||||
//
|
||||
// Created by Alex on 18/08/2024.
|
||||
//
|
||||
|
||||
import XCTest
|
||||
|
||||
final class astraTests: XCTestCase {
|
||||
|
||||
override func setUpWithError() throws {
|
||||
// Put setup code here. This method is called before the invocation of each test method in the class.
|
||||
}
|
||||
|
||||
override func tearDownWithError() throws {
|
||||
// Put teardown code here. This method is called after the invocation of each test method in the class.
|
||||
}
|
||||
|
||||
func testExample() throws {
|
||||
// This is an example of a functional test case.
|
||||
// Use XCTAssert and related functions to verify your tests produce the correct results.
|
||||
// Any test you write for XCTest can be annotated as throws and async.
|
||||
// Mark your test throws to produce an unexpected failure when your test encounters an uncaught error.
|
||||
// Mark your test async to allow awaiting for asynchronous code to complete. Check the results with assertions afterwards.
|
||||
}
|
||||
|
||||
func testPerformanceExample() throws {
|
||||
// This is an example of a performance test case.
|
||||
measure {
|
||||
// Put the code you want to measure the time of here.
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,41 +0,0 @@
|
||||
//
|
||||
// astraUITests.swift
|
||||
// astraUITests
|
||||
//
|
||||
// Created by Alex on 18/08/2024.
|
||||
//
|
||||
|
||||
import XCTest
|
||||
|
||||
final class astraUITests: XCTestCase {
|
||||
|
||||
override func setUpWithError() throws {
|
||||
// Put setup code here. This method is called before the invocation of each test method in the class.
|
||||
|
||||
// In UI tests it is usually best to stop immediately when a failure occurs.
|
||||
continueAfterFailure = false
|
||||
|
||||
// In UI tests it’s important to set the initial state - such as interface orientation - required for your tests before they run. The setUp method is a good place to do this.
|
||||
}
|
||||
|
||||
override func tearDownWithError() throws {
|
||||
// Put teardown code here. This method is called after the invocation of each test method in the class.
|
||||
}
|
||||
|
||||
func testExample() throws {
|
||||
// UI tests must launch the application that they test.
|
||||
let app = XCUIApplication()
|
||||
app.launch()
|
||||
|
||||
// Use XCTAssert and related functions to verify your tests produce the correct results.
|
||||
}
|
||||
|
||||
func testLaunchPerformance() throws {
|
||||
if #available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 7.0, *) {
|
||||
// This measures how long it takes to launch your application.
|
||||
measure(metrics: [XCTApplicationLaunchMetric()]) {
|
||||
XCUIApplication().launch()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,32 +0,0 @@
|
||||
//
|
||||
// astraUITestsLaunchTests.swift
|
||||
// astraUITests
|
||||
//
|
||||
// Created by Alex on 18/08/2024.
|
||||
//
|
||||
|
||||
import XCTest
|
||||
|
||||
final class astraUITestsLaunchTests: XCTestCase {
|
||||
|
||||
override class var runsForEachTargetApplicationUIConfiguration: Bool {
|
||||
true
|
||||
}
|
||||
|
||||
override func setUpWithError() throws {
|
||||
continueAfterFailure = false
|
||||
}
|
||||
|
||||
func testLaunch() throws {
|
||||
let app = XCUIApplication()
|
||||
app.launch()
|
||||
|
||||
// Insert steps here to perform after app launch but before taking a screenshot,
|
||||
// such as logging into a test account or navigating somewhere in the app
|
||||
|
||||
let attachment = XCTAttachment(screenshot: app.screenshot())
|
||||
attachment.name = "Launch Screen"
|
||||
attachment.lifetime = .keepAlways
|
||||
add(attachment)
|
||||
}
|
||||
}
|
||||
@@ -1,39 +0,0 @@
|
||||
# exo provides an API that aims to be a drop-in replacements for the ChatGPT-API.
|
||||
# This example shows how you can use the API first without streaming and second with streaming.
|
||||
# This works the same in a single-node set up and in a multi-node setup.
|
||||
# You need to start exo before running this by running `python3 main.py`.
|
||||
|
||||
API_ENDPOINT="http://${API_ENDPOINT:-$(ifconfig | grep 'inet ' | grep -v '127.0.0.1' | awk '{print $2}' | head -n 1):52415}"
|
||||
MODEL="llama-3.1-8b"
|
||||
PROMPT="What is the meaning of exo?"
|
||||
TEMPERATURE=0.7
|
||||
|
||||
echo ""
|
||||
echo ""
|
||||
echo "--- Output without streaming:"
|
||||
echo ""
|
||||
curl "${API_ENDPOINT}/v1/chat/completions" --silent \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "'"${MODEL}"'",
|
||||
"messages": [{"role": "user", "content": "'"${PROMPT}"'"}],
|
||||
"temperature": '"${TEMPERATURE}"'
|
||||
}'
|
||||
|
||||
echo ""
|
||||
echo ""
|
||||
echo "--- Output with streaming:"
|
||||
echo ""
|
||||
curl "${API_ENDPOINT}/v1/chat/completions" --silent \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "'"${MODEL}"'",
|
||||
"messages": [{"role": "user", "content": "'"${PROMPT}"'"}],
|
||||
"temperature": '"${TEMPERATURE}"',
|
||||
"stream": true
|
||||
}' | while read -r line; do
|
||||
if [[ $line == data:* ]]; then
|
||||
content=$(echo "$line" | sed 's/^data: //')
|
||||
echo "$content" | jq -r '.choices[].delta.content' --unbuffered | tr -d '\n'
|
||||
fi
|
||||
done
|
||||
@@ -1,111 +0,0 @@
|
||||
import json
|
||||
import re
|
||||
import requests
|
||||
|
||||
def get_current_weather(location: str, unit: str = "celsius"):
|
||||
"""Mock weather data function"""
|
||||
# Hardcoded response for demo purposes
|
||||
return {
|
||||
"location": location,
|
||||
"temperature": 22 if unit == "celsius" else 72,
|
||||
"unit": unit,
|
||||
"forecast": "Sunny with light clouds"
|
||||
}
|
||||
|
||||
def try_parse_tool_calls(content: str):
|
||||
"""Try parse the tool calls."""
|
||||
tool_calls = []
|
||||
offset = 0
|
||||
for i, m in enumerate(re.finditer(r"<tool_call>\n(.+)?\n</tool_call>", content)):
|
||||
if i == 0:
|
||||
offset = m.start()
|
||||
try:
|
||||
func = json.loads(m.group(1))
|
||||
tool_calls.append({"type": "function", "function": func})
|
||||
if isinstance(func["arguments"], str):
|
||||
func["arguments"] = json.loads(func["arguments"])
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Failed to parse tool calls: the content is {m.group(1)} and {e}")
|
||||
pass
|
||||
if tool_calls:
|
||||
if offset > 0 and content[:offset].strip():
|
||||
c = content[:offset]
|
||||
else:
|
||||
c = ""
|
||||
return {"role": "assistant", "content": c, "tool_calls": tool_calls}
|
||||
return {"role": "assistant", "content": re.sub(r"<\|im_end\|>$", "", content)}
|
||||
|
||||
def chat_completion(messages):
|
||||
"""Send chat completion request to local server"""
|
||||
response = requests.post(
|
||||
"http://localhost:52415/v1/chat/completions",
|
||||
json={
|
||||
"model": "qwen-2.5-1.5b",
|
||||
"messages": messages,
|
||||
"tools": [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"]
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}
|
||||
}],
|
||||
"tool_choice": "auto"
|
||||
}
|
||||
)
|
||||
return response.json()
|
||||
|
||||
def main():
|
||||
# Initial conversation
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": "Hi there, what's the weather in Boston?"
|
||||
}]
|
||||
|
||||
# Get initial response
|
||||
response = chat_completion(messages)
|
||||
print(f"First response: {response}")
|
||||
assistant_message = try_parse_tool_calls(response["choices"][0]["message"]["content"])
|
||||
messages.append(assistant_message)
|
||||
|
||||
# If there are tool calls, execute them and continue conversation
|
||||
if "tool_calls" in assistant_message:
|
||||
for tool_call in assistant_message["tool_calls"]:
|
||||
if tool_call["function"]["name"] == "get_current_weather":
|
||||
args = tool_call["function"]["arguments"]
|
||||
weather_data = get_current_weather(**args)
|
||||
|
||||
# Add tool response to messages
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"content": json.dumps(weather_data),
|
||||
"name": tool_call["function"]["name"]
|
||||
})
|
||||
|
||||
# Get final response with weather data
|
||||
response = chat_completion(messages)
|
||||
print(f"Final response: {response}")
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": response["choices"][0]["message"]["content"]
|
||||
})
|
||||
|
||||
# Print full conversation
|
||||
for msg in messages:
|
||||
print(f"\n{msg['role'].upper()}: {msg['content']}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1 +0,0 @@
|
||||
from exo.helpers import DEBUG as DEBUG, DEBUG_DISCOVERY as DEBUG_DISCOVERY, VERSION as VERSION
|
||||
@@ -1 +0,0 @@
|
||||
from exo.api.chatgpt_api import ChatGPTAPI as ChatGPTAPI
|
||||
@@ -1,645 +0,0 @@
|
||||
import uuid
|
||||
import time
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from transformers import AutoTokenizer
|
||||
from typing import List, Literal, Union, Dict, Optional
|
||||
from aiohttp import web
|
||||
import aiohttp_cors
|
||||
import traceback
|
||||
import signal
|
||||
from exo import DEBUG, VERSION
|
||||
from exo.helpers import PrefixDict, shutdown, get_exo_images_dir
|
||||
from exo.inference.tokenizers import resolve_tokenizer
|
||||
from exo.orchestration import Node
|
||||
from exo.models import build_base_shard, build_full_shard, model_cards, get_repo, get_supported_models, get_pretty_name
|
||||
from typing import Callable, Optional
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import base64
|
||||
from io import BytesIO
|
||||
import platform
|
||||
from exo.download.download_progress import RepoProgressEvent
|
||||
from exo.download.new_shard_download import delete_model
|
||||
import tempfile
|
||||
from exo.apputil import create_animation_mp4
|
||||
from collections import defaultdict
|
||||
|
||||
if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
|
||||
import mlx.core as mx
|
||||
else:
|
||||
import numpy as mx
|
||||
|
||||
|
||||
class Message:
|
||||
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]], tools: Optional[List[Dict]] = None):
|
||||
self.role = role
|
||||
self.content = content
|
||||
self.tools = tools
|
||||
|
||||
def to_dict(self):
|
||||
data = {"role": self.role, "content": self.content}
|
||||
if self.tools:
|
||||
data["tools"] = self.tools
|
||||
return data
|
||||
|
||||
|
||||
class ChatCompletionRequest:
|
||||
def __init__(self, model: str, messages: List[Message], temperature: float, tools: Optional[List[Dict]] = None):
|
||||
self.model = model
|
||||
self.messages = messages
|
||||
self.temperature = temperature
|
||||
self.tools = tools
|
||||
|
||||
def to_dict(self):
|
||||
return {"model": self.model, "messages": [message.to_dict() for message in self.messages], "temperature": self.temperature, "tools": self.tools}
|
||||
|
||||
|
||||
def generate_completion(
|
||||
chat_request: ChatCompletionRequest,
|
||||
tokenizer,
|
||||
prompt: str,
|
||||
request_id: str,
|
||||
tokens: List[int],
|
||||
stream: bool,
|
||||
finish_reason: Union[Literal["length", "stop"], None],
|
||||
object_type: Literal["chat.completion", "text_completion"],
|
||||
) -> dict:
|
||||
completion = {
|
||||
"id": f"chatcmpl-{request_id}",
|
||||
"object": object_type,
|
||||
"created": int(time.time()),
|
||||
"model": chat_request.model,
|
||||
"system_fingerprint": f"exo_{VERSION}",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": tokenizer.decode(tokens)},
|
||||
"logprobs": None,
|
||||
"finish_reason": finish_reason,
|
||||
}],
|
||||
}
|
||||
|
||||
if not stream:
|
||||
completion["usage"] = {
|
||||
"prompt_tokens": len(tokenizer.encode(prompt)),
|
||||
"completion_tokens": len(tokens),
|
||||
"total_tokens": len(tokenizer.encode(prompt)) + len(tokens),
|
||||
}
|
||||
|
||||
choice = completion["choices"][0]
|
||||
if object_type.startswith("chat.completion"):
|
||||
key_name = "delta" if stream else "message"
|
||||
choice[key_name] = {"role": "assistant", "content": tokenizer.decode(tokens)}
|
||||
elif object_type == "text_completion":
|
||||
choice["text"] = tokenizer.decode(tokens)
|
||||
else:
|
||||
ValueError(f"Unsupported response type: {object_type}")
|
||||
|
||||
return completion
|
||||
|
||||
|
||||
def remap_messages(messages: List[Message]) -> List[Message]:
|
||||
remapped_messages = []
|
||||
last_image = None
|
||||
for message in messages:
|
||||
if not isinstance(message.content, list):
|
||||
remapped_messages.append(message)
|
||||
continue
|
||||
|
||||
remapped_content = []
|
||||
for content in message.content:
|
||||
if isinstance(content, dict):
|
||||
if content.get("type") in ["image_url", "image"]:
|
||||
image_url = content.get("image_url", {}).get("url") or content.get("image")
|
||||
if image_url:
|
||||
last_image = {"type": "image", "image": image_url}
|
||||
remapped_content.append({"type": "text", "text": "[An image was uploaded but is not displayed here]"})
|
||||
else:
|
||||
remapped_content.append(content)
|
||||
else:
|
||||
remapped_content.append(content)
|
||||
remapped_messages.append(Message(role=message.role, content=remapped_content))
|
||||
|
||||
if last_image:
|
||||
# Replace the last image placeholder with the actual image content
|
||||
for message in reversed(remapped_messages):
|
||||
for i, content in enumerate(message.content):
|
||||
if isinstance(content, dict):
|
||||
if content.get("type") == "text" and content.get("text") == "[An image was uploaded but is not displayed here]":
|
||||
message.content[i] = last_image
|
||||
return remapped_messages
|
||||
|
||||
return remapped_messages
|
||||
|
||||
|
||||
def build_prompt(tokenizer, _messages: List[Message], tools: Optional[List[Dict]] = None):
|
||||
messages = remap_messages(_messages)
|
||||
chat_template_args = {"conversation": [m.to_dict() for m in messages], "tokenize": False, "add_generation_prompt": True}
|
||||
if tools:
|
||||
chat_template_args["tools"] = tools
|
||||
|
||||
try:
|
||||
prompt = tokenizer.apply_chat_template(**chat_template_args)
|
||||
if DEBUG >= 3: print(f"!!! Prompt: {prompt}")
|
||||
return prompt
|
||||
except UnicodeEncodeError:
|
||||
# Handle Unicode encoding by ensuring everything is UTF-8
|
||||
chat_template_args["conversation"] = [
|
||||
{k: v.encode('utf-8').decode('utf-8') if isinstance(v, str) else v
|
||||
for k, v in m.to_dict().items()}
|
||||
for m in messages
|
||||
]
|
||||
prompt = tokenizer.apply_chat_template(**chat_template_args)
|
||||
if DEBUG >= 3: print(f"!!! Prompt (UTF-8 encoded): {prompt}")
|
||||
return prompt
|
||||
|
||||
|
||||
def parse_message(data: dict):
|
||||
if "role" not in data or "content" not in data:
|
||||
raise ValueError(f"Invalid message: {data}. Must have 'role' and 'content'")
|
||||
return Message(data["role"], data["content"], data.get("tools"))
|
||||
|
||||
|
||||
def parse_chat_request(data: dict, default_model: str):
|
||||
return ChatCompletionRequest(
|
||||
data.get("model", default_model),
|
||||
[parse_message(msg) for msg in data["messages"]],
|
||||
data.get("temperature", 0.0),
|
||||
data.get("tools", None),
|
||||
)
|
||||
|
||||
|
||||
class PromptSession:
|
||||
def __init__(self, request_id: str, timestamp: int, prompt: str):
|
||||
self.request_id = request_id
|
||||
self.timestamp = timestamp
|
||||
self.prompt = prompt
|
||||
|
||||
|
||||
class ChatGPTAPI:
|
||||
def __init__(
|
||||
self,
|
||||
node: Node,
|
||||
inference_engine_classname: str,
|
||||
response_timeout: int = 90,
|
||||
on_chat_completion_request: Callable[[str, ChatCompletionRequest, str], None] = None,
|
||||
default_model: Optional[str] = None,
|
||||
system_prompt: Optional[str] = None
|
||||
):
|
||||
self.node = node
|
||||
self.inference_engine_classname = inference_engine_classname
|
||||
self.response_timeout = response_timeout
|
||||
self.on_chat_completion_request = on_chat_completion_request
|
||||
self.app = web.Application(client_max_size=100*1024*1024) # 100MB to support image upload
|
||||
self.prompts: PrefixDict[str, PromptSession] = PrefixDict()
|
||||
self.prev_token_lens: Dict[str, int] = {}
|
||||
self.stream_tasks: Dict[str, asyncio.Task] = {}
|
||||
self.default_model = default_model or "llama-3.2-1b"
|
||||
self.token_queues = defaultdict(asyncio.Queue)
|
||||
|
||||
# Get the callback system and register our handler
|
||||
self.token_callback = node.on_token.register("chatgpt-api-token-handler")
|
||||
self.token_callback.on_next(lambda _request_id, tokens, is_finished: asyncio.create_task(self.handle_tokens(_request_id, tokens, is_finished)))
|
||||
self.system_prompt = system_prompt
|
||||
|
||||
cors = aiohttp_cors.setup(self.app)
|
||||
cors_options = aiohttp_cors.ResourceOptions(
|
||||
allow_credentials=True,
|
||||
expose_headers="*",
|
||||
allow_headers="*",
|
||||
allow_methods="*",
|
||||
)
|
||||
cors.add(self.app.router.add_get("/models", self.handle_get_models), {"*": cors_options})
|
||||
cors.add(self.app.router.add_get("/v1/models", self.handle_get_models), {"*": cors_options})
|
||||
cors.add(self.app.router.add_post("/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
|
||||
cors.add(self.app.router.add_post("/v1/chat/token/encode", self.handle_post_chat_token_encode), {"*": cors_options})
|
||||
cors.add(self.app.router.add_post("/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
|
||||
cors.add(self.app.router.add_post("/v1/chat/completions", self.handle_post_chat_completions), {"*": cors_options})
|
||||
cors.add(self.app.router.add_post("/v1/image/generations", self.handle_post_image_generations), {"*": cors_options})
|
||||
cors.add(self.app.router.add_get("/v1/download/progress", self.handle_get_download_progress), {"*": cors_options})
|
||||
cors.add(self.app.router.add_get("/modelpool", self.handle_model_support), {"*": cors_options})
|
||||
cors.add(self.app.router.add_get("/healthcheck", self.handle_healthcheck), {"*": cors_options})
|
||||
cors.add(self.app.router.add_post("/quit", self.handle_quit), {"*": cors_options})
|
||||
cors.add(self.app.router.add_delete("/models/{model_name}", self.handle_delete_model), {"*": cors_options})
|
||||
cors.add(self.app.router.add_get("/initial_models", self.handle_get_initial_models), {"*": cors_options})
|
||||
cors.add(self.app.router.add_post("/create_animation", self.handle_create_animation), {"*": cors_options})
|
||||
cors.add(self.app.router.add_post("/download", self.handle_post_download), {"*": cors_options})
|
||||
cors.add(self.app.router.add_get("/v1/topology", self.handle_get_topology), {"*": cors_options})
|
||||
cors.add(self.app.router.add_get("/topology", self.handle_get_topology), {"*": cors_options})
|
||||
|
||||
# Add static routes
|
||||
if "__compiled__" not in globals():
|
||||
self.static_dir = Path(__file__).parent.parent/"tinychat"
|
||||
self.app.router.add_get("/", self.handle_root)
|
||||
self.app.router.add_static("/", self.static_dir, name="static")
|
||||
|
||||
# Always add images route, regardless of compilation status
|
||||
self.images_dir = get_exo_images_dir()
|
||||
self.images_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.app.router.add_static('/images/', self.images_dir, name='static_images')
|
||||
|
||||
self.app.middlewares.append(self.timeout_middleware)
|
||||
self.app.middlewares.append(self.log_request)
|
||||
|
||||
async def handle_quit(self, request):
|
||||
if DEBUG >= 1: print("Received quit signal")
|
||||
response = web.json_response({"detail": "Quit signal received"}, status=200)
|
||||
await response.prepare(request)
|
||||
await response.write_eof()
|
||||
await shutdown(signal.SIGINT, asyncio.get_event_loop(), self.node.server)
|
||||
|
||||
async def timeout_middleware(self, app, handler):
|
||||
async def middleware(request):
|
||||
try:
|
||||
return await asyncio.wait_for(handler(request), timeout=self.response_timeout)
|
||||
except asyncio.TimeoutError:
|
||||
return web.json_response({"detail": "Request timed out"}, status=408)
|
||||
|
||||
return middleware
|
||||
|
||||
async def log_request(self, app, handler):
|
||||
async def middleware(request):
|
||||
if DEBUG >= 2: print(f"Received request: {request.method} {request.path}")
|
||||
return await handler(request)
|
||||
|
||||
return middleware
|
||||
|
||||
async def handle_root(self, request):
|
||||
return web.FileResponse(self.static_dir/"index.html")
|
||||
|
||||
async def handle_healthcheck(self, request):
|
||||
return web.json_response({"status": "ok"})
|
||||
|
||||
async def handle_model_support(self, request):
|
||||
try:
|
||||
response = web.StreamResponse(status=200, reason='OK', headers={ 'Content-Type': 'text/event-stream', 'Cache-Control': 'no-cache', 'Connection': 'keep-alive' })
|
||||
await response.prepare(request)
|
||||
async for path, s in self.node.shard_downloader.get_shard_download_status(self.inference_engine_classname):
|
||||
model_data = { s.shard.model_id: { "downloaded": s.downloaded_bytes == s.total_bytes, "download_percentage": 100 if s.downloaded_bytes == s.total_bytes else 100 * float(s.downloaded_bytes) / float(s.total_bytes), "total_size": s.total_bytes, "total_downloaded": s.downloaded_bytes } }
|
||||
await response.write(f"data: {json.dumps(model_data)}\n\n".encode())
|
||||
await response.write(b"data: [DONE]\n\n")
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in handle_model_support: {str(e)}")
|
||||
traceback.print_exc()
|
||||
return web.json_response({"detail": f"Server error: {str(e)}"}, status=500)
|
||||
|
||||
async def handle_get_models(self, request):
|
||||
models_list = [{"id": model_name, "object": "model", "owned_by": "exo", "ready": True} for model_name, _ in model_cards.items()]
|
||||
return web.json_response({"object": "list", "data": models_list})
|
||||
|
||||
async def handle_post_chat_token_encode(self, request):
|
||||
data = await request.json()
|
||||
model = data.get("model", self.default_model)
|
||||
if model and model.startswith("gpt-"): # Handle gpt- model requests
|
||||
model = self.default_model
|
||||
if not model or model not in model_cards:
|
||||
if DEBUG >= 1: print(f"Invalid model: {model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
|
||||
model = self.default_model
|
||||
shard = build_base_shard(model, self.inference_engine_classname)
|
||||
messages = [parse_message(msg) for msg in data.get("messages", [])]
|
||||
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
|
||||
prompt = build_prompt(tokenizer, messages, data.get("tools", None))
|
||||
tokens = tokenizer.encode(prompt)
|
||||
return web.json_response({
|
||||
"length": len(prompt),
|
||||
"num_tokens": len(tokens),
|
||||
"encoded_tokens": tokens,
|
||||
"encoded_prompt": prompt,
|
||||
})
|
||||
|
||||
async def handle_get_download_progress(self, request):
|
||||
progress_data = {}
|
||||
for node_id, progress_event in self.node.node_download_progress.items():
|
||||
if isinstance(progress_event, RepoProgressEvent):
|
||||
if progress_event.status != "in_progress": continue
|
||||
progress_data[node_id] = progress_event.to_dict()
|
||||
else:
|
||||
print(f"Unknown progress event type: {type(progress_event)}. {progress_event}")
|
||||
return web.json_response(progress_data)
|
||||
|
||||
async def handle_post_chat_completions(self, request):
|
||||
data = await request.json()
|
||||
if DEBUG >= 2: print(f"[ChatGPTAPI] Handling chat completions request from {request.remote}: {data}")
|
||||
stream = data.get("stream", False)
|
||||
chat_request = parse_chat_request(data, self.default_model)
|
||||
if chat_request.model and chat_request.model.startswith("gpt-"): # to be compatible with ChatGPT tools, point all gpt- model requests to default model
|
||||
chat_request.model = self.default_model
|
||||
if not chat_request.model or chat_request.model not in model_cards:
|
||||
if DEBUG >= 1: print(f"[ChatGPTAPI] Invalid model: {chat_request.model}. Supported: {list(model_cards.keys())}. Defaulting to {self.default_model}")
|
||||
chat_request.model = self.default_model
|
||||
shard = build_base_shard(chat_request.model, self.inference_engine_classname)
|
||||
if not shard:
|
||||
supported_models = [model for model, info in model_cards.items() if self.inference_engine_classname in info.get("repo", {})]
|
||||
return web.json_response(
|
||||
{"detail": f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"},
|
||||
status=400,
|
||||
)
|
||||
|
||||
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
|
||||
if DEBUG >= 4: print(f"[ChatGPTAPI] Resolved tokenizer: {tokenizer}")
|
||||
|
||||
# Add system prompt if set
|
||||
if self.system_prompt and not any(msg.role == "system" for msg in chat_request.messages):
|
||||
chat_request.messages.insert(0, Message("system", self.system_prompt))
|
||||
|
||||
prompt = build_prompt(tokenizer, chat_request.messages, chat_request.tools)
|
||||
request_id = str(uuid.uuid4())
|
||||
if self.on_chat_completion_request:
|
||||
try:
|
||||
self.on_chat_completion_request(request_id, chat_request, prompt)
|
||||
except Exception as e:
|
||||
if DEBUG >= 2: traceback.print_exc()
|
||||
|
||||
if DEBUG >= 2: print(f"[ChatGPTAPI] Processing prompt: {request_id=} {shard=} {prompt=}")
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id))), timeout=self.response_timeout)
|
||||
|
||||
if DEBUG >= 2: print(f"[ChatGPTAPI] Waiting for response to finish. timeout={self.response_timeout}s")
|
||||
|
||||
if stream:
|
||||
response = web.StreamResponse(
|
||||
status=200,
|
||||
reason="OK",
|
||||
headers={
|
||||
"Content-Type": "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
},
|
||||
)
|
||||
await response.prepare(request)
|
||||
|
||||
try:
|
||||
# Stream tokens while waiting for inference to complete
|
||||
while True:
|
||||
if DEBUG >= 2: print(f"[ChatGPTAPI] Waiting for token from queue: {request_id=}")
|
||||
tokens, is_finished = await asyncio.wait_for(
|
||||
self.token_queues[request_id].get(),
|
||||
timeout=self.response_timeout
|
||||
)
|
||||
if DEBUG >= 2: print(f"[ChatGPTAPI] Got token from queue: {request_id=} {tokens=} {is_finished=}")
|
||||
|
||||
eos_token_id = None
|
||||
if not eos_token_id and hasattr(tokenizer, "eos_token_id"): eos_token_id = tokenizer.eos_token_id
|
||||
if not eos_token_id and hasattr(tokenizer, "_tokenizer"): eos_token_id = tokenizer.special_tokens_map.get("eos_token_id")
|
||||
|
||||
finish_reason = None
|
||||
if is_finished: finish_reason = "stop" if tokens[-1] == eos_token_id else "length"
|
||||
if DEBUG >= 2: print(f"{eos_token_id=} {tokens[-1]=} {finish_reason=}")
|
||||
|
||||
completion = generate_completion(
|
||||
chat_request,
|
||||
tokenizer,
|
||||
prompt,
|
||||
request_id,
|
||||
tokens,
|
||||
stream,
|
||||
finish_reason,
|
||||
"chat.completion",
|
||||
)
|
||||
|
||||
await response.write(f"data: {json.dumps(completion)}\n\n".encode())
|
||||
|
||||
if is_finished:
|
||||
break
|
||||
|
||||
await response.write_eof()
|
||||
return response
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
if DEBUG >= 2: print(f"[ChatGPTAPI] Timeout waiting for token: {request_id=}")
|
||||
return web.json_response({"detail": "Response generation timed out"}, status=408)
|
||||
|
||||
except Exception as e:
|
||||
if DEBUG >= 2:
|
||||
print(f"[ChatGPTAPI] Error processing prompt: {e}")
|
||||
traceback.print_exc()
|
||||
return web.json_response(
|
||||
{"detail": f"Error processing prompt: {str(e)}"},
|
||||
status=500
|
||||
)
|
||||
|
||||
finally:
|
||||
# Clean up the queue for this request
|
||||
if request_id in self.token_queues:
|
||||
if DEBUG >= 2: print(f"[ChatGPTAPI] Cleaning up token queue: {request_id=}")
|
||||
del self.token_queues[request_id]
|
||||
else:
|
||||
tokens = []
|
||||
while True:
|
||||
_tokens, is_finished = await asyncio.wait_for(self.token_queues[request_id].get(), timeout=self.response_timeout)
|
||||
tokens.extend(_tokens)
|
||||
if is_finished:
|
||||
break
|
||||
finish_reason = "length"
|
||||
eos_token_id = None
|
||||
if not eos_token_id and hasattr(tokenizer, "eos_token_id"): eos_token_id = tokenizer.eos_token_id
|
||||
if not eos_token_id and hasattr(tokenizer, "_tokenizer"): eos_token_id = tokenizer.special_tokens_map.get("eos_token_id")
|
||||
if DEBUG >= 2: print(f"Checking if end of tokens result {tokens[-1]=} is {eos_token_id=}")
|
||||
if tokens[-1] == eos_token_id:
|
||||
finish_reason = "stop"
|
||||
|
||||
return web.json_response(generate_completion(chat_request, tokenizer, prompt, request_id, tokens, stream, finish_reason, "chat.completion"))
|
||||
except asyncio.TimeoutError:
|
||||
return web.json_response({"detail": "Response generation timed out"}, status=408)
|
||||
except Exception as e:
|
||||
if DEBUG >= 2: traceback.print_exc()
|
||||
return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
|
||||
|
||||
async def handle_post_image_generations(self, request):
|
||||
data = await request.json()
|
||||
|
||||
if DEBUG >= 2: print(f"Handling chat completions request from {request.remote}: {data}")
|
||||
stream = data.get("stream", False)
|
||||
model = data.get("model", "")
|
||||
prompt = data.get("prompt", "")
|
||||
image_url = data.get("image_url", "")
|
||||
if DEBUG >= 2: print(f"model: {model}, prompt: {prompt}, stream: {stream}")
|
||||
shard = build_base_shard(model, self.inference_engine_classname)
|
||||
if DEBUG >= 2: print(f"shard: {shard}")
|
||||
if not shard:
|
||||
return web.json_response({"error": f"Unsupported model: {model} with inference engine {self.inference_engine_classname}"}, status=400)
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
callback_id = f"chatgpt-api-wait-response-{request_id}"
|
||||
callback = self.node.on_token.register(callback_id)
|
||||
try:
|
||||
if image_url != "" and image_url != None:
|
||||
img = self.base64_decode(image_url)
|
||||
else:
|
||||
img = None
|
||||
await asyncio.wait_for(asyncio.shield(asyncio.create_task(self.node.process_prompt(shard, prompt, request_id=request_id, inference_state={"image": img}))), timeout=self.response_timeout)
|
||||
|
||||
response = web.StreamResponse(status=200, reason='OK', headers={
|
||||
'Content-Type': 'application/octet-stream',
|
||||
"Cache-Control": "no-cache",
|
||||
})
|
||||
await response.prepare(request)
|
||||
|
||||
def get_progress_bar(current_step, total_steps, bar_length=50):
|
||||
# Calculate the percentage of completion
|
||||
percent = float(current_step)/total_steps
|
||||
# Calculate the number of hashes to display
|
||||
arrow = '-'*int(round(percent*bar_length) - 1) + '>'
|
||||
spaces = ' '*(bar_length - len(arrow))
|
||||
|
||||
# Create the progress bar string
|
||||
progress_bar = f'Progress: [{arrow}{spaces}] {int(percent * 100)}% ({current_step}/{total_steps})'
|
||||
return progress_bar
|
||||
|
||||
async def stream_image(_request_id: str, result, is_finished: bool):
|
||||
if isinstance(result, list):
|
||||
await response.write(json.dumps({'progress': get_progress_bar((result[0]), (result[1]))}).encode('utf-8') + b'\n')
|
||||
|
||||
elif isinstance(result, np.ndarray):
|
||||
try:
|
||||
im = Image.fromarray(np.array(result))
|
||||
# Save the image to a file
|
||||
image_filename = f"{_request_id}.png"
|
||||
image_path = self.images_dir/image_filename
|
||||
im.save(image_path)
|
||||
|
||||
# Get URL for the saved image
|
||||
try:
|
||||
image_url = request.app.router['static_images'].url_for(filename=image_filename)
|
||||
base_url = f"{request.scheme}://{request.host}"
|
||||
full_image_url = base_url + str(image_url)
|
||||
|
||||
await response.write(json.dumps({'images': [{'url': str(full_image_url), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
|
||||
except KeyError as e:
|
||||
if DEBUG >= 2: print(f"Error getting image URL: {e}")
|
||||
# Fallback to direct file path if URL generation fails
|
||||
await response.write(json.dumps({'images': [{'url': str(image_path), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n')
|
||||
|
||||
if is_finished:
|
||||
await response.write_eof()
|
||||
|
||||
except Exception as e:
|
||||
if DEBUG >= 2: print(f"Error processing image: {e}")
|
||||
if DEBUG >= 2: traceback.print_exc()
|
||||
await response.write(json.dumps({'error': str(e)}).encode('utf-8') + b'\n')
|
||||
|
||||
stream_task = None
|
||||
|
||||
def on_result(_request_id: str, result, is_finished: bool):
|
||||
nonlocal stream_task
|
||||
stream_task = asyncio.create_task(stream_image(_request_id, result, is_finished))
|
||||
return _request_id == request_id and is_finished
|
||||
|
||||
await callback.wait(on_result, timeout=self.response_timeout*10)
|
||||
|
||||
if stream_task:
|
||||
# Wait for the stream task to complete before returning
|
||||
await stream_task
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
if DEBUG >= 2: traceback.print_exc()
|
||||
return web.json_response({"detail": f"Error processing prompt (see logs with DEBUG>=2): {str(e)}"}, status=500)
|
||||
|
||||
async def handle_delete_model(self, request):
|
||||
model_id = request.match_info.get('model_name')
|
||||
try:
|
||||
if await delete_model(model_id, self.inference_engine_classname): return web.json_response({"status": "success", "message": f"Model {model_id} deleted successfully"})
|
||||
else: return web.json_response({"detail": f"Model {model_id} files not found"}, status=404)
|
||||
except Exception as e:
|
||||
if DEBUG >= 2: traceback.print_exc()
|
||||
return web.json_response({"detail": f"Error deleting model: {str(e)}"}, status=500)
|
||||
|
||||
async def handle_get_initial_models(self, request):
|
||||
model_data = {}
|
||||
for model_id in get_supported_models([[self.inference_engine_classname]]):
|
||||
model_data[model_id] = {
|
||||
"name": get_pretty_name(model_id),
|
||||
"downloaded": None, # Initially unknown
|
||||
"download_percentage": None, # Change from 0 to null
|
||||
"total_size": None,
|
||||
"total_downloaded": None,
|
||||
"loading": True # Add loading state
|
||||
}
|
||||
return web.json_response(model_data)
|
||||
|
||||
async def handle_create_animation(self, request):
|
||||
try:
|
||||
data = await request.json()
|
||||
replacement_image_path = data.get("replacement_image_path")
|
||||
device_name = data.get("device_name", "Local Device")
|
||||
prompt_text = data.get("prompt", "")
|
||||
|
||||
if DEBUG >= 2: print(f"Creating animation with params: replacement_image={replacement_image_path}, device={device_name}, prompt={prompt_text}")
|
||||
|
||||
if not replacement_image_path:
|
||||
return web.json_response({"error": "replacement_image_path is required"}, status=400)
|
||||
|
||||
# Create temp directory if it doesn't exist
|
||||
tmp_dir = Path(tempfile.gettempdir())/"exo_animations"
|
||||
tmp_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Generate unique output filename in temp directory
|
||||
output_filename = f"animation_{uuid.uuid4()}.mp4"
|
||||
output_path = str(tmp_dir/output_filename)
|
||||
|
||||
if DEBUG >= 2: print(f"Animation temp directory: {tmp_dir}, output file: {output_path}, directory exists: {tmp_dir.exists()}, directory permissions: {oct(tmp_dir.stat().st_mode)[-3:]}")
|
||||
|
||||
# Create the animation
|
||||
create_animation_mp4(replacement_image_path, output_path, device_name, prompt_text)
|
||||
|
||||
return web.json_response({"status": "success", "output_path": output_path})
|
||||
|
||||
except Exception as e:
|
||||
if DEBUG >= 2: traceback.print_exc()
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def handle_post_download(self, request):
|
||||
try:
|
||||
data = await request.json()
|
||||
model_name = data.get("model")
|
||||
if not model_name: return web.json_response({"error": "model parameter is required"}, status=400)
|
||||
if model_name not in model_cards: return web.json_response({"error": f"Invalid model: {model_name}. Supported models: {list(model_cards.keys())}"}, status=400)
|
||||
shard = build_full_shard(model_name, self.inference_engine_classname)
|
||||
if not shard: return web.json_response({"error": f"Could not build shard for model {model_name}"}, status=400)
|
||||
asyncio.create_task(self.node.inference_engine.shard_downloader.ensure_shard(shard, self.inference_engine_classname))
|
||||
|
||||
return web.json_response({"status": "success", "message": f"Download started for model: {model_name}"})
|
||||
except Exception as e:
|
||||
if DEBUG >= 2: traceback.print_exc()
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def handle_get_topology(self, request):
|
||||
try:
|
||||
topology = self.node.current_topology
|
||||
if topology:
|
||||
return web.json_response(topology.to_json())
|
||||
else:
|
||||
return web.json_response({})
|
||||
except Exception as e:
|
||||
if DEBUG >= 2: traceback.print_exc()
|
||||
return web.json_response({"detail": f"Error getting topology: {str(e)}"}, status=500)
|
||||
|
||||
async def handle_tokens(self, request_id: str, tokens: List[int], is_finished: bool):
|
||||
await self.token_queues[request_id].put((tokens, is_finished))
|
||||
|
||||
async def run(self, host: str = "0.0.0.0", port: int = 52415):
|
||||
runner = web.AppRunner(self.app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, host, port)
|
||||
await site.start()
|
||||
|
||||
def base64_decode(self, base64_string):
|
||||
#decode and reshape image
|
||||
if base64_string.startswith('data:image'):
|
||||
base64_string = base64_string.split(',')[1]
|
||||
image_data = base64.b64decode(base64_string)
|
||||
img = Image.open(BytesIO(image_data))
|
||||
W, H = (dim - dim%64 for dim in (img.width, img.height))
|
||||
if W != img.width or H != img.height:
|
||||
if DEBUG >= 2: print(f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}")
|
||||
img = img.resize((W, H), Image.NEAREST) # use desired downsampling filter
|
||||
img = mx.array(np.array(img))
|
||||
img = (img[:, :, :3].astype(mx.float32)/255)*2 - 1
|
||||
img = img[None]
|
||||
return img
|
||||
@@ -1 +0,0 @@
|
||||
from exo.apputil.anim import create_animation_mp4
|
||||
@@ -1,168 +0,0 @@
|
||||
from PIL import Image, ImageDraw, ImageFont, ImageFilter
|
||||
import os
|
||||
import numpy as np
|
||||
import cv2
|
||||
import sys
|
||||
|
||||
def draw_rounded_rectangle(draw, coords, radius, fill):
|
||||
left, top, right, bottom = coords
|
||||
diameter = radius * 2
|
||||
draw.rectangle([left + radius, top, right - radius, bottom], fill=fill)
|
||||
draw.rectangle([left, top + radius, right, bottom - radius], fill=fill)
|
||||
draw.pieslice([left, top, left + diameter, top + diameter], 180, 270, fill=fill)
|
||||
draw.pieslice([right - diameter, top, right, top + diameter], 270, 360, fill=fill)
|
||||
draw.pieslice([left, bottom - diameter, left + diameter, bottom], 90, 180, fill=fill)
|
||||
draw.pieslice([right - diameter, bottom - diameter, right, bottom], 0, 90, fill=fill)
|
||||
|
||||
def draw_centered_text_rounded(draw, text, font, rect_coords, radius=10, text_color="yellow", bg_color=(43,33,44)):
|
||||
bbox = font.getbbox(text)
|
||||
text_width = bbox[2] - bbox[0]
|
||||
text_height = bbox[3] - bbox[1]
|
||||
rect_left, rect_top, rect_right, rect_bottom = rect_coords
|
||||
rect_width = rect_right - rect_left
|
||||
rect_height = rect_bottom - rect_top
|
||||
text_x = rect_left + (rect_width - text_width) // 2
|
||||
text_y = rect_top + (rect_height - text_height) // 2
|
||||
draw_rounded_rectangle(draw, rect_coords, radius, bg_color)
|
||||
draw.text((text_x, text_y), text, fill=text_color, font=font)
|
||||
|
||||
def draw_left_aligned_text_rounded(draw, text, font, rect_coords, padding_left=20, radius=10, text_color="yellow", bg_color=(43,33,44)):
|
||||
bbox = font.getbbox(text)
|
||||
text_height = bbox[3] - bbox[1]
|
||||
rect_left, rect_top, rect_right, rect_bottom = rect_coords
|
||||
rect_height = rect_bottom - rect_top
|
||||
text_y = rect_top + (rect_height - text_height) // 2
|
||||
text_x = rect_left + padding_left
|
||||
draw_rounded_rectangle(draw, rect_coords, radius, bg_color)
|
||||
draw.text((text_x, text_y), text, fill=text_color, font=font)
|
||||
|
||||
def draw_right_text_dynamic_width_rounded(draw, text, font, base_coords, padding=20, radius=10, text_color="yellow", bg_color=(43,33,44)):
|
||||
bbox = font.getbbox(text)
|
||||
text_width = bbox[2] - bbox[0]
|
||||
text_height = bbox[3] - bbox[1]
|
||||
_, rect_top, rect_right, rect_bottom = base_coords
|
||||
rect_height = rect_bottom - rect_top
|
||||
new_rect_left = rect_right - (text_width + (padding * 2))
|
||||
text_y = rect_top + (rect_height - text_height) // 2
|
||||
text_x = new_rect_left + padding
|
||||
draw_rounded_rectangle(draw, (new_rect_left, rect_top, rect_right, rect_bottom), radius, bg_color)
|
||||
draw.text((text_x, text_y), text, fill=text_color, font=font)
|
||||
return new_rect_left
|
||||
|
||||
def draw_progress_bar(draw, progress, coords, color="yellow", bg_color=(70, 70, 70)):
|
||||
left, top, right, bottom = coords
|
||||
total_width = right - left
|
||||
draw.rectangle(coords, fill=bg_color)
|
||||
progress_width = int(total_width * progress)
|
||||
if progress_width > 0:
|
||||
draw.rectangle((left, top, left + progress_width, bottom), fill=color)
|
||||
|
||||
def crop_image(image, top_crop=70):
|
||||
width, height = image.size
|
||||
return image.crop((0, top_crop, width, height))
|
||||
|
||||
def create_animation_mp4(
|
||||
replacement_image_path,
|
||||
output_path,
|
||||
device_name,
|
||||
prompt_text,
|
||||
fps=30,
|
||||
target_size=(512, 512),
|
||||
target_position=(139, 755),
|
||||
progress_coords=(139, 1285, 655, 1295),
|
||||
device_coords=(1240, 370, 1640, 416),
|
||||
prompt_coords=(332, 1702, 2662, 1745)
|
||||
):
|
||||
frames = []
|
||||
try:
|
||||
font = ImageFont.truetype("/System/Library/Fonts/SFNSMono.ttf", 20)
|
||||
promptfont = ImageFont.truetype("/System/Library/Fonts/SFNSMono.ttf", 24)
|
||||
except:
|
||||
font = ImageFont.load_default()
|
||||
promptfont = ImageFont.load_default()
|
||||
|
||||
# Get the base directory for images when running as a bundled app
|
||||
if hasattr(sys, '_MEIPASS'):
|
||||
base_dir = os.path.join(sys._MEIPASS, "exo", "apputil", "baseimages")
|
||||
else:
|
||||
base_dir = os.path.join(os.path.dirname(__file__), "baseimages")
|
||||
|
||||
# Process first frame
|
||||
base_img = Image.open(os.path.join(base_dir, "image1.png"))
|
||||
draw = ImageDraw.Draw(base_img)
|
||||
draw_centered_text_rounded(draw, device_name, font, device_coords)
|
||||
frames.extend([crop_image(base_img)] * 30) # 1 second at 30fps
|
||||
|
||||
# Process second frame with typing animation
|
||||
base_img2 = Image.open(os.path.join(base_dir, "image2.png"))
|
||||
for i in range(len(prompt_text) + 1):
|
||||
current_frame = base_img2.copy()
|
||||
draw = ImageDraw.Draw(current_frame)
|
||||
draw_centered_text_rounded(draw, device_name, font, device_coords)
|
||||
if i > 0: # Only draw if we have at least one character
|
||||
draw_left_aligned_text_rounded(draw, prompt_text[:i], promptfont, prompt_coords)
|
||||
frames.extend([crop_image(current_frame)] * 2) # 2 frames per character for smooth typing
|
||||
|
||||
# Hold the complete prompt for a moment
|
||||
frames.extend([frames[-1]] * 30) # Hold for 1 second
|
||||
|
||||
# Create blur sequence
|
||||
replacement_img = Image.open(replacement_image_path)
|
||||
base_img = Image.open(os.path.join(base_dir, "image3.png"))
|
||||
blur_steps = [int(80 * (1 - i/8)) for i in range(9)]
|
||||
|
||||
for i, blur_amount in enumerate(blur_steps):
|
||||
new_frame = base_img.copy()
|
||||
draw = ImageDraw.Draw(new_frame)
|
||||
|
||||
replacement_copy = replacement_img.copy()
|
||||
replacement_copy.thumbnail(target_size, Image.Resampling.LANCZOS)
|
||||
if blur_amount > 0:
|
||||
replacement_copy = replacement_copy.filter(ImageFilter.GaussianBlur(radius=blur_amount))
|
||||
|
||||
mask = replacement_copy.split()[-1] if replacement_copy.mode in ('RGBA', 'LA') else None
|
||||
new_frame.paste(replacement_copy, target_position, mask)
|
||||
|
||||
draw_progress_bar(draw, (i + 1) / 9, progress_coords)
|
||||
draw_centered_text_rounded(draw, device_name, font, device_coords)
|
||||
draw_right_text_dynamic_width_rounded(draw, prompt_text, promptfont, (None, 590, 2850, 685), padding=30)
|
||||
|
||||
frames.extend([crop_image(new_frame)] * 15) # 0.5 seconds at 30fps
|
||||
|
||||
# Create and add final frame (image4)
|
||||
final_base = Image.open(os.path.join(base_dir, "image4.png"))
|
||||
draw = ImageDraw.Draw(final_base)
|
||||
|
||||
draw_centered_text_rounded(draw, device_name, font, device_coords)
|
||||
draw_right_text_dynamic_width_rounded(draw, prompt_text, promptfont, (None, 590, 2850, 685), padding=30)
|
||||
|
||||
replacement_copy = replacement_img.copy()
|
||||
replacement_copy.thumbnail(target_size, Image.Resampling.LANCZOS)
|
||||
mask = replacement_copy.split()[-1] if replacement_copy.mode in ('RGBA', 'LA') else None
|
||||
final_base.paste(replacement_copy, target_position, mask)
|
||||
|
||||
frames.extend([crop_image(final_base)] * 30) # 1 second at 30fps
|
||||
|
||||
# Convert frames to video using H.264 codec
|
||||
if frames:
|
||||
first_frame = np.array(frames[0])
|
||||
height, width = first_frame.shape[:2]
|
||||
fourcc = cv2.VideoWriter_fourcc(*'avc1')
|
||||
out = cv2.VideoWriter(
|
||||
output_path,
|
||||
fourcc,
|
||||
fps,
|
||||
(width, height),
|
||||
isColor=True
|
||||
)
|
||||
|
||||
if not out.isOpened():
|
||||
print("Error: VideoWriter failed to open")
|
||||
return
|
||||
|
||||
for frame in frames:
|
||||
frame_array = cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR)
|
||||
out.write(frame_array)
|
||||
|
||||
out.release()
|
||||
print(f"Video saved successfully to {output_path}")
|
||||
@@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:361fdadd67c277d45cd18b0bfc8c5ceea5fd89f2d65aef157fd915ce9cbb8599
|
||||
size 814460
|
||||
@@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:f0e3891bc6b4f4dfa7444af53fcaa4b3ba06b0549546202be3243f08a0e6bd7e
|
||||
size 814235
|
||||
@@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a2dc5b3378aef397d60fd1252da8a1c578ad97e202a859590ffa416b49551d19
|
||||
size 146633
|
||||
@@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:dbc6883e2a3c5233ec7b844c98646922bdc4f5e42e1f424857eaff56f785dbcd
|
||||
size 668550
|
||||
@@ -1,65 +0,0 @@
|
||||
from typing import Dict, Callable, Coroutine, Any, Literal
|
||||
from exo.inference.shard import Shard
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
|
||||
|
||||
@dataclass
|
||||
class RepoFileProgressEvent:
|
||||
repo_id: str
|
||||
repo_revision: str
|
||||
file_path: str
|
||||
downloaded: int
|
||||
downloaded_this_session: int
|
||||
total: int
|
||||
speed: int
|
||||
eta: timedelta
|
||||
status: Literal["not_started", "in_progress", "complete"]
|
||||
start_time: float
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"repo_id": self.repo_id, "repo_revision": self.repo_revision, "file_path": self.file_path, "downloaded": self.downloaded, "downloaded_this_session": self.downloaded_this_session,
|
||||
"total": self.total, "speed": self.speed, "eta": self.eta.total_seconds(), "status": self.status, "start_time": self.start_time
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data):
|
||||
if 'eta' in data: data['eta'] = timedelta(seconds=data['eta'])
|
||||
return cls(**data)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RepoProgressEvent:
|
||||
shard: Shard
|
||||
repo_id: str
|
||||
repo_revision: str
|
||||
completed_files: int
|
||||
total_files: int
|
||||
downloaded_bytes: int
|
||||
downloaded_bytes_this_session: int
|
||||
total_bytes: int
|
||||
overall_speed: int
|
||||
overall_eta: timedelta
|
||||
file_progress: Dict[str, RepoFileProgressEvent]
|
||||
status: Literal["not_started", "in_progress", "complete"]
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"shard": self.shard.to_dict(), "repo_id": self.repo_id, "repo_revision": self.repo_revision, "completed_files": self.completed_files, "total_files": self.total_files, "downloaded_bytes": self.downloaded_bytes,
|
||||
"downloaded_bytes_this_session": self.downloaded_bytes_this_session, "total_bytes": self.total_bytes, "overall_speed": self.overall_speed, "overall_eta": self.overall_eta.total_seconds(),
|
||||
"file_progress": {k: v.to_dict()
|
||||
for k, v in self.file_progress.items()}, "status": self.status
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data):
|
||||
if 'overall_eta' in data: data['overall_eta'] = timedelta(seconds=data['overall_eta'])
|
||||
if 'file_progress' in data: data['file_progress'] = {k: RepoFileProgressEvent.from_dict(v) for k, v in data['file_progress'].items()}
|
||||
if 'shard' in data: data['shard'] = Shard.from_dict(data['shard'])
|
||||
|
||||
return cls(**data)
|
||||
|
||||
|
||||
RepoFileProgressCallback = Callable[[RepoFileProgressEvent], Coroutine[Any, Any, None]]
|
||||
RepoProgressCallback = Callable[[RepoProgressEvent], Coroutine[Any, Any, None]]
|
||||
@@ -1,98 +0,0 @@
|
||||
import aiofiles.os as aios
|
||||
from typing import Union
|
||||
import os
|
||||
from typing import Callable, Optional, Dict, List, Union
|
||||
from fnmatch import fnmatch
|
||||
from pathlib import Path
|
||||
from typing import Generator, Iterable, TypeVar
|
||||
from exo.helpers import DEBUG
|
||||
from exo.inference.shard import Shard
|
||||
import aiofiles
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
def filter_repo_objects(
|
||||
items: Iterable[T],
|
||||
*,
|
||||
allow_patterns: Optional[Union[List[str], str]] = None,
|
||||
ignore_patterns: Optional[Union[List[str], str]] = None,
|
||||
key: Optional[Callable[[T], str]] = None,
|
||||
) -> Generator[T, None, None]:
|
||||
if isinstance(allow_patterns, str):
|
||||
allow_patterns = [allow_patterns]
|
||||
if isinstance(ignore_patterns, str):
|
||||
ignore_patterns = [ignore_patterns]
|
||||
if allow_patterns is not None:
|
||||
allow_patterns = [_add_wildcard_to_directories(p) for p in allow_patterns]
|
||||
if ignore_patterns is not None:
|
||||
ignore_patterns = [_add_wildcard_to_directories(p) for p in ignore_patterns]
|
||||
|
||||
if key is None:
|
||||
def _identity(item: T) -> str:
|
||||
if isinstance(item, str):
|
||||
return item
|
||||
if isinstance(item, Path):
|
||||
return str(item)
|
||||
raise ValueError(f"Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.")
|
||||
key = _identity
|
||||
|
||||
for item in items:
|
||||
path = key(item)
|
||||
if allow_patterns is not None and not any(fnmatch(path, r) for r in allow_patterns):
|
||||
continue
|
||||
if ignore_patterns is not None and any(fnmatch(path, r) for r in ignore_patterns):
|
||||
continue
|
||||
yield item
|
||||
|
||||
def _add_wildcard_to_directories(pattern: str) -> str:
|
||||
if pattern[-1] == "/":
|
||||
return pattern + "*"
|
||||
return pattern
|
||||
|
||||
def get_hf_endpoint() -> str:
|
||||
return os.environ.get('HF_ENDPOINT', "https://huggingface.co")
|
||||
|
||||
def get_hf_home() -> Path:
|
||||
"""Get the Hugging Face home directory."""
|
||||
return Path(os.environ.get("HF_HOME", Path.home()/".cache"/"huggingface"))
|
||||
|
||||
async def get_hf_token():
|
||||
"""Retrieve the Hugging Face token from the user's HF_HOME directory."""
|
||||
token_path = get_hf_home()/"token"
|
||||
if await aios.path.exists(token_path):
|
||||
async with aiofiles.open(token_path, 'r') as f:
|
||||
return (await f.read()).strip()
|
||||
return None
|
||||
|
||||
async def get_auth_headers():
|
||||
"""Get authentication headers if a token is available."""
|
||||
token = await get_hf_token()
|
||||
if token:
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
return {}
|
||||
|
||||
def extract_layer_num(tensor_name: str) -> Optional[int]:
|
||||
# This is a simple example and might need to be adjusted based on the actual naming convention
|
||||
parts = tensor_name.split('.')
|
||||
for part in parts:
|
||||
if part.isdigit():
|
||||
return int(part)
|
||||
return None
|
||||
|
||||
def get_allow_patterns(weight_map: Dict[str, str], shard: Shard) -> List[str]:
|
||||
default_patterns = set(["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"])
|
||||
shard_specific_patterns = set()
|
||||
if weight_map:
|
||||
for tensor_name, filename in weight_map.items():
|
||||
layer_num = extract_layer_num(tensor_name)
|
||||
if layer_num is not None and shard.start_layer <= layer_num <= shard.end_layer:
|
||||
shard_specific_patterns.add(filename)
|
||||
sorted_file_names = sorted(weight_map.values())
|
||||
if shard.is_first_layer():
|
||||
shard_specific_patterns.add(sorted_file_names[0])
|
||||
elif shard.is_last_layer():
|
||||
shard_specific_patterns.add(sorted_file_names[-1])
|
||||
else:
|
||||
shard_specific_patterns = set(["*.safetensors"])
|
||||
if DEBUG >= 3: print(f"get_allow_patterns {weight_map=} {shard=} {shard_specific_patterns=}")
|
||||
return list(default_patterns | shard_specific_patterns)
|
||||
@@ -1,307 +0,0 @@
|
||||
from exo.inference.shard import Shard
|
||||
from exo.models import get_repo
|
||||
from pathlib import Path
|
||||
from exo.download.hf.hf_helpers import get_hf_endpoint, get_auth_headers, filter_repo_objects, get_allow_patterns
|
||||
from exo.download.shard_download import ShardDownloader
|
||||
from exo.download.download_progress import RepoProgressEvent, RepoFileProgressEvent
|
||||
from exo.helpers import AsyncCallbackSystem, DEBUG
|
||||
from exo.models import get_supported_models, build_full_shard
|
||||
import os
|
||||
import aiofiles.os as aios
|
||||
import aiohttp
|
||||
import aiofiles
|
||||
from urllib.parse import urljoin
|
||||
from typing import Callable, Union, Tuple, Dict, List, Optional, Literal, AsyncIterator
|
||||
import time
|
||||
from datetime import timedelta
|
||||
import asyncio
|
||||
import json
|
||||
import traceback
|
||||
import shutil
|
||||
import tempfile
|
||||
import hashlib
|
||||
|
||||
def exo_home() -> Path:
|
||||
return Path(os.environ.get("EXO_HOME", Path.home()/".cache"/"exo"))
|
||||
|
||||
def exo_tmp() -> Path:
|
||||
return Path(tempfile.gettempdir())/"exo"
|
||||
|
||||
async def ensure_exo_home() -> Path:
|
||||
await aios.makedirs(exo_home(), exist_ok=True)
|
||||
return exo_home()
|
||||
|
||||
async def ensure_exo_tmp() -> Path:
|
||||
await aios.makedirs(exo_tmp(), exist_ok=True)
|
||||
return exo_tmp()
|
||||
|
||||
async def has_exo_home_read_access() -> bool:
|
||||
try: return await aios.access(exo_home(), os.R_OK)
|
||||
except OSError: return False
|
||||
|
||||
async def has_exo_home_write_access() -> bool:
|
||||
try: return await aios.access(exo_home(), os.W_OK)
|
||||
except OSError: return False
|
||||
|
||||
async def ensure_downloads_dir() -> Path:
|
||||
downloads_dir = exo_home()/"downloads"
|
||||
await aios.makedirs(downloads_dir, exist_ok=True)
|
||||
return downloads_dir
|
||||
|
||||
async def delete_model(model_id: str, inference_engine_name: str) -> bool:
|
||||
repo_id = get_repo(model_id, inference_engine_name)
|
||||
model_dir = await ensure_downloads_dir()/repo_id.replace("/", "--")
|
||||
if not await aios.path.exists(model_dir): return False
|
||||
await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)
|
||||
return True
|
||||
|
||||
async def seed_models(seed_dir: Union[str, Path]):
|
||||
"""Move model in resources folder of app to .cache/huggingface/hub"""
|
||||
source_dir = Path(seed_dir)
|
||||
dest_dir = await ensure_downloads_dir()
|
||||
for path in source_dir.iterdir():
|
||||
if path.is_dir() and path.name.startswith("models--"):
|
||||
dest_path = dest_dir/path.name
|
||||
if await aios.path.exists(dest_path): print('Skipping moving model to .cache directory')
|
||||
else:
|
||||
try: await aios.rename(str(path), str(dest_path))
|
||||
except:
|
||||
print(f"Error seeding model {path} to {dest_path}")
|
||||
traceback.print_exc()
|
||||
|
||||
async def fetch_file_list_with_cache(repo_id: str, revision: str = "main") -> List[Dict[str, Union[str, int]]]:
|
||||
cache_file = (await ensure_exo_tmp())/f"{repo_id.replace('/', '--')}--{revision}--file_list.json"
|
||||
if await aios.path.exists(cache_file):
|
||||
async with aiofiles.open(cache_file, 'r') as f: return json.loads(await f.read())
|
||||
file_list = await fetch_file_list_with_retry(repo_id, revision)
|
||||
await aios.makedirs(cache_file.parent, exist_ok=True)
|
||||
async with aiofiles.open(cache_file, 'w') as f: await f.write(json.dumps(file_list))
|
||||
return file_list
|
||||
|
||||
async def fetch_file_list_with_retry(repo_id: str, revision: str = "main", path: str = "") -> List[Dict[str, Union[str, int]]]:
|
||||
n_attempts = 30
|
||||
for attempt in range(n_attempts):
|
||||
try: return await _fetch_file_list(repo_id, revision, path)
|
||||
except Exception as e:
|
||||
if attempt == n_attempts - 1: raise e
|
||||
await asyncio.sleep(min(8, 0.1 * (2 ** attempt)))
|
||||
|
||||
async def _fetch_file_list(repo_id: str, revision: str = "main", path: str = "") -> List[Dict[str, Union[str, int]]]:
|
||||
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
|
||||
url = f"{api_url}/{path}" if path else api_url
|
||||
|
||||
headers = await get_auth_headers()
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=30, connect=10, sock_read=30, sock_connect=10)) as session:
|
||||
async with session.get(url, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
files = []
|
||||
for item in data:
|
||||
if item["type"] == "file":
|
||||
files.append({"path": item["path"], "size": item["size"]})
|
||||
elif item["type"] == "directory":
|
||||
subfiles = await _fetch_file_list(repo_id, revision, item["path"])
|
||||
files.extend(subfiles)
|
||||
return files
|
||||
else:
|
||||
raise Exception(f"Failed to fetch file list: {response.status}")
|
||||
|
||||
async def calc_hash(path: Path, type: Literal["sha1", "sha256"] = "sha1") -> str:
|
||||
hash = hashlib.sha1() if type == "sha1" else hashlib.sha256()
|
||||
if type == "sha1":
|
||||
header = f"blob {(await aios.stat(path)).st_size}\0".encode()
|
||||
hash.update(header)
|
||||
async with aiofiles.open(path, 'rb') as f:
|
||||
while chunk := await f.read(8 * 1024 * 1024):
|
||||
hash.update(chunk)
|
||||
return hash.hexdigest()
|
||||
|
||||
async def file_meta(repo_id: str, revision: str, path: str) -> Tuple[int, str]:
|
||||
url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
|
||||
headers = await get_auth_headers()
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as session:
|
||||
async with session.head(url, headers=headers) as r:
|
||||
content_length = int(r.headers.get('x-linked-size') or r.headers.get('content-length') or 0)
|
||||
etag = r.headers.get('X-Linked-ETag') or r.headers.get('ETag') or r.headers.get('Etag')
|
||||
assert content_length > 0, f"No content length for {url}"
|
||||
assert etag is not None, f"No remote hash for {url}"
|
||||
if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"): etag = etag[1:-1]
|
||||
return content_length, etag
|
||||
|
||||
async def download_file_with_retry(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
|
||||
n_attempts = 30
|
||||
for attempt in range(n_attempts):
|
||||
try: return await _download_file(repo_id, revision, path, target_dir, on_progress)
|
||||
except Exception as e:
|
||||
if isinstance(e, FileNotFoundError) or attempt == n_attempts - 1: raise e
|
||||
print(f"Download error on attempt {attempt}/{n_attempts} for {repo_id=} {revision=} {path=} {target_dir=}")
|
||||
traceback.print_exc()
|
||||
await asyncio.sleep(min(8, 0.1 * (2 ** attempt)))
|
||||
|
||||
async def _download_file(repo_id: str, revision: str, path: str, target_dir: Path, on_progress: Callable[[int, int], None] = lambda _, __: None) -> Path:
|
||||
if await aios.path.exists(target_dir/path): return target_dir/path
|
||||
await aios.makedirs((target_dir/path).parent, exist_ok=True)
|
||||
length, etag = await file_meta(repo_id, revision, path)
|
||||
remote_hash = etag[:-5] if etag.endswith("-gzip") else etag
|
||||
partial_path = target_dir/f"{path}.partial"
|
||||
resume_byte_pos = (await aios.stat(partial_path)).st_size if (await aios.path.exists(partial_path)) else None
|
||||
if resume_byte_pos != length:
|
||||
url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
|
||||
headers = await get_auth_headers()
|
||||
if resume_byte_pos: headers['Range'] = f'bytes={resume_byte_pos}-'
|
||||
n_read = resume_byte_pos or 0
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as session:
|
||||
async with session.get(url, headers=headers, timeout=aiohttp.ClientTimeout(total=1800, connect=60, sock_read=1800, sock_connect=60)) as r:
|
||||
if r.status == 404: raise FileNotFoundError(f"File not found: {url}")
|
||||
assert r.status in [200, 206], f"Failed to download {path} from {url}: {r.status}"
|
||||
async with aiofiles.open(partial_path, 'ab' if resume_byte_pos else 'wb') as f:
|
||||
while chunk := await r.content.read(8 * 1024 * 1024): on_progress(n_read := n_read + await f.write(chunk), length)
|
||||
|
||||
final_hash = await calc_hash(partial_path, type="sha256" if len(remote_hash) == 64 else "sha1")
|
||||
integrity = final_hash == remote_hash
|
||||
if not integrity:
|
||||
try: await aios.remove(partial_path)
|
||||
except Exception as e: print(f"Error removing partial file {partial_path}: {e}")
|
||||
raise Exception(f"Downloaded file {target_dir/path} has hash {final_hash} but remote hash is {remote_hash}")
|
||||
await aios.rename(partial_path, target_dir/path)
|
||||
return target_dir/path
|
||||
|
||||
|
||||
def calculate_repo_progress(shard: Shard, repo_id: str, revision: str, file_progress: Dict[str, RepoFileProgressEvent], all_start_time: float) -> RepoProgressEvent:
|
||||
all_total_bytes = sum([p.total for p in file_progress.values()])
|
||||
all_downloaded_bytes = sum([p.downloaded for p in file_progress.values()])
|
||||
all_downloaded_bytes_this_session = sum([p.downloaded_this_session for p in file_progress.values()])
|
||||
elapsed_time = time.time() - all_start_time
|
||||
all_speed = all_downloaded_bytes_this_session / elapsed_time if elapsed_time > 0 else 0
|
||||
all_eta = timedelta(seconds=(all_total_bytes - all_downloaded_bytes) / all_speed) if all_speed > 0 else timedelta(seconds=0)
|
||||
status = "complete" if all(p.status == "complete" for p in file_progress.values()) else "in_progress" if any(p.status == "in_progress" for p in file_progress.values()) else "not_started"
|
||||
return RepoProgressEvent(shard, repo_id, revision, len([p for p in file_progress.values() if p.downloaded == p.total]), len(file_progress), all_downloaded_bytes, all_downloaded_bytes_this_session, all_total_bytes, all_speed, all_eta, file_progress, status)
|
||||
|
||||
async def get_weight_map(repo_id: str, revision: str = "main") -> Dict[str, str]:
|
||||
target_dir = (await ensure_exo_tmp())/repo_id.replace("/", "--")
|
||||
index_file = await download_file_with_retry(repo_id, revision, "model.safetensors.index.json", target_dir)
|
||||
async with aiofiles.open(index_file, 'r') as f: index_data = json.loads(await f.read())
|
||||
return index_data.get("weight_map")
|
||||
|
||||
async def resolve_allow_patterns(shard: Shard, inference_engine_classname: str) -> List[str]:
|
||||
try:
|
||||
weight_map = await get_weight_map(get_repo(shard.model_id, inference_engine_classname))
|
||||
return get_allow_patterns(weight_map, shard)
|
||||
except:
|
||||
if DEBUG >= 1: print(f"Error getting weight map for {shard.model_id=} and inference engine {inference_engine_classname}")
|
||||
if DEBUG >= 1: traceback.print_exc()
|
||||
return ["*"]
|
||||
|
||||
async def get_downloaded_size(path: Path) -> int:
|
||||
partial_path = path.with_suffix(path.suffix + ".partial")
|
||||
if await aios.path.exists(path): return (await aios.stat(path)).st_size
|
||||
if await aios.path.exists(partial_path): return (await aios.stat(partial_path)).st_size
|
||||
return 0
|
||||
|
||||
async def download_shard(shard: Shard, inference_engine_classname: str, on_progress: AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]], max_parallel_downloads: int = 8, skip_download: bool = False) -> tuple[Path, RepoProgressEvent]:
|
||||
if DEBUG >= 2 and not skip_download: print(f"Downloading {shard.model_id=} for {inference_engine_classname}")
|
||||
repo_id = get_repo(shard.model_id, inference_engine_classname)
|
||||
revision = "main"
|
||||
target_dir = await ensure_downloads_dir()/repo_id.replace("/", "--")
|
||||
if not skip_download: await aios.makedirs(target_dir, exist_ok=True)
|
||||
|
||||
if repo_id is None:
|
||||
raise ValueError(f"No repo found for {shard.model_id=} and inference engine {inference_engine_classname}")
|
||||
|
||||
allow_patterns = await resolve_allow_patterns(shard, inference_engine_classname)
|
||||
if DEBUG >= 2: print(f"Downloading {shard.model_id=} with {allow_patterns=}")
|
||||
|
||||
all_start_time = time.time()
|
||||
file_list = await fetch_file_list_with_cache(repo_id, revision)
|
||||
filtered_file_list = list(filter_repo_objects(file_list, allow_patterns=allow_patterns, key=lambda x: x["path"]))
|
||||
file_progress: Dict[str, RepoFileProgressEvent] = {}
|
||||
def on_progress_wrapper(file: dict, curr_bytes: int, total_bytes: int):
|
||||
start_time = file_progress[file["path"]].start_time if file["path"] in file_progress else time.time()
|
||||
downloaded_this_session = file_progress[file["path"]].downloaded_this_session + (curr_bytes - file_progress[file["path"]].downloaded) if file["path"] in file_progress else curr_bytes
|
||||
speed = downloaded_this_session / (time.time() - start_time) if time.time() - start_time > 0 else 0
|
||||
eta = timedelta(seconds=(total_bytes - curr_bytes) / speed) if speed > 0 else timedelta(seconds=0)
|
||||
file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], curr_bytes, downloaded_this_session, total_bytes, speed, eta, "complete" if curr_bytes == total_bytes else "in_progress", start_time)
|
||||
on_progress.trigger_all(shard, calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time))
|
||||
if DEBUG >= 6: print(f"Downloading {file['path']} {curr_bytes}/{total_bytes} {speed} {eta}")
|
||||
for file in filtered_file_list:
|
||||
downloaded_bytes = await get_downloaded_size(target_dir/file["path"])
|
||||
file_progress[file["path"]] = RepoFileProgressEvent(repo_id, revision, file["path"], downloaded_bytes, 0, file["size"], 0, timedelta(0), "complete" if downloaded_bytes == file["size"] else "not_started", time.time())
|
||||
|
||||
semaphore = asyncio.Semaphore(max_parallel_downloads)
|
||||
async def download_with_semaphore(file):
|
||||
async with semaphore:
|
||||
await download_file_with_retry(repo_id, revision, file["path"], target_dir, lambda curr_bytes, total_bytes: on_progress_wrapper(file, curr_bytes, total_bytes))
|
||||
if not skip_download: await asyncio.gather(*[download_with_semaphore(file) for file in filtered_file_list])
|
||||
final_repo_progress = calculate_repo_progress(shard, repo_id, revision, file_progress, all_start_time)
|
||||
on_progress.trigger_all(shard, final_repo_progress)
|
||||
if gguf := next((f for f in filtered_file_list if f["path"].endswith(".gguf")), None):
|
||||
return target_dir/gguf["path"], final_repo_progress
|
||||
else:
|
||||
return target_dir, final_repo_progress
|
||||
|
||||
def new_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
|
||||
return SingletonShardDownloader(CachedShardDownloader(NewShardDownloader(max_parallel_downloads)))
|
||||
|
||||
class SingletonShardDownloader(ShardDownloader):
|
||||
def __init__(self, shard_downloader: ShardDownloader):
|
||||
self.shard_downloader = shard_downloader
|
||||
self.active_downloads: Dict[Shard, asyncio.Task] = {}
|
||||
|
||||
@property
|
||||
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
|
||||
return self.shard_downloader.on_progress
|
||||
|
||||
async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
|
||||
if shard not in self.active_downloads: self.active_downloads[shard] = asyncio.create_task(self.shard_downloader.ensure_shard(shard, inference_engine_name))
|
||||
try: return await self.active_downloads[shard]
|
||||
finally:
|
||||
if shard in self.active_downloads and self.active_downloads[shard].done(): del self.active_downloads[shard]
|
||||
|
||||
async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
|
||||
async for path, status in self.shard_downloader.get_shard_download_status(inference_engine_name):
|
||||
yield path, status
|
||||
|
||||
class CachedShardDownloader(ShardDownloader):
|
||||
def __init__(self, shard_downloader: ShardDownloader):
|
||||
self.shard_downloader = shard_downloader
|
||||
self.cache: Dict[tuple[str, Shard], Path] = {}
|
||||
|
||||
@property
|
||||
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
|
||||
return self.shard_downloader.on_progress
|
||||
|
||||
async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
|
||||
if (inference_engine_name, shard) in self.cache:
|
||||
if DEBUG >= 2: print(f"ensure_shard cache hit {shard=} for {inference_engine_name}")
|
||||
return self.cache[(inference_engine_name, shard)]
|
||||
if DEBUG >= 2: print(f"ensure_shard cache miss {shard=} for {inference_engine_name}")
|
||||
target_dir = await self.shard_downloader.ensure_shard(shard, inference_engine_name)
|
||||
self.cache[(inference_engine_name, shard)] = target_dir
|
||||
return target_dir
|
||||
|
||||
async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
|
||||
async for path, status in self.shard_downloader.get_shard_download_status(inference_engine_name):
|
||||
yield path, status
|
||||
|
||||
class NewShardDownloader(ShardDownloader):
|
||||
def __init__(self, max_parallel_downloads: int = 8):
|
||||
self.max_parallel_downloads = max_parallel_downloads
|
||||
self._on_progress = AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]()
|
||||
|
||||
@property
|
||||
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
|
||||
return self._on_progress
|
||||
|
||||
async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
|
||||
target_dir, _ = await download_shard(shard, inference_engine_name, self.on_progress, max_parallel_downloads=self.max_parallel_downloads)
|
||||
return target_dir
|
||||
|
||||
async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
|
||||
if DEBUG >= 2: print("Getting shard download status for", inference_engine_name)
|
||||
tasks = [download_shard(build_full_shard(model_id, inference_engine_name), inference_engine_name, self.on_progress, skip_download=True) for model_id in get_supported_models([[inference_engine_name]])]
|
||||
for task in asyncio.as_completed(tasks):
|
||||
try:
|
||||
path, progress = await task
|
||||
yield (path, progress)
|
||||
except Exception as e:
|
||||
print("Error downloading shard:", e)
|
||||
@@ -1,49 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Tuple, Dict, AsyncIterator
|
||||
from pathlib import Path
|
||||
from exo.inference.shard import Shard
|
||||
from exo.download.download_progress import RepoProgressEvent
|
||||
from exo.helpers import AsyncCallbackSystem
|
||||
|
||||
|
||||
class ShardDownloader(ABC):
|
||||
@abstractmethod
|
||||
async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
|
||||
"""
|
||||
Ensures that the shard is downloaded.
|
||||
Does not allow multiple overlapping downloads at once.
|
||||
If you try to download a Shard which overlaps a Shard that is already being downloaded,
|
||||
the download will be cancelled and a new download will start.
|
||||
|
||||
Args:
|
||||
shard (Shard): The shard to download.
|
||||
inference_engine_name (str): The inference engine used on the node hosting the shard
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
|
||||
"""Get the download status of shards.
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, float]]: A dictionary mapping shard IDs to their download percentage (0-100),
|
||||
or None if status cannot be determined
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class NoopShardDownloader(ShardDownloader):
|
||||
async def ensure_shard(self, shard: Shard, inference_engine_name: str) -> Path:
|
||||
return Path("/tmp/noop_shard")
|
||||
|
||||
@property
|
||||
def on_progress(self) -> AsyncCallbackSystem[str, Tuple[Shard, RepoProgressEvent]]:
|
||||
return AsyncCallbackSystem()
|
||||
|
||||
async def get_shard_download_status(self, inference_engine_name: str) -> AsyncIterator[tuple[Path, RepoProgressEvent]]:
|
||||
if False: yield
|
||||
@@ -1,14 +0,0 @@
|
||||
from exo.download.new_shard_download import NewShardDownloader
|
||||
from exo.inference.shard import Shard
|
||||
import asyncio
|
||||
|
||||
async def test_new_shard_download():
|
||||
shard_downloader = NewShardDownloader()
|
||||
shard_downloader.on_progress.register("test").on_next(lambda shard, event: print(shard, event))
|
||||
await shard_downloader.ensure_shard(Shard(model_id="llama-3.2-1b", start_layer=0, end_layer=0, n_layers=16), "MLXDynamicShardInferenceEngine")
|
||||
async for path, shard_status in shard_downloader.get_shard_download_status("MLXDynamicShardInferenceEngine"):
|
||||
print("Shard download status:", path, shard_status)
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_new_shard_download())
|
||||
|
||||
372
exo/helpers.py
372
exo/helpers.py
@@ -1,372 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
from typing import Callable, TypeVar, Optional, Dict, Generic, Tuple, List
|
||||
import socket
|
||||
import random
|
||||
import platform
|
||||
import psutil
|
||||
import uuid
|
||||
from scapy.all import get_if_addr, get_if_list
|
||||
import re
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
import json
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import traceback
|
||||
|
||||
DEBUG = int(os.getenv("DEBUG", default="0"))
|
||||
DEBUG_DISCOVERY = int(os.getenv("DEBUG_DISCOVERY", default="0"))
|
||||
VERSION = "0.0.1"
|
||||
|
||||
exo_text = r"""
|
||||
_____ _____
|
||||
/ _ \ \/ / _ \
|
||||
| __/> < (_) |
|
||||
\___/_/\_\___/
|
||||
"""
|
||||
|
||||
# Single shared thread pool for subprocess operations
|
||||
subprocess_pool = ThreadPoolExecutor(max_workers=4, thread_name_prefix="subprocess_worker")
|
||||
|
||||
|
||||
def get_system_info():
|
||||
if psutil.MACOS:
|
||||
if platform.machine() == "arm64":
|
||||
return "Apple Silicon Mac"
|
||||
if platform.machine() in ["x86_64", "i386"]:
|
||||
return "Intel Mac"
|
||||
return "Unknown Mac architecture"
|
||||
if psutil.LINUX:
|
||||
return "Linux"
|
||||
return "Non-Mac, non-Linux system"
|
||||
|
||||
|
||||
def find_available_port(host: str = "", min_port: int = 49152, max_port: int = 65535) -> int:
|
||||
used_ports_file = os.path.join(tempfile.gettempdir(), "exo_used_ports")
|
||||
|
||||
def read_used_ports():
|
||||
if os.path.exists(used_ports_file):
|
||||
with open(used_ports_file, "r") as f:
|
||||
return [int(line.strip()) for line in f if line.strip().isdigit()]
|
||||
return []
|
||||
|
||||
def write_used_port(port, used_ports):
|
||||
with open(used_ports_file, "w") as f:
|
||||
print(used_ports[-19:])
|
||||
for p in used_ports[-19:] + [port]:
|
||||
f.write(f"{p}\n")
|
||||
|
||||
used_ports = read_used_ports()
|
||||
available_ports = set(range(min_port, max_port + 1)) - set(used_ports)
|
||||
|
||||
while available_ports:
|
||||
port = random.choice(list(available_ports))
|
||||
if DEBUG >= 2: print(f"Trying to find available port {port=}")
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind((host, port))
|
||||
write_used_port(port, used_ports)
|
||||
return port
|
||||
except socket.error:
|
||||
available_ports.remove(port)
|
||||
|
||||
raise RuntimeError("No available ports in the specified range")
|
||||
|
||||
|
||||
def print_exo():
|
||||
print(exo_text)
|
||||
|
||||
|
||||
def print_yellow_exo():
|
||||
yellow = "\033[93m" # ANSI escape code for yellow
|
||||
reset = "\033[0m" # ANSI escape code to reset color
|
||||
print(f"{yellow}{exo_text}{reset}")
|
||||
|
||||
|
||||
def terminal_link(uri, label=None):
|
||||
if label is None:
|
||||
label = uri
|
||||
parameters = ""
|
||||
|
||||
# OSC 8 ; params ; URI ST <name> OSC 8 ;; ST
|
||||
escape_mask = "\033]8;{};{}\033\\{}\033]8;;\033\\"
|
||||
|
||||
return escape_mask.format(parameters, uri, label)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
K = TypeVar("K")
|
||||
|
||||
|
||||
class AsyncCallback(Generic[T]):
|
||||
def __init__(self) -> None:
|
||||
self.condition: asyncio.Condition = asyncio.Condition()
|
||||
self.result: Optional[Tuple[T, ...]] = None
|
||||
self.observers: list[Callable[..., None]] = []
|
||||
|
||||
async def wait(self, check_condition: Callable[..., bool], timeout: Optional[float] = None) -> Tuple[T, ...]:
|
||||
async with self.condition:
|
||||
await asyncio.wait_for(self.condition.wait_for(lambda: self.result is not None and check_condition(*self.result)), timeout)
|
||||
assert self.result is not None # for type checking
|
||||
return self.result
|
||||
|
||||
def on_next(self, callback: Callable[..., None]) -> None:
|
||||
self.observers.append(callback)
|
||||
|
||||
def set(self, *args: T) -> None:
|
||||
self.result = args
|
||||
for observer in self.observers:
|
||||
observer(*args)
|
||||
asyncio.create_task(self.notify())
|
||||
|
||||
async def notify(self) -> None:
|
||||
async with self.condition:
|
||||
self.condition.notify_all()
|
||||
|
||||
|
||||
class AsyncCallbackSystem(Generic[K, T]):
|
||||
def __init__(self) -> None:
|
||||
self.callbacks: Dict[K, AsyncCallback[T]] = {}
|
||||
|
||||
def register(self, name: K) -> AsyncCallback[T]:
|
||||
if name not in self.callbacks:
|
||||
self.callbacks[name] = AsyncCallback[T]()
|
||||
return self.callbacks[name]
|
||||
|
||||
def deregister(self, name: K) -> None:
|
||||
if name in self.callbacks:
|
||||
del self.callbacks[name]
|
||||
|
||||
def trigger(self, name: K, *args: T) -> None:
|
||||
if name in self.callbacks:
|
||||
self.callbacks[name].set(*args)
|
||||
|
||||
def trigger_all(self, *args: T) -> None:
|
||||
for callback in self.callbacks.values():
|
||||
callback.set(*args)
|
||||
|
||||
|
||||
K = TypeVar('K', bound=str)
|
||||
V = TypeVar('V')
|
||||
|
||||
|
||||
class PrefixDict(Generic[K, V]):
|
||||
def __init__(self):
|
||||
self.items: Dict[K, V] = {}
|
||||
|
||||
def add(self, key: K, value: V) -> None:
|
||||
self.items[key] = value
|
||||
|
||||
def find_prefix(self, argument: str) -> List[Tuple[K, V]]:
|
||||
return [(key, value) for key, value in self.items.items() if argument.startswith(key)]
|
||||
|
||||
def find_longest_prefix(self, argument: str) -> Optional[Tuple[K, V]]:
|
||||
matches = self.find_prefix(argument)
|
||||
if len(matches) == 0:
|
||||
return None
|
||||
|
||||
return max(matches, key=lambda x: len(x[0]))
|
||||
|
||||
|
||||
def is_valid_uuid(val):
|
||||
try:
|
||||
uuid.UUID(str(val))
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def get_or_create_node_id():
|
||||
NODE_ID_FILE = Path(tempfile.gettempdir())/".exo_node_id"
|
||||
try:
|
||||
if NODE_ID_FILE.is_file():
|
||||
with open(NODE_ID_FILE, "r") as f:
|
||||
stored_id = f.read().strip()
|
||||
if is_valid_uuid(stored_id):
|
||||
if DEBUG >= 2: print(f"Retrieved existing node ID: {stored_id}")
|
||||
return stored_id
|
||||
else:
|
||||
if DEBUG >= 2: print("Stored ID is not a valid UUID. Generating a new one.")
|
||||
|
||||
new_id = str(uuid.uuid4())
|
||||
with open(NODE_ID_FILE, "w") as f:
|
||||
f.write(new_id)
|
||||
|
||||
if DEBUG >= 2: print(f"Generated and stored new node ID: {new_id}")
|
||||
return new_id
|
||||
except IOError as e:
|
||||
if DEBUG >= 2: print(f"IO error creating node_id: {e}")
|
||||
return str(uuid.uuid4())
|
||||
except Exception as e:
|
||||
if DEBUG >= 2: print(f"Unexpected error creating node_id: {e}")
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
def pretty_print_bytes(size_in_bytes: int) -> str:
|
||||
if size_in_bytes < 1024:
|
||||
return f"{size_in_bytes} B"
|
||||
elif size_in_bytes < 1024**2:
|
||||
return f"{size_in_bytes / 1024:.2f} KB"
|
||||
elif size_in_bytes < 1024**3:
|
||||
return f"{size_in_bytes / (1024 ** 2):.2f} MB"
|
||||
elif size_in_bytes < 1024**4:
|
||||
return f"{size_in_bytes / (1024 ** 3):.2f} GB"
|
||||
else:
|
||||
return f"{size_in_bytes / (1024 ** 4):.2f} TB"
|
||||
|
||||
|
||||
def pretty_print_bytes_per_second(bytes_per_second: int) -> str:
|
||||
if bytes_per_second < 1024:
|
||||
return f"{bytes_per_second} B/s"
|
||||
elif bytes_per_second < 1024**2:
|
||||
return f"{bytes_per_second / 1024:.2f} KB/s"
|
||||
elif bytes_per_second < 1024**3:
|
||||
return f"{bytes_per_second / (1024 ** 2):.2f} MB/s"
|
||||
elif bytes_per_second < 1024**4:
|
||||
return f"{bytes_per_second / (1024 ** 3):.2f} GB/s"
|
||||
else:
|
||||
return f"{bytes_per_second / (1024 ** 4):.2f} TB/s"
|
||||
|
||||
|
||||
def get_all_ip_addresses_and_interfaces():
|
||||
ip_addresses = []
|
||||
for interface in get_if_list():
|
||||
try:
|
||||
ip = get_if_addr(interface)
|
||||
if ip.startswith("0.0."): continue
|
||||
simplified_interface = re.sub(r'^\\Device\\NPF_', '', interface)
|
||||
ip_addresses.append((ip, simplified_interface))
|
||||
except:
|
||||
if DEBUG >= 1: print(f"Failed to get IP address for interface {interface}")
|
||||
if DEBUG >= 1: traceback.print_exc()
|
||||
if not ip_addresses:
|
||||
if DEBUG >= 1: print("Failed to get any IP addresses. Defaulting to localhost.")
|
||||
return [("localhost", "lo")]
|
||||
return list(set(ip_addresses))
|
||||
|
||||
|
||||
|
||||
async def get_macos_interface_type(ifname: str) -> Optional[Tuple[int, str]]:
|
||||
try:
|
||||
# Use the shared subprocess_pool
|
||||
output = await asyncio.get_running_loop().run_in_executor(
|
||||
subprocess_pool, lambda: subprocess.run(['system_profiler', 'SPNetworkDataType', '-json'], capture_output=True, text=True, close_fds=True).stdout
|
||||
)
|
||||
|
||||
data = json.loads(output)
|
||||
|
||||
for interface in data.get('SPNetworkDataType', []):
|
||||
if interface.get('interface') == ifname:
|
||||
hardware = interface.get('hardware', '').lower()
|
||||
type_name = interface.get('type', '').lower()
|
||||
name = interface.get('_name', '').lower()
|
||||
|
||||
if 'thunderbolt' in name:
|
||||
return (5, "Thunderbolt")
|
||||
if hardware == 'ethernet' or type_name == 'ethernet':
|
||||
if 'usb' in name:
|
||||
return (4, "Ethernet [USB]")
|
||||
return (4, "Ethernet")
|
||||
if hardware == 'airport' or type_name == 'airport' or 'wi-fi' in name:
|
||||
return (3, "WiFi")
|
||||
if type_name == 'vpn':
|
||||
return (1, "External Virtual")
|
||||
|
||||
except Exception as e:
|
||||
if DEBUG >= 2: print(f"Error detecting macOS interface type: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def get_interface_priority_and_type(ifname: str) -> Tuple[int, str]:
|
||||
# On macOS, try to get interface type using networksetup
|
||||
if psutil.MACOS:
|
||||
macos_type = await get_macos_interface_type(ifname)
|
||||
if macos_type is not None: return macos_type
|
||||
|
||||
# Local container/virtual interfaces
|
||||
if (ifname.startswith(('docker', 'br-', 'veth', 'cni', 'flannel', 'calico', 'weave')) or 'bridge' in ifname):
|
||||
return (7, "Container Virtual")
|
||||
|
||||
# Loopback interface
|
||||
if ifname.startswith('lo'):
|
||||
return (6, "Loopback")
|
||||
|
||||
# Traditional detection for non-macOS systems or fallback
|
||||
if ifname.startswith(('tb', 'nx', 'ten')):
|
||||
return (5, "Thunderbolt")
|
||||
|
||||
# Regular ethernet detection
|
||||
if ifname.startswith(('eth', 'en')) and not ifname.startswith(('en1', 'en0')):
|
||||
return (4, "Ethernet")
|
||||
|
||||
# WiFi detection
|
||||
if ifname.startswith(('wlan', 'wifi', 'wl')) or ifname in ['en0', 'en1']:
|
||||
return (3, "WiFi")
|
||||
|
||||
# Non-local virtual interfaces (VPNs, tunnels)
|
||||
if ifname.startswith(('tun', 'tap', 'vtun', 'utun', 'gif', 'stf', 'awdl', 'llw')):
|
||||
return (1, "External Virtual")
|
||||
|
||||
# Other physical interfaces
|
||||
return (2, "Other")
|
||||
|
||||
|
||||
async def shutdown(signal, loop, server):
|
||||
"""Gracefully shutdown the server and close the asyncio loop."""
|
||||
print(f"Received exit signal {signal.name}...")
|
||||
print("Thank you for using exo.")
|
||||
print_yellow_exo()
|
||||
server_tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
|
||||
[task.cancel() for task in server_tasks]
|
||||
print(f"Cancelling {len(server_tasks)} outstanding tasks")
|
||||
await asyncio.gather(*server_tasks, return_exceptions=True)
|
||||
await server.stop()
|
||||
|
||||
|
||||
def is_frozen():
|
||||
return getattr(sys, 'frozen', False) or os.path.basename(sys.executable) == "exo" \
|
||||
or ('Contents/MacOS' in str(os.path.dirname(sys.executable))) \
|
||||
or '__nuitka__' in globals() or getattr(sys, '__compiled__', False)
|
||||
|
||||
async def get_mac_system_info() -> Tuple[str, str, int]:
|
||||
"""Get Mac system information using system_profiler."""
|
||||
try:
|
||||
output = await asyncio.get_running_loop().run_in_executor(
|
||||
subprocess_pool,
|
||||
lambda: subprocess.check_output(["system_profiler", "SPHardwareDataType"]).decode("utf-8")
|
||||
)
|
||||
|
||||
model_line = next((line for line in output.split("\n") if "Model Name" in line), None)
|
||||
model_id = model_line.split(": ")[1] if model_line else "Unknown Model"
|
||||
|
||||
chip_line = next((line for line in output.split("\n") if "Chip" in line), None)
|
||||
chip_id = chip_line.split(": ")[1] if chip_line else "Unknown Chip"
|
||||
|
||||
memory_line = next((line for line in output.split("\n") if "Memory" in line), None)
|
||||
memory_str = memory_line.split(": ")[1] if memory_line else "Unknown Memory"
|
||||
memory_units = memory_str.split()
|
||||
memory_value = int(memory_units[0])
|
||||
memory = memory_value * 1024 if memory_units[1] == "GB" else memory_value
|
||||
|
||||
return model_id, chip_id, memory
|
||||
except Exception as e:
|
||||
if DEBUG >= 2: print(f"Error getting Mac system info: {e}")
|
||||
return "Unknown Model", "Unknown Chip", 0
|
||||
|
||||
def get_exo_home() -> Path:
|
||||
if psutil.WINDOWS: docs_folder = Path(os.environ["USERPROFILE"])/"Documents"
|
||||
else: docs_folder = Path.home()/"Documents"
|
||||
if not docs_folder.exists(): docs_folder.mkdir(exist_ok=True)
|
||||
exo_folder = docs_folder/"Exo"
|
||||
if not exo_folder.exists(): exo_folder.mkdir(exist_ok=True)
|
||||
return exo_folder
|
||||
|
||||
|
||||
def get_exo_images_dir() -> Path:
|
||||
exo_home = get_exo_home()
|
||||
images_dir = exo_home/"Images"
|
||||
if not images_dir.exists(): images_dir.mkdir(exist_ok=True)
|
||||
return images_dir
|
||||
@@ -1,58 +0,0 @@
|
||||
from exo.inference.inference_engine import InferenceEngine
|
||||
from exo.inference.shard import Shard
|
||||
from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
|
||||
import asyncio
|
||||
import numpy as np
|
||||
|
||||
|
||||
# An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
|
||||
async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str):
|
||||
from exo.inference.tinygrad.inference import Tokenizer
|
||||
from pathlib import Path
|
||||
|
||||
_tokenizer = Tokenizer(str(Path(model_id)/"tokenizer.model"))
|
||||
|
||||
prompt = "In a single word only, what is the last name of the president of the United States? "
|
||||
resp_full = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
|
||||
token_full = await inference_engine_1.sample(resp_full)
|
||||
|
||||
next_resp_full, _ = await inference_engine_1.infer_tensor(
|
||||
"A",
|
||||
shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32),
|
||||
input_data=token_full,
|
||||
)
|
||||
|
||||
resp1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32), prompt=prompt)
|
||||
resp2, _ = await inference_engine_2.infer_tensor(
|
||||
"B",
|
||||
shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
|
||||
input_data=resp1,
|
||||
)
|
||||
token2 = await inference_engine_2.sample(resp2)
|
||||
resp3, _ = await inference_engine_1.infer_tensor(
|
||||
"B",
|
||||
shard=Shard(model_id=model_id, start_layer=0, end_layer=30, n_layers=32),
|
||||
input_data=token2,
|
||||
)
|
||||
resp4, _ = await inference_engine_2.infer_tensor(
|
||||
"B",
|
||||
shard=Shard(model_id=model_id, start_layer=31, end_layer=31, n_layers=32),
|
||||
input_data=resp3,
|
||||
)
|
||||
|
||||
print(f"{resp2=}")
|
||||
print(f"full: {_tokenizer.decode(resp_full)}")
|
||||
print(f"next full: {_tokenizer.decode(next_resp_full)}")
|
||||
print(f"resp2: {_tokenizer.decode(resp2)}")
|
||||
print(f"{resp4=}")
|
||||
print(f"resp4: {_tokenizer.decode(resp4)}")
|
||||
|
||||
assert np.array_equal(resp_full, resp2)
|
||||
assert np.array_equal(next_resp_full, resp4)
|
||||
|
||||
|
||||
asyncio.run(test_inference_engine(
|
||||
TinygradDynamicShardInferenceEngine(),
|
||||
TinygradDynamicShardInferenceEngine(),
|
||||
"llama3-8b-sfr",
|
||||
))
|
||||
@@ -1,37 +0,0 @@
|
||||
from typing import Optional, Tuple, TYPE_CHECKING
|
||||
import numpy as np
|
||||
from exo.inference.inference_engine import InferenceEngine
|
||||
from exo.inference.shard import Shard
|
||||
from exo.inference.tokenizers import DummyTokenizer
|
||||
|
||||
class DummyInferenceEngine(InferenceEngine):
|
||||
def __init__(self):
|
||||
self.shard = None
|
||||
self.vocab_size = 1000
|
||||
self.hidden_size = 256
|
||||
self.eos_token_id = 0
|
||||
self.latency_mean = 0.1
|
||||
self.latency_stddev = 0.02
|
||||
self.num_generate_dummy_tokens = 10
|
||||
self.tokenizer = DummyTokenizer()
|
||||
|
||||
async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
|
||||
return np.array(self.tokenizer.encode(prompt))
|
||||
|
||||
async def sample(self, x: np.ndarray, temp: float = 0.0, top_p: float = 1.0) -> np.ndarray:
|
||||
if x[0] > self.num_generate_dummy_tokens: return np.array([self.tokenizer.eos_token_id])
|
||||
return x
|
||||
|
||||
async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
|
||||
return self.tokenizer.decode(tokens)
|
||||
|
||||
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
|
||||
await self.ensure_shard(shard)
|
||||
return input_data + 1 if self.shard.is_last_layer() else input_data, None
|
||||
|
||||
async def ensure_shard(self, shard: Shard):
|
||||
if self.shard == shard: return
|
||||
self.shard = shard
|
||||
|
||||
async def load_checkpoint(self, shard: Shard, path: str):
|
||||
await self.ensure_shard(shard)
|
||||
@@ -1,77 +0,0 @@
|
||||
import numpy as np
|
||||
import os
|
||||
from exo.helpers import DEBUG # Make sure to import DEBUG
|
||||
|
||||
from typing import Tuple, Optional
|
||||
from abc import ABC, abstractmethod
|
||||
from .shard import Shard
|
||||
from exo.download.shard_download import ShardDownloader
|
||||
|
||||
|
||||
class InferenceEngine(ABC):
|
||||
session = {}
|
||||
|
||||
@abstractmethod
|
||||
async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def sample(self, x: np.ndarray) -> np.ndarray:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def decode(self, shard: Shard, tokens: np.ndarray) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def load_checkpoint(self, shard: Shard, path: str):
|
||||
pass
|
||||
|
||||
async def save_checkpoint(self, shard: Shard, path: str):
|
||||
pass
|
||||
|
||||
async def save_session(self, key, value):
|
||||
self.session[key] = value
|
||||
|
||||
async def clear_session(self):
|
||||
self.session.empty()
|
||||
|
||||
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
|
||||
tokens = await self.encode(shard, prompt)
|
||||
if shard.model_id != 'stable-diffusion-2-1-base':
|
||||
x = tokens.reshape(1, -1)
|
||||
else:
|
||||
x = tokens
|
||||
output_data, inference_state = await self.infer_tensor(request_id, shard, x, inference_state)
|
||||
|
||||
return output_data, inference_state
|
||||
|
||||
|
||||
inference_engine_classes = {
|
||||
"mlx": "MLXDynamicShardInferenceEngine",
|
||||
"tinygrad": "TinygradDynamicShardInferenceEngine",
|
||||
"dummy": "DummyInferenceEngine",
|
||||
}
|
||||
|
||||
|
||||
def get_inference_engine(inference_engine_name: str, shard_downloader: ShardDownloader):
|
||||
if DEBUG >= 2:
|
||||
print(f"get_inference_engine called with: {inference_engine_name}")
|
||||
if inference_engine_name == "mlx":
|
||||
from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
|
||||
|
||||
return MLXDynamicShardInferenceEngine(shard_downloader)
|
||||
elif inference_engine_name == "tinygrad":
|
||||
from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
|
||||
import tinygrad.helpers
|
||||
tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
|
||||
|
||||
return TinygradDynamicShardInferenceEngine(shard_downloader)
|
||||
elif inference_engine_name == "dummy":
|
||||
from exo.inference.dummy_inference_engine import DummyInferenceEngine
|
||||
return DummyInferenceEngine()
|
||||
raise ValueError(f"Unsupported inference engine: {inference_engine_name}")
|
||||
@@ -1,37 +0,0 @@
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
def length_masked_ce_loss(model, inputs, targets, lengths):
|
||||
# Run model on inputs
|
||||
logits = model(inputs).astype(mx.float32)
|
||||
|
||||
# Mask padding tokens
|
||||
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
|
||||
|
||||
# Calculate the loss
|
||||
ce = nn.losses.cross_entropy(logits, targets) * length_mask
|
||||
loss = ce.sum() / length_mask.sum()
|
||||
# print(f"| {inputs=}\n| ==>{logits=}\n| ~^~{ce=}\n| == {loss=}")
|
||||
return loss
|
||||
|
||||
#Naive intermediate layer loss, where we replace the targets with gradients and just multiply the output by the gradients to derive the loss. This is naive and may warrant some further iteration, but will do the job for now
|
||||
def back_gradient_loss(model, inputs, gradients, lengths):
|
||||
out = model(inputs).astype(mx.float32)
|
||||
grad = gradients.astype(mx.float32)
|
||||
|
||||
# Mask padding tokens
|
||||
length_mask = mx.repeat(mx.arange(inputs.shape[1])[None, :] < lengths[:, None], out.shape[-1]).reshape(out.shape)
|
||||
|
||||
masked_sum = (out * length_mask).sum(axis=1)
|
||||
gradient_lens = mx.abs(grad * masked_sum)
|
||||
loss = gradient_lens.sum() / length_mask.sum()
|
||||
# print(f"| {inputs=}\n"
|
||||
# + f"| ==>{out=}\n"
|
||||
# + f"| ~^~{masked_sum=}\n"
|
||||
# + f"| <~>{gradient_lens=}\n"
|
||||
# + f"| == {loss=}")
|
||||
return loss
|
||||
|
||||
loss_fns = {
|
||||
"back_gradient": back_gradient_loss,
|
||||
"length_masked_ce": length_masked_ce_loss,
|
||||
}
|
||||
@@ -1,307 +0,0 @@
|
||||
# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/__init__.py
|
||||
|
||||
import time
|
||||
from typing import Optional, Tuple
|
||||
import inspect
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from pathlib import Path
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from .sd_models.vae import ModelArgs as VAEArgs
|
||||
from .sd_models.vae import Autoencoder
|
||||
from .sd_models.tokenizer import load_tokenizer
|
||||
from .sd_models.clip import CLIPTextModel
|
||||
from .sd_models.clip import ModelArgs as CLIPArgs
|
||||
from .sd_models.unet import UNetConfig, UNetModel
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from exo.inference.shard import Shard
|
||||
|
||||
@dataclass
|
||||
class DiffusionConfig:
|
||||
beta_schedule: str = "scaled_linear"
|
||||
beta_start: float = 0.00085
|
||||
beta_end: float = 0.012
|
||||
num_train_steps: int = 1000
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, params):
|
||||
return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
|
||||
|
||||
|
||||
#Sampler
|
||||
def _linspace(a, b, num):
|
||||
x = mx.arange(0, num) / (num - 1)
|
||||
return (b - a) * x + a
|
||||
|
||||
|
||||
def _interp(y, x_new):
|
||||
"""Interpolate the function defined by (arange(0, len(y)), y) at positions x_new."""
|
||||
x_low = x_new.astype(mx.int32)
|
||||
x_high = mx.minimum(x_low + 1, len(y) - 1)
|
||||
|
||||
y_low = y[x_low]
|
||||
y_high = y[x_high]
|
||||
delta_x = x_new - x_low
|
||||
y_new = y_low * (1 - delta_x) + delta_x * y_high
|
||||
|
||||
return y_new
|
||||
|
||||
class SimpleEulerSampler:
|
||||
"""A simple Euler integrator that can be used to sample from our diffusion models.
|
||||
|
||||
The method ``step()`` performs one Euler step from x_t to x_t_prev.
|
||||
"""
|
||||
|
||||
def __init__(self, config: DiffusionConfig):
|
||||
# Compute the noise schedule
|
||||
if config.beta_schedule == "linear":
|
||||
betas = _linspace(
|
||||
config.beta_start, config.beta_end, config.num_train_steps
|
||||
)
|
||||
elif config.beta_schedule == "scaled_linear":
|
||||
betas = _linspace(
|
||||
config.beta_start**0.5, config.beta_end**0.5, config.num_train_steps
|
||||
).square()
|
||||
else:
|
||||
raise NotImplementedError(f"{config.beta_schedule} is not implemented.")
|
||||
|
||||
alphas = 1 - betas
|
||||
alphas_cumprod = mx.cumprod(alphas)
|
||||
|
||||
self._sigmas = mx.concatenate(
|
||||
[mx.zeros(1), ((1 - alphas_cumprod) / alphas_cumprod).sqrt()]
|
||||
)
|
||||
|
||||
@property
|
||||
def max_time(self):
|
||||
return len(self._sigmas) - 1
|
||||
|
||||
def sample_prior(self, shape, dtype=mx.float32, key=None):
|
||||
noise = mx.random.normal(shape, key=key)
|
||||
return (
|
||||
noise * self._sigmas[-1] * (self._sigmas[-1].square() + 1).rsqrt()
|
||||
).astype(dtype)
|
||||
|
||||
def add_noise(self, x, t, key=None):
|
||||
noise = mx.random.normal(x.shape, key=key)
|
||||
s = self.sigmas(t)
|
||||
return (x + noise * s) * (s.square() + 1).rsqrt()
|
||||
|
||||
def sigmas(self, t):
|
||||
return _interp(self._sigmas, t)
|
||||
|
||||
def timesteps(self, num_steps: int, start_time=None, dtype=mx.float32):
|
||||
start_time = start_time or (len(self._sigmas) - 1)
|
||||
assert 0 < start_time <= (len(self._sigmas) - 1)
|
||||
steps = _linspace(start_time, 0, num_steps + 1).astype(dtype)
|
||||
return list(zip(steps, steps[1:]))
|
||||
|
||||
def current_timestep(self, step, total_steps, start_time=None):
|
||||
if step < total_steps:
|
||||
steps = self.timesteps(total_steps, start_time)
|
||||
return steps[step]
|
||||
else:
|
||||
return mx.array(0),mx.array(0)
|
||||
|
||||
def step(self, eps_pred, x_t, t, t_prev):
|
||||
sigma = self.sigmas(t).astype(eps_pred.dtype)
|
||||
sigma_prev = self.sigmas(t_prev).astype(eps_pred.dtype)
|
||||
|
||||
dt = sigma_prev - sigma
|
||||
x_t_prev = (sigma.square() + 1).sqrt() * x_t + eps_pred * dt
|
||||
|
||||
x_t_prev = x_t_prev * (sigma_prev.square() + 1).rsqrt()
|
||||
|
||||
return x_t_prev
|
||||
|
||||
@dataclass
|
||||
class ShardConfig:
|
||||
model_id:str
|
||||
start_layer:int
|
||||
end_layer:int
|
||||
n_layers:int
|
||||
|
||||
@dataclass
|
||||
class StableDiffusionConfig:
|
||||
model_type:str
|
||||
vae:VAEArgs
|
||||
text_encoder:CLIPArgs
|
||||
scheduler:DiffusionConfig
|
||||
unet:UNetConfig
|
||||
shard:ShardConfig
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, params):
|
||||
return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(StableDiffusionConfig):
|
||||
shard:Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.shard, dict):
|
||||
self.shard = Shard(**self.shard)
|
||||
|
||||
if not isinstance(self.shard, Shard):
|
||||
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.model_type = config.model_type
|
||||
self.config = config
|
||||
self.model_path = config.vae['path'].split('/vae')[0]
|
||||
self.shard = config.shard
|
||||
self.shard_clip, self.shard_encoder, self.shard_unet, self.shard_decoder = model_shards(config.shard)
|
||||
self.config_clip=CLIPArgs.from_dict(config.text_encoder['config'])
|
||||
if self.shard_clip.start_layer != -1:
|
||||
self.text_encoder = CLIPTextModel(self.config_clip, shard=self.shard_clip)
|
||||
else:
|
||||
self.text_encoder = nn.Identity()
|
||||
self.tokenizer = load_tokenizer(Path(self.model_path), "vocab.json", "merges.txt")
|
||||
self.diffusion_config = DiffusionConfig.from_dict(config.scheduler['config'])
|
||||
self.sampler = SimpleEulerSampler(self.diffusion_config)
|
||||
if self.shard_unet.start_layer!=-1:
|
||||
self.config_unet = UNetConfig.from_dict(config.unet['config'])
|
||||
self.unet = UNetModel(self.config_unet, self.shard_unet)
|
||||
else:
|
||||
self.unet = nn.Identity()
|
||||
self.config_vae=VAEArgs.from_dict(config.vae['config'])
|
||||
if self.shard_encoder.start_layer != -1:
|
||||
self.encoder=Autoencoder(self.config_vae, self.shard_encoder, "vae_encoder")
|
||||
else:
|
||||
self.encoder = nn.Identity()
|
||||
if self.shard_decoder.start_layer != -1:
|
||||
self.decoder=Autoencoder(self.config_vae, self.shard_decoder, "vae_decoder")
|
||||
else:
|
||||
self.decoder = nn.Identity()
|
||||
|
||||
def __call__(self,x, step= 0, cfg_weight: float = 7.5,total_steps=50,conditioning=None,mask=None,residual=None,x_t_prev=None,is_finished=False,is_step_finished=False, image=None, strength=0.65, start_step=None):
|
||||
t, t_prev = self.sampler.current_timestep(step=step, total_steps=total_steps, start_time=start_step)
|
||||
is_finished = False
|
||||
is_step_finished = False
|
||||
if t.item()==1000:
|
||||
if self.shard_clip.start_layer == 0:
|
||||
conditioning = x
|
||||
if self.shard_clip.start_layer != -1:
|
||||
conditioning, mask= self.text_encoder(conditioning,mask)
|
||||
seed = int(time.time())
|
||||
mx.random.seed(seed)
|
||||
if image is None:
|
||||
if self.shard_encoder.is_last_layer():
|
||||
x = self.sampler.sample_prior((1, *(64, 64), self.config_vae.latent_channels_in), dtype=mx.float32)
|
||||
x_t_prev=x
|
||||
start_step = self.sampler.max_time
|
||||
else:
|
||||
if self.shard_encoder.start_layer != -1:
|
||||
image= self.encoder.encode(image)
|
||||
if self.shard_encoder.is_last_layer():
|
||||
start_step = self.sampler.max_time*strength
|
||||
total_steps = int(total_steps*strength)
|
||||
image = mx.broadcast_to(image, (1,) + image.shape[1:])
|
||||
x_t_prev=self.sampler.add_noise(image, mx.array(start_step))
|
||||
image = None
|
||||
t, t_prev = self.sampler.current_timestep(step=step, total_steps=total_steps, start_time=start_step)
|
||||
# Perform the denoising loop
|
||||
if self.shard_unet.start_layer != -1:
|
||||
with tqdm(total=total_steps,initial=step+1) as pbar:
|
||||
if step<total_steps:
|
||||
x = x_t_prev
|
||||
if self.shard_unet.is_first_layer():
|
||||
x_t_unet = mx.concatenate([x] * 2, axis=0) if cfg_weight> 1 else x
|
||||
else:
|
||||
x_t_unet = x
|
||||
t_unet = mx.broadcast_to(t, [len(x_t_unet)])
|
||||
x, residual= self.unet(x_t_unet, t_unet, encoder_x=conditioning, residuals=residual)
|
||||
if self.shard_unet.is_last_layer():
|
||||
if cfg_weight > 1:
|
||||
eps_text, eps_neg = x.split(2)
|
||||
eps_pred = eps_neg + cfg_weight * (eps_text - eps_neg)
|
||||
x = self.sampler.step(eps_pred, x_t_prev, t, t_prev)
|
||||
x_t_prev=x
|
||||
mx.eval(x)
|
||||
|
||||
if self.shard_decoder.is_last_layer():
|
||||
is_step_finished=True
|
||||
if self.shard_decoder.start_layer != -1:
|
||||
x=self.decoder.decode(x)
|
||||
if self.shard_decoder.is_last_layer():
|
||||
x = mx.clip(x / 2 + 0.5, 0, 1)
|
||||
B, H, W, C = x.shape
|
||||
x = x.reshape(1, B // 1, H, W, C).transpose(0, 2, 1, 3, 4)
|
||||
x = x.reshape(1 * H, B // 1 * W, C)
|
||||
x = (x * 255).astype(mx.uint8)
|
||||
if t_prev.item() ==0:
|
||||
is_finished=True
|
||||
mx.eval(x)
|
||||
|
||||
return x, {'conditioning':conditioning, 'mask':mask,'residual':residual,'x_t_prev':x_t_prev,'is_finished':is_finished,'is_step_finished':is_step_finished, 'step':step, 'total_steps':total_steps, 'start_step':start_step, 'image':image}
|
||||
|
||||
|
||||
def load(self):
|
||||
if self.shard_encoder.start_layer != -1:
|
||||
vae_weights = mx.load(self.config_vae.weight_files[0])
|
||||
vae_weights = self.encoder.sanitize(vae_weights)
|
||||
self.encoder.load_weights(list(vae_weights.items()), strict=True)
|
||||
if self.shard_decoder.start_layer != -1:
|
||||
vae_weights = mx.load(self.config_vae.weight_files[0])
|
||||
vae_weights = self.decoder.sanitize(vae_weights)
|
||||
self.decoder.load_weights(list(vae_weights.items()), strict=True)
|
||||
if self.shard_clip.start_layer != -1:
|
||||
clip_weights = mx.load(self.config_clip.weight_files[0])
|
||||
clip_weights = self.text_encoder.sanitize(clip_weights)
|
||||
self.text_encoder.load_weights(list(clip_weights.items()), strict=True)
|
||||
if self.shard_unet.start_layer !=-1:
|
||||
unet_weights = mx.load(self.config_unet.weight_files[0])
|
||||
unet_weights = self.unet.sanitize(unet_weights)
|
||||
self.unet.load_weights(list(unet_weights.items()), strict=True)
|
||||
|
||||
def model_shards(shard:ShardConfig):
|
||||
def create_shard(shard, model_ranges):
|
||||
start_layer = shard.start_layer
|
||||
end_layer = shard.end_layer
|
||||
|
||||
shards = {}
|
||||
|
||||
for model_name, (range_start, range_end) in model_ranges.items():
|
||||
if start_layer < range_end and end_layer >= range_start:
|
||||
# Calculate the overlap with the model range
|
||||
overlap_start = max(start_layer, range_start)
|
||||
overlap_end = min(end_layer, range_end - 1)
|
||||
|
||||
# Adjust the layers relative to the model's range
|
||||
relative_start = overlap_start - range_start
|
||||
relative_end = overlap_end - range_start
|
||||
shards[model_name] = Shard(model_name, relative_start, relative_end, range_end - range_start)
|
||||
else:
|
||||
# If no overlap, create a zero-layer shard
|
||||
shards[model_name] = Shard(model_name, -1, -1, range_end - range_start)
|
||||
|
||||
return shards
|
||||
|
||||
# Define the ranges for different models
|
||||
model_ranges = {
|
||||
'clip': (0, 12),
|
||||
'vae_encoder':(12,17),
|
||||
'unet':(17,26),
|
||||
'vae_decoder': (26, 31) # Example range for unet
|
||||
}
|
||||
|
||||
# Call the function and get the shards for all models
|
||||
shards = create_shard(shard, model_ranges)
|
||||
|
||||
# Access individual shards
|
||||
shard_clip = shards['clip']
|
||||
shard_encoder = shards['vae_encoder']
|
||||
shard_unet = shards['unet']
|
||||
shard_decoder = shards['vae_decoder']
|
||||
|
||||
return shard_clip, shard_encoder, shard_unet, shard_decoder
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
from typing import Optional
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_lm.models.cache import KVCache
|
||||
|
||||
|
||||
class IdentityBlock(nn.Module):
|
||||
def __call__(self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[KVCache] = None) -> mx.array:
|
||||
return x
|
||||
@@ -1,127 +0,0 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_lm.models.cache import KVCache
|
||||
from mlx_lm.models.deepseek_v2 import ModelArgs, DeepseekV2DecoderLayer
|
||||
from .base import IdentityBlock
|
||||
from exo.inference.shard import Shard
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(ModelArgs):
|
||||
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.shard, Shard):
|
||||
return
|
||||
if not isinstance(self.shard, dict):
|
||||
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
|
||||
|
||||
self.shard = Shard(**self.shard)
|
||||
|
||||
|
||||
class DeepseekV2Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = config
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
self.vocab_size = config.vocab_size
|
||||
if self.args.shard.is_first_layer():
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
|
||||
self.layers = []
|
||||
for i in range(self.num_hidden_layers):
|
||||
if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
|
||||
self.layers.append(DeepseekV2DecoderLayer(config, i))
|
||||
else:
|
||||
self.layers.append(IdentityBlock())
|
||||
|
||||
if self.args.shard.is_last_layer():
|
||||
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache: Optional[KVCache] = None,
|
||||
) -> mx.array:
|
||||
if self.args.shard.is_first_layer():
|
||||
h = self.embed_tokens(x)
|
||||
else:
|
||||
h = x
|
||||
|
||||
mask = None
|
||||
T = h.shape[1]
|
||||
if T > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
|
||||
mask = mask.astype(h.dtype)
|
||||
|
||||
if cache is None:
|
||||
cache = [None]*len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
if self.args.shard.is_last_layer():
|
||||
h = self.norm(h)
|
||||
return h
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = config
|
||||
self.model_type = config.model_type
|
||||
self.model = DeepseekV2Model(config)
|
||||
if self.args.shard.is_last_layer():
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[KVCache] = None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
if self.args.shard.is_last_layer():
|
||||
return self.lm_head(out)
|
||||
return out
|
||||
|
||||
def sanitize(self, weights):
|
||||
shard_state_dict = {}
|
||||
|
||||
for key, value in weights.items():
|
||||
if key.startswith('model.layers.'):
|
||||
layer_num = int(key.split('.')[2])
|
||||
if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer:
|
||||
shard_state_dict[key] = value
|
||||
elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'):
|
||||
shard_state_dict[key] = value
|
||||
elif self.args.shard.is_last_layer() and (key.startswith('model.norm') or key.startswith('lm_head')):
|
||||
shard_state_dict[key] = value
|
||||
|
||||
for l in range(self.args.num_hidden_layers):
|
||||
prefix = f"model.layers.{l}"
|
||||
for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
|
||||
for k in ["weight", "scales", "biases"]:
|
||||
if f"{prefix}.mlp.experts.0.{m}.{k}" in shard_state_dict:
|
||||
to_join = [shard_state_dict.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}") for e in range(self.args.n_routed_experts)]
|
||||
shard_state_dict[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
|
||||
|
||||
return shard_state_dict
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return (
|
||||
self.args.qk_nope_head_dim + self.args.qk_rope_head_dim,
|
||||
self.args.v_head_dim,
|
||||
)
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
@@ -1,134 +0,0 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_lm.models.cache import KVCache
|
||||
from mlx_lm.models.deepseek_v3 import (
|
||||
ModelArgs as V3ModelArgs,
|
||||
DeepseekV3DecoderLayer,
|
||||
)
|
||||
from .base import IdentityBlock
|
||||
from exo.inference.shard import Shard
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(V3ModelArgs):
|
||||
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.shard, Shard):
|
||||
return
|
||||
if not isinstance(self.shard, dict):
|
||||
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
|
||||
|
||||
self.shard = Shard(**self.shard)
|
||||
|
||||
|
||||
class DeepseekV3Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = config
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
self.vocab_size = config.vocab_size
|
||||
if self.args.shard.is_first_layer():
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
|
||||
self.layers = []
|
||||
for i in range(self.num_hidden_layers):
|
||||
if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
|
||||
self.layers.append(DeepseekV3DecoderLayer(config, i))
|
||||
else:
|
||||
self.layers.append(IdentityBlock())
|
||||
|
||||
if self.args.shard.is_last_layer():
|
||||
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache: Optional[KVCache] = None,
|
||||
) -> mx.array:
|
||||
if self.args.shard.is_first_layer():
|
||||
h = self.embed_tokens(x)
|
||||
else:
|
||||
h = x
|
||||
|
||||
mask = None
|
||||
T = h.shape[1]
|
||||
if T > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
|
||||
mask = mask.astype(h.dtype)
|
||||
|
||||
if cache is None:
|
||||
cache = [None]*len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
if self.args.shard.is_last_layer():
|
||||
h = self.norm(h)
|
||||
return h
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = config
|
||||
self.model_type = config.model_type
|
||||
self.model = DeepseekV3Model(config)
|
||||
if self.args.shard.is_last_layer():
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[KVCache] = None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
if self.args.shard.is_last_layer():
|
||||
return self.lm_head(out)
|
||||
return out
|
||||
|
||||
def sanitize(self, weights):
|
||||
shard_state_dict = {}
|
||||
|
||||
for key, value in weights.items():
|
||||
if key.startswith('model.layers.'):
|
||||
layer_num = int(key.split('.')[2])
|
||||
if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer:
|
||||
shard_state_dict[key] = value
|
||||
elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'):
|
||||
shard_state_dict[key] = value
|
||||
elif self.args.shard.is_last_layer() and (key.startswith('model.norm') or key.startswith('lm_head')):
|
||||
shard_state_dict[key] = value
|
||||
|
||||
for l in range(self.args.num_hidden_layers):
|
||||
prefix = f"model.layers.{l}"
|
||||
for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
|
||||
for k in ["weight", "scales", "biases"]:
|
||||
expert_key = f"{prefix}.mlp.experts.0.{m}.{k}"
|
||||
if expert_key in shard_state_dict:
|
||||
to_join = [
|
||||
shard_state_dict.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}")
|
||||
for e in range(self.args.n_routed_experts)
|
||||
]
|
||||
shard_state_dict[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
|
||||
|
||||
return shard_state_dict
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return (
|
||||
self.args.qk_nope_head_dim + self.args.qk_rope_head_dim,
|
||||
self.args.v_head_dim,
|
||||
)
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
@@ -1,118 +0,0 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_lm.models.base import create_attention_mask
|
||||
from mlx_lm.models.gemma2 import TransformerBlock, ModelArgs, RMSNorm
|
||||
|
||||
from ...shard import Shard
|
||||
from .base import IdentityBlock
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(ModelArgs):
|
||||
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.shard, Shard):
|
||||
return
|
||||
if not isinstance(self.shard, dict):
|
||||
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
|
||||
|
||||
self.shard = Shard(**self.shard)
|
||||
|
||||
|
||||
class GemmaModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
if args.shard.is_first_layer() or args.shard.is_last_layer():
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = []
|
||||
for i in range(self.num_hidden_layers):
|
||||
if args.shard.start_layer <= i <= args.shard.end_layer:
|
||||
self.layers.append(TransformerBlock(args=args))
|
||||
else:
|
||||
self.layers.append(IdentityBlock())
|
||||
if args.shard.is_last_layer():
|
||||
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
if self.args.shard.is_first_layer():
|
||||
h = self.embed_tokens(inputs)
|
||||
h = h * (self.args.hidden_size**0.5)
|
||||
else:
|
||||
h = inputs
|
||||
|
||||
mask = None
|
||||
if h.ndim > 1 and h.shape[1] > 1:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None]*len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, cache=c)
|
||||
|
||||
if self.args.shard.is_last_layer():
|
||||
h = self.norm(h)
|
||||
return h
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = GemmaModel(args)
|
||||
if args.shard.is_last_layer():
|
||||
self.final_logit_softcapping = args.final_logit_softcapping
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
if self.args.shard.is_last_layer():
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
out = mx.tanh(out / self.final_logit_softcapping)
|
||||
out = out * self.final_logit_softcapping
|
||||
return out
|
||||
|
||||
def sanitize(self, weights):
|
||||
shard_state_dict = {}
|
||||
|
||||
for key, value in weights.items():
|
||||
if "self_attn.rotary_emb.inv_freq" in key:
|
||||
continue
|
||||
if key.startswith('model.layers.'):
|
||||
layer_num = int(key.split('.')[2])
|
||||
if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer:
|
||||
shard_state_dict[key] = value
|
||||
elif (self.args.shard.is_first_layer() or self.args.shard.is_last_layer()) and key.startswith('model.embed_tokens'):
|
||||
shard_state_dict[key] = value
|
||||
elif self.args.shard.is_last_layer() and (key.startswith('model.norm')):
|
||||
shard_state_dict[key] = value
|
||||
|
||||
return shard_state_dict
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.head_dim
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
@@ -1,125 +0,0 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_lm.models.base import create_attention_mask
|
||||
from mlx_lm.models.llama import TransformerBlock, ModelArgs
|
||||
|
||||
from ...shard import Shard
|
||||
from .base import IdentityBlock
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(ModelArgs):
|
||||
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__() # Ensure parent initializations are respected
|
||||
|
||||
if isinstance(self.shard, Shard):
|
||||
return
|
||||
if not isinstance(self.shard, dict):
|
||||
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
|
||||
|
||||
self.shard = Shard(**self.shard)
|
||||
|
||||
|
||||
class LlamaModel(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
if args.shard.is_first_layer() or (args.shard.is_last_layer() and args.tie_word_embeddings):
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.layers = []
|
||||
for i in range(self.num_hidden_layers):
|
||||
if args.shard.start_layer <= i <= args.shard.end_layer:
|
||||
self.layers.append(TransformerBlock(args=args))
|
||||
else:
|
||||
self.layers.append(IdentityBlock())
|
||||
if args.shard.is_last_layer():
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
if self.args.shard.is_first_layer():
|
||||
h = self.embed_tokens(inputs)
|
||||
else:
|
||||
h = inputs
|
||||
|
||||
mask = None
|
||||
if h.ndim > 1 and h.shape[1] > 1:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None]*len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, cache=c)
|
||||
|
||||
if self.args.shard.is_last_layer():
|
||||
h = self.norm(h)
|
||||
return h
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = LlamaModel(args)
|
||||
if args.shard.is_last_layer():
|
||||
if not args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
if self.args.shard.is_last_layer():
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
else:
|
||||
out = self.lm_head(out)
|
||||
return out
|
||||
|
||||
def sanitize(self, weights):
|
||||
shard_state_dict = {}
|
||||
|
||||
for key, value in weights.items():
|
||||
if "self_attn.rotary_emb.inv_freq" in key:
|
||||
continue
|
||||
if key.startswith('model.layers.'):
|
||||
layer_num = int(key.split('.')[2])
|
||||
if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer:
|
||||
shard_state_dict[key] = value
|
||||
elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'):
|
||||
shard_state_dict[key] = value
|
||||
elif (self.args.shard.is_last_layer() and self.args.tie_word_embeddings) and key.startswith('model.embed_tokens'):
|
||||
shard_state_dict[key] = value
|
||||
elif (self.args.shard.is_last_layer() and not self.args.tie_word_embeddings) and key.startswith('lm_head'):
|
||||
shard_state_dict[key] = value
|
||||
elif self.args.shard.is_last_layer() and (key.startswith('model.norm')):
|
||||
shard_state_dict[key] = value
|
||||
|
||||
return shard_state_dict
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return (self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads)
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
@@ -1,585 +0,0 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import math
|
||||
import inspect
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Dict, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_lm.models.base import BaseModelArgs, KVCache
|
||||
from exo.inference.shard import Shard
|
||||
from .base import IdentityBlock
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass
|
||||
class VisionConfig:
|
||||
model_type: str
|
||||
num_hidden_layers: int = 24
|
||||
hidden_size: int = 1024
|
||||
intermediate_size: int = 4096
|
||||
num_attention_heads: int = 16
|
||||
image_size: int = 336
|
||||
patch_size: int = 14
|
||||
projection_dim: int = 768
|
||||
vocab_size: int = 32000
|
||||
num_channels: int = 3
|
||||
layer_norm_eps: float = 1e-5
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, params):
|
||||
return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
|
||||
|
||||
|
||||
class VisionAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
num_heads: int,
|
||||
query_input_dims: Optional[int] = None,
|
||||
key_input_dims: Optional[int] = None,
|
||||
value_input_dims: Optional[int] = None,
|
||||
value_dims: Optional[int] = None,
|
||||
value_output_dims: Optional[int] = None,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if (dims % num_heads) != 0:
|
||||
raise ValueError("The input feature dimensions should be divisible by the "
|
||||
f"number of heads ({dims} % {num_heads}) != 0")
|
||||
|
||||
query_input_dims = query_input_dims or dims
|
||||
key_input_dims = key_input_dims or dims
|
||||
value_input_dims = value_input_dims or key_input_dims
|
||||
value_dims = value_dims or dims
|
||||
value_output_dims = value_output_dims or dims
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.q_proj = nn.Linear(query_input_dims, dims, bias=bias)
|
||||
self.k_proj = nn.Linear(key_input_dims, dims, bias=bias)
|
||||
self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias)
|
||||
self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias)
|
||||
|
||||
def __call__(self, queries, keys, values, mask=None):
|
||||
queries = self.q_proj(queries)
|
||||
keys = self.k_proj(keys)
|
||||
values = self.v_proj(values)
|
||||
|
||||
num_heads = self.num_heads
|
||||
B, L, D = queries.shape
|
||||
_, S, _ = keys.shape
|
||||
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
|
||||
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
scale = math.sqrt(1/queries.shape[-1])
|
||||
scores = (queries*scale) @ keys
|
||||
if mask is not None:
|
||||
scores = scores + mask.astype(scores.dtype)
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
return self.out_proj(values_hat)
|
||||
|
||||
|
||||
class VisionMLP(nn.Module):
|
||||
def __init__(self, config: VisionConfig):
|
||||
super().__init__()
|
||||
self.activation_fn = nn.GELU(approx="fast")
|
||||
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
x = self.activation_fn(self.fc1(x))
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
|
||||
class VisionEncoderLayer(nn.Module):
|
||||
def __init__(self, config: VisionConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
self.self_attn = VisionAttention(config.hidden_size, config.num_attention_heads, bias=True)
|
||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.mlp = VisionMLP(config)
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
|
||||
y = self.layer_norm1(x)
|
||||
y = self.self_attn(y, y, y, mask)
|
||||
x = x + y
|
||||
y = self.layer_norm2(x)
|
||||
y = self.mlp(y)
|
||||
return x + y
|
||||
|
||||
|
||||
class VisionEncoder(nn.Module):
|
||||
def __init__(self, config: VisionConfig):
|
||||
super().__init__()
|
||||
self.layers = [VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]
|
||||
|
||||
|
||||
class VisionEmbeddings(nn.Module):
|
||||
def __init__(self, config: VisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.image_size = config.image_size
|
||||
self.patch_size = config.patch_size
|
||||
|
||||
self.class_embedding = mx.zeros((config.hidden_size,))
|
||||
|
||||
self.patch_embedding = nn.Conv2d(
|
||||
in_channels=config.num_channels,
|
||||
out_channels=self.embed_dim,
|
||||
kernel_size=self.patch_size,
|
||||
stride=self.patch_size,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.num_patches = (self.image_size // self.patch_size)**2
|
||||
self.num_positions = self.num_patches + 1
|
||||
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
batch_size = x.shape[0]
|
||||
patch_embeddings = self.patch_embedding(x)
|
||||
patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2)
|
||||
embed_dim = patch_embeddings.shape[-1]
|
||||
cls_embeddings = mx.broadcast_to(self.class_embedding, (batch_size, 1, embed_dim))
|
||||
embeddings = mx.concatenate((cls_embeddings, patch_embeddings), axis=1)
|
||||
embeddings += self.position_embedding.weight
|
||||
return embeddings
|
||||
|
||||
|
||||
class ClipVisionModel(nn.Module):
|
||||
def __init__(self, config: VisionConfig):
|
||||
super().__init__()
|
||||
self.embeddings = VisionEmbeddings(config)
|
||||
self.pre_layrnorm = nn.LayerNorm(config.hidden_size)
|
||||
self.encoder = VisionEncoder(config)
|
||||
self.post_layernorm = nn.LayerNorm(config.hidden_size)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
) -> mx.array:
|
||||
x = self.embeddings(x)
|
||||
x = self.pre_layrnorm(x)
|
||||
|
||||
encoder_states = (x,) if output_hidden_states else None
|
||||
|
||||
for l in self.encoder.layers:
|
||||
x = l(x, mask=None)
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (x,)
|
||||
|
||||
pooler_output = self.post_layernorm(x[:, 0, :])
|
||||
return pooler_output, x, encoder_states
|
||||
|
||||
|
||||
class VisionModel(nn.Module):
|
||||
def __init__(self, config: VisionConfig):
|
||||
super().__init__()
|
||||
|
||||
self.model_type = config.model_type
|
||||
if self.model_type != "clip_vision_model":
|
||||
raise ValueError(f"Unsupported model type: {self.model_type}")
|
||||
|
||||
self.vision_model = ClipVisionModel(config)
|
||||
|
||||
def __call__(self, x: mx.array, output_hidden_states: Optional[bool] = None) -> mx.array:
|
||||
return self.vision_model(x, output_hidden_states)
|
||||
|
||||
def sanitize(self, weights):
|
||||
sanitized_weights = {}
|
||||
for k, v in weights.items():
|
||||
if "position_ids" in k:
|
||||
# Remove unused position_ids
|
||||
continue
|
||||
elif "patch_embedding.weight" in k:
|
||||
# PyTorch conv2d weight tensors have shape:
|
||||
# [out_channels, in_channels, kH, KW]
|
||||
# MLX conv2d expects the weight be of shape:
|
||||
# [out_channels, kH, KW, in_channels]
|
||||
sanitized_weights[k] = v.transpose(0, 2, 3, 1)
|
||||
else:
|
||||
sanitized_weights[k] = v
|
||||
|
||||
return sanitized_weights
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextConfig:
|
||||
model_type: str
|
||||
hidden_size: int = 4096
|
||||
num_hidden_layers: int = 32
|
||||
intermediate_size: int = 11008
|
||||
num_attention_heads: int = 32
|
||||
head_dim: int = None
|
||||
rms_norm_eps: float = 1e-6
|
||||
vocab_size: int = 32000
|
||||
num_key_value_heads: int = None
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, params):
|
||||
return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
if self.head_dim is None:
|
||||
self.head_dim = self.hidden_size // self.num_attention_heads
|
||||
|
||||
if self.model_type is None:
|
||||
self.model_type = "llama"
|
||||
|
||||
if self.rope_scaling:
|
||||
required_keys = {"factor", "type"}
|
||||
if not all(key in self.rope_scaling for key in required_keys):
|
||||
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
||||
|
||||
if self.rope_scaling["type"] != "linear":
|
||||
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
|
||||
|
||||
|
||||
class TextAttention(nn.Module):
|
||||
def __init__(self, config: TextConfig):
|
||||
super().__init__()
|
||||
|
||||
dim = config.hidden_size
|
||||
self.n_heads = n_heads = config.num_attention_heads
|
||||
self.n_kv_heads = n_kv_heads = config.num_key_value_heads
|
||||
|
||||
self.repeats = n_heads // n_kv_heads
|
||||
|
||||
head_dim = config.hidden_size // n_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(dim, n_heads*head_dim, bias=False)
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads*head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads*head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(n_heads*head_dim, dim, bias=False)
|
||||
|
||||
rope_scale = (1/config.rope_scaling["factor"] if config.rope_scaling is not None and config.rope_scaling["type"] == "linear" else 1)
|
||||
self.rope = nn.RoPE(
|
||||
head_dim,
|
||||
traditional=config.rope_traditional,
|
||||
base=config.rope_theta,
|
||||
scale=rope_scale,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset)
|
||||
keys = self.rope(keys, offset=cache.offset)
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
else:
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
output = mx.fast.scaled_dot_product_attention(queries, keys, values, scale=self.scale, mask=mask)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output)
|
||||
|
||||
|
||||
class TextMLP(nn.Module):
|
||||
def __init__(self, dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def __call__(self, x) -> mx.array:
|
||||
return self.down_proj(nn.silu(self.gate_proj(x))*self.up_proj(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, config: TextConfig):
|
||||
super().__init__()
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = TextAttention(config)
|
||||
self.mlp = TextMLP(config.hidden_size, config.intermediate_size)
|
||||
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.config = config
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
return out
|
||||
|
||||
|
||||
class Llama(nn.Module):
|
||||
def __init__(self, config: TextConfig, shard: Shard):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.shard = shard
|
||||
self.vocab_size = config.vocab_size
|
||||
self.model_type = config.model_type
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.head_dim = config.head_dim
|
||||
assert self.vocab_size > 0
|
||||
if self.shard.is_first_layer():
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = []
|
||||
for i in range(self.num_hidden_layers):
|
||||
if self.shard.start_layer <= i <= self.shard.end_layer:
|
||||
self.layers.append(TransformerBlock(config=config))
|
||||
else:
|
||||
self.layers.append(IdentityBlock())
|
||||
if self.shard.is_last_layer():
|
||||
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
inputs_embeds=None,
|
||||
):
|
||||
# for passing merged input embeddings
|
||||
if inputs_embeds is None:
|
||||
if self.shard.is_first_layer():
|
||||
h = self.embed_tokens(inputs)
|
||||
else:
|
||||
h = inputs
|
||||
else:
|
||||
h = inputs_embeds
|
||||
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
|
||||
mask = mask.astype(h.dtype)
|
||||
|
||||
if cache is None:
|
||||
cache = [None]*len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
if self.shard.is_last_layer():
|
||||
h = self.norm(h)
|
||||
return h
|
||||
|
||||
|
||||
class LanguageModel(nn.Module):
|
||||
def __init__(self, config: TextConfig, shard: Shard):
|
||||
super().__init__()
|
||||
self.model_type = config.model_type
|
||||
if self.model_type != "llama":
|
||||
raise ValueError(f"Model type {self.model_type} not supported. Currently only 'llama' is supported")
|
||||
self.shard = shard
|
||||
self.model = Llama(config, shard)
|
||||
if self.shard.is_last_layer():
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
inputs_embeds=None,
|
||||
):
|
||||
out = self.model(inputs, cache, inputs_embeds)
|
||||
if self.shard.is_last_layer():
|
||||
out = self.lm_head(out)
|
||||
return out
|
||||
|
||||
def sanitize(self, weights):
|
||||
shard_state_dict = {}
|
||||
for key, value in weights.items():
|
||||
if "self_attn.rotary_emb.inv_freq" in key:
|
||||
continue
|
||||
|
||||
if key.startswith('language_model.model.layers.'):
|
||||
layer_num = int(key.split('.')[3])
|
||||
if layer_num < self.shard.start_layer or layer_num > self.shard.end_layer:
|
||||
continue
|
||||
if not self.shard.is_first_layer() and key.startswith('language_model.model.embed_tokens'):
|
||||
continue
|
||||
elif not self.shard.is_last_layer() and (key.startswith('language_model.model.norm') or key.startswith('language_model.lm_head')):
|
||||
continue
|
||||
|
||||
shard_state_dict[key] = value
|
||||
|
||||
return shard_state_dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class LlaVAConfig(BaseModelArgs):
|
||||
text_config: TextConfig
|
||||
vision_config: VisionConfig = None
|
||||
model_type: str = "llava"
|
||||
ignore_index: int = -100
|
||||
image_token_index: int = 32000
|
||||
vision_feature_select_strategy: str = "default"
|
||||
vision_feature_layer: int = -2
|
||||
vocab_size: int = 32000
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, params):
|
||||
updated_params = {}
|
||||
class_params = inspect.signature(cls).parameters
|
||||
for k, v in params.items():
|
||||
if k in class_params:
|
||||
if k in ["text_config", "vision_config"]:
|
||||
v = class_params[k].annotation.from_dict(v)
|
||||
updated_params.update({k: v})
|
||||
|
||||
return cls(**updated_params)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(LlaVAConfig):
|
||||
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.shard, dict):
|
||||
self.shard = Shard(**self.shard)
|
||||
|
||||
if not isinstance(self.shard, Shard):
|
||||
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
|
||||
|
||||
if not self.shard.is_first_layer():
|
||||
self.vision_config = None
|
||||
|
||||
|
||||
class LlavaMultiModalProjector(nn.Module):
|
||||
def __init__(self, config: LlaVAConfig):
|
||||
super().__init__()
|
||||
self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
|
||||
self.gelu = nn.GELU()
|
||||
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
x = self.linear_1(x)
|
||||
x = self.gelu(x)
|
||||
x = self.linear_2(x)
|
||||
return x
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model_type = config.model_type
|
||||
if config.vision_config:
|
||||
self.vision_tower = VisionModel(config.vision_config)
|
||||
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
||||
self.vision_feature_layer = config.vision_feature_layer
|
||||
self.vision_feature_select_strategy = config.vision_feature_select_strategy
|
||||
self.language_model = LanguageModel(config.text_config, config.shard)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: Optional[mx.array] = None,
|
||||
pixel_values: Optional[mx.array] = None,
|
||||
):
|
||||
if pixel_values is None:
|
||||
return self.language_model(input_ids)
|
||||
|
||||
# Get the input embeddings from the language model
|
||||
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||
|
||||
# Get the ouptut hidden states from the vision model
|
||||
*_, hidden_states = self.vision_tower(pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True)
|
||||
|
||||
# Select the hidden states from the desired layer
|
||||
selected_image_feature = hidden_states[self.vision_feature_layer]
|
||||
|
||||
if self.vision_feature_select_strategy == "default":
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif self.vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
else:
|
||||
raise ValueError("Unexpected feature selection strategy: "
|
||||
f"{self.vision_feature_select_strategy}")
|
||||
|
||||
# Pass image features through the multi-modal projector
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
|
||||
# Insert special image tokens in the input_ids
|
||||
final_inputs_embeds = self._merge_input_ids_with_image_features(image_features, inputs_embeds, input_ids)
|
||||
return final_inputs_embeds
|
||||
|
||||
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids):
|
||||
image_token_index = self.config.image_token_index
|
||||
num_images, num_image_patches, embed_dim = image_features.shape
|
||||
|
||||
# Positions of <image> tokens in input_ids, assuming batch size is 1
|
||||
image_positions = np.where(input_ids[0] == image_token_index)[0].tolist()
|
||||
|
||||
if len(image_positions) != num_images:
|
||||
raise ValueError(f"The number of image tokens ({len(image_positions)}) does not "
|
||||
f" match the number of image inputs ({num_images}).")
|
||||
|
||||
text_segments = []
|
||||
start_idx = 0
|
||||
|
||||
for position in image_positions:
|
||||
text_segments.append(inputs_embeds[:, start_idx:position])
|
||||
start_idx = position + 1
|
||||
|
||||
image_embeddings = mx.split(image_features, image_features.shape[0])
|
||||
final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p]
|
||||
final_embeddings += [inputs_embeds[:, start_idx:]]
|
||||
|
||||
# Create a final embedding of shape
|
||||
# (1, num_image_patches*num_images + sequence_len, embed_dim)
|
||||
return mx.concatenate(final_embeddings, axis=1)
|
||||
|
||||
def __call__(self, input_ids: mx.array, pixel_values: mx.array = None, cache=None):
|
||||
input_embddings = None
|
||||
if pixel_values is not None:
|
||||
input_embddings = self.get_input_embeddings(input_ids, pixel_values)
|
||||
logits = self.language_model(input_ids, cache=cache, inputs_embeds=input_embddings)
|
||||
return logits
|
||||
|
||||
def sanitize(self, weights):
|
||||
if self.config.vision_config:
|
||||
weights = self.vision_tower.sanitize(weights)
|
||||
else:
|
||||
weights = {k: v for k, v in weights.items() if not k.startswith(('vision_tower', 'multi_modal_projector', 'vision_feature_layer', 'vision_feature_select_strategy'))}
|
||||
weights = self.language_model.sanitize(weights)
|
||||
return weights
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.language_model.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return (self.language_model.model.head_dim or self.language_model.model.hidden_size // self.language_model.model.num_attention_heads)
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.language_model.model.num_key_value_heads
|
||||
@@ -1,117 +0,0 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_lm.models.base import create_attention_mask
|
||||
from mlx_lm.models.phi3 import TransformerBlock, ModelArgs
|
||||
|
||||
from ...shard import Shard
|
||||
from .base import IdentityBlock
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(ModelArgs):
|
||||
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
if isinstance(self.shard, Shard):
|
||||
return
|
||||
if not isinstance(self.shard, dict):
|
||||
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
|
||||
|
||||
self.shard = Shard(**self.shard)
|
||||
|
||||
class Phi3Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
|
||||
if self.args.shard.is_first_layer():
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
|
||||
self.layers = []
|
||||
for i in range(self.num_hidden_layers):
|
||||
if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
|
||||
self.layers.append(TransformerBlock(args=args))
|
||||
else:
|
||||
self.layers.append(IdentityBlock())
|
||||
|
||||
if self.args.shard.is_last_layer():
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
if self.args.shard.is_first_layer():
|
||||
h = self.embed_tokens(inputs)
|
||||
else:
|
||||
h = inputs
|
||||
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
if self.args.shard.is_last_layer():
|
||||
h = self.norm(h)
|
||||
return h
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = Phi3Model(args)
|
||||
if self.args.shard.is_last_layer():
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
if self.args.shard.is_last_layer():
|
||||
out = self.lm_head(out)
|
||||
return out
|
||||
|
||||
def sanitize(self, weights):
|
||||
shard_state_dict = {}
|
||||
|
||||
for key, value in weights.items():
|
||||
if "self_attn.rope.inv_freq" in key:
|
||||
continue
|
||||
if key.startswith('model.layers.'):
|
||||
layer_num = int(key.split('.')[2])
|
||||
if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer:
|
||||
shard_state_dict[key] = value
|
||||
elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'):
|
||||
shard_state_dict[key] = value
|
||||
elif self.args.shard.is_last_layer() and (key.startswith('lm_head') or key.startswith('model.norm')):
|
||||
shard_state_dict[key] = value
|
||||
|
||||
return shard_state_dict
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
@@ -1,129 +0,0 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_lm.models.base import create_attention_mask
|
||||
from mlx_lm.models.qwen2 import TransformerBlock, ModelArgs
|
||||
|
||||
from ...shard import Shard
|
||||
from .base import IdentityBlock
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(ModelArgs):
|
||||
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
if isinstance(self.shard, Shard):
|
||||
return
|
||||
if not isinstance(self.shard, dict):
|
||||
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
|
||||
|
||||
self.shard = Shard(**self.shard)
|
||||
|
||||
class Qwen2Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.vocab_size = args.vocab_size
|
||||
self.num_hidden_layers = args.num_hidden_layers
|
||||
assert self.vocab_size > 0
|
||||
|
||||
if self.args.shard.is_first_layer() or (self.args.shard.is_last_layer() and args.tie_word_embeddings):
|
||||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
|
||||
self.layers = []
|
||||
for i in range(self.num_hidden_layers):
|
||||
if self.args.shard.start_layer <= i <= self.args.shard.end_layer:
|
||||
self.layers.append(TransformerBlock(args=args))
|
||||
else:
|
||||
self.layers.append(IdentityBlock())
|
||||
|
||||
if self.args.shard.is_last_layer():
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
if self.args.shard.is_first_layer():
|
||||
h = self.embed_tokens(inputs)
|
||||
else:
|
||||
h = inputs
|
||||
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None]*len(self.layers)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
|
||||
if self.args.shard.is_last_layer():
|
||||
h = self.norm(h)
|
||||
return h
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.model_type = args.model_type
|
||||
self.model = Qwen2Model(args)
|
||||
if self.args.shard.is_last_layer():
|
||||
if not args.tie_word_embeddings:
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
if self.args.shard.is_last_layer():
|
||||
if self.args.tie_word_embeddings:
|
||||
out = self.model.embed_tokens.as_linear(out)
|
||||
else:
|
||||
out = self.lm_head(out)
|
||||
return out
|
||||
|
||||
def sanitize(self, weights):
|
||||
shard_state_dict = {}
|
||||
|
||||
for key, value in weights.items():
|
||||
if "self_attn.rotary_emb.inv_freq" in key:
|
||||
continue
|
||||
if key.startswith('model.layers.'):
|
||||
layer_num = int(key.split('.')[2])
|
||||
if self.args.shard.start_layer <= layer_num <= self.args.shard.end_layer:
|
||||
shard_state_dict[key] = value
|
||||
elif self.args.shard.is_first_layer() and key.startswith('model.embed_tokens'):
|
||||
shard_state_dict[key] = value
|
||||
elif (self.args.shard.is_last_layer() and self.args.tie_word_embeddings) and key.startswith('model.embed_tokens'):
|
||||
shard_state_dict[key] = value
|
||||
elif (self.args.shard.is_last_layer() and not self.args.tie_word_embeddings) and key.startswith('lm_head'):
|
||||
shard_state_dict[key] = value
|
||||
elif self.args.shard.is_last_layer() and (key.startswith('model.norm')):
|
||||
shard_state_dict[key] = value
|
||||
|
||||
if self.args.tie_word_embeddings:
|
||||
shard_state_dict.pop("lm_head.weight", None)
|
||||
|
||||
return shard_state_dict
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
@@ -1,191 +0,0 @@
|
||||
# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/clip.py
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from dataclasses import field, dataclass
|
||||
from exo.inference.shard import Shard
|
||||
from exo.inference.mlx.models.base import IdentityBlock
|
||||
|
||||
_ACTIVATIONS = {"quick_gelu": nn.gelu_fast_approx, "gelu": nn.gelu}
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class CLIPTextModelConfig:
|
||||
num_layers: int = 23
|
||||
model_dims: int = 1024
|
||||
num_heads: int = 16
|
||||
max_length: int = 77
|
||||
vocab_size: int = 49408
|
||||
projection_dim: Optional[int] = None
|
||||
hidden_act: str = "quick_gelu"
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config):
|
||||
return ModelArgs(
|
||||
num_layers=config["num_hidden_layers"],
|
||||
model_dims=config["hidden_size"],
|
||||
num_heads=config["num_attention_heads"],
|
||||
max_length=config["max_position_embeddings"],
|
||||
vocab_size=config["vocab_size"],
|
||||
projection_dim=config["projection_dim"] if "WithProjection" in config['architectures'][0] else None,
|
||||
hidden_act=config.get("hidden_act", "quick_gelu"),
|
||||
weight_files=config.get("weight_files", [])
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(CLIPTextModelConfig):
|
||||
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
|
||||
weight_files: List[str] = field(default_factory=lambda: [])
|
||||
def __post_init__(self):
|
||||
if isinstance(self.shard, dict):
|
||||
self.shard = Shard(**self.shard)
|
||||
|
||||
if not isinstance(self.shard, Shard):
|
||||
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
|
||||
|
||||
if not self.shard.is_first_layer():
|
||||
self.vision_config = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class CLIPOutput:
|
||||
pooled_output: Optional[mx.array] = None
|
||||
last_hidden_state: Optional[mx.array] = None
|
||||
hidden_states: Optional[List[mx.array]] = None
|
||||
|
||||
|
||||
class CLIPEncoderLayer(nn.Module):
|
||||
"""The transformer encoder layer from CLIP."""
|
||||
|
||||
def __init__(self, model_dims: int, num_heads: int, activation: str):
|
||||
super().__init__()
|
||||
|
||||
self.layer_norm1 = nn.LayerNorm(model_dims)
|
||||
self.layer_norm2 = nn.LayerNorm(model_dims)
|
||||
|
||||
self.attention = nn.MultiHeadAttention(model_dims, num_heads)
|
||||
self.attention.query_proj.bias = mx.zeros(model_dims)
|
||||
self.attention.key_proj.bias = mx.zeros(model_dims)
|
||||
self.attention.value_proj.bias = mx.zeros(model_dims)
|
||||
self.attention.out_proj.bias = mx.zeros(model_dims)
|
||||
|
||||
self.linear1 = nn.Linear(model_dims, 4 * model_dims)
|
||||
self.linear2 = nn.Linear(4 * model_dims, model_dims)
|
||||
|
||||
self.act = _ACTIVATIONS[activation]
|
||||
|
||||
def __call__(self, x, attn_mask=None):
|
||||
|
||||
y = self.layer_norm1(x)
|
||||
y = self.attention(y, y, y, attn_mask)
|
||||
x = y + x
|
||||
|
||||
y = self.layer_norm2(x)
|
||||
y = self.linear1(y)
|
||||
y = self.act(y)
|
||||
y = self.linear2(y)
|
||||
x = y + x
|
||||
return x
|
||||
|
||||
|
||||
class CLIPTextModel(nn.Module):
|
||||
"""Implements the text encoder transformer from CLIP."""
|
||||
|
||||
def __init__(self, config: CLIPTextModelConfig, shard: Shard):
|
||||
super().__init__()
|
||||
|
||||
self.shard = shard
|
||||
self.layers_range = range(self.shard.start_layer*2, self.shard.end_layer*2+2)
|
||||
if self.shard.is_first_layer():
|
||||
self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims)
|
||||
self.position_embedding = nn.Embedding(config.max_length, config.model_dims)
|
||||
self.layers = []
|
||||
for i in range(math.ceil(config.num_layers/2)):
|
||||
if 2*i in self.layers_range:
|
||||
self.layers.append(CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act))
|
||||
if 2*i+1 in self.layers_range and 2*i+1 < config.num_layers:
|
||||
self.layers.append(CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act))
|
||||
else:
|
||||
self.layers.append(IdentityBlock())
|
||||
if self.shard.is_last_layer():
|
||||
self.final_layer_norm = nn.LayerNorm(config.model_dims)
|
||||
|
||||
if config.projection_dim is not None:
|
||||
self.text_projection = nn.Linear(
|
||||
config.model_dims, config.projection_dim, bias=False
|
||||
)
|
||||
|
||||
def _get_mask(self, N, dtype):
|
||||
indices = mx.arange(N)
|
||||
mask = indices[:, None] < indices[None]
|
||||
mask = mask.astype(dtype) * (-6e4 if dtype == mx.float16 else -1e9)
|
||||
return mask
|
||||
|
||||
def __call__(self, x, mask=None):
|
||||
# Extract some shapes
|
||||
if self.shard.is_first_layer():
|
||||
B, N = x.shape
|
||||
eos_tokens = x.argmax(-1)
|
||||
|
||||
# Compute the embeddings
|
||||
x = self.token_embedding(x)
|
||||
|
||||
x = x + self.position_embedding.weight[:N]
|
||||
# Compute the features from the transformer
|
||||
mask = self._get_mask(N, x.dtype)
|
||||
|
||||
for l in self.layers:
|
||||
x = l(x, mask)
|
||||
# Apply the final layernorm and return
|
||||
|
||||
if self.shard.is_last_layer():
|
||||
x = self.final_layer_norm(x)
|
||||
|
||||
|
||||
|
||||
return x, mask
|
||||
def sanitize(self, weights):
|
||||
sanitized_weights = {}
|
||||
for key, value in weights.items():
|
||||
if "position_ids" in key:
|
||||
continue
|
||||
if key.startswith("text_model."):
|
||||
key = key[11:]
|
||||
if key.startswith("embeddings."):
|
||||
key = key[11:]
|
||||
if key.startswith("encoder."):
|
||||
key = key[8:]
|
||||
|
||||
# Map attention layers
|
||||
if "self_attn." in key:
|
||||
key = key.replace("self_attn.", "attention.")
|
||||
if "q_proj." in key:
|
||||
key = key.replace("q_proj.", "query_proj.")
|
||||
if "k_proj." in key:
|
||||
key = key.replace("k_proj.", "key_proj.")
|
||||
if "v_proj." in key:
|
||||
key = key.replace("v_proj.", "value_proj.")
|
||||
|
||||
# Map ffn layers
|
||||
if "mlp.fc1" in key:
|
||||
key = key.replace("mlp.fc1", "linear1")
|
||||
if "mlp.fc2" in key:
|
||||
key = key.replace("mlp.fc2", "linear2")
|
||||
|
||||
if key.startswith("layers."):
|
||||
layer_num = int(key.split(".")[1])
|
||||
if layer_num not in self.layers_range:
|
||||
continue
|
||||
if not self.shard.is_first_layer() and "embedding" in key:
|
||||
continue
|
||||
if not self.shard.is_last_layer() and key.startswith("final_layer_norm"):
|
||||
continue
|
||||
if not self.shard.is_last_layer() and key.startswith("text_projection"):
|
||||
continue
|
||||
sanitized_weights[key] = value
|
||||
return sanitized_weights
|
||||
@@ -1,131 +0,0 @@
|
||||
# adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/tokenizer.py
|
||||
|
||||
import regex
|
||||
import json
|
||||
import glob
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
"""A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ ."""
|
||||
|
||||
def __init__(self, bpe_ranks, vocab):
|
||||
self.bpe_ranks = bpe_ranks
|
||||
self.vocab = vocab
|
||||
self.pat = regex.compile(
|
||||
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
|
||||
regex.IGNORECASE,
|
||||
)
|
||||
|
||||
self._cache = {self.bos: self.bos, self.eos: self.eos}
|
||||
|
||||
@property
|
||||
def bos(self):
|
||||
return "<|startoftext|>"
|
||||
|
||||
@property
|
||||
def bos_token(self):
|
||||
return self.vocab[self.bos]
|
||||
|
||||
@property
|
||||
def eos(self):
|
||||
return "<|endoftext|>"
|
||||
|
||||
@property
|
||||
def eos_token(self):
|
||||
return self.vocab[self.eos]
|
||||
|
||||
def bpe(self, text):
|
||||
if text in self._cache:
|
||||
return self._cache[text]
|
||||
|
||||
unigrams = list(text[:-1]) + [text[-1] + "</w>"]
|
||||
unique_bigrams = set(zip(unigrams, unigrams[1:]))
|
||||
|
||||
if not unique_bigrams:
|
||||
return unigrams
|
||||
|
||||
# In every iteration try to merge the two most likely bigrams. If none
|
||||
# was merged we are done.
|
||||
#
|
||||
# Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py
|
||||
while unique_bigrams:
|
||||
bigram = min(
|
||||
unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))
|
||||
)
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
|
||||
new_unigrams = []
|
||||
skip = False
|
||||
for a, b in zip(unigrams, unigrams[1:]):
|
||||
if skip:
|
||||
skip = False
|
||||
continue
|
||||
|
||||
if (a, b) == bigram:
|
||||
new_unigrams.append(a + b)
|
||||
skip = True
|
||||
|
||||
else:
|
||||
new_unigrams.append(a)
|
||||
|
||||
if not skip:
|
||||
new_unigrams.append(b)
|
||||
|
||||
unigrams = new_unigrams
|
||||
unique_bigrams = set(zip(unigrams, unigrams[1:]))
|
||||
|
||||
self._cache[text] = unigrams
|
||||
|
||||
return unigrams
|
||||
|
||||
def tokenize(self, text, prepend_bos=True, append_eos=True):
|
||||
if isinstance(text, list):
|
||||
return [self.tokenize(t, prepend_bos, append_eos) for t in text]
|
||||
|
||||
# Lower case cleanup and split according to self.pat. Hugging Face does
|
||||
# a much more thorough job here but this should suffice for 95% of
|
||||
# cases.
|
||||
clean_text = regex.sub(r"\s+", " ", text.lower())
|
||||
tokens = regex.findall(self.pat, clean_text)
|
||||
|
||||
# Split the tokens according to the byte-pair merge file
|
||||
bpe_tokens = [ti for t in tokens for ti in self.bpe(t)]
|
||||
|
||||
# Map to token ids and return
|
||||
tokens = [self.vocab[t] for t in bpe_tokens]
|
||||
if prepend_bos:
|
||||
tokens = [self.bos_token] + tokens
|
||||
if append_eos:
|
||||
tokens.append(self.eos_token)
|
||||
|
||||
return tokens
|
||||
|
||||
def encode(self, prompt):
|
||||
tokens = [self.tokenize(prompt)]
|
||||
negative_text = ""
|
||||
if negative_text is not None:
|
||||
tokens += [self.tokenize(negative_text)]
|
||||
lengths = [len(t) for t in tokens]
|
||||
N = max(lengths)
|
||||
tokens = [t + [0] * (N - len(t)) for t in tokens]
|
||||
return tokens
|
||||
|
||||
def load_tokenizer(
|
||||
model_path: str,
|
||||
vocab_key: str = "tokenizer_vocab",
|
||||
merges_key: str = "tokenizer_merges",
|
||||
):
|
||||
|
||||
vocab_file = glob.glob(str(model_path/"tokenizer"/vocab_key))[0]
|
||||
with open(vocab_file, encoding="utf-8") as f:
|
||||
vocab = json.load(f)
|
||||
|
||||
merges_file = glob.glob(str(model_path/"tokenizer"/merges_key))[0]
|
||||
with open(merges_file, encoding="utf-8") as f:
|
||||
bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
|
||||
bpe_merges = [tuple(m.split()) for m in bpe_merges]
|
||||
bpe_ranks = dict(map(reversed, enumerate(bpe_merges)))
|
||||
|
||||
return Tokenizer(bpe_ranks, vocab)
|
||||
|
||||
@@ -1,629 +0,0 @@
|
||||
# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/unet.py
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Tuple, Optional, List
|
||||
from exo.inference.shard import Shard
|
||||
|
||||
@dataclass
|
||||
class UNetConfig:
|
||||
in_channels: int = 4
|
||||
out_channels: int = 4
|
||||
conv_in_kernel: int = 3
|
||||
conv_out_kernel: int = 3
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
|
||||
layers_per_block: Tuple[int] = (2, 2, 2, 2)
|
||||
mid_block_layers: int = 2
|
||||
transformer_layers_per_block: Tuple[int] = (1, 1, 1, 1)
|
||||
num_attention_heads: Tuple[int] = (5, 10, 20, 20)
|
||||
cross_attention_dim: Tuple[int] = (1024,) * 4
|
||||
norm_num_groups: int = 32
|
||||
down_block_types: Tuple[str] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
)
|
||||
up_block_types: Tuple[str] = (
|
||||
"UpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
)
|
||||
addition_embed_type: Optional[str] = None
|
||||
addition_time_embed_dim: Optional[int] = None
|
||||
projection_class_embeddings_input_dim: Optional[int] = None
|
||||
weight_files: List[str] = field(default_factory=lambda: [])
|
||||
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls,config):
|
||||
n_blocks = len(config['block_out_channels'])
|
||||
return UNetConfig(
|
||||
in_channels=config["in_channels"],
|
||||
out_channels=config["out_channels"],
|
||||
block_out_channels=config["block_out_channels"],
|
||||
layers_per_block=[config["layers_per_block"]] * n_blocks,
|
||||
transformer_layers_per_block=config.get(
|
||||
"transformer_layers_per_block", (1,) * 4
|
||||
),
|
||||
num_attention_heads=(
|
||||
[config["attention_head_dim"]] * n_blocks
|
||||
if isinstance(config["attention_head_dim"], int)
|
||||
else config["attention_head_dim"]
|
||||
),
|
||||
cross_attention_dim=[config["cross_attention_dim"]] * n_blocks,
|
||||
norm_num_groups=config["norm_num_groups"],
|
||||
down_block_types=config["down_block_types"],
|
||||
up_block_types=config["up_block_types"][::-1],
|
||||
addition_embed_type=config.get("addition_embed_type", None),
|
||||
addition_time_embed_dim=config.get("addition_time_embed_dim", None),
|
||||
projection_class_embeddings_input_dim=config.get(
|
||||
"projection_class_embeddings_input_dim", None
|
||||
),
|
||||
weight_files=config.get("weight_files", [])
|
||||
|
||||
)
|
||||
|
||||
|
||||
def upsample_nearest(x, scale: int = 2):
|
||||
B, H, W, C = x.shape
|
||||
x = mx.broadcast_to(x[:, :, None, :, None, :], (B, H, scale, W, scale, C))
|
||||
x = x.reshape(B, H * scale, W * scale, C)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(self, in_channels: int, time_embed_dim: int):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
||||
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.linear_1(x)
|
||||
x = nn.silu(x)
|
||||
x = self.linear_2(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_dims: int,
|
||||
num_heads: int,
|
||||
hidden_dims: Optional[int] = None,
|
||||
memory_dims: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = nn.LayerNorm(model_dims)
|
||||
self.attn1 = nn.MultiHeadAttention(model_dims, num_heads)
|
||||
self.attn1.out_proj.bias = mx.zeros(model_dims)
|
||||
|
||||
memory_dims = memory_dims or model_dims
|
||||
self.norm2 = nn.LayerNorm(model_dims)
|
||||
self.attn2 = nn.MultiHeadAttention(
|
||||
model_dims, num_heads, key_input_dims=memory_dims
|
||||
)
|
||||
self.attn2.out_proj.bias = mx.zeros(model_dims)
|
||||
|
||||
hidden_dims = hidden_dims or 4 * model_dims
|
||||
self.norm3 = nn.LayerNorm(model_dims)
|
||||
self.linear1 = nn.Linear(model_dims, hidden_dims)
|
||||
self.linear2 = nn.Linear(model_dims, hidden_dims)
|
||||
self.linear3 = nn.Linear(hidden_dims, model_dims)
|
||||
|
||||
def __call__(self, x, memory, attn_mask, memory_mask):
|
||||
# Self attention
|
||||
y = self.norm1(x)
|
||||
y = self.attn1(y, y, y, attn_mask)
|
||||
x = x + y
|
||||
|
||||
# Cross attention
|
||||
y = self.norm2(x)
|
||||
y = self.attn2(y, memory, memory, memory_mask)
|
||||
x = x + y
|
||||
|
||||
# FFN
|
||||
y = self.norm3(x)
|
||||
y_a = self.linear1(y)
|
||||
y_b = self.linear2(y)
|
||||
y = y_a * nn.gelu(y_b)
|
||||
y = self.linear3(y)
|
||||
x = x + y
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Transformer2D(nn.Module):
|
||||
"""A transformer model for inputs with 2 spatial dimensions."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
model_dims: int,
|
||||
encoder_dims: int,
|
||||
num_heads: int,
|
||||
num_layers: int = 1,
|
||||
norm_num_groups: int = 32,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.norm = nn.GroupNorm(norm_num_groups, in_channels, pytorch_compatible=True)
|
||||
self.proj_in = nn.Linear(in_channels, model_dims)
|
||||
self.transformer_blocks = [
|
||||
TransformerBlock(model_dims, num_heads, memory_dims=encoder_dims)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
self.proj_out = nn.Linear(model_dims, in_channels)
|
||||
|
||||
def __call__(self, x, encoder_x, attn_mask, encoder_attn_mask):
|
||||
# Save the input to add to the output
|
||||
input_x = x
|
||||
dtype = x.dtype
|
||||
|
||||
# Perform the input norm and projection
|
||||
B, H, W, C = x.shape
|
||||
x = self.norm(x.astype(mx.float32)).astype(dtype).reshape(B, -1, C)
|
||||
x = self.proj_in(x)
|
||||
|
||||
# Apply the transformer
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, encoder_x, attn_mask, encoder_attn_mask)
|
||||
|
||||
# Apply the output projection and reshape
|
||||
x = self.proj_out(x)
|
||||
x = x.reshape(B, H, W, C)
|
||||
|
||||
return x + input_x
|
||||
|
||||
|
||||
class ResnetBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
groups: int = 32,
|
||||
temb_channels: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
out_channels = out_channels or in_channels
|
||||
|
||||
self.norm1 = nn.GroupNorm(groups, in_channels, pytorch_compatible=True)
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
if temb_channels is not None:
|
||||
self.time_emb_proj = nn.Linear(temb_channels, out_channels)
|
||||
self.norm2 = nn.GroupNorm(groups, out_channels, pytorch_compatible=True)
|
||||
self.conv2 = nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
if in_channels != out_channels:
|
||||
self.conv_shortcut = nn.Linear(in_channels, out_channels)
|
||||
|
||||
def __call__(self, x, temb=None):
|
||||
dtype = x.dtype
|
||||
|
||||
if temb is not None:
|
||||
temb = self.time_emb_proj(nn.silu(temb))
|
||||
y = self.norm1(x.astype(mx.float32)).astype(dtype)
|
||||
|
||||
y = nn.silu(y)
|
||||
|
||||
y = self.conv1(y)
|
||||
|
||||
|
||||
if temb is not None:
|
||||
y = y + temb[:, None, None, :]
|
||||
y = self.norm2(y.astype(mx.float32)).astype(dtype)
|
||||
y = nn.silu(y)
|
||||
y = self.conv2(y)
|
||||
|
||||
x = y + (x if "conv_shortcut" not in self else self.conv_shortcut(x))
|
||||
return x
|
||||
|
||||
|
||||
class UNetBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
prev_out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
transformer_layers_per_block: int = 1,
|
||||
num_attention_heads: int = 8,
|
||||
cross_attention_dim=1280,
|
||||
resnet_groups: int = 32,
|
||||
add_downsample=True,
|
||||
add_upsample=True,
|
||||
add_cross_attention=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Prepare the in channels list for the resnets
|
||||
if prev_out_channels is None:
|
||||
in_channels_list = [in_channels] + [out_channels] * (num_layers - 1)
|
||||
else:
|
||||
in_channels_list = [prev_out_channels] + [out_channels] * (num_layers - 1)
|
||||
res_channels_list = [out_channels] * (num_layers - 1) + [in_channels]
|
||||
in_channels_list = [
|
||||
a + b for a, b in zip(in_channels_list, res_channels_list)
|
||||
]
|
||||
|
||||
# Add resnet blocks that also process the time embedding
|
||||
self.resnets = [
|
||||
ResnetBlock2D(
|
||||
in_channels=ic,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
groups=resnet_groups,
|
||||
)
|
||||
for ic in in_channels_list
|
||||
]
|
||||
|
||||
# Add optional cross attention layers
|
||||
if add_cross_attention:
|
||||
self.attentions = [
|
||||
Transformer2D(
|
||||
in_channels=out_channels,
|
||||
model_dims=out_channels,
|
||||
num_heads=num_attention_heads,
|
||||
num_layers=transformer_layers_per_block,
|
||||
encoder_dims=cross_attention_dim,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
|
||||
# Add an optional downsampling layer
|
||||
if add_downsample:
|
||||
self.downsample = nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=2, padding=1
|
||||
)
|
||||
|
||||
# or upsampling layer
|
||||
if add_upsample:
|
||||
self.upsample = nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x,
|
||||
encoder_x=None,
|
||||
temb=None,
|
||||
attn_mask=None,
|
||||
encoder_attn_mask=None,
|
||||
residual_hidden_states=None,
|
||||
):
|
||||
output_states = []
|
||||
|
||||
for i in range(len(self.resnets)):
|
||||
if residual_hidden_states is not None:
|
||||
x = mx.concatenate([x, residual_hidden_states.pop()], axis=-1)
|
||||
|
||||
x = self.resnets[i](x, temb)
|
||||
|
||||
if "attentions" in self:
|
||||
x = self.attentions[i](x, encoder_x, attn_mask, encoder_attn_mask)
|
||||
|
||||
output_states.append(x)
|
||||
|
||||
if "downsample" in self:
|
||||
x = self.downsample(x)
|
||||
output_states.append(x)
|
||||
|
||||
if "upsample" in self:
|
||||
x = self.upsample(upsample_nearest(x))
|
||||
output_states.append(x)
|
||||
|
||||
return x, output_states
|
||||
|
||||
|
||||
class UNetModel(nn.Module):
|
||||
"""The conditional 2D UNet model that actually performs the denoising."""
|
||||
|
||||
def __init__(self, config: UNetConfig, shard: Shard):
|
||||
super().__init__()
|
||||
self.shard = shard
|
||||
self.start_layer = shard.start_layer
|
||||
self.end_layer = shard.end_layer
|
||||
self.layers_range = list(range(self.start_layer, self.end_layer+1))
|
||||
if shard.is_first_layer():
|
||||
self.conv_in = nn.Conv2d(
|
||||
config.in_channels,
|
||||
config.block_out_channels[0],
|
||||
config.conv_in_kernel,
|
||||
padding=(config.conv_in_kernel - 1) // 2,
|
||||
)
|
||||
|
||||
self.timesteps = nn.SinusoidalPositionalEncoding(
|
||||
config.block_out_channels[0],
|
||||
max_freq=1,
|
||||
min_freq=math.exp(
|
||||
-math.log(10000) + 2 * math.log(10000) / config.block_out_channels[0]
|
||||
),
|
||||
scale=1.0,
|
||||
cos_first=True,
|
||||
full_turns=False,
|
||||
)
|
||||
self.time_embedding = TimestepEmbedding(
|
||||
config.block_out_channels[0],
|
||||
config.block_out_channels[0] * 4,
|
||||
)
|
||||
|
||||
if config.addition_embed_type == "text_time":
|
||||
self.add_time_proj = nn.SinusoidalPositionalEncoding(
|
||||
config.addition_time_embed_dim,
|
||||
max_freq=1,
|
||||
min_freq=math.exp(
|
||||
-math.log(10000)
|
||||
+ 2 * math.log(10000) / config.addition_time_embed_dim
|
||||
),
|
||||
scale=1.0,
|
||||
cos_first=True,
|
||||
full_turns=False,
|
||||
)
|
||||
self.add_embedding = TimestepEmbedding(
|
||||
config.projection_class_embeddings_input_dim,
|
||||
config.block_out_channels[0] * 4,
|
||||
)
|
||||
|
||||
# Make the downsampling blocks
|
||||
block_channels = [config.block_out_channels[0]] + list(
|
||||
config.block_out_channels
|
||||
)
|
||||
self.down_blocks = []
|
||||
|
||||
for i, (in_channels, out_channels) in enumerate(zip(block_channels, block_channels[1:])):
|
||||
if i in self.layers_range:
|
||||
self.down_blocks.append(
|
||||
UNetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=config.block_out_channels[0] * 4,
|
||||
num_layers=config.layers_per_block[i],
|
||||
transformer_layers_per_block=config.transformer_layers_per_block[i],
|
||||
num_attention_heads=config.num_attention_heads[i],
|
||||
cross_attention_dim=config.cross_attention_dim[i],
|
||||
resnet_groups=config.norm_num_groups,
|
||||
add_downsample=(i < len(config.block_out_channels) - 1),
|
||||
add_upsample=False,
|
||||
add_cross_attention="CrossAttn" in config.down_block_types[i],
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.down_blocks.append(nn.Identity())
|
||||
|
||||
|
||||
# Make the middle block
|
||||
if 4 in self.layers_range:
|
||||
self.mid_blocks = [
|
||||
ResnetBlock2D(
|
||||
in_channels=config.block_out_channels[-1],
|
||||
out_channels=config.block_out_channels[-1],
|
||||
temb_channels=config.block_out_channels[0] * 4,
|
||||
groups=config.norm_num_groups,
|
||||
),
|
||||
Transformer2D(
|
||||
in_channels=config.block_out_channels[-1],
|
||||
model_dims=config.block_out_channels[-1],
|
||||
num_heads=config.num_attention_heads[-1],
|
||||
num_layers=config.transformer_layers_per_block[-1],
|
||||
encoder_dims=config.cross_attention_dim[-1],
|
||||
),
|
||||
ResnetBlock2D(
|
||||
in_channels=config.block_out_channels[-1],
|
||||
out_channels=config.block_out_channels[-1],
|
||||
temb_channels=config.block_out_channels[0] * 4,
|
||||
groups=config.norm_num_groups,
|
||||
),
|
||||
]
|
||||
|
||||
# Make the upsampling blocks
|
||||
block_channels = (
|
||||
[config.block_out_channels[0]]
|
||||
+ list(config.block_out_channels)
|
||||
+ [config.block_out_channels[-1]]
|
||||
)
|
||||
|
||||
total_items = len(block_channels) - 3
|
||||
reversed_channels = list(reversed(list(zip(block_channels, block_channels[1:], block_channels[2:]))))
|
||||
|
||||
self.up_blocks = []
|
||||
for rev_i, (in_channels, out_channels, prev_out_channels) in enumerate(reversed_channels):
|
||||
i = total_items - rev_i
|
||||
if rev_i+5 in self.layers_range:
|
||||
self.up_blocks.append(
|
||||
UNetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=config.block_out_channels[0] * 4,
|
||||
prev_out_channels=prev_out_channels,
|
||||
num_layers=config.layers_per_block[i] + 1,
|
||||
transformer_layers_per_block=config.transformer_layers_per_block[i],
|
||||
num_attention_heads=config.num_attention_heads[i],
|
||||
cross_attention_dim=config.cross_attention_dim[i],
|
||||
resnet_groups=config.norm_num_groups,
|
||||
add_downsample=False,
|
||||
add_upsample=(i > 0),
|
||||
add_cross_attention="CrossAttn" in config.up_block_types[i],
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.up_blocks.append(nn.Identity())
|
||||
|
||||
|
||||
if shard.is_last_layer():
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
config.norm_num_groups,
|
||||
config.block_out_channels[0],
|
||||
pytorch_compatible=True,
|
||||
)
|
||||
self.conv_out = nn.Conv2d(
|
||||
config.block_out_channels[0],
|
||||
config.out_channels,
|
||||
config.conv_out_kernel,
|
||||
padding=(config.conv_out_kernel - 1) // 2,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x,
|
||||
timestep,
|
||||
encoder_x,
|
||||
attn_mask=None,
|
||||
encoder_attn_mask=None,
|
||||
text_time=None,
|
||||
residuals=None,
|
||||
):
|
||||
# Compute the time embeddings
|
||||
|
||||
temb = self.timesteps(timestep).astype(x.dtype)
|
||||
temb = self.time_embedding(temb)
|
||||
|
||||
# Add the extra text_time conditioning
|
||||
if text_time is not None:
|
||||
text_emb, time_ids = text_time
|
||||
emb = self.add_time_proj(time_ids).flatten(1).astype(x.dtype)
|
||||
emb = mx.concatenate([text_emb, emb], axis=-1)
|
||||
emb = self.add_embedding(emb)
|
||||
temb = temb + emb
|
||||
|
||||
if self.shard.is_first_layer():
|
||||
# Preprocess the input
|
||||
x = self.conv_in(x)
|
||||
residuals = [x]
|
||||
# Run the downsampling part of the unet
|
||||
|
||||
for i in range(len(self.down_blocks)):
|
||||
if i in self.layers_range:
|
||||
x, res = self.down_blocks[i](
|
||||
x,
|
||||
encoder_x=encoder_x,
|
||||
temb=temb,
|
||||
attn_mask=attn_mask,
|
||||
encoder_attn_mask=encoder_attn_mask,
|
||||
)
|
||||
residuals.extend(res)
|
||||
else:
|
||||
x= self.down_blocks[i](x)
|
||||
|
||||
if 4 in self.layers_range:
|
||||
# Run the middle part of the unet
|
||||
x = self.mid_blocks[0](x, temb)
|
||||
x = self.mid_blocks[1](x, encoder_x, attn_mask, encoder_attn_mask)
|
||||
x = self.mid_blocks[2](x, temb)
|
||||
|
||||
# Run the upsampling part of the unet
|
||||
for i in range(len(self.up_blocks)):
|
||||
if i+5 in self.layers_range:
|
||||
x, _ = self.up_blocks[i](
|
||||
x,
|
||||
encoder_x=encoder_x,
|
||||
temb=temb,
|
||||
attn_mask=attn_mask,
|
||||
encoder_attn_mask=encoder_attn_mask,
|
||||
residual_hidden_states=residuals,
|
||||
)
|
||||
else:
|
||||
x= self.up_blocks[i](x)
|
||||
|
||||
# Postprocess the output
|
||||
if self.shard.is_last_layer():
|
||||
dtype = x.dtype
|
||||
x = self.conv_norm_out(x.astype(mx.float32)).astype(dtype)
|
||||
x = nn.silu(x)
|
||||
x = self.conv_out(x)
|
||||
|
||||
return x, residuals
|
||||
def sanitize(self, weights):
|
||||
sanitized_weights = {}
|
||||
for key, value in weights.items():
|
||||
k1=""
|
||||
k2=""
|
||||
if "downsamplers" in key:
|
||||
key = key.replace("downsamplers.0.conv", "downsample")
|
||||
if "upsamplers" in key:
|
||||
key = key.replace("upsamplers.0.conv", "upsample")
|
||||
|
||||
# Map the mid block
|
||||
if "mid_block.resnets.0" in key:
|
||||
key = key.replace("mid_block.resnets.0", "mid_blocks.0")
|
||||
if "mid_block.attentions.0" in key:
|
||||
key = key.replace("mid_block.attentions.0", "mid_blocks.1")
|
||||
if "mid_block.resnets.1" in key:
|
||||
key = key.replace("mid_block.resnets.1", "mid_blocks.2")
|
||||
|
||||
# Map attention layers
|
||||
if "to_k" in key:
|
||||
key = key.replace("to_k", "key_proj")
|
||||
if "to_out.0" in key:
|
||||
key = key.replace("to_out.0", "out_proj")
|
||||
if "to_q" in key:
|
||||
key = key.replace("to_q", "query_proj")
|
||||
if "to_v" in key:
|
||||
key = key.replace("to_v", "value_proj")
|
||||
|
||||
# Map transformer ffn
|
||||
if "ff.net.2" in key:
|
||||
key = key.replace("ff.net.2", "linear3")
|
||||
if "ff.net.0" in key:
|
||||
k1 = key.replace("ff.net.0.proj", "linear1")
|
||||
k2 = key.replace("ff.net.0.proj", "linear2")
|
||||
v1, v2 = mx.split(value, 2)
|
||||
|
||||
|
||||
if "conv_shortcut.weight" in key:
|
||||
value = value.squeeze()
|
||||
|
||||
# Transform the weights from 1x1 convs to linear
|
||||
if len(value.shape) == 4 and ("proj_in" in key or "proj_out" in key):
|
||||
value = value.squeeze()
|
||||
|
||||
if len(value.shape) == 4:
|
||||
value = value.transpose(0, 2, 3, 1)
|
||||
value = value.reshape(-1).reshape(value.shape)
|
||||
|
||||
if key.startswith("conv_in") :
|
||||
if 0 not in self.layers_range:
|
||||
continue
|
||||
|
||||
if key.startswith("down_blocks"):
|
||||
layer_num = int(key.split(".")[1])
|
||||
if layer_num not in self.layers_range:
|
||||
continue
|
||||
|
||||
if key.startswith("mid_block"):
|
||||
if 4 not in self.layers_range:
|
||||
continue
|
||||
|
||||
if key.startswith("up_blocks"):
|
||||
layer_num = int(key.split(".")[1])
|
||||
if (layer_num+5) not in self.layers_range:
|
||||
continue
|
||||
|
||||
if key.startswith("conv_out") or key.startswith("conv_norm_out"):
|
||||
if 8 not in self.layers_range:
|
||||
continue
|
||||
|
||||
if len(k1)>0:
|
||||
sanitized_weights[k1] = v1
|
||||
sanitized_weights[k2] = v2
|
||||
else:
|
||||
sanitized_weights[key] = value
|
||||
|
||||
|
||||
return sanitized_weights
|
||||
@@ -1,429 +0,0 @@
|
||||
# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/vae.py
|
||||
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .unet import ResnetBlock2D, upsample_nearest
|
||||
from dataclasses import dataclass, field
|
||||
from exo.inference.shard import Shard
|
||||
from typing import Tuple
|
||||
import inspect
|
||||
from ..base import IdentityBlock
|
||||
|
||||
@dataclass
|
||||
class AutoencoderConfig:
|
||||
in_channels: int = 3
|
||||
out_channels: int = 3
|
||||
latent_channels_out: int = 8
|
||||
latent_channels_in: int = 4
|
||||
block_out_channels: Tuple[int] = (128, 256, 512, 512)
|
||||
layers_per_block: int = 2
|
||||
norm_num_groups: int = 32
|
||||
scaling_factor: float = 0.18215
|
||||
weight_files: List[str] = field(default_factory=lambda: [])
|
||||
@classmethod
|
||||
def from_dict(cls, params):
|
||||
return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(AutoencoderConfig):
|
||||
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.shard, dict):
|
||||
self.shard = Shard(**self.shard)
|
||||
|
||||
if not isinstance(self.shard, Shard):
|
||||
raise TypeError(f"Expected shard to be a Shard instance or a dict, got {type(self.shard)} instead")
|
||||
|
||||
if not self.shard.is_first_layer():
|
||||
self.vision_config = None
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""A single head unmasked attention for use with the VAE."""
|
||||
|
||||
def __init__(self, dims: int, norm_groups: int = 32):
|
||||
super().__init__()
|
||||
|
||||
self.group_norm = nn.GroupNorm(norm_groups, dims, pytorch_compatible=True)
|
||||
self.query_proj = nn.Linear(dims, dims)
|
||||
self.key_proj = nn.Linear(dims, dims)
|
||||
self.value_proj = nn.Linear(dims, dims)
|
||||
self.out_proj = nn.Linear(dims, dims)
|
||||
|
||||
def __call__(self, x):
|
||||
B, H, W, C = x.shape
|
||||
|
||||
y = self.group_norm(x)
|
||||
|
||||
queries = self.query_proj(y).reshape(B, H * W, C)
|
||||
keys = self.key_proj(y).reshape(B, H * W, C)
|
||||
values = self.value_proj(y).reshape(B, H * W, C)
|
||||
|
||||
scale = 1 / math.sqrt(queries.shape[-1])
|
||||
scores = (queries * scale) @ keys.transpose(0, 2, 1)
|
||||
attn = mx.softmax(scores, axis=-1)
|
||||
y = (attn @ values).reshape(B, H, W, C)
|
||||
|
||||
y = self.out_proj(y)
|
||||
x = x + y
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class EncoderDecoderBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_layers: int = 1,
|
||||
resnet_groups: int = 32,
|
||||
add_downsample=True,
|
||||
add_upsample=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Add the resnet blocks
|
||||
self.resnets = [
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels if i == 0 else out_channels,
|
||||
out_channels=out_channels,
|
||||
groups=resnet_groups,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
|
||||
# Add an optional downsampling layer
|
||||
if add_downsample:
|
||||
self.downsample = nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=2, padding=0
|
||||
)
|
||||
|
||||
# or upsampling layer
|
||||
if add_upsample:
|
||||
self.upsample = nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
for resnet in self.resnets:
|
||||
x = resnet(x)
|
||||
if "downsample" in self:
|
||||
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
|
||||
x = self.downsample(x)
|
||||
|
||||
if "upsample" in self:
|
||||
x = self.upsample(upsample_nearest(x))
|
||||
return x
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
"""Implements the encoder side of the Autoencoder."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
latent_channels_out: int,
|
||||
block_out_channels: List[int] = [64],
|
||||
layers_per_block: int = 2,
|
||||
resnet_groups: int = 32,
|
||||
layers_range: List[int] = [],
|
||||
shard: Shard = field(default_factory=lambda: Shard("", 0, 0, 0))
|
||||
):
|
||||
super().__init__()
|
||||
self.layers_range = layers_range
|
||||
self.shard = shard
|
||||
if self.shard.is_first_layer():
|
||||
self.conv_in = nn.Conv2d(
|
||||
in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
channels = [block_out_channels[0]] + list(block_out_channels)
|
||||
self.down_blocks = []
|
||||
current_layer = 1
|
||||
for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:])):
|
||||
if current_layer in self.layers_range:
|
||||
self.down_blocks.append(
|
||||
EncoderDecoderBlock2D(
|
||||
in_channels,
|
||||
out_channels,
|
||||
num_layers=layers_per_block,
|
||||
resnet_groups=resnet_groups,
|
||||
add_downsample=i < len(block_out_channels) - 1,
|
||||
add_upsample=False,
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.down_blocks.append(IdentityBlock())
|
||||
current_layer += 1
|
||||
|
||||
if self.shard.is_last_layer():
|
||||
self.mid_blocks = [
|
||||
ResnetBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
out_channels=block_out_channels[-1],
|
||||
groups=resnet_groups,
|
||||
),
|
||||
Attention(block_out_channels[-1], resnet_groups),
|
||||
ResnetBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
out_channels=block_out_channels[-1],
|
||||
groups=resnet_groups,
|
||||
),
|
||||
]
|
||||
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
resnet_groups, block_out_channels[-1], pytorch_compatible=True
|
||||
)
|
||||
self.conv_out = nn.Conv2d(block_out_channels[-1], latent_channels_out, 3, padding=1)
|
||||
|
||||
def __call__(self, x):
|
||||
if self.shard.is_first_layer():
|
||||
x = self.conv_in(x)
|
||||
|
||||
for l in self.down_blocks:
|
||||
x = l(x)
|
||||
|
||||
if self.shard.is_last_layer():
|
||||
x = self.mid_blocks[0](x)
|
||||
x = self.mid_blocks[1](x)
|
||||
x = self.mid_blocks[2](x)
|
||||
|
||||
x = self.conv_norm_out(x)
|
||||
x = nn.silu(x)
|
||||
x = self.conv_out(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""Implements the decoder side of the Autoencoder."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
shard: Shard,
|
||||
layer_range: List[int],
|
||||
block_out_channels: List[int] = [64],
|
||||
layers_per_block: int = 2,
|
||||
resnet_groups: int = 32,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels
|
||||
self.layers_range = layer_range
|
||||
if 0 in layer_range:
|
||||
self.conv_in = nn.Conv2d(
|
||||
in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
if 0 in layer_range:
|
||||
self.mid_blocks = [
|
||||
ResnetBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
out_channels=block_out_channels[-1],
|
||||
groups=resnet_groups,
|
||||
),
|
||||
Attention(block_out_channels[-1], resnet_groups),
|
||||
ResnetBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
out_channels=block_out_channels[-1],
|
||||
groups=resnet_groups,
|
||||
),
|
||||
]
|
||||
|
||||
channels = list(reversed(block_out_channels))
|
||||
channels = [channels[0]] + channels
|
||||
|
||||
self.up_blocks = []
|
||||
current_layer = 1
|
||||
|
||||
for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:])):
|
||||
if current_layer in layer_range:
|
||||
self.up_blocks.append(
|
||||
EncoderDecoderBlock2D(
|
||||
in_channels,
|
||||
out_channels,
|
||||
num_layers=layers_per_block,
|
||||
resnet_groups=resnet_groups,
|
||||
add_downsample=False,
|
||||
add_upsample=i < len(block_out_channels) - 1,
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.up_blocks.append(IdentityBlock())
|
||||
current_layer += 1
|
||||
if 4 in layer_range:
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
resnet_groups, block_out_channels[0], pytorch_compatible=True
|
||||
)
|
||||
self.conv_out = nn.Conv2d(block_out_channels[0], self.out_channels, 3, padding=1)
|
||||
|
||||
|
||||
def __call__(self, x):
|
||||
if 0 in self.layers_range:
|
||||
x = self.conv_in(x)
|
||||
x = self.mid_blocks[0](x)
|
||||
x = self.mid_blocks[1](x)
|
||||
x = self.mid_blocks[2](x)
|
||||
|
||||
for l in self.up_blocks:
|
||||
x = l(x)
|
||||
if 4 in self.layers_range:
|
||||
x = self.conv_norm_out(x)
|
||||
x = nn.silu(x)
|
||||
x = self.conv_out(x)
|
||||
return x
|
||||
|
||||
|
||||
class Autoencoder(nn.Module):
|
||||
"""The autoencoder that allows us to perform diffusion in the latent space."""
|
||||
|
||||
def __init__(self, config: AutoencoderConfig, shard: Shard, model_shard: str):
|
||||
super().__init__()
|
||||
self.shard = shard
|
||||
self.start_layer = shard.start_layer
|
||||
self.end_layer = shard.end_layer
|
||||
self.layers_range = list(range(self.start_layer, self.end_layer+1))
|
||||
self.latent_channels = config.latent_channels_in
|
||||
self.scaling_factor = config.scaling_factor
|
||||
self.model_shard = model_shard
|
||||
if self.model_shard == "vae_encoder":
|
||||
self.encoder = Encoder(
|
||||
config.in_channels,
|
||||
config.latent_channels_out,
|
||||
config.block_out_channels,
|
||||
config.layers_per_block,
|
||||
resnet_groups=config.norm_num_groups,
|
||||
layers_range=self.layers_range,
|
||||
shard=shard
|
||||
)
|
||||
if self.shard.is_last_layer():
|
||||
self.quant_proj = nn.Linear(
|
||||
config.latent_channels_out, config.latent_channels_out
|
||||
)
|
||||
if self.model_shard == "vae_decoder":
|
||||
self.decoder = Decoder(
|
||||
config.latent_channels_in,
|
||||
config.out_channels,
|
||||
shard,
|
||||
self.layers_range,
|
||||
config.block_out_channels,
|
||||
config.layers_per_block + 1,
|
||||
resnet_groups=config.norm_num_groups,
|
||||
)
|
||||
if self.shard.is_first_layer():
|
||||
self.post_quant_proj = nn.Linear(
|
||||
config.latent_channels_in, config.latent_channels_in
|
||||
)
|
||||
|
||||
def decode(self, z):
|
||||
if self.shard.is_first_layer():
|
||||
z = z / self.scaling_factor
|
||||
z=self.post_quant_proj(z)
|
||||
return self.decoder(z)
|
||||
|
||||
def encode(self, x):
|
||||
x = self.encoder(x)
|
||||
if self.shard.is_last_layer():
|
||||
x = self.quant_proj(x)
|
||||
mean, logvar = x.split(2, axis=-1)
|
||||
mean = mean * self.scaling_factor
|
||||
logvar = logvar + 2 * math.log(self.scaling_factor)
|
||||
x = mean
|
||||
return x
|
||||
|
||||
def __call__(self, x, key=None):
|
||||
mean, logvar = self.encode(x)
|
||||
z = mx.random.normal(mean.shape, key=key) * mx.exp(0.5 * logvar) + mean
|
||||
x_hat = self.decode(z)
|
||||
|
||||
return dict(x_hat=x_hat, z=z, mean=mean, logvar=logvar)
|
||||
|
||||
def sanitize(self, weights):
|
||||
shard = self.shard
|
||||
layers = self.layers_range
|
||||
sanitized_weights = {}
|
||||
for key, value in weights.items():
|
||||
|
||||
if "downsamplers" in key:
|
||||
key = key.replace("downsamplers.0.conv", "downsample")
|
||||
if "upsamplers" in key:
|
||||
key = key.replace("upsamplers.0.conv", "upsample")
|
||||
|
||||
# Map attention layers
|
||||
if "key" in key:
|
||||
key = key.replace("key", "key_proj")
|
||||
if "proj_attn" in key:
|
||||
key = key.replace("proj_attn", "out_proj")
|
||||
if "query" in key:
|
||||
key = key.replace("query", "query_proj")
|
||||
if "value" in key:
|
||||
key = key.replace("value", "value_proj")
|
||||
|
||||
# Map the mid block
|
||||
if "mid_block.resnets.0" in key:
|
||||
key = key.replace("mid_block.resnets.0", "mid_blocks.0")
|
||||
if "mid_block.attentions.0" in key:
|
||||
key = key.replace("mid_block.attentions.0", "mid_blocks.1")
|
||||
if "mid_block.resnets.1" in key:
|
||||
key = key.replace("mid_block.resnets.1", "mid_blocks.2")
|
||||
|
||||
# Map the quant/post_quant layers
|
||||
if "quant_conv" in key:
|
||||
key = key.replace("quant_conv", "quant_proj")
|
||||
value = value.squeeze()
|
||||
|
||||
# Map the conv_shortcut to linear
|
||||
if "conv_shortcut.weight" in key:
|
||||
value = value.squeeze()
|
||||
|
||||
if len(value.shape) == 4:
|
||||
value = value.transpose(0, 2, 3, 1)
|
||||
value = value.reshape(-1).reshape(value.shape)
|
||||
|
||||
|
||||
if "post_quant_conv" in key :
|
||||
key = key.replace("quant_conv", "quant_proj")
|
||||
value = value.squeeze()
|
||||
|
||||
if 'decoder' in key and self.model_shard == "vae_decoder":
|
||||
if key.startswith("decoder.mid_blocks."):
|
||||
if 0 in layers:
|
||||
sanitized_weights[key] = value
|
||||
if "conv_in" in key and 0 in layers:
|
||||
sanitized_weights[key] = value
|
||||
if key.startswith("decoder.up_blocks."):
|
||||
layer_num = int(key.split(".")[2])+1
|
||||
if layer_num in layers:
|
||||
sanitized_weights[key] = value
|
||||
if key.startswith("decoder.conv_norm_out") and 4 in layers:
|
||||
sanitized_weights[key] = value
|
||||
if key.startswith("decoder.conv_out") and 4 in layers:
|
||||
sanitized_weights[key] = value
|
||||
if self.model_shard == "vae_decoder":
|
||||
if key.startswith("post_quant_proj") and 0 in layers:
|
||||
sanitized_weights[key] = value
|
||||
if self.model_shard == "vae_encoder":
|
||||
if key.startswith("encoder."):
|
||||
if "conv_in" in key and shard.is_first_layer():
|
||||
sanitized_weights[key] = value
|
||||
if key.startswith("encoder.down_blocks."):
|
||||
layer_num = int(key.split(".")[2])+1
|
||||
if layer_num in layers:
|
||||
sanitized_weights[key] = value
|
||||
if key.startswith("encoder.mid_blocks.") and shard.is_last_layer():
|
||||
sanitized_weights[key] = value
|
||||
if "conv_norm_out" in key and shard.is_last_layer():
|
||||
sanitized_weights[key] = value
|
||||
if "conv_out" in key and shard.is_last_layer():
|
||||
sanitized_weights[key] = value
|
||||
if key.startswith("quant_proj") and shard.is_last_layer():
|
||||
sanitized_weights[key] = value
|
||||
return sanitized_weights
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
# Perf improvements
|
||||
|
||||
Target: 460 tok/sec
|
||||
- removing sample goes from 369 -> 402
|
||||
- performance degrades as we generate more tokens
|
||||
- make mlx inference engien synchronous, removing thread pool executor: 402 -> 413
|
||||
- remove self.on_opaque_status.trigger_all: 413 -> 418
|
||||
@@ -1,179 +0,0 @@
|
||||
import numpy as np
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
import mlx.optimizers as optim
|
||||
from ..inference_engine import InferenceEngine
|
||||
from .sharded_utils import load_model_shard, resolve_tokenizer
|
||||
from .losses import loss_fns
|
||||
from ..shard import Shard
|
||||
from typing import Dict, Optional, Tuple
|
||||
from exo.download.shard_download import ShardDownloader
|
||||
import asyncio
|
||||
from collections import OrderedDict
|
||||
from mlx_lm.models.cache import make_prompt_cache
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
class MLXDynamicShardInferenceEngine(InferenceEngine):
|
||||
def __init__(self, shard_downloader: ShardDownloader):
|
||||
self.shard = None
|
||||
self.shard_downloader = shard_downloader
|
||||
self.caches = OrderedDict()
|
||||
self.sampler_params: tuple[float, float] = (0.0, 0.0, 0.0, 1)
|
||||
self.sampler = make_sampler(*self.sampler_params)
|
||||
self._mlx_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix="mlx")
|
||||
self._tokenizer_thread = ThreadPoolExecutor(max_workers=1, thread_name_prefix="tokenizer")
|
||||
self.session = {}
|
||||
self._shard_lock = asyncio.Lock()
|
||||
|
||||
async def _eval_mlx(self, *args):
|
||||
await asyncio.get_running_loop().run_in_executor(self._mlx_thread, mx.eval, *args)
|
||||
|
||||
async def poll_state(self, request_id: str, max_caches=2):
|
||||
if request_id in self.caches:
|
||||
self.caches.move_to_end(request_id)
|
||||
else:
|
||||
newcache = make_prompt_cache(self.model)
|
||||
if len(self.caches) > max_caches:
|
||||
self.caches.popitem(last=False)
|
||||
self.caches[request_id] = newcache
|
||||
return {"cache": self.caches[request_id]}
|
||||
|
||||
async def sample(self, x: np.ndarray, temp: float = 0.0, top_p: float = 1.0) -> np.ndarray:
|
||||
if (temp, top_p, 0.0, 1) != self.sampler_params:
|
||||
self.sampler_params = (temp, top_p, 0.0, 1)
|
||||
self.sampler = make_sampler(*self.sampler_params)
|
||||
logits = mx.array(x)
|
||||
logits = logits[:, -1, :]
|
||||
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
||||
result = self.sampler(logprobs)
|
||||
await self._eval_mlx(result)
|
||||
return np.asarray(result, dtype=int)
|
||||
|
||||
async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
|
||||
await self.ensure_shard(shard)
|
||||
return np.asarray(
|
||||
await asyncio.get_running_loop().run_in_executor(
|
||||
self._tokenizer_thread,
|
||||
self.tokenizer.encode,
|
||||
prompt
|
||||
)
|
||||
)
|
||||
|
||||
async def decode(self, shard: Shard, tokens) -> str:
|
||||
await self.ensure_shard(shard)
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
self._tokenizer_thread,
|
||||
self.tokenizer.decode,
|
||||
tokens
|
||||
)
|
||||
|
||||
async def save_checkpoint(self, shard: Shard, path: str):
|
||||
await self.ensure_shard(shard)
|
||||
await asyncio.get_running_loop().run_in_executor(self._mlx_thread, lambda: self.model.save_weights(path))
|
||||
|
||||
async def load_checkpoint(self, shard: Shard, path: str):
|
||||
await self.ensure_shard(shard)
|
||||
await asyncio.get_running_loop().run_in_executor(self._mlx_thread, lambda: self.model.load_weights(path))
|
||||
|
||||
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
|
||||
await self.ensure_shard(shard)
|
||||
state = await self.poll_state(request_id) if self.model.model_type != 'StableDiffusionPipeline' else {}
|
||||
x = mx.array(input_data)
|
||||
|
||||
if self.model.model_type != 'StableDiffusionPipeline':
|
||||
output_data = await asyncio.get_running_loop().run_in_executor(
|
||||
self._mlx_thread,
|
||||
lambda: self.model(x, **state, **(inference_state or {}))
|
||||
)
|
||||
inference_state = None
|
||||
else:
|
||||
result = await asyncio.get_running_loop().run_in_executor(
|
||||
self._mlx_thread,
|
||||
lambda: self.model(x, **state, **(inference_state or {}))
|
||||
)
|
||||
output_data, inference_state = result
|
||||
|
||||
await self._eval_mlx(output_data)
|
||||
output_data = await asyncio.get_running_loop().run_in_executor(
|
||||
self._mlx_thread,
|
||||
lambda: np.array(output_data, copy=False)
|
||||
)
|
||||
return output_data, inference_state
|
||||
|
||||
async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce"):
|
||||
await self.ensure_shard(shard)
|
||||
await self.save_session('loss', loss_fns[loss])
|
||||
x = mx.array(inputs)
|
||||
y = mx.array(targets)
|
||||
l = mx.array(lengths)
|
||||
|
||||
score = await asyncio.get_running_loop().run_in_executor(
|
||||
self._mlx_thread,
|
||||
lambda: self.session['loss'](self.model, x, y, l)
|
||||
)
|
||||
return score
|
||||
|
||||
async def ensure_train(self, shard: Shard, loss: str, opt=optim.SGD, lr=1e-5, trainable_layers=['input_layernorm', 'gate_proj']):
|
||||
await self.ensure_shard(shard)
|
||||
|
||||
if 'train_layers' not in self.session or self.session['train_layers'] != trainable_layers:
|
||||
await self.save_session('train_layers', trainable_layers)
|
||||
def freeze_unfreeze():
|
||||
self.model.freeze()
|
||||
self.model.apply_to_modules(
|
||||
lambda k, v: v.unfreeze() if any(k.endswith(layer_name) for layer_name in trainable_layers) else None
|
||||
)
|
||||
await asyncio.get_running_loop().run_in_executor(self._mlx_thread, freeze_unfreeze)
|
||||
|
||||
if 'lossname' not in self.session or 'LVaG' not in self.session or self.session['lossname'] != loss:
|
||||
await self.save_session('lossname', loss)
|
||||
await self.save_session('LVaG', nn.value_and_grad(self.model, loss_fns[loss]))
|
||||
|
||||
if 'opt' not in self.session:
|
||||
await self.save_session('opt', opt(lr))
|
||||
return True
|
||||
|
||||
async def train(self, request_id: str, shard: Shard, inputs, targets, lengths, loss: str = "length_masked_ce", opt=optim.SGD, lr=1e-5):
|
||||
await self.ensure_train(shard, loss, opt, lr)
|
||||
|
||||
def train_step(inp, tar, lng):
|
||||
lval, grad = self.session['LVaG'](self.model, inp, tar, lng)
|
||||
gradlayers = grad['model']['layers']
|
||||
self.session['opt'].update(self.model, grad)
|
||||
return lval, gradlayers, (self.model.parameters(), self.session['opt'].state, lval)
|
||||
|
||||
x = mx.array(inputs)
|
||||
y = mx.array(targets)
|
||||
l = mx.array(lengths)
|
||||
score, gradients, eval_args = await asyncio.get_running_loop().run_in_executor(
|
||||
self._mlx_thread,
|
||||
lambda: train_step(x, y, l)
|
||||
)
|
||||
await self._eval_mlx(*eval_args)
|
||||
|
||||
layers = [{k: v["weight"] for k, v in layer.items() if 'weight' in v} for layer in gradients if layer]
|
||||
first_layer = np.array(layers[0]['input_layernorm'], copy=False)
|
||||
await self._eval_mlx(first_layer)
|
||||
return score, first_layer
|
||||
|
||||
async def ensure_shard(self, shard: Shard):
|
||||
async with self._shard_lock:
|
||||
if self.shard == shard: return
|
||||
model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
|
||||
if self.shard != shard:
|
||||
model_shard = await asyncio.get_running_loop().run_in_executor(
|
||||
self._mlx_thread,
|
||||
lambda: load_model_shard(model_path, shard, lazy=False)
|
||||
)
|
||||
if hasattr(model_shard, "tokenizer"):
|
||||
self.tokenizer = model_shard.tokenizer
|
||||
else:
|
||||
self.tokenizer = await resolve_tokenizer(model_path)
|
||||
self.shard = shard
|
||||
self.model = model_shard
|
||||
self.caches = OrderedDict()
|
||||
self.session = {}
|
||||
|
||||
async def cleanup(self):
|
||||
self._mlx_thread.shutdown(wait=True)
|
||||
@@ -1,257 +0,0 @@
|
||||
# Adapted from https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/utils.py
|
||||
|
||||
import glob
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
import aiohttp
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Union, List, Callable
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
import base64
|
||||
import traceback
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from transformers import AutoProcessor
|
||||
|
||||
from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
|
||||
|
||||
from exo import DEBUG
|
||||
from exo.inference.tokenizers import resolve_tokenizer
|
||||
from ..shard import Shard
|
||||
|
||||
|
||||
class ModelNotFoundError(Exception):
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
MODEL_REMAPPING = {
|
||||
"mistral": "llama", # mistral is compatible with llama
|
||||
"phi-msft": "phixtral",
|
||||
}
|
||||
|
||||
|
||||
def _get_classes(config: dict):
|
||||
"""
|
||||
Retrieve the model and model args classes based on the configuration.
|
||||
|
||||
Args:
|
||||
config (dict): The model configuration.
|
||||
|
||||
Returns:
|
||||
A tuple containing the Model class and the ModelArgs class.
|
||||
"""
|
||||
model_type = config["model_type"]
|
||||
model_type = MODEL_REMAPPING.get(model_type, model_type)
|
||||
try:
|
||||
arch = importlib.import_module(f"exo.inference.mlx.models.{model_type}")
|
||||
except ImportError:
|
||||
msg = f"Model type {model_type} not supported."
|
||||
logging.error(msg)
|
||||
traceback.print_exc()
|
||||
raise ValueError(msg)
|
||||
|
||||
return arch.Model, arch.ModelArgs
|
||||
|
||||
|
||||
def load_config(model_path: Path) -> dict:
|
||||
try:
|
||||
config_path = model_path / "config.json"
|
||||
if config_path.exists():
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
return config
|
||||
|
||||
model_index_path = model_path / "model_index.json"
|
||||
if model_index_path.exists():
|
||||
config = load_model_index(model_path, model_index_path)
|
||||
return config
|
||||
except FileNotFoundError:
|
||||
logging.error(f"Config file not found in {model_path}")
|
||||
raise
|
||||
return config
|
||||
|
||||
def load_model_shard(
|
||||
model_path: Path,
|
||||
shard: Shard,
|
||||
lazy: bool = False,
|
||||
model_config: dict = {},
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Load and initialize the model from a given path.
|
||||
|
||||
Args:
|
||||
model_path (Path): The path to load the model from.
|
||||
lazy (bool): If False eval the model parameters to make sure they are
|
||||
loaded in memory before returning, otherwise they will be loaded
|
||||
when needed. Default: ``False``
|
||||
model_config(dict, optional): Configuration parameters for the model.
|
||||
Defaults to an empty dictionary.
|
||||
|
||||
Returns:
|
||||
nn.Module: The loaded and initialized model.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the weight files (.safetensors) are not found.
|
||||
ValueError: If the model class or args class are not found or cannot be instantiated.
|
||||
"""
|
||||
config = load_config(model_path)
|
||||
config.update(model_config)
|
||||
|
||||
# TODO hack
|
||||
config["shard"] = {
|
||||
"model_id": model_path.name,
|
||||
"start_layer": shard.start_layer,
|
||||
"end_layer": shard.end_layer,
|
||||
"n_layers": shard.n_layers,
|
||||
}
|
||||
|
||||
weight_files = glob.glob(str(model_path/"model*.safetensors"))
|
||||
|
||||
if not weight_files:
|
||||
# Try weight for back-compat
|
||||
weight_files = glob.glob(str(model_path/"weight*.safetensors"))
|
||||
|
||||
model_class, model_args_class = _get_classes(config=config)
|
||||
|
||||
class ShardedModel(model_class):
|
||||
def __init__(self, args):
|
||||
super().__init__(args)
|
||||
self.shard = Shard(args.shard.model_id, args.shard.start_layer, args.shard.end_layer, args.shard.n_layers)
|
||||
|
||||
def __call__(self, x, *args, **kwargs):
|
||||
y = super().__call__(x, *args, **kwargs)
|
||||
return y
|
||||
|
||||
model_args = model_args_class.from_dict(config)
|
||||
model = ShardedModel(model_args)
|
||||
|
||||
if config.get("model_index", False):
|
||||
model.load()
|
||||
return model
|
||||
|
||||
if not weight_files:
|
||||
logging.error(f"No safetensors found in {model_path}")
|
||||
raise FileNotFoundError(f"No safetensors found in {model_path}")
|
||||
|
||||
weights = {}
|
||||
for wf in sorted(weight_files):
|
||||
if DEBUG >= 8:
|
||||
layer_nums = set()
|
||||
for k in mx.load(wf):
|
||||
if k.startswith("model.layers."):
|
||||
layer_num = int(k.split(".")[2])
|
||||
layer_nums.add(layer_num)
|
||||
if k.startswith("language_model.model.layers."):
|
||||
layer_num = int(k.split(".")[3])
|
||||
layer_nums.add(layer_num)
|
||||
print(f"\"{wf.split('/')[-1]}\": {sorted(layer_nums)},")
|
||||
|
||||
weights.update(mx.load(wf))
|
||||
|
||||
|
||||
|
||||
if hasattr(model, "sanitize"):
|
||||
weights = model.sanitize(weights)
|
||||
if DEBUG >= 8:
|
||||
print(f"\n|| {config=} ||\n")
|
||||
|
||||
if (quantization := config.get("quantization", None)) is not None:
|
||||
# Handle legacy models which may not have everything quantized
|
||||
def class_predicate(p, m):
|
||||
if not hasattr(m, "to_quantized"):
|
||||
return False
|
||||
return f"{p}.scales" in weights
|
||||
|
||||
|
||||
nn.quantize(
|
||||
model,
|
||||
**quantization,
|
||||
class_predicate=class_predicate,
|
||||
)
|
||||
|
||||
model.load_weights(list(weights.items()), strict=True)
|
||||
|
||||
if not lazy:
|
||||
mx.eval(model.parameters())
|
||||
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
async def load_shard(
|
||||
model_path: str,
|
||||
shard: Shard,
|
||||
tokenizer_config={},
|
||||
model_config={},
|
||||
adapter_path: Optional[str] = None,
|
||||
lazy: bool = False,
|
||||
) -> Tuple[nn.Module, TokenizerWrapper]:
|
||||
model = load_model_shard(model_path, shard, lazy, model_config)
|
||||
|
||||
# TODO: figure out a generic solution
|
||||
if model.model_type == "llava":
|
||||
processor = AutoProcessor.from_pretrained(model_path)
|
||||
processor.eos_token_id = processor.tokenizer.eos_token_id
|
||||
processor.encode = processor.tokenizer.encode
|
||||
return model, processor
|
||||
elif hasattr(model, "tokenizer"):
|
||||
tokenizer = model.tokenizer
|
||||
return model, tokenizer
|
||||
else:
|
||||
tokenizer = await resolve_tokenizer(model_path)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
async def get_image_from_str(_image_str: str):
|
||||
image_str = _image_str.strip()
|
||||
|
||||
if image_str.startswith("http"):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(image_str, timeout=10) as response:
|
||||
content = await response.read()
|
||||
return Image.open(BytesIO(content)).convert("RGB")
|
||||
elif image_str.startswith("data:image/"):
|
||||
# Extract the image format and base64 data
|
||||
format_prefix, base64_data = image_str.split(";base64,")
|
||||
image_format = format_prefix.split("/")[1].lower()
|
||||
if DEBUG >= 2: print(f"{image_str=} {image_format=}")
|
||||
imgdata = base64.b64decode(base64_data)
|
||||
img = Image.open(BytesIO(imgdata))
|
||||
|
||||
# Convert to RGB if not already
|
||||
if img.mode != "RGB":
|
||||
img = img.convert("RGB")
|
||||
|
||||
return img
|
||||
else:
|
||||
raise ValueError("Invalid image_str format. Must be a URL or a base64 encoded image.")
|
||||
|
||||
# loading a combined config for all models in the index
|
||||
def load_model_index(model_path: Path, model_index_path: Path):
|
||||
models_config = {}
|
||||
with open(model_index_path, "r") as f:
|
||||
model_index = json.load(f)
|
||||
models_config["model_index"] = True
|
||||
models_config["model_type"] = model_index["_class_name"]
|
||||
models_config["models"] = {}
|
||||
for model in model_index.keys():
|
||||
model_config_path = glob.glob(str(model_path / model / "*config.json"))
|
||||
if len(model_config_path)>0:
|
||||
with open(model_config_path[0], "r") as f:
|
||||
model_config = { }
|
||||
model_config["model_type"] = model
|
||||
model_config["config"] = json.load(f)
|
||||
model_config["path"] = model_path / model
|
||||
if model_config["path"]/"*model.safetensors":
|
||||
model_config["config"].update({"weight_files": list(glob.glob(str(model_config["path"]/"*model.safetensors")))})
|
||||
model_config["path"] = str(model_path / model)
|
||||
m = {}
|
||||
m[model] = model_config
|
||||
models_config.update(m)
|
||||
return models_config
|
||||
@@ -1,81 +0,0 @@
|
||||
import asyncio
|
||||
import time
|
||||
import numpy as np
|
||||
from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
|
||||
from exo.download.new_shard_download import NewShardDownloader
|
||||
from exo.inference.shard import Shard
|
||||
from exo.models import build_base_shard
|
||||
from collections import deque
|
||||
from statistics import mean, median
|
||||
|
||||
async def test_non_blocking():
|
||||
# Setup
|
||||
shard_downloader = NewShardDownloader()
|
||||
engine = MLXDynamicShardInferenceEngine(shard_downloader)
|
||||
_shard = build_base_shard("llama-3.1-8b", "MLXDynamicShardInferenceEngine")
|
||||
shard = Shard(_shard.model_id, _shard.start_layer, _shard.n_layers - 1, _shard.n_layers)
|
||||
await engine.ensure_shard(shard)
|
||||
|
||||
queue = asyncio.Queue()
|
||||
measurements = deque(maxlen=1000000)
|
||||
running = True
|
||||
|
||||
async def mlx_worker():
|
||||
try:
|
||||
start_time = time.time()
|
||||
count = 0
|
||||
while running and (time.time() - start_time) < 5: # Hard time limit
|
||||
start = time.perf_counter_ns()
|
||||
await engine.infer_prompt("req1", shard, "test prompt")
|
||||
duration = (time.perf_counter_ns() - start) / 1_000_000 # Convert to ms
|
||||
count += 1
|
||||
print(f"MLX operation {count} took: {duration:.3f}ms")
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
print(f"\nTotal MLX operations completed: {count}")
|
||||
print(f"Average rate: {count/5:.1f} ops/second")
|
||||
|
||||
async def latency_producer():
|
||||
try:
|
||||
start_time = time.perf_counter_ns()
|
||||
count = 0
|
||||
while running:
|
||||
await queue.put(time.perf_counter_ns())
|
||||
count += 1
|
||||
await asyncio.sleep(0) # Yield to event loop without delay
|
||||
duration = (time.perf_counter_ns() - start_time) / 1e9 # Convert to seconds
|
||||
print(f"\nProducer iterations: {count}")
|
||||
print(f"Producer rate: {count/duration:.1f} iterations/second")
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def latency_consumer():
|
||||
try:
|
||||
while running:
|
||||
timestamp = await queue.get()
|
||||
latency = (time.perf_counter_ns() - timestamp) / 1_000_000 # Convert to ms
|
||||
measurements.append(latency)
|
||||
queue.task_done()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(mlx_worker()),
|
||||
asyncio.create_task(latency_producer()),
|
||||
asyncio.create_task(latency_consumer())
|
||||
]
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.gather(*tasks), timeout=6)
|
||||
except asyncio.TimeoutError:
|
||||
print("\nTest timed out")
|
||||
finally:
|
||||
running = False
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
print(f"\nFinal measurement count: {len(measurements)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_non_blocking())
|
||||
@@ -1,52 +0,0 @@
|
||||
from exo.inference.shard import Shard
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from typing import Optional
|
||||
import numpy as np
|
||||
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
def __init__(self, shard: Optional[Shard] = None):
|
||||
self.shard = shard
|
||||
self.layers = [
|
||||
nn.Linear(8, 128),
|
||||
nn.Linear(128, 128),
|
||||
nn.Linear(128, 128),
|
||||
nn.Linear(128, 128),
|
||||
nn.Linear(128, 8),
|
||||
]
|
||||
|
||||
self.n_kv_heads = 4
|
||||
self.head_dim = 4
|
||||
|
||||
def __call__(self, x, cache=None):
|
||||
if self.shard:
|
||||
for layer in self.layers[self.shard.start_layer:self.shard.end_layer + 1]:
|
||||
x = layer(x)
|
||||
if self.shard.is_last_layer():
|
||||
x = x.reshape((1, 2, 4))
|
||||
else:
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
x = x.reshape((1, 2, 4))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
model = DummyModel()
|
||||
model.save_weights("./test_weights.npz")
|
||||
n_layers = 5
|
||||
shard1 = Shard("test", 0, n_layers // 2, n_layers)
|
||||
sharded_model1 = DummyModel(shard1)
|
||||
shard2 = Shard("test", n_layers//2 + 1, n_layers - 1, n_layers)
|
||||
sharded_model2 = DummyModel(shard2)
|
||||
|
||||
model.load_weights("./test_weights.npz")
|
||||
sharded_model1.load_weights("./test_weights.npz")
|
||||
sharded_model2.load_weights("./test_weights.npz")
|
||||
|
||||
fullresp = model(mx.array([1, 2, 3, 4, 5, 6, 7, 8]))
|
||||
resp1 = sharded_model1(mx.array([1, 2, 3, 4, 5, 6, 7, 8]))
|
||||
resp2 = sharded_model2(resp1)
|
||||
|
||||
assert np.all(np.array(fullresp) == np.array(resp2))
|
||||
@@ -1,39 +0,0 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Shard:
|
||||
model_id: str
|
||||
start_layer: int
|
||||
end_layer: int
|
||||
n_layers: int
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.model_id, self.start_layer, self.end_layer, self.n_layers))
|
||||
|
||||
def is_first_layer(self) -> bool:
|
||||
return self.start_layer == 0
|
||||
|
||||
def is_last_layer(self) -> bool:
|
||||
return self.end_layer == self.n_layers - 1
|
||||
|
||||
def get_layer_count(self) -> int:
|
||||
return self.end_layer - self.start_layer + 1
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"model_id": self.model_id,
|
||||
"start_layer": self.start_layer,
|
||||
"end_layer": self.end_layer,
|
||||
"n_layers": self.n_layers,
|
||||
}
|
||||
|
||||
def from_dict(data: dict) -> 'Shard':
|
||||
return Shard(**data)
|
||||
|
||||
def overlaps(self, other: 'Shard') -> bool:
|
||||
return shards_overlap(self, other)
|
||||
|
||||
|
||||
def shards_overlap(shard1: Shard, shard2: Shard) -> bool:
|
||||
return (shard1.model_id == shard2.model_id and max(shard1.start_layer, shard2.start_layer) <= min(shard1.end_layer, shard2.end_layer))
|
||||
@@ -1,47 +0,0 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
from exo.inference.dummy_inference_engine import DummyInferenceEngine
|
||||
from exo.inference.shard import Shard
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dummy_inference_specific():
|
||||
engine = DummyInferenceEngine()
|
||||
test_shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
|
||||
test_prompt = "This is a test prompt"
|
||||
|
||||
result, _ = await engine.infer_prompt("test_request", test_shard, test_prompt)
|
||||
|
||||
print(f"Inference result shape: {result.shape}")
|
||||
|
||||
assert result.shape[0] == 1, "Result should be a 2D array with first dimension 1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dummy_inference_engine():
|
||||
# Initialize the DummyInferenceEngine
|
||||
engine = DummyInferenceEngine()
|
||||
|
||||
# Create a test shard
|
||||
shard = Shard(model_id="test_model", start_layer=0, end_layer=1, n_layers=1)
|
||||
|
||||
# Test infer_prompt
|
||||
output, _ = await engine.infer_prompt("test_id", shard, "Test prompt")
|
||||
|
||||
assert isinstance(output, np.ndarray), "Output should be a numpy array"
|
||||
assert output.ndim == 2, "Output should be 2-dimensional"
|
||||
|
||||
# Test infer_tensor
|
||||
input_tensor = np.array([[1, 2, 3]])
|
||||
output, _ = await engine.infer_tensor("test_id", shard, input_tensor)
|
||||
|
||||
assert isinstance(output, np.ndarray), "Output should be a numpy array"
|
||||
assert output.ndim == 2, "Output should be 2-dimensional"
|
||||
|
||||
print("All tests passed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(test_dummy_inference_engine())
|
||||
asyncio.run(test_dummy_inference_specific())
|
||||
@@ -1,54 +0,0 @@
|
||||
from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
|
||||
from exo.inference.inference_engine import InferenceEngine
|
||||
from exo.download.new_shard_download import NewShardDownloader
|
||||
from exo.inference.shard import Shard
|
||||
from exo.helpers import DEBUG
|
||||
import os
|
||||
import asyncio
|
||||
import numpy as np
|
||||
|
||||
|
||||
# An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
|
||||
async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str, n_layers: int):
|
||||
prompt = "In a single word only, what is the last name of the current president of the USA?"
|
||||
resp_full, _ = await inference_engine_1.infer_prompt("A", shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers), prompt=prompt)
|
||||
token_full = await inference_engine_1.sample(resp_full)
|
||||
token_full = token_full.reshape(1, -1)
|
||||
next_resp_full, _ = await inference_engine_1.infer_tensor(
|
||||
"A",
|
||||
shard=Shard(model_id=model_id, start_layer=0, end_layer=n_layers - 1, n_layers=n_layers),
|
||||
input_data=token_full,
|
||||
)
|
||||
|
||||
pp = n_layers // 2
|
||||
resp1, _ = await inference_engine_1.infer_prompt("B", shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers), prompt=prompt)
|
||||
resp2, _ = await inference_engine_2.infer_tensor(
|
||||
"B",
|
||||
shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
|
||||
input_data=resp1,
|
||||
)
|
||||
tokens2 = await inference_engine_1.sample(resp2)
|
||||
tokens2 = tokens2.reshape(1, -1)
|
||||
resp3, _ = await inference_engine_1.infer_tensor(
|
||||
"B",
|
||||
shard=Shard(model_id=model_id, start_layer=0, end_layer=pp, n_layers=n_layers),
|
||||
input_data=tokens2,
|
||||
)
|
||||
resp4, _ = await inference_engine_2.infer_tensor(
|
||||
"B",
|
||||
shard=Shard(model_id=model_id, start_layer=pp + 1, end_layer=n_layers - 1, n_layers=n_layers),
|
||||
input_data=resp3,
|
||||
)
|
||||
|
||||
assert np.array_equal(resp_full, resp2)
|
||||
assert np.array_equal(next_resp_full, resp4)
|
||||
|
||||
|
||||
asyncio.run(test_inference_engine(MLXDynamicShardInferenceEngine(NewShardDownloader()), MLXDynamicShardInferenceEngine(NewShardDownloader()), "llama-3.2-1b", 16))
|
||||
|
||||
if os.getenv("RUN_TINYGRAD", default="0") == "1":
|
||||
import tinygrad
|
||||
import os
|
||||
from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
|
||||
tinygrad.helpers.DEBUG.value = int(os.getenv("TINYGRAD_DEBUG", default="0"))
|
||||
asyncio.run(test_inference_engine(TinygradDynamicShardInferenceEngine(NewShardDownloader()), TinygradDynamicShardInferenceEngine(NewShardDownloader()), "llama-3.2-1b", 32))
|
||||
@@ -1,157 +0,0 @@
|
||||
from pathlib import Path
|
||||
import json
|
||||
import os
|
||||
from exo.inference.tinygrad.models.llama import Transformer, TransformerShard, convert_from_huggingface, fix_bf16, sample_logits
|
||||
from exo.inference.shard import Shard
|
||||
from exo.inference.tokenizers import resolve_tokenizer
|
||||
from tinygrad.nn.state import safe_save, safe_load, get_state_dict, load_state_dict
|
||||
from tinygrad import Tensor, nn, Context, TinyJit
|
||||
from exo.inference.inference_engine import InferenceEngine
|
||||
import numpy as np
|
||||
from exo.inference.tinygrad.tinygrad_helpers import concat_weights, load
|
||||
from exo.download.shard_download import ShardDownloader
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from .stateful_model import make_prompt_state
|
||||
from .losses import length_masked_ce_loss
|
||||
from collections import OrderedDict
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
Tensor.no_grad = True
|
||||
# default settings
|
||||
TEMPERATURE = int(os.getenv("TEMPERATURE", 0.85))
|
||||
TOP_K = 25
|
||||
TOP_P = 0.9
|
||||
ALPHA_F = 0.1
|
||||
ALPHA_P = 0.0
|
||||
MODEL_PARAMS = {
|
||||
"1B": {
|
||||
"args": {
|
||||
"dim": 2048, "n_heads": 32, "n_kv_heads": 8, "n_layers": 16, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 8192,
|
||||
"rope_scaling": {"factor": 32.0, "high_freq_factor": 4.0, "low_freq_factor": 1.0, "original_max_position_embeddings": 8192, "rope_type": "llama3"}, "tie_word_embeddings": True
|
||||
}, "files": 1
|
||||
}, "3B": {
|
||||
"args": {
|
||||
"dim": 3072, "n_heads": 24, "n_kv_heads": 8, "n_layers": 28, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 8192,
|
||||
"rope_scaling": {"factor": 32.0, "high_freq_factor": 4.0, "low_freq_factor": 1.0, "original_max_position_embeddings": 8192, "rope_type": "llama3"}, "tie_word_embeddings": True
|
||||
}, "files": 1
|
||||
}, "8B": {"args": {"dim": 4096, "n_heads": 32, "n_kv_heads": 8, "n_layers": 32, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 14336}, "files": 1},
|
||||
"70B": {"args": {"dim": 8192, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 128256, "hidden_dim": 28672}, "files": 8}
|
||||
}
|
||||
|
||||
|
||||
def build_transformer(model_path: Path, shard: Shard, model_size="8B", device=None):
|
||||
# build model
|
||||
linear = nn.Linear
|
||||
model = Transformer(**MODEL_PARAMS[model_size]["args"], linear=linear, max_context=8192, jit=True, shard=shard)
|
||||
|
||||
# load weights
|
||||
if model_path.is_dir():
|
||||
if (model_path/"model.safetensors.index.json").exists(): weights = load(str(model_path/"model.safetensors.index.json"), shard)
|
||||
elif (model_path/"model.safetensors").exists(): weights = load(str(model_path/"model.safetensors"), shard)
|
||||
else: weights = concat_weights([load(str(model_path/f"consolidated.{i:02d}.pth"), shard) for i in range(MODEL_PARAMS[model_size]["files"])], device[0] if isinstance(device, tuple) else device)
|
||||
else:
|
||||
weights = load(str(model_path), shard)
|
||||
weights = convert_from_huggingface(weights, model, MODEL_PARAMS[model_size]["args"]["n_heads"], MODEL_PARAMS[model_size]["args"]["n_kv_heads"])
|
||||
weights = fix_bf16(weights)
|
||||
|
||||
with Context(BEAM=0):
|
||||
# replace weights in model
|
||||
load_state_dict(model, weights, strict=False, consume=False) # consume=True
|
||||
model = TransformerShard(shard, model)
|
||||
|
||||
return model
|
||||
|
||||
_executor = ThreadPoolExecutor(max_workers=1) # singleton so tinygrad always runs on the same thread
|
||||
class TinygradDynamicShardInferenceEngine(InferenceEngine):
|
||||
def __init__(self, shard_downloader: ShardDownloader):
|
||||
self.shard = None
|
||||
self.shard_downloader = shard_downloader
|
||||
self.states = OrderedDict()
|
||||
self.executor = _executor
|
||||
|
||||
def poll_state(self, x, request_id: str, max_states=2):
|
||||
if request_id not in self.states:
|
||||
if len(self.states) >= max_states:
|
||||
self.states.popitem(last=False)
|
||||
self.states[request_id] = make_prompt_state(x, self.model)
|
||||
else:
|
||||
self.states.move_to_end(request_id)
|
||||
state = self.states[request_id]
|
||||
return {"start_pos": state.start, "cache": state.cache}
|
||||
|
||||
async def sample(self, x: np.ndarray, temp=TEMPERATURE, top_p: float = 0.0) -> np.ndarray:
|
||||
def sample_wrapper():
|
||||
logits = x[:, -1, :]
|
||||
return sample_logits(Tensor(logits).flatten(), temp, 0, 0.8, top_p, 0.0).realize().numpy().astype(int)
|
||||
return await asyncio.get_running_loop().run_in_executor(self.executor, sample_wrapper)
|
||||
|
||||
async def encode(self, shard: Shard, prompt: str) -> np.ndarray:
|
||||
await self.ensure_shard(shard)
|
||||
tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.encode, prompt)
|
||||
return await asyncio.get_running_loop().run_in_executor(self.executor, np.array, tokens)
|
||||
|
||||
async def decode(self, shard: Shard, tokens) -> str:
|
||||
await self.ensure_shard(shard)
|
||||
tokens = await asyncio.get_running_loop().run_in_executor(self.executor, self.tokenizer.decode, tokens)
|
||||
return tokens
|
||||
|
||||
async def load_checkpoint(self, shard: Shard, path: str):
|
||||
await self.ensure_shard(shard)
|
||||
state_dict = safe_load(path)
|
||||
await asyncio.get_running_loop().run_in_executor(self.executor, load_state_dict, self.model, state_dict)
|
||||
|
||||
async def save_checkpoint(self, shard: Shard, path: str):
|
||||
await self.ensure_shard(shard)
|
||||
state_dict = await asyncio.get_running_loop().run_in_executor(self.executor, get_state_dict, self.model)
|
||||
safe_save(state_dict, path)
|
||||
|
||||
async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarray, inference_state: Optional[dict] = None) -> tuple[np.ndarray, Optional[dict]]:
|
||||
await self.ensure_shard(shard)
|
||||
def wrap_infer():
|
||||
x = Tensor(input_data)
|
||||
h = self.model.embed(x)
|
||||
state = self.poll_state(h, request_id)
|
||||
out = self.model.forward(h, **state)
|
||||
self.states[request_id].start += x.shape[1]
|
||||
return out.numpy()
|
||||
output_data = await asyncio.get_running_loop().run_in_executor(self.executor, wrap_infer)
|
||||
return output_data, inference_state
|
||||
|
||||
async def evaluate(self, request_id: str, shard: Shard, inputs, targets, lengths, loss=length_masked_ce_loss):
|
||||
def step(x, y, l):
|
||||
Tensor.training = False
|
||||
return self.session['loss'](self.model, x, y, l)
|
||||
await self.ensure_shard(shard)
|
||||
score = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.session['jit'](Tensor(inputs), targets, lengths))
|
||||
out = score.numpy()
|
||||
return out
|
||||
|
||||
async def train(self, request_id: str, shard: Shard, inputs, targets, lengths, loss=length_masked_ce_loss, opt=nn.optim.Adam, lr=1e-5):
|
||||
def step(x, y, l):
|
||||
Tensor.training = True
|
||||
score = self.session['loss'](self.model, x, y, l)
|
||||
self.session['opt'].zero_grad()
|
||||
score.backward()
|
||||
self.session['opt'].step()
|
||||
return score
|
||||
await self.ensure_shard(shard)
|
||||
|
||||
score = await asyncio.get_running_loop().run_in_executor(self.executor, lambda: self.session['jit'](Tensor(inputs), targets, lengths).realize())
|
||||
|
||||
return loss.numpy(), loss.numpy()
|
||||
|
||||
async def ensure_shard(self, shard: Shard):
|
||||
if self.shard == shard:
|
||||
return
|
||||
|
||||
model_path = await self.shard_downloader.ensure_shard(shard, self.__class__.__name__)
|
||||
|
||||
if self.shard != shard:
|
||||
loop = asyncio.get_running_loop()
|
||||
parameters = "1B" if "1b" in shard.model_id.lower() else "3B" if "3b" in shard.model_id.lower() else "8B" if "8b" in shard.model_id.lower() else "70B"
|
||||
model_shard = await loop.run_in_executor(self.executor, build_transformer, model_path, shard, parameters)
|
||||
|
||||
tokenizer_path = str((model_path if model_path.is_dir() else model_path.parent))
|
||||
self.tokenizer = await resolve_tokenizer(tokenizer_path)
|
||||
self.shard = shard
|
||||
self.model = model_shard
|
||||
@@ -1,14 +0,0 @@
|
||||
from tinygrad import Tensor, dtypes
|
||||
import numpy as np
|
||||
def length_masked_ce_loss(model, inputs, targets, lengths):
|
||||
# Run model on inputs
|
||||
logits = model(inputs).cast(dtypes.float32).contiguous()
|
||||
|
||||
# Mask padding tokens
|
||||
length_mask = Tensor(np.arange(inputs.shape[1])[None, :] < lengths[:, None], requires_grad=False)
|
||||
|
||||
# Calculate the loss
|
||||
ce = logits.sparse_categorical_crossentropy(Tensor(targets, requires_grad=False)).mul(length_mask)
|
||||
loss = ce.sum() / length_mask.sum()
|
||||
return loss
|
||||
|
||||
@@ -1,327 +0,0 @@
|
||||
from typing import Tuple, Union, Optional, Dict, Any, List
|
||||
from tinygrad import Tensor, Variable, TinyJit, dtypes, nn, Device
|
||||
from tinygrad.helpers import getenv
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
# https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=dtypes.half, rope_scaling: Optional[Dict[str, float]] = None) -> Tensor:
|
||||
freqs = 1.0/(theta**(Tensor.arange(0, dim, 2)[:(dim // 2)]/dim))
|
||||
|
||||
if rope_scaling:
|
||||
factor = rope_scaling.get('factor', 1.0)
|
||||
low_freq_factor = rope_scaling.get('low_freq_factor', 1.0)
|
||||
high_freq_factor = rope_scaling.get('high_freq_factor', 1.0)
|
||||
original_max_pos_emb = rope_scaling.get('original_max_position_embeddings', end)
|
||||
|
||||
freqs[:dim // 4] *= low_freq_factor
|
||||
freqs[dim // 4:] = freqs[dim // 4:].contiguous()*high_freq_factor
|
||||
freqs *= (original_max_pos_emb/end)**(1.0/factor)
|
||||
|
||||
freqs = Tensor.arange(end).unsqueeze(dim=1)*freqs.unsqueeze(dim=0)
|
||||
# TODO: move dtype outside this
|
||||
return Tensor.stack(freqs.cos().cast(dtype), freqs.sin().cast(dtype), dim=-1).reshape(1, end, 1, dim // 2, 2)
|
||||
|
||||
|
||||
# (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc)
|
||||
def complex_mult(A, c, d):
|
||||
a, b = A[..., 0:1], A[..., 1:2]
|
||||
ro = a*c - b*d
|
||||
co = a*d + b*c
|
||||
return ro.cat(co, dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_emb(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
assert freqs_cis.shape[1] == xq.shape[1] == xk.shape[1], f"freqs_cis shape mismatch {freqs_cis.shape} xq:{xq.shape} xk:{xk.shape}"
|
||||
xq = xq.reshape(*xq.shape[0:-1], -1, 2)
|
||||
xk = xk.reshape(*xk.shape[0:-1], -1, 2)
|
||||
assert len(xq.shape) == len(xk.shape) == len(freqs_cis.shape) == 5
|
||||
c, d = freqs_cis[..., 0:1], freqs_cis[..., 1:2]
|
||||
xq_out = complex_mult(xq, c, d)
|
||||
xk_out = complex_mult(xk, c, d)
|
||||
return xq_out.flatten(3), xk_out.flatten(3)
|
||||
|
||||
|
||||
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
||||
bs, seqlen, n_kv_heads, head_dim = x.shape
|
||||
if n_rep == 1: return x
|
||||
# NOTE: this is different from x.repeat((1, 1, n_rep, 1))
|
||||
return x.repeat((1, 1, 1, n_rep)).reshape(bs, seqlen, n_kv_heads*n_rep, head_dim)
|
||||
|
||||
class Attention:
|
||||
def __init__(self, dim, n_heads, n_kv_heads, max_context, linear=nn.Linear):
|
||||
self.n_heads = n_heads
|
||||
self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
|
||||
self.head_dim = dim // n_heads
|
||||
self.n_rep = self.n_heads // self.n_kv_heads
|
||||
self.max_context = max_context
|
||||
|
||||
self.wq = linear(dim, self.n_heads*self.head_dim, bias=False)
|
||||
self.wk = linear(dim, self.n_kv_heads*self.head_dim, bias=False)
|
||||
self.wv = linear(dim, self.n_kv_heads*self.head_dim, bias=False)
|
||||
self.wo = linear(self.n_heads*self.head_dim, dim, bias=False)
|
||||
|
||||
def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor], cache: Optional[Tensor]=None) -> Tensor:
|
||||
if getenv("WQKV"):
|
||||
if not hasattr(self, 'wqkv'): self.wqkv = Tensor.cat(self.wq.weight, self.wk.weight, self.wv.weight)
|
||||
xqkv = x @ self.wqkv.T
|
||||
xq, xk, xv = xqkv.split([self.wq.weight.shape[0], self.wk.weight.shape[0], self.wv.weight.shape[0]], dim=2)
|
||||
else:
|
||||
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
||||
|
||||
xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
|
||||
xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
|
||||
xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim)
|
||||
|
||||
xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
|
||||
bsz, seqlen, _, _ = xq.shape
|
||||
|
||||
if cache is not None:
|
||||
# update the cache
|
||||
assert xk.dtype == xv.dtype == cache.dtype, f"{xk.dtype=}, {xv.dtype=}, {cache.dtype=}"
|
||||
cache.shrink((None, None, (start_pos, start_pos + seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()
|
||||
|
||||
keys = cache[0].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xk
|
||||
values = cache[1].shrink((None, (0, start_pos + seqlen), None, None)) if start_pos > 0 else xv
|
||||
else:
|
||||
keys = xk
|
||||
values = xv
|
||||
|
||||
keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
|
||||
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
|
||||
attn = xq.scaled_dot_product_attention(keys, values, mask).transpose(1, 2)
|
||||
attn = attn.reshape(bsz, seqlen, -1)
|
||||
return self.wo(attn)
|
||||
|
||||
|
||||
class FeedForward:
|
||||
def __init__(self, dim: int, hidden_dim: int, linear=nn.Linear):
|
||||
self.w1 = linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = linear(hidden_dim, dim, bias=False)
|
||||
self.w3 = linear(dim, hidden_dim, bias=False) # the gate in Gated Linear Unit
|
||||
|
||||
def __call__(self, x: Tensor) -> Tensor:
|
||||
return self.w2(self.w1(x).silu()*self.w3(x)) # SwiGLU [arxiv/2002.05202, eq (5)]
|
||||
|
||||
|
||||
class TransformerBlock:
|
||||
def __init__(self, dim: int, hidden_dim: int, n_heads: int, n_kv_heads: int, norm_eps: float, max_context: int, linear=nn.Linear, feed_forward=FeedForward):
|
||||
self.attention = Attention(dim, n_heads, n_kv_heads, max_context, linear)
|
||||
self.feed_forward = feed_forward(dim, hidden_dim, linear)
|
||||
self.attention_norm = nn.RMSNorm(dim, norm_eps)
|
||||
self.ffn_norm = nn.RMSNorm(dim, norm_eps)
|
||||
|
||||
def __call__(self, x: Tensor, start_pos: Union[Variable, int], freqs_cis: Tensor, mask: Optional[Tensor], cache: Optional[Tensor]=None):
|
||||
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask, cache=cache)
|
||||
return (h + self.feed_forward(self.ffn_norm(h))).contiguous()
|
||||
|
||||
|
||||
# standard openai sampling
|
||||
def sample_logits(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):
|
||||
assert logits.ndim == 1, "only works on 1d tensors"
|
||||
assert 0 <= p <= 1, "p must be between 0 and 1"
|
||||
assert 0 <= k <= logits.numel(), "k must be between 0 and numel"
|
||||
|
||||
# if temperature is very low just use argmax
|
||||
if temp < 1e-6: return logits.argmax().reshape(1)
|
||||
|
||||
# alpha sampling
|
||||
if af or ap:
|
||||
if not hasattr(sample, "alpha_counter"):
|
||||
setattr(sample, "alpha_counter", Tensor.zeros_like(logits, dtype=dtypes.int32).contiguous())
|
||||
logits = logits - (sample.alpha_counter*af + (sample.alpha_counter > 0)*ap)
|
||||
|
||||
# replace NaNs with -inf
|
||||
logits = (logits != logits).where(-float("inf"), logits)
|
||||
|
||||
# softmax
|
||||
t = (logits/temp).softmax()
|
||||
|
||||
counter, counter2 = Tensor.arange(t.numel(), device=logits.device).contiguous(), Tensor.arange(t.numel() - 1, -1, -1, device=logits.device).contiguous()
|
||||
# top k
|
||||
if k:
|
||||
output, output_indices = Tensor.zeros(k, device=logits.device).contiguous(), Tensor.zeros(k, device=logits.device, dtype=dtypes.int32).contiguous()
|
||||
for i in range(k):
|
||||
t_argmax = (t.numel() - ((t == (t_max := t.max()))*counter2).max() - 1).cast(dtypes.default_int)
|
||||
output = output + t_max.unsqueeze(0).pad(((i, k - i - 1),))
|
||||
output_indices = output_indices + t_argmax.unsqueeze(0).pad(((i, k - i - 1),))
|
||||
t = (counter == t_argmax).where(0, t)
|
||||
|
||||
# approximate top p
|
||||
# because we are already limited to top k elements we can do top p "without sorting"
|
||||
output_cumsum = output[::-1]._cumsum()[::-1] + t.sum()
|
||||
output = (output_cumsum >= (1 - p))*output
|
||||
output_indices = (output_cumsum >= (1 - p))*output_indices
|
||||
|
||||
# sample
|
||||
output_idx = output.multinomial()
|
||||
output_token = output_indices[output_idx]
|
||||
else:
|
||||
output_token = t.multinomial()
|
||||
|
||||
# increase alpha counter
|
||||
if af or ap:
|
||||
sample.alpha_counter = (counter == output_token).where(sample.alpha_counter + 1, sample.alpha_counter)
|
||||
|
||||
return output_token
|
||||
|
||||
|
||||
from exo.inference.shard import Shard
|
||||
|
||||
|
||||
class Transformer:
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
n_heads: int,
|
||||
n_layers: int,
|
||||
norm_eps: float,
|
||||
vocab_size,
|
||||
shard: Shard = None,
|
||||
linear=nn.Linear,
|
||||
n_kv_heads=None,
|
||||
rope_theta=10000,
|
||||
max_context=1024,
|
||||
jit=True,
|
||||
feed_forward=FeedForward,
|
||||
rope_scaling: Optional[Dict[str, float]] = None,
|
||||
tie_word_embeddings=False,
|
||||
):
|
||||
self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, max_context, linear, feed_forward=feed_forward) for _ in range(n_layers)]
|
||||
self.norm = nn.RMSNorm(dim, norm_eps)
|
||||
self.tok_embeddings = nn.Embedding(vocab_size, dim)
|
||||
self.output = nn.Linear(dim, vocab_size, bias=False)
|
||||
if tie_word_embeddings:
|
||||
self.output.weight = self.tok_embeddings.weight
|
||||
self.max_context = max_context
|
||||
self.freqs_cis = precompute_freqs_cis(dim // n_heads, self.max_context*2, rope_theta, rope_scaling=rope_scaling).contiguous()
|
||||
self.forward_jit = TinyJit(self.forward_base) if jit else None
|
||||
self.shard = shard
|
||||
|
||||
def forward_base(self, x: Tensor, start_pos: Union[Variable, int], cache: Optional[List[Tensor]] = None):
|
||||
seqlen = x.shape[1]
|
||||
freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos + seqlen), None, None, None))
|
||||
mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-100000000"), dtype=x.dtype, device=x.device).triu(start_pos + 1).realize() if seqlen > 1 else None
|
||||
|
||||
h = x
|
||||
|
||||
if cache is None:
|
||||
cache = [None for _ in range(self.shard.start_layer, self.shard.end_layer + 1)]
|
||||
for i, c in zip(range(self.shard.start_layer, self.shard.end_layer + 1), cache):
|
||||
layer = self.layers[i]
|
||||
h = layer(h, start_pos, freqs_cis, mask, cache=c)
|
||||
|
||||
if self.shard.is_last_layer():
|
||||
logits = self.output(self.norm(h)).float().realize()
|
||||
return logits
|
||||
else:
|
||||
return h
|
||||
|
||||
def embed(self, inputs: Tensor):
|
||||
if self.shard.is_first_layer():
|
||||
h = self.tok_embeddings(inputs)
|
||||
else:
|
||||
h = inputs
|
||||
return h
|
||||
|
||||
def forward(self, x: Tensor, start_pos: int, cache: Optional[List[Tensor]] = None):
|
||||
if x.shape[0:2] == (1, 1) and self.forward_jit is not None and start_pos != 0:
|
||||
return self.forward_jit(x, Variable("start_pos", 1, self.max_context).bind(start_pos), cache=cache)
|
||||
return self.forward_base(x, start_pos, cache=cache)
|
||||
|
||||
def __call__(self, x: Tensor, start_pos: Variable, cache: Optional[List[Tensor]] = None):
|
||||
# TODO: better way to handle the first call v.s. the rest?
|
||||
h = self.embed(x)
|
||||
return self.forward(h, start_pos, cache=cache)
|
||||
|
||||
class TransformerShard:
|
||||
def __init__(
|
||||
self,
|
||||
shard: Shard,
|
||||
base,
|
||||
jit: bool = True,
|
||||
):
|
||||
shardrange = range(shard.start_layer, shard.end_layer + 1)
|
||||
self.layers = [layer for layer, n in zip(base.layers, range(shard.n_layers)) if n in shardrange]
|
||||
self.norm = base.norm
|
||||
self.tok_embeddings = base.tok_embeddings
|
||||
self.embed = (lambda x: self.tok_embeddings(x)) if shard.is_first_layer() else (lambda x: x)
|
||||
self.output = base.output
|
||||
self.post = (lambda x: self.output(x)) if shard.is_last_layer() else (lambda x: x)
|
||||
self.max_context = base.max_context
|
||||
self.null_cache = [None for _ in shardrange]
|
||||
self.freqs_cis = base.freqs_cis
|
||||
self.forward_jit = TinyJit(self.forward_base) if jit else None
|
||||
|
||||
def forward_base(self, x: Tensor, start_pos: Union[Variable, int], cache):
|
||||
seqlen = x.shape[1]
|
||||
freqs_cis = self.freqs_cis.shrink((None, (start_pos, start_pos + seqlen), None, None, None))
|
||||
mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-100000000"), dtype=x.dtype, device=x.device).triu(start_pos + 1).realize() if seqlen > 1 else None
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
x = layer(x, start_pos, freqs_cis, mask, cache=c)
|
||||
|
||||
out = self.post(x)
|
||||
return out
|
||||
|
||||
def forward(self, x: Tensor, start_pos: int, cache: Optional[List[Tensor]] = None):
|
||||
if x.shape[0:2] == (1, 1) and self.forward_jit is not None and start_pos != 0:
|
||||
return self.forward_jit(x, Variable("start_pos", 1, self.max_context).bind(start_pos), cache=cache)
|
||||
return self.forward_base(x, start_pos, cache=cache)
|
||||
|
||||
def __call__(self, x: Tensor, start_pos: Variable, cache: Optional[List[Tensor]] = None):
|
||||
# TODO: better way to handle the first call v.s. the rest?
|
||||
h = self.embed(x)
|
||||
return self.forward(h, start_pos, cache=self.null_cache if cache is None else cache)
|
||||
|
||||
# *** helpers ***
|
||||
|
||||
|
||||
def convert_from_huggingface(weights: Dict[str, Tensor], model: Transformer, n_heads: int, n_kv_heads: int):
|
||||
def permute(v: Tensor, n_heads: int):
|
||||
return v.reshape(n_heads, 2, v.shape[0] // n_heads // 2, v.shape[1]).transpose(1, 2).reshape(*v.shape[:2])
|
||||
|
||||
keymap = {
|
||||
"model.embed_tokens.weight": "tok_embeddings.weight",
|
||||
**{f"model.layers.{l}.input_layernorm.weight": f"layers.{l}.attention_norm.weight"
|
||||
for l in range(len(model.layers))},
|
||||
**{f"model.layers.{l}.self_attn.{x}_proj.weight": f"layers.{l}.attention.w{x}.weight"
|
||||
for x in ["q", "k", "v", "o"]
|
||||
for l in range(len(model.layers))},
|
||||
**{f"model.layers.{l}.post_attention_layernorm.weight": f"layers.{l}.ffn_norm.weight"
|
||||
for l in range(len(model.layers))},
|
||||
**{f"model.layers.{l}.mlp.{x}_proj.weight": f"layers.{l}.feed_forward.w{y}.weight"
|
||||
for x, y in {"gate": "1", "down": "2", "up": "3"}.items()
|
||||
for l in range(len(model.layers))},
|
||||
"model.norm.weight": "norm.weight",
|
||||
"lm_head.weight": "output.weight",
|
||||
}
|
||||
sd = {}
|
||||
for k, v in weights.items():
|
||||
if ".rotary_emb." in k: continue
|
||||
v = v.to(Device.DEFAULT)
|
||||
if "model.layers" in k:
|
||||
if "q_proj" in k:
|
||||
v = permute(v, n_heads)
|
||||
elif "k_proj" in k:
|
||||
v = permute(v, n_kv_heads)
|
||||
if k in keymap:
|
||||
sd[keymap[k]] = v
|
||||
else:
|
||||
sd[k] = v
|
||||
return sd
|
||||
|
||||
|
||||
def fix_bf16(weights: Dict[Any, Tensor]):
|
||||
if Device.DEFAULT == "CLANG":
|
||||
# TODO: without casting to float16, 70B llama OOM on tinybox.
|
||||
return {
|
||||
k: (v.llvm_bf16_cast(dtypes.float32).to(v.device) if v.dtype == dtypes.bfloat16 else v)
|
||||
for k, v in weights.items()
|
||||
}
|
||||
if getenv("SUPPORT_BF16", 1):
|
||||
# TODO: without casting to float16, 70B llama OOM on tinybox.
|
||||
return {k: v.cast(dtypes.float32).cast(dtypes.float16) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}
|
||||
# TODO: check if device supports bf16
|
||||
return {k: v.llvm_bf16_cast(dtypes.half).to(v.device) if v.dtype == dtypes.bfloat16 else v for k, v in weights.items()}
|
||||
@@ -1,22 +0,0 @@
|
||||
from tinygrad import Tensor, Variable
|
||||
from collections import OrderedDict
|
||||
from typing import List, Optional
|
||||
|
||||
def create_kv_cache(x: Tensor, layer):
|
||||
cache_kv = Tensor.zeros(2, x.shape[0], layer.max_context, layer.n_kv_heads, layer.head_dim, dtype=x.dtype).contiguous().realize()
|
||||
if isinstance(x.device, tuple):
|
||||
# TODO: instead of specifying how to shard, it can follow how xk and xv are being sharded
|
||||
cache_kv.shard_((x.device), axis=3 if getenv("SHARD_KVCACHE") else None).realize()
|
||||
return cache_kv.realize()
|
||||
|
||||
class ModelState:
|
||||
cache: List[Tensor]
|
||||
start: int
|
||||
def __init__(self, cache: List[Tensor], start: int = 0):
|
||||
self.cache = cache
|
||||
self.start = start
|
||||
|
||||
def make_prompt_state(x: Tensor, model):
|
||||
cache = [create_kv_cache(x, l.attention) for l in model.layers]
|
||||
|
||||
return ModelState(cache)
|
||||
@@ -1,52 +0,0 @@
|
||||
from tinygrad.nn.state import safe_load, torch_load
|
||||
from tinygrad import Tensor
|
||||
from pathlib import Path
|
||||
import json
|
||||
from typing import List
|
||||
from exo.inference.shard import Shard
|
||||
from exo.helpers import DEBUG
|
||||
from exo.download.hf.hf_helpers import get_allow_patterns
|
||||
from fnmatch import fnmatch
|
||||
import re
|
||||
|
||||
|
||||
# **** helper functions ****
|
||||
def concat_weights(models, device=None):
|
||||
def convert(name) -> Tensor:
|
||||
disk_tensors: List[Tensor] = [model[name] for model in models]
|
||||
if len(disk_tensors) == 1 or len(disk_tensors[0].shape) == 1:
|
||||
return disk_tensors[0].to(device=device)
|
||||
axis = 1 if name.endswith(".attention.wo.weight") or name.endswith(".feed_forward.w2.weight") else 0
|
||||
lazy_tensors = [data.to(device=device) for data in disk_tensors]
|
||||
return lazy_tensors[0].cat(*lazy_tensors[1:], dim=axis)
|
||||
|
||||
return {name: convert(name) for name in {name: None for model in models for name in model}}
|
||||
|
||||
|
||||
def load(fn: str, shard: Shard):
|
||||
if fn.endswith('.index.json'):
|
||||
with open(fn) as fp:
|
||||
weight_map = json.load(fp)['weight_map']
|
||||
parts = {}
|
||||
filtered_weight_map = {}
|
||||
allow_patterns = get_allow_patterns(weight_map, shard)
|
||||
for k, n in weight_map.items():
|
||||
if allow_patterns is not None and not any(fnmatch(n, r) for r in allow_patterns):
|
||||
continue
|
||||
if k.startswith("model.layers."):
|
||||
layer_num = int(k.split('.')[2])
|
||||
if layer_num < shard.start_layer or layer_num > shard.end_layer:
|
||||
continue
|
||||
|
||||
parts[n] = load(str(Path(fn).parent/Path(n).name), shard)
|
||||
filtered_weight_map[k] = n
|
||||
if DEBUG >= 2: print(f"Excluded model param keys for {shard=}: {sorted(set(weight_map.keys()) - set(filtered_weight_map.keys()))}")
|
||||
return {k: parts[n][k] for k, n in filtered_weight_map.items()}
|
||||
elif fn.endswith(".safetensors"):
|
||||
weight_map = safe_load(fn)
|
||||
for k in list(weight_map):
|
||||
if (n := re.search(r"\.(\d+)\.", k)) and not (shard.start_layer <= int(n.group(1)) <= shard.end_layer):
|
||||
del weight_map[k]
|
||||
return weight_map
|
||||
else:
|
||||
return torch_load(fn)
|
||||
@@ -1,63 +0,0 @@
|
||||
import traceback
|
||||
from os import PathLike
|
||||
from aiofiles import os as aios
|
||||
from typing import Union
|
||||
from transformers import AutoTokenizer, AutoProcessor
|
||||
import numpy as np
|
||||
from exo.helpers import DEBUG
|
||||
from exo.download.new_shard_download import ensure_downloads_dir
|
||||
|
||||
|
||||
class DummyTokenizer:
|
||||
def __init__(self):
|
||||
self.eos_token_id = 69
|
||||
self.vocab_size = 1000
|
||||
|
||||
def apply_chat_template(self, conversation, tokenize=True, add_generation_prompt=True, tools=None, **kwargs):
|
||||
return "dummy_tokenized_prompt"
|
||||
|
||||
def encode(self, text):
|
||||
return np.array([1])
|
||||
|
||||
def decode(self, tokens):
|
||||
return "dummy" * len(tokens)
|
||||
|
||||
|
||||
async def resolve_tokenizer(repo_id: Union[str, PathLike]):
|
||||
if repo_id == "dummy":
|
||||
return DummyTokenizer()
|
||||
local_path = await ensure_downloads_dir()/str(repo_id).replace("/", "--")
|
||||
if DEBUG >= 2: print(f"Checking if local path exists to load tokenizer from local {local_path=}")
|
||||
try:
|
||||
if local_path and await aios.path.exists(local_path):
|
||||
if DEBUG >= 2: print(f"Resolving tokenizer for {repo_id=} from {local_path=}")
|
||||
return await _resolve_tokenizer(local_path)
|
||||
except:
|
||||
if DEBUG >= 5: print(f"Local check for {local_path=} failed. Resolving tokenizer for {repo_id=} normally...")
|
||||
if DEBUG >= 5: traceback.print_exc()
|
||||
return await _resolve_tokenizer(repo_id)
|
||||
|
||||
|
||||
async def _resolve_tokenizer(repo_id_or_local_path: Union[str, PathLike]):
|
||||
try:
|
||||
if DEBUG >= 4: print(f"Trying AutoProcessor for {repo_id_or_local_path}")
|
||||
processor = AutoProcessor.from_pretrained(repo_id_or_local_path, use_fast=True if "Mistral-Large" in f"{repo_id_or_local_path}" else False, trust_remote_code=True)
|
||||
if not hasattr(processor, 'eos_token_id'):
|
||||
processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id
|
||||
if not hasattr(processor, 'encode'):
|
||||
processor.encode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).encode
|
||||
if not hasattr(processor, 'decode'):
|
||||
processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode
|
||||
return processor
|
||||
except Exception as e:
|
||||
if DEBUG >= 4: print(f"Failed to load processor for {repo_id_or_local_path}. Error: {e}")
|
||||
if DEBUG >= 4: print(traceback.format_exc())
|
||||
|
||||
try:
|
||||
if DEBUG >= 4: print(f"Trying AutoTokenizer for {repo_id_or_local_path}")
|
||||
return AutoTokenizer.from_pretrained(repo_id_or_local_path, trust_remote_code=True)
|
||||
except Exception as e:
|
||||
if DEBUG >= 4: print(f"Failed to load tokenizer for {repo_id_or_local_path}. Falling back to tinygrad tokenizer. Error: {e}")
|
||||
if DEBUG >= 4: print(traceback.format_exc())
|
||||
|
||||
raise ValueError(f"[TODO] Unsupported model: {repo_id_or_local_path}")
|
||||
402
exo/main.py
402
exo/main.py
@@ -1,402 +0,0 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import atexit
|
||||
import signal
|
||||
import json
|
||||
import platform
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from exo.train.dataset import load_dataset, iterate_batches
|
||||
from exo.networking.manual.manual_discovery import ManualDiscovery
|
||||
from exo.orchestration.node import Node
|
||||
from exo.networking.grpc.grpc_server import GRPCServer
|
||||
from exo.networking.udp.udp_discovery import UDPDiscovery
|
||||
from exo.networking.tailscale.tailscale_discovery import TailscaleDiscovery
|
||||
from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
|
||||
from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
|
||||
from exo.api import ChatGPTAPI
|
||||
from exo.download.shard_download import ShardDownloader, NoopShardDownloader
|
||||
from exo.download.download_progress import RepoProgressEvent
|
||||
from exo.download.new_shard_download import new_shard_downloader, has_exo_home_read_access, has_exo_home_write_access, ensure_exo_home, seed_models
|
||||
from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_system_info, get_or_create_node_id, get_all_ip_addresses_and_interfaces, terminal_link, shutdown
|
||||
from exo.inference.shard import Shard
|
||||
from exo.inference.inference_engine import get_inference_engine
|
||||
from exo.inference.tokenizers import resolve_tokenizer
|
||||
from exo.models import build_base_shard, get_repo
|
||||
from exo.viz.topology_viz import TopologyViz
|
||||
import uvloop
|
||||
import concurrent.futures
|
||||
import resource
|
||||
import psutil
|
||||
|
||||
# TODO: figure out why this is happening
|
||||
os.environ["GRPC_VERBOSITY"] = "error"
|
||||
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||
|
||||
# Configure uvloop for maximum performance
|
||||
def configure_uvloop():
|
||||
uvloop.install()
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# Increase file descriptor limits on Unix systems
|
||||
if not psutil.WINDOWS:
|
||||
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
try: resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
|
||||
except ValueError:
|
||||
try: resource.setrlimit(resource.RLIMIT_NOFILE, (8192, hard))
|
||||
except ValueError: pass
|
||||
|
||||
loop.set_default_executor(concurrent.futures.ThreadPoolExecutor(max_workers=min(32, (os.cpu_count() or 1) * 4)))
|
||||
return loop
|
||||
|
||||
# parse args
|
||||
parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
|
||||
parser.add_argument("command", nargs="?", choices=["run", "eval", "train"], help="Command to run")
|
||||
parser.add_argument("model_name", nargs="?", help="Model name to run")
|
||||
parser.add_argument("--default-model", type=str, default=None, help="Default model")
|
||||
parser.add_argument("--iters", type=int, default=100, help="Training iterations")
|
||||
parser.add_argument("--save-every", type=int, default=5, help="Save the model every N iterations.")
|
||||
parser.add_argument("--data", type=str, default="exo/train/data/lora", help="Directory where training data lives")
|
||||
parser.add_argument("--batch-size", type=int, default=1, help="Minibatch size.")
|
||||
parser.add_argument("--resume-checkpoint", type=str, default=None, help="Path to a custom checkpoint to load")
|
||||
parser.add_argument("--save-checkpoint-dir", type=str, default="checkpoints", help="Path to a folder where checkpoints are stored")
|
||||
parser.add_argument("--node-id", type=str, default=None, help="Node ID")
|
||||
parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host")
|
||||
parser.add_argument("--node-port", type=int, default=None, help="Node port")
|
||||
parser.add_argument("--models-seed-dir", type=str, default=None, help="Model seed directory")
|
||||
parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery")
|
||||
parser.add_argument("--download-quick-check", action="store_true", help="Quick check local path for model shards download")
|
||||
parser.add_argument("--max-parallel-downloads", type=int, default=8, help="Max parallel downloads for model shards download")
|
||||
parser.add_argument("--broadcast-port", type=int, default=5678, help="Broadcast port for discovery")
|
||||
parser.add_argument("--discovery-module", type=str, choices=["udp", "tailscale", "manual"], default="udp", help="Discovery module to use")
|
||||
parser.add_argument("--discovery-timeout", type=int, default=30, help="Discovery timeout in seconds")
|
||||
parser.add_argument("--discovery-config-path", type=str, default=None, help="Path to discovery config json file")
|
||||
parser.add_argument("--wait-for-peers", type=int, default=0, help="Number of peers to wait to connect to before starting")
|
||||
parser.add_argument("--chatgpt-api-port", type=int, default=52415, help="ChatGPT API port")
|
||||
parser.add_argument("--chatgpt-api-response-timeout", type=int, default=900, help="ChatGPT API response timeout in seconds")
|
||||
parser.add_argument("--max-generate-tokens", type=int, default=10000, help="Max tokens to generate in each request")
|
||||
parser.add_argument("--inference-engine", type=str, default=None, help="Inference engine to use (mlx, tinygrad, or dummy)")
|
||||
parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help="Disable TUI")
|
||||
parser.add_argument("--run-model", type=str, help="Specify a model to run directly")
|
||||
parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
|
||||
parser.add_argument("--default-temp", type=float, help="Default token sampling temperature", default=0.0)
|
||||
parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
|
||||
parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
|
||||
parser.add_argument("--node-id-filter", type=str, default=None, help="Comma separated list of allowed node IDs (only for UDP and Tailscale discovery)")
|
||||
parser.add_argument("--interface-type-filter", type=str, default=None, help="Comma separated list of allowed interface types (only for UDP discovery)")
|
||||
parser.add_argument("--system-prompt", type=str, default=None, help="System prompt for the ChatGPT API")
|
||||
args = parser.parse_args()
|
||||
print(f"Selected inference engine: {args.inference_engine}")
|
||||
|
||||
print_yellow_exo()
|
||||
|
||||
print("\n" + "="*80)
|
||||
print("EXO")
|
||||
print("="*80)
|
||||
print("\nEXO started out of a desire to run research experiments on large language")
|
||||
print("models using the hardware we already owned.")
|
||||
print("\nWhat began here is becoming part of something much larger.")
|
||||
print("\nsoon™")
|
||||
print("\n- The EXO Team")
|
||||
print("="*80 + "\n")
|
||||
|
||||
system_info = get_system_info()
|
||||
print(f"Detected system: {system_info}")
|
||||
|
||||
shard_downloader: ShardDownloader = new_shard_downloader(args.max_parallel_downloads) if args.inference_engine != "dummy" else NoopShardDownloader()
|
||||
inference_engine_name = args.inference_engine or ("mlx" if system_info == "Apple Silicon Mac" else "tinygrad")
|
||||
print(f"Inference engine name after selection: {inference_engine_name}")
|
||||
|
||||
inference_engine = get_inference_engine(inference_engine_name, shard_downloader)
|
||||
print(f"Using inference engine: {inference_engine.__class__.__name__} with shard downloader: {shard_downloader.__class__.__name__}")
|
||||
|
||||
if args.node_port is None:
|
||||
args.node_port = find_available_port(args.node_host)
|
||||
if DEBUG >= 1: print(f"Using available port: {args.node_port}")
|
||||
|
||||
args.node_id = args.node_id or get_or_create_node_id()
|
||||
chatgpt_api_endpoints = [f"http://{ip}:{args.chatgpt_api_port}/v1/chat/completions" for ip, _ in get_all_ip_addresses_and_interfaces()]
|
||||
web_chat_urls = [f"http://{ip}:{args.chatgpt_api_port}" for ip, _ in get_all_ip_addresses_and_interfaces()]
|
||||
if DEBUG >= 0:
|
||||
print("Chat interface started:")
|
||||
for web_chat_url in web_chat_urls:
|
||||
print(f" - {terminal_link(web_chat_url)}")
|
||||
print("ChatGPT API endpoint served at:")
|
||||
for chatgpt_api_endpoint in chatgpt_api_endpoints:
|
||||
print(f" - {terminal_link(chatgpt_api_endpoint)}")
|
||||
|
||||
# Convert node-id-filter and interface-type-filter to lists if provided
|
||||
allowed_node_ids = args.node_id_filter.split(',') if args.node_id_filter else None
|
||||
allowed_interface_types = args.interface_type_filter.split(',') if args.interface_type_filter else None
|
||||
|
||||
if args.discovery_module == "udp":
|
||||
discovery = UDPDiscovery(
|
||||
args.node_id,
|
||||
args.node_port,
|
||||
args.listen_port,
|
||||
args.broadcast_port,
|
||||
lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
|
||||
discovery_timeout=args.discovery_timeout,
|
||||
allowed_node_ids=allowed_node_ids,
|
||||
allowed_interface_types=allowed_interface_types
|
||||
)
|
||||
elif args.discovery_module == "tailscale":
|
||||
discovery = TailscaleDiscovery(
|
||||
args.node_id,
|
||||
args.node_port,
|
||||
lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
|
||||
discovery_timeout=args.discovery_timeout,
|
||||
tailscale_api_key=args.tailscale_api_key,
|
||||
tailnet=args.tailnet_name,
|
||||
allowed_node_ids=allowed_node_ids
|
||||
)
|
||||
elif args.discovery_module == "manual":
|
||||
if not args.discovery_config_path:
|
||||
raise ValueError(f"--discovery-config-path is required when using manual discovery. Please provide a path to a config json file.")
|
||||
discovery = ManualDiscovery(args.discovery_config_path, args.node_id, create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities))
|
||||
topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not args.disable_tui else None
|
||||
node = Node(
|
||||
args.node_id,
|
||||
None,
|
||||
inference_engine,
|
||||
discovery,
|
||||
shard_downloader,
|
||||
partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
|
||||
max_generate_tokens=args.max_generate_tokens,
|
||||
topology_viz=topology_viz,
|
||||
default_sample_temperature=args.default_temp
|
||||
)
|
||||
server = GRPCServer(node, args.node_host, args.node_port)
|
||||
node.server = server
|
||||
api = ChatGPTAPI(
|
||||
node,
|
||||
node.inference_engine.__class__.__name__,
|
||||
response_timeout=args.chatgpt_api_response_timeout,
|
||||
on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None,
|
||||
default_model=args.default_model,
|
||||
system_prompt=args.system_prompt
|
||||
)
|
||||
buffered_token_output = {}
|
||||
def update_topology_viz(req_id, tokens, __):
|
||||
if not topology_viz: return
|
||||
if not node.inference_engine.shard: return
|
||||
if node.inference_engine.shard.model_id == 'stable-diffusion-2-1-base': return
|
||||
if req_id in buffered_token_output: buffered_token_output[req_id].extend(tokens)
|
||||
else: buffered_token_output[req_id] = tokens
|
||||
topology_viz.update_prompt_output(req_id, node.inference_engine.tokenizer.decode(buffered_token_output[req_id]))
|
||||
node.on_token.register("update_topology_viz").on_next(update_topology_viz)
|
||||
def update_prompt_viz(request_id, opaque_status: str):
|
||||
if not topology_viz: return
|
||||
try:
|
||||
status = json.loads(opaque_status)
|
||||
if status.get("type") != "node_status" or status.get("status") != "start_process_prompt": return
|
||||
topology_viz.update_prompt(request_id, status.get("prompt", "corrupted prompt (this should never happen)"))
|
||||
except Exception as e:
|
||||
if DEBUG >= 2:
|
||||
print(f"Failed to update prompt viz: {e}")
|
||||
traceback.print_exc()
|
||||
node.on_opaque_status.register("update_prompt_viz").on_next(update_prompt_viz)
|
||||
|
||||
def preemptively_load_shard(request_id: str, opaque_status: str):
|
||||
try:
|
||||
status = json.loads(opaque_status)
|
||||
if status.get("type") != "node_status" or status.get("status") != "start_process_prompt": return
|
||||
current_shard = node.get_current_shard(Shard.from_dict(status.get("shard")))
|
||||
if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}")
|
||||
asyncio.create_task(node.inference_engine.ensure_shard(current_shard))
|
||||
except Exception as e:
|
||||
if DEBUG >= 2:
|
||||
print(f"Failed to preemptively start download: {e}")
|
||||
traceback.print_exc()
|
||||
node.on_opaque_status.register("preemptively_load_shard").on_next(preemptively_load_shard)
|
||||
|
||||
last_events: dict[str, tuple[float, RepoProgressEvent]] = {}
|
||||
def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
|
||||
global last_events
|
||||
current_time = time.time()
|
||||
if event.status == "not_started": return
|
||||
last_event = last_events.get(shard.model_id)
|
||||
if last_event and last_event[1].status == "complete" and event.status == "complete": return
|
||||
if last_event and last_event[0] == event.status and current_time - last_event[0] < 0.2: return
|
||||
last_events[shard.model_id] = (current_time, event)
|
||||
asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
|
||||
shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
|
||||
|
||||
async def run_model_cli(node: Node, model_name: str, prompt: str):
|
||||
inference_class = node.inference_engine.__class__.__name__
|
||||
shard = build_base_shard(model_name, inference_class)
|
||||
if not shard:
|
||||
print(f"Error: Unsupported model '{model_name}' for inference engine {inference_class}")
|
||||
return
|
||||
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class))
|
||||
request_id = str(uuid.uuid4())
|
||||
callback_id = f"cli-wait-response-{request_id}"
|
||||
callback = node.on_token.register(callback_id)
|
||||
if topology_viz:
|
||||
topology_viz.update_prompt(request_id, prompt)
|
||||
prompt = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True)
|
||||
|
||||
try:
|
||||
print(f"Processing prompt: {prompt}")
|
||||
await node.process_prompt(shard, prompt, request_id=request_id)
|
||||
|
||||
tokens = []
|
||||
def on_token(_request_id, _tokens, _is_finished):
|
||||
tokens.extend(_tokens)
|
||||
return _request_id == request_id and _is_finished
|
||||
await callback.wait(on_token, timeout=300)
|
||||
|
||||
print("\nGenerated response:")
|
||||
print(tokenizer.decode(tokens))
|
||||
except Exception as e:
|
||||
print(f"Error processing prompt: {str(e)}")
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
node.on_token.deregister(callback_id)
|
||||
|
||||
def clean_path(path):
|
||||
"""Clean and resolve path"""
|
||||
if path.startswith("Optional("):
|
||||
path = path.strip('Optional("').rstrip('")')
|
||||
return os.path.expanduser(path)
|
||||
|
||||
async def hold_outstanding(node: Node):
|
||||
while node.outstanding_requests:
|
||||
await asyncio.sleep(.5)
|
||||
return
|
||||
|
||||
async def run_iter(node: Node, shard: Shard, train: bool, data, batch_size=1):
|
||||
losses = []
|
||||
tokens = []
|
||||
for batch in tqdm(iterate_batches(data, batch_size), total=len(data) // batch_size):
|
||||
_, _, lengths = batch
|
||||
losses.append(np.sum(lengths * await node.enqueue_example(shard, *batch, train=train)))
|
||||
tokens.append(np.sum(lengths))
|
||||
total_tokens = np.sum(tokens)
|
||||
total_loss = np.sum(losses) / total_tokens
|
||||
|
||||
return total_loss, total_tokens
|
||||
|
||||
async def eval_model_cli(node: Node, model_name, dataloader, batch_size, num_batches=-1):
|
||||
inference_class = node.inference_engine.__class__.__name__
|
||||
shard = build_base_shard(model_name, inference_class)
|
||||
if not shard:
|
||||
print(f"Error: Unsupported model '{model_name}' for inference engine {inference_class}")
|
||||
return
|
||||
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class))
|
||||
train, val, test = dataloader(tokenizer.encode)
|
||||
print(f"Evaluating {len(test)} examples with batch_size {batch_size}")
|
||||
loss, tokens = await run_iter(node, shard, False, test, batch_size)
|
||||
print(f"total | {loss=}, {tokens=}")
|
||||
print("Waiting for outstanding tasks")
|
||||
await hold_outstanding(node)
|
||||
|
||||
async def train_model_cli(node: Node, model_name, dataloader, batch_size, iters, save_interval=0, checkpoint_dir=None):
|
||||
inference_class = node.inference_engine.__class__.__name__
|
||||
shard = build_base_shard(model_name, inference_class)
|
||||
if not shard:
|
||||
print(f"Error: Unsupported model '{model_name}' for inference engine {inference_class}")
|
||||
return
|
||||
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class))
|
||||
train, val, test = dataloader(tokenizer.encode)
|
||||
print(f"Training on {len(train)} examples with batch_size {batch_size} for {iters} epochs")
|
||||
for i in tqdm(range(3)):
|
||||
await asyncio.sleep(1)
|
||||
for epoch in range(iters):
|
||||
loss, tokens = await run_iter(node, shard, True, train, batch_size)
|
||||
print(f"epoch {epoch + 1}/{iters}\t| loss: {loss}, tokens: {tokens}")
|
||||
if save_interval > 0 and epoch > 0 and (epoch % save_interval) == 0 and checkpoint_dir is not None:
|
||||
await node.coordinate_save(shard, epoch, checkpoint_dir)
|
||||
await hold_outstanding(node)
|
||||
await hold_outstanding(node)
|
||||
|
||||
async def check_exo_home():
|
||||
home, has_read, has_write = await ensure_exo_home(), await has_exo_home_read_access(), await has_exo_home_write_access()
|
||||
if DEBUG >= 1: print(f"exo home directory: {home}")
|
||||
print(f"{has_read=}, {has_write=}")
|
||||
if not has_read or not has_write:
|
||||
print(f"""
|
||||
WARNING: Limited permissions for exo home directory: {home}.
|
||||
This may prevent model downloads from working correctly.
|
||||
{"❌ No read access" if not has_read else ""}
|
||||
{"❌ No write access" if not has_write else ""}
|
||||
""")
|
||||
|
||||
async def main():
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
try: await check_exo_home()
|
||||
except Exception as e: print(f"Error checking exo home directory: {e}")
|
||||
|
||||
if not args.models_seed_dir is None:
|
||||
try:
|
||||
models_seed_dir = clean_path(args.models_seed_dir)
|
||||
await seed_models(models_seed_dir)
|
||||
except Exception as e:
|
||||
print(f"Error seeding models: {e}")
|
||||
|
||||
def restore_cursor():
|
||||
if platform.system() != "Windows":
|
||||
os.system("tput cnorm") # Show cursor
|
||||
|
||||
# Restore the cursor when the program exits
|
||||
atexit.register(restore_cursor)
|
||||
|
||||
# Use a more direct approach to handle signals
|
||||
def handle_exit():
|
||||
asyncio.ensure_future(shutdown(signal.SIGTERM, loop, node.server))
|
||||
|
||||
if platform.system() != "Windows":
|
||||
for s in [signal.SIGINT, signal.SIGTERM]:
|
||||
loop.add_signal_handler(s, handle_exit)
|
||||
|
||||
await node.start(wait_for_peers=args.wait_for_peers)
|
||||
|
||||
if args.command == "run" or args.run_model:
|
||||
model_name = args.model_name or args.run_model
|
||||
if not model_name:
|
||||
print("Error: Model name is required when using 'run' command or --run-model")
|
||||
return
|
||||
await run_model_cli(node, model_name, args.prompt)
|
||||
elif args.command == "eval" or args.command == 'train':
|
||||
model_name = args.model_name
|
||||
dataloader = lambda tok: load_dataset(args.data, preprocess=lambda item: tok(item)
|
||||
, loadline=lambda line: json.loads(line).get("text",""))
|
||||
if args.command == 'eval':
|
||||
if not model_name:
|
||||
print("Error: Much like a human, I can't evaluate anything without a model")
|
||||
return
|
||||
await eval_model_cli(node, model_name, dataloader, args.batch_size)
|
||||
else:
|
||||
if not model_name:
|
||||
print("Error: This train ain't leaving the station without a model")
|
||||
return
|
||||
await train_model_cli(node, model_name, dataloader, args.batch_size, args.iters, save_interval=args.save_every, checkpoint_dir=args.save_checkpoint_dir)
|
||||
|
||||
else:
|
||||
asyncio.create_task(api.run(port=args.chatgpt_api_port)) # Start the API server as a non-blocking task
|
||||
await asyncio.Event().wait()
|
||||
|
||||
if args.wait_for_peers > 0:
|
||||
print("Cooldown to allow peers to exit gracefully")
|
||||
for i in tqdm(range(50)):
|
||||
await asyncio.sleep(.1)
|
||||
|
||||
def run():
|
||||
loop = None
|
||||
try:
|
||||
loop = configure_uvloop()
|
||||
loop.run_until_complete(main())
|
||||
except KeyboardInterrupt:
|
||||
print("\nShutdown requested... exiting")
|
||||
finally:
|
||||
if loop: loop.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
273
exo/models.py
273
exo/models.py
@@ -1,273 +0,0 @@
|
||||
from exo.inference.shard import Shard
|
||||
from typing import Optional, List
|
||||
|
||||
model_cards = {
|
||||
### llama
|
||||
"llama-3.3-70b": {
|
||||
"layers": 80,
|
||||
"repo": {
|
||||
"MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.3-70B-Instruct-4bit",
|
||||
"TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.3-70B-Instruct",
|
||||
},
|
||||
},
|
||||
"llama-3.2-1b": {
|
||||
"layers": 16,
|
||||
"repo": {
|
||||
"MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-1B-Instruct-4bit",
|
||||
"TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-1B-Instruct",
|
||||
},
|
||||
},
|
||||
"llama-3.2-1b-8bit": {
|
||||
"layers": 16,
|
||||
"repo": {
|
||||
"MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-1B-Instruct-8bit",
|
||||
"TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-1B-Instruct",
|
||||
},
|
||||
},
|
||||
"llama-3.2-3b": {
|
||||
"layers": 28,
|
||||
"repo": {
|
||||
"MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-3B-Instruct-4bit",
|
||||
"TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-3B-Instruct",
|
||||
},
|
||||
},
|
||||
"llama-3.2-3b-8bit": {
|
||||
"layers": 28,
|
||||
"repo": {
|
||||
"MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-3B-Instruct-8bit",
|
||||
"TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-3B-Instruct",
|
||||
},
|
||||
},
|
||||
"llama-3.2-3b-bf16": {
|
||||
"layers": 28,
|
||||
"repo": {
|
||||
"MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.2-3B-Instruct",
|
||||
"TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-3B-Instruct",
|
||||
},
|
||||
},
|
||||
"llama-3.1-8b": {
|
||||
"layers": 32,
|
||||
"repo": {
|
||||
"MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
|
||||
"TinygradDynamicShardInferenceEngine": "mlabonne/Meta-Llama-3.1-8B-Instruct-abliterated",
|
||||
},
|
||||
},
|
||||
"llama-3.1-70b": {
|
||||
"layers": 80,
|
||||
"repo": {
|
||||
"MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit",
|
||||
"TinygradDynamicShardInferenceEngine": "NousResearch/Meta-Llama-3.1-70B-Instruct",
|
||||
},
|
||||
},
|
||||
"llama-3.1-70b-bf16": {
|
||||
"layers": 80,
|
||||
"repo": {
|
||||
"MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-70B-Instruct-bf16-CORRECTED",
|
||||
"TinygradDynamicShardInferenceEngine": "NousResearch/Meta-Llama-3.1-70B-Instruct",
|
||||
},
|
||||
},
|
||||
"llama-3-8b": {
|
||||
"layers": 32,
|
||||
"repo": {
|
||||
"MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
|
||||
"TinygradDynamicShardInferenceEngine": "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R",
|
||||
},
|
||||
},
|
||||
"llama-3-70b": {
|
||||
"layers": 80,
|
||||
"repo": {
|
||||
"MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3-70B-Instruct-4bit",
|
||||
"TinygradDynamicShardInferenceEngine": "TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-70B-R",
|
||||
},
|
||||
},
|
||||
"llama-3.1-405b": { "layers": 126, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-405B-4bit", }, },
|
||||
"llama-3.1-405b-8bit": { "layers": 126, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Meta-Llama-3.1-405B-Instruct-8bit", }, },
|
||||
### mistral
|
||||
"mistral-nemo": { "layers": 40, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Mistral-Nemo-Instruct-2407-4bit", }, },
|
||||
"mistral-large": { "layers": 88, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Mistral-Large-Instruct-2407-4bit", }, },
|
||||
### deepseek
|
||||
"deepseek-coder-v2-lite": { "layers": 27, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", }, },
|
||||
"deepseek-coder-v2.5": { "layers": 60, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-V2.5-MLX-AQ4_1_64", }, },
|
||||
"deepseek-v3": { "layers": 61, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-V3-4bit", }, },
|
||||
"deepseek-v3-3bit": { "layers": 61, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-V3-3bit", }, },
|
||||
"deepseek-r1": { "layers": 61, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-4bit", }, },
|
||||
"deepseek-r1-3bit": { "layers": 61, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-3bit", }, },
|
||||
### deepseek distills
|
||||
"deepseek-r1-distill-qwen-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/deepseek-r1-distill-qwen-1.5b", }, },
|
||||
"deepseek-r1-distill-qwen-1.5b-3bit": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-3bit", }, },
|
||||
"deepseek-r1-distill-qwen-1.5b-6bit": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-6bit", }, },
|
||||
"deepseek-r1-distill-qwen-1.5b-8bit": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-8bit", }, },
|
||||
"deepseek-r1-distill-qwen-1.5b-bf16": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-bf16", }, },
|
||||
"deepseek-r1-distill-qwen-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-7B-4bit", }, },
|
||||
"deepseek-r1-distill-qwen-7b-3bit": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-7B-3bit", }, },
|
||||
"deepseek-r1-distill-qwen-7b-6bit": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-7B-6bit", }, },
|
||||
"deepseek-r1-distill-qwen-7b-8bit": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-7B-8bit", }, },
|
||||
"deepseek-r1-distill-qwen-7b-bf16": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-7B-bf16", }, },
|
||||
"deepseek-r1-distill-qwen-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-14B-4bit", }, },
|
||||
"deepseek-r1-distill-qwen-14b-3bit": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-14B-3bit", }, },
|
||||
"deepseek-r1-distill-qwen-14b-6bit": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-14B-6bit", }, },
|
||||
"deepseek-r1-distill-qwen-14b-8bit": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-14B-8bit", }, },
|
||||
"deepseek-r1-distill-qwen-14b-bf16": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-14B-bf16", }, },
|
||||
"deepseek-r1-distill-qwen-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-32B-4bit", }, },
|
||||
"deepseek-r1-distill-qwen-32b-3bit": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-32B-3bit", }, },
|
||||
"deepseek-r1-distill-qwen-32b-6bit": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-32B-6bit", }, },
|
||||
"deepseek-r1-distill-qwen-32b-8bit": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-32B-MLX-8Bit", }, },
|
||||
"deepseek-r1-distill-qwen-32b-bf16": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Qwen-32B-bf16", }, },
|
||||
"deepseek-r1-distill-llama-8b": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-8B-4bit", }, },
|
||||
"deepseek-r1-distill-llama-8b-3bit": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-8B-3bit", }, },
|
||||
"deepseek-r1-distill-llama-8b-6bit": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-8B-6bit", }, },
|
||||
"deepseek-r1-distill-llama-8b-8bit": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-8B-8bit", }, },
|
||||
"deepseek-r1-distill-llama-8b-bf16": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-8B-bf16", }, },
|
||||
"deepseek-r1-distill-llama-70b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-70B-4bit", }, },
|
||||
"deepseek-r1-distill-llama-70b-3bit": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-70B-3bit", }, },
|
||||
"deepseek-r1-distill-llama-70b-6bit": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-70B-6bit", }, },
|
||||
"deepseek-r1-distill-llama-70b-8bit": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/DeepSeek-R1-Distill-Llama-70B-8bit", }, },
|
||||
### llava
|
||||
"llava-1.5-7b-hf": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "llava-hf/llava-1.5-7b-hf", }, },
|
||||
### qwen
|
||||
"qwen-2.5-0.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-0.5B-Instruct-4bit", }, },
|
||||
"qwen-2.5-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-1.5B-Instruct-4bit", }, },
|
||||
"qwen-2.5-coder-1.5b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-1.5B-Instruct-4bit", }, },
|
||||
"qwen-2.5-3b": { "layers": 36, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-3B-Instruct-4bit", }, },
|
||||
"qwen-2.5-coder-3b": { "layers": 36, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-3B-Instruct-4bit", }, },
|
||||
"qwen-2.5-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-7B-Instruct-4bit", }, },
|
||||
"qwen-2.5-coder-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-7B-Instruct-4bit", }, },
|
||||
"qwen-2.5-math-7b": { "layers": 28, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-7B-Instruct-4bit", }, },
|
||||
"qwen-2.5-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-14B-Instruct-4bit", }, },
|
||||
"qwen-2.5-coder-14b": { "layers": 48, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-14B-Instruct-4bit", }, },
|
||||
"qwen-2.5-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-32B-Instruct-4bit", }, },
|
||||
"qwen-2.5-coder-32b": { "layers": 64, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Coder-32B-Instruct-4bit", }, },
|
||||
"qwen-2.5-72b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-72B-Instruct-4bit", }, },
|
||||
"qwen-2.5-math-72b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Qwen2.5-Math-72B-Instruct-4bit", }, },
|
||||
### nemotron
|
||||
"nemotron-70b": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/nvidia_Llama-3.1-Nemotron-70B-Instruct-HF_4bit", }, },
|
||||
"nemotron-70b-bf16": { "layers": 80, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Llama-3.1-Nemotron-70B-Instruct-HF-bf16", }, },
|
||||
# gemma
|
||||
"gemma2-9b": { "layers": 42, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-9b-it-4bit", }, },
|
||||
"gemma2-27b": { "layers": 46, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/gemma-2-27b-it-4bit", }, },
|
||||
# stable diffusion
|
||||
"stable-diffusion-2-1-base": { "layers": 31, "repo": { "MLXDynamicShardInferenceEngine": "stabilityai/stable-diffusion-2-1-base" } },
|
||||
# phi
|
||||
"phi-3.5-mini": { "layers": 32, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/Phi-3.5-mini-instruct-4bit", }, },
|
||||
"phi-4": { "layers": 40, "repo": { "MLXDynamicShardInferenceEngine": "mlx-community/phi-4-4bit", }, },
|
||||
# dummy
|
||||
"dummy": { "layers": 8, "repo": { "DummyInferenceEngine": "dummy", }, },
|
||||
}
|
||||
|
||||
pretty_name = {
|
||||
"llama-3.3-70b": "Llama 3.3 70B",
|
||||
"llama-3.2-1b": "Llama 3.2 1B",
|
||||
"llama-3.2-1b-8bit": "Llama 3.2 1B (8-bit)",
|
||||
"llama-3.2-3b": "Llama 3.2 3B",
|
||||
"llama-3.2-3b-8bit": "Llama 3.2 3B (8-bit)",
|
||||
"llama-3.2-3b-bf16": "Llama 3.2 3B (BF16)",
|
||||
"llama-3.1-8b": "Llama 3.1 8B",
|
||||
"llama-3.1-70b": "Llama 3.1 70B",
|
||||
"llama-3.1-70b-bf16": "Llama 3.1 70B (BF16)",
|
||||
"llama-3.1-405b": "Llama 3.1 405B",
|
||||
"llama-3.1-405b-8bit": "Llama 3.1 405B (8-bit)",
|
||||
"gemma2-9b": "Gemma2 9B",
|
||||
"gemma2-27b": "Gemma2 27B",
|
||||
"nemotron-70b": "Nemotron 70B",
|
||||
"nemotron-70b-bf16": "Nemotron 70B (BF16)",
|
||||
"mistral-nemo": "Mistral Nemo",
|
||||
"mistral-large": "Mistral Large",
|
||||
"deepseek-coder-v2-lite": "Deepseek Coder V2 Lite",
|
||||
"deepseek-coder-v2.5": "Deepseek Coder V2.5",
|
||||
"deepseek-v3": "Deepseek V3 (4-bit)",
|
||||
"deepseek-v3-3bit": "Deepseek V3 (3-bit)",
|
||||
"deepseek-r1": "Deepseek R1 (4-bit)",
|
||||
"deepseek-r1-3bit": "Deepseek R1 (3-bit)",
|
||||
"llava-1.5-7b-hf": "LLaVa 1.5 7B (Vision Model)",
|
||||
"qwen-2.5-0.5b": "Qwen 2.5 0.5B",
|
||||
"qwen-2.5-1.5b": "Qwen 2.5 1.5B",
|
||||
"qwen-2.5-coder-1.5b": "Qwen 2.5 Coder 1.5B",
|
||||
"qwen-2.5-3b": "Qwen 2.5 3B",
|
||||
"qwen-2.5-coder-3b": "Qwen 2.5 Coder 3B",
|
||||
"qwen-2.5-7b": "Qwen 2.5 7B",
|
||||
"qwen-2.5-coder-7b": "Qwen 2.5 Coder 7B",
|
||||
"qwen-2.5-math-7b": "Qwen 2.5 7B (Math)",
|
||||
"qwen-2.5-14b": "Qwen 2.5 14B",
|
||||
"qwen-2.5-coder-14b": "Qwen 2.5 Coder 14B",
|
||||
"qwen-2.5-32b": "Qwen 2.5 32B",
|
||||
"qwen-2.5-coder-32b": "Qwen 2.5 Coder 32B",
|
||||
"qwen-2.5-72b": "Qwen 2.5 72B",
|
||||
"qwen-2.5-math-72b": "Qwen 2.5 72B (Math)",
|
||||
"phi-3.5-mini": "Phi-3.5 Mini",
|
||||
"phi-4": "Phi-4",
|
||||
"llama-3-8b": "Llama 3 8B",
|
||||
"llama-3-70b": "Llama 3 70B",
|
||||
"stable-diffusion-2-1-base": "Stable Diffusion 2.1",
|
||||
"deepseek-r1-distill-qwen-1.5b": "DeepSeek R1 Distill Qwen 1.5B",
|
||||
"deepseek-r1-distill-qwen-1.5b-3bit": "DeepSeek R1 Distill Qwen 1.5B (3-bit)",
|
||||
"deepseek-r1-distill-qwen-1.5b-6bit": "DeepSeek R1 Distill Qwen 1.5B (6-bit)",
|
||||
"deepseek-r1-distill-qwen-1.5b-8bit": "DeepSeek R1 Distill Qwen 1.5B (8-bit)",
|
||||
"deepseek-r1-distill-qwen-1.5b-bf16": "DeepSeek R1 Distill Qwen 1.5B (BF16)",
|
||||
"deepseek-r1-distill-qwen-7b": "DeepSeek R1 Distill Qwen 7B",
|
||||
"deepseek-r1-distill-qwen-7b-3bit": "DeepSeek R1 Distill Qwen 7B (3-bit)",
|
||||
"deepseek-r1-distill-qwen-7b-6bit": "DeepSeek R1 Distill Qwen 7B (6-bit)",
|
||||
"deepseek-r1-distill-qwen-7b-8bit": "DeepSeek R1 Distill Qwen 7B (8-bit)",
|
||||
"deepseek-r1-distill-qwen-7b-bf16": "DeepSeek R1 Distill Qwen 7B (BF16)",
|
||||
"deepseek-r1-distill-qwen-14b": "DeepSeek R1 Distill Qwen 14B",
|
||||
"deepseek-r1-distill-qwen-14b-3bit": "DeepSeek R1 Distill Qwen 14B (3-bit)",
|
||||
"deepseek-r1-distill-qwen-14b-6bit": "DeepSeek R1 Distill Qwen 14B (6-bit)",
|
||||
"deepseek-r1-distill-qwen-14b-8bit": "DeepSeek R1 Distill Qwen 14B (8-bit)",
|
||||
"deepseek-r1-distill-qwen-14b-bf16": "DeepSeek R1 Distill Qwen 14B (BF16)",
|
||||
"deepseek-r1-distill-qwen-32b": "DeepSeek R1 Distill Qwen 32B",
|
||||
"deepseek-r1-distill-qwen-32b-3bit": "DeepSeek R1 Distill Qwen 32B (3-bit)",
|
||||
"deepseek-r1-distill-qwen-32b-8bit": "DeepSeek R1 Distill Qwen 32B (8-bit)",
|
||||
"deepseek-r1-distill-qwen-32b-bf16": "DeepSeek R1 Distill Qwen 32B (BF16)",
|
||||
"deepseek-r1-distill-llama-8b-8bit": "DeepSeek R1 Distill Llama 8B (8-bit)",
|
||||
"deepseek-r1-distill-llama-70b-6bit": "DeepSeek R1 Distill Llama 70B (6-bit)",
|
||||
"deepseek-r1-distill-llama-70b-8bit": "DeepSeek R1 Distill Llama 70B (8-bit)",
|
||||
"deepseek-r1-distill-llama-8b": "DeepSeek R1 Distill Llama 8B",
|
||||
"deepseek-r1-distill-llama-8b-3bit": "DeepSeek R1 Distill Llama 8B (3-bit)",
|
||||
"deepseek-r1-distill-llama-8b-6bit": "DeepSeek R1 Distill Llama 8B (6-bit)",
|
||||
"deepseek-r1-distill-llama-8b-8bit": "DeepSeek R1 Distill Llama 8B (8-bit)",
|
||||
"deepseek-r1-distill-llama-8b-bf16": "DeepSeek R1 Distill Llama 8B (BF16)",
|
||||
"deepseek-r1-distill-llama-70b": "DeepSeek R1 Distill Llama 70B",
|
||||
"deepseek-r1-distill-llama-70b-3bit": "DeepSeek R1 Distill Llama 70B (3-bit)",
|
||||
"deepseek-r1-distill-llama-70b-6bit": "DeepSeek R1 Distill Llama 70B (6-bit)",
|
||||
"deepseek-r1-distill-llama-70b-8bit": "DeepSeek R1 Distill Llama 70B (8-bit)",
|
||||
"deepseek-r1-distill-qwen-32b-6bit": "DeepSeek R1 Distill Qwen 32B (6-bit)",
|
||||
}
|
||||
|
||||
def get_repo(model_id: str, inference_engine_classname: str) -> Optional[str]:
|
||||
return model_cards.get(model_id, {}).get("repo", {}).get(inference_engine_classname, None)
|
||||
|
||||
def get_pretty_name(model_id: str) -> Optional[str]:
|
||||
return pretty_name.get(model_id, None)
|
||||
|
||||
def build_base_shard(model_id: str, inference_engine_classname: str) -> Optional[Shard]:
|
||||
repo = get_repo(model_id, inference_engine_classname)
|
||||
n_layers = model_cards.get(model_id, {}).get("layers", 0)
|
||||
if repo is None or n_layers < 1:
|
||||
return None
|
||||
return Shard(model_id, 0, 0, n_layers)
|
||||
|
||||
def build_full_shard(model_id: str, inference_engine_classname: str) -> Optional[Shard]:
|
||||
base_shard = build_base_shard(model_id, inference_engine_classname)
|
||||
if base_shard is None: return None
|
||||
return Shard(base_shard.model_id, 0, base_shard.n_layers - 1, base_shard.n_layers)
|
||||
|
||||
def get_supported_models(supported_inference_engine_lists: Optional[List[List[str]]] = None) -> List[str]:
|
||||
if not supported_inference_engine_lists:
|
||||
return list(model_cards.keys())
|
||||
|
||||
from exo.inference.inference_engine import inference_engine_classes
|
||||
supported_inference_engine_lists = [
|
||||
[inference_engine_classes[engine] if engine in inference_engine_classes else engine for engine in engine_list]
|
||||
for engine_list in supported_inference_engine_lists
|
||||
]
|
||||
|
||||
def has_any_engine(model_info: dict, engine_list: List[str]) -> bool:
|
||||
return any(engine in model_info.get("repo", {}) for engine in engine_list)
|
||||
|
||||
def supports_all_engine_lists(model_info: dict) -> bool:
|
||||
return all(has_any_engine(model_info, engine_list)
|
||||
for engine_list in supported_inference_engine_lists)
|
||||
|
||||
return [
|
||||
model_id for model_id, model_info in model_cards.items()
|
||||
if supports_all_engine_lists(model_info)
|
||||
]
|
||||
@@ -1,5 +0,0 @@
|
||||
from .discovery import Discovery
|
||||
from .peer_handle import PeerHandle
|
||||
from .server import Server
|
||||
|
||||
__all__ = ["Discovery", "PeerHandle", "Server"]
|
||||
@@ -1,17 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
from .peer_handle import PeerHandle
|
||||
|
||||
|
||||
class Discovery(ABC):
|
||||
@abstractmethod
|
||||
async def start(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def stop(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
|
||||
pass
|
||||
@@ -1,226 +0,0 @@
|
||||
import grpc
|
||||
import numpy as np
|
||||
import asyncio
|
||||
from typing import Optional, Tuple, List
|
||||
|
||||
from . import node_service_pb2
|
||||
from . import node_service_pb2_grpc
|
||||
|
||||
from ..peer_handle import PeerHandle
|
||||
from exo.inference.shard import Shard
|
||||
from exo.topology.topology import Topology
|
||||
from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
|
||||
from exo.helpers import DEBUG
|
||||
import json
|
||||
import platform
|
||||
|
||||
if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
|
||||
import mlx.core as mx
|
||||
else:
|
||||
import numpy as mx
|
||||
|
||||
|
||||
class GRPCPeerHandle(PeerHandle):
|
||||
def __init__(self, _id: str, address: str, desc: str, device_capabilities: DeviceCapabilities):
|
||||
self._id = _id
|
||||
self.address = address
|
||||
self.desc = desc
|
||||
self._device_capabilities = device_capabilities
|
||||
self.channel = None
|
||||
self.stub = None
|
||||
self.channel_options = [
|
||||
("grpc.max_metadata_size", 32 * 1024 * 1024),
|
||||
("grpc.max_receive_message_length", 256 * 1024 * 1024),
|
||||
("grpc.max_send_message_length", 256 * 1024 * 1024),
|
||||
("grpc.max_concurrent_streams", 100),
|
||||
("grpc.http2.min_time_between_pings_ms", 10000),
|
||||
("grpc.keepalive_time_ms", 10000),
|
||||
("grpc.keepalive_timeout_ms", 5000),
|
||||
("grpc.keepalive_permit_without_calls", 1),
|
||||
("grpc.http2.max_pings_without_data", 0),
|
||||
("grpc.http2.min_ping_interval_without_data_ms", 5000),
|
||||
("grpc.tcp_nodelay", 1),
|
||||
("grpc.optimization_target", "throughput"),
|
||||
]
|
||||
|
||||
def id(self) -> str:
|
||||
return self._id
|
||||
|
||||
def addr(self) -> str:
|
||||
return self.address
|
||||
|
||||
def description(self) -> str:
|
||||
return self.desc
|
||||
|
||||
def device_capabilities(self) -> DeviceCapabilities:
|
||||
return self._device_capabilities
|
||||
|
||||
async def connect(self):
|
||||
self.channel = grpc.aio.insecure_channel(
|
||||
self.address,
|
||||
options=self.channel_options,
|
||||
compression=grpc.Compression.Gzip
|
||||
)
|
||||
self.stub = node_service_pb2_grpc.NodeServiceStub(self.channel)
|
||||
await asyncio.wait_for(self.channel.channel_ready(), timeout=10.0)
|
||||
|
||||
async def is_connected(self) -> bool:
|
||||
return self.channel is not None and self.channel.get_state() == grpc.ChannelConnectivity.READY
|
||||
|
||||
async def disconnect(self):
|
||||
if self.channel:
|
||||
await self.channel.close()
|
||||
self.channel = None
|
||||
self.stub = None
|
||||
|
||||
async def _ensure_connected(self):
|
||||
if not (await self.is_connected()):
|
||||
try:
|
||||
await asyncio.wait_for(self.connect(), timeout=10.0)
|
||||
except asyncio.TimeoutError:
|
||||
if DEBUG >= 2: print(f"Connection timeout for {self._id}@{self.address}")
|
||||
await self.disconnect()
|
||||
raise
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
try:
|
||||
await self._ensure_connected()
|
||||
request = node_service_pb2.HealthCheckRequest()
|
||||
response = await asyncio.wait_for(self.stub.HealthCheck(request), timeout=5)
|
||||
return response.is_healthy
|
||||
except asyncio.TimeoutError:
|
||||
return False
|
||||
except Exception:
|
||||
if DEBUG >= 4:
|
||||
print(f"Health check failed for {self._id}@{self.address}.")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
async def send_prompt(self, shard: Shard, prompt: str, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]:
|
||||
await self._ensure_connected()
|
||||
request = node_service_pb2.PromptRequest(
|
||||
prompt=prompt,
|
||||
shard=node_service_pb2.Shard(
|
||||
model_id=shard.model_id,
|
||||
start_layer=shard.start_layer,
|
||||
end_layer=shard.end_layer,
|
||||
n_layers=shard.n_layers,
|
||||
),
|
||||
request_id=request_id,
|
||||
inference_state=None if inference_state is None else self.serialize_inference_state(inference_state)
|
||||
)
|
||||
await self.stub.SendPrompt(request)
|
||||
|
||||
async def send_tensor(self, shard: Shard, tensor: np.ndarray, inference_state: Optional[dict] = None, request_id: Optional[str] = None) -> Optional[np.array]:
|
||||
await self._ensure_connected()
|
||||
request = node_service_pb2.TensorRequest(
|
||||
shard=node_service_pb2.Shard(
|
||||
model_id=shard.model_id,
|
||||
start_layer=shard.start_layer,
|
||||
end_layer=shard.end_layer,
|
||||
n_layers=shard.n_layers,
|
||||
),
|
||||
tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)),
|
||||
request_id=request_id,
|
||||
inference_state=None if inference_state is None else self.serialize_inference_state(inference_state)
|
||||
)
|
||||
response = await self.stub.SendTensor(request)
|
||||
|
||||
if not response.tensor_data or not response.shape or not response.dtype:
|
||||
return None
|
||||
|
||||
return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
|
||||
|
||||
async def send_example(self, shard: Shard, example: np.ndarray, target: np.ndarray, length: np.ndarray, train: bool, request_id: Optional[str] = None) -> Optional[np.array]:
|
||||
await self._ensure_connected()
|
||||
request = node_service_pb2.ExampleRequest(
|
||||
shard=node_service_pb2.Shard(
|
||||
model_id=shard.model_id,
|
||||
start_layer=shard.start_layer,
|
||||
end_layer=shard.end_layer,
|
||||
n_layers=shard.n_layers,
|
||||
),
|
||||
example=node_service_pb2.Tensor(tensor_data=example.tobytes(), shape=example.shape, dtype=str(example.dtype)),
|
||||
target=node_service_pb2.Tensor(tensor_data=target.tobytes(), shape=target.shape, dtype=str(target.dtype)),
|
||||
length=node_service_pb2.Tensor(tensor_data=length.tobytes(), shape=length.shape, dtype=str(length.dtype)),
|
||||
train=train,
|
||||
request_id=request_id,
|
||||
)
|
||||
response = await self.stub.SendExample(request)
|
||||
loss = response.loss
|
||||
if train and not shard.is_first_layer():
|
||||
grads = np.frombuffer(response.grads.tensor_data, dtype=np.dtype(response.grads.dtype)).reshape(response.grads.shape)
|
||||
return loss, grads
|
||||
else:
|
||||
return loss
|
||||
|
||||
async def send_loss(self, shard: Shard, tensor: np.ndarray, request_id: Optional[str] = None) -> Optional[np.array]:
|
||||
await self._ensure_connected()
|
||||
request = node_service_pb2.TensorRequest(
|
||||
shard=node_service_pb2.Shard(
|
||||
model_id=shard.model_id,
|
||||
start_layer=shard.start_layer,
|
||||
end_layer=shard.end_layer,
|
||||
n_layers=shard.n_layers,
|
||||
),
|
||||
tensor=node_service_pb2.Tensor(tensor_data=tensor.tobytes(), shape=tensor.shape, dtype=str(tensor.dtype)),
|
||||
request_id=request_id,
|
||||
)
|
||||
response = await self.stub.SendLoss(request)
|
||||
|
||||
if not response.tensor_data or not response.shape or not response.dtype:
|
||||
return None
|
||||
|
||||
return np.frombuffer(response.tensor_data, dtype=np.dtype(response.dtype)).reshape(response.shape)
|
||||
|
||||
async def collect_topology(self, visited: set[str], max_depth: int) -> Topology:
|
||||
await self._ensure_connected()
|
||||
request = node_service_pb2.CollectTopologyRequest(visited=visited, max_depth=max_depth)
|
||||
response = await self.stub.CollectTopology(request)
|
||||
topology = Topology()
|
||||
for node_id, capabilities in response.nodes.items():
|
||||
device_capabilities = DeviceCapabilities(
|
||||
model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory, flops=DeviceFlops(fp16=capabilities.flops.fp16, fp32=capabilities.flops.fp32, int8=capabilities.flops.int8)
|
||||
)
|
||||
topology.update_node(node_id, device_capabilities)
|
||||
for node_id, peer_connections in response.peer_graph.items():
|
||||
for conn in peer_connections.connections:
|
||||
topology.add_edge(node_id, conn.to_id, conn.description)
|
||||
return topology
|
||||
|
||||
async def send_result(self, request_id: str, result: List[int], is_finished: bool) -> None:
|
||||
await self._ensure_connected()
|
||||
tensor = None
|
||||
if isinstance(result, np.ndarray):
|
||||
tensor = node_service_pb2.Tensor(tensor_data=result.tobytes(), shape=result.shape, dtype=str(result.dtype))
|
||||
result = []
|
||||
request = node_service_pb2.SendResultRequest(request_id=request_id, result=result, tensor=tensor, is_finished=is_finished)
|
||||
await self.stub.SendResult(request)
|
||||
|
||||
async def send_opaque_status(self, request_id: str, status: str) -> None:
|
||||
await self._ensure_connected()
|
||||
request = node_service_pb2.SendOpaqueStatusRequest(request_id=request_id, status=status)
|
||||
await asyncio.wait_for(self.stub.SendOpaqueStatus(request), timeout=10.0)
|
||||
|
||||
def serialize_inference_state(self, inference_state: dict) -> node_service_pb2.InferenceState:
|
||||
proto_inference_state = node_service_pb2.InferenceState()
|
||||
other_data = {}
|
||||
for k, v in inference_state.items():
|
||||
if isinstance(v, mx.array):
|
||||
np_array = np.array(v)
|
||||
tensor_data = node_service_pb2.Tensor(tensor_data=np_array.tobytes(), shape=list(np_array.shape), dtype=str(np_array.dtype))
|
||||
proto_inference_state.tensor_data[k].CopyFrom(tensor_data)
|
||||
elif isinstance(v, list) and all(isinstance(item, mx.array) for item in v):
|
||||
tensor_list = node_service_pb2.TensorList()
|
||||
for tensor in v:
|
||||
np_array = np.array(tensor)
|
||||
tensor_data = node_service_pb2.Tensor(tensor_data=np_array.tobytes(), shape=list(np_array.shape), dtype=str(np_array.dtype))
|
||||
tensor_list.tensors.append(tensor_data)
|
||||
proto_inference_state.tensor_list_data[k].CopyFrom(tensor_list)
|
||||
else:
|
||||
# For non-tensor data, we'll still use JSON
|
||||
other_data[k] = v
|
||||
if other_data:
|
||||
proto_inference_state.other_data_json = json.dumps(other_data)
|
||||
return proto_inference_state
|
||||
@@ -1,173 +0,0 @@
|
||||
import grpc
|
||||
from concurrent import futures
|
||||
import numpy as np
|
||||
from asyncio import CancelledError
|
||||
|
||||
import platform
|
||||
|
||||
from . import node_service_pb2
|
||||
from . import node_service_pb2_grpc
|
||||
from exo import DEBUG
|
||||
from exo.inference.shard import Shard
|
||||
from exo.orchestration import Node
|
||||
import json
|
||||
|
||||
if platform.system().lower() == "darwin" and platform.machine().lower() == "arm64":
|
||||
import mlx.core as mx
|
||||
else:
|
||||
import numpy as mx
|
||||
|
||||
|
||||
class GRPCServer(node_service_pb2_grpc.NodeServiceServicer):
|
||||
def __init__(self, node: Node, host: str, port: int):
|
||||
self.node = node
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.server = None
|
||||
|
||||
async def start(self) -> None:
|
||||
self.server = grpc.aio.server(
|
||||
futures.ThreadPoolExecutor(max_workers=32),
|
||||
options=[
|
||||
("grpc.max_metadata_size", 32*1024*1024),
|
||||
("grpc.max_send_message_length", 256*1024*1024),
|
||||
("grpc.max_receive_message_length", 256*1024*1024),
|
||||
("grpc.keepalive_time_ms", 10000),
|
||||
("grpc.keepalive_timeout_ms", 5000),
|
||||
("grpc.http2.max_pings_without_data", 0),
|
||||
("grpc.http2.min_time_between_pings_ms", 10000),
|
||||
("grpc.http2.min_ping_interval_without_data_ms", 5000),
|
||||
("grpc.max_concurrent_streams", 100),
|
||||
("grpc.tcp_nodelay", 1),
|
||||
("grpc.optimization_target", "throughput"),
|
||||
("grpc.keepalive_permit_without_calls", 1),
|
||||
("grpc.http2.max_concurrent_streams", 0), # Unlimited concurrent streams
|
||||
],
|
||||
)
|
||||
node_service_pb2_grpc.add_NodeServiceServicer_to_server(self, self.server)
|
||||
listen_addr = f"{self.host}:{self.port}"
|
||||
self.server.add_insecure_port(listen_addr)
|
||||
await self.server.start()
|
||||
if DEBUG >= 1: print(f"Server started, listening on {listen_addr}")
|
||||
|
||||
async def stop(self) -> None:
|
||||
if self.server:
|
||||
try:
|
||||
await self.server.stop(grace=5)
|
||||
await self.server.wait_for_termination()
|
||||
except CancelledError:
|
||||
pass
|
||||
if DEBUG >= 1: print("Server stopped and all connections are closed")
|
||||
|
||||
async def SendPrompt(self, request, context):
|
||||
shard = Shard(
|
||||
model_id=request.shard.model_id,
|
||||
start_layer=request.shard.start_layer,
|
||||
end_layer=request.shard.end_layer,
|
||||
n_layers=request.shard.n_layers,
|
||||
)
|
||||
prompt = request.prompt
|
||||
request_id = request.request_id
|
||||
inference_state = None if request.inference_state is None else self.deserialize_inference_state(request.inference_state)
|
||||
result = await self.node.process_prompt(shard, prompt, request_id, inference_state)
|
||||
if DEBUG >= 5: print(f"SendPrompt {shard=} {prompt=} {request_id=} result: {result}")
|
||||
tensor_data = result.tobytes() if result is not None else None
|
||||
return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
|
||||
|
||||
async def SendTensor(self, request, context):
|
||||
shard = Shard(
|
||||
model_id=request.shard.model_id,
|
||||
start_layer=request.shard.start_layer,
|
||||
end_layer=request.shard.end_layer,
|
||||
n_layers=request.shard.n_layers,
|
||||
)
|
||||
tensor = np.frombuffer(request.tensor.tensor_data, dtype=np.dtype(request.tensor.dtype)).reshape(request.tensor.shape)
|
||||
request_id = request.request_id
|
||||
|
||||
inference_state = None if request.inference_state is None else self.deserialize_inference_state(request.inference_state)
|
||||
|
||||
result = await self.node.process_tensor(shard, tensor, request_id, inference_state)
|
||||
if DEBUG >= 5: print(f"SendTensor tensor {shard=} {tensor=} {request_id=} result: {result}")
|
||||
tensor_data = result.tobytes() if result is not None else None
|
||||
return node_service_pb2.Tensor(tensor_data=tensor_data, shape=result.shape, dtype=str(result.dtype)) if result is not None else node_service_pb2.Tensor()
|
||||
|
||||
async def SendExample(self, request, context):
|
||||
shard = Shard(
|
||||
model_id=request.shard.model_id,
|
||||
start_layer=request.shard.start_layer,
|
||||
end_layer=request.shard.end_layer,
|
||||
n_layers=request.shard.n_layers,
|
||||
)
|
||||
example = np.frombuffer(request.example.tensor_data, dtype=np.dtype(request.example.dtype)).reshape(request.example.shape)
|
||||
target = np.frombuffer(request.target.tensor_data, dtype=np.dtype(request.target.dtype)).reshape(request.target.shape)
|
||||
length = np.frombuffer(request.length.tensor_data, dtype=np.dtype(request.length.dtype)).reshape(request.length.shape)
|
||||
train = request.train
|
||||
request_id = request.request_id
|
||||
|
||||
if train and not shard.is_first_layer():
|
||||
loss, grad = await self.node.process_example(shard, example, target, length, train, request_id)
|
||||
tensor_data = grad.tobytes()
|
||||
grad_tensor = node_service_pb2.Tensor(tensor_data=tensor_data, shape=grad.shape, dtype=str(grad.dtype))
|
||||
return node_service_pb2.Loss(loss=loss, grads=grad_tensor)
|
||||
else:
|
||||
loss = await self.node.process_example(shard, example, target, length, train, request_id)
|
||||
return node_service_pb2.Loss(loss=loss, grads=None)
|
||||
|
||||
async def CollectTopology(self, request, context):
|
||||
max_depth = request.max_depth
|
||||
visited = set(request.visited)
|
||||
topology = self.node.current_topology
|
||||
nodes = {
|
||||
node_id:
|
||||
node_service_pb2.DeviceCapabilities(
|
||||
model=cap.model,
|
||||
chip=cap.chip,
|
||||
memory=cap.memory,
|
||||
flops=node_service_pb2.DeviceFlops(fp32=cap.flops.fp32, fp16=cap.flops.fp16, int8=cap.flops.int8),
|
||||
)
|
||||
for node_id, cap in topology.nodes.items()
|
||||
}
|
||||
peer_graph = {
|
||||
node_id: node_service_pb2.PeerConnections(connections=[node_service_pb2.PeerConnection(to_id=conn.to_id, description=conn.description) for conn in connections])
|
||||
for node_id, connections in topology.peer_graph.items()
|
||||
}
|
||||
if DEBUG >= 5: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}")
|
||||
return node_service_pb2.Topology(nodes=nodes, peer_graph=peer_graph)
|
||||
|
||||
async def SendResult(self, request, context):
|
||||
request_id = request.request_id
|
||||
result = request.result
|
||||
is_finished = request.is_finished
|
||||
img = request.tensor
|
||||
if DEBUG >= 5: print(f"Received SendResult request: {request_id=} {result=} {is_finished=}")
|
||||
result = list(result)
|
||||
if len(img.tensor_data) > 0:
|
||||
result = np.frombuffer(img.tensor_data, dtype=np.dtype(img.dtype)).reshape(img.shape)
|
||||
self.node.on_token.trigger_all(request_id, result, is_finished)
|
||||
return node_service_pb2.Empty()
|
||||
|
||||
async def SendOpaqueStatus(self, request, context):
|
||||
request_id = request.request_id
|
||||
status = request.status
|
||||
if DEBUG >= 8: print(f"Received SendOpaqueStatus request: {request_id=} {status=}")
|
||||
self.node.on_opaque_status.trigger_all(request_id, status)
|
||||
return node_service_pb2.Empty()
|
||||
|
||||
async def HealthCheck(self, request, context):
|
||||
return node_service_pb2.HealthCheckResponse(is_healthy=True)
|
||||
|
||||
def deserialize_inference_state(self, inference_state_proto: node_service_pb2.InferenceState) -> dict:
|
||||
inference_state = {}
|
||||
|
||||
for k, tensor_data in inference_state_proto.tensor_data.items():
|
||||
np_array = np.frombuffer(tensor_data.tensor_data, dtype=tensor_data.dtype).reshape(tensor_data.shape)
|
||||
inference_state[k] = mx.array(np_array)
|
||||
|
||||
for k, tensor_list in inference_state_proto.tensor_list_data.items():
|
||||
inference_state[k] = [mx.array(np.frombuffer(tensor.tensor_data, dtype=tensor.dtype).reshape(tensor.shape)) for tensor in tensor_list.tensors]
|
||||
|
||||
if inference_state_proto.other_data_json:
|
||||
other_data = json.loads(inference_state_proto.other_data_json)
|
||||
inference_state.update(other_data)
|
||||
|
||||
return inference_state
|
||||
@@ -1,116 +0,0 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package node_service;
|
||||
|
||||
service NodeService {
|
||||
rpc SendPrompt (PromptRequest) returns (Tensor) {}
|
||||
rpc SendTensor (TensorRequest) returns (Tensor) {}
|
||||
rpc SendExample (ExampleRequest) returns (Loss) {}
|
||||
rpc CollectTopology (CollectTopologyRequest) returns (Topology) {}
|
||||
rpc SendResult (SendResultRequest) returns (Empty) {}
|
||||
rpc SendOpaqueStatus (SendOpaqueStatusRequest) returns (Empty) {}
|
||||
rpc HealthCheck (HealthCheckRequest) returns (HealthCheckResponse) {}
|
||||
}
|
||||
|
||||
message Shard {
|
||||
string model_id = 1;
|
||||
int32 start_layer = 2;
|
||||
int32 end_layer = 3;
|
||||
int32 n_layers = 4;
|
||||
}
|
||||
|
||||
message PromptRequest {
|
||||
Shard shard = 1;
|
||||
string prompt = 2;
|
||||
optional string request_id = 3;
|
||||
optional InferenceState inference_state = 4;
|
||||
}
|
||||
|
||||
message TensorRequest {
|
||||
Shard shard = 1;
|
||||
Tensor tensor = 2;
|
||||
optional string request_id = 3;
|
||||
optional InferenceState inference_state = 4;
|
||||
}
|
||||
|
||||
message ExampleRequest {
|
||||
Shard shard = 1;
|
||||
Tensor example = 2;
|
||||
Tensor target = 3;
|
||||
Tensor length = 4;
|
||||
bool train = 5;
|
||||
optional string request_id = 6;
|
||||
}
|
||||
|
||||
message Loss {
|
||||
float loss = 1;
|
||||
optional Tensor grads = 2;
|
||||
}
|
||||
|
||||
message Tensor {
|
||||
bytes tensor_data = 1;
|
||||
repeated int32 shape = 2;
|
||||
string dtype = 3;
|
||||
}
|
||||
|
||||
message TensorList {
|
||||
repeated Tensor tensors = 1;
|
||||
}
|
||||
|
||||
message InferenceState {
|
||||
map<string, Tensor> tensor_data = 1;
|
||||
map<string, TensorList> tensor_list_data = 2;
|
||||
string other_data_json = 3;
|
||||
}
|
||||
|
||||
message CollectTopologyRequest {
|
||||
repeated string visited = 1;
|
||||
int32 max_depth = 2;
|
||||
}
|
||||
|
||||
message Topology {
|
||||
map<string, DeviceCapabilities> nodes = 1;
|
||||
map<string, PeerConnections> peer_graph = 2;
|
||||
}
|
||||
|
||||
message PeerConnection {
|
||||
string to_id = 1;
|
||||
optional string description = 2;
|
||||
}
|
||||
|
||||
message PeerConnections {
|
||||
repeated PeerConnection connections = 1;
|
||||
}
|
||||
|
||||
message DeviceFlops {
|
||||
double fp32 = 1;
|
||||
double fp16 = 2;
|
||||
double int8 = 3;
|
||||
}
|
||||
|
||||
message DeviceCapabilities {
|
||||
string model = 1;
|
||||
string chip = 2;
|
||||
int32 memory = 3;
|
||||
DeviceFlops flops = 4;
|
||||
}
|
||||
|
||||
message SendResultRequest {
|
||||
string request_id = 1;
|
||||
repeated int32 result = 2;
|
||||
optional Tensor tensor = 3;
|
||||
bool is_finished = 4;
|
||||
}
|
||||
|
||||
message SendOpaqueStatusRequest {
|
||||
string request_id = 1;
|
||||
string status = 2;
|
||||
}
|
||||
|
||||
message HealthCheckRequest {}
|
||||
|
||||
message HealthCheckResponse {
|
||||
bool is_healthy = 1;
|
||||
}
|
||||
|
||||
message Empty {}
|
||||
@@ -1,90 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# NO CHECKED-IN PROTOBUF GENCODE
|
||||
# source: node_service.proto
|
||||
# Protobuf Python Version: 5.27.2
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import runtime_version as _runtime_version
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
_runtime_version.ValidateProtobufRuntimeVersion(
|
||||
_runtime_version.Domain.PUBLIC,
|
||||
5,
|
||||
27,
|
||||
2,
|
||||
'',
|
||||
'node_service.proto'
|
||||
)
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"\xbb\x01\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x12:\n\x0finference_state\x18\x04 \x01(\x0b\x32\x1c.node_service.InferenceStateH\x01\x88\x01\x01\x42\r\n\x0b_request_idB\x12\n\x10_inference_state\"\xd1\x01\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12\x17\n\nrequest_id\x18\x03 \x01(\tH\x00\x88\x01\x01\x12:\n\x0finference_state\x18\x04 \x01(\x0b\x32\x1c.node_service.InferenceStateH\x01\x88\x01\x01\x42\r\n\x0b_request_idB\x12\n\x10_inference_state\"\xde\x01\n\x0e\x45xampleRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12%\n\x07\x65xample\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12$\n\x06target\x18\x03 \x01(\x0b\x32\x14.node_service.Tensor\x12$\n\x06length\x18\x04 \x01(\x0b\x32\x14.node_service.Tensor\x12\r\n\x05train\x18\x05 \x01(\x08\x12\x17\n\nrequest_id\x18\x06 \x01(\tH\x00\x88\x01\x01\x42\r\n\x0b_request_id\"H\n\x04Loss\x12\x0c\n\x04loss\x18\x01 \x01(\x02\x12(\n\x05grads\x18\x02 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x42\x08\n\x06_grads\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"3\n\nTensorList\x12%\n\x07tensors\x18\x01 \x03(\x0b\x32\x14.node_service.Tensor\"\xd2\x02\n\x0eInferenceState\x12\x41\n\x0btensor_data\x18\x01 \x03(\x0b\x32,.node_service.InferenceState.TensorDataEntry\x12J\n\x10tensor_list_data\x18\x02 \x03(\x0b\x32\x30.node_service.InferenceState.TensorListDataEntry\x12\x17\n\x0fother_data_json\x18\x03 \x01(\t\x1aG\n\x0fTensorDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor:\x02\x38\x01\x1aO\n\x13TensorListDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\'\n\x05value\x18\x02 \x01(\x0b\x32\x18.node_service.TensorList:\x02\x38\x01\"<\n\x16\x43ollectTopologyRequest\x12\x0f\n\x07visited\x18\x01 \x03(\t\x12\x11\n\tmax_depth\x18\x02 \x01(\x05\"\x98\x02\n\x08Topology\x12\x30\n\x05nodes\x18\x01 \x03(\x0b\x32!.node_service.Topology.NodesEntry\x12\x39\n\npeer_graph\x18\x02 \x03(\x0b\x32%.node_service.Topology.PeerGraphEntry\x1aN\n\nNodesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .node_service.DeviceCapabilities:\x02\x38\x01\x1aO\n\x0ePeerGraphEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12,\n\x05value\x18\x02 \x01(\x0b\x32\x1d.node_service.PeerConnections:\x02\x38\x01\"I\n\x0ePeerConnection\x12\r\n\x05to_id\x18\x01 \x01(\t\x12\x18\n\x0b\x64\x65scription\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x0e\n\x0c_description\"D\n\x0fPeerConnections\x12\x31\n\x0b\x63onnections\x18\x01 \x03(\x0b\x32\x1c.node_service.PeerConnection\"7\n\x0b\x44\x65viceFlops\x12\x0c\n\x04\x66p32\x18\x01 \x01(\x01\x12\x0c\n\x04\x66p16\x18\x02 \x01(\x01\x12\x0c\n\x04int8\x18\x03 \x01(\x01\"k\n\x12\x44\x65viceCapabilities\x12\r\n\x05model\x18\x01 \x01(\t\x12\x0c\n\x04\x63hip\x18\x02 \x01(\t\x12\x0e\n\x06memory\x18\x03 \x01(\x05\x12(\n\x05\x66lops\x18\x04 \x01(\x0b\x32\x19.node_service.DeviceFlops\"\x82\x01\n\x11SendResultRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06result\x18\x02 \x03(\x05\x12)\n\x06tensor\x18\x03 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x12\x13\n\x0bis_finished\x18\x04 \x01(\x08\x42\t\n\x07_tensor\"=\n\x17SendOpaqueStatusRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\t\"\x14\n\x12HealthCheckRequest\")\n\x13HealthCheckResponse\x12\x12\n\nis_healthy\x18\x01 \x01(\x08\"\x07\n\x05\x45mpty2\x97\x04\n\x0bNodeService\x12\x41\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\x0bSendExample\x12\x1c.node_service.ExampleRequest\x1a\x12.node_service.Loss\"\x00\x12Q\n\x0f\x43ollectTopology\x12$.node_service.CollectTopologyRequest\x1a\x16.node_service.Topology\"\x00\x12\x44\n\nSendResult\x12\x1f.node_service.SendResultRequest\x1a\x13.node_service.Empty\"\x00\x12P\n\x10SendOpaqueStatus\x12%.node_service.SendOpaqueStatusRequest\x1a\x13.node_service.Empty\"\x00\x12T\n\x0bHealthCheck\x12 .node_service.HealthCheckRequest\x1a!.node_service.HealthCheckResponse\"\x00\x62\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'node_service_pb2', _globals)
|
||||
if not _descriptor._USE_C_DESCRIPTORS:
|
||||
DESCRIPTOR._loaded_options = None
|
||||
_globals['_INFERENCESTATE_TENSORDATAENTRY']._loaded_options = None
|
||||
_globals['_INFERENCESTATE_TENSORDATAENTRY']._serialized_options = b'8\001'
|
||||
_globals['_INFERENCESTATE_TENSORLISTDATAENTRY']._loaded_options = None
|
||||
_globals['_INFERENCESTATE_TENSORLISTDATAENTRY']._serialized_options = b'8\001'
|
||||
_globals['_TOPOLOGY_NODESENTRY']._loaded_options = None
|
||||
_globals['_TOPOLOGY_NODESENTRY']._serialized_options = b'8\001'
|
||||
_globals['_TOPOLOGY_PEERGRAPHENTRY']._loaded_options = None
|
||||
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_options = b'8\001'
|
||||
_globals['_SHARD']._serialized_start=36
|
||||
_globals['_SHARD']._serialized_end=119
|
||||
_globals['_PROMPTREQUEST']._serialized_start=122
|
||||
_globals['_PROMPTREQUEST']._serialized_end=309
|
||||
_globals['_TENSORREQUEST']._serialized_start=312
|
||||
_globals['_TENSORREQUEST']._serialized_end=521
|
||||
_globals['_EXAMPLEREQUEST']._serialized_start=524
|
||||
_globals['_EXAMPLEREQUEST']._serialized_end=746
|
||||
_globals['_LOSS']._serialized_start=748
|
||||
_globals['_LOSS']._serialized_end=820
|
||||
_globals['_TENSOR']._serialized_start=822
|
||||
_globals['_TENSOR']._serialized_end=881
|
||||
_globals['_TENSORLIST']._serialized_start=883
|
||||
_globals['_TENSORLIST']._serialized_end=934
|
||||
_globals['_INFERENCESTATE']._serialized_start=937
|
||||
_globals['_INFERENCESTATE']._serialized_end=1275
|
||||
_globals['_INFERENCESTATE_TENSORDATAENTRY']._serialized_start=1123
|
||||
_globals['_INFERENCESTATE_TENSORDATAENTRY']._serialized_end=1194
|
||||
_globals['_INFERENCESTATE_TENSORLISTDATAENTRY']._serialized_start=1196
|
||||
_globals['_INFERENCESTATE_TENSORLISTDATAENTRY']._serialized_end=1275
|
||||
_globals['_COLLECTTOPOLOGYREQUEST']._serialized_start=1277
|
||||
_globals['_COLLECTTOPOLOGYREQUEST']._serialized_end=1337
|
||||
_globals['_TOPOLOGY']._serialized_start=1340
|
||||
_globals['_TOPOLOGY']._serialized_end=1620
|
||||
_globals['_TOPOLOGY_NODESENTRY']._serialized_start=1461
|
||||
_globals['_TOPOLOGY_NODESENTRY']._serialized_end=1539
|
||||
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_start=1541
|
||||
_globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_end=1620
|
||||
_globals['_PEERCONNECTION']._serialized_start=1622
|
||||
_globals['_PEERCONNECTION']._serialized_end=1695
|
||||
_globals['_PEERCONNECTIONS']._serialized_start=1697
|
||||
_globals['_PEERCONNECTIONS']._serialized_end=1765
|
||||
_globals['_DEVICEFLOPS']._serialized_start=1767
|
||||
_globals['_DEVICEFLOPS']._serialized_end=1822
|
||||
_globals['_DEVICECAPABILITIES']._serialized_start=1824
|
||||
_globals['_DEVICECAPABILITIES']._serialized_end=1931
|
||||
_globals['_SENDRESULTREQUEST']._serialized_start=1934
|
||||
_globals['_SENDRESULTREQUEST']._serialized_end=2064
|
||||
_globals['_SENDOPAQUESTATUSREQUEST']._serialized_start=2066
|
||||
_globals['_SENDOPAQUESTATUSREQUEST']._serialized_end=2127
|
||||
_globals['_HEALTHCHECKREQUEST']._serialized_start=2129
|
||||
_globals['_HEALTHCHECKREQUEST']._serialized_end=2149
|
||||
_globals['_HEALTHCHECKRESPONSE']._serialized_start=2151
|
||||
_globals['_HEALTHCHECKRESPONSE']._serialized_end=2192
|
||||
_globals['_EMPTY']._serialized_start=2194
|
||||
_globals['_EMPTY']._serialized_end=2201
|
||||
_globals['_NODESERVICE']._serialized_start=2204
|
||||
_globals['_NODESERVICE']._serialized_end=2739
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
@@ -1,355 +0,0 @@
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
import warnings
|
||||
|
||||
from . import node_service_pb2 as node__service__pb2
|
||||
|
||||
GRPC_GENERATED_VERSION = '1.67.0'
|
||||
GRPC_VERSION = grpc.__version__
|
||||
_version_not_supported = False
|
||||
|
||||
try:
|
||||
from grpc._utilities import first_version_is_lower
|
||||
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
|
||||
except ImportError:
|
||||
_version_not_supported = True
|
||||
|
||||
if _version_not_supported:
|
||||
raise RuntimeError(
|
||||
f'The grpc package installed is at version {GRPC_VERSION},'
|
||||
+ f' but the generated code in node_service_pb2_grpc.py depends on'
|
||||
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
|
||||
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
|
||||
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
|
||||
)
|
||||
|
||||
|
||||
class NodeServiceStub(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.SendPrompt = channel.unary_unary(
|
||||
'/node_service.NodeService/SendPrompt',
|
||||
request_serializer=node__service__pb2.PromptRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.Tensor.FromString,
|
||||
_registered_method=True)
|
||||
self.SendTensor = channel.unary_unary(
|
||||
'/node_service.NodeService/SendTensor',
|
||||
request_serializer=node__service__pb2.TensorRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.Tensor.FromString,
|
||||
_registered_method=True)
|
||||
self.SendExample = channel.unary_unary(
|
||||
'/node_service.NodeService/SendExample',
|
||||
request_serializer=node__service__pb2.ExampleRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.Loss.FromString,
|
||||
_registered_method=True)
|
||||
self.CollectTopology = channel.unary_unary(
|
||||
'/node_service.NodeService/CollectTopology',
|
||||
request_serializer=node__service__pb2.CollectTopologyRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.Topology.FromString,
|
||||
_registered_method=True)
|
||||
self.SendResult = channel.unary_unary(
|
||||
'/node_service.NodeService/SendResult',
|
||||
request_serializer=node__service__pb2.SendResultRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.SendOpaqueStatus = channel.unary_unary(
|
||||
'/node_service.NodeService/SendOpaqueStatus',
|
||||
request_serializer=node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.HealthCheck = channel.unary_unary(
|
||||
'/node_service.NodeService/HealthCheck',
|
||||
request_serializer=node__service__pb2.HealthCheckRequest.SerializeToString,
|
||||
response_deserializer=node__service__pb2.HealthCheckResponse.FromString,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
class NodeServiceServicer(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def SendPrompt(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def SendTensor(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def SendExample(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def CollectTopology(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def SendResult(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def SendOpaqueStatus(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def HealthCheck(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_NodeServiceServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'SendPrompt': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendPrompt,
|
||||
request_deserializer=node__service__pb2.PromptRequest.FromString,
|
||||
response_serializer=node__service__pb2.Tensor.SerializeToString,
|
||||
),
|
||||
'SendTensor': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendTensor,
|
||||
request_deserializer=node__service__pb2.TensorRequest.FromString,
|
||||
response_serializer=node__service__pb2.Tensor.SerializeToString,
|
||||
),
|
||||
'SendExample': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendExample,
|
||||
request_deserializer=node__service__pb2.ExampleRequest.FromString,
|
||||
response_serializer=node__service__pb2.Loss.SerializeToString,
|
||||
),
|
||||
'CollectTopology': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.CollectTopology,
|
||||
request_deserializer=node__service__pb2.CollectTopologyRequest.FromString,
|
||||
response_serializer=node__service__pb2.Topology.SerializeToString,
|
||||
),
|
||||
'SendResult': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendResult,
|
||||
request_deserializer=node__service__pb2.SendResultRequest.FromString,
|
||||
response_serializer=node__service__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'SendOpaqueStatus': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendOpaqueStatus,
|
||||
request_deserializer=node__service__pb2.SendOpaqueStatusRequest.FromString,
|
||||
response_serializer=node__service__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'HealthCheck': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.HealthCheck,
|
||||
request_deserializer=node__service__pb2.HealthCheckRequest.FromString,
|
||||
response_serializer=node__service__pb2.HealthCheckResponse.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'node_service.NodeService', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
server.add_registered_method_handlers('node_service.NodeService', rpc_method_handlers)
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class NodeService(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
@staticmethod
|
||||
def SendPrompt(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/node_service.NodeService/SendPrompt',
|
||||
node__service__pb2.PromptRequest.SerializeToString,
|
||||
node__service__pb2.Tensor.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def SendTensor(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/node_service.NodeService/SendTensor',
|
||||
node__service__pb2.TensorRequest.SerializeToString,
|
||||
node__service__pb2.Tensor.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def SendExample(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/node_service.NodeService/SendExample',
|
||||
node__service__pb2.ExampleRequest.SerializeToString,
|
||||
node__service__pb2.Loss.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def CollectTopology(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/node_service.NodeService/CollectTopology',
|
||||
node__service__pb2.CollectTopologyRequest.SerializeToString,
|
||||
node__service__pb2.Topology.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def SendResult(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/node_service.NodeService/SendResult',
|
||||
node__service__pb2.SendResultRequest.SerializeToString,
|
||||
node__service__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def SendOpaqueStatus(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/node_service.NodeService/SendOpaqueStatus',
|
||||
node__service__pb2.SendOpaqueStatusRequest.SerializeToString,
|
||||
node__service__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def HealthCheck(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/node_service.NodeService/HealthCheck',
|
||||
node__service__pb2.HealthCheckRequest.SerializeToString,
|
||||
node__service__pb2.HealthCheckResponse.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
@@ -1,101 +0,0 @@
|
||||
import os
|
||||
import asyncio
|
||||
from typing import Dict, List, Callable, Optional
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from exo.networking.discovery import Discovery
|
||||
from exo.topology.device_capabilities import DeviceCapabilities
|
||||
from exo.networking.manual.network_topology_config import NetworkTopology, PeerConfig
|
||||
from exo.helpers import DEBUG_DISCOVERY
|
||||
from exo.networking.peer_handle import PeerHandle
|
||||
|
||||
|
||||
class ManualDiscovery(Discovery):
|
||||
def __init__(
|
||||
self,
|
||||
network_config_path: str,
|
||||
node_id: str,
|
||||
create_peer_handle: Callable[[str, str, str, DeviceCapabilities], PeerHandle],
|
||||
):
|
||||
self.network_config_path = network_config_path
|
||||
self.node_id = node_id
|
||||
self.create_peer_handle = create_peer_handle
|
||||
|
||||
self.listen_task = None
|
||||
self.known_peers: Dict[str, PeerHandle] = {}
|
||||
|
||||
self._cached_peers: Dict[str, PeerConfig] = {}
|
||||
self._last_modified_time: Optional[float] = None
|
||||
self._file_executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
async def start(self) -> None:
|
||||
self.listen_task = asyncio.create_task(self.task_find_peers_from_config())
|
||||
|
||||
async def stop(self) -> None:
|
||||
if self.listen_task: self.listen_task.cancel()
|
||||
self._file_executor.shutdown(wait=True)
|
||||
|
||||
async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]:
|
||||
if wait_for_peers > 0:
|
||||
while len(self.known_peers) < wait_for_peers:
|
||||
if DEBUG_DISCOVERY >= 2: print(f"Current peers: {len(self.known_peers)}/{wait_for_peers}. Waiting for more peers...")
|
||||
await asyncio.sleep(0.1)
|
||||
if DEBUG_DISCOVERY >= 2: print(f"Discovered peers: {[peer.id() for peer in self.known_peers.values()]}")
|
||||
return list(self.known_peers.values())
|
||||
|
||||
async def task_find_peers_from_config(self):
|
||||
if DEBUG_DISCOVERY >= 2: print("Starting task to find peers from config...")
|
||||
while True:
|
||||
peers_from_config = await self._get_peers()
|
||||
new_known_peers = {}
|
||||
for peer_id, peer_config in peers_from_config.items():
|
||||
try:
|
||||
if DEBUG_DISCOVERY >= 2: print(f"Checking peer {peer_id=} at {peer_config.address}:{peer_config.port}")
|
||||
peer = self.known_peers.get(peer_id)
|
||||
if not peer:
|
||||
if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} not found in known peers. Adding.")
|
||||
peer = self.create_peer_handle(peer_id, f"{peer_config.address}:{peer_config.port}", "MAN", peer_config.device_capabilities)
|
||||
is_healthy = await peer.health_check()
|
||||
if is_healthy:
|
||||
if DEBUG_DISCOVERY >= 2: print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is healthy.")
|
||||
new_known_peers[peer_id] = peer
|
||||
elif DEBUG_DISCOVERY >= 2:
|
||||
print(f"{peer_id=} at {peer_config.address}:{peer_config.port} is not healthy. Removing.")
|
||||
except Exception as e:
|
||||
if DEBUG_DISCOVERY >= 2: print(f"Exception occurred when attempting to add {peer_id=}: {e}")
|
||||
self.known_peers = new_known_peers
|
||||
await asyncio.sleep(5.0)
|
||||
|
||||
if DEBUG_DISCOVERY >= 2: print(f"Current known peers: {[peer.id() for peer in self.known_peers.values()]}")
|
||||
|
||||
async def _get_peers(self):
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
current_mtime = await loop.run_in_executor(self._file_executor, os.path.getmtime, self.network_config_path)
|
||||
|
||||
if (self._cached_peers is not None and self._last_modified_time is not None and current_mtime <= self._last_modified_time):
|
||||
return self._cached_peers
|
||||
|
||||
topology = await loop.run_in_executor(self._file_executor, NetworkTopology.from_path, self.network_config_path)
|
||||
|
||||
if self.node_id not in topology.peers:
|
||||
raise ValueError(
|
||||
f"Node ID {self.node_id} not found in network config file "
|
||||
f"{self.network_config_path}. Please run with `node_id` set to "
|
||||
f"one of the keys in the config file: {[k for k, _ in topology.peers]}"
|
||||
)
|
||||
|
||||
peers_in_network = topology.peers
|
||||
peers_in_network.pop(self.node_id)
|
||||
|
||||
self._cached_peers = peers_in_network
|
||||
self._last_modified_time = current_mtime
|
||||
|
||||
return peers_in_network
|
||||
|
||||
except Exception as e:
|
||||
if DEBUG_DISCOVERY >= 2:
|
||||
print(f"Error when loading network config file from {self.network_config_path}. "
|
||||
f"Please update the config file in order to successfully discover peers. "
|
||||
f"Exception: {e}")
|
||||
return self._cached_peers
|
||||
@@ -1,31 +0,0 @@
|
||||
from typing import Dict
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from exo.topology.device_capabilities import DeviceCapabilities
|
||||
|
||||
|
||||
class PeerConfig(BaseModel):
|
||||
address: str
|
||||
port: int
|
||||
device_capabilities: DeviceCapabilities
|
||||
|
||||
|
||||
class NetworkTopology(BaseModel):
|
||||
"""Configuration of the network. A collection outlining all nodes in the network, including the node this is running from."""
|
||||
|
||||
peers: Dict[str, PeerConfig]
|
||||
"""
|
||||
node_id to PeerConfig. The node_id is used to identify the peer in the discovery process. The node that this is running from should be included in this dict.
|
||||
"""
|
||||
@classmethod
|
||||
def from_path(cls, path: str) -> "NetworkTopology":
|
||||
try:
|
||||
with open(path, "r") as f:
|
||||
config_data = f.read()
|
||||
except FileNotFoundError as e:
|
||||
raise FileNotFoundError(f"Config file not found at {path}") from e
|
||||
|
||||
try:
|
||||
return cls.model_validate_json(config_data)
|
||||
except ValidationError as e:
|
||||
raise ValueError(f"Error validating network topology config from {path}: {e}") from e
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user