Compare commits

..

7 Commits

Author SHA1 Message Date
Sami Khan
3e9eb93f82 exo theme 2026-02-19 04:21:58 +05:00
Sami Khan
ab622f79c3 EXO iOS app 2026-02-18 06:40:07 +05:00
Alex Cheema
d6301ed593 dashboard: redesign downloads page as model×node table (#1465)
## Motivation

The current downloads page uses a node-centric card grid layout that is
messy and hard to read — the same model across different nodes appears
in separate cards, and deep nesting wastes space. This makes it
difficult to quickly see which models are on which nodes.

## Changes

Rewrote the downloads page
(`dashboard/src/routes/downloads/+page.svelte`) from a card grid to a
clean table layout:

- **Rows** = models (unique across all nodes)
- **Columns** = nodes (with disk free shown in header)
- **Cells** show status at a glance:
  -  Green checkmark + size for completed downloads
  - 🟡 Yellow percentage + mini progress bar + speed for active downloads
  - `...` for pending downloads
  -  Red X for failed downloads
  - `--` for models not present on a node
- Delete/download action buttons appear on row hover
- Model name column is sticky on horizontal scroll (for many-node
clusters)
- Models sorted by number of nodes with completed downloads
- Imported shared utilities from `$lib/utils/downloads` instead of
inline re-implementations

### Backend: model directory in download events

- Added `model_directory` field to `BaseDownloadProgress` so all
download status events include the on-disk path
- Added `_model_dir()` helper to `DownloadCoordinator` to compute the
path from `EXO_MODELS_DIR`
- Dashboard uses this to show file location and enable "open in Finder"
for completed downloads

### Info modal

- Clicking a model name opens an info modal showing card details
(family, quantization, capabilities, storage size, layer count, tensor
parallelism support)

### Other fixes

- Fixed model name truncation in the table
- Excluded `tests/start_distributed_test.py` from pytest collection (CLI
script that calls `sys.exit()` at import time)

## Test Plan

- [x] `uv run basedpyright` — 0 errors
- [x] `uv run ruff check` — all passed
- [x] `nix fmt` — clean
- [x] `uv run pytest` — 188 passed, 1 skipped

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 14:31:47 +00:00
Evan Quiney
6d1ca6689b don't time out node identities (#1493)
currently nodes leaving and rejoining the cluster can lose their identity. We have no need to delete this data on node timing out, so let's just persist it.
2026-02-17 11:48:28 +00:00
Evan
c01b6fff21 eprint banner
our banner was being printed to stdout but should be printed to stderr
as its essentially a log message
2026-02-17 11:43:06 +00:00
Jake Hillion
8392e78afe bench: add spec for automatic canary benchmarks (#1483)
Adds all the models that can fit onto a single M3 Ultra for single
machine benchmarks. Fixes the macOS version, GPU spec, and chip type for
maximum reproducibility. Specifies the minimum memory accordingly for
each type of model, using the smallest machine available (the smallest
M3 Ultra is 96GiB).

Test plan:
- Running this with some code that makes machines of this spec available
and stores the results. It works.

This will become part of a larger testing/stability strategy once we've
collected more of the data.
2026-02-17 10:52:05 +00:00
Evan
86735ece78 begins
begins
2026-02-16 19:26:19 +00:00
69 changed files with 3848 additions and 2137 deletions

View File

@@ -1,15 +0,0 @@
.venv/
.direnv/
target/
.git/
.idea/
.pytest_cache/
.ruff_cache/
dashboard/node_modules/
dashboard/.svelte-kit/
dashboard/build/
dist/
*.pdb
**/__pycache__
**/.DS_Store
.mlx_typings/

View File

@@ -1,44 +0,0 @@
name: e2e-tests
on:
push:
branches:
- e2e-tests
pull_request:
branches:
- staging
- main
jobs:
e2e:
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- name: Free up disk space
run: |
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc \
/opt/hostedtoolcache /usr/local/share/boost /usr/share/swift \
/opt/microsoft /opt/az
docker system prune -af
df -h /
- name: Checkout repository
uses: actions/checkout@v4
with:
lfs: false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build E2E image with cache
uses: docker/build-push-action@v6
with:
context: .
file: e2e/Dockerfile
tags: exo-e2e:latest
load: true
cache-from: type=gha
cache-to: type=gha,mode=max
- name: Run E2E tests
run: python3 e2e/run_all.py

View File

@@ -0,0 +1,628 @@
// !$*UTF8*$!
{
archiveVersion = 1;
classes = {
};
objectVersion = 77;
objects = {
/* Begin PBXBuildFile section */
E09D17522F44F359009C51A3 /* MLXLLM in Frameworks */ = {isa = PBXBuildFile; productRef = E09D17512F44F359009C51A3 /* MLXLLM */; };
E09D17542F44F359009C51A3 /* MLXLMCommon in Frameworks */ = {isa = PBXBuildFile; productRef = E09D17532F44F359009C51A3 /* MLXLMCommon */; };
/* End PBXBuildFile section */
/* Begin PBXContainerItemProxy section */
E09D167D2F44CA20009C51A3 /* PBXContainerItemProxy */ = {
isa = PBXContainerItemProxy;
containerPortal = E09D16672F44CA1E009C51A3 /* Project object */;
proxyType = 1;
remoteGlobalIDString = E09D166E2F44CA1E009C51A3;
remoteInfo = "EXO-iOS";
};
E09D16872F44CA20009C51A3 /* PBXContainerItemProxy */ = {
isa = PBXContainerItemProxy;
containerPortal = E09D16672F44CA1E009C51A3 /* Project object */;
proxyType = 1;
remoteGlobalIDString = E09D166E2F44CA1E009C51A3;
remoteInfo = "EXO-iOS";
};
/* End PBXContainerItemProxy section */
/* Begin PBXFileReference section */
E09D166F2F44CA1E009C51A3 /* EXO-iOS.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = "EXO-iOS.app"; sourceTree = BUILT_PRODUCTS_DIR; };
E09D167C2F44CA20009C51A3 /* EXO-iOSTests.xctest */ = {isa = PBXFileReference; explicitFileType = wrapper.cfbundle; includeInIndex = 0; path = "EXO-iOSTests.xctest"; sourceTree = BUILT_PRODUCTS_DIR; };
E09D16862F44CA20009C51A3 /* EXO-iOSUITests.xctest */ = {isa = PBXFileReference; explicitFileType = wrapper.cfbundle; includeInIndex = 0; path = "EXO-iOSUITests.xctest"; sourceTree = BUILT_PRODUCTS_DIR; };
/* End PBXFileReference section */
/* Begin PBXFileSystemSynchronizedBuildFileExceptionSet section */
E09D169A2F44CA20009C51A3 /* Exceptions for "EXO-iOS" folder in "EXO-iOS" target */ = {
isa = PBXFileSystemSynchronizedBuildFileExceptionSet;
membershipExceptions = (
Info.plist,
);
target = E09D166E2F44CA1E009C51A3 /* EXO-iOS */;
};
/* End PBXFileSystemSynchronizedBuildFileExceptionSet section */
/* Begin PBXFileSystemSynchronizedRootGroup section */
E09D16712F44CA1E009C51A3 /* EXO-iOS */ = {
isa = PBXFileSystemSynchronizedRootGroup;
exceptions = (
E09D169A2F44CA20009C51A3 /* Exceptions for "EXO-iOS" folder in "EXO-iOS" target */,
);
path = "EXO-iOS";
sourceTree = "<group>";
};
E09D167F2F44CA20009C51A3 /* EXO-iOSTests */ = {
isa = PBXFileSystemSynchronizedRootGroup;
path = "EXO-iOSTests";
sourceTree = "<group>";
};
E09D16892F44CA20009C51A3 /* EXO-iOSUITests */ = {
isa = PBXFileSystemSynchronizedRootGroup;
path = "EXO-iOSUITests";
sourceTree = "<group>";
};
/* End PBXFileSystemSynchronizedRootGroup section */
/* Begin PBXFrameworksBuildPhase section */
E09D166C2F44CA1E009C51A3 /* Frameworks */ = {
isa = PBXFrameworksBuildPhase;
buildActionMask = 2147483647;
files = (
E09D17542F44F359009C51A3 /* MLXLMCommon in Frameworks */,
E09D17522F44F359009C51A3 /* MLXLLM in Frameworks */,
);
runOnlyForDeploymentPostprocessing = 0;
};
E09D16792F44CA20009C51A3 /* Frameworks */ = {
isa = PBXFrameworksBuildPhase;
buildActionMask = 2147483647;
files = (
);
runOnlyForDeploymentPostprocessing = 0;
};
E09D16832F44CA20009C51A3 /* Frameworks */ = {
isa = PBXFrameworksBuildPhase;
buildActionMask = 2147483647;
files = (
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXFrameworksBuildPhase section */
/* Begin PBXGroup section */
E09D16662F44CA1E009C51A3 = {
isa = PBXGroup;
children = (
E09D16712F44CA1E009C51A3 /* EXO-iOS */,
E09D167F2F44CA20009C51A3 /* EXO-iOSTests */,
E09D16892F44CA20009C51A3 /* EXO-iOSUITests */,
E09D16702F44CA1E009C51A3 /* Products */,
);
sourceTree = "<group>";
};
E09D16702F44CA1E009C51A3 /* Products */ = {
isa = PBXGroup;
children = (
E09D166F2F44CA1E009C51A3 /* EXO-iOS.app */,
E09D167C2F44CA20009C51A3 /* EXO-iOSTests.xctest */,
E09D16862F44CA20009C51A3 /* EXO-iOSUITests.xctest */,
);
name = Products;
sourceTree = "<group>";
};
/* End PBXGroup section */
/* Begin PBXNativeTarget section */
E09D166E2F44CA1E009C51A3 /* EXO-iOS */ = {
isa = PBXNativeTarget;
buildConfigurationList = E09D16902F44CA20009C51A3 /* Build configuration list for PBXNativeTarget "EXO-iOS" */;
buildPhases = (
E09D166B2F44CA1E009C51A3 /* Sources */,
E09D166C2F44CA1E009C51A3 /* Frameworks */,
E09D166D2F44CA1E009C51A3 /* Resources */,
);
buildRules = (
);
dependencies = (
);
fileSystemSynchronizedGroups = (
E09D16712F44CA1E009C51A3 /* EXO-iOS */,
);
name = "EXO-iOS";
packageProductDependencies = (
E09D17512F44F359009C51A3 /* MLXLLM */,
E09D17532F44F359009C51A3 /* MLXLMCommon */,
);
productName = "EXO-iOS";
productReference = E09D166F2F44CA1E009C51A3 /* EXO-iOS.app */;
productType = "com.apple.product-type.application";
};
E09D167B2F44CA20009C51A3 /* EXO-iOSTests */ = {
isa = PBXNativeTarget;
buildConfigurationList = E09D16932F44CA20009C51A3 /* Build configuration list for PBXNativeTarget "EXO-iOSTests" */;
buildPhases = (
E09D16782F44CA20009C51A3 /* Sources */,
E09D16792F44CA20009C51A3 /* Frameworks */,
E09D167A2F44CA20009C51A3 /* Resources */,
);
buildRules = (
);
dependencies = (
E09D167E2F44CA20009C51A3 /* PBXTargetDependency */,
);
fileSystemSynchronizedGroups = (
E09D167F2F44CA20009C51A3 /* EXO-iOSTests */,
);
name = "EXO-iOSTests";
packageProductDependencies = (
);
productName = "EXO-iOSTests";
productReference = E09D167C2F44CA20009C51A3 /* EXO-iOSTests.xctest */;
productType = "com.apple.product-type.bundle.unit-test";
};
E09D16852F44CA20009C51A3 /* EXO-iOSUITests */ = {
isa = PBXNativeTarget;
buildConfigurationList = E09D16962F44CA20009C51A3 /* Build configuration list for PBXNativeTarget "EXO-iOSUITests" */;
buildPhases = (
E09D16822F44CA20009C51A3 /* Sources */,
E09D16832F44CA20009C51A3 /* Frameworks */,
E09D16842F44CA20009C51A3 /* Resources */,
);
buildRules = (
);
dependencies = (
E09D16882F44CA20009C51A3 /* PBXTargetDependency */,
);
fileSystemSynchronizedGroups = (
E09D16892F44CA20009C51A3 /* EXO-iOSUITests */,
);
name = "EXO-iOSUITests";
packageProductDependencies = (
);
productName = "EXO-iOSUITests";
productReference = E09D16862F44CA20009C51A3 /* EXO-iOSUITests.xctest */;
productType = "com.apple.product-type.bundle.ui-testing";
};
/* End PBXNativeTarget section */
/* Begin PBXProject section */
E09D16672F44CA1E009C51A3 /* Project object */ = {
isa = PBXProject;
attributes = {
BuildIndependentTargetsInParallel = 1;
LastSwiftUpdateCheck = 2620;
LastUpgradeCheck = 2620;
TargetAttributes = {
E09D166E2F44CA1E009C51A3 = {
CreatedOnToolsVersion = 26.2;
};
E09D167B2F44CA20009C51A3 = {
CreatedOnToolsVersion = 26.2;
TestTargetID = E09D166E2F44CA1E009C51A3;
};
E09D16852F44CA20009C51A3 = {
CreatedOnToolsVersion = 26.2;
TestTargetID = E09D166E2F44CA1E009C51A3;
};
};
};
buildConfigurationList = E09D166A2F44CA1E009C51A3 /* Build configuration list for PBXProject "EXO-iOS" */;
developmentRegion = en;
hasScannedForEncodings = 0;
knownRegions = (
en,
Base,
);
mainGroup = E09D16662F44CA1E009C51A3;
minimizedProjectReferenceProxies = 1;
packageReferences = (
E09D17502F44F359009C51A3 /* XCRemoteSwiftPackageReference "mlx-swift-lm" */,
);
preferredProjectObjectVersion = 77;
productRefGroup = E09D16702F44CA1E009C51A3 /* Products */;
projectDirPath = "";
projectRoot = "";
targets = (
E09D166E2F44CA1E009C51A3 /* EXO-iOS */,
E09D167B2F44CA20009C51A3 /* EXO-iOSTests */,
E09D16852F44CA20009C51A3 /* EXO-iOSUITests */,
);
};
/* End PBXProject section */
/* Begin PBXResourcesBuildPhase section */
E09D166D2F44CA1E009C51A3 /* Resources */ = {
isa = PBXResourcesBuildPhase;
buildActionMask = 2147483647;
files = (
);
runOnlyForDeploymentPostprocessing = 0;
};
E09D167A2F44CA20009C51A3 /* Resources */ = {
isa = PBXResourcesBuildPhase;
buildActionMask = 2147483647;
files = (
);
runOnlyForDeploymentPostprocessing = 0;
};
E09D16842F44CA20009C51A3 /* Resources */ = {
isa = PBXResourcesBuildPhase;
buildActionMask = 2147483647;
files = (
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXResourcesBuildPhase section */
/* Begin PBXSourcesBuildPhase section */
E09D166B2F44CA1E009C51A3 /* Sources */ = {
isa = PBXSourcesBuildPhase;
buildActionMask = 2147483647;
files = (
);
runOnlyForDeploymentPostprocessing = 0;
};
E09D16782F44CA20009C51A3 /* Sources */ = {
isa = PBXSourcesBuildPhase;
buildActionMask = 2147483647;
files = (
);
runOnlyForDeploymentPostprocessing = 0;
};
E09D16822F44CA20009C51A3 /* Sources */ = {
isa = PBXSourcesBuildPhase;
buildActionMask = 2147483647;
files = (
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXSourcesBuildPhase section */
/* Begin PBXTargetDependency section */
E09D167E2F44CA20009C51A3 /* PBXTargetDependency */ = {
isa = PBXTargetDependency;
target = E09D166E2F44CA1E009C51A3 /* EXO-iOS */;
targetProxy = E09D167D2F44CA20009C51A3 /* PBXContainerItemProxy */;
};
E09D16882F44CA20009C51A3 /* PBXTargetDependency */ = {
isa = PBXTargetDependency;
target = E09D166E2F44CA1E009C51A3 /* EXO-iOS */;
targetProxy = E09D16872F44CA20009C51A3 /* PBXContainerItemProxy */;
};
/* End PBXTargetDependency section */
/* Begin XCBuildConfiguration section */
E09D168E2F44CA20009C51A3 /* Debug */ = {
isa = XCBuildConfiguration;
buildSettings = {
ALWAYS_SEARCH_USER_PATHS = NO;
ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;
CLANG_ANALYZER_NONNULL = YES;
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
CLANG_ENABLE_MODULES = YES;
CLANG_ENABLE_OBJC_ARC = YES;
CLANG_ENABLE_OBJC_WEAK = YES;
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
CLANG_WARN_BOOL_CONVERSION = YES;
CLANG_WARN_COMMA = YES;
CLANG_WARN_CONSTANT_CONVERSION = YES;
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
CLANG_WARN_EMPTY_BODY = YES;
CLANG_WARN_ENUM_CONVERSION = YES;
CLANG_WARN_INFINITE_RECURSION = YES;
CLANG_WARN_INT_CONVERSION = YES;
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
CLANG_WARN_STRICT_PROTOTYPES = YES;
CLANG_WARN_SUSPICIOUS_MOVE = YES;
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
CLANG_WARN_UNREACHABLE_CODE = YES;
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
COPY_PHASE_STRIP = NO;
DEBUG_INFORMATION_FORMAT = dwarf;
ENABLE_STRICT_OBJC_MSGSEND = YES;
ENABLE_TESTABILITY = YES;
ENABLE_USER_SCRIPT_SANDBOXING = YES;
GCC_C_LANGUAGE_STANDARD = gnu17;
GCC_DYNAMIC_NO_PIC = NO;
GCC_NO_COMMON_BLOCKS = YES;
GCC_OPTIMIZATION_LEVEL = 0;
GCC_PREPROCESSOR_DEFINITIONS = (
"DEBUG=1",
"$(inherited)",
);
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
GCC_WARN_UNDECLARED_SELECTOR = YES;
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
GCC_WARN_UNUSED_FUNCTION = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
IPHONEOS_DEPLOYMENT_TARGET = 26.2;
LOCALIZATION_PREFERS_STRING_CATALOGS = YES;
MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE;
MTL_FAST_MATH = YES;
ONLY_ACTIVE_ARCH = YES;
SDKROOT = iphoneos;
SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)";
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
};
name = Debug;
};
E09D168F2F44CA20009C51A3 /* Release */ = {
isa = XCBuildConfiguration;
buildSettings = {
ALWAYS_SEARCH_USER_PATHS = NO;
ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;
CLANG_ANALYZER_NONNULL = YES;
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
CLANG_ENABLE_MODULES = YES;
CLANG_ENABLE_OBJC_ARC = YES;
CLANG_ENABLE_OBJC_WEAK = YES;
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
CLANG_WARN_BOOL_CONVERSION = YES;
CLANG_WARN_COMMA = YES;
CLANG_WARN_CONSTANT_CONVERSION = YES;
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
CLANG_WARN_EMPTY_BODY = YES;
CLANG_WARN_ENUM_CONVERSION = YES;
CLANG_WARN_INFINITE_RECURSION = YES;
CLANG_WARN_INT_CONVERSION = YES;
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
CLANG_WARN_STRICT_PROTOTYPES = YES;
CLANG_WARN_SUSPICIOUS_MOVE = YES;
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
CLANG_WARN_UNREACHABLE_CODE = YES;
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
COPY_PHASE_STRIP = NO;
DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym";
ENABLE_NS_ASSERTIONS = NO;
ENABLE_STRICT_OBJC_MSGSEND = YES;
ENABLE_USER_SCRIPT_SANDBOXING = YES;
GCC_C_LANGUAGE_STANDARD = gnu17;
GCC_NO_COMMON_BLOCKS = YES;
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
GCC_WARN_UNDECLARED_SELECTOR = YES;
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
GCC_WARN_UNUSED_FUNCTION = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
IPHONEOS_DEPLOYMENT_TARGET = 26.2;
LOCALIZATION_PREFERS_STRING_CATALOGS = YES;
MTL_ENABLE_DEBUG_INFO = NO;
MTL_FAST_MATH = YES;
SDKROOT = iphoneos;
SWIFT_COMPILATION_MODE = wholemodule;
VALIDATE_PRODUCT = YES;
};
name = Release;
};
E09D16912F44CA20009C51A3 /* Debug */ = {
isa = XCBuildConfiguration;
buildSettings = {
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor;
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
DEVELOPMENT_TEAM = 3M3M67U93M;
ENABLE_PREVIEWS = YES;
GENERATE_INFOPLIST_FILE = YES;
INFOPLIST_FILE = "EXO-iOS/Info.plist";
INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES;
INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES;
INFOPLIST_KEY_UILaunchScreen_Generation = YES;
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
LD_RUNPATH_SEARCH_PATHS = (
"$(inherited)",
"@executable_path/Frameworks",
);
MARKETING_VERSION = 1.0;
PRODUCT_BUNDLE_IDENTIFIER = "com.exo.EXO-iOS";
PRODUCT_NAME = "$(TARGET_NAME)";
STRING_CATALOG_GENERATE_SYMBOLS = YES;
SWIFT_APPROACHABLE_CONCURRENCY = YES;
SWIFT_DEFAULT_ACTOR_ISOLATION = MainActor;
SWIFT_EMIT_LOC_STRINGS = YES;
SWIFT_UPCOMING_FEATURE_MEMBER_IMPORT_VISIBILITY = YES;
SWIFT_VERSION = 5.0;
TARGETED_DEVICE_FAMILY = "1,2";
};
name = Debug;
};
E09D16922F44CA20009C51A3 /* Release */ = {
isa = XCBuildConfiguration;
buildSettings = {
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor;
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
DEVELOPMENT_TEAM = 3M3M67U93M;
ENABLE_PREVIEWS = YES;
GENERATE_INFOPLIST_FILE = YES;
INFOPLIST_FILE = "EXO-iOS/Info.plist";
INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES;
INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES;
INFOPLIST_KEY_UILaunchScreen_Generation = YES;
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
LD_RUNPATH_SEARCH_PATHS = (
"$(inherited)",
"@executable_path/Frameworks",
);
MARKETING_VERSION = 1.0;
PRODUCT_BUNDLE_IDENTIFIER = "com.exo.EXO-iOS";
PRODUCT_NAME = "$(TARGET_NAME)";
STRING_CATALOG_GENERATE_SYMBOLS = YES;
SWIFT_APPROACHABLE_CONCURRENCY = YES;
SWIFT_DEFAULT_ACTOR_ISOLATION = MainActor;
SWIFT_EMIT_LOC_STRINGS = YES;
SWIFT_UPCOMING_FEATURE_MEMBER_IMPORT_VISIBILITY = YES;
SWIFT_VERSION = 5.0;
TARGETED_DEVICE_FAMILY = "1,2";
};
name = Release;
};
E09D16942F44CA20009C51A3 /* Debug */ = {
isa = XCBuildConfiguration;
buildSettings = {
BUNDLE_LOADER = "$(TEST_HOST)";
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
GENERATE_INFOPLIST_FILE = YES;
IPHONEOS_DEPLOYMENT_TARGET = 26.2;
MARKETING_VERSION = 1.0;
PRODUCT_BUNDLE_IDENTIFIER = "com.exo.EXO-iOSTests";
PRODUCT_NAME = "$(TARGET_NAME)";
STRING_CATALOG_GENERATE_SYMBOLS = NO;
SWIFT_APPROACHABLE_CONCURRENCY = YES;
SWIFT_EMIT_LOC_STRINGS = NO;
SWIFT_UPCOMING_FEATURE_MEMBER_IMPORT_VISIBILITY = YES;
SWIFT_VERSION = 5.0;
TARGETED_DEVICE_FAMILY = "1,2";
TEST_HOST = "$(BUILT_PRODUCTS_DIR)/EXO-iOS.app/$(BUNDLE_EXECUTABLE_FOLDER_PATH)/EXO-iOS";
};
name = Debug;
};
E09D16952F44CA20009C51A3 /* Release */ = {
isa = XCBuildConfiguration;
buildSettings = {
BUNDLE_LOADER = "$(TEST_HOST)";
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
GENERATE_INFOPLIST_FILE = YES;
IPHONEOS_DEPLOYMENT_TARGET = 26.2;
MARKETING_VERSION = 1.0;
PRODUCT_BUNDLE_IDENTIFIER = "com.exo.EXO-iOSTests";
PRODUCT_NAME = "$(TARGET_NAME)";
STRING_CATALOG_GENERATE_SYMBOLS = NO;
SWIFT_APPROACHABLE_CONCURRENCY = YES;
SWIFT_EMIT_LOC_STRINGS = NO;
SWIFT_UPCOMING_FEATURE_MEMBER_IMPORT_VISIBILITY = YES;
SWIFT_VERSION = 5.0;
TARGETED_DEVICE_FAMILY = "1,2";
TEST_HOST = "$(BUILT_PRODUCTS_DIR)/EXO-iOS.app/$(BUNDLE_EXECUTABLE_FOLDER_PATH)/EXO-iOS";
};
name = Release;
};
E09D16972F44CA20009C51A3 /* Debug */ = {
isa = XCBuildConfiguration;
buildSettings = {
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
GENERATE_INFOPLIST_FILE = YES;
MARKETING_VERSION = 1.0;
PRODUCT_BUNDLE_IDENTIFIER = "com.exo.EXO-iOSUITests";
PRODUCT_NAME = "$(TARGET_NAME)";
STRING_CATALOG_GENERATE_SYMBOLS = NO;
SWIFT_APPROACHABLE_CONCURRENCY = YES;
SWIFT_EMIT_LOC_STRINGS = NO;
SWIFT_UPCOMING_FEATURE_MEMBER_IMPORT_VISIBILITY = YES;
SWIFT_VERSION = 5.0;
TARGETED_DEVICE_FAMILY = "1,2";
TEST_TARGET_NAME = "EXO-iOS";
};
name = Debug;
};
E09D16982F44CA20009C51A3 /* Release */ = {
isa = XCBuildConfiguration;
buildSettings = {
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
GENERATE_INFOPLIST_FILE = YES;
MARKETING_VERSION = 1.0;
PRODUCT_BUNDLE_IDENTIFIER = "com.exo.EXO-iOSUITests";
PRODUCT_NAME = "$(TARGET_NAME)";
STRING_CATALOG_GENERATE_SYMBOLS = NO;
SWIFT_APPROACHABLE_CONCURRENCY = YES;
SWIFT_EMIT_LOC_STRINGS = NO;
SWIFT_UPCOMING_FEATURE_MEMBER_IMPORT_VISIBILITY = YES;
SWIFT_VERSION = 5.0;
TARGETED_DEVICE_FAMILY = "1,2";
TEST_TARGET_NAME = "EXO-iOS";
};
name = Release;
};
/* End XCBuildConfiguration section */
/* Begin XCConfigurationList section */
E09D166A2F44CA1E009C51A3 /* Build configuration list for PBXProject "EXO-iOS" */ = {
isa = XCConfigurationList;
buildConfigurations = (
E09D168E2F44CA20009C51A3 /* Debug */,
E09D168F2F44CA20009C51A3 /* Release */,
);
defaultConfigurationIsVisible = 0;
defaultConfigurationName = Release;
};
E09D16902F44CA20009C51A3 /* Build configuration list for PBXNativeTarget "EXO-iOS" */ = {
isa = XCConfigurationList;
buildConfigurations = (
E09D16912F44CA20009C51A3 /* Debug */,
E09D16922F44CA20009C51A3 /* Release */,
);
defaultConfigurationIsVisible = 0;
defaultConfigurationName = Release;
};
E09D16932F44CA20009C51A3 /* Build configuration list for PBXNativeTarget "EXO-iOSTests" */ = {
isa = XCConfigurationList;
buildConfigurations = (
E09D16942F44CA20009C51A3 /* Debug */,
E09D16952F44CA20009C51A3 /* Release */,
);
defaultConfigurationIsVisible = 0;
defaultConfigurationName = Release;
};
E09D16962F44CA20009C51A3 /* Build configuration list for PBXNativeTarget "EXO-iOSUITests" */ = {
isa = XCConfigurationList;
buildConfigurations = (
E09D16972F44CA20009C51A3 /* Debug */,
E09D16982F44CA20009C51A3 /* Release */,
);
defaultConfigurationIsVisible = 0;
defaultConfigurationName = Release;
};
/* End XCConfigurationList section */
/* Begin XCRemoteSwiftPackageReference section */
E09D17502F44F359009C51A3 /* XCRemoteSwiftPackageReference "mlx-swift-lm" */ = {
isa = XCRemoteSwiftPackageReference;
repositoryURL = "https://github.com/ml-explore/mlx-swift-lm";
requirement = {
kind = upToNextMajorVersion;
minimumVersion = 2.30.3;
};
};
/* End XCRemoteSwiftPackageReference section */
/* Begin XCSwiftPackageProductDependency section */
E09D17512F44F359009C51A3 /* MLXLLM */ = {
isa = XCSwiftPackageProductDependency;
package = E09D17502F44F359009C51A3 /* XCRemoteSwiftPackageReference "mlx-swift-lm" */;
productName = MLXLLM;
};
E09D17532F44F359009C51A3 /* MLXLMCommon */ = {
isa = XCSwiftPackageProductDependency;
package = E09D17502F44F359009C51A3 /* XCRemoteSwiftPackageReference "mlx-swift-lm" */;
productName = MLXLMCommon;
};
/* End XCSwiftPackageProductDependency section */
};
rootObject = E09D16672F44CA1E009C51A3 /* Project object */;
}

View File

@@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<Workspace
version = "1.0">
<FileRef
location = "self:">
</FileRef>
</Workspace>

View File

@@ -0,0 +1,60 @@
{
"originHash" : "facc0ac7c70363ea20f6cd1235de91dea6b06f0d00190946045a6c8ae753abc2",
"pins" : [
{
"identity" : "mlx-swift",
"kind" : "remoteSourceControl",
"location" : "https://github.com/ml-explore/mlx-swift",
"state" : {
"revision" : "6ba4827fb82c97d012eec9ab4b2de21f85c3b33d",
"version" : "0.30.6"
}
},
{
"identity" : "mlx-swift-lm",
"kind" : "remoteSourceControl",
"location" : "https://github.com/ml-explore/mlx-swift-lm",
"state" : {
"revision" : "360c5052b81cc154b04ee0933597a4ad6db4b8ae",
"version" : "2.30.3"
}
},
{
"identity" : "swift-collections",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-collections.git",
"state" : {
"revision" : "7b847a3b7008b2dc2f47ca3110d8c782fb2e5c7e",
"version" : "1.3.0"
}
},
{
"identity" : "swift-jinja",
"kind" : "remoteSourceControl",
"location" : "https://github.com/huggingface/swift-jinja.git",
"state" : {
"revision" : "d81197f35f41445bc10e94600795e68c6f5e94b0",
"version" : "2.3.1"
}
},
{
"identity" : "swift-numerics",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-numerics",
"state" : {
"revision" : "0c0290ff6b24942dadb83a929ffaaa1481df04a2",
"version" : "1.1.1"
}
},
{
"identity" : "swift-transformers",
"kind" : "remoteSourceControl",
"location" : "https://github.com/huggingface/swift-transformers",
"state" : {
"revision" : "573e5c9036c2f136b3a8a071da8e8907322403d0",
"version" : "1.1.6"
}
}
],
"version" : 3
}

View File

@@ -0,0 +1,20 @@
{
"colors" : [
{
"color" : {
"color-space" : "srgb",
"components" : {
"alpha" : "1.000",
"blue" : "0x00",
"green" : "0xD7",
"red" : "0xFF"
}
},
"idiom" : "universal"
}
],
"info" : {
"author" : "xcode",
"version" : 1
}
}

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 10 KiB

View File

@@ -0,0 +1,38 @@
{
"images" : [
{
"filename" : "AppIcon.png",
"idiom" : "universal",
"platform" : "ios",
"size" : "1024x1024"
},
{
"appearances" : [
{
"appearance" : "luminosity",
"value" : "dark"
}
],
"filename" : "AppIcon.png",
"idiom" : "universal",
"platform" : "ios",
"size" : "1024x1024"
},
{
"appearances" : [
{
"appearance" : "luminosity",
"value" : "tinted"
}
],
"filename" : "AppIcon.png",
"idiom" : "universal",
"platform" : "ios",
"size" : "1024x1024"
}
],
"info" : {
"author" : "xcode",
"version" : 1
}
}

View File

@@ -0,0 +1,6 @@
{
"info" : {
"author" : "xcode",
"version" : 1
}
}

View File

@@ -0,0 +1,21 @@
{
"images" : [
{
"filename" : "exo-logo.png",
"idiom" : "universal",
"scale" : "1x"
},
{
"idiom" : "universal",
"scale" : "2x"
},
{
"idiom" : "universal",
"scale" : "3x"
}
],
"info" : {
"author" : "xcode",
"version" : 1
}
}

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.6 KiB

View File

@@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>com.apple.developer.kernel.increased-memory-limit</key>
<true/>
</dict>
</plist>

View File

@@ -0,0 +1,67 @@
import SwiftUI
import UIKit
@main
struct EXO_iOSApp: App {
@State private var clusterService = ClusterService()
@State private var discoveryService = DiscoveryService()
@State private var localInferenceService = LocalInferenceService()
@State private var chatService: ChatService?
init() {
let darkGray = UIColor(red: 0x1F / 255.0, green: 0x1F / 255.0, blue: 0x1F / 255.0, alpha: 1)
let yellow = UIColor(red: 0xFF / 255.0, green: 0xD7 / 255.0, blue: 0x00 / 255.0, alpha: 1)
let navAppearance = UINavigationBarAppearance()
navAppearance.configureWithOpaqueBackground()
navAppearance.backgroundColor = darkGray
navAppearance.titleTextAttributes = [
.foregroundColor: yellow,
.font: UIFont.monospacedSystemFont(ofSize: 17, weight: .semibold),
]
navAppearance.largeTitleTextAttributes = [
.foregroundColor: yellow,
.font: UIFont.monospacedSystemFont(ofSize: 34, weight: .bold),
]
UINavigationBar.appearance().standardAppearance = navAppearance
UINavigationBar.appearance().compactAppearance = navAppearance
UINavigationBar.appearance().scrollEdgeAppearance = navAppearance
UINavigationBar.appearance().tintColor = yellow
}
var body: some Scene {
WindowGroup {
if let chatService {
RootView()
.environment(clusterService)
.environment(discoveryService)
.environment(chatService)
.environment(localInferenceService)
.preferredColorScheme(.dark)
.task {
await clusterService.attemptAutoReconnect()
discoveryService.startBrowsing()
await localInferenceService.prepareModel()
}
.onChange(of: discoveryService.discoveredClusters) { _, clusters in
guard !clusterService.isConnected,
case .disconnected = clusterService.connectionState,
let first = clusters.first
else { return }
Task {
await clusterService.connectToDiscoveredCluster(
first, using: discoveryService)
}
}
} else {
Color.exoBlack.onAppear {
chatService = ChatService(
clusterService: clusterService,
localInferenceService: localInferenceService
)
}
}
}
}
}

View File

@@ -0,0 +1,19 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>UIUserInterfaceStyle</key>
<string>Dark</string>
<key>CFBundleDisplayName</key>
<string>EXO</string>
<key>NSLocalNetworkUsageDescription</key>
<string>EXO needs local network access to connect to your EXO cluster.</string>
<key>NSBonjourServices</key>
<array>
<string>_exo._tcp</string>
<string>_p2p._tcp</string>
<string>_p2p._udp</string>
<string>_libp2p._udp</string>
</array>
</dict>
</plist>

View File

@@ -0,0 +1,129 @@
import Foundation
// MARK: - Request
struct ChatCompletionRequest: Encodable {
let model: String
let messages: [ChatCompletionMessageParam]
let stream: Bool
let maxTokens: Int?
let temperature: Double?
enum CodingKeys: String, CodingKey {
case model, messages, stream, temperature
case maxTokens = "max_tokens"
}
}
struct ChatCompletionMessageParam: Encodable {
let role: String
let content: String
}
// MARK: - Streaming Response
struct ChatCompletionChunk: Decodable {
let id: String
let model: String?
let choices: [StreamingChoice]
let usage: ChunkUsage?
init(id: String, model: String?, choices: [StreamingChoice], usage: ChunkUsage?) {
self.id = id
self.model = model
self.choices = choices
self.usage = usage
}
}
struct StreamingChoice: Decodable {
let index: Int
let delta: Delta
let finishReason: String?
enum CodingKeys: String, CodingKey {
case index, delta
case finishReason = "finish_reason"
}
init(index: Int, delta: Delta, finishReason: String?) {
self.index = index
self.delta = delta
self.finishReason = finishReason
}
}
struct Delta: Decodable {
let role: String?
let content: String?
init(role: String?, content: String?) {
self.role = role
self.content = content
}
}
struct ChunkUsage: Decodable {
let promptTokens: Int?
let completionTokens: Int?
let totalTokens: Int?
enum CodingKeys: String, CodingKey {
case promptTokens = "prompt_tokens"
case completionTokens = "completion_tokens"
case totalTokens = "total_tokens"
}
init(promptTokens: Int?, completionTokens: Int?, totalTokens: Int?) {
self.promptTokens = promptTokens
self.completionTokens = completionTokens
self.totalTokens = totalTokens
}
}
// MARK: - Non-Streaming Response
struct ChatCompletionResponse: Decodable {
let id: String
let model: String?
let choices: [ResponseChoice]
}
struct ResponseChoice: Decodable {
let index: Int
let message: ResponseMessage
let finishReason: String?
enum CodingKeys: String, CodingKey {
case index, message
case finishReason = "finish_reason"
}
}
struct ResponseMessage: Decodable {
let role: String?
let content: String?
}
// MARK: - Models List
struct ModelListResponse: Decodable {
let data: [ModelInfo]
}
struct ModelInfo: Decodable, Identifiable {
let id: String
let name: String?
}
// MARK: - Error
struct APIErrorResponse: Decodable {
let error: APIErrorInfo
}
struct APIErrorInfo: Decodable {
let message: String
let type: String?
let code: Int?
}

View File

@@ -0,0 +1,26 @@
import Foundation
struct ChatMessage: Identifiable, Equatable {
let id: UUID
let role: Role
var content: String
let timestamp: Date
var isStreaming: Bool
enum Role: String, Codable {
case user
case assistant
case system
}
init(
id: UUID = UUID(), role: Role, content: String, timestamp: Date = Date(),
isStreaming: Bool = false
) {
self.id = id
self.role = role
self.content = content
self.timestamp = timestamp
self.isStreaming = isStreaming
}
}

View File

@@ -0,0 +1,11 @@
import Foundation
struct ConnectionInfo: Codable, Equatable {
let host: String
let port: Int
let nodeId: String?
var baseURL: URL { URL(string: "http://\(host):\(port)")! }
static let defaultPort = 52415
}

View File

@@ -0,0 +1,34 @@
import Foundation
struct Conversation: Identifiable, Codable, Equatable {
let id: UUID
var title: String
var messages: [StoredMessage]
var modelId: String?
let createdAt: Date
init(
id: UUID = UUID(), title: String = "New Chat", messages: [StoredMessage] = [],
modelId: String? = nil, createdAt: Date = Date()
) {
self.id = id
self.title = title
self.messages = messages
self.modelId = modelId
self.createdAt = createdAt
}
}
struct StoredMessage: Identifiable, Codable, Equatable {
let id: UUID
let role: String
var content: String
let timestamp: Date
init(id: UUID = UUID(), role: String, content: String, timestamp: Date = Date()) {
self.id = id
self.role = role
self.content = content
self.timestamp = timestamp
}
}

View File

@@ -0,0 +1,227 @@
import Foundation
@Observable
@MainActor
final class ChatService {
var conversations: [Conversation] = []
var activeConversationId: UUID?
private(set) var isGenerating: Bool = false
private var currentGenerationTask: Task<Void, Never>?
private let clusterService: ClusterService
private let localInferenceService: LocalInferenceService
var canSendMessage: Bool {
clusterService.isConnected || localInferenceService.isAvailable
}
var activeConversation: Conversation? {
guard let id = activeConversationId else { return nil }
return conversations.first { $0.id == id }
}
var activeMessages: [ChatMessage] {
guard let conversation = activeConversation else { return [] }
return conversation.messages.map { stored in
ChatMessage(
id: stored.id,
role: ChatMessage.Role(rawValue: stored.role) ?? .user,
content: stored.content,
timestamp: stored.timestamp
)
}
}
init(clusterService: ClusterService, localInferenceService: LocalInferenceService) {
self.clusterService = clusterService
self.localInferenceService = localInferenceService
loadConversations()
}
// MARK: - Conversation Management
func createConversation(modelId: String? = nil) {
let conversation = Conversation(
modelId: modelId ?? clusterService.availableModels.first?.id)
conversations.insert(conversation, at: 0)
activeConversationId = conversation.id
saveConversations()
}
func deleteConversation(id: UUID) {
conversations.removeAll { $0.id == id }
if activeConversationId == id {
activeConversationId = conversations.first?.id
}
saveConversations()
}
func setActiveConversation(id: UUID) {
activeConversationId = id
}
func setModelForActiveConversation(_ modelId: String) {
guard let index = conversations.firstIndex(where: { $0.id == activeConversationId }) else {
return
}
conversations[index].modelId = modelId
saveConversations()
}
// MARK: - Messaging
func sendMessage(_ text: String) {
guard !text.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty else { return }
if activeConversation == nil {
createConversation()
}
guard let index = conversations.firstIndex(where: { $0.id == activeConversationId }) else {
return
}
let userMessage = StoredMessage(role: "user", content: text)
conversations[index].messages.append(userMessage)
if conversations[index].title == "New Chat" {
let preview = String(text.prefix(40))
conversations[index].title = preview + (text.count > 40 ? "..." : "")
}
let modelId: String
if clusterService.isConnected {
guard
let clusterId = conversations[index].modelId
?? clusterService.availableModels.first?.id
else {
let errorMessage = StoredMessage(
role: "assistant", content: "No model selected. Please select a model first.")
conversations[index].messages.append(errorMessage)
saveConversations()
return
}
modelId = clusterId
} else if localInferenceService.isAvailable {
modelId = localInferenceService.defaultModelId
} else {
let errorMessage = StoredMessage(
role: "assistant",
content: "Not connected to a cluster and local model is not available.")
conversations[index].messages.append(errorMessage)
saveConversations()
return
}
conversations[index].modelId = modelId
let assistantMessageId = UUID()
let assistantMessage = StoredMessage(
id: assistantMessageId, role: "assistant", content: "", timestamp: Date())
conversations[index].messages.append(assistantMessage)
let messagesForAPI = conversations[index].messages.dropLast().map { stored in
ChatCompletionMessageParam(role: stored.role, content: stored.content)
}
let request = ChatCompletionRequest(
model: modelId,
messages: Array(messagesForAPI),
stream: true,
maxTokens: 4096,
temperature: nil
)
let conversationId = conversations[index].id
isGenerating = true
currentGenerationTask = Task { [weak self] in
guard let self else { return }
await self.performStreaming(
request: request, conversationId: conversationId,
assistantMessageId: assistantMessageId)
}
saveConversations()
}
func cancelGeneration() {
currentGenerationTask?.cancel()
currentGenerationTask = nil
localInferenceService.cancelGeneration()
isGenerating = false
}
// MARK: - Streaming
private func performStreaming(
request: ChatCompletionRequest, conversationId: UUID, assistantMessageId: UUID
) async {
defer {
isGenerating = false
currentGenerationTask = nil
saveConversations()
}
do {
let stream =
clusterService.isConnected
? clusterService.streamChatCompletion(request: request)
: localInferenceService.streamChatCompletion(request: request)
for try await chunk in stream {
guard !Task.isCancelled else { return }
guard let content = chunk.choices.first?.delta.content, !content.isEmpty else {
continue
}
if let convIndex = conversations.firstIndex(where: { $0.id == conversationId }),
let msgIndex = conversations[convIndex].messages.firstIndex(where: {
$0.id == assistantMessageId
})
{
conversations[convIndex].messages[msgIndex].content += content
}
}
} catch {
if !Task.isCancelled {
if let convIndex = conversations.firstIndex(where: { $0.id == conversationId }),
let msgIndex = conversations[convIndex].messages.firstIndex(where: {
$0.id == assistantMessageId
})
{
if conversations[convIndex].messages[msgIndex].content.isEmpty {
conversations[convIndex].messages[msgIndex].content =
"Error: \(error.localizedDescription)"
}
}
}
}
}
// MARK: - Persistence
private static var storageURL: URL {
let documents = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask)
.first!
return documents.appendingPathComponent("exo_conversations.json")
}
private func saveConversations() {
do {
let data = try JSONEncoder().encode(conversations)
try data.write(to: Self.storageURL, options: .atomic)
} catch {
// Save failed silently
}
}
private func loadConversations() {
do {
let data = try Data(contentsOf: Self.storageURL)
conversations = try JSONDecoder().decode([Conversation].self, from: data)
activeConversationId = conversations.first?.id
} catch {
conversations = []
}
}
}

View File

@@ -0,0 +1,200 @@
import Foundation
enum ConnectionState: Equatable {
case disconnected
case connecting
case connected(ConnectionInfo)
}
struct ModelOption: Identifiable, Equatable {
let id: String
let displayName: String
}
@Observable
@MainActor
final class ClusterService {
private(set) var connectionState: ConnectionState = .disconnected
private(set) var availableModels: [ModelOption] = []
private(set) var lastError: String?
private let session: URLSession
private let decoder: JSONDecoder
private var pollingTask: Task<Void, Never>?
private static let connectionInfoKey = "exo_last_connection_info"
var isConnected: Bool {
if case .connected = connectionState { return true }
return false
}
var currentConnection: ConnectionInfo? {
if case .connected(let info) = connectionState { return info }
return nil
}
init(session: URLSession = .shared) {
self.session = session
let decoder = JSONDecoder()
self.decoder = decoder
}
// MARK: - Connection
func connect(to info: ConnectionInfo) async {
connectionState = .connecting
lastError = nil
do {
let url = info.baseURL.appendingPathComponent("node_id")
var request = URLRequest(url: url)
request.timeoutInterval = 5
request.cachePolicy = .reloadIgnoringLocalCacheData
let (_, response) = try await session.data(for: request)
guard let httpResponse = response as? HTTPURLResponse,
(200..<300).contains(httpResponse.statusCode)
else {
throw URLError(.badServerResponse)
}
connectionState = .connected(info)
persistConnection(info)
startPolling()
await fetchModels(baseURL: info.baseURL)
} catch {
connectionState = .disconnected
lastError = "Could not connect to \(info.host):\(info.port)"
}
}
func connectToDiscoveredCluster(
_ cluster: DiscoveredCluster, using discoveryService: DiscoveryService
) async {
guard case .disconnected = connectionState else { return }
connectionState = .connecting
lastError = nil
guard let info = await discoveryService.resolve(cluster) else {
connectionState = .disconnected
lastError = "Could not resolve \(cluster.name)"
return
}
connectionState = .disconnected // reset so connect() can proceed
await connect(to: info)
}
func disconnect() {
stopPolling()
connectionState = .disconnected
availableModels = []
lastError = nil
}
func attemptAutoReconnect() async {
guard case .disconnected = connectionState,
let info = loadPersistedConnection()
else { return }
await connect(to: info)
}
// MARK: - Polling
private func startPolling(interval: TimeInterval = 2.0) {
stopPolling()
pollingTask = Task { [weak self] in
while !Task.isCancelled {
try? await Task.sleep(for: .seconds(interval))
guard let self, !Task.isCancelled else { return }
guard let connection = self.currentConnection else { return }
await self.fetchModels(baseURL: connection.baseURL)
}
}
}
private func stopPolling() {
pollingTask?.cancel()
pollingTask = nil
}
// MARK: - API
private func fetchModels(baseURL: URL) async {
do {
let url = baseURL.appendingPathComponent("models")
var request = URLRequest(url: url)
request.cachePolicy = .reloadIgnoringLocalCacheData
let (data, response) = try await session.data(for: request)
guard let httpResponse = response as? HTTPURLResponse,
(200..<300).contains(httpResponse.statusCode)
else { return }
let list = try decoder.decode(ModelListResponse.self, from: data)
availableModels = list.data.map {
ModelOption(id: $0.id, displayName: $0.name ?? $0.id)
}
} catch {
// Models fetch failed silently will retry on next poll
}
}
func streamChatCompletion(request body: ChatCompletionRequest) -> AsyncThrowingStream<
ChatCompletionChunk, Error
> {
AsyncThrowingStream { continuation in
let task = Task { [weak self] in
guard let self, let connection = self.currentConnection else {
continuation.finish(throwing: URLError(.notConnectedToInternet))
return
}
do {
let url = connection.baseURL.appendingPathComponent("v1/chat/completions")
var request = URLRequest(url: url)
request.httpMethod = "POST"
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
request.httpBody = try JSONEncoder().encode(body)
let (bytes, response) = try await self.session.bytes(for: request)
guard let httpResponse = response as? HTTPURLResponse,
(200..<300).contains(httpResponse.statusCode)
else {
continuation.finish(throwing: URLError(.badServerResponse))
return
}
let parser = SSEStreamParser<ChatCompletionChunk>(
bytes: bytes, decoder: self.decoder)
for try await chunk in parser {
continuation.yield(chunk)
}
continuation.finish()
} catch {
continuation.finish(throwing: error)
}
}
continuation.onTermination = { _ in
task.cancel()
}
}
}
// MARK: - Persistence
private func persistConnection(_ info: ConnectionInfo) {
if let data = try? JSONEncoder().encode(info) {
UserDefaults.standard.set(data, forKey: Self.connectionInfoKey)
}
}
private func loadPersistedConnection() -> ConnectionInfo? {
guard let data = UserDefaults.standard.data(forKey: Self.connectionInfoKey) else {
return nil
}
return try? JSONDecoder().decode(ConnectionInfo.self, from: data)
}
}

View File

@@ -0,0 +1,123 @@
import Foundation
import Network
import os
struct DiscoveredCluster: Identifiable, Equatable {
let id: String
let name: String
let endpoint: NWEndpoint
static func == (lhs: DiscoveredCluster, rhs: DiscoveredCluster) -> Bool {
lhs.id == rhs.id && lhs.name == rhs.name
}
}
@Observable
@MainActor
final class DiscoveryService {
private(set) var discoveredClusters: [DiscoveredCluster] = []
private(set) var isSearching = false
private var browser: NWBrowser?
func startBrowsing() {
guard browser == nil else { return }
let browser = NWBrowser(for: .bonjour(type: "_exo._tcp", domain: nil), using: .tcp)
browser.stateUpdateHandler = { [weak self] state in
guard let service = self else { return }
Task { @MainActor in
switch state {
case .ready:
service.isSearching = true
case .failed, .cancelled:
service.isSearching = false
default:
break
}
}
}
browser.browseResultsChangedHandler = { [weak self] results, _ in
guard let service = self else { return }
Task { @MainActor in
service.discoveredClusters = results.compactMap { result in
guard case .service(let name, _, _, _) = result.endpoint else {
return nil
}
return DiscoveredCluster(
id: name,
name: name,
endpoint: result.endpoint
)
}
}
}
browser.start(queue: .main)
self.browser = browser
}
func stopBrowsing() {
browser?.cancel()
browser = nil
isSearching = false
discoveredClusters = []
}
/// Resolve a discovered Bonjour endpoint to an IP address and port, then return a ConnectionInfo.
func resolve(_ cluster: DiscoveredCluster) async -> ConnectionInfo? {
await withCheckedContinuation { continuation in
let didResume = OSAllocatedUnfairLock(initialState: false)
let connection = NWConnection(to: cluster.endpoint, using: .tcp)
connection.stateUpdateHandler = { state in
guard
didResume.withLock({
guard !$0 else { return false }
$0 = true
return true
})
else { return }
switch state {
case .ready:
if let innerEndpoint = connection.currentPath?.remoteEndpoint,
case .hostPort(let host, let port) = innerEndpoint
{
var hostString: String
switch host {
case .ipv4(let addr):
hostString = "\(addr)"
case .ipv6(let addr):
hostString = "\(addr)"
case .name(let name, _):
hostString = name
@unknown default:
hostString = "\(host)"
}
// Strip interface scope suffix (e.g. "%en0")
if let pct = hostString.firstIndex(of: "%") {
hostString = String(hostString[..<pct])
}
let info = ConnectionInfo(
host: hostString,
port: Int(port.rawValue),
nodeId: nil
)
connection.cancel()
continuation.resume(returning: info)
} else {
connection.cancel()
continuation.resume(returning: nil)
}
case .failed, .cancelled:
continuation.resume(returning: nil)
default:
// Not a terminal state allow future callbacks
didResume.withLock { $0 = false }
}
}
connection.start(queue: .global(qos: .userInitiated))
}
}
}

View File

@@ -0,0 +1,201 @@
import Foundation
import MLXLLM
import MLXLMCommon
enum LocalModelState: Equatable {
case notDownloaded
case downloading(progress: Double)
case downloaded
case loading
case ready
case generating
case error(String)
}
@Observable
@MainActor
final class LocalInferenceService {
private(set) var modelState: LocalModelState = .notDownloaded
private var modelContainer: ModelContainer?
private var generationTask: Task<Void, Never>?
let defaultModelId = "mlx-community/Qwen3-0.6B-4bit"
private static let modelDownloadedKey = "exo_local_model_downloaded"
var isReady: Bool {
modelState == .ready
}
var isAvailable: Bool {
modelState == .ready || modelState == .generating
}
init() {
if UserDefaults.standard.bool(forKey: Self.modelDownloadedKey) {
modelState = .downloaded
}
}
// MARK: - Model Lifecycle
func prepareModel() async {
guard modelState == .notDownloaded || modelState == .downloaded else { return }
let wasDownloaded = modelState == .downloaded
if !wasDownloaded {
modelState = .downloading(progress: 0)
} else {
modelState = .loading
}
do {
let container = try await loadModelContainer(
id: defaultModelId
) { [weak self] progress in
guard let self else { return }
Task { @MainActor in
if case .downloading = self.modelState {
self.modelState = .downloading(progress: progress.fractionCompleted)
}
}
}
self.modelContainer = container
UserDefaults.standard.set(true, forKey: Self.modelDownloadedKey)
modelState = .ready
} catch {
modelState = .error(error.localizedDescription)
}
}
func unloadModel() {
cancelGeneration()
modelContainer = nil
modelState = .downloaded
}
// MARK: - Generation
func streamChatCompletion(request: ChatCompletionRequest) -> AsyncThrowingStream<
ChatCompletionChunk, Error
> {
AsyncThrowingStream { continuation in
let task = Task { [weak self] in
guard let self else {
continuation.finish(throwing: LocalInferenceError.serviceUnavailable)
return
}
guard let container = self.modelContainer else {
continuation.finish(throwing: LocalInferenceError.modelNotLoaded)
return
}
await MainActor.run {
self.modelState = .generating
}
defer {
Task { @MainActor [weak self] in
if self?.modelState == .generating {
self?.modelState = .ready
}
}
}
let chunkId = "local-\(UUID().uuidString)"
do {
// Build Chat.Message array from the request
var chatMessages: [Chat.Message] = []
for msg in request.messages {
switch msg.role {
case "system":
chatMessages.append(.system(msg.content))
case "assistant":
chatMessages.append(.assistant(msg.content))
default:
chatMessages.append(.user(msg.content))
}
}
// Use ChatSession for streaming generation
let session = ChatSession(
container,
history: chatMessages,
generateParameters: GenerateParameters(
maxTokens: request.maxTokens ?? 4096,
temperature: Float(request.temperature ?? 0.7)
)
)
// Stream with an empty prompt since history already contains the conversation
let stream = session.streamResponse(to: "")
for try await text in stream {
if Task.isCancelled { break }
let chunk = ChatCompletionChunk(
id: chunkId,
model: request.model,
choices: [
StreamingChoice(
index: 0,
delta: Delta(role: nil, content: text),
finishReason: nil
)
],
usage: nil
)
continuation.yield(chunk)
}
// Send final chunk with finish reason
let finalChunk = ChatCompletionChunk(
id: chunkId,
model: request.model,
choices: [
StreamingChoice(
index: 0,
delta: Delta(role: nil, content: nil),
finishReason: "stop"
)
],
usage: nil
)
continuation.yield(finalChunk)
continuation.finish()
} catch {
continuation.finish(throwing: error)
}
}
self.generationTask = task
continuation.onTermination = { _ in
task.cancel()
}
}
}
func cancelGeneration() {
generationTask?.cancel()
generationTask = nil
if modelState == .generating {
modelState = .ready
}
}
}
enum LocalInferenceError: LocalizedError {
case serviceUnavailable
case modelNotLoaded
var errorDescription: String? {
switch self {
case .serviceUnavailable: "Local inference service is unavailable"
case .modelNotLoaded: "Local model is not loaded"
}
}
}

View File

@@ -0,0 +1,50 @@
import Foundation
struct SSEStreamParser<T: Decodable>: AsyncSequence {
typealias Element = T
let bytes: URLSession.AsyncBytes
let decoder: JSONDecoder
init(bytes: URLSession.AsyncBytes, decoder: JSONDecoder = JSONDecoder()) {
self.bytes = bytes
self.decoder = decoder
}
func makeAsyncIterator() -> AsyncIterator {
AsyncIterator(lines: bytes.lines, decoder: decoder)
}
struct AsyncIterator: AsyncIteratorProtocol {
var lines: AsyncLineSequence<URLSession.AsyncBytes>.AsyncIterator
let decoder: JSONDecoder
init(lines: AsyncLineSequence<URLSession.AsyncBytes>, decoder: JSONDecoder) {
self.lines = lines.makeAsyncIterator()
self.decoder = decoder
}
mutating func next() async throws -> T? {
while let line = try await lines.next() {
let trimmed = line.trimmingCharacters(in: .whitespaces)
guard trimmed.hasPrefix("data: ") else { continue }
let payload = String(trimmed.dropFirst(6))
if payload == "[DONE]" {
return nil
}
guard let data = payload.data(using: .utf8) else { continue }
do {
return try decoder.decode(T.self, from: data)
} catch {
continue
}
}
return nil
}
}
}

View File

@@ -0,0 +1,203 @@
import SwiftUI
struct ChatView: View {
@Environment(ClusterService.self) private var clusterService
@Environment(ChatService.self) private var chatService
@Environment(LocalInferenceService.self) private var localInferenceService
@State private var inputText = ""
@State private var showModelSelector = false
var body: some View {
VStack(spacing: 0) {
modelBar
GradientDivider()
messageList
GradientDivider()
inputBar
}
.background(Color.exoBlack)
.sheet(isPresented: $showModelSelector) {
ModelSelectorView(
models: clusterService.availableModels,
selectedModelId: chatService.activeConversation?.modelId
) { modelId in
chatService.setModelForActiveConversation(modelId)
}
.presentationBackground(Color.exoDarkGray)
}
}
// MARK: - Model Bar
private var useLocalModel: Bool {
!clusterService.isConnected && localInferenceService.isAvailable
}
private var modelBar: some View {
Button {
if !useLocalModel {
showModelSelector = true
}
} label: {
HStack {
Image(systemName: useLocalModel ? "iphone" : "cpu")
.font(.exoCaption)
.foregroundStyle(useLocalModel ? Color.exoYellow : Color.exoLightGray)
if useLocalModel {
Text(localInferenceService.defaultModelId)
.font(.exoSubheadline)
.foregroundStyle(Color.exoForeground)
.lineLimit(1)
} else if let modelId = chatService.activeConversation?.modelId {
Text(modelId)
.font(.exoSubheadline)
.foregroundStyle(Color.exoForeground)
.lineLimit(1)
} else {
Text("SELECT MODEL")
.font(.exoSubheadline)
.tracking(1.5)
.foregroundStyle(Color.exoLightGray)
}
Spacer()
if useLocalModel {
Text("ON-DEVICE")
.font(.exoCaption)
.tracking(1)
.foregroundStyle(Color.exoYellow)
.padding(.horizontal, 6)
.padding(.vertical, 2)
.background(Color.exoYellow.opacity(0.15))
.clipShape(Capsule())
} else {
Image(systemName: "chevron.right")
.font(.caption)
.foregroundStyle(Color.exoLightGray)
}
}
.padding(.horizontal)
.padding(.vertical, 10)
.background(Color.exoDarkGray)
}
.tint(.primary)
.disabled(useLocalModel)
}
// MARK: - Messages
private var messageList: some View {
ScrollViewReader { proxy in
ScrollView {
LazyVStack(spacing: 12) {
if chatService.activeMessages.isEmpty {
emptyState
} else {
ForEach(chatService.activeMessages) { message in
MessageBubbleView(message: message)
.id(message.id)
}
}
}
.padding()
}
.background(Color.exoBlack)
.onChange(of: chatService.activeMessages.last?.content) {
if let lastId = chatService.activeMessages.last?.id {
withAnimation(.easeOut(duration: 0.2)) {
proxy.scrollTo(lastId, anchor: .bottom)
}
}
}
}
}
private var emptyState: some View {
VStack(spacing: 16) {
Spacer(minLength: 80)
ZStack {
Circle()
.stroke(Color.exoYellow.opacity(0.15), lineWidth: 1)
.frame(width: 80, height: 80)
Circle()
.stroke(Color.exoYellow.opacity(0.3), lineWidth: 1)
.frame(width: 56, height: 56)
Circle()
.fill(Color.exoYellow.opacity(0.15))
.frame(width: 32, height: 32)
Circle()
.fill(Color.exoYellow)
.frame(width: 8, height: 8)
.shadow(color: Color.exoYellow.opacity(0.6), radius: 6)
}
Text("AWAITING INPUT")
.font(.exoSubheadline)
.tracking(3)
.foregroundStyle(Color.exoLightGray)
Text("Send a message to begin.")
.font(.exoCaption)
.foregroundStyle(Color.exoLightGray.opacity(0.6))
Spacer(minLength: 80)
}
.padding()
}
// MARK: - Input
private var inputBar: some View {
HStack(alignment: .bottom, spacing: 8) {
TextField("Message...", text: $inputText, axis: .vertical)
.font(.exoBody)
.lineLimit(1...6)
.textFieldStyle(.plain)
.padding(10)
.background(Color.exoMediumGray)
.foregroundStyle(Color.exoForeground)
.clipShape(RoundedRectangle(cornerRadius: 8))
if chatService.isGenerating {
Button {
chatService.cancelGeneration()
} label: {
Image(systemName: "stop.circle.fill")
.font(.title2)
.foregroundStyle(Color.exoDestructive)
}
} else {
Button {
let text = inputText
inputText = ""
chatService.sendMessage(text)
} label: {
Text("SEND")
.font(.exoMono(12, weight: .bold))
.tracking(1)
.foregroundStyle(canSend ? Color.exoBlack : Color.exoLightGray)
.padding(.horizontal, 14)
.padding(.vertical, 8)
.background(canSend ? Color.exoYellow : Color.exoMediumGray)
.clipShape(RoundedRectangle(cornerRadius: 8))
}
.disabled(!canSend)
}
}
.padding(.horizontal)
.padding(.vertical, 8)
.background(Color.exoDarkGray)
}
private var canSend: Bool {
!inputText.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty
&& (clusterService.isConnected || localInferenceService.isAvailable)
}
}

View File

@@ -0,0 +1,54 @@
import SwiftUI
struct MessageBubbleView: View {
let message: ChatMessage
private var isAssistant: Bool { message.role == .assistant }
var body: some View {
HStack {
if message.role == .user { Spacer(minLength: 48) }
VStack(alignment: isAssistant ? .leading : .trailing, spacing: 6) {
// Header
HStack(spacing: 4) {
if isAssistant {
Circle()
.fill(Color.exoYellow)
.frame(width: 6, height: 6)
.shadow(color: Color.exoYellow.opacity(0.6), radius: 4)
Text("EXO")
.font(.exoMono(10, weight: .bold))
.tracking(1.5)
.foregroundStyle(Color.exoYellow)
} else {
Text("QUERY")
.font(.exoMono(10, weight: .medium))
.tracking(1.5)
.foregroundStyle(Color.exoLightGray)
}
}
// Bubble
HStack(spacing: 0) {
if isAssistant {
RoundedRectangle(cornerRadius: 1)
.fill(Color.exoYellow.opacity(0.5))
.frame(width: 2)
}
Text(message.content + (message.isStreaming ? " \u{258C}" : ""))
.font(.exoBody)
.textSelection(.enabled)
.foregroundStyle(Color.exoForeground)
.padding(.horizontal, 14)
.padding(.vertical, 10)
}
.background(Color.exoDarkGray)
.clipShape(RoundedRectangle(cornerRadius: 8))
}
if isAssistant { Spacer(minLength: 48) }
}
}
}

View File

@@ -0,0 +1,75 @@
import SwiftUI
struct ModelSelectorView: View {
let models: [ModelOption]
let selectedModelId: String?
let onSelect: (String) -> Void
@Environment(\.dismiss) private var dismiss
var body: some View {
NavigationStack {
List {
if models.isEmpty {
emptyContent
} else {
modelsList
}
}
.scrollContentBackground(.hidden)
.background(Color.exoBlack)
.navigationTitle("SELECT MODEL")
.navigationBarTitleDisplayMode(.inline)
.toolbar {
ToolbarItem(placement: .cancellationAction) {
Button("Cancel") { dismiss() }
.font(.exoSubheadline)
.foregroundStyle(Color.exoYellow)
}
}
}
}
private var emptyContent: some View {
ContentUnavailableView(
"No Models Available",
systemImage: "cpu",
description: Text("Connect to an EXO cluster to see available models.")
)
.foregroundStyle(Color.exoLightGray)
.listRowBackground(Color.exoBlack)
}
private var modelsList: some View {
ForEach(models) { model in
Button {
onSelect(model.id)
dismiss()
} label: {
modelRow(model)
}
.tint(.primary)
.listRowBackground(Color.exoDarkGray)
}
}
private func modelRow(_ model: ModelOption) -> some View {
HStack {
VStack(alignment: .leading, spacing: 2) {
Text(model.displayName)
.font(.exoSubheadline)
.foregroundStyle(Color.exoForeground)
Text(model.id)
.font(.exoCaption)
.foregroundStyle(Color.exoLightGray)
}
Spacer()
if model.id == selectedModelId {
Image(systemName: "checkmark")
.font(.exoSubheadline)
.foregroundStyle(Color.exoYellow)
}
}
}
}

View File

@@ -0,0 +1,68 @@
import SwiftUI
struct ConnectionStatusBadge: View {
let connectionState: ConnectionState
var localModelState: LocalModelState = .notDownloaded
private var isLocalReady: Bool {
if case .disconnected = connectionState {
return localModelState == .ready || localModelState == .generating
}
return false
}
var body: some View {
HStack(spacing: 6) {
Circle()
.fill(dotColor)
.frame(width: 8, height: 8)
.shadow(color: dotColor.opacity(0.6), radius: 4)
Text(label.uppercased())
.font(.exoMono(10, weight: .medium))
.tracking(1)
.foregroundStyle(Color.exoForeground)
}
.padding(.horizontal, 10)
.padding(.vertical, 5)
.background(backgroundColor)
.clipShape(Capsule())
.overlay(
Capsule()
.stroke(dotColor.opacity(0.3), lineWidth: 1)
)
}
private var dotColor: Color {
if isLocalReady {
return .exoYellow
}
switch connectionState {
case .connected: return .green
case .connecting: return .orange
case .disconnected: return .exoLightGray
}
}
private var label: String {
if isLocalReady {
return "Local"
}
switch connectionState {
case .connected: return "Connected"
case .connecting: return "Connecting"
case .disconnected: return "Disconnected"
}
}
private var backgroundColor: Color {
if isLocalReady {
return Color.exoYellow.opacity(0.1)
}
switch connectionState {
case .connected: return .green.opacity(0.1)
case .connecting: return .orange.opacity(0.1)
case .disconnected: return Color.exoMediumGray.opacity(0.5)
}
}
}

View File

@@ -0,0 +1,136 @@
import SwiftUI
struct RootView: View {
@Environment(ClusterService.self) private var clusterService
@Environment(DiscoveryService.self) private var discoveryService
@Environment(ChatService.self) private var chatService
@Environment(LocalInferenceService.self) private var localInferenceService
@State private var showSettings = false
@State private var showConversations = false
var body: some View {
NavigationStack {
ChatView()
.navigationTitle("EXO")
.navigationBarTitleDisplayMode(.inline)
.toolbar {
ToolbarItem(placement: .topBarLeading) {
conversationMenuButton
}
ToolbarItem(placement: .principal) {
ConnectionStatusBadge(
connectionState: clusterService.connectionState,
localModelState: localInferenceService.modelState
)
}
ToolbarItem(placement: .topBarTrailing) {
Button {
showSettings = true
} label: {
Image(systemName: "gear")
.foregroundStyle(Color.exoYellow)
}
}
}
}
.tint(Color.exoYellow)
.sheet(isPresented: $showSettings) {
SettingsView()
.environment(discoveryService)
.presentationBackground(Color.exoDarkGray)
}
.sheet(isPresented: $showConversations) {
conversationList
.presentationBackground(Color.exoDarkGray)
}
}
// MARK: - Conversations
private var conversationMenuButton: some View {
HStack(spacing: 12) {
Button {
showConversations = true
} label: {
Image(systemName: "sidebar.left")
.foregroundStyle(Color.exoYellow)
}
Button {
chatService.createConversation()
} label: {
Image(systemName: "square.and.pencil")
.foregroundStyle(Color.exoYellow)
}
}
}
private var conversationList: some View {
NavigationStack {
List {
if chatService.conversations.isEmpty {
Text("No conversations yet")
.font(.exoBody)
.foregroundStyle(Color.exoLightGray)
.listRowBackground(Color.exoDarkGray)
} else {
ForEach(chatService.conversations) { conversation in
let isActive = conversation.id == chatService.activeConversationId
Button {
chatService.setActiveConversation(id: conversation.id)
showConversations = false
} label: {
VStack(alignment: .leading, spacing: 4) {
Text(conversation.title)
.font(.exoSubheadline)
.fontWeight(isActive ? .semibold : .regular)
.foregroundStyle(
isActive ? Color.exoYellow : Color.exoForeground
)
.lineLimit(1)
if let modelId = conversation.modelId {
Text(modelId)
.font(.exoCaption)
.foregroundStyle(Color.exoLightGray)
.lineLimit(1)
}
}
}
.listRowBackground(
isActive
? Color.exoYellow.opacity(0.1)
: Color.exoDarkGray
)
}
.onDelete { indexSet in
for index in indexSet {
chatService.deleteConversation(id: chatService.conversations[index].id)
}
}
}
}
.scrollContentBackground(.hidden)
.background(Color.exoBlack)
.navigationTitle("Conversations")
.navigationBarTitleDisplayMode(.inline)
.toolbar {
ToolbarItem(placement: .confirmationAction) {
Button("Done") { showConversations = false }
.font(.exoSubheadline)
.foregroundStyle(Color.exoYellow)
}
ToolbarItem(placement: .topBarLeading) {
Button {
chatService.createConversation()
} label: {
Image(systemName: "plus")
.foregroundStyle(Color.exoYellow)
}
}
}
}
}
}

View File

@@ -0,0 +1,314 @@
import SwiftUI
struct SettingsView: View {
@Environment(ClusterService.self) private var clusterService
@Environment(DiscoveryService.self) private var discoveryService
@Environment(LocalInferenceService.self) private var localInferenceService
@Environment(\.dismiss) private var dismiss
@State private var host: String = ""
@State private var port: String = "52415"
var body: some View {
NavigationStack {
Form {
localModelSection
nearbyClustersSection
connectionSection
statusSection
}
.scrollContentBackground(.hidden)
.background(Color.exoBlack)
.navigationTitle("Settings")
.navigationBarTitleDisplayMode(.inline)
.toolbar {
ToolbarItem(placement: .confirmationAction) {
Button("Done") { dismiss() }
.font(.exoSubheadline)
.foregroundStyle(Color.exoYellow)
}
}
}
}
// MARK: - Section Headers
private func sectionHeader(_ title: String) -> some View {
Text(title.uppercased())
.font(.exoMono(10, weight: .semibold))
.tracking(2)
.foregroundStyle(Color.exoYellow)
}
// MARK: - Local Model
private var localModelSection: some View {
Section {
HStack {
VStack(alignment: .leading, spacing: 4) {
Text(localInferenceService.defaultModelId)
.font(.exoSubheadline)
.foregroundStyle(Color.exoForeground)
Text(localModelStatusText)
.font(.exoCaption)
.foregroundStyle(Color.exoLightGray)
}
Spacer()
localModelActionButton
}
.listRowBackground(Color.exoDarkGray)
if case .downloading(let progress) = localInferenceService.modelState {
ProgressView(value: progress)
.tint(Color.exoYellow)
.listRowBackground(Color.exoDarkGray)
}
} header: {
sectionHeader("Local Model")
} footer: {
Text(
"When disconnected from a cluster, messages are processed on-device using this model."
)
.font(.exoCaption)
.foregroundStyle(Color.exoLightGray.opacity(0.7))
}
}
private var localModelStatusText: String {
switch localInferenceService.modelState {
case .notDownloaded: "Not downloaded"
case .downloading(let progress): "Downloading \(Int(progress * 100))%..."
case .downloaded: "Downloaded — not loaded"
case .loading: "Loading into memory..."
case .ready: "Ready"
case .generating: "Generating..."
case .error(let message): "Error: \(message)"
}
}
@ViewBuilder
private var localModelActionButton: some View {
switch localInferenceService.modelState {
case .notDownloaded:
exoButton("Download") {
Task { await localInferenceService.prepareModel() }
}
case .downloading:
ProgressView()
.controlSize(.small)
.tint(Color.exoYellow)
case .downloaded:
exoButton("Load") {
Task { await localInferenceService.prepareModel() }
}
case .loading:
ProgressView()
.controlSize(.small)
.tint(Color.exoYellow)
case .ready, .generating:
exoButton("Unload") {
localInferenceService.unloadModel()
}
case .error:
exoButton("Retry", destructive: true) {
Task { await localInferenceService.prepareModel() }
}
}
}
private func exoButton(_ title: String, destructive: Bool = false, action: @escaping () -> Void)
-> some View
{
let borderColor = destructive ? Color.exoDestructive : Color.exoYellow
return Button(action: action) {
Text(title.uppercased())
.font(.exoMono(11, weight: .semibold))
.tracking(1)
.foregroundStyle(borderColor)
.padding(.horizontal, 10)
.padding(.vertical, 5)
.overlay(
RoundedRectangle(cornerRadius: 6)
.stroke(borderColor, lineWidth: 1)
)
}
}
// MARK: - Nearby Clusters
private var nearbyClustersSection: some View {
Section {
if discoveryService.discoveredClusters.isEmpty {
if discoveryService.isSearching {
HStack {
ProgressView()
.tint(Color.exoYellow)
.padding(.trailing, 8)
Text("Searching for clusters...")
.font(.exoBody)
.foregroundStyle(Color.exoLightGray)
}
.listRowBackground(Color.exoDarkGray)
} else {
Text("No clusters found")
.font(.exoBody)
.foregroundStyle(Color.exoLightGray)
.listRowBackground(Color.exoDarkGray)
}
} else {
ForEach(discoveryService.discoveredClusters) { cluster in
HStack {
VStack(alignment: .leading) {
Text(cluster.name)
.font(.exoBody)
.foregroundStyle(Color.exoForeground)
}
Spacer()
exoButton("Connect") {
Task {
await clusterService.connectToDiscoveredCluster(
cluster, using: discoveryService
)
if clusterService.isConnected {
dismiss()
}
}
}
}
.listRowBackground(Color.exoDarkGray)
}
}
} header: {
sectionHeader("Nearby Clusters")
}
}
// MARK: - Manual Connection
private var connectionSection: some View {
Section {
TextField("IP Address (e.g. 192.168.1.42)", text: $host)
.font(.exoBody)
.keyboardType(.decimalPad)
.textContentType(.URL)
.autocorrectionDisabled()
.foregroundStyle(Color.exoForeground)
.listRowBackground(Color.exoDarkGray)
TextField("Port", text: $port)
.font(.exoBody)
.keyboardType(.numberPad)
.foregroundStyle(Color.exoForeground)
.listRowBackground(Color.exoDarkGray)
Button {
Task {
let portNum = Int(port) ?? ConnectionInfo.defaultPort
let info = ConnectionInfo(host: host, port: portNum, nodeId: nil)
await clusterService.connect(to: info)
if clusterService.isConnected {
dismiss()
}
}
} label: {
Text(clusterService.isConnected ? "RECONNECT" : "CONNECT")
.font(.exoMono(13, weight: .semibold))
.tracking(1.5)
.foregroundStyle(
host.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty
? Color.exoLightGray : Color.exoYellow
)
.frame(maxWidth: .infinity)
}
.disabled(host.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty)
.listRowBackground(Color.exoDarkGray)
} header: {
sectionHeader("Manual Connection")
}
}
// MARK: - Status
private var statusSection: some View {
Section {
if let connection = clusterService.currentConnection {
LabeledContent {
Text(connection.host)
.font(.exoCaption)
.foregroundStyle(Color.exoForeground)
} label: {
Text("Host")
.font(.exoCaption)
.foregroundStyle(Color.exoLightGray)
}
.listRowBackground(Color.exoDarkGray)
LabeledContent {
Text("\(connection.port)")
.font(.exoCaption)
.foregroundStyle(Color.exoForeground)
} label: {
Text("Port")
.font(.exoCaption)
.foregroundStyle(Color.exoLightGray)
}
.listRowBackground(Color.exoDarkGray)
if let nodeId = connection.nodeId {
LabeledContent {
Text(String(nodeId.prefix(12)) + "...")
.font(.exoCaption)
.foregroundStyle(Color.exoForeground)
} label: {
Text("Node ID")
.font(.exoCaption)
.foregroundStyle(Color.exoLightGray)
}
.listRowBackground(Color.exoDarkGray)
}
LabeledContent {
Text("\(clusterService.availableModels.count)")
.font(.exoCaption)
.foregroundStyle(Color.exoForeground)
} label: {
Text("Models")
.font(.exoCaption)
.foregroundStyle(Color.exoLightGray)
}
.listRowBackground(Color.exoDarkGray)
Button(role: .destructive) {
clusterService.disconnect()
} label: {
Text("DISCONNECT")
.font(.exoMono(13, weight: .semibold))
.tracking(1.5)
.foregroundStyle(Color.exoDestructive)
.frame(maxWidth: .infinity)
}
.listRowBackground(Color.exoDarkGray)
} else {
if let error = clusterService.lastError {
Label {
Text(error)
.font(.exoCaption)
} icon: {
Image(systemName: "exclamationmark.triangle")
}
.foregroundStyle(Color.exoDestructive)
.listRowBackground(Color.exoDarkGray)
} else {
Text("Not connected")
.font(.exoBody)
.foregroundStyle(Color.exoLightGray)
.listRowBackground(Color.exoDarkGray)
}
}
} header: {
sectionHeader("Status")
}
}
}

View File

@@ -0,0 +1,51 @@
import SwiftUI
// MARK: - EXO Color Palette
extension Color {
/// Primary background near-black (#121212)
static let exoBlack = Color(red: 0x12 / 255.0, green: 0x12 / 255.0, blue: 0x12 / 255.0)
/// Card / surface background (#1F1F1F)
static let exoDarkGray = Color(red: 0x1F / 255.0, green: 0x1F / 255.0, blue: 0x1F / 255.0)
/// Input field / elevated surface (#353535)
static let exoMediumGray = Color(red: 0x35 / 255.0, green: 0x35 / 255.0, blue: 0x35 / 255.0)
/// Secondary text (#999999)
static let exoLightGray = Color(red: 0x99 / 255.0, green: 0x99 / 255.0, blue: 0x99 / 255.0)
/// Accent yellow matches dashboard (#FFD700)
static let exoYellow = Color(red: 0xFF / 255.0, green: 0xD7 / 255.0, blue: 0x00 / 255.0)
/// Primary foreground text (#E5E5E5)
static let exoForeground = Color(red: 0xE5 / 255.0, green: 0xE5 / 255.0, blue: 0xE5 / 255.0)
/// Destructive / error (#E74C3C)
static let exoDestructive = Color(red: 0xE7 / 255.0, green: 0x4C / 255.0, blue: 0x3C / 255.0)
}
// MARK: - EXO Typography (SF Mono via .monospaced design)
extension Font {
/// Monospaced font at a given size and weight.
static func exoMono(_ size: CGFloat, weight: Font.Weight = .regular) -> Font {
.system(size: size, weight: weight, design: .monospaced)
}
/// Body text 15pt monospaced
static let exoBody: Font = .system(size: 15, weight: .regular, design: .monospaced)
/// Caption 11pt monospaced
static let exoCaption: Font = .system(size: 11, weight: .regular, design: .monospaced)
/// Subheadline 13pt monospaced medium
static let exoSubheadline: Font = .system(size: 13, weight: .medium, design: .monospaced)
/// Headline 17pt monospaced semibold
static let exoHeadline: Font = .system(size: 17, weight: .semibold, design: .monospaced)
}
// MARK: - Reusable Gradient Divider
struct GradientDivider: View {
var body: some View {
LinearGradient(
colors: [.clear, Color.exoYellow.opacity(0.3), .clear],
startPoint: .leading,
endPoint: .trailing
)
.frame(height: 1)
}
}

View File

@@ -0,0 +1,18 @@
//
// EXO_iOSTests.swift
// EXO-iOSTests
//
// Created by Sami Khan on 2026-02-17.
//
import Testing
@testable import EXO_iOS
struct EXO_iOSTests {
@Test func example() async throws {
// Write your test here and use APIs like `#expect(...)` to check expected conditions.
}
}

View File

@@ -0,0 +1,41 @@
//
// EXO_iOSUITests.swift
// EXO-iOSUITests
//
// Created by Sami Khan on 2026-02-17.
//
import XCTest
final class EXO_iOSUITests: XCTestCase {
override func setUpWithError() throws {
// Put setup code here. This method is called before the invocation of each test method in the class.
// In UI tests it is usually best to stop immediately when a failure occurs.
continueAfterFailure = false
// In UI tests its important to set the initial state - such as interface orientation - required for your tests before they run. The setUp method is a good place to do this.
}
override func tearDownWithError() throws {
// Put teardown code here. This method is called after the invocation of each test method in the class.
}
@MainActor
func testExample() throws {
// UI tests must launch the application that they test.
let app = XCUIApplication()
app.launch()
// Use XCTAssert and related functions to verify your tests produce the correct results.
}
@MainActor
func testLaunchPerformance() throws {
// This measures how long it takes to launch your application.
measure(metrics: [XCTApplicationLaunchMetric()]) {
XCUIApplication().launch()
}
}
}

View File

@@ -0,0 +1,33 @@
//
// EXO_iOSUITestsLaunchTests.swift
// EXO-iOSUITests
//
// Created by Sami Khan on 2026-02-17.
//
import XCTest
final class EXO_iOSUITestsLaunchTests: XCTestCase {
override class var runsForEachTargetApplicationUIConfiguration: Bool {
true
}
override func setUpWithError() throws {
continueAfterFailure = false
}
@MainActor
func testLaunch() throws {
let app = XCUIApplication()
app.launch()
// Insert steps here to perform after app launch but before taking a screenshot,
// such as logging into a test account or navigating somewhere in the app
let attachment = XCTAttachment(screenshot: app.screenshot())
attachment.name = "Launch Screen"
attachment.lifetime = .keepAlways
add(attachment)
}
}

7
bench/bench.toml Normal file
View File

@@ -0,0 +1,7 @@
# Canary benchmark manifest
#
# Lists the suite files to include. Each file defines benchmarks
# with shared constraints, topology, and default args.
include = [
"single-m3-ultra.toml",
]

189
bench/single-m3-ultra.toml Normal file
View File

@@ -0,0 +1,189 @@
# Single-node M3 Ultra benchmarks
#
# Shared constraints applied to ALL benchmarks in this file.
constraints = [
"All(MacOsBuild(=25D125))",
"Hosts(=1)",
"All(Chip(m3_ultra))",
"All(GpuCores(=80))",
]
[topology]
type = "none"
# Default args merged into each benchmark's args (benchmark-level args win).
[defaults]
pp = [512, 2048, 8192, 16384]
tg = 128
[[benchmark]]
model = "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/gpt-oss-120b-MXFP4-Q8"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-Flash-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-6bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-30B-A3B-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-0.6B-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-0.6B-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.2-1B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.2-3B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.2-3B-Instruct-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/gpt-oss-20b-MXFP4-Q8"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-30B-A3B-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-Flash-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-Flash-5bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-Flash-6bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.3-70B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-5bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.3-70B-Instruct-8bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/llama-3.3-70b-instruct-fp16"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.5-Air-8bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.5-Air-bf16"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-4bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/MiniMax-M2.1-3bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/MiniMax-M2.1-8bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-bf16"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/Step-3.5-Flash-4bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/Step-3.5-Flash-6bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/Step-3.5-Flash-8Bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/DeepSeek-V3.1-4bit"
extra_constraints = ["All(Memory(>=512GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-6bit"
extra_constraints = ["All(Memory(>=512GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-8bit-gs32"
extra_constraints = ["All(Memory(>=512GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"
extra_constraints = ["All(Memory(>=512GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"
extra_constraints = ["All(Memory(>=512GiB))"]

View File

@@ -1 +0,0 @@
collect_ignore = ["tests/start_distributed_test.py"]

View File

File diff suppressed because it is too large Load Diff

View File

@@ -1,58 +0,0 @@
# Stage 1: Build the dashboard
FROM node:22-slim AS dashboard
WORKDIR /app/dashboard
COPY dashboard/package.json dashboard/package-lock.json ./
RUN npm ci
COPY dashboard/ .
RUN npm run build
# Stage 2: Build and run exo
FROM python:3.13-slim
# Install system dependencies
# libblas-dev/liblapack-dev/liblapacke-dev are required by MLX CPU backend on Linux
RUN apt-get update && apt-get install -y \
build-essential \
pkg-config \
libssl-dev \
libblas-dev \
liblapack-dev \
liblapacke-dev \
curl \
protobuf-compiler \
iptables \
&& rm -rf /var/lib/apt/lists/*
# Install Rust nightly
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain nightly
ENV PATH="/root/.cargo/bin:${PATH}"
# Wrap g++ with -fpermissive to fix MLX CPU JIT compilation with GCC 14
# (GCC 14 treats _Float128/_Float32/_Float64 as built-in types, conflicting with MLX-generated code)
# Must be done BEFORE uv sync so any source builds also get the fix
RUN mv /usr/bin/g++ /usr/bin/g++.real && \
printf '#!/bin/sh\nexec /usr/bin/g++.real -fpermissive "$@"\n' > /usr/bin/g++ && \
chmod +x /usr/bin/g++
# Install uv
COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv
WORKDIR /app
# Copy dependency files first for better layer caching
COPY pyproject.toml Cargo.toml uv.lock README.md ./
COPY rust/ ./rust/
COPY bench/pyproject.toml ./bench/pyproject.toml
# Copy source and resources
COPY src/ ./src/
COPY resources/ ./resources/
# Copy built dashboard from stage 1
COPY --from=dashboard /app/dashboard/build ./dashboard/build/
# Install Python deps and build Rust bindings, then clean up build artifacts
# to keep the layer small (Rust target/ and cargo registry can be 1-2 GB)
RUN uv sync && rm -rf /app/rust/target /root/.cargo/registry /root/.cargo/git
CMD [".venv/bin/exo", "-v"]

View File

@@ -1,195 +0,0 @@
"""Shared E2E test infrastructure for exo cluster tests."""
import asyncio
import json
import os
import sys
from pathlib import Path
from urllib.error import URLError
from urllib.request import Request, urlopen
E2E_DIR = Path(__file__).parent.resolve()
TIMEOUT = int(os.environ.get("E2E_TIMEOUT", "120"))
class Cluster:
"""Async wrapper around a docker compose exo cluster."""
def __init__(self, name: str, overrides: list[str] | None = None):
self.name = name
self.project = f"e2e-{name}"
compose_files = [str(E2E_DIR / "docker-compose.yml")]
for path in overrides or []:
compose_files.append(str(E2E_DIR / path))
self._compose_base = [
"docker",
"compose",
"-p",
self.project,
*[arg for f in compose_files for arg in ("-f", f)],
]
async def __aenter__(self):
return self
async def __aexit__(self, *exc):
await self.stop()
async def _run(self, *args: str, check: bool = True) -> str:
proc = await asyncio.create_subprocess_exec(
*self._compose_base,
*args,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
stdout, _ = await proc.communicate()
output = stdout.decode()
if check and proc.returncode != 0:
print(output, file=sys.stderr)
raise RuntimeError(
f"docker compose {' '.join(args)} failed (rc={proc.returncode})"
)
return output
async def build(self):
# Skip build if the image was pre-built (e.g. in CI with buildx cache)
proc = await asyncio.create_subprocess_exec(
"docker",
"image",
"inspect",
"exo-e2e:latest",
stdout=asyncio.subprocess.DEVNULL,
stderr=asyncio.subprocess.DEVNULL,
)
await proc.wait()
if proc.returncode == 0:
print(" Using pre-built image (exo-e2e:latest)")
return
print(" Building images...")
await self._run("build", "--quiet")
async def start(self):
print(" Starting cluster...")
await self._run("up", "-d")
async def stop(self):
print(" Cleaning up...")
await self._run("down", "--timeout", "5", check=False)
async def logs(self) -> str:
return await self._run("logs", check=False)
async def exec(
self, service: str, *cmd: str, check: bool = True
) -> tuple[int, str]:
"""Run a command inside a running container. Returns (returncode, output)."""
proc = await asyncio.create_subprocess_exec(
*self._compose_base,
"exec",
"-T",
service,
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
stdout, _ = await proc.communicate()
output = stdout.decode()
if check and proc.returncode != 0:
raise RuntimeError(
f"exec {' '.join(cmd)} in {service} failed (rc={proc.returncode})"
)
return proc.returncode, output
async def wait_for(self, description: str, check_fn, timeout: int = TIMEOUT):
"""Poll check_fn every 2s until it returns True or timeout expires."""
print(f" Waiting for {description}...")
deadline = asyncio.get_event_loop().time() + timeout
while asyncio.get_event_loop().time() < deadline:
if await check_fn():
print(f" {description}")
return
await asyncio.sleep(2)
output = await self.logs()
print(f"--- cluster logs ---\n{output}\n---", file=sys.stderr)
raise TimeoutError(f"Timed out waiting for {description}")
async def assert_healthy(self):
"""Verify the cluster formed correctly: nodes started, discovered each other, elected a master, API responds."""
async def both_nodes_started():
log = await self.logs()
return log.count("Starting node") >= 2
async def nodes_discovered():
log = await self.logs()
return log.count("ConnectionMessageType.Connected") >= 2
async def master_elected():
log = await self.logs()
return "demoting self" in log
async def api_responding():
try:
with urlopen("http://localhost:52415/v1/models", timeout=3) as resp:
return resp.status == 200
except (URLError, OSError):
return False
await self.wait_for("Both nodes started", both_nodes_started)
await self.wait_for("Nodes discovered each other", nodes_discovered)
await self.wait_for("Master election resolved", master_elected)
await self.wait_for("API responding", api_responding)
async def _api(
self, method: str, path: str, body: dict | None = None, timeout: int = 30
) -> dict:
"""Make an API request to the cluster. Returns parsed JSON."""
url = f"http://localhost:52415{path}"
data = json.dumps(body).encode() if body else None
req = Request(
url, data=data, headers={"Content-Type": "application/json"}, method=method
)
loop = asyncio.get_event_loop()
resp_bytes = await loop.run_in_executor(
None, lambda: urlopen(req, timeout=timeout).read()
)
return json.loads(resp_bytes)
async def place_model(self, model: str, timeout: int = 600):
"""Place a model instance on the cluster (triggers download) and wait until it's ready."""
await self._api("POST", "/place_instance", {"model_id": model})
async def model_ready():
try:
resp = await self._api("GET", "/v1/models")
return any(m.get("id") == model for m in resp.get("data", []))
except Exception:
return False
await self.wait_for(f"Model {model} ready", model_ready, timeout=timeout)
async def chat(
self, model: str, messages: list[dict], timeout: int = 600, **kwargs
) -> dict:
"""Send a chat completion request. Retries until model is downloaded and inference completes."""
body = json.dumps({"model": model, "messages": messages, **kwargs}).encode()
deadline = asyncio.get_event_loop().time() + timeout
last_error = None
while asyncio.get_event_loop().time() < deadline:
try:
req = Request(
"http://localhost:52415/v1/chat/completions",
data=body,
headers={"Content-Type": "application/json"},
)
loop = asyncio.get_event_loop()
resp_bytes = await loop.run_in_executor(
None, lambda r=req: urlopen(r, timeout=300).read()
)
return json.loads(resp_bytes)
except Exception as e:
last_error = e
await asyncio.sleep(5)
raise TimeoutError(f"Chat request failed after {timeout}s: {last_error}")

View File

@@ -1,20 +0,0 @@
services:
exo-node-1:
image: exo-e2e:latest
build:
context: ..
dockerfile: e2e/Dockerfile
environment:
- EXO_LIBP2P_NAMESPACE=docker-e2e
command: [".venv/bin/exo", "-v"]
ports:
- "52415:52415"
exo-node-2:
image: exo-e2e:latest
build:
context: ..
dockerfile: e2e/Dockerfile
environment:
- EXO_LIBP2P_NAMESPACE=docker-e2e
command: [".venv/bin/exo", "-v"]

View File

@@ -1,83 +0,0 @@
#!/usr/bin/env python3
"""Discovers and runs all E2E tests in e2e/test_*.py.
Tests with '# slow' on the first line of their docstring are skipped
unless --slow is passed or E2E_SLOW=1 is set.
"""
import os
import subprocess
import sys
from pathlib import Path
E2E_DIR = Path(__file__).parent.resolve()
def is_slow(test_file: Path) -> bool:
"""Check if the test file is marked as slow (has '# slow' in first 3 lines)."""
with open(test_file) as f:
for line in f:
if line.strip().startswith("#"):
continue
if line.strip().startswith('"""') or line.strip().startswith("'''"):
# Read into the docstring
for doc_line in f:
if "slow" in doc_line.lower() and doc_line.strip().startswith(
"slow"
):
return True
if '"""' in doc_line or "'''" in doc_line:
break
break
return False
def main():
run_slow = "--slow" in sys.argv or os.environ.get("E2E_SLOW") == "1"
if "--update-snapshots" in sys.argv:
os.environ["UPDATE_SNAPSHOTS"] = "1"
test_files = sorted(E2E_DIR.glob("test_*.py"))
if not test_files:
print("No test files found")
sys.exit(1)
passed = 0
failed = 0
skipped = 0
failures = []
for test_file in test_files:
name = test_file.stem
if is_slow(test_file) and not run_slow:
print(f"=== {name} === SKIPPED (slow, use --slow to run)")
skipped += 1
continue
print(f"=== {name} ===")
result = subprocess.run([sys.executable, str(test_file)])
if result.returncode == 0:
passed += 1
else:
# Retry once — Docker networking (mDNS) can be slow on first boot
print(f"\n=== {name} === RETRYING (attempt 2/2)")
result = subprocess.run([sys.executable, str(test_file)])
if result.returncode == 0:
passed += 1
else:
failed += 1
failures.append(name)
print()
total = passed + failed + skipped
print("================================")
print(
f"{passed}/{total} tests passed" + (f", {skipped} skipped" if skipped else "")
)
if failed:
print(f"Failed: {' '.join(failures)}")
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -1,78 +0,0 @@
"""Snapshot testing infrastructure for E2E tests.
Provides deterministic regression testing by comparing inference output
against committed baseline snapshots. Tests FAIL if no baseline exists —
baselines must be explicitly generated and committed.
Generate baselines: UPDATE_SNAPSHOTS=1 python3 e2e/run_all.py --slow
Update after intentional changes: UPDATE_SNAPSHOTS=1 python3 e2e/run_all.py --slow
Snapshots are stored per-architecture (e.g. snapshots/x86_64/, snapshots/arm64/)
since floating-point results differ between CPU architectures.
"""
import difflib
import json
import os
import platform
from pathlib import Path
ARCH = platform.machine()
SNAPSHOTS_DIR = Path(__file__).parent / "snapshots" / ARCH
def assert_snapshot(
name: str,
content: str,
metadata: dict,
) -> None:
"""Compare content against a saved snapshot, or create one if missing.
Args:
name: Snapshot identifier (used as filename: snapshots/{arch}/{name}.json).
content: The actual inference output to compare.
metadata: Additional context stored alongside content (model, seed, etc.).
Not used for comparison -- purely documentary.
Raises:
AssertionError: If content doesn't match the saved snapshot.
Environment:
UPDATE_SNAPSHOTS=1: Overwrite existing snapshot with actual content.
"""
snapshot_file = SNAPSHOTS_DIR / f"{name}.json"
update = os.environ.get("UPDATE_SNAPSHOTS") == "1"
if update:
# Explicitly regenerate snapshot
SNAPSHOTS_DIR.mkdir(parents=True, exist_ok=True)
snapshot_data = {**metadata, "arch": ARCH, "content": content}
snapshot_file.write_text(json.dumps(snapshot_data, indent=2) + "\n")
print(f" Updated snapshot: {ARCH}/{snapshot_file.name}")
elif not snapshot_file.exists():
raise AssertionError(
f"No baseline snapshot for '{name}' on {ARCH}.\n"
f"Expected file: {snapshot_file}\n\n"
f"Generate baselines with: UPDATE_SNAPSHOTS=1 python3 e2e/run_all.py --slow"
)
else:
snapshot = json.loads(snapshot_file.read_text())
expected = snapshot["content"]
if content != expected:
diff = "\n".join(
difflib.unified_diff(
expected.splitlines(),
content.splitlines(),
fromfile=f"expected ({snapshot_file.relative_to(SNAPSHOTS_DIR.parent.parent)})",
tofile="actual",
lineterm="",
)
)
raise AssertionError(
f"Snapshot mismatch for '{name}' on {ARCH}!\n\n"
f"{diff}\n\n"
f"Expected: {expected!r}\n"
f"Actual: {content!r}\n\n"
f"To update: UPDATE_SNAPSHOTS=1 python3 e2e/run_all.py --slow"
)
print(f" Output matches snapshot ({ARCH}/{snapshot_file.name})")

View File

@@ -1,22 +0,0 @@
"""Test: Basic cluster formation.
Verifies two nodes discover each other, elect a master, and the API responds.
"""
import asyncio
import sys
sys.path.insert(0, str(__import__("pathlib").Path(__file__).parent))
from conftest import Cluster
async def main():
async with Cluster("cluster_formation") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
print("PASSED: cluster_formation")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,61 +0,0 @@
"""Test: Deterministic inference output (snapshot test).
Sends a chat completion request with a fixed seed,
then verifies the output matches a known-good snapshot. This ensures
inference produces consistent results across runs.
Uses MLX CPU backend in Docker on x86 Linux.
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from snapshot import assert_snapshot
from conftest import Cluster
MODEL = "mlx-community/Qwen3-0.6B-4bit"
SEED = 42
PROMPT = "What is 2+2? Reply with just the number."
MAX_TOKENS = 32
async def main():
async with Cluster("inference_snapshot") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
print(f" Launching model {MODEL}...")
await cluster.place_model(MODEL)
print(f" Sending chat completion (seed={SEED})...")
resp = await cluster.chat(
model=MODEL,
messages=[{"role": "user", "content": PROMPT}],
seed=SEED,
temperature=0,
max_tokens=MAX_TOKENS,
)
content = resp["choices"][0]["message"]["content"]
print(f" Response: {content!r}")
assert_snapshot(
name="inference_snapshot",
content=content,
metadata={
"model": MODEL,
"seed": SEED,
"prompt": PROMPT,
"max_tokens": MAX_TOKENS,
},
)
print("PASSED: inference_snapshot")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,47 +0,0 @@
"""Test: Cluster works without internet access.
Verifies exo functions correctly when containers can talk to each other
but cannot reach the internet. Uses iptables to block all outbound traffic
except private subnets and multicast (for mDNS discovery).
"""
import asyncio
import sys
sys.path.insert(0, str(__import__("pathlib").Path(__file__).parent))
from conftest import Cluster
async def main():
async with Cluster(
"no_internet",
overrides=["tests/no_internet/docker-compose.override.yml"],
) as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
# Verify internet is actually blocked from inside the containers
for node in ["exo-node-1", "exo-node-2"]:
rc, _ = await cluster.exec(
node,
"curl",
"-sf",
"--max-time",
"3",
"https://huggingface.co",
check=False,
)
assert rc != 0, f"{node} should not be able to reach the internet"
print(f" {node}: internet correctly blocked")
# Verify exo detected no internet connectivity
log = await cluster.logs()
assert "Internet connectivity: False" in log, "exo should detect no internet"
print(" exo correctly detected no internet connectivity")
print("PASSED: no_internet")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,65 +0,0 @@
"""Test: Runner chaos — abrupt runner death detection.
slow
Sends a chat completion with the EXO_RUNNER_MUST_DIE trigger, which causes
the runner process to call os._exit(1) (simulating an OOM kill). Verifies that
the RunnerSupervisor health check detects the death and the system doesn't hang.
Requires a machine that can run MLX inference at reasonable speed (Apple Silicon).
Run with: python3 e2e/run_all.py --slow or E2E_SLOW=1 python3 e2e/run_all.py
"""
import asyncio
import contextlib
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from conftest import Cluster
MODEL = "mlx-community/Qwen3-0.6B-4bit"
async def main():
async with Cluster("runner_chaos") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
# Place the model so a runner is loaded and ready
print(f" Launching model {MODEL}...")
await cluster.place_model(MODEL)
# Send a chat request with the die trigger.
# The runner will call os._exit(1) mid-inference, simulating OOM kill.
# The chat request itself will fail — that's expected.
print(" Sending EXO_RUNNER_MUST_DIE trigger...")
with contextlib.suppress(Exception):
await cluster.chat(
model=MODEL,
messages=[{"role": "user", "content": "EXO RUNNER MUST DIE"}],
timeout=60,
)
# Wait for the health check to detect the death and emit RunnerFailed
async def health_check_detected():
log = await cluster.logs()
return "runner process died unexpectedly" in log
await cluster.wait_for(
"Health check detected runner death",
health_check_detected,
timeout=30,
)
# Verify RunnerFailed was emitted (visible in logs)
log = await cluster.logs()
assert "runner process died unexpectedly" in log, (
f"Expected health check to detect runner death but it didn't.\nLogs:\n{log}"
)
print("PASSED: runner_chaos")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,60 +0,0 @@
"""Test: Code generation snapshot.
slow
Verifies deterministic output for a code generation prompt.
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from snapshot import assert_snapshot
from conftest import Cluster
MODEL = "mlx-community/Qwen3-0.6B-4bit"
SEED = 42
PROMPT = (
"Write a Python function to reverse a string. Only output the code, no explanation."
)
MAX_TOKENS = 64
async def main():
async with Cluster("snapshot_code_gen") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
print(f" Launching model {MODEL}...")
await cluster.place_model(MODEL)
print(f" Sending chat completion (seed={SEED})...")
resp = await cluster.chat(
model=MODEL,
messages=[{"role": "user", "content": PROMPT}],
seed=SEED,
temperature=0,
max_tokens=MAX_TOKENS,
)
content = resp["choices"][0]["message"]["content"]
print(f" Response: {content!r}")
assert_snapshot(
name="snapshot_code_gen",
content=content,
metadata={
"model": MODEL,
"seed": SEED,
"prompt": PROMPT,
"max_tokens": MAX_TOKENS,
},
)
print("PASSED: snapshot_code_gen")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,65 +0,0 @@
"""Test: Edge case snapshots.
slow
Verifies deterministic output for edge-case prompts: single word input,
special characters, and unicode.
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from snapshot import assert_snapshot
from conftest import Cluster
MODEL = "mlx-community/Qwen3-0.6B-4bit"
SEED = 42
MAX_TOKENS = 32
CASES = [
("edge_single_word", "Hi"),
("edge_special_chars", "What does 2 * (3 + 4) / 7 - 1 equal? Use <math> tags."),
("edge_unicode", "Translate 'hello' to Japanese, Chinese, and Korean."),
]
async def main():
async with Cluster("snapshot_edge") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
print(f" Launching model {MODEL}...")
await cluster.place_model(MODEL)
for snapshot_name, prompt in CASES:
print(f" [{snapshot_name}] Sending: {prompt!r}")
resp = await cluster.chat(
model=MODEL,
messages=[{"role": "user", "content": prompt}],
seed=SEED,
temperature=0,
max_tokens=MAX_TOKENS,
)
content = resp["choices"][0]["message"]["content"]
print(f" [{snapshot_name}] Response: {content!r}")
assert_snapshot(
name=snapshot_name,
content=content,
metadata={
"model": MODEL,
"seed": SEED,
"prompt": prompt,
"max_tokens": MAX_TOKENS,
},
)
print("PASSED: snapshot_edge")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,58 +0,0 @@
"""Test: Longer output snapshot.
slow
Verifies deterministic output with a higher max_tokens (128).
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from snapshot import assert_snapshot
from conftest import Cluster
MODEL = "mlx-community/Qwen3-0.6B-4bit"
SEED = 42
PROMPT = "Explain how a binary search algorithm works."
MAX_TOKENS = 128
async def main():
async with Cluster("snapshot_long_output") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
print(f" Launching model {MODEL}...")
await cluster.place_model(MODEL)
print(f" Sending chat completion (seed={SEED}, max_tokens={MAX_TOKENS})...")
resp = await cluster.chat(
model=MODEL,
messages=[{"role": "user", "content": PROMPT}],
seed=SEED,
temperature=0,
max_tokens=MAX_TOKENS,
)
content = resp["choices"][0]["message"]["content"]
print(f" Response: {content!r}")
assert_snapshot(
name="snapshot_long_output",
content=content,
metadata={
"model": MODEL,
"seed": SEED,
"prompt": PROMPT,
"max_tokens": MAX_TOKENS,
},
)
print("PASSED: snapshot_long_output")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,73 +0,0 @@
"""Test: Multi-model snapshot tests.
slow
Verifies deterministic output across different model architectures to catch
model-specific regressions. Each model uses its own snapshot file.
Run with: python3 e2e/run_all.py --slow or E2E_SLOW=1 python3 e2e/run_all.py
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from snapshot import assert_snapshot
from conftest import Cluster
SEED = 42
PROMPT = "What is the capital of France?"
MAX_TOKENS = 32
MODELS = [
"mlx-community/SmolLM2-135M-Instruct",
"mlx-community/Llama-3.2-1B-Instruct-4bit",
"mlx-community/gemma-2-2b-it-4bit",
]
async def main():
async with Cluster("snapshot_multi_model") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
for model in MODELS:
short_name = (
model.split("/")[-1].lower().replace("-", "_").replace(".", "_")
)
snapshot_name = f"snapshot_multi_{short_name}"
print(f" Launching model {model}...")
await cluster.place_model(model)
print(f" Sending chat completion (seed={SEED})...")
resp = await cluster.chat(
model=model,
messages=[{"role": "user", "content": PROMPT}],
seed=SEED,
temperature=0,
max_tokens=MAX_TOKENS,
)
content = resp["choices"][0]["message"]["content"]
print(f" [{short_name}] Response: {content!r}")
assert_snapshot(
name=snapshot_name,
content=content,
metadata={
"model": model,
"seed": SEED,
"prompt": PROMPT,
"max_tokens": MAX_TOKENS,
},
)
print(f" [{short_name}] PASSED")
print("PASSED: snapshot_multi_model")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,58 +0,0 @@
"""Test: Reasoning/math snapshot.
slow
Verifies deterministic output for a simple reasoning prompt.
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from snapshot import assert_snapshot
from conftest import Cluster
MODEL = "mlx-community/Qwen3-0.6B-4bit"
SEED = 42
PROMPT = "If I have 3 apples and give away 1, how many do I have? Think step by step."
MAX_TOKENS = 64
async def main():
async with Cluster("snapshot_reasoning") as cluster:
await cluster.build()
await cluster.start()
await cluster.assert_healthy()
print(f" Launching model {MODEL}...")
await cluster.place_model(MODEL)
print(f" Sending chat completion (seed={SEED})...")
resp = await cluster.chat(
model=MODEL,
messages=[{"role": "user", "content": PROMPT}],
seed=SEED,
temperature=0,
max_tokens=MAX_TOKENS,
)
content = resp["choices"][0]["message"]["content"]
print(f" Response: {content!r}")
assert_snapshot(
name="snapshot_reasoning",
content=content,
metadata={
"model": MODEL,
"seed": SEED,
"prompt": PROMPT,
"max_tokens": MAX_TOKENS,
},
)
print("PASSED: snapshot_reasoning")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,32 +0,0 @@
# Block all outbound internet traffic using iptables while preserving:
# - Multicast (224.0.0.0/4) for mDNS peer discovery
# - Private subnets (10/8, 172.16/12, 192.168/16) for inter-container communication
# - Loopback (127/8)
# Requires NET_ADMIN capability for iptables.
services:
exo-node-1:
cap_add:
- NET_ADMIN
entrypoint: ["/bin/sh", "-c"]
command:
- |
iptables -A OUTPUT -d 127.0.0.0/8 -j ACCEPT
iptables -A OUTPUT -d 10.0.0.0/8 -j ACCEPT
iptables -A OUTPUT -d 172.16.0.0/12 -j ACCEPT
iptables -A OUTPUT -d 192.168.0.0/16 -j ACCEPT
iptables -A OUTPUT -d 224.0.0.0/4 -j ACCEPT
iptables -A OUTPUT -j REJECT
exec .venv/bin/exo -v
exo-node-2:
cap_add:
- NET_ADMIN
entrypoint: ["/bin/sh", "-c"]
command:
- |
iptables -A OUTPUT -d 127.0.0.0/8 -j ACCEPT
iptables -A OUTPUT -d 10.0.0.0/8 -j ACCEPT
iptables -A OUTPUT -d 172.16.0.0/12 -j ACCEPT
iptables -A OUTPUT -d 192.168.0.0/16 -j ACCEPT
iptables -A OUTPUT -d 224.0.0.0/4 -j ACCEPT
iptables -A OUTPUT -j REJECT
exec .venv/bin/exo -v

View File

@@ -132,7 +132,7 @@ markers = [
env = [
"EXO_TESTS=1"
]
addopts = "-m 'not slow'"
addopts = "-m 'not slow' --ignore=tests/start_distributed_test.py"
filterwarnings = [
"ignore:builtin type Swig:DeprecationWarning",
]

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/SmolLM2-135M-Instruct"
n_layers = 30
hidden_size = 576
supports_tensor = true
tasks = ["TextGeneration"]
family = "llama"
quantization = "bf16"
base_model = "SmolLM2 135M"
capabilities = ["text"]
[storage_size]
in_bytes = 269060381

View File

@@ -1,12 +0,0 @@
model_id = "mlx-community/gemma-2-2b-it-4bit"
n_layers = 26
hidden_size = 2304
supports_tensor = false
tasks = ["TextGeneration"]
family = "gemma2"
quantization = "4bit"
base_model = "Gemma 2 2B"
capabilities = ["text"]
[storage_size]
in_bytes = 1492755242

View File

@@ -14,6 +14,7 @@ from exo.download.download_utils import (
map_repo_download_progress_to_download_progress_data,
)
from exo.download.shard_downloader import ShardDownloader
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import ModelId
from exo.shared.types.commands import (
CancelDownload,
@@ -63,6 +64,9 @@ class DownloadCoordinator:
self.event_sender, self.event_receiver = channel[Event]()
self.shard_downloader.on_progress(self._download_progress_callback)
def _model_dir(self, model_id: ModelId) -> str:
return str(EXO_MODELS_DIR / model_id.normalize())
async def _download_progress_callback(
self, callback_shard: ShardMetadata, progress: RepoDownloadProgress
) -> None:
@@ -74,6 +78,7 @@ class DownloadCoordinator:
shard_metadata=callback_shard,
node_id=self.node_id,
total_bytes=progress.total_bytes,
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = completed
await self.event_sender.send(
@@ -93,6 +98,7 @@ class DownloadCoordinator:
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = ongoing
await self.event_sender.send(
@@ -170,7 +176,11 @@ class DownloadCoordinator:
return
# Emit pending status
progress = DownloadPending(shard_metadata=shard, node_id=self.node_id)
progress = DownloadPending(
shard_metadata=shard,
node_id=self.node_id,
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = progress
await self.event_sender.send(NodeDownloadProgress(download_progress=progress))
@@ -184,6 +194,7 @@ class DownloadCoordinator:
shard_metadata=shard,
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = completed
await self.event_sender.send(
@@ -206,6 +217,7 @@ class DownloadCoordinator:
download_progress=map_repo_download_progress_to_download_progress_data(
initial_progress
),
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = status
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
@@ -219,6 +231,7 @@ class DownloadCoordinator:
shard_metadata=shard,
node_id=self.node_id,
error_message=str(e),
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = failed
await self.event_sender.send(
@@ -253,6 +266,7 @@ class DownloadCoordinator:
pending = DownloadPending(
shard_metadata=current_status.shard_metadata,
node_id=self.node_id,
model_directory=self._model_dir(model_id),
)
await self.event_sender.send(
NodeDownloadProgress(download_progress=pending)
@@ -295,11 +309,18 @@ class DownloadCoordinator:
node_id=self.node_id,
shard_metadata=progress.shard,
total_bytes=progress.total_bytes,
model_directory=self._model_dir(
progress.shard.model_card.model_id
),
)
elif progress.status in ["in_progress", "not_started"]:
if progress.downloaded_bytes_this_session.in_bytes == 0:
status = DownloadPending(
node_id=self.node_id, shard_metadata=progress.shard
node_id=self.node_id,
shard_metadata=progress.shard,
model_directory=self._model_dir(
progress.shard.model_card.model_id
),
)
else:
status = DownloadOngoing(
@@ -308,6 +329,9 @@ class DownloadCoordinator:
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
model_directory=self._model_dir(
progress.shard.model_card.model_id
),
)
else:
continue

View File

@@ -144,8 +144,8 @@ async def collect_responses_response(
for tool in chunk.tool_calls:
function_call_items.append(
ResponseFunctionCallItem(
id=f"fc_{tool.id}",
call_id=f"call_{tool.id}",
id=tool.id,
call_id=tool.id,
name=tool.name,
arguments=tool.arguments,
)

View File

@@ -1367,6 +1367,7 @@ class API:
async def run(self):
shutdown_ev = anyio.Event()
bonjour_cleanup = self._register_bonjour_service()
try:
async with create_task_group() as tg:
self._tg = tg
@@ -1382,10 +1383,48 @@ class API:
with anyio.CancelScope(shield=True):
shutdown_ev.set()
finally:
bonjour_cleanup()
self._event_log.close()
self.command_sender.close()
self.global_event_receiver.close()
def _register_bonjour_service(self) -> Callable[[], None]:
"""Register a Bonjour service via the system mDNSResponder. Returns a cleanup function."""
import subprocess
import sys
if sys.platform != "darwin":
logger.info("Bonjour service registration is only supported on macOS")
return lambda: None
service_name = f"EXO Cluster ({self.node_id[:8]})"
try:
proc = subprocess.Popen(
[
"dns-sd",
"-R",
service_name,
"_exo._tcp",
"local",
str(self.port),
f"node_id={self.node_id}",
],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
logger.info(
f"Registered Bonjour service _exo._tcp on port {self.port} (pid {proc.pid})"
)
def cleanup() -> None:
proc.terminate()
proc.wait()
return cleanup
except Exception as e:
logger.warning(f"Failed to register Bonjour service: {e}")
return lambda: None
async def run_api(self, ev: anyio.Event):
cfg = Config()
cfg.bind = [f"0.0.0.0:{self.port}"]

View File

@@ -218,11 +218,6 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
key: value for key, value in state.downloads.items() if key != event.node_id
}
# Clean up all granular node mappings
node_identities = {
key: value
for key, value in state.node_identities.items()
if key != event.node_id
}
node_memory = {
key: value for key, value in state.node_memory.items() if key != event.node_id
}
@@ -263,7 +258,6 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
"downloads": downloads,
"topology": topology,
"last_seen": last_seen,
"node_identities": node_identities,
"node_memory": node_memory,
"node_disk": node_disk,
"node_system": node_system,

View File

@@ -26,6 +26,7 @@ class DownloadProgressData(CamelCaseModel):
class BaseDownloadProgress(TaggedModel):
node_id: NodeId
shard_metadata: ShardMetadata
model_directory: str = ""
class DownloadPending(BaseDownloadProgress):

View File

@@ -1,5 +1,7 @@
import sys
def print_startup_banner(port: int) -> None:
"""Print a prominent startup banner with API endpoint information."""
dashboard_url = f"http://localhost:{port}"
banner = f"""
╔═══════════════════════════════════════════════════════════════════════╗
@@ -27,4 +29,4 @@ def print_startup_banner(port: int) -> None:
"""
print(banner)
print(banner, file=sys.stderr)

View File

@@ -306,7 +306,7 @@ def mlx_generate(
max_stop_len = max((len(s) for s in stop_sequences), default=0)
mx_barrier(group)
logger.info("Ready to prefill")
logger.info("Starting prefill")
# Prefill cache with all tokens except the last one
prefill_tps, prefill_tokens, ssm_snapshots_list = prefill(

View File

@@ -353,7 +353,13 @@ def load_tokenizer_for_model_id(
return list(hf_tokenizer.model.encode(text, allowed_special="all")) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
hf_tokenizer.encode = _patched_encode
return TokenizerWrapper(hf_tokenizer, eos_token_ids=eos_token_ids)
return TokenizerWrapper(
hf_tokenizer,
eos_token_ids=eos_token_ids,
tool_call_start="<|tool_calls_section_begin|>",
tool_call_end="<|tool_calls_section_end|>",
tool_parser=_parse_kimi_tool_calls,
)
tokenizer = load_tokenizer(
model_path,
@@ -585,3 +591,41 @@ def mx_barrier(group: Group | None):
mx.array(1.0), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
)
)
def _parse_kimi_tool_calls(text: str):
import regex as re
# kimi has a fixed function naming scheme, with a json formatted arg
# functions.multiply:0<|tool_call_argument_begin|>{"a": 2, "b": 3}
_func_name_regex = re.compile(
r"^\s*((?:functions\.)?(.+?):\d+)\s*<\|tool_call_argument_begin\|>", re.DOTALL
)
_func_arg_regex = re.compile(r"<\|tool_call_argument_begin\|>\s*(.*)\s*", re.DOTALL)
_tool_call_split_regex = re.compile(
r"<\|tool_call_begin\|>(.*?)<\|tool_call_end\|>", re.DOTALL
)
def _parse_single_tool(text: str) -> dict[str, Any]:
func_name_match = _func_name_regex.search(text)
if func_name_match is None:
raise ValueError("No tool call found.")
tool_call_id = func_name_match.group(1) # e.g. "functions.get_weather:0"
func_name = func_name_match.group(2) # e.g. "get_weather"
func_args_match = _func_arg_regex.search(text)
if func_args_match is None:
raise ValueError("No tool call arguments found.")
func_args = func_args_match.group(1)
try:
arg_dct = json.loads(func_args) # pyright: ignore[reportAny]
except Exception:
arg_dct = None
return dict(id=tool_call_id, name=func_name, arguments=arg_dct)
tool_matches = _tool_call_split_regex.findall(text)
if tool_matches:
return [_parse_single_tool(match) for match in tool_matches] # pyright: ignore[reportAny]
else:
return [_parse_single_tool(text)]

View File

@@ -1,8 +1,4 @@
from __future__ import annotations
import os
import threading
from multiprocessing.sharedctypes import Synchronized
import loguru
@@ -14,15 +10,6 @@ from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
logger: "loguru.Logger" = loguru.logger
HEARTBEAT_INTERVAL_SECONDS = 0.5
def _heartbeat_loop(heartbeat: Synchronized[int], stop: threading.Event) -> None:
"""Daemon thread that periodically increments the heartbeat counter."""
while not stop.is_set():
heartbeat.value += 1
stop.wait(HEARTBEAT_INTERVAL_SECONDS)
def entrypoint(
bound_instance: BoundInstance,
@@ -30,7 +17,6 @@ def entrypoint(
task_receiver: MpReceiver[Task],
cancel_receiver: MpReceiver[TaskId],
_logger: "loguru.Logger",
heartbeat: Synchronized[int] | None = None,
) -> None:
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
if fast_synch_override == "on" or (
@@ -49,17 +35,6 @@ def entrypoint(
logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}")
# Start heartbeat thread so the supervisor can detect if we freeze.
stop_heartbeat = threading.Event()
heartbeat_thread: threading.Thread | None = None
if heartbeat is not None:
heartbeat_thread = threading.Thread(
target=_heartbeat_loop,
args=(heartbeat, stop_heartbeat),
daemon=True,
)
heartbeat_thread.start()
# Import main after setting global logger - this lets us just import logger from this module
try:
from exo.worker.runner.runner import main
@@ -78,9 +53,6 @@ def entrypoint(
)
)
finally:
stop_heartbeat.set()
if heartbeat_thread is not None:
heartbeat_thread.join(timeout=1)
try:
event_sender.close()
task_receiver.close()

View File

@@ -1,12 +1,10 @@
import base64
import json
import math
import os
import resource
import time
from collections.abc import Generator
from functools import cache
from typing import Any, Callable, Literal
from typing import Literal
import mlx.core as mx
from mlx_lm.models.gpt_oss import Model as GptOssModel
@@ -17,7 +15,6 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
StreamableParser,
load_harmony_encoding,
)
from pydantic import ValidationError
from exo.shared.constants import EXO_MAX_CHUNK_SIZE, EXO_TRACING_ENABLED
from exo.shared.models.model_cards import ModelId, ModelTask
@@ -94,6 +91,8 @@ from exo.worker.engines.mlx.utils_mlx import (
)
from exo.worker.runner.bootstrap import logger
from .tool_parsers import ToolParser, make_mlx_parser
def _is_primary_output_node(shard_metadata: ShardMetadata) -> bool:
"""Check if this node is the primary output node for image generation.
@@ -139,6 +138,7 @@ def main(
inference_model: Model | None = None
image_model: DistributedImageModel | None = None
tokenizer = None
tool_parser: ToolParser | None = None
group = None
kv_prefix_cache: KVPrefixCache | None = None
check_for_cancel_every: int | None = None
@@ -204,8 +204,17 @@ def main(
bound_instance, group, on_timeout=on_model_load_timeout
)
logger.info(
f"model has_tool_calling={tokenizer.has_tool_calling}"
f"model has_tool_calling={tokenizer.has_tool_calling} using tokens {tokenizer.tool_call_start}, {tokenizer.tool_call_end}"
)
if tokenizer.has_tool_calling:
assert tokenizer.tool_call_start
assert tokenizer.tool_call_end
assert tokenizer.tool_parser # pyright: ignore[reportAny]
tool_parser = make_mlx_parser(
tokenizer.tool_call_start,
tokenizer.tool_call_end,
tokenizer.tool_parser, # pyright: ignore[reportAny]
)
kv_prefix_cache = KVPrefixCache(group)
elif (
@@ -311,31 +320,11 @@ def main(
mlx_generator, tokenizer
)
# Kimi-K2 has tool call sections - we don't care about them
if "kimi" in shard_metadata.model_card.model_id.lower():
mlx_generator = filter_kimi_tokens(mlx_generator)
patch_kimi_tokenizer(tokenizer)
# GLM models need patched parser (upstream has bug with None regex match)
elif "glm" in shard_metadata.model_card.model_id.lower():
patch_glm_tokenizer(tokenizer)
# GPT-OSS specific parsing to match other model formats.
elif isinstance(inference_model, GptOssModel):
if isinstance(inference_model, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator)
if tokenizer.has_tool_calling and not isinstance(
inference_model, GptOssModel
):
assert tokenizer.tool_call_start
assert tokenizer.tool_call_end
assert tokenizer.tool_parser # pyright: ignore[reportAny]
mlx_generator = parse_tool_calls(
mlx_generator,
tokenizer.tool_call_start,
tokenizer.tool_call_end,
tokenizer.tool_parser, # pyright: ignore[reportAny]
)
elif tool_parser:
mlx_generator = parse_tool_calls(mlx_generator, tool_parser)
completion_tokens = 0
tokens_since_last_cancel_check = 0
@@ -588,21 +577,8 @@ def get_gpt_oss_encoding():
return encoding
def filter_kimi_tokens(
responses: Generator[GenerationResponse | ToolCallResponse],
) -> Generator[GenerationResponse]:
for resp in responses:
assert isinstance(resp, GenerationResponse)
if (
resp.text == "<|tool_calls_section_begin|>"
or resp.text == "<|tool_calls_section_end|>"
):
continue
yield resp
def parse_gpt_oss(
responses: Generator[GenerationResponse | ToolCallResponse],
responses: Generator[GenerationResponse],
) -> Generator[GenerationResponse | ToolCallResponse]:
encoding = get_gpt_oss_encoding()
stream = StreamableParser(encoding, role=Role.ASSISTANT)
@@ -659,9 +635,9 @@ def parse_gpt_oss(
def parse_thinking_models(
responses: Generator[GenerationResponse | ToolCallResponse],
responses: Generator[GenerationResponse],
tokenizer: TokenizerWrapper,
) -> Generator[GenerationResponse | ToolCallResponse]:
) -> Generator[GenerationResponse]:
"""
For models that inject thinking tags in the prompt (like GLM-4.7),
prepend the thinking tag to the output stream so the frontend
@@ -782,225 +758,58 @@ def _process_image_response(
def parse_tool_calls(
responses: Generator[GenerationResponse | ToolCallResponse],
tool_call_start: str,
tool_call_end: str,
tool_parser: Callable[[str], dict[str, Any] | list[dict[str, Any]]],
responses: Generator[GenerationResponse], tool_parser: ToolParser
) -> Generator[GenerationResponse | ToolCallResponse]:
in_tool_call = False
tool_call_text_parts: list[str] = []
for response in responses:
assert isinstance(response, GenerationResponse)
# assumption: the tool call start is one token
if response.text == tool_call_start:
if response.text.startswith(tool_parser.start_parsing):
in_tool_call = True
continue
# assumption: the tool call end is one token
if in_tool_call and response.text == tool_call_end:
try:
# tool_parser returns an arbitrarily nested python dictionary
# we actually don't want the python dictionary, we just want to
# parse the top level { function: ..., arguments: ... } structure
# as we're just gonna hand it back to the api anyway
parsed = tool_parser("".join(tool_call_text_parts).strip())
logger.info(f"parsed {tool_call_text_parts=} into {parsed=}")
if isinstance(parsed, list):
tools = [_validate_single_tool(tool) for tool in parsed]
else:
tools = [_validate_single_tool(parsed)]
yield ToolCallResponse(
tool_calls=tools, usage=response.usage, stats=response.stats
)
except (
json.JSONDecodeError,
ValidationError,
ValueError,
AttributeError,
) as e:
# ValueError: our parsers raise this for malformed tool calls
# AttributeError: upstream parsers (e.g. glm47) may raise this when regex doesn't match
logger.opt(exception=e).warning("tool call parsing failed")
# assumption: talking about tool calls, not making a tool call
response.text = (
tool_call_start + "".join(tool_call_text_parts) + tool_call_end
)
yield response
in_tool_call = False
tool_call_text_parts = []
continue
if in_tool_call:
tool_call_text_parts.append(response.text)
if response.text.endswith(tool_parser.end_parsing):
# parse the actual tool calls from the tool call text
parsed = tool_parser.parse_tool_calls(
"".join(tool_call_text_parts).strip()
)
logger.info(f"parsed {tool_call_text_parts=} into {parsed=}")
if parsed is not None:
yield ToolCallResponse(
tool_calls=parsed, usage=response.usage, stats=response.stats
)
else:
logger.warning(
f"tool call parsing failed for text {''.join(tool_call_text_parts)}"
)
response.text = "".join(tool_call_text_parts)
yield response
in_tool_call = False
tool_call_text_parts = []
continue
if response.finish_reason is not None:
logger.info(
"toll call parsing interrupted, yield partial tool call as text"
"tool call parsing interrupted, yield partial tool call as text"
)
yield GenerationResponse(
text=tool_call_start + "".join(tool_call_text_parts),
token=0,
finish_reason=response.finish_reason,
usage=response.usage,
stats=response.stats,
response = response.model_copy(
update={
"text": "".join(tool_call_text_parts),
"token": 0,
}
)
yield response
continue
# fallthrough
yield response
def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
"""
Version of to-be-upstreamed kimi-k2 tool parser
"""
import ast
import json
from typing import Any
import regex as re
# kimi has a fixed function naming scheme, with a json formatted arg
# functions.multiply:0 <|tool_call_argument_begin|> {"a": 2, "b": 3}
# Also needs to handle tools like call_0<|tool_call_argument_begin|>{"filePath": "..."}
_func_name_regex = re.compile(
r"^\s*(.+)[:](\d+)\s*<\|tool_call_argument_begin\|>", re.DOTALL
)
_func_arg_regex = re.compile(r"<\|tool_call_argument_begin\|>\s*(.*)\s*", re.DOTALL)
# kimi has a tool_calls_section - we're leaving this up to the caller to handle
tool_call_start = "<|tool_call_begin|>"
tool_call_end = "<|tool_call_end|>"
def _deserialize(value: str) -> Any: # pyright: ignore[reportAny]
try:
return json.loads(value) # pyright: ignore[reportAny]
except Exception:
pass
try:
return ast.literal_eval(value) # pyright: ignore[reportAny]
except Exception:
pass
return value
def parse_tool_call(text: str, tools: Any | None = None):
func_name_match = _func_name_regex.search(text)
if func_name_match is None:
raise ValueError(f"Could not parse function name from tool call: {text!r}")
original_func_name = func_name_match.group(1)
tool_id = func_name_match.group(2)
# strip off the `functions.` prefix, if it exists.
func_name = original_func_name[original_func_name.find(".") + 1 :]
func_args_match = _func_arg_regex.search(text)
if func_args_match is None:
raise ValueError(f"Could not parse function args from tool call: {text!r}")
func_args = func_args_match.group(1)
# the args should be valid json - no need to check against our tools to deserialize
arg_dct = _deserialize(func_args) # pyright: ignore[reportAny]
return dict(
id=f"{original_func_name}:{tool_id}",
name=func_name,
arguments=arg_dct, # pyright: ignore[reportAny]
)
tokenizer._tool_call_start = tool_call_start
tokenizer._tool_call_end = tool_call_end
tokenizer._tool_parser = parse_tool_call
def patch_glm_tokenizer(tokenizer: TokenizerWrapper):
"""
Fixed version of mlx_lm's glm47 tool parser that handles regex match failures.
"""
import ast
import json
from typing import Any
import regex as re
_func_name_regex = re.compile(r"^(.*?)<arg_key>", re.DOTALL)
_func_arg_regex = re.compile(
r"<arg_key>(.*?)</arg_key>(?:\n|\s)*<arg_value>(.*?)(?:</arg_value>|(?=<arg_key>)|$)",
re.DOTALL,
)
tool_call_start = "<tool_call>"
tool_call_end = "</tool_call>"
def _is_string_type(
tool_name: str,
arg_name: str,
tools: list[Any] | None,
) -> bool:
if tools is None:
return False
for tool in tools: # pyright: ignore[reportAny]
func = tool["function"] # pyright: ignore[reportAny]
if func["name"] == tool_name:
params = func["parameters"] # pyright: ignore[reportAny]
if params is None:
return False
props = params.get("properties", {}) # pyright: ignore[reportAny]
arg_props = props.get(arg_name, {}) # pyright: ignore[reportAny]
arg_type = arg_props.get("type", None) # pyright: ignore[reportAny]
return arg_type == "string" # pyright: ignore[reportAny]
return False
def _deserialize(value: str) -> Any: # pyright: ignore[reportAny]
try:
return json.loads(value) # pyright: ignore[reportAny]
except Exception:
pass
try:
return ast.literal_eval(value) # pyright: ignore[reportAny]
except Exception:
pass
return value
def parse_tool_call(text: str, tools: list[Any] | None = None):
func_name_match = _func_name_regex.search(text)
if func_name_match is None:
raise ValueError(f"Could not parse function name from tool call: {text!r}")
func_name = func_name_match.group(1)
pairs = _func_arg_regex.findall(text)
arg_dct: dict[str, Any] = {}
for key, value in pairs: # pyright: ignore[reportAny]
arg_key = key.strip() # pyright: ignore[reportAny]
arg_val = value.strip() # pyright: ignore[reportAny]
if not _is_string_type(func_name, arg_key, tools): # pyright: ignore[reportAny]
arg_val = _deserialize(arg_val) # pyright: ignore[reportAny]
arg_dct[arg_key] = arg_val
return dict(name=func_name, arguments=arg_dct)
tokenizer._tool_call_start = tool_call_start
tokenizer._tool_call_end = tool_call_end
tokenizer._tool_parser = parse_tool_call
def _validate_single_tool(obj: dict[str, Any]) -> ToolCallItem:
if (
((name := obj.get("name")) is not None)
and ((args := obj.get("arguments")) is not None)
and isinstance(name, str)
):
raw_id: object = obj.get("id")
extra = {"id": str(raw_id)} if raw_id is not None else {}
return ToolCallItem(
**extra,
name=name,
arguments=json.dumps(args),
)
else:
raise ValidationError
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"
EXO_RUNNER_MUST_DIE = "EXO RUNNER MUST DIE"
def _check_for_debug_prompts(task_params: TextGenerationTaskParams) -> None:
@@ -1016,9 +825,6 @@ def _check_for_debug_prompts(task_params: TextGenerationTaskParams) -> None:
if not prompt:
return
if EXO_RUNNER_MUST_DIE in prompt:
logger.info("Abrupt process death triggered (simulates OOM kill)")
os._exit(1)
if EXO_RUNNER_MUST_FAIL in prompt:
logger.info("raising exception")
raise Exception("Artificial runner exception - for testing purposes only.")

View File

@@ -1,17 +1,12 @@
from __future__ import annotations
import contextlib
import multiprocessing
import signal
from dataclasses import dataclass, field
from multiprocessing import Process
from multiprocessing.sharedctypes import Synchronized
from typing import Self
import anyio
from anyio import (
BrokenResourceError,
CancelScope,
ClosedResourceError,
to_thread,
)
@@ -31,7 +26,6 @@ from exo.shared.types.worker.runners import (
RunnerIdle,
RunnerLoading,
RunnerRunning,
RunnerShutdown,
RunnerShuttingDown,
RunnerStatus,
RunnerWarmingUp,
@@ -42,8 +36,6 @@ from exo.worker.runner.bootstrap import entrypoint
PREFILL_TIMEOUT_SECONDS = 60
DECODE_TIMEOUT_SECONDS = 5
HEALTH_CHECK_INTERVAL_SECONDS = 1
HEARTBEAT_STALE_CHECKS = 10
@dataclass(eq=False)
@@ -56,14 +48,10 @@ class RunnerSupervisor:
_task_sender: MpSender[Task]
_event_sender: Sender[Event]
_cancel_sender: MpSender[TaskId]
_heartbeat: Synchronized[int]
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
completed: set[TaskId] = field(default_factory=set, init=False)
cancelled: set[TaskId] = field(default_factory=set, init=False)
_death_handled: bool = field(default=False, init=False)
_last_heartbeat_value: int = field(default=0, init=False)
_heartbeat_stale_count: int = field(default=0, init=False)
@classmethod
def create(
@@ -77,8 +65,6 @@ class RunnerSupervisor:
task_sender, task_recv = mp_channel[Task]()
cancel_sender, cancel_recv = mp_channel[TaskId]()
heartbeat: Synchronized[int] = multiprocessing.Value("Q", 0)
runner_process = Process(
target=entrypoint,
args=(
@@ -87,7 +73,6 @@ class RunnerSupervisor:
task_recv,
cancel_recv,
logger,
heartbeat,
),
daemon=True,
)
@@ -103,16 +88,13 @@ class RunnerSupervisor:
_task_sender=task_sender,
_cancel_sender=cancel_sender,
_event_sender=event_sender,
_heartbeat=heartbeat,
)
return self
async def run(self):
self.runner_process.start()
async with anyio.create_task_group() as tg:
tg.start_soon(self._forward_events)
tg.start_soon(self._health_check, tg.cancel_scope)
await self._forward_events()
def shutdown(self):
logger.info("Runner supervisor shutting down")
@@ -195,99 +177,9 @@ class RunnerSupervisor:
self.completed.add(event.task_id)
await self._event_sender.send(event)
except (ClosedResourceError, BrokenResourceError) as e:
if not self._death_handled:
self._death_handled = True
await self._check_runner(e)
for tid in self.pending:
self.pending[tid].set()
async def _health_check(self, cancel_scope: CancelScope) -> None:
"""Periodically check if the runner process is alive and responsive.
Detects two failure modes:
1. Process death (e.g. OOM kill) without cleanly closing the event
channel, which would leave _forward_events blocked on queue.get().
2. Unresponsive process (e.g. frozen by OS memory pressure, deadlock)
detected via a stale heartbeat counter.
"""
while True:
await anyio.sleep(HEALTH_CHECK_INTERVAL_SECONDS)
if not self.runner_process.is_alive():
self._handle_process_exit(cancel_scope)
return
# Check heartbeat counter — if it hasn't changed between
# consecutive checks, the subprocess may be frozen.
current = self._heartbeat.value
if current > 0:
if current == self._last_heartbeat_value:
self._heartbeat_stale_count += 1
if self._heartbeat_stale_count >= HEARTBEAT_STALE_CHECKS:
logger.error(
f"Health check: runner process unresponsive "
f"(heartbeat stale for {self._heartbeat_stale_count} checks), killing"
)
self._handle_unresponsive(cancel_scope)
return
else:
self._heartbeat_stale_count = 0
self._last_heartbeat_value = current
def _handle_process_exit(self, cancel_scope: CancelScope) -> None:
"""Handle runner process that has exited."""
if not self._death_handled:
self._death_handled = True
if isinstance(
self.status, (RunnerShutdown, RunnerShuttingDown, RunnerFailed)
):
logger.info("Health check: runner process exited (expected)")
else:
rc = self.runner_process.exitcode
if isinstance(rc, int) and rc < 0:
sig = -rc
try:
cause = f"signal={sig} ({signal.strsignal(sig)})"
except Exception:
cause = f"signal={sig}"
else:
cause = f"exitcode={rc}"
logger.error(
f"Health check: runner process died unexpectedly ({cause})"
)
self._event_sender.send_nowait(
RunnerStatusUpdated(
runner_id=self.bound_instance.bound_runner_id,
runner_status=RunnerFailed(
error_message=f"Terminated ({cause})"
),
)
)
self.shutdown()
for tid in self.pending:
self.pending[tid].set()
cancel_scope.cancel()
def _handle_unresponsive(self, cancel_scope: CancelScope) -> None:
"""Handle runner process that is alive but unresponsive."""
if not self._death_handled:
self._death_handled = True
self._event_sender.send_nowait(
RunnerStatusUpdated(
runner_id=self.bound_instance.bound_runner_id,
runner_status=RunnerFailed(
error_message="Runner process unresponsive (heartbeat timeout)"
),
)
)
for tid in self.pending:
self.pending[tid].set()
self.shutdown()
cancel_scope.cancel()
await self._check_runner(e)
for tid in self.pending:
self.pending[tid].set()
def __del__(self) -> None:
if self.runner_process.is_alive():

View File

@@ -0,0 +1,72 @@
import json
from dataclasses import dataclass
from typing import Any, Callable
from exo.shared.types.api import ToolCallItem
@dataclass
class ToolParser:
start_parsing: str
end_parsing: str
parse_tool_calls: Callable[[str], list[ToolCallItem] | None]
def make_mlx_parser(
tool_call_start: str,
tool_call_end: str,
tool_parser: Callable[[str], dict[str, Any] | list[dict[str, Any]]],
) -> ToolParser:
def parse_tool_calls(text: str) -> list[ToolCallItem] | None:
try:
text = text.removeprefix(tool_call_start)
text = text.removesuffix(tool_call_end)
parsed = tool_parser(text)
if isinstance(parsed, list):
return [ToolCallItem.model_validate(_flatten(p)) for p in parsed]
else:
return [ToolCallItem.model_validate(_flatten(parsed))]
except Exception:
return None
return ToolParser(
start_parsing=tool_call_start,
end_parsing=tool_call_end,
parse_tool_calls=parse_tool_calls,
)
# TODO / example code:
def _parse_json_calls(text: str) -> list[ToolCallItem] | None:
try:
text = text.removeprefix("<tool_call>")
text = text.removesuffix("</tool_call>")
top_level = {
k: json.dumps(v) if isinstance(v, (dict, list)) else v
for k, v in json.loads(text).items() # pyright: ignore[reportAny]
}
return [ToolCallItem.model_validate(top_level)]
except Exception:
return None
def _flatten(p: dict[str, Any]) -> dict[str, str]:
return {
k: json.dumps(v) if isinstance(v, (dict, list)) else str(v) # pyright: ignore[reportAny]
for k, v in p.items() # pyright: ignore[reportAny]
}
json_tool_parser = ToolParser(
start_parsing="<tool_call>",
end_parsing="</tool_call>",
parse_tool_calls=_parse_json_calls,
)
def infer_tool_parser(chat_template: str) -> ToolParser | None:
"""Attempt to auto-infer a tool parser from the chat template."""
if "<tool_call>" in chat_template and "tool_call.name" in chat_template:
return json_tool_parser
return None

View File

@@ -5,12 +5,13 @@ from typing import Any
from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse
from exo.worker.runner.runner import parse_tool_calls
from exo.worker.runner.tool_parsers import make_mlx_parser
def _make_responses(
texts: list[str],
finish_on_last: bool = True,
) -> Generator[GenerationResponse | ToolCallResponse]:
) -> Generator[GenerationResponse]:
"""Create a sequence of GenerationResponses from text strings."""
for i, text in enumerate(texts):
is_last = i == len(texts) - 1
@@ -22,10 +23,13 @@ def _make_responses(
)
def _dummy_parser(text: str) -> dict[str, Any]:
def _dummier_parser(text: str) -> dict[str, Any]:
return {"name": "test_fn", "arguments": {"arg": text}}
_dummy_parser = make_mlx_parser("<tool_call>", "</tool_call>", _dummier_parser)
class TestParseToolCalls:
"""Tests for parse_tool_calls generator."""
@@ -35,8 +39,6 @@ class TestParseToolCalls:
results = list(
parse_tool_calls(
_make_responses(texts, finish_on_last=False),
"<tool_call>",
"</tool_call>",
_dummy_parser,
)
)
@@ -50,8 +52,6 @@ class TestParseToolCalls:
results = list(
parse_tool_calls(
_make_responses(texts),
"<tool_call>",
"</tool_call>",
_dummy_parser,
)
)
@@ -76,9 +76,7 @@ class TestParseToolCalls:
results = list(
parse_tool_calls(
_make_responses(texts, finish_on_last=False),
"<tool_call>",
"</tool_call>",
_failing_parser,
make_mlx_parser("<tool_call>", "</tool_call>", _failing_parser),
)
)

View File

@@ -1,204 +1 @@
from __future__ import annotations
import multiprocessing
import os
import signal as signal_module
from collections.abc import Callable
from multiprocessing.sharedctypes import Synchronized
from typing import Any
import anyio
from exo.shared.types.events import Event, RunnerStatusUpdated
from exo.shared.types.tasks import Task, TaskId
from exo.shared.types.worker.runners import RunnerFailed, RunnerIdle, RunnerShutdown
from exo.utils.channels import Receiver, Sender, channel, mp_channel
from exo.worker.runner.runner_supervisor import (
HEALTH_CHECK_INTERVAL_SECONDS,
HEARTBEAT_STALE_CHECKS,
RunnerSupervisor,
)
from ...constants import (
INSTANCE_1_ID,
MODEL_A_ID,
NODE_A,
RUNNER_1_ID,
)
from ..conftest import get_bound_mlx_ring_instance
def _die_immediately() -> None:
"""Subprocess target that exits with a non-zero code."""
os._exit(1)
def _die_with_signal() -> None:
"""Subprocess target that kills itself with SIGKILL (simulates OOM)."""
os.kill(os.getpid(), signal_module.SIGKILL)
def _exit_cleanly() -> None:
"""Subprocess target that exits with code 0."""
os._exit(0)
def _hang_forever() -> None:
"""Subprocess target that hangs without updating heartbeat (simulates freeze)."""
import time
# Write one heartbeat so the supervisor starts tracking, then stop.
time.sleep(100000)
def _build_supervisor(
event_sender: Sender[Event],
target: Callable[..., Any],
) -> RunnerSupervisor:
"""Build a RunnerSupervisor with a custom subprocess target.
Uses a clone of event_sender (matching real Worker behavior) so that
closing the supervisor's copy doesn't close the test's receiver.
"""
bound_instance = get_bound_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
runner_id=RUNNER_1_ID,
node_id=NODE_A,
)
_ev_send, ev_recv = mp_channel[Event]()
task_sender, _task_recv = mp_channel[Task]()
cancel_sender, _cancel_recv = mp_channel[TaskId]()
runner_process = multiprocessing.Process(target=target, daemon=True)
heartbeat: Synchronized[int] = multiprocessing.Value("Q", 0)
return RunnerSupervisor(
bound_instance=bound_instance,
shard_metadata=bound_instance.bound_shard,
runner_process=runner_process,
initialize_timeout=10,
_ev_recv=ev_recv,
_task_sender=task_sender,
_cancel_sender=cancel_sender,
_event_sender=event_sender.clone(),
_heartbeat=heartbeat,
)
def _collect_failed_events(
event_receiver: Receiver[Event],
) -> list[RunnerFailed]:
"""Drain the receiver and return all RunnerFailed statuses."""
out: list[RunnerFailed] = []
while True:
try:
event = event_receiver.receive_nowait()
except Exception:
break
if isinstance(event, RunnerStatusUpdated) and isinstance(
event.runner_status, RunnerFailed
):
out.append(event.runner_status)
return out
async def test_health_check_detects_dead_process():
"""When the runner process dies with a non-zero exit code, the health check
should emit a RunnerFailed event and run() should return."""
event_sender, event_receiver = channel[Event]()
supervisor = _build_supervisor(event_sender, _die_immediately)
with anyio.fail_after(HEALTH_CHECK_INTERVAL_SECONDS + 5):
await supervisor.run()
failures = _collect_failed_events(event_receiver)
assert len(failures) == 1
assert failures[0].error_message is not None
assert "exitcode=1" in failures[0].error_message
async def test_health_check_detects_signal_death():
"""When the runner process is killed by a signal (e.g. OOM -> SIGKILL),
the health check should report the signal in the failure message."""
event_sender, event_receiver = channel[Event]()
supervisor = _build_supervisor(event_sender, _die_with_signal)
with anyio.fail_after(HEALTH_CHECK_INTERVAL_SECONDS + 5):
await supervisor.run()
failures = _collect_failed_events(event_receiver)
assert len(failures) == 1
assert failures[0].error_message is not None
assert "signal=9" in failures[0].error_message
async def test_health_check_releases_pending_tasks():
"""When the runner dies, any pending start_task() waiters should be unblocked."""
event_sender, _event_receiver = channel[Event]()
supervisor = _build_supervisor(event_sender, _die_immediately)
# Register a pending waiter as if start_task() was waiting for acknowledgement
task_event = anyio.Event()
tid = TaskId("pending-task")
supervisor.pending[tid] = task_event
with anyio.fail_after(HEALTH_CHECK_INTERVAL_SECONDS + 5):
await supervisor.run()
assert task_event.is_set()
async def test_clean_exit_no_failure_when_shutdown_status():
"""When the runner was in RunnerShutdown status and exits with code 0,
no RunnerFailed event should be emitted."""
event_sender, event_receiver = channel[Event]()
supervisor = _build_supervisor(event_sender, _exit_cleanly)
# Simulate that the runner had already reported shutdown via events
supervisor.status = RunnerShutdown()
with anyio.fail_after(HEALTH_CHECK_INTERVAL_SECONDS + 5):
await supervisor.run()
failures = _collect_failed_events(event_receiver)
assert len(failures) == 0
async def test_unexpected_exit_code_zero_emits_failure():
"""When the runner exits with code 0 but was NOT in a shutdown state,
this is unexpected and should still emit RunnerFailed."""
event_sender, event_receiver = channel[Event]()
supervisor = _build_supervisor(event_sender, _exit_cleanly)
assert isinstance(supervisor.status, RunnerIdle)
with anyio.fail_after(HEALTH_CHECK_INTERVAL_SECONDS + 5):
await supervisor.run()
failures = _collect_failed_events(event_receiver)
assert len(failures) == 1
assert failures[0].error_message is not None
assert "exitcode=0" in failures[0].error_message
async def test_heartbeat_timeout_detects_unresponsive_process():
"""When the runner process is alive but its heartbeat goes stale,
the health check should kill it and emit RunnerFailed."""
event_sender, event_receiver = channel[Event]()
supervisor = _build_supervisor(event_sender, _hang_forever)
# Pre-seed the heartbeat counter with a non-zero value and set the
# supervisor's last-seen value to match so it appears stale immediately.
# Set stale count to HEARTBEAT_STALE_CHECKS - 1 so a single check triggers.
supervisor._heartbeat.value = 42 # pyright: ignore[reportPrivateUsage]
supervisor._last_heartbeat_value = 42 # pyright: ignore[reportPrivateUsage]
supervisor._heartbeat_stale_count = HEARTBEAT_STALE_CHECKS - 1 # pyright: ignore[reportPrivateUsage]
with anyio.fail_after(HEALTH_CHECK_INTERVAL_SECONDS + 5):
await supervisor.run()
failures = _collect_failed_events(event_receiver)
assert len(failures) == 1
assert failures[0].error_message is not None
assert "unresponsive" in failures[0].error_message.lower()
# TODO: