Compare commits

...

6 Commits

Author SHA1 Message Date
Sami Khan
ab622f79c3 EXO iOS app 2026-02-18 06:40:07 +05:00
Alex Cheema
d6301ed593 dashboard: redesign downloads page as model×node table (#1465)
## Motivation

The current downloads page uses a node-centric card grid layout that is
messy and hard to read — the same model across different nodes appears
in separate cards, and deep nesting wastes space. This makes it
difficult to quickly see which models are on which nodes.

## Changes

Rewrote the downloads page
(`dashboard/src/routes/downloads/+page.svelte`) from a card grid to a
clean table layout:

- **Rows** = models (unique across all nodes)
- **Columns** = nodes (with disk free shown in header)
- **Cells** show status at a glance:
  -  Green checkmark + size for completed downloads
  - 🟡 Yellow percentage + mini progress bar + speed for active downloads
  - `...` for pending downloads
  -  Red X for failed downloads
  - `--` for models not present on a node
- Delete/download action buttons appear on row hover
- Model name column is sticky on horizontal scroll (for many-node
clusters)
- Models sorted by number of nodes with completed downloads
- Imported shared utilities from `$lib/utils/downloads` instead of
inline re-implementations

### Backend: model directory in download events

- Added `model_directory` field to `BaseDownloadProgress` so all
download status events include the on-disk path
- Added `_model_dir()` helper to `DownloadCoordinator` to compute the
path from `EXO_MODELS_DIR`
- Dashboard uses this to show file location and enable "open in Finder"
for completed downloads

### Info modal

- Clicking a model name opens an info modal showing card details
(family, quantization, capabilities, storage size, layer count, tensor
parallelism support)

### Other fixes

- Fixed model name truncation in the table
- Excluded `tests/start_distributed_test.py` from pytest collection (CLI
script that calls `sys.exit()` at import time)

## Test Plan

- [x] `uv run basedpyright` — 0 errors
- [x] `uv run ruff check` — all passed
- [x] `nix fmt` — clean
- [x] `uv run pytest` — 188 passed, 1 skipped

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 14:31:47 +00:00
Evan Quiney
6d1ca6689b don't time out node identities (#1493)
currently nodes leaving and rejoining the cluster can lose their identity. We have no need to delete this data on node timing out, so let's just persist it.
2026-02-17 11:48:28 +00:00
Evan
c01b6fff21 eprint banner
our banner was being printed to stdout but should be printed to stderr
as its essentially a log message
2026-02-17 11:43:06 +00:00
Jake Hillion
8392e78afe bench: add spec for automatic canary benchmarks (#1483)
Adds all the models that can fit onto a single M3 Ultra for single
machine benchmarks. Fixes the macOS version, GPU spec, and chip type for
maximum reproducibility. Specifies the minimum memory accordingly for
each type of model, using the smallest machine available (the smallest
M3 Ultra is 96GiB).

Test plan:
- Running this with some code that makes machines of this spec available
and stores the results. It works.

This will become part of a larger testing/stability strategy once we've
collected more of the data.
2026-02-17 10:52:05 +00:00
Evan
86735ece78 begins
begins
2026-02-16 19:26:19 +00:00
42 changed files with 3503 additions and 729 deletions

View File

@@ -0,0 +1,628 @@
// !$*UTF8*$!
{
archiveVersion = 1;
classes = {
};
objectVersion = 77;
objects = {
/* Begin PBXBuildFile section */
E09D17522F44F359009C51A3 /* MLXLLM in Frameworks */ = {isa = PBXBuildFile; productRef = E09D17512F44F359009C51A3 /* MLXLLM */; };
E09D17542F44F359009C51A3 /* MLXLMCommon in Frameworks */ = {isa = PBXBuildFile; productRef = E09D17532F44F359009C51A3 /* MLXLMCommon */; };
/* End PBXBuildFile section */
/* Begin PBXContainerItemProxy section */
E09D167D2F44CA20009C51A3 /* PBXContainerItemProxy */ = {
isa = PBXContainerItemProxy;
containerPortal = E09D16672F44CA1E009C51A3 /* Project object */;
proxyType = 1;
remoteGlobalIDString = E09D166E2F44CA1E009C51A3;
remoteInfo = "EXO-iOS";
};
E09D16872F44CA20009C51A3 /* PBXContainerItemProxy */ = {
isa = PBXContainerItemProxy;
containerPortal = E09D16672F44CA1E009C51A3 /* Project object */;
proxyType = 1;
remoteGlobalIDString = E09D166E2F44CA1E009C51A3;
remoteInfo = "EXO-iOS";
};
/* End PBXContainerItemProxy section */
/* Begin PBXFileReference section */
E09D166F2F44CA1E009C51A3 /* EXO-iOS.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = "EXO-iOS.app"; sourceTree = BUILT_PRODUCTS_DIR; };
E09D167C2F44CA20009C51A3 /* EXO-iOSTests.xctest */ = {isa = PBXFileReference; explicitFileType = wrapper.cfbundle; includeInIndex = 0; path = "EXO-iOSTests.xctest"; sourceTree = BUILT_PRODUCTS_DIR; };
E09D16862F44CA20009C51A3 /* EXO-iOSUITests.xctest */ = {isa = PBXFileReference; explicitFileType = wrapper.cfbundle; includeInIndex = 0; path = "EXO-iOSUITests.xctest"; sourceTree = BUILT_PRODUCTS_DIR; };
/* End PBXFileReference section */
/* Begin PBXFileSystemSynchronizedBuildFileExceptionSet section */
E09D169A2F44CA20009C51A3 /* Exceptions for "EXO-iOS" folder in "EXO-iOS" target */ = {
isa = PBXFileSystemSynchronizedBuildFileExceptionSet;
membershipExceptions = (
Info.plist,
);
target = E09D166E2F44CA1E009C51A3 /* EXO-iOS */;
};
/* End PBXFileSystemSynchronizedBuildFileExceptionSet section */
/* Begin PBXFileSystemSynchronizedRootGroup section */
E09D16712F44CA1E009C51A3 /* EXO-iOS */ = {
isa = PBXFileSystemSynchronizedRootGroup;
exceptions = (
E09D169A2F44CA20009C51A3 /* Exceptions for "EXO-iOS" folder in "EXO-iOS" target */,
);
path = "EXO-iOS";
sourceTree = "<group>";
};
E09D167F2F44CA20009C51A3 /* EXO-iOSTests */ = {
isa = PBXFileSystemSynchronizedRootGroup;
path = "EXO-iOSTests";
sourceTree = "<group>";
};
E09D16892F44CA20009C51A3 /* EXO-iOSUITests */ = {
isa = PBXFileSystemSynchronizedRootGroup;
path = "EXO-iOSUITests";
sourceTree = "<group>";
};
/* End PBXFileSystemSynchronizedRootGroup section */
/* Begin PBXFrameworksBuildPhase section */
E09D166C2F44CA1E009C51A3 /* Frameworks */ = {
isa = PBXFrameworksBuildPhase;
buildActionMask = 2147483647;
files = (
E09D17542F44F359009C51A3 /* MLXLMCommon in Frameworks */,
E09D17522F44F359009C51A3 /* MLXLLM in Frameworks */,
);
runOnlyForDeploymentPostprocessing = 0;
};
E09D16792F44CA20009C51A3 /* Frameworks */ = {
isa = PBXFrameworksBuildPhase;
buildActionMask = 2147483647;
files = (
);
runOnlyForDeploymentPostprocessing = 0;
};
E09D16832F44CA20009C51A3 /* Frameworks */ = {
isa = PBXFrameworksBuildPhase;
buildActionMask = 2147483647;
files = (
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXFrameworksBuildPhase section */
/* Begin PBXGroup section */
E09D16662F44CA1E009C51A3 = {
isa = PBXGroup;
children = (
E09D16712F44CA1E009C51A3 /* EXO-iOS */,
E09D167F2F44CA20009C51A3 /* EXO-iOSTests */,
E09D16892F44CA20009C51A3 /* EXO-iOSUITests */,
E09D16702F44CA1E009C51A3 /* Products */,
);
sourceTree = "<group>";
};
E09D16702F44CA1E009C51A3 /* Products */ = {
isa = PBXGroup;
children = (
E09D166F2F44CA1E009C51A3 /* EXO-iOS.app */,
E09D167C2F44CA20009C51A3 /* EXO-iOSTests.xctest */,
E09D16862F44CA20009C51A3 /* EXO-iOSUITests.xctest */,
);
name = Products;
sourceTree = "<group>";
};
/* End PBXGroup section */
/* Begin PBXNativeTarget section */
E09D166E2F44CA1E009C51A3 /* EXO-iOS */ = {
isa = PBXNativeTarget;
buildConfigurationList = E09D16902F44CA20009C51A3 /* Build configuration list for PBXNativeTarget "EXO-iOS" */;
buildPhases = (
E09D166B2F44CA1E009C51A3 /* Sources */,
E09D166C2F44CA1E009C51A3 /* Frameworks */,
E09D166D2F44CA1E009C51A3 /* Resources */,
);
buildRules = (
);
dependencies = (
);
fileSystemSynchronizedGroups = (
E09D16712F44CA1E009C51A3 /* EXO-iOS */,
);
name = "EXO-iOS";
packageProductDependencies = (
E09D17512F44F359009C51A3 /* MLXLLM */,
E09D17532F44F359009C51A3 /* MLXLMCommon */,
);
productName = "EXO-iOS";
productReference = E09D166F2F44CA1E009C51A3 /* EXO-iOS.app */;
productType = "com.apple.product-type.application";
};
E09D167B2F44CA20009C51A3 /* EXO-iOSTests */ = {
isa = PBXNativeTarget;
buildConfigurationList = E09D16932F44CA20009C51A3 /* Build configuration list for PBXNativeTarget "EXO-iOSTests" */;
buildPhases = (
E09D16782F44CA20009C51A3 /* Sources */,
E09D16792F44CA20009C51A3 /* Frameworks */,
E09D167A2F44CA20009C51A3 /* Resources */,
);
buildRules = (
);
dependencies = (
E09D167E2F44CA20009C51A3 /* PBXTargetDependency */,
);
fileSystemSynchronizedGroups = (
E09D167F2F44CA20009C51A3 /* EXO-iOSTests */,
);
name = "EXO-iOSTests";
packageProductDependencies = (
);
productName = "EXO-iOSTests";
productReference = E09D167C2F44CA20009C51A3 /* EXO-iOSTests.xctest */;
productType = "com.apple.product-type.bundle.unit-test";
};
E09D16852F44CA20009C51A3 /* EXO-iOSUITests */ = {
isa = PBXNativeTarget;
buildConfigurationList = E09D16962F44CA20009C51A3 /* Build configuration list for PBXNativeTarget "EXO-iOSUITests" */;
buildPhases = (
E09D16822F44CA20009C51A3 /* Sources */,
E09D16832F44CA20009C51A3 /* Frameworks */,
E09D16842F44CA20009C51A3 /* Resources */,
);
buildRules = (
);
dependencies = (
E09D16882F44CA20009C51A3 /* PBXTargetDependency */,
);
fileSystemSynchronizedGroups = (
E09D16892F44CA20009C51A3 /* EXO-iOSUITests */,
);
name = "EXO-iOSUITests";
packageProductDependencies = (
);
productName = "EXO-iOSUITests";
productReference = E09D16862F44CA20009C51A3 /* EXO-iOSUITests.xctest */;
productType = "com.apple.product-type.bundle.ui-testing";
};
/* End PBXNativeTarget section */
/* Begin PBXProject section */
E09D16672F44CA1E009C51A3 /* Project object */ = {
isa = PBXProject;
attributes = {
BuildIndependentTargetsInParallel = 1;
LastSwiftUpdateCheck = 2620;
LastUpgradeCheck = 2620;
TargetAttributes = {
E09D166E2F44CA1E009C51A3 = {
CreatedOnToolsVersion = 26.2;
};
E09D167B2F44CA20009C51A3 = {
CreatedOnToolsVersion = 26.2;
TestTargetID = E09D166E2F44CA1E009C51A3;
};
E09D16852F44CA20009C51A3 = {
CreatedOnToolsVersion = 26.2;
TestTargetID = E09D166E2F44CA1E009C51A3;
};
};
};
buildConfigurationList = E09D166A2F44CA1E009C51A3 /* Build configuration list for PBXProject "EXO-iOS" */;
developmentRegion = en;
hasScannedForEncodings = 0;
knownRegions = (
en,
Base,
);
mainGroup = E09D16662F44CA1E009C51A3;
minimizedProjectReferenceProxies = 1;
packageReferences = (
E09D17502F44F359009C51A3 /* XCRemoteSwiftPackageReference "mlx-swift-lm" */,
);
preferredProjectObjectVersion = 77;
productRefGroup = E09D16702F44CA1E009C51A3 /* Products */;
projectDirPath = "";
projectRoot = "";
targets = (
E09D166E2F44CA1E009C51A3 /* EXO-iOS */,
E09D167B2F44CA20009C51A3 /* EXO-iOSTests */,
E09D16852F44CA20009C51A3 /* EXO-iOSUITests */,
);
};
/* End PBXProject section */
/* Begin PBXResourcesBuildPhase section */
E09D166D2F44CA1E009C51A3 /* Resources */ = {
isa = PBXResourcesBuildPhase;
buildActionMask = 2147483647;
files = (
);
runOnlyForDeploymentPostprocessing = 0;
};
E09D167A2F44CA20009C51A3 /* Resources */ = {
isa = PBXResourcesBuildPhase;
buildActionMask = 2147483647;
files = (
);
runOnlyForDeploymentPostprocessing = 0;
};
E09D16842F44CA20009C51A3 /* Resources */ = {
isa = PBXResourcesBuildPhase;
buildActionMask = 2147483647;
files = (
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXResourcesBuildPhase section */
/* Begin PBXSourcesBuildPhase section */
E09D166B2F44CA1E009C51A3 /* Sources */ = {
isa = PBXSourcesBuildPhase;
buildActionMask = 2147483647;
files = (
);
runOnlyForDeploymentPostprocessing = 0;
};
E09D16782F44CA20009C51A3 /* Sources */ = {
isa = PBXSourcesBuildPhase;
buildActionMask = 2147483647;
files = (
);
runOnlyForDeploymentPostprocessing = 0;
};
E09D16822F44CA20009C51A3 /* Sources */ = {
isa = PBXSourcesBuildPhase;
buildActionMask = 2147483647;
files = (
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXSourcesBuildPhase section */
/* Begin PBXTargetDependency section */
E09D167E2F44CA20009C51A3 /* PBXTargetDependency */ = {
isa = PBXTargetDependency;
target = E09D166E2F44CA1E009C51A3 /* EXO-iOS */;
targetProxy = E09D167D2F44CA20009C51A3 /* PBXContainerItemProxy */;
};
E09D16882F44CA20009C51A3 /* PBXTargetDependency */ = {
isa = PBXTargetDependency;
target = E09D166E2F44CA1E009C51A3 /* EXO-iOS */;
targetProxy = E09D16872F44CA20009C51A3 /* PBXContainerItemProxy */;
};
/* End PBXTargetDependency section */
/* Begin XCBuildConfiguration section */
E09D168E2F44CA20009C51A3 /* Debug */ = {
isa = XCBuildConfiguration;
buildSettings = {
ALWAYS_SEARCH_USER_PATHS = NO;
ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;
CLANG_ANALYZER_NONNULL = YES;
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
CLANG_ENABLE_MODULES = YES;
CLANG_ENABLE_OBJC_ARC = YES;
CLANG_ENABLE_OBJC_WEAK = YES;
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
CLANG_WARN_BOOL_CONVERSION = YES;
CLANG_WARN_COMMA = YES;
CLANG_WARN_CONSTANT_CONVERSION = YES;
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
CLANG_WARN_EMPTY_BODY = YES;
CLANG_WARN_ENUM_CONVERSION = YES;
CLANG_WARN_INFINITE_RECURSION = YES;
CLANG_WARN_INT_CONVERSION = YES;
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
CLANG_WARN_STRICT_PROTOTYPES = YES;
CLANG_WARN_SUSPICIOUS_MOVE = YES;
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
CLANG_WARN_UNREACHABLE_CODE = YES;
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
COPY_PHASE_STRIP = NO;
DEBUG_INFORMATION_FORMAT = dwarf;
ENABLE_STRICT_OBJC_MSGSEND = YES;
ENABLE_TESTABILITY = YES;
ENABLE_USER_SCRIPT_SANDBOXING = YES;
GCC_C_LANGUAGE_STANDARD = gnu17;
GCC_DYNAMIC_NO_PIC = NO;
GCC_NO_COMMON_BLOCKS = YES;
GCC_OPTIMIZATION_LEVEL = 0;
GCC_PREPROCESSOR_DEFINITIONS = (
"DEBUG=1",
"$(inherited)",
);
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
GCC_WARN_UNDECLARED_SELECTOR = YES;
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
GCC_WARN_UNUSED_FUNCTION = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
IPHONEOS_DEPLOYMENT_TARGET = 26.2;
LOCALIZATION_PREFERS_STRING_CATALOGS = YES;
MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE;
MTL_FAST_MATH = YES;
ONLY_ACTIVE_ARCH = YES;
SDKROOT = iphoneos;
SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)";
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
};
name = Debug;
};
E09D168F2F44CA20009C51A3 /* Release */ = {
isa = XCBuildConfiguration;
buildSettings = {
ALWAYS_SEARCH_USER_PATHS = NO;
ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;
CLANG_ANALYZER_NONNULL = YES;
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
CLANG_ENABLE_MODULES = YES;
CLANG_ENABLE_OBJC_ARC = YES;
CLANG_ENABLE_OBJC_WEAK = YES;
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
CLANG_WARN_BOOL_CONVERSION = YES;
CLANG_WARN_COMMA = YES;
CLANG_WARN_CONSTANT_CONVERSION = YES;
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
CLANG_WARN_EMPTY_BODY = YES;
CLANG_WARN_ENUM_CONVERSION = YES;
CLANG_WARN_INFINITE_RECURSION = YES;
CLANG_WARN_INT_CONVERSION = YES;
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
CLANG_WARN_STRICT_PROTOTYPES = YES;
CLANG_WARN_SUSPICIOUS_MOVE = YES;
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
CLANG_WARN_UNREACHABLE_CODE = YES;
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
COPY_PHASE_STRIP = NO;
DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym";
ENABLE_NS_ASSERTIONS = NO;
ENABLE_STRICT_OBJC_MSGSEND = YES;
ENABLE_USER_SCRIPT_SANDBOXING = YES;
GCC_C_LANGUAGE_STANDARD = gnu17;
GCC_NO_COMMON_BLOCKS = YES;
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
GCC_WARN_UNDECLARED_SELECTOR = YES;
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
GCC_WARN_UNUSED_FUNCTION = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
IPHONEOS_DEPLOYMENT_TARGET = 26.2;
LOCALIZATION_PREFERS_STRING_CATALOGS = YES;
MTL_ENABLE_DEBUG_INFO = NO;
MTL_FAST_MATH = YES;
SDKROOT = iphoneos;
SWIFT_COMPILATION_MODE = wholemodule;
VALIDATE_PRODUCT = YES;
};
name = Release;
};
E09D16912F44CA20009C51A3 /* Debug */ = {
isa = XCBuildConfiguration;
buildSettings = {
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor;
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
DEVELOPMENT_TEAM = 3M3M67U93M;
ENABLE_PREVIEWS = YES;
GENERATE_INFOPLIST_FILE = YES;
INFOPLIST_FILE = "EXO-iOS/Info.plist";
INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES;
INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES;
INFOPLIST_KEY_UILaunchScreen_Generation = YES;
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
LD_RUNPATH_SEARCH_PATHS = (
"$(inherited)",
"@executable_path/Frameworks",
);
MARKETING_VERSION = 1.0;
PRODUCT_BUNDLE_IDENTIFIER = "com.exo.EXO-iOS";
PRODUCT_NAME = "$(TARGET_NAME)";
STRING_CATALOG_GENERATE_SYMBOLS = YES;
SWIFT_APPROACHABLE_CONCURRENCY = YES;
SWIFT_DEFAULT_ACTOR_ISOLATION = MainActor;
SWIFT_EMIT_LOC_STRINGS = YES;
SWIFT_UPCOMING_FEATURE_MEMBER_IMPORT_VISIBILITY = YES;
SWIFT_VERSION = 5.0;
TARGETED_DEVICE_FAMILY = "1,2";
};
name = Debug;
};
E09D16922F44CA20009C51A3 /* Release */ = {
isa = XCBuildConfiguration;
buildSettings = {
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor;
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
DEVELOPMENT_TEAM = 3M3M67U93M;
ENABLE_PREVIEWS = YES;
GENERATE_INFOPLIST_FILE = YES;
INFOPLIST_FILE = "EXO-iOS/Info.plist";
INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES;
INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES;
INFOPLIST_KEY_UILaunchScreen_Generation = YES;
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
LD_RUNPATH_SEARCH_PATHS = (
"$(inherited)",
"@executable_path/Frameworks",
);
MARKETING_VERSION = 1.0;
PRODUCT_BUNDLE_IDENTIFIER = "com.exo.EXO-iOS";
PRODUCT_NAME = "$(TARGET_NAME)";
STRING_CATALOG_GENERATE_SYMBOLS = YES;
SWIFT_APPROACHABLE_CONCURRENCY = YES;
SWIFT_DEFAULT_ACTOR_ISOLATION = MainActor;
SWIFT_EMIT_LOC_STRINGS = YES;
SWIFT_UPCOMING_FEATURE_MEMBER_IMPORT_VISIBILITY = YES;
SWIFT_VERSION = 5.0;
TARGETED_DEVICE_FAMILY = "1,2";
};
name = Release;
};
E09D16942F44CA20009C51A3 /* Debug */ = {
isa = XCBuildConfiguration;
buildSettings = {
BUNDLE_LOADER = "$(TEST_HOST)";
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
GENERATE_INFOPLIST_FILE = YES;
IPHONEOS_DEPLOYMENT_TARGET = 26.2;
MARKETING_VERSION = 1.0;
PRODUCT_BUNDLE_IDENTIFIER = "com.exo.EXO-iOSTests";
PRODUCT_NAME = "$(TARGET_NAME)";
STRING_CATALOG_GENERATE_SYMBOLS = NO;
SWIFT_APPROACHABLE_CONCURRENCY = YES;
SWIFT_EMIT_LOC_STRINGS = NO;
SWIFT_UPCOMING_FEATURE_MEMBER_IMPORT_VISIBILITY = YES;
SWIFT_VERSION = 5.0;
TARGETED_DEVICE_FAMILY = "1,2";
TEST_HOST = "$(BUILT_PRODUCTS_DIR)/EXO-iOS.app/$(BUNDLE_EXECUTABLE_FOLDER_PATH)/EXO-iOS";
};
name = Debug;
};
E09D16952F44CA20009C51A3 /* Release */ = {
isa = XCBuildConfiguration;
buildSettings = {
BUNDLE_LOADER = "$(TEST_HOST)";
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
GENERATE_INFOPLIST_FILE = YES;
IPHONEOS_DEPLOYMENT_TARGET = 26.2;
MARKETING_VERSION = 1.0;
PRODUCT_BUNDLE_IDENTIFIER = "com.exo.EXO-iOSTests";
PRODUCT_NAME = "$(TARGET_NAME)";
STRING_CATALOG_GENERATE_SYMBOLS = NO;
SWIFT_APPROACHABLE_CONCURRENCY = YES;
SWIFT_EMIT_LOC_STRINGS = NO;
SWIFT_UPCOMING_FEATURE_MEMBER_IMPORT_VISIBILITY = YES;
SWIFT_VERSION = 5.0;
TARGETED_DEVICE_FAMILY = "1,2";
TEST_HOST = "$(BUILT_PRODUCTS_DIR)/EXO-iOS.app/$(BUNDLE_EXECUTABLE_FOLDER_PATH)/EXO-iOS";
};
name = Release;
};
E09D16972F44CA20009C51A3 /* Debug */ = {
isa = XCBuildConfiguration;
buildSettings = {
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
GENERATE_INFOPLIST_FILE = YES;
MARKETING_VERSION = 1.0;
PRODUCT_BUNDLE_IDENTIFIER = "com.exo.EXO-iOSUITests";
PRODUCT_NAME = "$(TARGET_NAME)";
STRING_CATALOG_GENERATE_SYMBOLS = NO;
SWIFT_APPROACHABLE_CONCURRENCY = YES;
SWIFT_EMIT_LOC_STRINGS = NO;
SWIFT_UPCOMING_FEATURE_MEMBER_IMPORT_VISIBILITY = YES;
SWIFT_VERSION = 5.0;
TARGETED_DEVICE_FAMILY = "1,2";
TEST_TARGET_NAME = "EXO-iOS";
};
name = Debug;
};
E09D16982F44CA20009C51A3 /* Release */ = {
isa = XCBuildConfiguration;
buildSettings = {
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
GENERATE_INFOPLIST_FILE = YES;
MARKETING_VERSION = 1.0;
PRODUCT_BUNDLE_IDENTIFIER = "com.exo.EXO-iOSUITests";
PRODUCT_NAME = "$(TARGET_NAME)";
STRING_CATALOG_GENERATE_SYMBOLS = NO;
SWIFT_APPROACHABLE_CONCURRENCY = YES;
SWIFT_EMIT_LOC_STRINGS = NO;
SWIFT_UPCOMING_FEATURE_MEMBER_IMPORT_VISIBILITY = YES;
SWIFT_VERSION = 5.0;
TARGETED_DEVICE_FAMILY = "1,2";
TEST_TARGET_NAME = "EXO-iOS";
};
name = Release;
};
/* End XCBuildConfiguration section */
/* Begin XCConfigurationList section */
E09D166A2F44CA1E009C51A3 /* Build configuration list for PBXProject "EXO-iOS" */ = {
isa = XCConfigurationList;
buildConfigurations = (
E09D168E2F44CA20009C51A3 /* Debug */,
E09D168F2F44CA20009C51A3 /* Release */,
);
defaultConfigurationIsVisible = 0;
defaultConfigurationName = Release;
};
E09D16902F44CA20009C51A3 /* Build configuration list for PBXNativeTarget "EXO-iOS" */ = {
isa = XCConfigurationList;
buildConfigurations = (
E09D16912F44CA20009C51A3 /* Debug */,
E09D16922F44CA20009C51A3 /* Release */,
);
defaultConfigurationIsVisible = 0;
defaultConfigurationName = Release;
};
E09D16932F44CA20009C51A3 /* Build configuration list for PBXNativeTarget "EXO-iOSTests" */ = {
isa = XCConfigurationList;
buildConfigurations = (
E09D16942F44CA20009C51A3 /* Debug */,
E09D16952F44CA20009C51A3 /* Release */,
);
defaultConfigurationIsVisible = 0;
defaultConfigurationName = Release;
};
E09D16962F44CA20009C51A3 /* Build configuration list for PBXNativeTarget "EXO-iOSUITests" */ = {
isa = XCConfigurationList;
buildConfigurations = (
E09D16972F44CA20009C51A3 /* Debug */,
E09D16982F44CA20009C51A3 /* Release */,
);
defaultConfigurationIsVisible = 0;
defaultConfigurationName = Release;
};
/* End XCConfigurationList section */
/* Begin XCRemoteSwiftPackageReference section */
E09D17502F44F359009C51A3 /* XCRemoteSwiftPackageReference "mlx-swift-lm" */ = {
isa = XCRemoteSwiftPackageReference;
repositoryURL = "https://github.com/ml-explore/mlx-swift-lm";
requirement = {
kind = upToNextMajorVersion;
minimumVersion = 2.30.3;
};
};
/* End XCRemoteSwiftPackageReference section */
/* Begin XCSwiftPackageProductDependency section */
E09D17512F44F359009C51A3 /* MLXLLM */ = {
isa = XCSwiftPackageProductDependency;
package = E09D17502F44F359009C51A3 /* XCRemoteSwiftPackageReference "mlx-swift-lm" */;
productName = MLXLLM;
};
E09D17532F44F359009C51A3 /* MLXLMCommon */ = {
isa = XCSwiftPackageProductDependency;
package = E09D17502F44F359009C51A3 /* XCRemoteSwiftPackageReference "mlx-swift-lm" */;
productName = MLXLMCommon;
};
/* End XCSwiftPackageProductDependency section */
};
rootObject = E09D16672F44CA1E009C51A3 /* Project object */;
}

View File

@@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<Workspace
version = "1.0">
<FileRef
location = "self:">
</FileRef>
</Workspace>

View File

@@ -0,0 +1,60 @@
{
"originHash" : "facc0ac7c70363ea20f6cd1235de91dea6b06f0d00190946045a6c8ae753abc2",
"pins" : [
{
"identity" : "mlx-swift",
"kind" : "remoteSourceControl",
"location" : "https://github.com/ml-explore/mlx-swift",
"state" : {
"revision" : "6ba4827fb82c97d012eec9ab4b2de21f85c3b33d",
"version" : "0.30.6"
}
},
{
"identity" : "mlx-swift-lm",
"kind" : "remoteSourceControl",
"location" : "https://github.com/ml-explore/mlx-swift-lm",
"state" : {
"revision" : "360c5052b81cc154b04ee0933597a4ad6db4b8ae",
"version" : "2.30.3"
}
},
{
"identity" : "swift-collections",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-collections.git",
"state" : {
"revision" : "7b847a3b7008b2dc2f47ca3110d8c782fb2e5c7e",
"version" : "1.3.0"
}
},
{
"identity" : "swift-jinja",
"kind" : "remoteSourceControl",
"location" : "https://github.com/huggingface/swift-jinja.git",
"state" : {
"revision" : "d81197f35f41445bc10e94600795e68c6f5e94b0",
"version" : "2.3.1"
}
},
{
"identity" : "swift-numerics",
"kind" : "remoteSourceControl",
"location" : "https://github.com/apple/swift-numerics",
"state" : {
"revision" : "0c0290ff6b24942dadb83a929ffaaa1481df04a2",
"version" : "1.1.1"
}
},
{
"identity" : "swift-transformers",
"kind" : "remoteSourceControl",
"location" : "https://github.com/huggingface/swift-transformers",
"state" : {
"revision" : "573e5c9036c2f136b3a8a071da8e8907322403d0",
"version" : "1.1.6"
}
}
],
"version" : 3
}

View File

@@ -0,0 +1,11 @@
{
"colors" : [
{
"idiom" : "universal"
}
],
"info" : {
"author" : "xcode",
"version" : 1
}
}

View File

@@ -0,0 +1,35 @@
{
"images" : [
{
"idiom" : "universal",
"platform" : "ios",
"size" : "1024x1024"
},
{
"appearances" : [
{
"appearance" : "luminosity",
"value" : "dark"
}
],
"idiom" : "universal",
"platform" : "ios",
"size" : "1024x1024"
},
{
"appearances" : [
{
"appearance" : "luminosity",
"value" : "tinted"
}
],
"idiom" : "universal",
"platform" : "ios",
"size" : "1024x1024"
}
],
"info" : {
"author" : "xcode",
"version" : 1
}
}

View File

@@ -0,0 +1,6 @@
{
"info" : {
"author" : "xcode",
"version" : 1
}
}

View File

@@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>com.apple.developer.kernel.increased-memory-limit</key>
<true/>
</dict>
</plist>

View File

@@ -0,0 +1,42 @@
import SwiftUI
@main
struct EXO_iOSApp: App {
@State private var clusterService = ClusterService()
@State private var discoveryService = DiscoveryService()
@State private var localInferenceService = LocalInferenceService()
@State private var chatService: ChatService?
var body: some Scene {
WindowGroup {
if let chatService {
RootView()
.environment(clusterService)
.environment(discoveryService)
.environment(chatService)
.environment(localInferenceService)
.task {
await clusterService.attemptAutoReconnect()
discoveryService.startBrowsing()
await localInferenceService.prepareModel()
}
.onChange(of: discoveryService.discoveredClusters) { _, clusters in
guard !clusterService.isConnected,
case .disconnected = clusterService.connectionState,
let first = clusters.first
else { return }
Task {
await clusterService.connectToDiscoveredCluster(first, using: discoveryService)
}
}
} else {
Color.clear.onAppear {
chatService = ChatService(
clusterService: clusterService,
localInferenceService: localInferenceService
)
}
}
}
}
}

View File

@@ -0,0 +1,17 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>CFBundleDisplayName</key>
<string>EXO</string>
<key>NSLocalNetworkUsageDescription</key>
<string>EXO needs local network access to connect to your EXO cluster.</string>
<key>NSBonjourServices</key>
<array>
<string>_exo._tcp</string>
<string>_p2p._tcp</string>
<string>_p2p._udp</string>
<string>_libp2p._udp</string>
</array>
</dict>
</plist>

View File

@@ -0,0 +1,129 @@
import Foundation
// MARK: - Request
struct ChatCompletionRequest: Encodable {
let model: String
let messages: [ChatCompletionMessageParam]
let stream: Bool
let maxTokens: Int?
let temperature: Double?
enum CodingKeys: String, CodingKey {
case model, messages, stream, temperature
case maxTokens = "max_tokens"
}
}
struct ChatCompletionMessageParam: Encodable {
let role: String
let content: String
}
// MARK: - Streaming Response
struct ChatCompletionChunk: Decodable {
let id: String
let model: String?
let choices: [StreamingChoice]
let usage: ChunkUsage?
init(id: String, model: String?, choices: [StreamingChoice], usage: ChunkUsage?) {
self.id = id
self.model = model
self.choices = choices
self.usage = usage
}
}
struct StreamingChoice: Decodable {
let index: Int
let delta: Delta
let finishReason: String?
enum CodingKeys: String, CodingKey {
case index, delta
case finishReason = "finish_reason"
}
init(index: Int, delta: Delta, finishReason: String?) {
self.index = index
self.delta = delta
self.finishReason = finishReason
}
}
struct Delta: Decodable {
let role: String?
let content: String?
init(role: String?, content: String?) {
self.role = role
self.content = content
}
}
struct ChunkUsage: Decodable {
let promptTokens: Int?
let completionTokens: Int?
let totalTokens: Int?
enum CodingKeys: String, CodingKey {
case promptTokens = "prompt_tokens"
case completionTokens = "completion_tokens"
case totalTokens = "total_tokens"
}
init(promptTokens: Int?, completionTokens: Int?, totalTokens: Int?) {
self.promptTokens = promptTokens
self.completionTokens = completionTokens
self.totalTokens = totalTokens
}
}
// MARK: - Non-Streaming Response
struct ChatCompletionResponse: Decodable {
let id: String
let model: String?
let choices: [ResponseChoice]
}
struct ResponseChoice: Decodable {
let index: Int
let message: ResponseMessage
let finishReason: String?
enum CodingKeys: String, CodingKey {
case index, message
case finishReason = "finish_reason"
}
}
struct ResponseMessage: Decodable {
let role: String?
let content: String?
}
// MARK: - Models List
struct ModelListResponse: Decodable {
let data: [ModelInfo]
}
struct ModelInfo: Decodable, Identifiable {
let id: String
let name: String?
}
// MARK: - Error
struct APIErrorResponse: Decodable {
let error: APIErrorInfo
}
struct APIErrorInfo: Decodable {
let message: String
let type: String?
let code: Int?
}

View File

@@ -0,0 +1,26 @@
import Foundation
struct ChatMessage: Identifiable, Equatable {
let id: UUID
let role: Role
var content: String
let timestamp: Date
var isStreaming: Bool
enum Role: String, Codable {
case user
case assistant
case system
}
init(
id: UUID = UUID(), role: Role, content: String, timestamp: Date = Date(),
isStreaming: Bool = false
) {
self.id = id
self.role = role
self.content = content
self.timestamp = timestamp
self.isStreaming = isStreaming
}
}

View File

@@ -0,0 +1,11 @@
import Foundation
struct ConnectionInfo: Codable, Equatable {
let host: String
let port: Int
let nodeId: String?
var baseURL: URL { URL(string: "http://\(host):\(port)")! }
static let defaultPort = 52415
}

View File

@@ -0,0 +1,34 @@
import Foundation
struct Conversation: Identifiable, Codable, Equatable {
let id: UUID
var title: String
var messages: [StoredMessage]
var modelId: String?
let createdAt: Date
init(
id: UUID = UUID(), title: String = "New Chat", messages: [StoredMessage] = [],
modelId: String? = nil, createdAt: Date = Date()
) {
self.id = id
self.title = title
self.messages = messages
self.modelId = modelId
self.createdAt = createdAt
}
}
struct StoredMessage: Identifiable, Codable, Equatable {
let id: UUID
let role: String
var content: String
let timestamp: Date
init(id: UUID = UUID(), role: String, content: String, timestamp: Date = Date()) {
self.id = id
self.role = role
self.content = content
self.timestamp = timestamp
}
}

View File

@@ -0,0 +1,225 @@
import Foundation
@Observable
@MainActor
final class ChatService {
var conversations: [Conversation] = []
var activeConversationId: UUID?
private(set) var isGenerating: Bool = false
private var currentGenerationTask: Task<Void, Never>?
private let clusterService: ClusterService
private let localInferenceService: LocalInferenceService
var canSendMessage: Bool {
clusterService.isConnected || localInferenceService.isAvailable
}
var activeConversation: Conversation? {
guard let id = activeConversationId else { return nil }
return conversations.first { $0.id == id }
}
var activeMessages: [ChatMessage] {
guard let conversation = activeConversation else { return [] }
return conversation.messages.map { stored in
ChatMessage(
id: stored.id,
role: ChatMessage.Role(rawValue: stored.role) ?? .user,
content: stored.content,
timestamp: stored.timestamp
)
}
}
init(clusterService: ClusterService, localInferenceService: LocalInferenceService) {
self.clusterService = clusterService
self.localInferenceService = localInferenceService
loadConversations()
}
// MARK: - Conversation Management
func createConversation(modelId: String? = nil) {
let conversation = Conversation(
modelId: modelId ?? clusterService.availableModels.first?.id)
conversations.insert(conversation, at: 0)
activeConversationId = conversation.id
saveConversations()
}
func deleteConversation(id: UUID) {
conversations.removeAll { $0.id == id }
if activeConversationId == id {
activeConversationId = conversations.first?.id
}
saveConversations()
}
func setActiveConversation(id: UUID) {
activeConversationId = id
}
func setModelForActiveConversation(_ modelId: String) {
guard let index = conversations.firstIndex(where: { $0.id == activeConversationId }) else {
return
}
conversations[index].modelId = modelId
saveConversations()
}
// MARK: - Messaging
func sendMessage(_ text: String) {
guard !text.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty else { return }
if activeConversation == nil {
createConversation()
}
guard let index = conversations.firstIndex(where: { $0.id == activeConversationId }) else {
return
}
let userMessage = StoredMessage(role: "user", content: text)
conversations[index].messages.append(userMessage)
if conversations[index].title == "New Chat" {
let preview = String(text.prefix(40))
conversations[index].title = preview + (text.count > 40 ? "..." : "")
}
let modelId: String
if clusterService.isConnected {
guard let clusterId = conversations[index].modelId ?? clusterService.availableModels.first?.id
else {
let errorMessage = StoredMessage(
role: "assistant", content: "No model selected. Please select a model first.")
conversations[index].messages.append(errorMessage)
saveConversations()
return
}
modelId = clusterId
} else if localInferenceService.isAvailable {
modelId = localInferenceService.defaultModelId
} else {
let errorMessage = StoredMessage(
role: "assistant",
content: "Not connected to a cluster and local model is not available.")
conversations[index].messages.append(errorMessage)
saveConversations()
return
}
conversations[index].modelId = modelId
let assistantMessageId = UUID()
let assistantMessage = StoredMessage(
id: assistantMessageId, role: "assistant", content: "", timestamp: Date())
conversations[index].messages.append(assistantMessage)
let messagesForAPI = conversations[index].messages.dropLast().map { stored in
ChatCompletionMessageParam(role: stored.role, content: stored.content)
}
let request = ChatCompletionRequest(
model: modelId,
messages: Array(messagesForAPI),
stream: true,
maxTokens: 4096,
temperature: nil
)
let conversationId = conversations[index].id
isGenerating = true
currentGenerationTask = Task { [weak self] in
guard let self else { return }
await self.performStreaming(
request: request, conversationId: conversationId,
assistantMessageId: assistantMessageId)
}
saveConversations()
}
func cancelGeneration() {
currentGenerationTask?.cancel()
currentGenerationTask = nil
localInferenceService.cancelGeneration()
isGenerating = false
}
// MARK: - Streaming
private func performStreaming(
request: ChatCompletionRequest, conversationId: UUID, assistantMessageId: UUID
) async {
defer {
isGenerating = false
currentGenerationTask = nil
saveConversations()
}
do {
let stream =
clusterService.isConnected
? clusterService.streamChatCompletion(request: request)
: localInferenceService.streamChatCompletion(request: request)
for try await chunk in stream {
guard !Task.isCancelled else { return }
guard let content = chunk.choices.first?.delta.content, !content.isEmpty else {
continue
}
if let convIndex = conversations.firstIndex(where: { $0.id == conversationId }),
let msgIndex = conversations[convIndex].messages.firstIndex(where: {
$0.id == assistantMessageId
})
{
conversations[convIndex].messages[msgIndex].content += content
}
}
} catch {
if !Task.isCancelled {
if let convIndex = conversations.firstIndex(where: { $0.id == conversationId }),
let msgIndex = conversations[convIndex].messages.firstIndex(where: {
$0.id == assistantMessageId
})
{
if conversations[convIndex].messages[msgIndex].content.isEmpty {
conversations[convIndex].messages[msgIndex].content =
"Error: \(error.localizedDescription)"
}
}
}
}
}
// MARK: - Persistence
private static var storageURL: URL {
let documents = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask)
.first!
return documents.appendingPathComponent("exo_conversations.json")
}
private func saveConversations() {
do {
let data = try JSONEncoder().encode(conversations)
try data.write(to: Self.storageURL, options: .atomic)
} catch {
// Save failed silently
}
}
private func loadConversations() {
do {
let data = try Data(contentsOf: Self.storageURL)
conversations = try JSONDecoder().decode([Conversation].self, from: data)
activeConversationId = conversations.first?.id
} catch {
conversations = []
}
}
}

View File

@@ -0,0 +1,198 @@
import Foundation
enum ConnectionState: Equatable {
case disconnected
case connecting
case connected(ConnectionInfo)
}
struct ModelOption: Identifiable, Equatable {
let id: String
let displayName: String
}
@Observable
@MainActor
final class ClusterService {
private(set) var connectionState: ConnectionState = .disconnected
private(set) var availableModels: [ModelOption] = []
private(set) var lastError: String?
private let session: URLSession
private let decoder: JSONDecoder
private var pollingTask: Task<Void, Never>?
private static let connectionInfoKey = "exo_last_connection_info"
var isConnected: Bool {
if case .connected = connectionState { return true }
return false
}
var currentConnection: ConnectionInfo? {
if case .connected(let info) = connectionState { return info }
return nil
}
init(session: URLSession = .shared) {
self.session = session
let decoder = JSONDecoder()
self.decoder = decoder
}
// MARK: - Connection
func connect(to info: ConnectionInfo) async {
connectionState = .connecting
lastError = nil
do {
let url = info.baseURL.appendingPathComponent("node_id")
var request = URLRequest(url: url)
request.timeoutInterval = 5
request.cachePolicy = .reloadIgnoringLocalCacheData
let (_, response) = try await session.data(for: request)
guard let httpResponse = response as? HTTPURLResponse,
(200..<300).contains(httpResponse.statusCode)
else {
throw URLError(.badServerResponse)
}
connectionState = .connected(info)
persistConnection(info)
startPolling()
await fetchModels(baseURL: info.baseURL)
} catch {
connectionState = .disconnected
lastError = "Could not connect to \(info.host):\(info.port)"
}
}
func connectToDiscoveredCluster(_ cluster: DiscoveredCluster, using discoveryService: DiscoveryService) async {
guard case .disconnected = connectionState else { return }
connectionState = .connecting
lastError = nil
guard let info = await discoveryService.resolve(cluster) else {
connectionState = .disconnected
lastError = "Could not resolve \(cluster.name)"
return
}
connectionState = .disconnected // reset so connect() can proceed
await connect(to: info)
}
func disconnect() {
stopPolling()
connectionState = .disconnected
availableModels = []
lastError = nil
}
func attemptAutoReconnect() async {
guard case .disconnected = connectionState,
let info = loadPersistedConnection()
else { return }
await connect(to: info)
}
// MARK: - Polling
private func startPolling(interval: TimeInterval = 2.0) {
stopPolling()
pollingTask = Task { [weak self] in
while !Task.isCancelled {
try? await Task.sleep(for: .seconds(interval))
guard let self, !Task.isCancelled else { return }
guard let connection = self.currentConnection else { return }
await self.fetchModels(baseURL: connection.baseURL)
}
}
}
private func stopPolling() {
pollingTask?.cancel()
pollingTask = nil
}
// MARK: - API
private func fetchModels(baseURL: URL) async {
do {
let url = baseURL.appendingPathComponent("models")
var request = URLRequest(url: url)
request.cachePolicy = .reloadIgnoringLocalCacheData
let (data, response) = try await session.data(for: request)
guard let httpResponse = response as? HTTPURLResponse,
(200..<300).contains(httpResponse.statusCode)
else { return }
let list = try decoder.decode(ModelListResponse.self, from: data)
availableModels = list.data.map {
ModelOption(id: $0.id, displayName: $0.name ?? $0.id)
}
} catch {
// Models fetch failed silently will retry on next poll
}
}
func streamChatCompletion(request body: ChatCompletionRequest) -> AsyncThrowingStream<
ChatCompletionChunk, Error
> {
AsyncThrowingStream { continuation in
let task = Task { [weak self] in
guard let self, let connection = self.currentConnection else {
continuation.finish(throwing: URLError(.notConnectedToInternet))
return
}
do {
let url = connection.baseURL.appendingPathComponent("v1/chat/completions")
var request = URLRequest(url: url)
request.httpMethod = "POST"
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
request.httpBody = try JSONEncoder().encode(body)
let (bytes, response) = try await self.session.bytes(for: request)
guard let httpResponse = response as? HTTPURLResponse,
(200..<300).contains(httpResponse.statusCode)
else {
continuation.finish(throwing: URLError(.badServerResponse))
return
}
let parser = SSEStreamParser<ChatCompletionChunk>(
bytes: bytes, decoder: self.decoder)
for try await chunk in parser {
continuation.yield(chunk)
}
continuation.finish()
} catch {
continuation.finish(throwing: error)
}
}
continuation.onTermination = { _ in
task.cancel()
}
}
}
// MARK: - Persistence
private func persistConnection(_ info: ConnectionInfo) {
if let data = try? JSONEncoder().encode(info) {
UserDefaults.standard.set(data, forKey: Self.connectionInfoKey)
}
}
private func loadPersistedConnection() -> ConnectionInfo? {
guard let data = UserDefaults.standard.data(forKey: Self.connectionInfoKey) else {
return nil
}
return try? JSONDecoder().decode(ConnectionInfo.self, from: data)
}
}

View File

@@ -0,0 +1,118 @@
import Foundation
import Network
import os
struct DiscoveredCluster: Identifiable, Equatable {
let id: String
let name: String
let endpoint: NWEndpoint
static func == (lhs: DiscoveredCluster, rhs: DiscoveredCluster) -> Bool {
lhs.id == rhs.id && lhs.name == rhs.name
}
}
@Observable
@MainActor
final class DiscoveryService {
private(set) var discoveredClusters: [DiscoveredCluster] = []
private(set) var isSearching = false
private var browser: NWBrowser?
func startBrowsing() {
guard browser == nil else { return }
let browser = NWBrowser(for: .bonjour(type: "_exo._tcp", domain: nil), using: .tcp)
browser.stateUpdateHandler = { [weak self] state in
guard let service = self else { return }
Task { @MainActor in
switch state {
case .ready:
service.isSearching = true
case .failed, .cancelled:
service.isSearching = false
default:
break
}
}
}
browser.browseResultsChangedHandler = { [weak self] results, _ in
guard let service = self else { return }
Task { @MainActor in
service.discoveredClusters = results.compactMap { result in
guard case .service(let name, _, _, _) = result.endpoint else {
return nil
}
return DiscoveredCluster(
id: name,
name: name,
endpoint: result.endpoint
)
}
}
}
browser.start(queue: .main)
self.browser = browser
}
func stopBrowsing() {
browser?.cancel()
browser = nil
isSearching = false
discoveredClusters = []
}
/// Resolve a discovered Bonjour endpoint to an IP address and port, then return a ConnectionInfo.
func resolve(_ cluster: DiscoveredCluster) async -> ConnectionInfo? {
await withCheckedContinuation { continuation in
let didResume = OSAllocatedUnfairLock(initialState: false)
let connection = NWConnection(to: cluster.endpoint, using: .tcp)
connection.stateUpdateHandler = { state in
guard didResume.withLock({ guard !$0 else { return false }; $0 = true; return true })
else { return }
switch state {
case .ready:
if let innerEndpoint = connection.currentPath?.remoteEndpoint,
case .hostPort(let host, let port) = innerEndpoint
{
var hostString: String
switch host {
case .ipv4(let addr):
hostString = "\(addr)"
case .ipv6(let addr):
hostString = "\(addr)"
case .name(let name, _):
hostString = name
@unknown default:
hostString = "\(host)"
}
// Strip interface scope suffix (e.g. "%en0")
if let pct = hostString.firstIndex(of: "%") {
hostString = String(hostString[..<pct])
}
let info = ConnectionInfo(
host: hostString,
port: Int(port.rawValue),
nodeId: nil
)
connection.cancel()
continuation.resume(returning: info)
} else {
connection.cancel()
continuation.resume(returning: nil)
}
case .failed, .cancelled:
continuation.resume(returning: nil)
default:
// Not a terminal state allow future callbacks
didResume.withLock { $0 = false }
}
}
connection.start(queue: .global(qos: .userInitiated))
}
}
}

View File

@@ -0,0 +1,201 @@
import Foundation
import MLXLLM
import MLXLMCommon
enum LocalModelState: Equatable {
case notDownloaded
case downloading(progress: Double)
case downloaded
case loading
case ready
case generating
case error(String)
}
@Observable
@MainActor
final class LocalInferenceService {
private(set) var modelState: LocalModelState = .notDownloaded
private var modelContainer: ModelContainer?
private var generationTask: Task<Void, Never>?
let defaultModelId = "mlx-community/Qwen3-0.6B-4bit"
private static let modelDownloadedKey = "exo_local_model_downloaded"
var isReady: Bool {
modelState == .ready
}
var isAvailable: Bool {
modelState == .ready || modelState == .generating
}
init() {
if UserDefaults.standard.bool(forKey: Self.modelDownloadedKey) {
modelState = .downloaded
}
}
// MARK: - Model Lifecycle
func prepareModel() async {
guard modelState == .notDownloaded || modelState == .downloaded else { return }
let wasDownloaded = modelState == .downloaded
if !wasDownloaded {
modelState = .downloading(progress: 0)
} else {
modelState = .loading
}
do {
let container = try await loadModelContainer(
id: defaultModelId
) { [weak self] progress in
guard let self else { return }
Task { @MainActor in
if case .downloading = self.modelState {
self.modelState = .downloading(progress: progress.fractionCompleted)
}
}
}
self.modelContainer = container
UserDefaults.standard.set(true, forKey: Self.modelDownloadedKey)
modelState = .ready
} catch {
modelState = .error(error.localizedDescription)
}
}
func unloadModel() {
cancelGeneration()
modelContainer = nil
modelState = .downloaded
}
// MARK: - Generation
func streamChatCompletion(request: ChatCompletionRequest) -> AsyncThrowingStream<
ChatCompletionChunk, Error
> {
AsyncThrowingStream { continuation in
let task = Task { [weak self] in
guard let self else {
continuation.finish(throwing: LocalInferenceError.serviceUnavailable)
return
}
guard let container = self.modelContainer else {
continuation.finish(throwing: LocalInferenceError.modelNotLoaded)
return
}
await MainActor.run {
self.modelState = .generating
}
defer {
Task { @MainActor [weak self] in
if self?.modelState == .generating {
self?.modelState = .ready
}
}
}
let chunkId = "local-\(UUID().uuidString)"
do {
// Build Chat.Message array from the request
var chatMessages: [Chat.Message] = []
for msg in request.messages {
switch msg.role {
case "system":
chatMessages.append(.system(msg.content))
case "assistant":
chatMessages.append(.assistant(msg.content))
default:
chatMessages.append(.user(msg.content))
}
}
// Use ChatSession for streaming generation
let session = ChatSession(
container,
history: chatMessages,
generateParameters: GenerateParameters(
maxTokens: request.maxTokens ?? 4096,
temperature: Float(request.temperature ?? 0.7)
)
)
// Stream with an empty prompt since history already contains the conversation
let stream = session.streamResponse(to: "")
for try await text in stream {
if Task.isCancelled { break }
let chunk = ChatCompletionChunk(
id: chunkId,
model: request.model,
choices: [
StreamingChoice(
index: 0,
delta: Delta(role: nil, content: text),
finishReason: nil
)
],
usage: nil
)
continuation.yield(chunk)
}
// Send final chunk with finish reason
let finalChunk = ChatCompletionChunk(
id: chunkId,
model: request.model,
choices: [
StreamingChoice(
index: 0,
delta: Delta(role: nil, content: nil),
finishReason: "stop"
)
],
usage: nil
)
continuation.yield(finalChunk)
continuation.finish()
} catch {
continuation.finish(throwing: error)
}
}
self.generationTask = task
continuation.onTermination = { _ in
task.cancel()
}
}
}
func cancelGeneration() {
generationTask?.cancel()
generationTask = nil
if modelState == .generating {
modelState = .ready
}
}
}
enum LocalInferenceError: LocalizedError {
case serviceUnavailable
case modelNotLoaded
var errorDescription: String? {
switch self {
case .serviceUnavailable: "Local inference service is unavailable"
case .modelNotLoaded: "Local model is not loaded"
}
}
}

View File

@@ -0,0 +1,50 @@
import Foundation
struct SSEStreamParser<T: Decodable>: AsyncSequence {
typealias Element = T
let bytes: URLSession.AsyncBytes
let decoder: JSONDecoder
init(bytes: URLSession.AsyncBytes, decoder: JSONDecoder = JSONDecoder()) {
self.bytes = bytes
self.decoder = decoder
}
func makeAsyncIterator() -> AsyncIterator {
AsyncIterator(lines: bytes.lines, decoder: decoder)
}
struct AsyncIterator: AsyncIteratorProtocol {
var lines: AsyncLineSequence<URLSession.AsyncBytes>.AsyncIterator
let decoder: JSONDecoder
init(lines: AsyncLineSequence<URLSession.AsyncBytes>, decoder: JSONDecoder) {
self.lines = lines.makeAsyncIterator()
self.decoder = decoder
}
mutating func next() async throws -> T? {
while let line = try await lines.next() {
let trimmed = line.trimmingCharacters(in: .whitespaces)
guard trimmed.hasPrefix("data: ") else { continue }
let payload = String(trimmed.dropFirst(6))
if payload == "[DONE]" {
return nil
}
guard let data = payload.data(using: .utf8) else { continue }
do {
return try decoder.decode(T.self, from: data)
} catch {
continue
}
}
return nil
}
}
}

View File

@@ -0,0 +1,171 @@
import SwiftUI
struct ChatView: View {
@Environment(ClusterService.self) private var clusterService
@Environment(ChatService.self) private var chatService
@Environment(LocalInferenceService.self) private var localInferenceService
@State private var inputText = ""
@State private var showModelSelector = false
var body: some View {
VStack(spacing: 0) {
modelBar
Divider()
messageList
Divider()
inputBar
}
.sheet(isPresented: $showModelSelector) {
ModelSelectorView(
models: clusterService.availableModels,
selectedModelId: chatService.activeConversation?.modelId
) { modelId in
chatService.setModelForActiveConversation(modelId)
}
}
}
// MARK: - Model Bar
private var useLocalModel: Bool {
!clusterService.isConnected && localInferenceService.isAvailable
}
private var modelBar: some View {
Button {
if !useLocalModel {
showModelSelector = true
}
} label: {
HStack {
Image(systemName: useLocalModel ? "iphone" : "cpu")
.foregroundStyle(useLocalModel ? .blue : .secondary)
if useLocalModel {
Text(localInferenceService.defaultModelId)
.font(.subheadline)
.lineLimit(1)
} else if let modelId = chatService.activeConversation?.modelId {
Text(modelId)
.font(.subheadline)
.lineLimit(1)
} else {
Text("Select Model")
.font(.subheadline)
.foregroundStyle(.secondary)
}
Spacer()
if useLocalModel {
Text("On-Device")
.font(.caption2)
.foregroundStyle(.blue)
.padding(.horizontal, 6)
.padding(.vertical, 2)
.background(.blue.opacity(0.1))
.clipShape(Capsule())
} else {
Image(systemName: "chevron.right")
.font(.caption)
.foregroundStyle(.tertiary)
}
}
.padding(.horizontal)
.padding(.vertical, 10)
.background(Color(.secondarySystemBackground))
}
.tint(.primary)
.disabled(useLocalModel)
}
// MARK: - Messages
private var messageList: some View {
ScrollViewReader { proxy in
ScrollView {
LazyVStack(spacing: 12) {
if chatService.activeMessages.isEmpty {
emptyState
} else {
ForEach(chatService.activeMessages) { message in
MessageBubbleView(message: message)
.id(message.id)
}
}
}
.padding()
}
.onChange(of: chatService.activeMessages.last?.content) {
if let lastId = chatService.activeMessages.last?.id {
withAnimation(.easeOut(duration: 0.2)) {
proxy.scrollTo(lastId, anchor: .bottom)
}
}
}
}
}
private var emptyState: some View {
VStack(spacing: 12) {
Spacer(minLength: 80)
Image(systemName: "bubble.left.and.bubble.right")
.font(.system(size: 48))
.foregroundStyle(.tertiary)
Text("Start a conversation")
.font(.headline)
.foregroundStyle(.secondary)
Text("Send a message to begin chatting with the model.")
.font(.subheadline)
.foregroundStyle(.tertiary)
.multilineTextAlignment(.center)
Spacer(minLength: 80)
}
.padding()
}
// MARK: - Input
private var inputBar: some View {
HStack(alignment: .bottom, spacing: 8) {
TextField("Message...", text: $inputText, axis: .vertical)
.lineLimit(1...6)
.textFieldStyle(.plain)
.padding(10)
.background(Color(.tertiarySystemBackground))
.clipShape(RoundedRectangle(cornerRadius: 20))
if chatService.isGenerating {
Button {
chatService.cancelGeneration()
} label: {
Image(systemName: "stop.circle.fill")
.font(.title2)
.foregroundStyle(.red)
}
} else {
Button {
let text = inputText
inputText = ""
chatService.sendMessage(text)
} label: {
Image(systemName: "arrow.up.circle.fill")
.font(.title2)
.foregroundStyle(canSend ? Color.accentColor : Color.gray)
}
.disabled(!canSend)
}
}
.padding(.horizontal)
.padding(.vertical, 8)
}
private var canSend: Bool {
!inputText.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty
&& (clusterService.isConnected || localInferenceService.isAvailable)
}
}

View File

@@ -0,0 +1,27 @@
import SwiftUI
struct MessageBubbleView: View {
let message: ChatMessage
var body: some View {
HStack {
if message.role == .user { Spacer(minLength: 48) }
VStack(alignment: message.role == .user ? .trailing : .leading, spacing: 4) {
Text(message.content + (message.isStreaming ? " \u{258C}" : ""))
.textSelection(.enabled)
.padding(.horizontal, 14)
.padding(.vertical, 10)
.background(bubbleBackground)
.foregroundStyle(message.role == .user ? .white : .primary)
.clipShape(RoundedRectangle(cornerRadius: 16))
}
if message.role == .assistant { Spacer(minLength: 48) }
}
}
private var bubbleBackground: Color {
message.role == .user ? .accentColor : Color(.secondarySystemBackground)
}
}

View File

@@ -0,0 +1,66 @@
import SwiftUI
struct ModelSelectorView: View {
let models: [ModelOption]
let selectedModelId: String?
let onSelect: (String) -> Void
@Environment(\.dismiss) private var dismiss
var body: some View {
NavigationStack {
List {
if models.isEmpty {
emptyContent
} else {
modelsList
}
}
.navigationTitle("Select Model")
.navigationBarTitleDisplayMode(.inline)
.toolbar {
ToolbarItem(placement: .cancellationAction) {
Button("Cancel") { dismiss() }
}
}
}
}
private var emptyContent: some View {
ContentUnavailableView(
"No Models Available",
systemImage: "cpu",
description: Text("Connect to an EXO cluster to see available models.")
)
}
private var modelsList: some View {
ForEach(models) { model in
Button {
onSelect(model.id)
dismiss()
} label: {
modelRow(model)
}
.tint(.primary)
}
}
private func modelRow(_ model: ModelOption) -> some View {
HStack {
VStack(alignment: .leading, spacing: 2) {
Text(model.displayName)
.fontWeight(.medium)
Text(model.id)
.font(.caption)
.foregroundStyle(.secondary)
}
Spacer()
if model.id == selectedModelId {
Image(systemName: "checkmark")
.foregroundStyle(Color.accentColor)
}
}
}
}

View File

@@ -0,0 +1,62 @@
import SwiftUI
struct ConnectionStatusBadge: View {
let connectionState: ConnectionState
var localModelState: LocalModelState = .notDownloaded
private var isLocalReady: Bool {
if case .disconnected = connectionState {
return localModelState == .ready || localModelState == .generating
}
return false
}
var body: some View {
HStack(spacing: 6) {
Circle()
.fill(dotColor)
.frame(width: 8, height: 8)
Text(label)
.font(.caption)
.fontWeight(.medium)
}
.padding(.horizontal, 10)
.padding(.vertical, 5)
.background(backgroundColor)
.clipShape(Capsule())
}
private var dotColor: Color {
if isLocalReady {
return .blue
}
switch connectionState {
case .connected: return .green
case .connecting: return .orange
case .disconnected: return .gray
}
}
private var label: String {
if isLocalReady {
return "Local"
}
switch connectionState {
case .connected: return "Connected"
case .connecting: return "Connecting..."
case .disconnected: return "Disconnected"
}
}
private var backgroundColor: Color {
if isLocalReady {
return .blue.opacity(0.15)
}
switch connectionState {
case .connected: return .green.opacity(0.15)
case .connecting: return .orange.opacity(0.15)
case .disconnected: return .gray.opacity(0.15)
}
}
}

View File

@@ -0,0 +1,117 @@
import SwiftUI
struct RootView: View {
@Environment(ClusterService.self) private var clusterService
@Environment(DiscoveryService.self) private var discoveryService
@Environment(ChatService.self) private var chatService
@Environment(LocalInferenceService.self) private var localInferenceService
@State private var showSettings = false
@State private var showConversations = false
var body: some View {
NavigationStack {
ChatView()
.navigationTitle("EXO")
.navigationBarTitleDisplayMode(.inline)
.toolbar {
ToolbarItem(placement: .topBarLeading) {
conversationMenuButton
}
ToolbarItem(placement: .principal) {
ConnectionStatusBadge(
connectionState: clusterService.connectionState,
localModelState: localInferenceService.modelState
)
}
ToolbarItem(placement: .topBarTrailing) {
Button {
showSettings = true
} label: {
Image(systemName: "gear")
}
}
}
}
.sheet(isPresented: $showSettings) {
SettingsView()
.environment(discoveryService)
}
.sheet(isPresented: $showConversations) {
conversationList
}
}
// MARK: - Conversations
private var conversationMenuButton: some View {
HStack(spacing: 12) {
Button {
showConversations = true
} label: {
Image(systemName: "sidebar.left")
}
Button {
chatService.createConversation()
} label: {
Image(systemName: "square.and.pencil")
}
}
}
private var conversationList: some View {
NavigationStack {
List {
if chatService.conversations.isEmpty {
Text("No conversations yet")
.foregroundStyle(.secondary)
} else {
ForEach(chatService.conversations) { conversation in
Button {
chatService.setActiveConversation(id: conversation.id)
showConversations = false
} label: {
VStack(alignment: .leading, spacing: 4) {
Text(conversation.title)
.fontWeight(
conversation.id == chatService.activeConversationId
? .semibold : .regular
)
.lineLimit(1)
if let modelId = conversation.modelId {
Text(modelId)
.font(.caption)
.foregroundStyle(.secondary)
.lineLimit(1)
}
}
}
.tint(.primary)
}
.onDelete { indexSet in
for index in indexSet {
chatService.deleteConversation(id: chatService.conversations[index].id)
}
}
}
}
.navigationTitle("Conversations")
.navigationBarTitleDisplayMode(.inline)
.toolbar {
ToolbarItem(placement: .confirmationAction) {
Button("Done") { showConversations = false }
}
ToolbarItem(placement: .topBarLeading) {
Button {
chatService.createConversation()
} label: {
Image(systemName: "plus")
}
}
}
}
}
}

View File

@@ -0,0 +1,197 @@
import SwiftUI
struct SettingsView: View {
@Environment(ClusterService.self) private var clusterService
@Environment(DiscoveryService.self) private var discoveryService
@Environment(LocalInferenceService.self) private var localInferenceService
@Environment(\.dismiss) private var dismiss
@State private var host: String = ""
@State private var port: String = "52415"
var body: some View {
NavigationStack {
Form {
localModelSection
nearbyClustersSection
connectionSection
statusSection
}
.navigationTitle("Settings")
.navigationBarTitleDisplayMode(.inline)
.toolbar {
ToolbarItem(placement: .confirmationAction) {
Button("Done") { dismiss() }
}
}
}
}
private var localModelSection: some View {
Section {
HStack {
VStack(alignment: .leading, spacing: 4) {
Text(localInferenceService.defaultModelId)
.font(.subheadline)
.fontWeight(.medium)
Text(localModelStatusText)
.font(.caption)
.foregroundStyle(.secondary)
}
Spacer()
localModelActionButton
}
if case .downloading(let progress) = localInferenceService.modelState {
ProgressView(value: progress)
.tint(.blue)
}
} header: {
Text("Local Model")
} footer: {
Text("When disconnected from a cluster, messages are processed on-device using this model.")
}
}
private var localModelStatusText: String {
switch localInferenceService.modelState {
case .notDownloaded: "Not downloaded"
case .downloading(let progress): "Downloading \(Int(progress * 100))%..."
case .downloaded: "Downloaded — not loaded"
case .loading: "Loading into memory..."
case .ready: "Ready"
case .generating: "Generating..."
case .error(let message): "Error: \(message)"
}
}
@ViewBuilder
private var localModelActionButton: some View {
switch localInferenceService.modelState {
case .notDownloaded:
Button("Download") {
Task { await localInferenceService.prepareModel() }
}
.buttonStyle(.borderedProminent)
.controlSize(.small)
case .downloading:
ProgressView()
.controlSize(.small)
case .downloaded:
Button("Load") {
Task { await localInferenceService.prepareModel() }
}
.buttonStyle(.bordered)
.controlSize(.small)
case .loading:
ProgressView()
.controlSize(.small)
case .ready, .generating:
Button("Unload") {
localInferenceService.unloadModel()
}
.buttonStyle(.bordered)
.controlSize(.small)
case .error:
Button("Retry") {
Task { await localInferenceService.prepareModel() }
}
.buttonStyle(.borderedProminent)
.controlSize(.small)
.tint(.red)
}
}
private var nearbyClustersSection: some View {
Section {
if discoveryService.discoveredClusters.isEmpty {
if discoveryService.isSearching {
HStack {
ProgressView()
.padding(.trailing, 8)
Text("Searching for clusters...")
.foregroundStyle(.secondary)
}
} else {
Text("No clusters found")
.foregroundStyle(.secondary)
}
} else {
ForEach(discoveryService.discoveredClusters) { cluster in
HStack {
VStack(alignment: .leading) {
Text(cluster.name)
.font(.body)
}
Spacer()
Button("Connect") {
Task {
await clusterService.connectToDiscoveredCluster(
cluster, using: discoveryService
)
if clusterService.isConnected {
dismiss()
}
}
}
.buttonStyle(.borderedProminent)
.controlSize(.small)
}
}
}
} header: {
Text("Nearby Clusters")
}
}
private var connectionSection: some View {
Section("Manual Connection") {
TextField("IP Address (e.g. 192.168.1.42)", text: $host)
.keyboardType(.decimalPad)
.textContentType(.URL)
.autocorrectionDisabled()
TextField("Port", text: $port)
.keyboardType(.numberPad)
Button(clusterService.isConnected ? "Reconnect" : "Connect") {
Task {
let portNum = Int(port) ?? ConnectionInfo.defaultPort
let info = ConnectionInfo(host: host, port: portNum, nodeId: nil)
await clusterService.connect(to: info)
if clusterService.isConnected {
dismiss()
}
}
}
.disabled(host.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty)
}
}
private var statusSection: some View {
Section("Status") {
if let connection = clusterService.currentConnection {
LabeledContent("Host", value: connection.host)
LabeledContent("Port", value: "\(connection.port)")
if let nodeId = connection.nodeId {
LabeledContent("Node ID", value: String(nodeId.prefix(12)) + "...")
}
LabeledContent("Models", value: "\(clusterService.availableModels.count)")
Button("Disconnect", role: .destructive) {
clusterService.disconnect()
}
} else {
if let error = clusterService.lastError {
Label(error, systemImage: "exclamationmark.triangle")
.foregroundStyle(.red)
} else {
Text("Not connected")
.foregroundStyle(.secondary)
}
}
}
}
}

View File

@@ -0,0 +1,18 @@
//
// EXO_iOSTests.swift
// EXO-iOSTests
//
// Created by Sami Khan on 2026-02-17.
//
import Testing
@testable import EXO_iOS
struct EXO_iOSTests {
@Test func example() async throws {
// Write your test here and use APIs like `#expect(...)` to check expected conditions.
}
}

View File

@@ -0,0 +1,41 @@
//
// EXO_iOSUITests.swift
// EXO-iOSUITests
//
// Created by Sami Khan on 2026-02-17.
//
import XCTest
final class EXO_iOSUITests: XCTestCase {
override func setUpWithError() throws {
// Put setup code here. This method is called before the invocation of each test method in the class.
// In UI tests it is usually best to stop immediately when a failure occurs.
continueAfterFailure = false
// In UI tests its important to set the initial state - such as interface orientation - required for your tests before they run. The setUp method is a good place to do this.
}
override func tearDownWithError() throws {
// Put teardown code here. This method is called after the invocation of each test method in the class.
}
@MainActor
func testExample() throws {
// UI tests must launch the application that they test.
let app = XCUIApplication()
app.launch()
// Use XCTAssert and related functions to verify your tests produce the correct results.
}
@MainActor
func testLaunchPerformance() throws {
// This measures how long it takes to launch your application.
measure(metrics: [XCTApplicationLaunchMetric()]) {
XCUIApplication().launch()
}
}
}

View File

@@ -0,0 +1,33 @@
//
// EXO_iOSUITestsLaunchTests.swift
// EXO-iOSUITests
//
// Created by Sami Khan on 2026-02-17.
//
import XCTest
final class EXO_iOSUITestsLaunchTests: XCTestCase {
override class var runsForEachTargetApplicationUIConfiguration: Bool {
true
}
override func setUpWithError() throws {
continueAfterFailure = false
}
@MainActor
func testLaunch() throws {
let app = XCUIApplication()
app.launch()
// Insert steps here to perform after app launch but before taking a screenshot,
// such as logging into a test account or navigating somewhere in the app
let attachment = XCTAttachment(screenshot: app.screenshot())
attachment.name = "Launch Screen"
attachment.lifetime = .keepAlways
add(attachment)
}
}

7
bench/bench.toml Normal file
View File

@@ -0,0 +1,7 @@
# Canary benchmark manifest
#
# Lists the suite files to include. Each file defines benchmarks
# with shared constraints, topology, and default args.
include = [
"single-m3-ultra.toml",
]

189
bench/single-m3-ultra.toml Normal file
View File

@@ -0,0 +1,189 @@
# Single-node M3 Ultra benchmarks
#
# Shared constraints applied to ALL benchmarks in this file.
constraints = [
"All(MacOsBuild(=25D125))",
"Hosts(=1)",
"All(Chip(m3_ultra))",
"All(GpuCores(=80))",
]
[topology]
type = "none"
# Default args merged into each benchmark's args (benchmark-level args win).
[defaults]
pp = [512, 2048, 8192, 16384]
tg = 128
[[benchmark]]
model = "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/gpt-oss-120b-MXFP4-Q8"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-Flash-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-6bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-30B-A3B-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-0.6B-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-0.6B-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.2-1B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.2-3B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.2-3B-Instruct-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/gpt-oss-20b-MXFP4-Q8"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-30B-A3B-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-Flash-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-Flash-5bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-Flash-6bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.3-70B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-5bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"
extra_constraints = ["All(Memory(>=96GiB))"]
[[benchmark]]
model = "mlx-community/Llama-3.3-70B-Instruct-8bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/llama-3.3-70b-instruct-fp16"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.5-Air-8bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.5-Air-bf16"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-4bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/MiniMax-M2.1-3bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/MiniMax-M2.1-8bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-Next-bf16"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/Step-3.5-Flash-4bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/Step-3.5-Flash-6bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/Step-3.5-Flash-8Bit"
extra_constraints = ["All(Memory(>=256GiB))"]
[[benchmark]]
model = "mlx-community/DeepSeek-V3.1-4bit"
extra_constraints = ["All(Memory(>=512GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-6bit"
extra_constraints = ["All(Memory(>=512GiB))"]
[[benchmark]]
model = "mlx-community/GLM-4.7-8bit-gs32"
extra_constraints = ["All(Memory(>=512GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"
extra_constraints = ["All(Memory(>=512GiB))"]
[[benchmark]]
model = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"
extra_constraints = ["All(Memory(>=512GiB))"]

View File

File diff suppressed because it is too large Load Diff

View File

@@ -132,7 +132,7 @@ markers = [
env = [
"EXO_TESTS=1"
]
addopts = "-m 'not slow'"
addopts = "-m 'not slow' --ignore=tests/start_distributed_test.py"
filterwarnings = [
"ignore:builtin type Swig:DeprecationWarning",
]

View File

@@ -14,6 +14,7 @@ from exo.download.download_utils import (
map_repo_download_progress_to_download_progress_data,
)
from exo.download.shard_downloader import ShardDownloader
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import ModelId
from exo.shared.types.commands import (
CancelDownload,
@@ -63,6 +64,9 @@ class DownloadCoordinator:
self.event_sender, self.event_receiver = channel[Event]()
self.shard_downloader.on_progress(self._download_progress_callback)
def _model_dir(self, model_id: ModelId) -> str:
return str(EXO_MODELS_DIR / model_id.normalize())
async def _download_progress_callback(
self, callback_shard: ShardMetadata, progress: RepoDownloadProgress
) -> None:
@@ -74,6 +78,7 @@ class DownloadCoordinator:
shard_metadata=callback_shard,
node_id=self.node_id,
total_bytes=progress.total_bytes,
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = completed
await self.event_sender.send(
@@ -93,6 +98,7 @@ class DownloadCoordinator:
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = ongoing
await self.event_sender.send(
@@ -170,7 +176,11 @@ class DownloadCoordinator:
return
# Emit pending status
progress = DownloadPending(shard_metadata=shard, node_id=self.node_id)
progress = DownloadPending(
shard_metadata=shard,
node_id=self.node_id,
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = progress
await self.event_sender.send(NodeDownloadProgress(download_progress=progress))
@@ -184,6 +194,7 @@ class DownloadCoordinator:
shard_metadata=shard,
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = completed
await self.event_sender.send(
@@ -206,6 +217,7 @@ class DownloadCoordinator:
download_progress=map_repo_download_progress_to_download_progress_data(
initial_progress
),
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = status
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
@@ -219,6 +231,7 @@ class DownloadCoordinator:
shard_metadata=shard,
node_id=self.node_id,
error_message=str(e),
model_directory=self._model_dir(model_id),
)
self.download_status[model_id] = failed
await self.event_sender.send(
@@ -253,6 +266,7 @@ class DownloadCoordinator:
pending = DownloadPending(
shard_metadata=current_status.shard_metadata,
node_id=self.node_id,
model_directory=self._model_dir(model_id),
)
await self.event_sender.send(
NodeDownloadProgress(download_progress=pending)
@@ -295,11 +309,18 @@ class DownloadCoordinator:
node_id=self.node_id,
shard_metadata=progress.shard,
total_bytes=progress.total_bytes,
model_directory=self._model_dir(
progress.shard.model_card.model_id
),
)
elif progress.status in ["in_progress", "not_started"]:
if progress.downloaded_bytes_this_session.in_bytes == 0:
status = DownloadPending(
node_id=self.node_id, shard_metadata=progress.shard
node_id=self.node_id,
shard_metadata=progress.shard,
model_directory=self._model_dir(
progress.shard.model_card.model_id
),
)
else:
status = DownloadOngoing(
@@ -308,6 +329,9 @@ class DownloadCoordinator:
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
model_directory=self._model_dir(
progress.shard.model_card.model_id
),
)
else:
continue

View File

@@ -144,8 +144,8 @@ async def collect_responses_response(
for tool in chunk.tool_calls:
function_call_items.append(
ResponseFunctionCallItem(
id=f"fc_{tool.id}",
call_id=f"call_{tool.id}",
id=tool.id,
call_id=tool.id,
name=tool.name,
arguments=tool.arguments,
)

View File

@@ -1367,6 +1367,7 @@ class API:
async def run(self):
shutdown_ev = anyio.Event()
bonjour_cleanup = self._register_bonjour_service()
try:
async with create_task_group() as tg:
self._tg = tg
@@ -1382,10 +1383,38 @@ class API:
with anyio.CancelScope(shield=True):
shutdown_ev.set()
finally:
bonjour_cleanup()
self._event_log.close()
self.command_sender.close()
self.global_event_receiver.close()
def _register_bonjour_service(self) -> Callable[[], None]:
"""Register a Bonjour service via the system mDNSResponder. Returns a cleanup function."""
import subprocess
import sys
if sys.platform != "darwin":
logger.info("Bonjour service registration is only supported on macOS")
return lambda: None
service_name = f"EXO Cluster ({self.node_id[:8]})"
try:
proc = subprocess.Popen(
["dns-sd", "-R", service_name, "_exo._tcp", "local", str(self.port), f"node_id={self.node_id}"],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
logger.info(f"Registered Bonjour service _exo._tcp on port {self.port} (pid {proc.pid})")
def cleanup() -> None:
proc.terminate()
proc.wait()
return cleanup
except Exception as e:
logger.warning(f"Failed to register Bonjour service: {e}")
return lambda: None
async def run_api(self, ev: anyio.Event):
cfg = Config()
cfg.bind = [f"0.0.0.0:{self.port}"]

View File

@@ -218,11 +218,6 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
key: value for key, value in state.downloads.items() if key != event.node_id
}
# Clean up all granular node mappings
node_identities = {
key: value
for key, value in state.node_identities.items()
if key != event.node_id
}
node_memory = {
key: value for key, value in state.node_memory.items() if key != event.node_id
}
@@ -263,7 +258,6 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
"downloads": downloads,
"topology": topology,
"last_seen": last_seen,
"node_identities": node_identities,
"node_memory": node_memory,
"node_disk": node_disk,
"node_system": node_system,

View File

@@ -26,6 +26,7 @@ class DownloadProgressData(CamelCaseModel):
class BaseDownloadProgress(TaggedModel):
node_id: NodeId
shard_metadata: ShardMetadata
model_directory: str = ""
class DownloadPending(BaseDownloadProgress):

View File

@@ -1,5 +1,7 @@
import sys
def print_startup_banner(port: int) -> None:
"""Print a prominent startup banner with API endpoint information."""
dashboard_url = f"http://localhost:{port}"
banner = f"""
╔═══════════════════════════════════════════════════════════════════════╗
@@ -27,4 +29,4 @@ def print_startup_banner(port: int) -> None:
"""
print(banner)
print(banner, file=sys.stderr)

View File

@@ -306,7 +306,7 @@ def mlx_generate(
max_stop_len = max((len(s) for s in stop_sequences), default=0)
mx_barrier(group)
logger.info("Ready to prefill")
logger.info("Starting prefill")
# Prefill cache with all tokens except the last one
prefill_tps, prefill_tokens, ssm_snapshots_list = prefill(

View File

@@ -353,7 +353,13 @@ def load_tokenizer_for_model_id(
return list(hf_tokenizer.model.encode(text, allowed_special="all")) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
hf_tokenizer.encode = _patched_encode
return TokenizerWrapper(hf_tokenizer, eos_token_ids=eos_token_ids)
return TokenizerWrapper(
hf_tokenizer,
eos_token_ids=eos_token_ids,
tool_call_start="<|tool_calls_section_begin|>",
tool_call_end="<|tool_calls_section_end|>",
tool_parser=_parse_kimi_tool_calls,
)
tokenizer = load_tokenizer(
model_path,
@@ -585,3 +591,41 @@ def mx_barrier(group: Group | None):
mx.array(1.0), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
)
)
def _parse_kimi_tool_calls(text: str):
import regex as re
# kimi has a fixed function naming scheme, with a json formatted arg
# functions.multiply:0<|tool_call_argument_begin|>{"a": 2, "b": 3}
_func_name_regex = re.compile(
r"^\s*((?:functions\.)?(.+?):\d+)\s*<\|tool_call_argument_begin\|>", re.DOTALL
)
_func_arg_regex = re.compile(r"<\|tool_call_argument_begin\|>\s*(.*)\s*", re.DOTALL)
_tool_call_split_regex = re.compile(
r"<\|tool_call_begin\|>(.*?)<\|tool_call_end\|>", re.DOTALL
)
def _parse_single_tool(text: str) -> dict[str, Any]:
func_name_match = _func_name_regex.search(text)
if func_name_match is None:
raise ValueError("No tool call found.")
tool_call_id = func_name_match.group(1) # e.g. "functions.get_weather:0"
func_name = func_name_match.group(2) # e.g. "get_weather"
func_args_match = _func_arg_regex.search(text)
if func_args_match is None:
raise ValueError("No tool call arguments found.")
func_args = func_args_match.group(1)
try:
arg_dct = json.loads(func_args) # pyright: ignore[reportAny]
except Exception:
arg_dct = None
return dict(id=tool_call_id, name=func_name, arguments=arg_dct)
tool_matches = _tool_call_split_regex.findall(text)
if tool_matches:
return [_parse_single_tool(match) for match in tool_matches] # pyright: ignore[reportAny]
else:
return [_parse_single_tool(text)]

View File

@@ -1,11 +1,10 @@
import base64
import json
import math
import resource
import time
from collections.abc import Generator
from functools import cache
from typing import Any, Callable, Literal
from typing import Literal
import mlx.core as mx
from mlx_lm.models.gpt_oss import Model as GptOssModel
@@ -16,7 +15,6 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
StreamableParser,
load_harmony_encoding,
)
from pydantic import ValidationError
from exo.shared.constants import EXO_MAX_CHUNK_SIZE, EXO_TRACING_ENABLED
from exo.shared.models.model_cards import ModelId, ModelTask
@@ -93,6 +91,8 @@ from exo.worker.engines.mlx.utils_mlx import (
)
from exo.worker.runner.bootstrap import logger
from .tool_parsers import ToolParser, make_mlx_parser
def _is_primary_output_node(shard_metadata: ShardMetadata) -> bool:
"""Check if this node is the primary output node for image generation.
@@ -138,6 +138,7 @@ def main(
inference_model: Model | None = None
image_model: DistributedImageModel | None = None
tokenizer = None
tool_parser: ToolParser | None = None
group = None
kv_prefix_cache: KVPrefixCache | None = None
check_for_cancel_every: int | None = None
@@ -203,8 +204,17 @@ def main(
bound_instance, group, on_timeout=on_model_load_timeout
)
logger.info(
f"model has_tool_calling={tokenizer.has_tool_calling}"
f"model has_tool_calling={tokenizer.has_tool_calling} using tokens {tokenizer.tool_call_start}, {tokenizer.tool_call_end}"
)
if tokenizer.has_tool_calling:
assert tokenizer.tool_call_start
assert tokenizer.tool_call_end
assert tokenizer.tool_parser # pyright: ignore[reportAny]
tool_parser = make_mlx_parser(
tokenizer.tool_call_start,
tokenizer.tool_call_end,
tokenizer.tool_parser, # pyright: ignore[reportAny]
)
kv_prefix_cache = KVPrefixCache(group)
elif (
@@ -310,31 +320,11 @@ def main(
mlx_generator, tokenizer
)
# Kimi-K2 has tool call sections - we don't care about them
if "kimi" in shard_metadata.model_card.model_id.lower():
mlx_generator = filter_kimi_tokens(mlx_generator)
patch_kimi_tokenizer(tokenizer)
# GLM models need patched parser (upstream has bug with None regex match)
elif "glm" in shard_metadata.model_card.model_id.lower():
patch_glm_tokenizer(tokenizer)
# GPT-OSS specific parsing to match other model formats.
elif isinstance(inference_model, GptOssModel):
if isinstance(inference_model, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator)
if tokenizer.has_tool_calling and not isinstance(
inference_model, GptOssModel
):
assert tokenizer.tool_call_start
assert tokenizer.tool_call_end
assert tokenizer.tool_parser # pyright: ignore[reportAny]
mlx_generator = parse_tool_calls(
mlx_generator,
tokenizer.tool_call_start,
tokenizer.tool_call_end,
tokenizer.tool_parser, # pyright: ignore[reportAny]
)
elif tool_parser:
mlx_generator = parse_tool_calls(mlx_generator, tool_parser)
completion_tokens = 0
tokens_since_last_cancel_check = 0
@@ -587,21 +577,8 @@ def get_gpt_oss_encoding():
return encoding
def filter_kimi_tokens(
responses: Generator[GenerationResponse | ToolCallResponse],
) -> Generator[GenerationResponse]:
for resp in responses:
assert isinstance(resp, GenerationResponse)
if (
resp.text == "<|tool_calls_section_begin|>"
or resp.text == "<|tool_calls_section_end|>"
):
continue
yield resp
def parse_gpt_oss(
responses: Generator[GenerationResponse | ToolCallResponse],
responses: Generator[GenerationResponse],
) -> Generator[GenerationResponse | ToolCallResponse]:
encoding = get_gpt_oss_encoding()
stream = StreamableParser(encoding, role=Role.ASSISTANT)
@@ -658,9 +635,9 @@ def parse_gpt_oss(
def parse_thinking_models(
responses: Generator[GenerationResponse | ToolCallResponse],
responses: Generator[GenerationResponse],
tokenizer: TokenizerWrapper,
) -> Generator[GenerationResponse | ToolCallResponse]:
) -> Generator[GenerationResponse]:
"""
For models that inject thinking tags in the prompt (like GLM-4.7),
prepend the thinking tag to the output stream so the frontend
@@ -781,221 +758,55 @@ def _process_image_response(
def parse_tool_calls(
responses: Generator[GenerationResponse | ToolCallResponse],
tool_call_start: str,
tool_call_end: str,
tool_parser: Callable[[str], dict[str, Any] | list[dict[str, Any]]],
responses: Generator[GenerationResponse], tool_parser: ToolParser
) -> Generator[GenerationResponse | ToolCallResponse]:
in_tool_call = False
tool_call_text_parts: list[str] = []
for response in responses:
assert isinstance(response, GenerationResponse)
# assumption: the tool call start is one token
if response.text == tool_call_start:
if response.text.startswith(tool_parser.start_parsing):
in_tool_call = True
continue
# assumption: the tool call end is one token
if in_tool_call and response.text == tool_call_end:
try:
# tool_parser returns an arbitrarily nested python dictionary
# we actually don't want the python dictionary, we just want to
# parse the top level { function: ..., arguments: ... } structure
# as we're just gonna hand it back to the api anyway
parsed = tool_parser("".join(tool_call_text_parts).strip())
logger.info(f"parsed {tool_call_text_parts=} into {parsed=}")
if isinstance(parsed, list):
tools = [_validate_single_tool(tool) for tool in parsed]
else:
tools = [_validate_single_tool(parsed)]
yield ToolCallResponse(
tool_calls=tools, usage=response.usage, stats=response.stats
)
except (
json.JSONDecodeError,
ValidationError,
ValueError,
AttributeError,
) as e:
# ValueError: our parsers raise this for malformed tool calls
# AttributeError: upstream parsers (e.g. glm47) may raise this when regex doesn't match
logger.opt(exception=e).warning("tool call parsing failed")
# assumption: talking about tool calls, not making a tool call
response.text = (
tool_call_start + "".join(tool_call_text_parts) + tool_call_end
)
yield response
in_tool_call = False
tool_call_text_parts = []
continue
if in_tool_call:
tool_call_text_parts.append(response.text)
if response.text.endswith(tool_parser.end_parsing):
# parse the actual tool calls from the tool call text
parsed = tool_parser.parse_tool_calls(
"".join(tool_call_text_parts).strip()
)
logger.info(f"parsed {tool_call_text_parts=} into {parsed=}")
if parsed is not None:
yield ToolCallResponse(
tool_calls=parsed, usage=response.usage, stats=response.stats
)
else:
logger.warning(
f"tool call parsing failed for text {''.join(tool_call_text_parts)}"
)
response.text = "".join(tool_call_text_parts)
yield response
in_tool_call = False
tool_call_text_parts = []
continue
if response.finish_reason is not None:
logger.info(
"toll call parsing interrupted, yield partial tool call as text"
"tool call parsing interrupted, yield partial tool call as text"
)
yield GenerationResponse(
text=tool_call_start + "".join(tool_call_text_parts),
token=0,
finish_reason=response.finish_reason,
usage=response.usage,
stats=response.stats,
response = response.model_copy(
update={
"text": "".join(tool_call_text_parts),
"token": 0,
}
)
yield response
continue
# fallthrough
yield response
def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
"""
Version of to-be-upstreamed kimi-k2 tool parser
"""
import ast
import json
from typing import Any
import regex as re
# kimi has a fixed function naming scheme, with a json formatted arg
# functions.multiply:0 <|tool_call_argument_begin|> {"a": 2, "b": 3}
# Also needs to handle tools like call_0<|tool_call_argument_begin|>{"filePath": "..."}
_func_name_regex = re.compile(
r"^\s*(.+)[:](\d+)\s*<\|tool_call_argument_begin\|>", re.DOTALL
)
_func_arg_regex = re.compile(r"<\|tool_call_argument_begin\|>\s*(.*)\s*", re.DOTALL)
# kimi has a tool_calls_section - we're leaving this up to the caller to handle
tool_call_start = "<|tool_call_begin|>"
tool_call_end = "<|tool_call_end|>"
def _deserialize(value: str) -> Any: # pyright: ignore[reportAny]
try:
return json.loads(value) # pyright: ignore[reportAny]
except Exception:
pass
try:
return ast.literal_eval(value) # pyright: ignore[reportAny]
except Exception:
pass
return value
def parse_tool_call(text: str, tools: Any | None = None):
func_name_match = _func_name_regex.search(text)
if func_name_match is None:
raise ValueError(f"Could not parse function name from tool call: {text!r}")
original_func_name = func_name_match.group(1)
tool_id = func_name_match.group(2)
# strip off the `functions.` prefix, if it exists.
func_name = original_func_name[original_func_name.find(".") + 1 :]
func_args_match = _func_arg_regex.search(text)
if func_args_match is None:
raise ValueError(f"Could not parse function args from tool call: {text!r}")
func_args = func_args_match.group(1)
# the args should be valid json - no need to check against our tools to deserialize
arg_dct = _deserialize(func_args) # pyright: ignore[reportAny]
return dict(
id=f"{original_func_name}:{tool_id}",
name=func_name,
arguments=arg_dct, # pyright: ignore[reportAny]
)
tokenizer._tool_call_start = tool_call_start
tokenizer._tool_call_end = tool_call_end
tokenizer._tool_parser = parse_tool_call
def patch_glm_tokenizer(tokenizer: TokenizerWrapper):
"""
Fixed version of mlx_lm's glm47 tool parser that handles regex match failures.
"""
import ast
import json
from typing import Any
import regex as re
_func_name_regex = re.compile(r"^(.*?)<arg_key>", re.DOTALL)
_func_arg_regex = re.compile(
r"<arg_key>(.*?)</arg_key>(?:\n|\s)*<arg_value>(.*?)(?:</arg_value>|(?=<arg_key>)|$)",
re.DOTALL,
)
tool_call_start = "<tool_call>"
tool_call_end = "</tool_call>"
def _is_string_type(
tool_name: str,
arg_name: str,
tools: list[Any] | None,
) -> bool:
if tools is None:
return False
for tool in tools: # pyright: ignore[reportAny]
func = tool["function"] # pyright: ignore[reportAny]
if func["name"] == tool_name:
params = func["parameters"] # pyright: ignore[reportAny]
if params is None:
return False
props = params.get("properties", {}) # pyright: ignore[reportAny]
arg_props = props.get(arg_name, {}) # pyright: ignore[reportAny]
arg_type = arg_props.get("type", None) # pyright: ignore[reportAny]
return arg_type == "string" # pyright: ignore[reportAny]
return False
def _deserialize(value: str) -> Any: # pyright: ignore[reportAny]
try:
return json.loads(value) # pyright: ignore[reportAny]
except Exception:
pass
try:
return ast.literal_eval(value) # pyright: ignore[reportAny]
except Exception:
pass
return value
def parse_tool_call(text: str, tools: list[Any] | None = None):
func_name_match = _func_name_regex.search(text)
if func_name_match is None:
raise ValueError(f"Could not parse function name from tool call: {text!r}")
func_name = func_name_match.group(1)
pairs = _func_arg_regex.findall(text)
arg_dct: dict[str, Any] = {}
for key, value in pairs: # pyright: ignore[reportAny]
arg_key = key.strip() # pyright: ignore[reportAny]
arg_val = value.strip() # pyright: ignore[reportAny]
if not _is_string_type(func_name, arg_key, tools): # pyright: ignore[reportAny]
arg_val = _deserialize(arg_val) # pyright: ignore[reportAny]
arg_dct[arg_key] = arg_val
return dict(name=func_name, arguments=arg_dct)
tokenizer._tool_call_start = tool_call_start
tokenizer._tool_call_end = tool_call_end
tokenizer._tool_parser = parse_tool_call
def _validate_single_tool(obj: dict[str, Any]) -> ToolCallItem:
if (
((name := obj.get("name")) is not None)
and ((args := obj.get("arguments")) is not None)
and isinstance(name, str)
):
raw_id: object = obj.get("id")
extra = {"id": str(raw_id)} if raw_id is not None else {}
return ToolCallItem(
**extra,
name=name,
arguments=json.dumps(args),
)
else:
raise ValidationError
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"

View File

@@ -0,0 +1,72 @@
import json
from dataclasses import dataclass
from typing import Any, Callable
from exo.shared.types.api import ToolCallItem
@dataclass
class ToolParser:
start_parsing: str
end_parsing: str
parse_tool_calls: Callable[[str], list[ToolCallItem] | None]
def make_mlx_parser(
tool_call_start: str,
tool_call_end: str,
tool_parser: Callable[[str], dict[str, Any] | list[dict[str, Any]]],
) -> ToolParser:
def parse_tool_calls(text: str) -> list[ToolCallItem] | None:
try:
text = text.removeprefix(tool_call_start)
text = text.removesuffix(tool_call_end)
parsed = tool_parser(text)
if isinstance(parsed, list):
return [ToolCallItem.model_validate(_flatten(p)) for p in parsed]
else:
return [ToolCallItem.model_validate(_flatten(parsed))]
except Exception:
return None
return ToolParser(
start_parsing=tool_call_start,
end_parsing=tool_call_end,
parse_tool_calls=parse_tool_calls,
)
# TODO / example code:
def _parse_json_calls(text: str) -> list[ToolCallItem] | None:
try:
text = text.removeprefix("<tool_call>")
text = text.removesuffix("</tool_call>")
top_level = {
k: json.dumps(v) if isinstance(v, (dict, list)) else v
for k, v in json.loads(text).items() # pyright: ignore[reportAny]
}
return [ToolCallItem.model_validate(top_level)]
except Exception:
return None
def _flatten(p: dict[str, Any]) -> dict[str, str]:
return {
k: json.dumps(v) if isinstance(v, (dict, list)) else str(v) # pyright: ignore[reportAny]
for k, v in p.items() # pyright: ignore[reportAny]
}
json_tool_parser = ToolParser(
start_parsing="<tool_call>",
end_parsing="</tool_call>",
parse_tool_calls=_parse_json_calls,
)
def infer_tool_parser(chat_template: str) -> ToolParser | None:
"""Attempt to auto-infer a tool parser from the chat template."""
if "<tool_call>" in chat_template and "tool_call.name" in chat_template:
return json_tool_parser
return None

View File

@@ -5,12 +5,13 @@ from typing import Any
from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse
from exo.worker.runner.runner import parse_tool_calls
from exo.worker.runner.tool_parsers import make_mlx_parser
def _make_responses(
texts: list[str],
finish_on_last: bool = True,
) -> Generator[GenerationResponse | ToolCallResponse]:
) -> Generator[GenerationResponse]:
"""Create a sequence of GenerationResponses from text strings."""
for i, text in enumerate(texts):
is_last = i == len(texts) - 1
@@ -22,10 +23,13 @@ def _make_responses(
)
def _dummy_parser(text: str) -> dict[str, Any]:
def _dummier_parser(text: str) -> dict[str, Any]:
return {"name": "test_fn", "arguments": {"arg": text}}
_dummy_parser = make_mlx_parser("<tool_call>", "</tool_call>", _dummier_parser)
class TestParseToolCalls:
"""Tests for parse_tool_calls generator."""
@@ -35,8 +39,6 @@ class TestParseToolCalls:
results = list(
parse_tool_calls(
_make_responses(texts, finish_on_last=False),
"<tool_call>",
"</tool_call>",
_dummy_parser,
)
)
@@ -50,8 +52,6 @@ class TestParseToolCalls:
results = list(
parse_tool_calls(
_make_responses(texts),
"<tool_call>",
"</tool_call>",
_dummy_parser,
)
)
@@ -76,9 +76,7 @@ class TestParseToolCalls:
results = list(
parse_tool_calls(
_make_responses(texts, finish_on_last=False),
"<tool_call>",
"</tool_call>",
_failing_parser,
make_mlx_parser("<tool_call>", "</tool_call>", _failing_parser),
)
)