mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-18 23:06:23 -05:00
Compare commits
2 Commits
alexcheema
...
sami/iOS-a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3e9eb93f82 | ||
|
|
ab622f79c3 |
11
README.md
11
README.md
@@ -72,23 +72,16 @@ There are two ways to run exo:
|
||||
|
||||
### Run from Source (macOS)
|
||||
|
||||
If you have [Nix](https://nixos.org/) installed, you can skip most of the steps below and run exo directly (after accepting the Cachix cache):
|
||||
|
||||
```bash
|
||||
nix run .#exo
|
||||
```
|
||||
|
||||
**Prerequisites:**
|
||||
- [Xcode](https://developer.apple.com/xcode/) (provides the Metal ToolChain required for MLX compilation)
|
||||
- [brew](https://github.com/Homebrew/brew) (for simple package management on macOS)
|
||||
|
||||
|
||||
```bash
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
|
||||
```
|
||||
- [uv](https://github.com/astral-sh/uv) (for Python dependency management)
|
||||
- [macmon](https://github.com/vladkens/macmon) (for hardware monitoring on Apple Silicon)
|
||||
- [node](https://github.com/nodejs/node) (for building the dashboard)
|
||||
|
||||
|
||||
```bash
|
||||
brew install uv macmon node
|
||||
```
|
||||
|
||||
628
app/EXO-iOS/EXO-iOS.xcodeproj/project.pbxproj
Normal file
628
app/EXO-iOS/EXO-iOS.xcodeproj/project.pbxproj
Normal file
@@ -0,0 +1,628 @@
|
||||
// !$*UTF8*$!
|
||||
{
|
||||
archiveVersion = 1;
|
||||
classes = {
|
||||
};
|
||||
objectVersion = 77;
|
||||
objects = {
|
||||
|
||||
/* Begin PBXBuildFile section */
|
||||
E09D17522F44F359009C51A3 /* MLXLLM in Frameworks */ = {isa = PBXBuildFile; productRef = E09D17512F44F359009C51A3 /* MLXLLM */; };
|
||||
E09D17542F44F359009C51A3 /* MLXLMCommon in Frameworks */ = {isa = PBXBuildFile; productRef = E09D17532F44F359009C51A3 /* MLXLMCommon */; };
|
||||
/* End PBXBuildFile section */
|
||||
|
||||
/* Begin PBXContainerItemProxy section */
|
||||
E09D167D2F44CA20009C51A3 /* PBXContainerItemProxy */ = {
|
||||
isa = PBXContainerItemProxy;
|
||||
containerPortal = E09D16672F44CA1E009C51A3 /* Project object */;
|
||||
proxyType = 1;
|
||||
remoteGlobalIDString = E09D166E2F44CA1E009C51A3;
|
||||
remoteInfo = "EXO-iOS";
|
||||
};
|
||||
E09D16872F44CA20009C51A3 /* PBXContainerItemProxy */ = {
|
||||
isa = PBXContainerItemProxy;
|
||||
containerPortal = E09D16672F44CA1E009C51A3 /* Project object */;
|
||||
proxyType = 1;
|
||||
remoteGlobalIDString = E09D166E2F44CA1E009C51A3;
|
||||
remoteInfo = "EXO-iOS";
|
||||
};
|
||||
/* End PBXContainerItemProxy section */
|
||||
|
||||
/* Begin PBXFileReference section */
|
||||
E09D166F2F44CA1E009C51A3 /* EXO-iOS.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = "EXO-iOS.app"; sourceTree = BUILT_PRODUCTS_DIR; };
|
||||
E09D167C2F44CA20009C51A3 /* EXO-iOSTests.xctest */ = {isa = PBXFileReference; explicitFileType = wrapper.cfbundle; includeInIndex = 0; path = "EXO-iOSTests.xctest"; sourceTree = BUILT_PRODUCTS_DIR; };
|
||||
E09D16862F44CA20009C51A3 /* EXO-iOSUITests.xctest */ = {isa = PBXFileReference; explicitFileType = wrapper.cfbundle; includeInIndex = 0; path = "EXO-iOSUITests.xctest"; sourceTree = BUILT_PRODUCTS_DIR; };
|
||||
/* End PBXFileReference section */
|
||||
|
||||
/* Begin PBXFileSystemSynchronizedBuildFileExceptionSet section */
|
||||
E09D169A2F44CA20009C51A3 /* Exceptions for "EXO-iOS" folder in "EXO-iOS" target */ = {
|
||||
isa = PBXFileSystemSynchronizedBuildFileExceptionSet;
|
||||
membershipExceptions = (
|
||||
Info.plist,
|
||||
);
|
||||
target = E09D166E2F44CA1E009C51A3 /* EXO-iOS */;
|
||||
};
|
||||
/* End PBXFileSystemSynchronizedBuildFileExceptionSet section */
|
||||
|
||||
/* Begin PBXFileSystemSynchronizedRootGroup section */
|
||||
E09D16712F44CA1E009C51A3 /* EXO-iOS */ = {
|
||||
isa = PBXFileSystemSynchronizedRootGroup;
|
||||
exceptions = (
|
||||
E09D169A2F44CA20009C51A3 /* Exceptions for "EXO-iOS" folder in "EXO-iOS" target */,
|
||||
);
|
||||
path = "EXO-iOS";
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
E09D167F2F44CA20009C51A3 /* EXO-iOSTests */ = {
|
||||
isa = PBXFileSystemSynchronizedRootGroup;
|
||||
path = "EXO-iOSTests";
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
E09D16892F44CA20009C51A3 /* EXO-iOSUITests */ = {
|
||||
isa = PBXFileSystemSynchronizedRootGroup;
|
||||
path = "EXO-iOSUITests";
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
/* End PBXFileSystemSynchronizedRootGroup section */
|
||||
|
||||
/* Begin PBXFrameworksBuildPhase section */
|
||||
E09D166C2F44CA1E009C51A3 /* Frameworks */ = {
|
||||
isa = PBXFrameworksBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
E09D17542F44F359009C51A3 /* MLXLMCommon in Frameworks */,
|
||||
E09D17522F44F359009C51A3 /* MLXLLM in Frameworks */,
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
E09D16792F44CA20009C51A3 /* Frameworks */ = {
|
||||
isa = PBXFrameworksBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
E09D16832F44CA20009C51A3 /* Frameworks */ = {
|
||||
isa = PBXFrameworksBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
/* End PBXFrameworksBuildPhase section */
|
||||
|
||||
/* Begin PBXGroup section */
|
||||
E09D16662F44CA1E009C51A3 = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
E09D16712F44CA1E009C51A3 /* EXO-iOS */,
|
||||
E09D167F2F44CA20009C51A3 /* EXO-iOSTests */,
|
||||
E09D16892F44CA20009C51A3 /* EXO-iOSUITests */,
|
||||
E09D16702F44CA1E009C51A3 /* Products */,
|
||||
);
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
E09D16702F44CA1E009C51A3 /* Products */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
E09D166F2F44CA1E009C51A3 /* EXO-iOS.app */,
|
||||
E09D167C2F44CA20009C51A3 /* EXO-iOSTests.xctest */,
|
||||
E09D16862F44CA20009C51A3 /* EXO-iOSUITests.xctest */,
|
||||
);
|
||||
name = Products;
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
/* End PBXGroup section */
|
||||
|
||||
/* Begin PBXNativeTarget section */
|
||||
E09D166E2F44CA1E009C51A3 /* EXO-iOS */ = {
|
||||
isa = PBXNativeTarget;
|
||||
buildConfigurationList = E09D16902F44CA20009C51A3 /* Build configuration list for PBXNativeTarget "EXO-iOS" */;
|
||||
buildPhases = (
|
||||
E09D166B2F44CA1E009C51A3 /* Sources */,
|
||||
E09D166C2F44CA1E009C51A3 /* Frameworks */,
|
||||
E09D166D2F44CA1E009C51A3 /* Resources */,
|
||||
);
|
||||
buildRules = (
|
||||
);
|
||||
dependencies = (
|
||||
);
|
||||
fileSystemSynchronizedGroups = (
|
||||
E09D16712F44CA1E009C51A3 /* EXO-iOS */,
|
||||
);
|
||||
name = "EXO-iOS";
|
||||
packageProductDependencies = (
|
||||
E09D17512F44F359009C51A3 /* MLXLLM */,
|
||||
E09D17532F44F359009C51A3 /* MLXLMCommon */,
|
||||
);
|
||||
productName = "EXO-iOS";
|
||||
productReference = E09D166F2F44CA1E009C51A3 /* EXO-iOS.app */;
|
||||
productType = "com.apple.product-type.application";
|
||||
};
|
||||
E09D167B2F44CA20009C51A3 /* EXO-iOSTests */ = {
|
||||
isa = PBXNativeTarget;
|
||||
buildConfigurationList = E09D16932F44CA20009C51A3 /* Build configuration list for PBXNativeTarget "EXO-iOSTests" */;
|
||||
buildPhases = (
|
||||
E09D16782F44CA20009C51A3 /* Sources */,
|
||||
E09D16792F44CA20009C51A3 /* Frameworks */,
|
||||
E09D167A2F44CA20009C51A3 /* Resources */,
|
||||
);
|
||||
buildRules = (
|
||||
);
|
||||
dependencies = (
|
||||
E09D167E2F44CA20009C51A3 /* PBXTargetDependency */,
|
||||
);
|
||||
fileSystemSynchronizedGroups = (
|
||||
E09D167F2F44CA20009C51A3 /* EXO-iOSTests */,
|
||||
);
|
||||
name = "EXO-iOSTests";
|
||||
packageProductDependencies = (
|
||||
);
|
||||
productName = "EXO-iOSTests";
|
||||
productReference = E09D167C2F44CA20009C51A3 /* EXO-iOSTests.xctest */;
|
||||
productType = "com.apple.product-type.bundle.unit-test";
|
||||
};
|
||||
E09D16852F44CA20009C51A3 /* EXO-iOSUITests */ = {
|
||||
isa = PBXNativeTarget;
|
||||
buildConfigurationList = E09D16962F44CA20009C51A3 /* Build configuration list for PBXNativeTarget "EXO-iOSUITests" */;
|
||||
buildPhases = (
|
||||
E09D16822F44CA20009C51A3 /* Sources */,
|
||||
E09D16832F44CA20009C51A3 /* Frameworks */,
|
||||
E09D16842F44CA20009C51A3 /* Resources */,
|
||||
);
|
||||
buildRules = (
|
||||
);
|
||||
dependencies = (
|
||||
E09D16882F44CA20009C51A3 /* PBXTargetDependency */,
|
||||
);
|
||||
fileSystemSynchronizedGroups = (
|
||||
E09D16892F44CA20009C51A3 /* EXO-iOSUITests */,
|
||||
);
|
||||
name = "EXO-iOSUITests";
|
||||
packageProductDependencies = (
|
||||
);
|
||||
productName = "EXO-iOSUITests";
|
||||
productReference = E09D16862F44CA20009C51A3 /* EXO-iOSUITests.xctest */;
|
||||
productType = "com.apple.product-type.bundle.ui-testing";
|
||||
};
|
||||
/* End PBXNativeTarget section */
|
||||
|
||||
/* Begin PBXProject section */
|
||||
E09D16672F44CA1E009C51A3 /* Project object */ = {
|
||||
isa = PBXProject;
|
||||
attributes = {
|
||||
BuildIndependentTargetsInParallel = 1;
|
||||
LastSwiftUpdateCheck = 2620;
|
||||
LastUpgradeCheck = 2620;
|
||||
TargetAttributes = {
|
||||
E09D166E2F44CA1E009C51A3 = {
|
||||
CreatedOnToolsVersion = 26.2;
|
||||
};
|
||||
E09D167B2F44CA20009C51A3 = {
|
||||
CreatedOnToolsVersion = 26.2;
|
||||
TestTargetID = E09D166E2F44CA1E009C51A3;
|
||||
};
|
||||
E09D16852F44CA20009C51A3 = {
|
||||
CreatedOnToolsVersion = 26.2;
|
||||
TestTargetID = E09D166E2F44CA1E009C51A3;
|
||||
};
|
||||
};
|
||||
};
|
||||
buildConfigurationList = E09D166A2F44CA1E009C51A3 /* Build configuration list for PBXProject "EXO-iOS" */;
|
||||
developmentRegion = en;
|
||||
hasScannedForEncodings = 0;
|
||||
knownRegions = (
|
||||
en,
|
||||
Base,
|
||||
);
|
||||
mainGroup = E09D16662F44CA1E009C51A3;
|
||||
minimizedProjectReferenceProxies = 1;
|
||||
packageReferences = (
|
||||
E09D17502F44F359009C51A3 /* XCRemoteSwiftPackageReference "mlx-swift-lm" */,
|
||||
);
|
||||
preferredProjectObjectVersion = 77;
|
||||
productRefGroup = E09D16702F44CA1E009C51A3 /* Products */;
|
||||
projectDirPath = "";
|
||||
projectRoot = "";
|
||||
targets = (
|
||||
E09D166E2F44CA1E009C51A3 /* EXO-iOS */,
|
||||
E09D167B2F44CA20009C51A3 /* EXO-iOSTests */,
|
||||
E09D16852F44CA20009C51A3 /* EXO-iOSUITests */,
|
||||
);
|
||||
};
|
||||
/* End PBXProject section */
|
||||
|
||||
/* Begin PBXResourcesBuildPhase section */
|
||||
E09D166D2F44CA1E009C51A3 /* Resources */ = {
|
||||
isa = PBXResourcesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
E09D167A2F44CA20009C51A3 /* Resources */ = {
|
||||
isa = PBXResourcesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
E09D16842F44CA20009C51A3 /* Resources */ = {
|
||||
isa = PBXResourcesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
/* End PBXResourcesBuildPhase section */
|
||||
|
||||
/* Begin PBXSourcesBuildPhase section */
|
||||
E09D166B2F44CA1E009C51A3 /* Sources */ = {
|
||||
isa = PBXSourcesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
E09D16782F44CA20009C51A3 /* Sources */ = {
|
||||
isa = PBXSourcesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
E09D16822F44CA20009C51A3 /* Sources */ = {
|
||||
isa = PBXSourcesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
/* End PBXSourcesBuildPhase section */
|
||||
|
||||
/* Begin PBXTargetDependency section */
|
||||
E09D167E2F44CA20009C51A3 /* PBXTargetDependency */ = {
|
||||
isa = PBXTargetDependency;
|
||||
target = E09D166E2F44CA1E009C51A3 /* EXO-iOS */;
|
||||
targetProxy = E09D167D2F44CA20009C51A3 /* PBXContainerItemProxy */;
|
||||
};
|
||||
E09D16882F44CA20009C51A3 /* PBXTargetDependency */ = {
|
||||
isa = PBXTargetDependency;
|
||||
target = E09D166E2F44CA1E009C51A3 /* EXO-iOS */;
|
||||
targetProxy = E09D16872F44CA20009C51A3 /* PBXContainerItemProxy */;
|
||||
};
|
||||
/* End PBXTargetDependency section */
|
||||
|
||||
/* Begin XCBuildConfiguration section */
|
||||
E09D168E2F44CA20009C51A3 /* Debug */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
ALWAYS_SEARCH_USER_PATHS = NO;
|
||||
ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;
|
||||
CLANG_ANALYZER_NONNULL = YES;
|
||||
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
|
||||
CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
|
||||
CLANG_ENABLE_MODULES = YES;
|
||||
CLANG_ENABLE_OBJC_ARC = YES;
|
||||
CLANG_ENABLE_OBJC_WEAK = YES;
|
||||
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
|
||||
CLANG_WARN_BOOL_CONVERSION = YES;
|
||||
CLANG_WARN_COMMA = YES;
|
||||
CLANG_WARN_CONSTANT_CONVERSION = YES;
|
||||
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
|
||||
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
|
||||
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
|
||||
CLANG_WARN_EMPTY_BODY = YES;
|
||||
CLANG_WARN_ENUM_CONVERSION = YES;
|
||||
CLANG_WARN_INFINITE_RECURSION = YES;
|
||||
CLANG_WARN_INT_CONVERSION = YES;
|
||||
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
|
||||
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
|
||||
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
|
||||
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
|
||||
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
|
||||
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
|
||||
CLANG_WARN_STRICT_PROTOTYPES = YES;
|
||||
CLANG_WARN_SUSPICIOUS_MOVE = YES;
|
||||
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
|
||||
CLANG_WARN_UNREACHABLE_CODE = YES;
|
||||
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
|
||||
COPY_PHASE_STRIP = NO;
|
||||
DEBUG_INFORMATION_FORMAT = dwarf;
|
||||
ENABLE_STRICT_OBJC_MSGSEND = YES;
|
||||
ENABLE_TESTABILITY = YES;
|
||||
ENABLE_USER_SCRIPT_SANDBOXING = YES;
|
||||
GCC_C_LANGUAGE_STANDARD = gnu17;
|
||||
GCC_DYNAMIC_NO_PIC = NO;
|
||||
GCC_NO_COMMON_BLOCKS = YES;
|
||||
GCC_OPTIMIZATION_LEVEL = 0;
|
||||
GCC_PREPROCESSOR_DEFINITIONS = (
|
||||
"DEBUG=1",
|
||||
"$(inherited)",
|
||||
);
|
||||
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
|
||||
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
|
||||
GCC_WARN_UNDECLARED_SELECTOR = YES;
|
||||
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
|
||||
GCC_WARN_UNUSED_FUNCTION = YES;
|
||||
GCC_WARN_UNUSED_VARIABLE = YES;
|
||||
IPHONEOS_DEPLOYMENT_TARGET = 26.2;
|
||||
LOCALIZATION_PREFERS_STRING_CATALOGS = YES;
|
||||
MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE;
|
||||
MTL_FAST_MATH = YES;
|
||||
ONLY_ACTIVE_ARCH = YES;
|
||||
SDKROOT = iphoneos;
|
||||
SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)";
|
||||
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
|
||||
};
|
||||
name = Debug;
|
||||
};
|
||||
E09D168F2F44CA20009C51A3 /* Release */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
ALWAYS_SEARCH_USER_PATHS = NO;
|
||||
ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;
|
||||
CLANG_ANALYZER_NONNULL = YES;
|
||||
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
|
||||
CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
|
||||
CLANG_ENABLE_MODULES = YES;
|
||||
CLANG_ENABLE_OBJC_ARC = YES;
|
||||
CLANG_ENABLE_OBJC_WEAK = YES;
|
||||
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
|
||||
CLANG_WARN_BOOL_CONVERSION = YES;
|
||||
CLANG_WARN_COMMA = YES;
|
||||
CLANG_WARN_CONSTANT_CONVERSION = YES;
|
||||
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
|
||||
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
|
||||
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
|
||||
CLANG_WARN_EMPTY_BODY = YES;
|
||||
CLANG_WARN_ENUM_CONVERSION = YES;
|
||||
CLANG_WARN_INFINITE_RECURSION = YES;
|
||||
CLANG_WARN_INT_CONVERSION = YES;
|
||||
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
|
||||
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
|
||||
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
|
||||
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
|
||||
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
|
||||
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
|
||||
CLANG_WARN_STRICT_PROTOTYPES = YES;
|
||||
CLANG_WARN_SUSPICIOUS_MOVE = YES;
|
||||
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
|
||||
CLANG_WARN_UNREACHABLE_CODE = YES;
|
||||
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
|
||||
COPY_PHASE_STRIP = NO;
|
||||
DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym";
|
||||
ENABLE_NS_ASSERTIONS = NO;
|
||||
ENABLE_STRICT_OBJC_MSGSEND = YES;
|
||||
ENABLE_USER_SCRIPT_SANDBOXING = YES;
|
||||
GCC_C_LANGUAGE_STANDARD = gnu17;
|
||||
GCC_NO_COMMON_BLOCKS = YES;
|
||||
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
|
||||
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
|
||||
GCC_WARN_UNDECLARED_SELECTOR = YES;
|
||||
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
|
||||
GCC_WARN_UNUSED_FUNCTION = YES;
|
||||
GCC_WARN_UNUSED_VARIABLE = YES;
|
||||
IPHONEOS_DEPLOYMENT_TARGET = 26.2;
|
||||
LOCALIZATION_PREFERS_STRING_CATALOGS = YES;
|
||||
MTL_ENABLE_DEBUG_INFO = NO;
|
||||
MTL_FAST_MATH = YES;
|
||||
SDKROOT = iphoneos;
|
||||
SWIFT_COMPILATION_MODE = wholemodule;
|
||||
VALIDATE_PRODUCT = YES;
|
||||
};
|
||||
name = Release;
|
||||
};
|
||||
E09D16912F44CA20009C51A3 /* Debug */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
|
||||
ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor;
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
CURRENT_PROJECT_VERSION = 1;
|
||||
DEVELOPMENT_TEAM = 3M3M67U93M;
|
||||
ENABLE_PREVIEWS = YES;
|
||||
GENERATE_INFOPLIST_FILE = YES;
|
||||
INFOPLIST_FILE = "EXO-iOS/Info.plist";
|
||||
INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES;
|
||||
INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES;
|
||||
INFOPLIST_KEY_UILaunchScreen_Generation = YES;
|
||||
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
|
||||
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
|
||||
LD_RUNPATH_SEARCH_PATHS = (
|
||||
"$(inherited)",
|
||||
"@executable_path/Frameworks",
|
||||
);
|
||||
MARKETING_VERSION = 1.0;
|
||||
PRODUCT_BUNDLE_IDENTIFIER = "com.exo.EXO-iOS";
|
||||
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||
STRING_CATALOG_GENERATE_SYMBOLS = YES;
|
||||
SWIFT_APPROACHABLE_CONCURRENCY = YES;
|
||||
SWIFT_DEFAULT_ACTOR_ISOLATION = MainActor;
|
||||
SWIFT_EMIT_LOC_STRINGS = YES;
|
||||
SWIFT_UPCOMING_FEATURE_MEMBER_IMPORT_VISIBILITY = YES;
|
||||
SWIFT_VERSION = 5.0;
|
||||
TARGETED_DEVICE_FAMILY = "1,2";
|
||||
};
|
||||
name = Debug;
|
||||
};
|
||||
E09D16922F44CA20009C51A3 /* Release */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
|
||||
ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor;
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
CURRENT_PROJECT_VERSION = 1;
|
||||
DEVELOPMENT_TEAM = 3M3M67U93M;
|
||||
ENABLE_PREVIEWS = YES;
|
||||
GENERATE_INFOPLIST_FILE = YES;
|
||||
INFOPLIST_FILE = "EXO-iOS/Info.plist";
|
||||
INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES;
|
||||
INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES;
|
||||
INFOPLIST_KEY_UILaunchScreen_Generation = YES;
|
||||
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
|
||||
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
|
||||
LD_RUNPATH_SEARCH_PATHS = (
|
||||
"$(inherited)",
|
||||
"@executable_path/Frameworks",
|
||||
);
|
||||
MARKETING_VERSION = 1.0;
|
||||
PRODUCT_BUNDLE_IDENTIFIER = "com.exo.EXO-iOS";
|
||||
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||
STRING_CATALOG_GENERATE_SYMBOLS = YES;
|
||||
SWIFT_APPROACHABLE_CONCURRENCY = YES;
|
||||
SWIFT_DEFAULT_ACTOR_ISOLATION = MainActor;
|
||||
SWIFT_EMIT_LOC_STRINGS = YES;
|
||||
SWIFT_UPCOMING_FEATURE_MEMBER_IMPORT_VISIBILITY = YES;
|
||||
SWIFT_VERSION = 5.0;
|
||||
TARGETED_DEVICE_FAMILY = "1,2";
|
||||
};
|
||||
name = Release;
|
||||
};
|
||||
E09D16942F44CA20009C51A3 /* Debug */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
BUNDLE_LOADER = "$(TEST_HOST)";
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
CURRENT_PROJECT_VERSION = 1;
|
||||
GENERATE_INFOPLIST_FILE = YES;
|
||||
IPHONEOS_DEPLOYMENT_TARGET = 26.2;
|
||||
MARKETING_VERSION = 1.0;
|
||||
PRODUCT_BUNDLE_IDENTIFIER = "com.exo.EXO-iOSTests";
|
||||
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||
STRING_CATALOG_GENERATE_SYMBOLS = NO;
|
||||
SWIFT_APPROACHABLE_CONCURRENCY = YES;
|
||||
SWIFT_EMIT_LOC_STRINGS = NO;
|
||||
SWIFT_UPCOMING_FEATURE_MEMBER_IMPORT_VISIBILITY = YES;
|
||||
SWIFT_VERSION = 5.0;
|
||||
TARGETED_DEVICE_FAMILY = "1,2";
|
||||
TEST_HOST = "$(BUILT_PRODUCTS_DIR)/EXO-iOS.app/$(BUNDLE_EXECUTABLE_FOLDER_PATH)/EXO-iOS";
|
||||
};
|
||||
name = Debug;
|
||||
};
|
||||
E09D16952F44CA20009C51A3 /* Release */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
BUNDLE_LOADER = "$(TEST_HOST)";
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
CURRENT_PROJECT_VERSION = 1;
|
||||
GENERATE_INFOPLIST_FILE = YES;
|
||||
IPHONEOS_DEPLOYMENT_TARGET = 26.2;
|
||||
MARKETING_VERSION = 1.0;
|
||||
PRODUCT_BUNDLE_IDENTIFIER = "com.exo.EXO-iOSTests";
|
||||
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||
STRING_CATALOG_GENERATE_SYMBOLS = NO;
|
||||
SWIFT_APPROACHABLE_CONCURRENCY = YES;
|
||||
SWIFT_EMIT_LOC_STRINGS = NO;
|
||||
SWIFT_UPCOMING_FEATURE_MEMBER_IMPORT_VISIBILITY = YES;
|
||||
SWIFT_VERSION = 5.0;
|
||||
TARGETED_DEVICE_FAMILY = "1,2";
|
||||
TEST_HOST = "$(BUILT_PRODUCTS_DIR)/EXO-iOS.app/$(BUNDLE_EXECUTABLE_FOLDER_PATH)/EXO-iOS";
|
||||
};
|
||||
name = Release;
|
||||
};
|
||||
E09D16972F44CA20009C51A3 /* Debug */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
CURRENT_PROJECT_VERSION = 1;
|
||||
GENERATE_INFOPLIST_FILE = YES;
|
||||
MARKETING_VERSION = 1.0;
|
||||
PRODUCT_BUNDLE_IDENTIFIER = "com.exo.EXO-iOSUITests";
|
||||
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||
STRING_CATALOG_GENERATE_SYMBOLS = NO;
|
||||
SWIFT_APPROACHABLE_CONCURRENCY = YES;
|
||||
SWIFT_EMIT_LOC_STRINGS = NO;
|
||||
SWIFT_UPCOMING_FEATURE_MEMBER_IMPORT_VISIBILITY = YES;
|
||||
SWIFT_VERSION = 5.0;
|
||||
TARGETED_DEVICE_FAMILY = "1,2";
|
||||
TEST_TARGET_NAME = "EXO-iOS";
|
||||
};
|
||||
name = Debug;
|
||||
};
|
||||
E09D16982F44CA20009C51A3 /* Release */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
CURRENT_PROJECT_VERSION = 1;
|
||||
GENERATE_INFOPLIST_FILE = YES;
|
||||
MARKETING_VERSION = 1.0;
|
||||
PRODUCT_BUNDLE_IDENTIFIER = "com.exo.EXO-iOSUITests";
|
||||
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||
STRING_CATALOG_GENERATE_SYMBOLS = NO;
|
||||
SWIFT_APPROACHABLE_CONCURRENCY = YES;
|
||||
SWIFT_EMIT_LOC_STRINGS = NO;
|
||||
SWIFT_UPCOMING_FEATURE_MEMBER_IMPORT_VISIBILITY = YES;
|
||||
SWIFT_VERSION = 5.0;
|
||||
TARGETED_DEVICE_FAMILY = "1,2";
|
||||
TEST_TARGET_NAME = "EXO-iOS";
|
||||
};
|
||||
name = Release;
|
||||
};
|
||||
/* End XCBuildConfiguration section */
|
||||
|
||||
/* Begin XCConfigurationList section */
|
||||
E09D166A2F44CA1E009C51A3 /* Build configuration list for PBXProject "EXO-iOS" */ = {
|
||||
isa = XCConfigurationList;
|
||||
buildConfigurations = (
|
||||
E09D168E2F44CA20009C51A3 /* Debug */,
|
||||
E09D168F2F44CA20009C51A3 /* Release */,
|
||||
);
|
||||
defaultConfigurationIsVisible = 0;
|
||||
defaultConfigurationName = Release;
|
||||
};
|
||||
E09D16902F44CA20009C51A3 /* Build configuration list for PBXNativeTarget "EXO-iOS" */ = {
|
||||
isa = XCConfigurationList;
|
||||
buildConfigurations = (
|
||||
E09D16912F44CA20009C51A3 /* Debug */,
|
||||
E09D16922F44CA20009C51A3 /* Release */,
|
||||
);
|
||||
defaultConfigurationIsVisible = 0;
|
||||
defaultConfigurationName = Release;
|
||||
};
|
||||
E09D16932F44CA20009C51A3 /* Build configuration list for PBXNativeTarget "EXO-iOSTests" */ = {
|
||||
isa = XCConfigurationList;
|
||||
buildConfigurations = (
|
||||
E09D16942F44CA20009C51A3 /* Debug */,
|
||||
E09D16952F44CA20009C51A3 /* Release */,
|
||||
);
|
||||
defaultConfigurationIsVisible = 0;
|
||||
defaultConfigurationName = Release;
|
||||
};
|
||||
E09D16962F44CA20009C51A3 /* Build configuration list for PBXNativeTarget "EXO-iOSUITests" */ = {
|
||||
isa = XCConfigurationList;
|
||||
buildConfigurations = (
|
||||
E09D16972F44CA20009C51A3 /* Debug */,
|
||||
E09D16982F44CA20009C51A3 /* Release */,
|
||||
);
|
||||
defaultConfigurationIsVisible = 0;
|
||||
defaultConfigurationName = Release;
|
||||
};
|
||||
/* End XCConfigurationList section */
|
||||
|
||||
/* Begin XCRemoteSwiftPackageReference section */
|
||||
E09D17502F44F359009C51A3 /* XCRemoteSwiftPackageReference "mlx-swift-lm" */ = {
|
||||
isa = XCRemoteSwiftPackageReference;
|
||||
repositoryURL = "https://github.com/ml-explore/mlx-swift-lm";
|
||||
requirement = {
|
||||
kind = upToNextMajorVersion;
|
||||
minimumVersion = 2.30.3;
|
||||
};
|
||||
};
|
||||
/* End XCRemoteSwiftPackageReference section */
|
||||
|
||||
/* Begin XCSwiftPackageProductDependency section */
|
||||
E09D17512F44F359009C51A3 /* MLXLLM */ = {
|
||||
isa = XCSwiftPackageProductDependency;
|
||||
package = E09D17502F44F359009C51A3 /* XCRemoteSwiftPackageReference "mlx-swift-lm" */;
|
||||
productName = MLXLLM;
|
||||
};
|
||||
E09D17532F44F359009C51A3 /* MLXLMCommon */ = {
|
||||
isa = XCSwiftPackageProductDependency;
|
||||
package = E09D17502F44F359009C51A3 /* XCRemoteSwiftPackageReference "mlx-swift-lm" */;
|
||||
productName = MLXLMCommon;
|
||||
};
|
||||
/* End XCSwiftPackageProductDependency section */
|
||||
};
|
||||
rootObject = E09D16672F44CA1E009C51A3 /* Project object */;
|
||||
}
|
||||
7
app/EXO-iOS/EXO-iOS.xcodeproj/project.xcworkspace/contents.xcworkspacedata
generated
Normal file
7
app/EXO-iOS/EXO-iOS.xcodeproj/project.xcworkspace/contents.xcworkspacedata
generated
Normal file
@@ -0,0 +1,7 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<Workspace
|
||||
version = "1.0">
|
||||
<FileRef
|
||||
location = "self:">
|
||||
</FileRef>
|
||||
</Workspace>
|
||||
@@ -0,0 +1,60 @@
|
||||
{
|
||||
"originHash" : "facc0ac7c70363ea20f6cd1235de91dea6b06f0d00190946045a6c8ae753abc2",
|
||||
"pins" : [
|
||||
{
|
||||
"identity" : "mlx-swift",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/ml-explore/mlx-swift",
|
||||
"state" : {
|
||||
"revision" : "6ba4827fb82c97d012eec9ab4b2de21f85c3b33d",
|
||||
"version" : "0.30.6"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "mlx-swift-lm",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/ml-explore/mlx-swift-lm",
|
||||
"state" : {
|
||||
"revision" : "360c5052b81cc154b04ee0933597a4ad6db4b8ae",
|
||||
"version" : "2.30.3"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-collections",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/apple/swift-collections.git",
|
||||
"state" : {
|
||||
"revision" : "7b847a3b7008b2dc2f47ca3110d8c782fb2e5c7e",
|
||||
"version" : "1.3.0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-jinja",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/huggingface/swift-jinja.git",
|
||||
"state" : {
|
||||
"revision" : "d81197f35f41445bc10e94600795e68c6f5e94b0",
|
||||
"version" : "2.3.1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-numerics",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/apple/swift-numerics",
|
||||
"state" : {
|
||||
"revision" : "0c0290ff6b24942dadb83a929ffaaa1481df04a2",
|
||||
"version" : "1.1.1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-transformers",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/huggingface/swift-transformers",
|
||||
"state" : {
|
||||
"revision" : "573e5c9036c2f136b3a8a071da8e8907322403d0",
|
||||
"version" : "1.1.6"
|
||||
}
|
||||
}
|
||||
],
|
||||
"version" : 3
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
{
|
||||
"colors" : [
|
||||
{
|
||||
"color" : {
|
||||
"color-space" : "srgb",
|
||||
"components" : {
|
||||
"alpha" : "1.000",
|
||||
"blue" : "0x00",
|
||||
"green" : "0xD7",
|
||||
"red" : "0xFF"
|
||||
}
|
||||
},
|
||||
"idiom" : "universal"
|
||||
}
|
||||
],
|
||||
"info" : {
|
||||
"author" : "xcode",
|
||||
"version" : 1
|
||||
}
|
||||
}
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 10 KiB |
@@ -0,0 +1,38 @@
|
||||
{
|
||||
"images" : [
|
||||
{
|
||||
"filename" : "AppIcon.png",
|
||||
"idiom" : "universal",
|
||||
"platform" : "ios",
|
||||
"size" : "1024x1024"
|
||||
},
|
||||
{
|
||||
"appearances" : [
|
||||
{
|
||||
"appearance" : "luminosity",
|
||||
"value" : "dark"
|
||||
}
|
||||
],
|
||||
"filename" : "AppIcon.png",
|
||||
"idiom" : "universal",
|
||||
"platform" : "ios",
|
||||
"size" : "1024x1024"
|
||||
},
|
||||
{
|
||||
"appearances" : [
|
||||
{
|
||||
"appearance" : "luminosity",
|
||||
"value" : "tinted"
|
||||
}
|
||||
],
|
||||
"filename" : "AppIcon.png",
|
||||
"idiom" : "universal",
|
||||
"platform" : "ios",
|
||||
"size" : "1024x1024"
|
||||
}
|
||||
],
|
||||
"info" : {
|
||||
"author" : "xcode",
|
||||
"version" : 1
|
||||
}
|
||||
}
|
||||
6
app/EXO-iOS/EXO-iOS/Assets.xcassets/Contents.json
Normal file
6
app/EXO-iOS/EXO-iOS/Assets.xcassets/Contents.json
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"info" : {
|
||||
"author" : "xcode",
|
||||
"version" : 1
|
||||
}
|
||||
}
|
||||
21
app/EXO-iOS/EXO-iOS/Assets.xcassets/ExoLogo.imageset/Contents.json
vendored
Normal file
21
app/EXO-iOS/EXO-iOS/Assets.xcassets/ExoLogo.imageset/Contents.json
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
{
|
||||
"images" : [
|
||||
{
|
||||
"filename" : "exo-logo.png",
|
||||
"idiom" : "universal",
|
||||
"scale" : "1x"
|
||||
},
|
||||
{
|
||||
"idiom" : "universal",
|
||||
"scale" : "2x"
|
||||
},
|
||||
{
|
||||
"idiom" : "universal",
|
||||
"scale" : "3x"
|
||||
}
|
||||
],
|
||||
"info" : {
|
||||
"author" : "xcode",
|
||||
"version" : 1
|
||||
}
|
||||
}
|
||||
BIN
app/EXO-iOS/EXO-iOS/Assets.xcassets/ExoLogo.imageset/exo-logo.png
vendored
Normal file
BIN
app/EXO-iOS/EXO-iOS/Assets.xcassets/ExoLogo.imageset/exo-logo.png
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.6 KiB |
8
app/EXO-iOS/EXO-iOS/EXO-iOS.entitlements
Normal file
8
app/EXO-iOS/EXO-iOS/EXO-iOS.entitlements
Normal file
@@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>com.apple.developer.kernel.increased-memory-limit</key>
|
||||
<true/>
|
||||
</dict>
|
||||
</plist>
|
||||
67
app/EXO-iOS/EXO-iOS/EXO_iOSApp.swift
Normal file
67
app/EXO-iOS/EXO-iOS/EXO_iOSApp.swift
Normal file
@@ -0,0 +1,67 @@
|
||||
import SwiftUI
|
||||
import UIKit
|
||||
|
||||
@main
|
||||
struct EXO_iOSApp: App {
|
||||
@State private var clusterService = ClusterService()
|
||||
@State private var discoveryService = DiscoveryService()
|
||||
@State private var localInferenceService = LocalInferenceService()
|
||||
@State private var chatService: ChatService?
|
||||
|
||||
init() {
|
||||
let darkGray = UIColor(red: 0x1F / 255.0, green: 0x1F / 255.0, blue: 0x1F / 255.0, alpha: 1)
|
||||
let yellow = UIColor(red: 0xFF / 255.0, green: 0xD7 / 255.0, blue: 0x00 / 255.0, alpha: 1)
|
||||
|
||||
let navAppearance = UINavigationBarAppearance()
|
||||
navAppearance.configureWithOpaqueBackground()
|
||||
navAppearance.backgroundColor = darkGray
|
||||
navAppearance.titleTextAttributes = [
|
||||
.foregroundColor: yellow,
|
||||
.font: UIFont.monospacedSystemFont(ofSize: 17, weight: .semibold),
|
||||
]
|
||||
navAppearance.largeTitleTextAttributes = [
|
||||
.foregroundColor: yellow,
|
||||
.font: UIFont.monospacedSystemFont(ofSize: 34, weight: .bold),
|
||||
]
|
||||
|
||||
UINavigationBar.appearance().standardAppearance = navAppearance
|
||||
UINavigationBar.appearance().compactAppearance = navAppearance
|
||||
UINavigationBar.appearance().scrollEdgeAppearance = navAppearance
|
||||
UINavigationBar.appearance().tintColor = yellow
|
||||
}
|
||||
|
||||
var body: some Scene {
|
||||
WindowGroup {
|
||||
if let chatService {
|
||||
RootView()
|
||||
.environment(clusterService)
|
||||
.environment(discoveryService)
|
||||
.environment(chatService)
|
||||
.environment(localInferenceService)
|
||||
.preferredColorScheme(.dark)
|
||||
.task {
|
||||
await clusterService.attemptAutoReconnect()
|
||||
discoveryService.startBrowsing()
|
||||
await localInferenceService.prepareModel()
|
||||
}
|
||||
.onChange(of: discoveryService.discoveredClusters) { _, clusters in
|
||||
guard !clusterService.isConnected,
|
||||
case .disconnected = clusterService.connectionState,
|
||||
let first = clusters.first
|
||||
else { return }
|
||||
Task {
|
||||
await clusterService.connectToDiscoveredCluster(
|
||||
first, using: discoveryService)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Color.exoBlack.onAppear {
|
||||
chatService = ChatService(
|
||||
clusterService: clusterService,
|
||||
localInferenceService: localInferenceService
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
19
app/EXO-iOS/EXO-iOS/Info.plist
Normal file
19
app/EXO-iOS/EXO-iOS/Info.plist
Normal file
@@ -0,0 +1,19 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>UIUserInterfaceStyle</key>
|
||||
<string>Dark</string>
|
||||
<key>CFBundleDisplayName</key>
|
||||
<string>EXO</string>
|
||||
<key>NSLocalNetworkUsageDescription</key>
|
||||
<string>EXO needs local network access to connect to your EXO cluster.</string>
|
||||
<key>NSBonjourServices</key>
|
||||
<array>
|
||||
<string>_exo._tcp</string>
|
||||
<string>_p2p._tcp</string>
|
||||
<string>_p2p._udp</string>
|
||||
<string>_libp2p._udp</string>
|
||||
</array>
|
||||
</dict>
|
||||
</plist>
|
||||
129
app/EXO-iOS/EXO-iOS/Models/ChatCompletionTypes.swift
Normal file
129
app/EXO-iOS/EXO-iOS/Models/ChatCompletionTypes.swift
Normal file
@@ -0,0 +1,129 @@
|
||||
import Foundation
|
||||
|
||||
// MARK: - Request
|
||||
|
||||
struct ChatCompletionRequest: Encodable {
|
||||
let model: String
|
||||
let messages: [ChatCompletionMessageParam]
|
||||
let stream: Bool
|
||||
let maxTokens: Int?
|
||||
let temperature: Double?
|
||||
|
||||
enum CodingKeys: String, CodingKey {
|
||||
case model, messages, stream, temperature
|
||||
case maxTokens = "max_tokens"
|
||||
}
|
||||
}
|
||||
|
||||
struct ChatCompletionMessageParam: Encodable {
|
||||
let role: String
|
||||
let content: String
|
||||
}
|
||||
|
||||
// MARK: - Streaming Response
|
||||
|
||||
struct ChatCompletionChunk: Decodable {
|
||||
let id: String
|
||||
let model: String?
|
||||
let choices: [StreamingChoice]
|
||||
let usage: ChunkUsage?
|
||||
|
||||
init(id: String, model: String?, choices: [StreamingChoice], usage: ChunkUsage?) {
|
||||
self.id = id
|
||||
self.model = model
|
||||
self.choices = choices
|
||||
self.usage = usage
|
||||
}
|
||||
}
|
||||
|
||||
struct StreamingChoice: Decodable {
|
||||
let index: Int
|
||||
let delta: Delta
|
||||
let finishReason: String?
|
||||
|
||||
enum CodingKeys: String, CodingKey {
|
||||
case index, delta
|
||||
case finishReason = "finish_reason"
|
||||
}
|
||||
|
||||
init(index: Int, delta: Delta, finishReason: String?) {
|
||||
self.index = index
|
||||
self.delta = delta
|
||||
self.finishReason = finishReason
|
||||
}
|
||||
}
|
||||
|
||||
struct Delta: Decodable {
|
||||
let role: String?
|
||||
let content: String?
|
||||
|
||||
init(role: String?, content: String?) {
|
||||
self.role = role
|
||||
self.content = content
|
||||
}
|
||||
}
|
||||
|
||||
struct ChunkUsage: Decodable {
|
||||
let promptTokens: Int?
|
||||
let completionTokens: Int?
|
||||
let totalTokens: Int?
|
||||
|
||||
enum CodingKeys: String, CodingKey {
|
||||
case promptTokens = "prompt_tokens"
|
||||
case completionTokens = "completion_tokens"
|
||||
case totalTokens = "total_tokens"
|
||||
}
|
||||
|
||||
init(promptTokens: Int?, completionTokens: Int?, totalTokens: Int?) {
|
||||
self.promptTokens = promptTokens
|
||||
self.completionTokens = completionTokens
|
||||
self.totalTokens = totalTokens
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Non-Streaming Response
|
||||
|
||||
struct ChatCompletionResponse: Decodable {
|
||||
let id: String
|
||||
let model: String?
|
||||
let choices: [ResponseChoice]
|
||||
}
|
||||
|
||||
struct ResponseChoice: Decodable {
|
||||
let index: Int
|
||||
let message: ResponseMessage
|
||||
let finishReason: String?
|
||||
|
||||
enum CodingKeys: String, CodingKey {
|
||||
case index, message
|
||||
case finishReason = "finish_reason"
|
||||
}
|
||||
}
|
||||
|
||||
struct ResponseMessage: Decodable {
|
||||
let role: String?
|
||||
let content: String?
|
||||
}
|
||||
|
||||
// MARK: - Models List
|
||||
|
||||
struct ModelListResponse: Decodable {
|
||||
let data: [ModelInfo]
|
||||
}
|
||||
|
||||
struct ModelInfo: Decodable, Identifiable {
|
||||
let id: String
|
||||
let name: String?
|
||||
}
|
||||
|
||||
// MARK: - Error
|
||||
|
||||
struct APIErrorResponse: Decodable {
|
||||
let error: APIErrorInfo
|
||||
}
|
||||
|
||||
struct APIErrorInfo: Decodable {
|
||||
let message: String
|
||||
let type: String?
|
||||
let code: Int?
|
||||
}
|
||||
26
app/EXO-iOS/EXO-iOS/Models/ChatMessage.swift
Normal file
26
app/EXO-iOS/EXO-iOS/Models/ChatMessage.swift
Normal file
@@ -0,0 +1,26 @@
|
||||
import Foundation
|
||||
|
||||
struct ChatMessage: Identifiable, Equatable {
|
||||
let id: UUID
|
||||
let role: Role
|
||||
var content: String
|
||||
let timestamp: Date
|
||||
var isStreaming: Bool
|
||||
|
||||
enum Role: String, Codable {
|
||||
case user
|
||||
case assistant
|
||||
case system
|
||||
}
|
||||
|
||||
init(
|
||||
id: UUID = UUID(), role: Role, content: String, timestamp: Date = Date(),
|
||||
isStreaming: Bool = false
|
||||
) {
|
||||
self.id = id
|
||||
self.role = role
|
||||
self.content = content
|
||||
self.timestamp = timestamp
|
||||
self.isStreaming = isStreaming
|
||||
}
|
||||
}
|
||||
11
app/EXO-iOS/EXO-iOS/Models/ConnectionInfo.swift
Normal file
11
app/EXO-iOS/EXO-iOS/Models/ConnectionInfo.swift
Normal file
@@ -0,0 +1,11 @@
|
||||
import Foundation
|
||||
|
||||
struct ConnectionInfo: Codable, Equatable {
|
||||
let host: String
|
||||
let port: Int
|
||||
let nodeId: String?
|
||||
|
||||
var baseURL: URL { URL(string: "http://\(host):\(port)")! }
|
||||
|
||||
static let defaultPort = 52415
|
||||
}
|
||||
34
app/EXO-iOS/EXO-iOS/Models/Conversation.swift
Normal file
34
app/EXO-iOS/EXO-iOS/Models/Conversation.swift
Normal file
@@ -0,0 +1,34 @@
|
||||
import Foundation
|
||||
|
||||
struct Conversation: Identifiable, Codable, Equatable {
|
||||
let id: UUID
|
||||
var title: String
|
||||
var messages: [StoredMessage]
|
||||
var modelId: String?
|
||||
let createdAt: Date
|
||||
|
||||
init(
|
||||
id: UUID = UUID(), title: String = "New Chat", messages: [StoredMessage] = [],
|
||||
modelId: String? = nil, createdAt: Date = Date()
|
||||
) {
|
||||
self.id = id
|
||||
self.title = title
|
||||
self.messages = messages
|
||||
self.modelId = modelId
|
||||
self.createdAt = createdAt
|
||||
}
|
||||
}
|
||||
|
||||
struct StoredMessage: Identifiable, Codable, Equatable {
|
||||
let id: UUID
|
||||
let role: String
|
||||
var content: String
|
||||
let timestamp: Date
|
||||
|
||||
init(id: UUID = UUID(), role: String, content: String, timestamp: Date = Date()) {
|
||||
self.id = id
|
||||
self.role = role
|
||||
self.content = content
|
||||
self.timestamp = timestamp
|
||||
}
|
||||
}
|
||||
227
app/EXO-iOS/EXO-iOS/Services/ChatService.swift
Normal file
227
app/EXO-iOS/EXO-iOS/Services/ChatService.swift
Normal file
@@ -0,0 +1,227 @@
|
||||
import Foundation
|
||||
|
||||
@Observable
|
||||
@MainActor
|
||||
final class ChatService {
|
||||
var conversations: [Conversation] = []
|
||||
var activeConversationId: UUID?
|
||||
private(set) var isGenerating: Bool = false
|
||||
private var currentGenerationTask: Task<Void, Never>?
|
||||
|
||||
private let clusterService: ClusterService
|
||||
private let localInferenceService: LocalInferenceService
|
||||
|
||||
var canSendMessage: Bool {
|
||||
clusterService.isConnected || localInferenceService.isAvailable
|
||||
}
|
||||
|
||||
var activeConversation: Conversation? {
|
||||
guard let id = activeConversationId else { return nil }
|
||||
return conversations.first { $0.id == id }
|
||||
}
|
||||
|
||||
var activeMessages: [ChatMessage] {
|
||||
guard let conversation = activeConversation else { return [] }
|
||||
return conversation.messages.map { stored in
|
||||
ChatMessage(
|
||||
id: stored.id,
|
||||
role: ChatMessage.Role(rawValue: stored.role) ?? .user,
|
||||
content: stored.content,
|
||||
timestamp: stored.timestamp
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
init(clusterService: ClusterService, localInferenceService: LocalInferenceService) {
|
||||
self.clusterService = clusterService
|
||||
self.localInferenceService = localInferenceService
|
||||
loadConversations()
|
||||
}
|
||||
|
||||
// MARK: - Conversation Management
|
||||
|
||||
func createConversation(modelId: String? = nil) {
|
||||
let conversation = Conversation(
|
||||
modelId: modelId ?? clusterService.availableModels.first?.id)
|
||||
conversations.insert(conversation, at: 0)
|
||||
activeConversationId = conversation.id
|
||||
saveConversations()
|
||||
}
|
||||
|
||||
func deleteConversation(id: UUID) {
|
||||
conversations.removeAll { $0.id == id }
|
||||
if activeConversationId == id {
|
||||
activeConversationId = conversations.first?.id
|
||||
}
|
||||
saveConversations()
|
||||
}
|
||||
|
||||
func setActiveConversation(id: UUID) {
|
||||
activeConversationId = id
|
||||
}
|
||||
|
||||
func setModelForActiveConversation(_ modelId: String) {
|
||||
guard let index = conversations.firstIndex(where: { $0.id == activeConversationId }) else {
|
||||
return
|
||||
}
|
||||
conversations[index].modelId = modelId
|
||||
saveConversations()
|
||||
}
|
||||
|
||||
// MARK: - Messaging
|
||||
|
||||
func sendMessage(_ text: String) {
|
||||
guard !text.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty else { return }
|
||||
|
||||
if activeConversation == nil {
|
||||
createConversation()
|
||||
}
|
||||
|
||||
guard let index = conversations.firstIndex(where: { $0.id == activeConversationId }) else {
|
||||
return
|
||||
}
|
||||
|
||||
let userMessage = StoredMessage(role: "user", content: text)
|
||||
conversations[index].messages.append(userMessage)
|
||||
|
||||
if conversations[index].title == "New Chat" {
|
||||
let preview = String(text.prefix(40))
|
||||
conversations[index].title = preview + (text.count > 40 ? "..." : "")
|
||||
}
|
||||
|
||||
let modelId: String
|
||||
if clusterService.isConnected {
|
||||
guard
|
||||
let clusterId = conversations[index].modelId
|
||||
?? clusterService.availableModels.first?.id
|
||||
else {
|
||||
let errorMessage = StoredMessage(
|
||||
role: "assistant", content: "No model selected. Please select a model first.")
|
||||
conversations[index].messages.append(errorMessage)
|
||||
saveConversations()
|
||||
return
|
||||
}
|
||||
modelId = clusterId
|
||||
} else if localInferenceService.isAvailable {
|
||||
modelId = localInferenceService.defaultModelId
|
||||
} else {
|
||||
let errorMessage = StoredMessage(
|
||||
role: "assistant",
|
||||
content: "Not connected to a cluster and local model is not available.")
|
||||
conversations[index].messages.append(errorMessage)
|
||||
saveConversations()
|
||||
return
|
||||
}
|
||||
|
||||
conversations[index].modelId = modelId
|
||||
|
||||
let assistantMessageId = UUID()
|
||||
let assistantMessage = StoredMessage(
|
||||
id: assistantMessageId, role: "assistant", content: "", timestamp: Date())
|
||||
conversations[index].messages.append(assistantMessage)
|
||||
|
||||
let messagesForAPI = conversations[index].messages.dropLast().map { stored in
|
||||
ChatCompletionMessageParam(role: stored.role, content: stored.content)
|
||||
}
|
||||
|
||||
let request = ChatCompletionRequest(
|
||||
model: modelId,
|
||||
messages: Array(messagesForAPI),
|
||||
stream: true,
|
||||
maxTokens: 4096,
|
||||
temperature: nil
|
||||
)
|
||||
|
||||
let conversationId = conversations[index].id
|
||||
|
||||
isGenerating = true
|
||||
currentGenerationTask = Task { [weak self] in
|
||||
guard let self else { return }
|
||||
await self.performStreaming(
|
||||
request: request, conversationId: conversationId,
|
||||
assistantMessageId: assistantMessageId)
|
||||
}
|
||||
|
||||
saveConversations()
|
||||
}
|
||||
|
||||
func cancelGeneration() {
|
||||
currentGenerationTask?.cancel()
|
||||
currentGenerationTask = nil
|
||||
localInferenceService.cancelGeneration()
|
||||
isGenerating = false
|
||||
}
|
||||
|
||||
// MARK: - Streaming
|
||||
|
||||
private func performStreaming(
|
||||
request: ChatCompletionRequest, conversationId: UUID, assistantMessageId: UUID
|
||||
) async {
|
||||
defer {
|
||||
isGenerating = false
|
||||
currentGenerationTask = nil
|
||||
saveConversations()
|
||||
}
|
||||
|
||||
do {
|
||||
let stream =
|
||||
clusterService.isConnected
|
||||
? clusterService.streamChatCompletion(request: request)
|
||||
: localInferenceService.streamChatCompletion(request: request)
|
||||
for try await chunk in stream {
|
||||
guard !Task.isCancelled else { return }
|
||||
guard let content = chunk.choices.first?.delta.content, !content.isEmpty else {
|
||||
continue
|
||||
}
|
||||
|
||||
if let convIndex = conversations.firstIndex(where: { $0.id == conversationId }),
|
||||
let msgIndex = conversations[convIndex].messages.firstIndex(where: {
|
||||
$0.id == assistantMessageId
|
||||
})
|
||||
{
|
||||
conversations[convIndex].messages[msgIndex].content += content
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
if !Task.isCancelled {
|
||||
if let convIndex = conversations.firstIndex(where: { $0.id == conversationId }),
|
||||
let msgIndex = conversations[convIndex].messages.firstIndex(where: {
|
||||
$0.id == assistantMessageId
|
||||
})
|
||||
{
|
||||
if conversations[convIndex].messages[msgIndex].content.isEmpty {
|
||||
conversations[convIndex].messages[msgIndex].content =
|
||||
"Error: \(error.localizedDescription)"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Persistence
|
||||
|
||||
private static var storageURL: URL {
|
||||
let documents = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask)
|
||||
.first!
|
||||
return documents.appendingPathComponent("exo_conversations.json")
|
||||
}
|
||||
|
||||
private func saveConversations() {
|
||||
do {
|
||||
let data = try JSONEncoder().encode(conversations)
|
||||
try data.write(to: Self.storageURL, options: .atomic)
|
||||
} catch {
|
||||
// Save failed silently
|
||||
}
|
||||
}
|
||||
|
||||
private func loadConversations() {
|
||||
do {
|
||||
let data = try Data(contentsOf: Self.storageURL)
|
||||
conversations = try JSONDecoder().decode([Conversation].self, from: data)
|
||||
activeConversationId = conversations.first?.id
|
||||
} catch {
|
||||
conversations = []
|
||||
}
|
||||
}
|
||||
}
|
||||
200
app/EXO-iOS/EXO-iOS/Services/ClusterService.swift
Normal file
200
app/EXO-iOS/EXO-iOS/Services/ClusterService.swift
Normal file
@@ -0,0 +1,200 @@
|
||||
import Foundation
|
||||
|
||||
enum ConnectionState: Equatable {
|
||||
case disconnected
|
||||
case connecting
|
||||
case connected(ConnectionInfo)
|
||||
}
|
||||
|
||||
struct ModelOption: Identifiable, Equatable {
|
||||
let id: String
|
||||
let displayName: String
|
||||
}
|
||||
|
||||
@Observable
|
||||
@MainActor
|
||||
final class ClusterService {
|
||||
private(set) var connectionState: ConnectionState = .disconnected
|
||||
private(set) var availableModels: [ModelOption] = []
|
||||
private(set) var lastError: String?
|
||||
|
||||
private let session: URLSession
|
||||
private let decoder: JSONDecoder
|
||||
private var pollingTask: Task<Void, Never>?
|
||||
|
||||
private static let connectionInfoKey = "exo_last_connection_info"
|
||||
|
||||
var isConnected: Bool {
|
||||
if case .connected = connectionState { return true }
|
||||
return false
|
||||
}
|
||||
|
||||
var currentConnection: ConnectionInfo? {
|
||||
if case .connected(let info) = connectionState { return info }
|
||||
return nil
|
||||
}
|
||||
|
||||
init(session: URLSession = .shared) {
|
||||
self.session = session
|
||||
let decoder = JSONDecoder()
|
||||
self.decoder = decoder
|
||||
}
|
||||
|
||||
// MARK: - Connection
|
||||
|
||||
func connect(to info: ConnectionInfo) async {
|
||||
connectionState = .connecting
|
||||
lastError = nil
|
||||
|
||||
do {
|
||||
let url = info.baseURL.appendingPathComponent("node_id")
|
||||
var request = URLRequest(url: url)
|
||||
request.timeoutInterval = 5
|
||||
request.cachePolicy = .reloadIgnoringLocalCacheData
|
||||
let (_, response) = try await session.data(for: request)
|
||||
|
||||
guard let httpResponse = response as? HTTPURLResponse,
|
||||
(200..<300).contains(httpResponse.statusCode)
|
||||
else {
|
||||
throw URLError(.badServerResponse)
|
||||
}
|
||||
|
||||
connectionState = .connected(info)
|
||||
persistConnection(info)
|
||||
startPolling()
|
||||
await fetchModels(baseURL: info.baseURL)
|
||||
} catch {
|
||||
connectionState = .disconnected
|
||||
lastError = "Could not connect to \(info.host):\(info.port)"
|
||||
}
|
||||
}
|
||||
|
||||
func connectToDiscoveredCluster(
|
||||
_ cluster: DiscoveredCluster, using discoveryService: DiscoveryService
|
||||
) async {
|
||||
guard case .disconnected = connectionState else { return }
|
||||
connectionState = .connecting
|
||||
lastError = nil
|
||||
|
||||
guard let info = await discoveryService.resolve(cluster) else {
|
||||
connectionState = .disconnected
|
||||
lastError = "Could not resolve \(cluster.name)"
|
||||
return
|
||||
}
|
||||
connectionState = .disconnected // reset so connect() can proceed
|
||||
await connect(to: info)
|
||||
}
|
||||
|
||||
func disconnect() {
|
||||
stopPolling()
|
||||
connectionState = .disconnected
|
||||
availableModels = []
|
||||
lastError = nil
|
||||
}
|
||||
|
||||
func attemptAutoReconnect() async {
|
||||
guard case .disconnected = connectionState,
|
||||
let info = loadPersistedConnection()
|
||||
else { return }
|
||||
await connect(to: info)
|
||||
}
|
||||
|
||||
// MARK: - Polling
|
||||
|
||||
private func startPolling(interval: TimeInterval = 2.0) {
|
||||
stopPolling()
|
||||
pollingTask = Task { [weak self] in
|
||||
while !Task.isCancelled {
|
||||
try? await Task.sleep(for: .seconds(interval))
|
||||
guard let self, !Task.isCancelled else { return }
|
||||
guard let connection = self.currentConnection else { return }
|
||||
await self.fetchModels(baseURL: connection.baseURL)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func stopPolling() {
|
||||
pollingTask?.cancel()
|
||||
pollingTask = nil
|
||||
}
|
||||
|
||||
// MARK: - API
|
||||
|
||||
private func fetchModels(baseURL: URL) async {
|
||||
do {
|
||||
let url = baseURL.appendingPathComponent("models")
|
||||
var request = URLRequest(url: url)
|
||||
request.cachePolicy = .reloadIgnoringLocalCacheData
|
||||
let (data, response) = try await session.data(for: request)
|
||||
|
||||
guard let httpResponse = response as? HTTPURLResponse,
|
||||
(200..<300).contains(httpResponse.statusCode)
|
||||
else { return }
|
||||
|
||||
let list = try decoder.decode(ModelListResponse.self, from: data)
|
||||
availableModels = list.data.map {
|
||||
ModelOption(id: $0.id, displayName: $0.name ?? $0.id)
|
||||
}
|
||||
} catch {
|
||||
// Models fetch failed silently — will retry on next poll
|
||||
}
|
||||
}
|
||||
|
||||
func streamChatCompletion(request body: ChatCompletionRequest) -> AsyncThrowingStream<
|
||||
ChatCompletionChunk, Error
|
||||
> {
|
||||
AsyncThrowingStream { continuation in
|
||||
let task = Task { [weak self] in
|
||||
guard let self, let connection = self.currentConnection else {
|
||||
continuation.finish(throwing: URLError(.notConnectedToInternet))
|
||||
return
|
||||
}
|
||||
|
||||
do {
|
||||
let url = connection.baseURL.appendingPathComponent("v1/chat/completions")
|
||||
var request = URLRequest(url: url)
|
||||
request.httpMethod = "POST"
|
||||
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
|
||||
request.httpBody = try JSONEncoder().encode(body)
|
||||
|
||||
let (bytes, response) = try await self.session.bytes(for: request)
|
||||
|
||||
guard let httpResponse = response as? HTTPURLResponse,
|
||||
(200..<300).contains(httpResponse.statusCode)
|
||||
else {
|
||||
continuation.finish(throwing: URLError(.badServerResponse))
|
||||
return
|
||||
}
|
||||
|
||||
let parser = SSEStreamParser<ChatCompletionChunk>(
|
||||
bytes: bytes, decoder: self.decoder)
|
||||
for try await chunk in parser {
|
||||
continuation.yield(chunk)
|
||||
}
|
||||
continuation.finish()
|
||||
} catch {
|
||||
continuation.finish(throwing: error)
|
||||
}
|
||||
}
|
||||
|
||||
continuation.onTermination = { _ in
|
||||
task.cancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Persistence
|
||||
|
||||
private func persistConnection(_ info: ConnectionInfo) {
|
||||
if let data = try? JSONEncoder().encode(info) {
|
||||
UserDefaults.standard.set(data, forKey: Self.connectionInfoKey)
|
||||
}
|
||||
}
|
||||
|
||||
private func loadPersistedConnection() -> ConnectionInfo? {
|
||||
guard let data = UserDefaults.standard.data(forKey: Self.connectionInfoKey) else {
|
||||
return nil
|
||||
}
|
||||
return try? JSONDecoder().decode(ConnectionInfo.self, from: data)
|
||||
}
|
||||
}
|
||||
123
app/EXO-iOS/EXO-iOS/Services/DiscoveryService.swift
Normal file
123
app/EXO-iOS/EXO-iOS/Services/DiscoveryService.swift
Normal file
@@ -0,0 +1,123 @@
|
||||
import Foundation
|
||||
import Network
|
||||
import os
|
||||
|
||||
struct DiscoveredCluster: Identifiable, Equatable {
|
||||
let id: String
|
||||
let name: String
|
||||
let endpoint: NWEndpoint
|
||||
|
||||
static func == (lhs: DiscoveredCluster, rhs: DiscoveredCluster) -> Bool {
|
||||
lhs.id == rhs.id && lhs.name == rhs.name
|
||||
}
|
||||
}
|
||||
|
||||
@Observable
|
||||
@MainActor
|
||||
final class DiscoveryService {
|
||||
private(set) var discoveredClusters: [DiscoveredCluster] = []
|
||||
private(set) var isSearching = false
|
||||
|
||||
private var browser: NWBrowser?
|
||||
|
||||
func startBrowsing() {
|
||||
guard browser == nil else { return }
|
||||
|
||||
let browser = NWBrowser(for: .bonjour(type: "_exo._tcp", domain: nil), using: .tcp)
|
||||
|
||||
browser.stateUpdateHandler = { [weak self] state in
|
||||
guard let service = self else { return }
|
||||
Task { @MainActor in
|
||||
switch state {
|
||||
case .ready:
|
||||
service.isSearching = true
|
||||
case .failed, .cancelled:
|
||||
service.isSearching = false
|
||||
default:
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
browser.browseResultsChangedHandler = { [weak self] results, _ in
|
||||
guard let service = self else { return }
|
||||
Task { @MainActor in
|
||||
service.discoveredClusters = results.compactMap { result in
|
||||
guard case .service(let name, _, _, _) = result.endpoint else {
|
||||
return nil
|
||||
}
|
||||
return DiscoveredCluster(
|
||||
id: name,
|
||||
name: name,
|
||||
endpoint: result.endpoint
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
browser.start(queue: .main)
|
||||
self.browser = browser
|
||||
}
|
||||
|
||||
func stopBrowsing() {
|
||||
browser?.cancel()
|
||||
browser = nil
|
||||
isSearching = false
|
||||
discoveredClusters = []
|
||||
}
|
||||
|
||||
/// Resolve a discovered Bonjour endpoint to an IP address and port, then return a ConnectionInfo.
|
||||
func resolve(_ cluster: DiscoveredCluster) async -> ConnectionInfo? {
|
||||
await withCheckedContinuation { continuation in
|
||||
let didResume = OSAllocatedUnfairLock(initialState: false)
|
||||
let connection = NWConnection(to: cluster.endpoint, using: .tcp)
|
||||
connection.stateUpdateHandler = { state in
|
||||
guard
|
||||
didResume.withLock({
|
||||
guard !$0 else { return false }
|
||||
$0 = true
|
||||
return true
|
||||
})
|
||||
else { return }
|
||||
switch state {
|
||||
case .ready:
|
||||
if let innerEndpoint = connection.currentPath?.remoteEndpoint,
|
||||
case .hostPort(let host, let port) = innerEndpoint
|
||||
{
|
||||
var hostString: String
|
||||
switch host {
|
||||
case .ipv4(let addr):
|
||||
hostString = "\(addr)"
|
||||
case .ipv6(let addr):
|
||||
hostString = "\(addr)"
|
||||
case .name(let name, _):
|
||||
hostString = name
|
||||
@unknown default:
|
||||
hostString = "\(host)"
|
||||
}
|
||||
// Strip interface scope suffix (e.g. "%en0")
|
||||
if let pct = hostString.firstIndex(of: "%") {
|
||||
hostString = String(hostString[..<pct])
|
||||
}
|
||||
let info = ConnectionInfo(
|
||||
host: hostString,
|
||||
port: Int(port.rawValue),
|
||||
nodeId: nil
|
||||
)
|
||||
connection.cancel()
|
||||
continuation.resume(returning: info)
|
||||
} else {
|
||||
connection.cancel()
|
||||
continuation.resume(returning: nil)
|
||||
}
|
||||
case .failed, .cancelled:
|
||||
continuation.resume(returning: nil)
|
||||
default:
|
||||
// Not a terminal state — allow future callbacks
|
||||
didResume.withLock { $0 = false }
|
||||
}
|
||||
}
|
||||
connection.start(queue: .global(qos: .userInitiated))
|
||||
}
|
||||
}
|
||||
}
|
||||
201
app/EXO-iOS/EXO-iOS/Services/LocalInferenceService.swift
Normal file
201
app/EXO-iOS/EXO-iOS/Services/LocalInferenceService.swift
Normal file
@@ -0,0 +1,201 @@
|
||||
import Foundation
|
||||
import MLXLLM
|
||||
import MLXLMCommon
|
||||
|
||||
enum LocalModelState: Equatable {
|
||||
case notDownloaded
|
||||
case downloading(progress: Double)
|
||||
case downloaded
|
||||
case loading
|
||||
case ready
|
||||
case generating
|
||||
case error(String)
|
||||
}
|
||||
|
||||
@Observable
|
||||
@MainActor
|
||||
final class LocalInferenceService {
|
||||
private(set) var modelState: LocalModelState = .notDownloaded
|
||||
private var modelContainer: ModelContainer?
|
||||
private var generationTask: Task<Void, Never>?
|
||||
|
||||
let defaultModelId = "mlx-community/Qwen3-0.6B-4bit"
|
||||
|
||||
private static let modelDownloadedKey = "exo_local_model_downloaded"
|
||||
|
||||
var isReady: Bool {
|
||||
modelState == .ready
|
||||
}
|
||||
|
||||
var isAvailable: Bool {
|
||||
modelState == .ready || modelState == .generating
|
||||
}
|
||||
|
||||
init() {
|
||||
if UserDefaults.standard.bool(forKey: Self.modelDownloadedKey) {
|
||||
modelState = .downloaded
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Model Lifecycle
|
||||
|
||||
func prepareModel() async {
|
||||
guard modelState == .notDownloaded || modelState == .downloaded else { return }
|
||||
|
||||
let wasDownloaded = modelState == .downloaded
|
||||
|
||||
if !wasDownloaded {
|
||||
modelState = .downloading(progress: 0)
|
||||
} else {
|
||||
modelState = .loading
|
||||
}
|
||||
|
||||
do {
|
||||
let container = try await loadModelContainer(
|
||||
id: defaultModelId
|
||||
) { [weak self] progress in
|
||||
guard let self else { return }
|
||||
Task { @MainActor in
|
||||
if case .downloading = self.modelState {
|
||||
self.modelState = .downloading(progress: progress.fractionCompleted)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.modelContainer = container
|
||||
UserDefaults.standard.set(true, forKey: Self.modelDownloadedKey)
|
||||
modelState = .ready
|
||||
} catch {
|
||||
modelState = .error(error.localizedDescription)
|
||||
}
|
||||
}
|
||||
|
||||
func unloadModel() {
|
||||
cancelGeneration()
|
||||
modelContainer = nil
|
||||
modelState = .downloaded
|
||||
}
|
||||
|
||||
// MARK: - Generation
|
||||
|
||||
func streamChatCompletion(request: ChatCompletionRequest) -> AsyncThrowingStream<
|
||||
ChatCompletionChunk, Error
|
||||
> {
|
||||
AsyncThrowingStream { continuation in
|
||||
let task = Task { [weak self] in
|
||||
guard let self else {
|
||||
continuation.finish(throwing: LocalInferenceError.serviceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
guard let container = self.modelContainer else {
|
||||
continuation.finish(throwing: LocalInferenceError.modelNotLoaded)
|
||||
return
|
||||
}
|
||||
|
||||
await MainActor.run {
|
||||
self.modelState = .generating
|
||||
}
|
||||
|
||||
defer {
|
||||
Task { @MainActor [weak self] in
|
||||
if self?.modelState == .generating {
|
||||
self?.modelState = .ready
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let chunkId = "local-\(UUID().uuidString)"
|
||||
|
||||
do {
|
||||
// Build Chat.Message array from the request
|
||||
var chatMessages: [Chat.Message] = []
|
||||
for msg in request.messages {
|
||||
switch msg.role {
|
||||
case "system":
|
||||
chatMessages.append(.system(msg.content))
|
||||
case "assistant":
|
||||
chatMessages.append(.assistant(msg.content))
|
||||
default:
|
||||
chatMessages.append(.user(msg.content))
|
||||
}
|
||||
}
|
||||
|
||||
// Use ChatSession for streaming generation
|
||||
let session = ChatSession(
|
||||
container,
|
||||
history: chatMessages,
|
||||
generateParameters: GenerateParameters(
|
||||
maxTokens: request.maxTokens ?? 4096,
|
||||
temperature: Float(request.temperature ?? 0.7)
|
||||
)
|
||||
)
|
||||
|
||||
// Stream with an empty prompt since history already contains the conversation
|
||||
let stream = session.streamResponse(to: "")
|
||||
for try await text in stream {
|
||||
if Task.isCancelled { break }
|
||||
|
||||
let chunk = ChatCompletionChunk(
|
||||
id: chunkId,
|
||||
model: request.model,
|
||||
choices: [
|
||||
StreamingChoice(
|
||||
index: 0,
|
||||
delta: Delta(role: nil, content: text),
|
||||
finishReason: nil
|
||||
)
|
||||
],
|
||||
usage: nil
|
||||
)
|
||||
continuation.yield(chunk)
|
||||
}
|
||||
|
||||
// Send final chunk with finish reason
|
||||
let finalChunk = ChatCompletionChunk(
|
||||
id: chunkId,
|
||||
model: request.model,
|
||||
choices: [
|
||||
StreamingChoice(
|
||||
index: 0,
|
||||
delta: Delta(role: nil, content: nil),
|
||||
finishReason: "stop"
|
||||
)
|
||||
],
|
||||
usage: nil
|
||||
)
|
||||
continuation.yield(finalChunk)
|
||||
continuation.finish()
|
||||
} catch {
|
||||
continuation.finish(throwing: error)
|
||||
}
|
||||
}
|
||||
|
||||
self.generationTask = task
|
||||
|
||||
continuation.onTermination = { _ in
|
||||
task.cancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func cancelGeneration() {
|
||||
generationTask?.cancel()
|
||||
generationTask = nil
|
||||
if modelState == .generating {
|
||||
modelState = .ready
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum LocalInferenceError: LocalizedError {
|
||||
case serviceUnavailable
|
||||
case modelNotLoaded
|
||||
|
||||
var errorDescription: String? {
|
||||
switch self {
|
||||
case .serviceUnavailable: "Local inference service is unavailable"
|
||||
case .modelNotLoaded: "Local model is not loaded"
|
||||
}
|
||||
}
|
||||
}
|
||||
50
app/EXO-iOS/EXO-iOS/Services/SSEStreamParser.swift
Normal file
50
app/EXO-iOS/EXO-iOS/Services/SSEStreamParser.swift
Normal file
@@ -0,0 +1,50 @@
|
||||
import Foundation
|
||||
|
||||
struct SSEStreamParser<T: Decodable>: AsyncSequence {
|
||||
typealias Element = T
|
||||
|
||||
let bytes: URLSession.AsyncBytes
|
||||
let decoder: JSONDecoder
|
||||
|
||||
init(bytes: URLSession.AsyncBytes, decoder: JSONDecoder = JSONDecoder()) {
|
||||
self.bytes = bytes
|
||||
self.decoder = decoder
|
||||
}
|
||||
|
||||
func makeAsyncIterator() -> AsyncIterator {
|
||||
AsyncIterator(lines: bytes.lines, decoder: decoder)
|
||||
}
|
||||
|
||||
struct AsyncIterator: AsyncIteratorProtocol {
|
||||
var lines: AsyncLineSequence<URLSession.AsyncBytes>.AsyncIterator
|
||||
let decoder: JSONDecoder
|
||||
|
||||
init(lines: AsyncLineSequence<URLSession.AsyncBytes>, decoder: JSONDecoder) {
|
||||
self.lines = lines.makeAsyncIterator()
|
||||
self.decoder = decoder
|
||||
}
|
||||
|
||||
mutating func next() async throws -> T? {
|
||||
while let line = try await lines.next() {
|
||||
let trimmed = line.trimmingCharacters(in: .whitespaces)
|
||||
|
||||
guard trimmed.hasPrefix("data: ") else { continue }
|
||||
|
||||
let payload = String(trimmed.dropFirst(6))
|
||||
|
||||
if payload == "[DONE]" {
|
||||
return nil
|
||||
}
|
||||
|
||||
guard let data = payload.data(using: .utf8) else { continue }
|
||||
|
||||
do {
|
||||
return try decoder.decode(T.self, from: data)
|
||||
} catch {
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
203
app/EXO-iOS/EXO-iOS/Views/Chat/ChatView.swift
Normal file
203
app/EXO-iOS/EXO-iOS/Views/Chat/ChatView.swift
Normal file
@@ -0,0 +1,203 @@
|
||||
import SwiftUI
|
||||
|
||||
struct ChatView: View {
|
||||
@Environment(ClusterService.self) private var clusterService
|
||||
@Environment(ChatService.self) private var chatService
|
||||
@Environment(LocalInferenceService.self) private var localInferenceService
|
||||
@State private var inputText = ""
|
||||
@State private var showModelSelector = false
|
||||
|
||||
var body: some View {
|
||||
VStack(spacing: 0) {
|
||||
modelBar
|
||||
|
||||
GradientDivider()
|
||||
|
||||
messageList
|
||||
|
||||
GradientDivider()
|
||||
|
||||
inputBar
|
||||
}
|
||||
.background(Color.exoBlack)
|
||||
.sheet(isPresented: $showModelSelector) {
|
||||
ModelSelectorView(
|
||||
models: clusterService.availableModels,
|
||||
selectedModelId: chatService.activeConversation?.modelId
|
||||
) { modelId in
|
||||
chatService.setModelForActiveConversation(modelId)
|
||||
}
|
||||
.presentationBackground(Color.exoDarkGray)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Model Bar
|
||||
|
||||
private var useLocalModel: Bool {
|
||||
!clusterService.isConnected && localInferenceService.isAvailable
|
||||
}
|
||||
|
||||
private var modelBar: some View {
|
||||
Button {
|
||||
if !useLocalModel {
|
||||
showModelSelector = true
|
||||
}
|
||||
} label: {
|
||||
HStack {
|
||||
Image(systemName: useLocalModel ? "iphone" : "cpu")
|
||||
.font(.exoCaption)
|
||||
.foregroundStyle(useLocalModel ? Color.exoYellow : Color.exoLightGray)
|
||||
|
||||
if useLocalModel {
|
||||
Text(localInferenceService.defaultModelId)
|
||||
.font(.exoSubheadline)
|
||||
.foregroundStyle(Color.exoForeground)
|
||||
.lineLimit(1)
|
||||
} else if let modelId = chatService.activeConversation?.modelId {
|
||||
Text(modelId)
|
||||
.font(.exoSubheadline)
|
||||
.foregroundStyle(Color.exoForeground)
|
||||
.lineLimit(1)
|
||||
} else {
|
||||
Text("SELECT MODEL")
|
||||
.font(.exoSubheadline)
|
||||
.tracking(1.5)
|
||||
.foregroundStyle(Color.exoLightGray)
|
||||
}
|
||||
|
||||
Spacer()
|
||||
|
||||
if useLocalModel {
|
||||
Text("ON-DEVICE")
|
||||
.font(.exoCaption)
|
||||
.tracking(1)
|
||||
.foregroundStyle(Color.exoYellow)
|
||||
.padding(.horizontal, 6)
|
||||
.padding(.vertical, 2)
|
||||
.background(Color.exoYellow.opacity(0.15))
|
||||
.clipShape(Capsule())
|
||||
} else {
|
||||
Image(systemName: "chevron.right")
|
||||
.font(.caption)
|
||||
.foregroundStyle(Color.exoLightGray)
|
||||
}
|
||||
}
|
||||
.padding(.horizontal)
|
||||
.padding(.vertical, 10)
|
||||
.background(Color.exoDarkGray)
|
||||
}
|
||||
.tint(.primary)
|
||||
.disabled(useLocalModel)
|
||||
}
|
||||
|
||||
// MARK: - Messages
|
||||
|
||||
private var messageList: some View {
|
||||
ScrollViewReader { proxy in
|
||||
ScrollView {
|
||||
LazyVStack(spacing: 12) {
|
||||
if chatService.activeMessages.isEmpty {
|
||||
emptyState
|
||||
} else {
|
||||
ForEach(chatService.activeMessages) { message in
|
||||
MessageBubbleView(message: message)
|
||||
.id(message.id)
|
||||
}
|
||||
}
|
||||
}
|
||||
.padding()
|
||||
}
|
||||
.background(Color.exoBlack)
|
||||
.onChange(of: chatService.activeMessages.last?.content) {
|
||||
if let lastId = chatService.activeMessages.last?.id {
|
||||
withAnimation(.easeOut(duration: 0.2)) {
|
||||
proxy.scrollTo(lastId, anchor: .bottom)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private var emptyState: some View {
|
||||
VStack(spacing: 16) {
|
||||
Spacer(minLength: 80)
|
||||
|
||||
ZStack {
|
||||
Circle()
|
||||
.stroke(Color.exoYellow.opacity(0.15), lineWidth: 1)
|
||||
.frame(width: 80, height: 80)
|
||||
Circle()
|
||||
.stroke(Color.exoYellow.opacity(0.3), lineWidth: 1)
|
||||
.frame(width: 56, height: 56)
|
||||
Circle()
|
||||
.fill(Color.exoYellow.opacity(0.15))
|
||||
.frame(width: 32, height: 32)
|
||||
Circle()
|
||||
.fill(Color.exoYellow)
|
||||
.frame(width: 8, height: 8)
|
||||
.shadow(color: Color.exoYellow.opacity(0.6), radius: 6)
|
||||
}
|
||||
|
||||
Text("AWAITING INPUT")
|
||||
.font(.exoSubheadline)
|
||||
.tracking(3)
|
||||
.foregroundStyle(Color.exoLightGray)
|
||||
|
||||
Text("Send a message to begin.")
|
||||
.font(.exoCaption)
|
||||
.foregroundStyle(Color.exoLightGray.opacity(0.6))
|
||||
|
||||
Spacer(minLength: 80)
|
||||
}
|
||||
.padding()
|
||||
}
|
||||
|
||||
// MARK: - Input
|
||||
|
||||
private var inputBar: some View {
|
||||
HStack(alignment: .bottom, spacing: 8) {
|
||||
TextField("Message...", text: $inputText, axis: .vertical)
|
||||
.font(.exoBody)
|
||||
.lineLimit(1...6)
|
||||
.textFieldStyle(.plain)
|
||||
.padding(10)
|
||||
.background(Color.exoMediumGray)
|
||||
.foregroundStyle(Color.exoForeground)
|
||||
.clipShape(RoundedRectangle(cornerRadius: 8))
|
||||
|
||||
if chatService.isGenerating {
|
||||
Button {
|
||||
chatService.cancelGeneration()
|
||||
} label: {
|
||||
Image(systemName: "stop.circle.fill")
|
||||
.font(.title2)
|
||||
.foregroundStyle(Color.exoDestructive)
|
||||
}
|
||||
} else {
|
||||
Button {
|
||||
let text = inputText
|
||||
inputText = ""
|
||||
chatService.sendMessage(text)
|
||||
} label: {
|
||||
Text("SEND")
|
||||
.font(.exoMono(12, weight: .bold))
|
||||
.tracking(1)
|
||||
.foregroundStyle(canSend ? Color.exoBlack : Color.exoLightGray)
|
||||
.padding(.horizontal, 14)
|
||||
.padding(.vertical, 8)
|
||||
.background(canSend ? Color.exoYellow : Color.exoMediumGray)
|
||||
.clipShape(RoundedRectangle(cornerRadius: 8))
|
||||
}
|
||||
.disabled(!canSend)
|
||||
}
|
||||
}
|
||||
.padding(.horizontal)
|
||||
.padding(.vertical, 8)
|
||||
.background(Color.exoDarkGray)
|
||||
}
|
||||
|
||||
private var canSend: Bool {
|
||||
!inputText.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty
|
||||
&& (clusterService.isConnected || localInferenceService.isAvailable)
|
||||
}
|
||||
}
|
||||
54
app/EXO-iOS/EXO-iOS/Views/Chat/MessageBubbleView.swift
Normal file
54
app/EXO-iOS/EXO-iOS/Views/Chat/MessageBubbleView.swift
Normal file
@@ -0,0 +1,54 @@
|
||||
import SwiftUI
|
||||
|
||||
struct MessageBubbleView: View {
|
||||
let message: ChatMessage
|
||||
|
||||
private var isAssistant: Bool { message.role == .assistant }
|
||||
|
||||
var body: some View {
|
||||
HStack {
|
||||
if message.role == .user { Spacer(minLength: 48) }
|
||||
|
||||
VStack(alignment: isAssistant ? .leading : .trailing, spacing: 6) {
|
||||
// Header
|
||||
HStack(spacing: 4) {
|
||||
if isAssistant {
|
||||
Circle()
|
||||
.fill(Color.exoYellow)
|
||||
.frame(width: 6, height: 6)
|
||||
.shadow(color: Color.exoYellow.opacity(0.6), radius: 4)
|
||||
Text("EXO")
|
||||
.font(.exoMono(10, weight: .bold))
|
||||
.tracking(1.5)
|
||||
.foregroundStyle(Color.exoYellow)
|
||||
} else {
|
||||
Text("QUERY")
|
||||
.font(.exoMono(10, weight: .medium))
|
||||
.tracking(1.5)
|
||||
.foregroundStyle(Color.exoLightGray)
|
||||
}
|
||||
}
|
||||
|
||||
// Bubble
|
||||
HStack(spacing: 0) {
|
||||
if isAssistant {
|
||||
RoundedRectangle(cornerRadius: 1)
|
||||
.fill(Color.exoYellow.opacity(0.5))
|
||||
.frame(width: 2)
|
||||
}
|
||||
|
||||
Text(message.content + (message.isStreaming ? " \u{258C}" : ""))
|
||||
.font(.exoBody)
|
||||
.textSelection(.enabled)
|
||||
.foregroundStyle(Color.exoForeground)
|
||||
.padding(.horizontal, 14)
|
||||
.padding(.vertical, 10)
|
||||
}
|
||||
.background(Color.exoDarkGray)
|
||||
.clipShape(RoundedRectangle(cornerRadius: 8))
|
||||
}
|
||||
|
||||
if isAssistant { Spacer(minLength: 48) }
|
||||
}
|
||||
}
|
||||
}
|
||||
75
app/EXO-iOS/EXO-iOS/Views/Chat/ModelSelectorView.swift
Normal file
75
app/EXO-iOS/EXO-iOS/Views/Chat/ModelSelectorView.swift
Normal file
@@ -0,0 +1,75 @@
|
||||
import SwiftUI
|
||||
|
||||
struct ModelSelectorView: View {
|
||||
let models: [ModelOption]
|
||||
let selectedModelId: String?
|
||||
let onSelect: (String) -> Void
|
||||
@Environment(\.dismiss) private var dismiss
|
||||
|
||||
var body: some View {
|
||||
NavigationStack {
|
||||
List {
|
||||
if models.isEmpty {
|
||||
emptyContent
|
||||
} else {
|
||||
modelsList
|
||||
}
|
||||
}
|
||||
.scrollContentBackground(.hidden)
|
||||
.background(Color.exoBlack)
|
||||
.navigationTitle("SELECT MODEL")
|
||||
.navigationBarTitleDisplayMode(.inline)
|
||||
.toolbar {
|
||||
ToolbarItem(placement: .cancellationAction) {
|
||||
Button("Cancel") { dismiss() }
|
||||
.font(.exoSubheadline)
|
||||
.foregroundStyle(Color.exoYellow)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private var emptyContent: some View {
|
||||
ContentUnavailableView(
|
||||
"No Models Available",
|
||||
systemImage: "cpu",
|
||||
description: Text("Connect to an EXO cluster to see available models.")
|
||||
)
|
||||
.foregroundStyle(Color.exoLightGray)
|
||||
.listRowBackground(Color.exoBlack)
|
||||
}
|
||||
|
||||
private var modelsList: some View {
|
||||
ForEach(models) { model in
|
||||
Button {
|
||||
onSelect(model.id)
|
||||
dismiss()
|
||||
} label: {
|
||||
modelRow(model)
|
||||
}
|
||||
.tint(.primary)
|
||||
.listRowBackground(Color.exoDarkGray)
|
||||
}
|
||||
}
|
||||
|
||||
private func modelRow(_ model: ModelOption) -> some View {
|
||||
HStack {
|
||||
VStack(alignment: .leading, spacing: 2) {
|
||||
Text(model.displayName)
|
||||
.font(.exoSubheadline)
|
||||
.foregroundStyle(Color.exoForeground)
|
||||
Text(model.id)
|
||||
.font(.exoCaption)
|
||||
.foregroundStyle(Color.exoLightGray)
|
||||
}
|
||||
|
||||
Spacer()
|
||||
|
||||
if model.id == selectedModelId {
|
||||
Image(systemName: "checkmark")
|
||||
.font(.exoSubheadline)
|
||||
.foregroundStyle(Color.exoYellow)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
import SwiftUI
|
||||
|
||||
struct ConnectionStatusBadge: View {
|
||||
let connectionState: ConnectionState
|
||||
var localModelState: LocalModelState = .notDownloaded
|
||||
|
||||
private var isLocalReady: Bool {
|
||||
if case .disconnected = connectionState {
|
||||
return localModelState == .ready || localModelState == .generating
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
var body: some View {
|
||||
HStack(spacing: 6) {
|
||||
Circle()
|
||||
.fill(dotColor)
|
||||
.frame(width: 8, height: 8)
|
||||
.shadow(color: dotColor.opacity(0.6), radius: 4)
|
||||
|
||||
Text(label.uppercased())
|
||||
.font(.exoMono(10, weight: .medium))
|
||||
.tracking(1)
|
||||
.foregroundStyle(Color.exoForeground)
|
||||
}
|
||||
.padding(.horizontal, 10)
|
||||
.padding(.vertical, 5)
|
||||
.background(backgroundColor)
|
||||
.clipShape(Capsule())
|
||||
.overlay(
|
||||
Capsule()
|
||||
.stroke(dotColor.opacity(0.3), lineWidth: 1)
|
||||
)
|
||||
}
|
||||
|
||||
private var dotColor: Color {
|
||||
if isLocalReady {
|
||||
return .exoYellow
|
||||
}
|
||||
switch connectionState {
|
||||
case .connected: return .green
|
||||
case .connecting: return .orange
|
||||
case .disconnected: return .exoLightGray
|
||||
}
|
||||
}
|
||||
|
||||
private var label: String {
|
||||
if isLocalReady {
|
||||
return "Local"
|
||||
}
|
||||
switch connectionState {
|
||||
case .connected: return "Connected"
|
||||
case .connecting: return "Connecting"
|
||||
case .disconnected: return "Disconnected"
|
||||
}
|
||||
}
|
||||
|
||||
private var backgroundColor: Color {
|
||||
if isLocalReady {
|
||||
return Color.exoYellow.opacity(0.1)
|
||||
}
|
||||
switch connectionState {
|
||||
case .connected: return .green.opacity(0.1)
|
||||
case .connecting: return .orange.opacity(0.1)
|
||||
case .disconnected: return Color.exoMediumGray.opacity(0.5)
|
||||
}
|
||||
}
|
||||
}
|
||||
136
app/EXO-iOS/EXO-iOS/Views/RootView.swift
Normal file
136
app/EXO-iOS/EXO-iOS/Views/RootView.swift
Normal file
@@ -0,0 +1,136 @@
|
||||
import SwiftUI
|
||||
|
||||
struct RootView: View {
|
||||
@Environment(ClusterService.self) private var clusterService
|
||||
@Environment(DiscoveryService.self) private var discoveryService
|
||||
@Environment(ChatService.self) private var chatService
|
||||
@Environment(LocalInferenceService.self) private var localInferenceService
|
||||
@State private var showSettings = false
|
||||
@State private var showConversations = false
|
||||
|
||||
var body: some View {
|
||||
NavigationStack {
|
||||
ChatView()
|
||||
.navigationTitle("EXO")
|
||||
.navigationBarTitleDisplayMode(.inline)
|
||||
.toolbar {
|
||||
ToolbarItem(placement: .topBarLeading) {
|
||||
conversationMenuButton
|
||||
}
|
||||
|
||||
ToolbarItem(placement: .principal) {
|
||||
ConnectionStatusBadge(
|
||||
connectionState: clusterService.connectionState,
|
||||
localModelState: localInferenceService.modelState
|
||||
)
|
||||
}
|
||||
|
||||
ToolbarItem(placement: .topBarTrailing) {
|
||||
Button {
|
||||
showSettings = true
|
||||
} label: {
|
||||
Image(systemName: "gear")
|
||||
.foregroundStyle(Color.exoYellow)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
.tint(Color.exoYellow)
|
||||
.sheet(isPresented: $showSettings) {
|
||||
SettingsView()
|
||||
.environment(discoveryService)
|
||||
.presentationBackground(Color.exoDarkGray)
|
||||
}
|
||||
.sheet(isPresented: $showConversations) {
|
||||
conversationList
|
||||
.presentationBackground(Color.exoDarkGray)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Conversations
|
||||
|
||||
private var conversationMenuButton: some View {
|
||||
HStack(spacing: 12) {
|
||||
Button {
|
||||
showConversations = true
|
||||
} label: {
|
||||
Image(systemName: "sidebar.left")
|
||||
.foregroundStyle(Color.exoYellow)
|
||||
}
|
||||
|
||||
Button {
|
||||
chatService.createConversation()
|
||||
} label: {
|
||||
Image(systemName: "square.and.pencil")
|
||||
.foregroundStyle(Color.exoYellow)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private var conversationList: some View {
|
||||
NavigationStack {
|
||||
List {
|
||||
if chatService.conversations.isEmpty {
|
||||
Text("No conversations yet")
|
||||
.font(.exoBody)
|
||||
.foregroundStyle(Color.exoLightGray)
|
||||
.listRowBackground(Color.exoDarkGray)
|
||||
} else {
|
||||
ForEach(chatService.conversations) { conversation in
|
||||
let isActive = conversation.id == chatService.activeConversationId
|
||||
Button {
|
||||
chatService.setActiveConversation(id: conversation.id)
|
||||
showConversations = false
|
||||
} label: {
|
||||
VStack(alignment: .leading, spacing: 4) {
|
||||
Text(conversation.title)
|
||||
.font(.exoSubheadline)
|
||||
.fontWeight(isActive ? .semibold : .regular)
|
||||
.foregroundStyle(
|
||||
isActive ? Color.exoYellow : Color.exoForeground
|
||||
)
|
||||
.lineLimit(1)
|
||||
|
||||
if let modelId = conversation.modelId {
|
||||
Text(modelId)
|
||||
.font(.exoCaption)
|
||||
.foregroundStyle(Color.exoLightGray)
|
||||
.lineLimit(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
.listRowBackground(
|
||||
isActive
|
||||
? Color.exoYellow.opacity(0.1)
|
||||
: Color.exoDarkGray
|
||||
)
|
||||
}
|
||||
.onDelete { indexSet in
|
||||
for index in indexSet {
|
||||
chatService.deleteConversation(id: chatService.conversations[index].id)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
.scrollContentBackground(.hidden)
|
||||
.background(Color.exoBlack)
|
||||
.navigationTitle("Conversations")
|
||||
.navigationBarTitleDisplayMode(.inline)
|
||||
.toolbar {
|
||||
ToolbarItem(placement: .confirmationAction) {
|
||||
Button("Done") { showConversations = false }
|
||||
.font(.exoSubheadline)
|
||||
.foregroundStyle(Color.exoYellow)
|
||||
}
|
||||
ToolbarItem(placement: .topBarLeading) {
|
||||
Button {
|
||||
chatService.createConversation()
|
||||
} label: {
|
||||
Image(systemName: "plus")
|
||||
.foregroundStyle(Color.exoYellow)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
314
app/EXO-iOS/EXO-iOS/Views/Settings/SettingsView.swift
Normal file
314
app/EXO-iOS/EXO-iOS/Views/Settings/SettingsView.swift
Normal file
@@ -0,0 +1,314 @@
|
||||
import SwiftUI
|
||||
|
||||
struct SettingsView: View {
|
||||
@Environment(ClusterService.self) private var clusterService
|
||||
@Environment(DiscoveryService.self) private var discoveryService
|
||||
@Environment(LocalInferenceService.self) private var localInferenceService
|
||||
@Environment(\.dismiss) private var dismiss
|
||||
@State private var host: String = ""
|
||||
@State private var port: String = "52415"
|
||||
|
||||
var body: some View {
|
||||
NavigationStack {
|
||||
Form {
|
||||
localModelSection
|
||||
nearbyClustersSection
|
||||
connectionSection
|
||||
statusSection
|
||||
}
|
||||
.scrollContentBackground(.hidden)
|
||||
.background(Color.exoBlack)
|
||||
.navigationTitle("Settings")
|
||||
.navigationBarTitleDisplayMode(.inline)
|
||||
.toolbar {
|
||||
ToolbarItem(placement: .confirmationAction) {
|
||||
Button("Done") { dismiss() }
|
||||
.font(.exoSubheadline)
|
||||
.foregroundStyle(Color.exoYellow)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Section Headers
|
||||
|
||||
private func sectionHeader(_ title: String) -> some View {
|
||||
Text(title.uppercased())
|
||||
.font(.exoMono(10, weight: .semibold))
|
||||
.tracking(2)
|
||||
.foregroundStyle(Color.exoYellow)
|
||||
}
|
||||
|
||||
// MARK: - Local Model
|
||||
|
||||
private var localModelSection: some View {
|
||||
Section {
|
||||
HStack {
|
||||
VStack(alignment: .leading, spacing: 4) {
|
||||
Text(localInferenceService.defaultModelId)
|
||||
.font(.exoSubheadline)
|
||||
.foregroundStyle(Color.exoForeground)
|
||||
|
||||
Text(localModelStatusText)
|
||||
.font(.exoCaption)
|
||||
.foregroundStyle(Color.exoLightGray)
|
||||
}
|
||||
|
||||
Spacer()
|
||||
|
||||
localModelActionButton
|
||||
}
|
||||
.listRowBackground(Color.exoDarkGray)
|
||||
|
||||
if case .downloading(let progress) = localInferenceService.modelState {
|
||||
ProgressView(value: progress)
|
||||
.tint(Color.exoYellow)
|
||||
.listRowBackground(Color.exoDarkGray)
|
||||
}
|
||||
} header: {
|
||||
sectionHeader("Local Model")
|
||||
} footer: {
|
||||
Text(
|
||||
"When disconnected from a cluster, messages are processed on-device using this model."
|
||||
)
|
||||
.font(.exoCaption)
|
||||
.foregroundStyle(Color.exoLightGray.opacity(0.7))
|
||||
}
|
||||
}
|
||||
|
||||
private var localModelStatusText: String {
|
||||
switch localInferenceService.modelState {
|
||||
case .notDownloaded: "Not downloaded"
|
||||
case .downloading(let progress): "Downloading \(Int(progress * 100))%..."
|
||||
case .downloaded: "Downloaded — not loaded"
|
||||
case .loading: "Loading into memory..."
|
||||
case .ready: "Ready"
|
||||
case .generating: "Generating..."
|
||||
case .error(let message): "Error: \(message)"
|
||||
}
|
||||
}
|
||||
|
||||
@ViewBuilder
|
||||
private var localModelActionButton: some View {
|
||||
switch localInferenceService.modelState {
|
||||
case .notDownloaded:
|
||||
exoButton("Download") {
|
||||
Task { await localInferenceService.prepareModel() }
|
||||
}
|
||||
case .downloading:
|
||||
ProgressView()
|
||||
.controlSize(.small)
|
||||
.tint(Color.exoYellow)
|
||||
case .downloaded:
|
||||
exoButton("Load") {
|
||||
Task { await localInferenceService.prepareModel() }
|
||||
}
|
||||
case .loading:
|
||||
ProgressView()
|
||||
.controlSize(.small)
|
||||
.tint(Color.exoYellow)
|
||||
case .ready, .generating:
|
||||
exoButton("Unload") {
|
||||
localInferenceService.unloadModel()
|
||||
}
|
||||
case .error:
|
||||
exoButton("Retry", destructive: true) {
|
||||
Task { await localInferenceService.prepareModel() }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func exoButton(_ title: String, destructive: Bool = false, action: @escaping () -> Void)
|
||||
-> some View
|
||||
{
|
||||
let borderColor = destructive ? Color.exoDestructive : Color.exoYellow
|
||||
return Button(action: action) {
|
||||
Text(title.uppercased())
|
||||
.font(.exoMono(11, weight: .semibold))
|
||||
.tracking(1)
|
||||
.foregroundStyle(borderColor)
|
||||
.padding(.horizontal, 10)
|
||||
.padding(.vertical, 5)
|
||||
.overlay(
|
||||
RoundedRectangle(cornerRadius: 6)
|
||||
.stroke(borderColor, lineWidth: 1)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Nearby Clusters
|
||||
|
||||
private var nearbyClustersSection: some View {
|
||||
Section {
|
||||
if discoveryService.discoveredClusters.isEmpty {
|
||||
if discoveryService.isSearching {
|
||||
HStack {
|
||||
ProgressView()
|
||||
.tint(Color.exoYellow)
|
||||
.padding(.trailing, 8)
|
||||
Text("Searching for clusters...")
|
||||
.font(.exoBody)
|
||||
.foregroundStyle(Color.exoLightGray)
|
||||
}
|
||||
.listRowBackground(Color.exoDarkGray)
|
||||
} else {
|
||||
Text("No clusters found")
|
||||
.font(.exoBody)
|
||||
.foregroundStyle(Color.exoLightGray)
|
||||
.listRowBackground(Color.exoDarkGray)
|
||||
}
|
||||
} else {
|
||||
ForEach(discoveryService.discoveredClusters) { cluster in
|
||||
HStack {
|
||||
VStack(alignment: .leading) {
|
||||
Text(cluster.name)
|
||||
.font(.exoBody)
|
||||
.foregroundStyle(Color.exoForeground)
|
||||
}
|
||||
Spacer()
|
||||
exoButton("Connect") {
|
||||
Task {
|
||||
await clusterService.connectToDiscoveredCluster(
|
||||
cluster, using: discoveryService
|
||||
)
|
||||
if clusterService.isConnected {
|
||||
dismiss()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
.listRowBackground(Color.exoDarkGray)
|
||||
}
|
||||
}
|
||||
} header: {
|
||||
sectionHeader("Nearby Clusters")
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Manual Connection
|
||||
|
||||
private var connectionSection: some View {
|
||||
Section {
|
||||
TextField("IP Address (e.g. 192.168.1.42)", text: $host)
|
||||
.font(.exoBody)
|
||||
.keyboardType(.decimalPad)
|
||||
.textContentType(.URL)
|
||||
.autocorrectionDisabled()
|
||||
.foregroundStyle(Color.exoForeground)
|
||||
.listRowBackground(Color.exoDarkGray)
|
||||
|
||||
TextField("Port", text: $port)
|
||||
.font(.exoBody)
|
||||
.keyboardType(.numberPad)
|
||||
.foregroundStyle(Color.exoForeground)
|
||||
.listRowBackground(Color.exoDarkGray)
|
||||
|
||||
Button {
|
||||
Task {
|
||||
let portNum = Int(port) ?? ConnectionInfo.defaultPort
|
||||
let info = ConnectionInfo(host: host, port: portNum, nodeId: nil)
|
||||
await clusterService.connect(to: info)
|
||||
if clusterService.isConnected {
|
||||
dismiss()
|
||||
}
|
||||
}
|
||||
} label: {
|
||||
Text(clusterService.isConnected ? "RECONNECT" : "CONNECT")
|
||||
.font(.exoMono(13, weight: .semibold))
|
||||
.tracking(1.5)
|
||||
.foregroundStyle(
|
||||
host.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty
|
||||
? Color.exoLightGray : Color.exoYellow
|
||||
)
|
||||
.frame(maxWidth: .infinity)
|
||||
}
|
||||
.disabled(host.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty)
|
||||
.listRowBackground(Color.exoDarkGray)
|
||||
} header: {
|
||||
sectionHeader("Manual Connection")
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Status
|
||||
|
||||
private var statusSection: some View {
|
||||
Section {
|
||||
if let connection = clusterService.currentConnection {
|
||||
LabeledContent {
|
||||
Text(connection.host)
|
||||
.font(.exoCaption)
|
||||
.foregroundStyle(Color.exoForeground)
|
||||
} label: {
|
||||
Text("Host")
|
||||
.font(.exoCaption)
|
||||
.foregroundStyle(Color.exoLightGray)
|
||||
}
|
||||
.listRowBackground(Color.exoDarkGray)
|
||||
|
||||
LabeledContent {
|
||||
Text("\(connection.port)")
|
||||
.font(.exoCaption)
|
||||
.foregroundStyle(Color.exoForeground)
|
||||
} label: {
|
||||
Text("Port")
|
||||
.font(.exoCaption)
|
||||
.foregroundStyle(Color.exoLightGray)
|
||||
}
|
||||
.listRowBackground(Color.exoDarkGray)
|
||||
|
||||
if let nodeId = connection.nodeId {
|
||||
LabeledContent {
|
||||
Text(String(nodeId.prefix(12)) + "...")
|
||||
.font(.exoCaption)
|
||||
.foregroundStyle(Color.exoForeground)
|
||||
} label: {
|
||||
Text("Node ID")
|
||||
.font(.exoCaption)
|
||||
.foregroundStyle(Color.exoLightGray)
|
||||
}
|
||||
.listRowBackground(Color.exoDarkGray)
|
||||
}
|
||||
|
||||
LabeledContent {
|
||||
Text("\(clusterService.availableModels.count)")
|
||||
.font(.exoCaption)
|
||||
.foregroundStyle(Color.exoForeground)
|
||||
} label: {
|
||||
Text("Models")
|
||||
.font(.exoCaption)
|
||||
.foregroundStyle(Color.exoLightGray)
|
||||
}
|
||||
.listRowBackground(Color.exoDarkGray)
|
||||
|
||||
Button(role: .destructive) {
|
||||
clusterService.disconnect()
|
||||
} label: {
|
||||
Text("DISCONNECT")
|
||||
.font(.exoMono(13, weight: .semibold))
|
||||
.tracking(1.5)
|
||||
.foregroundStyle(Color.exoDestructive)
|
||||
.frame(maxWidth: .infinity)
|
||||
}
|
||||
.listRowBackground(Color.exoDarkGray)
|
||||
} else {
|
||||
if let error = clusterService.lastError {
|
||||
Label {
|
||||
Text(error)
|
||||
.font(.exoCaption)
|
||||
} icon: {
|
||||
Image(systemName: "exclamationmark.triangle")
|
||||
}
|
||||
.foregroundStyle(Color.exoDestructive)
|
||||
.listRowBackground(Color.exoDarkGray)
|
||||
} else {
|
||||
Text("Not connected")
|
||||
.font(.exoBody)
|
||||
.foregroundStyle(Color.exoLightGray)
|
||||
.listRowBackground(Color.exoDarkGray)
|
||||
}
|
||||
}
|
||||
} header: {
|
||||
sectionHeader("Status")
|
||||
}
|
||||
}
|
||||
}
|
||||
51
app/EXO-iOS/EXO-iOS/Views/Theme/EXOTheme.swift
Normal file
51
app/EXO-iOS/EXO-iOS/Views/Theme/EXOTheme.swift
Normal file
@@ -0,0 +1,51 @@
|
||||
import SwiftUI
|
||||
|
||||
// MARK: - EXO Color Palette
|
||||
|
||||
extension Color {
|
||||
/// Primary background — near-black (#121212)
|
||||
static let exoBlack = Color(red: 0x12 / 255.0, green: 0x12 / 255.0, blue: 0x12 / 255.0)
|
||||
/// Card / surface background (#1F1F1F)
|
||||
static let exoDarkGray = Color(red: 0x1F / 255.0, green: 0x1F / 255.0, blue: 0x1F / 255.0)
|
||||
/// Input field / elevated surface (#353535)
|
||||
static let exoMediumGray = Color(red: 0x35 / 255.0, green: 0x35 / 255.0, blue: 0x35 / 255.0)
|
||||
/// Secondary text (#999999)
|
||||
static let exoLightGray = Color(red: 0x99 / 255.0, green: 0x99 / 255.0, blue: 0x99 / 255.0)
|
||||
/// Accent yellow — matches dashboard (#FFD700)
|
||||
static let exoYellow = Color(red: 0xFF / 255.0, green: 0xD7 / 255.0, blue: 0x00 / 255.0)
|
||||
/// Primary foreground text (#E5E5E5)
|
||||
static let exoForeground = Color(red: 0xE5 / 255.0, green: 0xE5 / 255.0, blue: 0xE5 / 255.0)
|
||||
/// Destructive / error (#E74C3C)
|
||||
static let exoDestructive = Color(red: 0xE7 / 255.0, green: 0x4C / 255.0, blue: 0x3C / 255.0)
|
||||
}
|
||||
|
||||
// MARK: - EXO Typography (SF Mono via .monospaced design)
|
||||
|
||||
extension Font {
|
||||
/// Monospaced font at a given size and weight.
|
||||
static func exoMono(_ size: CGFloat, weight: Font.Weight = .regular) -> Font {
|
||||
.system(size: size, weight: weight, design: .monospaced)
|
||||
}
|
||||
|
||||
/// Body text — 15pt monospaced
|
||||
static let exoBody: Font = .system(size: 15, weight: .regular, design: .monospaced)
|
||||
/// Caption — 11pt monospaced
|
||||
static let exoCaption: Font = .system(size: 11, weight: .regular, design: .monospaced)
|
||||
/// Subheadline — 13pt monospaced medium
|
||||
static let exoSubheadline: Font = .system(size: 13, weight: .medium, design: .monospaced)
|
||||
/// Headline — 17pt monospaced semibold
|
||||
static let exoHeadline: Font = .system(size: 17, weight: .semibold, design: .monospaced)
|
||||
}
|
||||
|
||||
// MARK: - Reusable Gradient Divider
|
||||
|
||||
struct GradientDivider: View {
|
||||
var body: some View {
|
||||
LinearGradient(
|
||||
colors: [.clear, Color.exoYellow.opacity(0.3), .clear],
|
||||
startPoint: .leading,
|
||||
endPoint: .trailing
|
||||
)
|
||||
.frame(height: 1)
|
||||
}
|
||||
}
|
||||
18
app/EXO-iOS/EXO-iOSTests/EXO_iOSTests.swift
Normal file
18
app/EXO-iOS/EXO-iOSTests/EXO_iOSTests.swift
Normal file
@@ -0,0 +1,18 @@
|
||||
//
|
||||
// EXO_iOSTests.swift
|
||||
// EXO-iOSTests
|
||||
//
|
||||
// Created by Sami Khan on 2026-02-17.
|
||||
//
|
||||
|
||||
import Testing
|
||||
|
||||
@testable import EXO_iOS
|
||||
|
||||
struct EXO_iOSTests {
|
||||
|
||||
@Test func example() async throws {
|
||||
// Write your test here and use APIs like `#expect(...)` to check expected conditions.
|
||||
}
|
||||
|
||||
}
|
||||
41
app/EXO-iOS/EXO-iOSUITests/EXO_iOSUITests.swift
Normal file
41
app/EXO-iOS/EXO-iOSUITests/EXO_iOSUITests.swift
Normal file
@@ -0,0 +1,41 @@
|
||||
//
|
||||
// EXO_iOSUITests.swift
|
||||
// EXO-iOSUITests
|
||||
//
|
||||
// Created by Sami Khan on 2026-02-17.
|
||||
//
|
||||
|
||||
import XCTest
|
||||
|
||||
final class EXO_iOSUITests: XCTestCase {
|
||||
|
||||
override func setUpWithError() throws {
|
||||
// Put setup code here. This method is called before the invocation of each test method in the class.
|
||||
|
||||
// In UI tests it is usually best to stop immediately when a failure occurs.
|
||||
continueAfterFailure = false
|
||||
|
||||
// In UI tests it’s important to set the initial state - such as interface orientation - required for your tests before they run. The setUp method is a good place to do this.
|
||||
}
|
||||
|
||||
override func tearDownWithError() throws {
|
||||
// Put teardown code here. This method is called after the invocation of each test method in the class.
|
||||
}
|
||||
|
||||
@MainActor
|
||||
func testExample() throws {
|
||||
// UI tests must launch the application that they test.
|
||||
let app = XCUIApplication()
|
||||
app.launch()
|
||||
|
||||
// Use XCTAssert and related functions to verify your tests produce the correct results.
|
||||
}
|
||||
|
||||
@MainActor
|
||||
func testLaunchPerformance() throws {
|
||||
// This measures how long it takes to launch your application.
|
||||
measure(metrics: [XCTApplicationLaunchMetric()]) {
|
||||
XCUIApplication().launch()
|
||||
}
|
||||
}
|
||||
}
|
||||
33
app/EXO-iOS/EXO-iOSUITests/EXO_iOSUITestsLaunchTests.swift
Normal file
33
app/EXO-iOS/EXO-iOSUITests/EXO_iOSUITestsLaunchTests.swift
Normal file
@@ -0,0 +1,33 @@
|
||||
//
|
||||
// EXO_iOSUITestsLaunchTests.swift
|
||||
// EXO-iOSUITests
|
||||
//
|
||||
// Created by Sami Khan on 2026-02-17.
|
||||
//
|
||||
|
||||
import XCTest
|
||||
|
||||
final class EXO_iOSUITestsLaunchTests: XCTestCase {
|
||||
|
||||
override class var runsForEachTargetApplicationUIConfiguration: Bool {
|
||||
true
|
||||
}
|
||||
|
||||
override func setUpWithError() throws {
|
||||
continueAfterFailure = false
|
||||
}
|
||||
|
||||
@MainActor
|
||||
func testLaunch() throws {
|
||||
let app = XCUIApplication()
|
||||
app.launch()
|
||||
|
||||
// Insert steps here to perform after app launch but before taking a screenshot,
|
||||
// such as logging into a test account or navigating somewhere in the app
|
||||
|
||||
let attachment = XCTAttachment(screenshot: app.screenshot())
|
||||
attachment.name = "Launch Screen"
|
||||
attachment.lifetime = .keepAlways
|
||||
add(attachment)
|
||||
}
|
||||
}
|
||||
@@ -126,37 +126,11 @@ final class ExoProcessController: ObservableObject {
|
||||
return
|
||||
}
|
||||
process.terminationHandler = nil
|
||||
status = .stopped
|
||||
|
||||
guard process.isRunning else {
|
||||
self.process = nil
|
||||
return
|
||||
if process.isRunning {
|
||||
process.terminate()
|
||||
}
|
||||
|
||||
let proc = process
|
||||
self.process = nil
|
||||
|
||||
Task.detached {
|
||||
proc.interrupt()
|
||||
|
||||
for _ in 0..<50 {
|
||||
if !proc.isRunning { return }
|
||||
try? await Task.sleep(nanoseconds: 100_000_000)
|
||||
}
|
||||
|
||||
if proc.isRunning {
|
||||
proc.terminate()
|
||||
}
|
||||
|
||||
for _ in 0..<30 {
|
||||
if !proc.isRunning { return }
|
||||
try? await Task.sleep(nanoseconds: 100_000_000)
|
||||
}
|
||||
|
||||
if proc.isRunning {
|
||||
kill(proc.processIdentifier, SIGKILL)
|
||||
}
|
||||
}
|
||||
status = .stopped
|
||||
}
|
||||
|
||||
func restart() {
|
||||
|
||||
@@ -185,7 +185,11 @@
|
||||
|
||||
let instanceType: string | null = null;
|
||||
if (instanceTag === "MlxRingInstance") instanceType = "MLX Ring";
|
||||
else if (instanceTag === "MlxJacclInstance") instanceType = "MLX RDMA";
|
||||
else if (
|
||||
instanceTag === "MlxIbvInstance" ||
|
||||
instanceTag === "MlxJacclInstance"
|
||||
)
|
||||
instanceType = "MLX RDMA";
|
||||
|
||||
let sharding: string | null = null;
|
||||
const inst = instance as {
|
||||
|
||||
@@ -21,7 +21,7 @@
|
||||
} | null;
|
||||
nodes?: Record<string, NodeInfo>;
|
||||
sharding?: "Pipeline" | "Tensor";
|
||||
runtime?: "MlxRing" | "MlxJaccl";
|
||||
runtime?: "MlxRing" | "MlxIbv" | "MlxJaccl";
|
||||
onLaunch?: () => void;
|
||||
tags?: string[];
|
||||
apiPreview?: PlacementPreview | null;
|
||||
@@ -348,7 +348,7 @@
|
||||
// Debug mode state
|
||||
const isDebugMode = $derived(debugMode());
|
||||
const topology = $derived(topologyData());
|
||||
const isRdma = $derived(runtime === "MlxJaccl");
|
||||
const isRdma = $derived(runtime === "MlxIbv" || runtime === "MlxJaccl");
|
||||
|
||||
// Get interface name for an IP from node data
|
||||
function getInterfaceForIp(nodeId: string, ip?: string): string | null {
|
||||
@@ -575,7 +575,7 @@
|
||||
>
|
||||
{runtime === "MlxRing"
|
||||
? "MLX Ring"
|
||||
: runtime === "MlxJaccl"
|
||||
: runtime === "MlxIbv" || runtime === "MlxJaccl"
|
||||
? "MLX RDMA"
|
||||
: runtime}
|
||||
</span>
|
||||
|
||||
@@ -168,7 +168,7 @@ export interface ModelDownloadStatus {
|
||||
export interface PlacementPreview {
|
||||
model_id: string;
|
||||
sharding: "Pipeline" | "Tensor";
|
||||
instance_meta: "MlxRing" | "MlxJaccl";
|
||||
instance_meta: "MlxRing" | "MlxIbv" | "MlxJaccl";
|
||||
instance: unknown | null;
|
||||
memory_delta_by_node: Record<string, number> | null;
|
||||
error: string | null;
|
||||
@@ -219,6 +219,7 @@ interface RawStateResponse {
|
||||
string,
|
||||
{
|
||||
MlxRingInstance?: Instance;
|
||||
MlxIbvInstance?: Instance;
|
||||
MlxJacclInstance?: Instance;
|
||||
}
|
||||
>;
|
||||
@@ -249,20 +250,6 @@ interface RawStateResponse {
|
||||
>;
|
||||
// Thunderbolt bridge cycles (nodes with bridge enabled forming loops)
|
||||
thunderboltBridgeCycles?: string[][];
|
||||
// MetaInstances (declarative instance constraints)
|
||||
metaInstances?: Record<string, MetaInstanceData>;
|
||||
}
|
||||
|
||||
export interface MetaInstanceData {
|
||||
metaInstanceId: string;
|
||||
modelId: string;
|
||||
sharding: string;
|
||||
instanceMeta: string;
|
||||
minNodes: number;
|
||||
nodeIds: string[] | null;
|
||||
placementError: string | null;
|
||||
consecutiveFailures: number;
|
||||
lastFailureError: string | null;
|
||||
}
|
||||
|
||||
export interface MessageAttachment {
|
||||
@@ -550,7 +537,6 @@ class AppStore {
|
||||
previewNodeFilter = $state<Set<string>>(new Set());
|
||||
lastUpdate = $state<number | null>(null);
|
||||
nodeIdentities = $state<Record<string, RawNodeIdentity>>({});
|
||||
metaInstances = $state<Record<string, MetaInstanceData>>({});
|
||||
thunderboltBridgeCycles = $state<string[][]>([]);
|
||||
nodeThunderbolt = $state<
|
||||
Record<
|
||||
@@ -909,7 +895,11 @@ class AppStore {
|
||||
|
||||
let instanceType: string | null = null;
|
||||
if (instanceTag === "MlxRingInstance") instanceType = "MLX Ring";
|
||||
else if (instanceTag === "MlxJacclInstance") instanceType = "MLX RDMA";
|
||||
else if (
|
||||
instanceTag === "MlxIbvInstance" ||
|
||||
instanceTag === "MlxJacclInstance"
|
||||
)
|
||||
instanceType = "MLX RDMA";
|
||||
|
||||
let sharding: string | null = null;
|
||||
const inst = instance as {
|
||||
@@ -1283,8 +1273,6 @@ class AppStore {
|
||||
this.nodeThunderbolt = data.nodeThunderbolt ?? {};
|
||||
// RDMA ctl status per node
|
||||
this.nodeRdmaCtl = data.nodeRdmaCtl ?? {};
|
||||
// MetaInstances
|
||||
this.metaInstances = data.metaInstances ?? {};
|
||||
// Thunderbolt bridge cycles
|
||||
this.thunderboltBridgeCycles = data.thunderboltBridgeCycles ?? [];
|
||||
// Thunderbolt bridge status per node
|
||||
@@ -3056,7 +3044,6 @@ export const tps = () => appStore.tps;
|
||||
export const totalTokens = () => appStore.totalTokens;
|
||||
export const topologyData = () => appStore.topologyData;
|
||||
export const instances = () => appStore.instances;
|
||||
export const metaInstances = () => appStore.metaInstances;
|
||||
export const runners = () => appStore.runners;
|
||||
export const downloads = () => appStore.downloads;
|
||||
export const nodeDisk = () => appStore.nodeDisk;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -115,7 +115,7 @@
|
||||
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin (
|
||||
let
|
||||
uvLock = builtins.fromTOML (builtins.readFile ./uv.lock);
|
||||
mlxPackage = builtins.head (builtins.filter (p: p.name == "mlx" && p.source ? git) uvLock.package);
|
||||
mlxPackage = builtins.head (builtins.filter (p: p.name == "mlx") uvLock.package);
|
||||
uvLockMlxVersion = mlxPackage.version;
|
||||
in
|
||||
{
|
||||
|
||||
10
nix/mlx.nix
10
nix/mlx.nix
@@ -41,16 +41,16 @@ let
|
||||
|
||||
mlx = stdenv.mkDerivation rec {
|
||||
pname = "mlx";
|
||||
version = let v = "0.30.7.dev20260217+50487b41"; in
|
||||
version = let v = "0.30.6"; in
|
||||
assert v == uvLockMlxVersion || throw "MLX version mismatch: nix/mlx.nix has ${v} but uv.lock has ${uvLockMlxVersion}. Update both the version and hash in nix/mlx.nix.";
|
||||
v;
|
||||
pyproject = true;
|
||||
|
||||
src = fetchFromGitHub {
|
||||
owner = "rltakashige";
|
||||
repo = "mlx-jaccl-fix-small-recv";
|
||||
rev = "50487b4141f3c951122655db3b83df5146c1fbeb";
|
||||
hash = "sha256-IL4a9vMX5nocgJU1WG4zE8hArHkHJtnh4sdYh3od5zU=";
|
||||
owner = "ml-explore";
|
||||
repo = "mlx";
|
||||
tag = "v${version}";
|
||||
hash = "sha256-avD5EGhwgmPdXLAyQSqTO6AXk/W3ziH+f6AetjK3Sdo=";
|
||||
};
|
||||
|
||||
patches = [
|
||||
|
||||
@@ -17,7 +17,7 @@ dependencies = [
|
||||
"loguru>=0.7.3",
|
||||
"exo_pyo3_bindings", # rust bindings
|
||||
"anyio==4.11.0",
|
||||
"mlx; sys_platform == 'darwin'",
|
||||
"mlx==0.30.6; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.6; sys_platform == 'linux'",
|
||||
"mlx-lm==0.30.6",
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
@@ -64,7 +64,6 @@ members = [
|
||||
|
||||
[tool.uv.sources]
|
||||
exo_pyo3_bindings = { workspace = true }
|
||||
mlx = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git", branch = "address-rdma-gpu-locks", marker = "sys_platform == 'darwin'" }
|
||||
#mlx-lm = { git = "https://github.com/davidmcc73/mlx-lm", branch = "stable" }
|
||||
# Uncomment to use local mlx/mlx-lm development versions:
|
||||
# mlx = { path = "/Users/Shared/mlx", editable=true }
|
||||
|
||||
@@ -58,21 +58,6 @@
|
||||
lib.optionalAttrs pkgs.stdenv.hostPlatform.isLinux (
|
||||
(lib.mapAttrs (_: ignoreMissing) nvidiaPackages) // {
|
||||
mlx = ignoreMissing prev.mlx;
|
||||
mlx-cuda-13 = prev.mlx-cuda-13.overrideAttrs (old: {
|
||||
buildInputs = (old.buildInputs or [ ]) ++ [
|
||||
final.nvidia-cublas
|
||||
final.nvidia-cuda-nvrtc
|
||||
final.nvidia-cudnn-cu13
|
||||
final.nvidia-nccl-cu13
|
||||
];
|
||||
preFixup = ''
|
||||
addAutoPatchelfSearchPath ${final.nvidia-cublas}
|
||||
addAutoPatchelfSearchPath ${final.nvidia-cuda-nvrtc}
|
||||
addAutoPatchelfSearchPath ${final.nvidia-cudnn-cu13}
|
||||
addAutoPatchelfSearchPath ${final.nvidia-nccl-cu13}
|
||||
'';
|
||||
autoPatchelfIgnoreMissingDeps = [ "libcuda.so.1" ];
|
||||
});
|
||||
torch = ignoreMissing prev.torch;
|
||||
triton = ignoreMissing prev.triton;
|
||||
}
|
||||
@@ -89,25 +74,14 @@
|
||||
linuxOverlay
|
||||
]
|
||||
);
|
||||
# mlx-cpu and mlx-cuda-13 both ship mlx/ site-packages files; keep first.
|
||||
# mlx-cpu/mlx-cuda-13 and nvidia-cudnn-cu12/cu13 ship overlapping files.
|
||||
venvCollisionPaths = lib.optionals pkgs.stdenv.hostPlatform.isLinux [
|
||||
"lib/python3.13/site-packages/mlx*"
|
||||
"lib/python3.13/site-packages/nvidia*"
|
||||
];
|
||||
|
||||
exoVenv = (pythonSet.mkVirtualEnv "exo-env" workspace.deps.default).overrideAttrs {
|
||||
venvIgnoreCollisions = venvCollisionPaths;
|
||||
};
|
||||
exoVenv = pythonSet.mkVirtualEnv "exo-env" workspace.deps.default;
|
||||
|
||||
# Virtual environment with dev dependencies for testing
|
||||
testVenv = (pythonSet.mkVirtualEnv "exo-test-env" (
|
||||
testVenv = pythonSet.mkVirtualEnv "exo-test-env" (
|
||||
workspace.deps.default // {
|
||||
exo = [ "dev" ]; # Include pytest, pytest-asyncio, pytest-env
|
||||
}
|
||||
)).overrideAttrs {
|
||||
venvIgnoreCollisions = venvCollisionPaths;
|
||||
};
|
||||
);
|
||||
|
||||
mkPythonScript = name: path: pkgs.writeShellApplication {
|
||||
inherit name;
|
||||
|
||||
@@ -314,17 +314,7 @@ class DownloadCoordinator:
|
||||
),
|
||||
)
|
||||
elif progress.status in ["in_progress", "not_started"]:
|
||||
if (
|
||||
progress.downloaded_bytes.in_bytes
|
||||
>= progress.total_bytes.in_bytes
|
||||
> 0
|
||||
):
|
||||
status = DownloadCompleted(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
total_bytes=progress.total_bytes,
|
||||
)
|
||||
elif progress.downloaded_bytes_this_session.in_bytes == 0:
|
||||
if progress.downloaded_bytes_this_session.in_bytes == 0:
|
||||
status = DownloadPending(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
|
||||
@@ -136,8 +136,6 @@ class Node:
|
||||
|
||||
async def run(self):
|
||||
async with self._tg as tg:
|
||||
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
|
||||
signal.signal(signal.SIGTERM, lambda _, __: self.shutdown())
|
||||
tg.start_soon(self.router.run)
|
||||
tg.start_soon(self.election.run)
|
||||
if self.download_coordinator:
|
||||
@@ -149,6 +147,8 @@ class Node:
|
||||
if self.api:
|
||||
tg.start_soon(self.api.run)
|
||||
tg.start_soon(self._elect_loop)
|
||||
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
|
||||
signal.signal(signal.SIGTERM, lambda _, __: self.shutdown())
|
||||
|
||||
def shutdown(self):
|
||||
# if this is our second call to shutdown, just sys.exit
|
||||
@@ -254,7 +254,7 @@ def main():
|
||||
target = min(max(soft, 65535), hard)
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (target, hard))
|
||||
|
||||
mp.set_start_method("spawn", force=True)
|
||||
mp.set_start_method("spawn")
|
||||
# TODO: Refactor the current verbosity system
|
||||
logger_setup(EXO_LOG, args.verbosity)
|
||||
logger.info("Starting EXO")
|
||||
|
||||
@@ -71,11 +71,8 @@ from exo.shared.types.api import (
|
||||
ChatCompletionResponse,
|
||||
CreateInstanceParams,
|
||||
CreateInstanceResponse,
|
||||
CreateMetaInstanceParams,
|
||||
CreateMetaInstanceResponse,
|
||||
DeleteDownloadResponse,
|
||||
DeleteInstanceResponse,
|
||||
DeleteMetaInstanceResponse,
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
FinishReason,
|
||||
@@ -118,10 +115,8 @@ from exo.shared.types.claude_api import (
|
||||
from exo.shared.types.commands import (
|
||||
Command,
|
||||
CreateInstance,
|
||||
CreateMetaInstance,
|
||||
DeleteDownload,
|
||||
DeleteInstance,
|
||||
DeleteMetaInstance,
|
||||
DownloadCommand,
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
@@ -134,7 +129,7 @@ from exo.shared.types.commands import (
|
||||
TaskFinished,
|
||||
TextGeneration,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, Id, MetaInstanceId, NodeId, SessionId
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
@@ -143,7 +138,6 @@ from exo.shared.types.events import (
|
||||
TracesMerged,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.meta_instance import MetaInstance
|
||||
from exo.shared.types.openai_responses import (
|
||||
ResponsesRequest,
|
||||
ResponsesResponse,
|
||||
@@ -282,9 +276,6 @@ class API:
|
||||
self.app.get("/instance/previews")(self.get_placement_previews)
|
||||
self.app.get("/instance/{instance_id}")(self.get_instance)
|
||||
self.app.delete("/instance/{instance_id}")(self.delete_instance)
|
||||
self.app.get("/meta_instances")(self.list_meta_instances)
|
||||
self.app.post("/meta_instance")(self.create_meta_instance)
|
||||
self.app.delete("/meta_instance/{meta_instance_id}")(self.delete_meta_instance)
|
||||
self.app.get("/models")(self.get_models)
|
||||
self.app.get("/v1/models")(self.get_models)
|
||||
self.app.post("/models/add")(self.add_custom_model)
|
||||
@@ -314,27 +305,12 @@ class API:
|
||||
self.app.get("/v1/traces/{task_id}/raw")(self.get_trace_raw)
|
||||
|
||||
async def place_instance(self, payload: PlaceInstanceParams):
|
||||
model_card = await ModelCard.load(payload.model_id)
|
||||
command = PlaceInstance(
|
||||
model_card=model_card,
|
||||
model_card=await ModelCard.load(payload.model_id),
|
||||
sharding=payload.sharding,
|
||||
instance_meta=payload.instance_meta,
|
||||
min_nodes=payload.min_nodes,
|
||||
)
|
||||
|
||||
# Validate placement before sending — fail fast with a clear error
|
||||
# instead of silently dropping the command in the master.
|
||||
try:
|
||||
get_instance_placements(
|
||||
command,
|
||||
topology=self.state.topology,
|
||||
current_instances=self.state.instances,
|
||||
node_memory=self.state.node_memory,
|
||||
node_network=self.state.node_network,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
await self._send(command)
|
||||
|
||||
return CreateInstanceResponse(
|
||||
@@ -546,44 +522,6 @@ class API:
|
||||
instance_id=instance_id,
|
||||
)
|
||||
|
||||
def list_meta_instances(self) -> dict[MetaInstanceId, MetaInstance]:
|
||||
return dict(self.state.meta_instances)
|
||||
|
||||
async def create_meta_instance(
|
||||
self, payload: CreateMetaInstanceParams
|
||||
) -> CreateMetaInstanceResponse:
|
||||
meta_instance = MetaInstance(
|
||||
model_id=payload.model_id,
|
||||
sharding=payload.sharding,
|
||||
instance_meta=payload.instance_meta,
|
||||
min_nodes=payload.min_nodes,
|
||||
node_ids=payload.node_ids,
|
||||
)
|
||||
command = CreateMetaInstance(meta_instance=meta_instance)
|
||||
await self._send(command)
|
||||
return CreateMetaInstanceResponse(
|
||||
message="Command received.",
|
||||
command_id=command.command_id,
|
||||
meta_instance_id=meta_instance.meta_instance_id,
|
||||
)
|
||||
|
||||
async def delete_meta_instance(
|
||||
self, meta_instance_id: MetaInstanceId
|
||||
) -> DeleteMetaInstanceResponse:
|
||||
meta = self.state.meta_instances.get(meta_instance_id)
|
||||
if not meta:
|
||||
raise HTTPException(status_code=404, detail="MetaInstance not found")
|
||||
|
||||
# Command processor handles cascade-deleting backing instances
|
||||
command = DeleteMetaInstance(meta_instance_id=meta_instance_id)
|
||||
await self._send(command)
|
||||
|
||||
return DeleteMetaInstanceResponse(
|
||||
message="Command received.",
|
||||
command_id=command.command_id,
|
||||
meta_instance_id=meta_instance_id,
|
||||
)
|
||||
|
||||
async def _token_chunk_stream(
|
||||
self, command_id: CommandId
|
||||
) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]:
|
||||
@@ -1429,6 +1367,7 @@ class API:
|
||||
async def run(self):
|
||||
shutdown_ev = anyio.Event()
|
||||
|
||||
bonjour_cleanup = self._register_bonjour_service()
|
||||
try:
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
@@ -1444,10 +1383,48 @@ class API:
|
||||
with anyio.CancelScope(shield=True):
|
||||
shutdown_ev.set()
|
||||
finally:
|
||||
bonjour_cleanup()
|
||||
self._event_log.close()
|
||||
self.command_sender.close()
|
||||
self.global_event_receiver.close()
|
||||
|
||||
def _register_bonjour_service(self) -> Callable[[], None]:
|
||||
"""Register a Bonjour service via the system mDNSResponder. Returns a cleanup function."""
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
if sys.platform != "darwin":
|
||||
logger.info("Bonjour service registration is only supported on macOS")
|
||||
return lambda: None
|
||||
|
||||
service_name = f"EXO Cluster ({self.node_id[:8]})"
|
||||
try:
|
||||
proc = subprocess.Popen(
|
||||
[
|
||||
"dns-sd",
|
||||
"-R",
|
||||
service_name,
|
||||
"_exo._tcp",
|
||||
"local",
|
||||
str(self.port),
|
||||
f"node_id={self.node_id}",
|
||||
],
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
logger.info(
|
||||
f"Registered Bonjour service _exo._tcp on port {self.port} (pid {proc.pid})"
|
||||
)
|
||||
|
||||
def cleanup() -> None:
|
||||
proc.terminate()
|
||||
proc.wait()
|
||||
|
||||
return cleanup
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to register Bonjour service: {e}")
|
||||
return lambda: None
|
||||
|
||||
async def run_api(self, ev: anyio.Event):
|
||||
cfg = Config()
|
||||
cfg.bind = [f"0.0.0.0:{self.port}"]
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import anyio
|
||||
from anyio.abc import TaskGroup
|
||||
@@ -13,22 +12,11 @@ from exo.master.placement import (
|
||||
get_transition_events,
|
||||
place_instance,
|
||||
)
|
||||
from exo.master.process_managers import ProcessManager
|
||||
from exo.master.process_managers.instance_health import InstanceHealthReconciler
|
||||
from exo.master.process_managers.meta_instance import MetaInstanceReconciler
|
||||
from exo.master.process_managers.node_timeout import NodeTimeoutReconciler
|
||||
from exo.master.reconcile import (
|
||||
find_unsatisfied_meta_instances,
|
||||
try_place_for_meta_instance,
|
||||
)
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.constants import EXO_EVENT_LOG_DIR, EXO_TRACING_ENABLED
|
||||
from exo.shared.models.model_cards import ModelCard
|
||||
from exo.shared.types.commands import (
|
||||
CreateInstance,
|
||||
CreateMetaInstance,
|
||||
DeleteInstance,
|
||||
DeleteMetaInstance,
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
ImageEdits,
|
||||
@@ -48,12 +36,8 @@ from exo.shared.types.events import (
|
||||
IndexedEvent,
|
||||
InputChunkReceived,
|
||||
InstanceDeleted,
|
||||
JacclSideChannelData,
|
||||
JacclSideChannelGathered,
|
||||
MetaInstanceCreated,
|
||||
MetaInstanceDeleted,
|
||||
MetaInstancePlacementFailed,
|
||||
NodeGatheredInfo,
|
||||
NodeTimedOut,
|
||||
TaskCreated,
|
||||
TaskDeleted,
|
||||
TaskStatusUpdated,
|
||||
@@ -76,8 +60,7 @@ from exo.shared.types.tasks import (
|
||||
TextGeneration as TextGenerationTask,
|
||||
)
|
||||
from exo.shared.types.worker.instances import InstanceId
|
||||
from exo.shared.types.worker.runners import RunnerId
|
||||
from exo.utils.channels import Receiver, Sender
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.event_buffer import MultiSourceBuffer
|
||||
|
||||
|
||||
@@ -101,16 +84,16 @@ class Master:
|
||||
self.local_event_receiver = local_event_receiver
|
||||
self.global_event_sender = global_event_sender
|
||||
self.download_command_sender = download_command_sender
|
||||
send, recv = channel[Event]()
|
||||
self.event_sender: Sender[Event] = send
|
||||
self._loopback_event_receiver: Receiver[Event] = recv
|
||||
self._loopback_event_sender: Sender[ForwarderEvent] = (
|
||||
local_event_receiver.clone_sender()
|
||||
)
|
||||
self._multi_buffer = MultiSourceBuffer[NodeId, Event]()
|
||||
self._event_log = DiskEventLog(EXO_EVENT_LOG_DIR / "master")
|
||||
self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {}
|
||||
self._expected_ranks: dict[TaskId, set[int]] = {}
|
||||
self._jaccl_pending: dict[InstanceId, dict[int, dict[RunnerId, bytes]]] = {}
|
||||
self._process_managers: Sequence[ProcessManager] = [
|
||||
InstanceHealthReconciler(),
|
||||
NodeTimeoutReconciler(),
|
||||
MetaInstanceReconciler(),
|
||||
]
|
||||
|
||||
async def run(self):
|
||||
logger.info("Starting Master")
|
||||
@@ -119,12 +102,15 @@ class Master:
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(self._event_processor)
|
||||
tg.start_soon(self._command_processor)
|
||||
tg.start_soon(self._reconcile)
|
||||
tg.start_soon(self._loopback_processor)
|
||||
tg.start_soon(self._plan)
|
||||
finally:
|
||||
self._event_log.close()
|
||||
self.global_event_sender.close()
|
||||
self.local_event_receiver.close()
|
||||
self.command_receiver.close()
|
||||
self._loopback_event_sender.close()
|
||||
self._loopback_event_receiver.close()
|
||||
|
||||
async def shutdown(self):
|
||||
logger.info("Stopping Master")
|
||||
@@ -306,86 +292,6 @@ class Master:
|
||||
)
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case CreateMetaInstance():
|
||||
logger.info(
|
||||
f"Creating MetaInstance for {command.meta_instance.model_id}"
|
||||
f" (min_nodes={command.meta_instance.min_nodes},"
|
||||
f" sharding={command.meta_instance.sharding})"
|
||||
)
|
||||
# Apply immediately so self.state is fresh across
|
||||
# the await below and the reconciler won't race.
|
||||
await self._apply_and_broadcast(
|
||||
MetaInstanceCreated(meta_instance=command.meta_instance)
|
||||
)
|
||||
# Immediate placement attempt for responsiveness
|
||||
model_card = await ModelCard.load(
|
||||
command.meta_instance.model_id
|
||||
)
|
||||
# Re-check: reconciler may have satisfied it during the await
|
||||
meta_id = command.meta_instance.meta_instance_id
|
||||
still_unsatisfied = any(
|
||||
m.meta_instance_id == meta_id
|
||||
for m in find_unsatisfied_meta_instances(
|
||||
self.state.meta_instances,
|
||||
self.state.instances,
|
||||
self.state.topology,
|
||||
)
|
||||
)
|
||||
if still_unsatisfied:
|
||||
result = try_place_for_meta_instance(
|
||||
command.meta_instance,
|
||||
model_card,
|
||||
self.state.topology,
|
||||
self.state.instances,
|
||||
self.state.node_memory,
|
||||
self.state.node_network,
|
||||
self.state.tasks,
|
||||
)
|
||||
generated_events.extend(result.events)
|
||||
if result.error is not None:
|
||||
generated_events.append(
|
||||
MetaInstancePlacementFailed(
|
||||
meta_instance_id=meta_id,
|
||||
reason=result.error,
|
||||
)
|
||||
)
|
||||
case DeleteMetaInstance():
|
||||
backing_count = sum(
|
||||
1
|
||||
for inst in self.state.instances.values()
|
||||
if inst.meta_instance_id == command.meta_instance_id
|
||||
)
|
||||
logger.info(
|
||||
f"Deleting MetaInstance {command.meta_instance_id}"
|
||||
f" (cascade-deleting {backing_count} backing instance(s))"
|
||||
)
|
||||
generated_events.append(
|
||||
MetaInstanceDeleted(
|
||||
meta_instance_id=command.meta_instance_id
|
||||
)
|
||||
)
|
||||
# Cascade-delete backing instances atomically,
|
||||
# cancelling any active tasks first.
|
||||
for iid, inst in self.state.instances.items():
|
||||
if inst.meta_instance_id == command.meta_instance_id:
|
||||
for task in self.state.tasks.values():
|
||||
if (
|
||||
task.instance_id == iid
|
||||
and task.task_status
|
||||
in (
|
||||
TaskStatus.Pending,
|
||||
TaskStatus.Running,
|
||||
)
|
||||
):
|
||||
generated_events.append(
|
||||
TaskStatusUpdated(
|
||||
task_status=TaskStatus.Cancelled,
|
||||
task_id=task.task_id,
|
||||
)
|
||||
)
|
||||
generated_events.append(
|
||||
InstanceDeleted(instance_id=iid)
|
||||
)
|
||||
case PlaceInstance():
|
||||
placement = place_instance(
|
||||
command,
|
||||
@@ -448,32 +354,31 @@ class Master:
|
||||
):
|
||||
await self._send_event(IndexedEvent(idx=i, event=event))
|
||||
for event in generated_events:
|
||||
await self._apply_and_broadcast(event)
|
||||
await self.event_sender.send(event)
|
||||
except ValueError as e:
|
||||
logger.opt(exception=e).warning("Error in command processor")
|
||||
|
||||
async def _apply_and_broadcast(self, event: Event) -> None:
|
||||
"""Apply event to state, persist to disk, and broadcast to workers.
|
||||
|
||||
State is updated synchronously (before any await), so callers can
|
||||
rely on ``self.state`` reflecting this event immediately after the
|
||||
call. Python's cooperative scheduling guarantees no interleaving
|
||||
between the state read and write.
|
||||
"""
|
||||
logger.debug(f"Master indexing event: {str(event)[:100]}")
|
||||
indexed = IndexedEvent(event=event, idx=len(self._event_log))
|
||||
self.state = apply(self.state, indexed)
|
||||
event._master_time_stamp = datetime.now(tz=timezone.utc) # pyright: ignore[reportPrivateUsage]
|
||||
self._event_log.append(event)
|
||||
await self._send_event(indexed)
|
||||
|
||||
async def _reconcile(self) -> None:
|
||||
# These plan loops are the cracks showing in our event sourcing architecture - more things could be commands
|
||||
async def _plan(self) -> None:
|
||||
while True:
|
||||
for pm in self._process_managers:
|
||||
events = await pm.reconcile(self.state)
|
||||
for event in events:
|
||||
await self._apply_and_broadcast(event)
|
||||
await anyio.sleep(1)
|
||||
# kill broken instances
|
||||
connected_node_ids = set(self.state.topology.list_nodes())
|
||||
for instance_id, instance in self.state.instances.items():
|
||||
for node_id in instance.shard_assignments.node_to_runner:
|
||||
if node_id not in connected_node_ids:
|
||||
await self.event_sender.send(
|
||||
InstanceDeleted(instance_id=instance_id)
|
||||
)
|
||||
break
|
||||
|
||||
# time out dead nodes
|
||||
for node_id, time in self.state.last_seen.items():
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
if now - time > timedelta(seconds=30):
|
||||
logger.info(f"Manually removing node {node_id} due to inactivity")
|
||||
await self.event_sender.send(NodeTimedOut(node_id=node_id))
|
||||
|
||||
await anyio.sleep(10)
|
||||
|
||||
async def _event_processor(self) -> None:
|
||||
with self.local_event_receiver as local_events:
|
||||
@@ -491,15 +396,32 @@ class Master:
|
||||
await self._handle_traces_collected(event)
|
||||
continue
|
||||
|
||||
if isinstance(event, JacclSideChannelData):
|
||||
await self._apply_and_broadcast(event)
|
||||
await self._handle_jaccl_side_channel(event)
|
||||
continue
|
||||
logger.debug(f"Master indexing event: {str(event)[:100]}")
|
||||
indexed = IndexedEvent(event=event, idx=len(self._event_log))
|
||||
self.state = apply(self.state, indexed)
|
||||
|
||||
event._master_time_stamp = datetime.now(tz=timezone.utc) # pyright: ignore[reportPrivateUsage]
|
||||
if isinstance(event, NodeGatheredInfo):
|
||||
event.when = str(datetime.now(tz=timezone.utc))
|
||||
|
||||
await self._apply_and_broadcast(event)
|
||||
self._event_log.append(event)
|
||||
await self._send_event(indexed)
|
||||
|
||||
async def _loopback_processor(self) -> None:
|
||||
# this would ideally not be necessary.
|
||||
# this is WAY less hacky than how I was working around this before
|
||||
local_index = 0
|
||||
with self._loopback_event_receiver as events:
|
||||
async for event in events:
|
||||
await self._loopback_event_sender.send(
|
||||
ForwarderEvent(
|
||||
origin=NodeId(f"master_{self.node_id}"),
|
||||
origin_idx=local_index,
|
||||
session=self.session_id,
|
||||
event=event,
|
||||
)
|
||||
)
|
||||
local_index += 1
|
||||
|
||||
# This function is re-entrant, take care!
|
||||
async def _send_event(self, event: IndexedEvent):
|
||||
@@ -531,49 +453,10 @@ class Master:
|
||||
for trace_data in self._pending_traces[task_id].values():
|
||||
all_trace_data.extend(trace_data)
|
||||
|
||||
await self._apply_and_broadcast(
|
||||
await self.event_sender.send(
|
||||
TracesMerged(task_id=task_id, traces=all_trace_data)
|
||||
)
|
||||
|
||||
del self._pending_traces[task_id]
|
||||
if task_id in self._expected_ranks:
|
||||
del self._expected_ranks[task_id]
|
||||
|
||||
async def _handle_jaccl_side_channel(self, event: JacclSideChannelData) -> None:
|
||||
"""Accumulate SideChannel contributions; when all runners for an instance
|
||||
have submitted for the same sequence, emit JacclSideChannelGathered."""
|
||||
iid = event.instance_id
|
||||
seq = event.sequence
|
||||
|
||||
if iid not in self._jaccl_pending:
|
||||
self._jaccl_pending[iid] = {}
|
||||
if seq not in self._jaccl_pending[iid]:
|
||||
self._jaccl_pending[iid][seq] = {}
|
||||
self._jaccl_pending[iid][seq][event.runner_id] = event.data
|
||||
|
||||
instance = self.state.instances.get(iid)
|
||||
if instance is None:
|
||||
logger.warning(f"JacclSideChannelData for unknown instance {iid}")
|
||||
return
|
||||
|
||||
expected_runners = set(instance.shard_assignments.runner_to_shard.keys())
|
||||
submitted = set(self._jaccl_pending[iid][seq].keys())
|
||||
|
||||
logger.info(
|
||||
f"JACCL side channel: instance={iid} seq={seq} "
|
||||
f"submitted={len(submitted)}/{len(expected_runners)}"
|
||||
)
|
||||
|
||||
if submitted >= expected_runners:
|
||||
gathered = dict(self._jaccl_pending[iid][seq])
|
||||
del self._jaccl_pending[iid][seq]
|
||||
if not self._jaccl_pending[iid]:
|
||||
del self._jaccl_pending[iid]
|
||||
|
||||
await self._apply_and_broadcast(
|
||||
JacclSideChannelGathered(
|
||||
instance_id=iid,
|
||||
sequence=seq,
|
||||
gathered_data=gathered,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -6,11 +6,11 @@ from typing import Sequence
|
||||
from exo.master.placement_utils import (
|
||||
Cycle,
|
||||
filter_cycles_by_memory,
|
||||
get_largest_cycles,
|
||||
get_mlx_jaccl_coordinators,
|
||||
get_mlx_jaccl_devices_matrix,
|
||||
get_mlx_ring_hosts_by_node,
|
||||
get_shard_assignments,
|
||||
get_smallest_cycles,
|
||||
)
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.topology import Topology
|
||||
@@ -106,27 +106,23 @@ def place_instance(
|
||||
"Pipeline parallelism is not supported for DeepSeek V3.1 (8-bit)"
|
||||
)
|
||||
|
||||
largest_cycles = get_largest_cycles(cycles_with_sufficient_memory)
|
||||
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
|
||||
|
||||
largest_rdma_cycles = [
|
||||
cycle for cycle in largest_cycles if topology.is_rdma_cycle(cycle)
|
||||
smallest_rdma_cycles = [
|
||||
cycle for cycle in smallest_cycles if topology.is_rdma_cycle(cycle)
|
||||
]
|
||||
|
||||
if command.instance_meta == InstanceMeta.MlxJaccl:
|
||||
if not largest_rdma_cycles:
|
||||
raise ValueError(
|
||||
"Requested RDMA (MlxJaccl) but no RDMA-connected cycles available"
|
||||
)
|
||||
largest_cycles = largest_rdma_cycles
|
||||
if command.instance_meta == InstanceMeta.MlxJaccl and smallest_rdma_cycles != []:
|
||||
smallest_cycles = smallest_rdma_cycles
|
||||
|
||||
cycles_with_leaf_nodes: list[Cycle] = [
|
||||
cycle
|
||||
for cycle in largest_cycles
|
||||
for cycle in smallest_cycles
|
||||
if any(topology.node_is_leaf(node_id) for node_id in cycle)
|
||||
]
|
||||
|
||||
selected_cycle = max(
|
||||
cycles_with_leaf_nodes if cycles_with_leaf_nodes != [] else largest_cycles,
|
||||
cycles_with_leaf_nodes if cycles_with_leaf_nodes != [] else smallest_cycles,
|
||||
key=lambda cycle: sum(
|
||||
(node_memory[node_id].ram_available for node_id in cycle),
|
||||
start=Memory(),
|
||||
|
||||
@@ -37,11 +37,11 @@ def filter_cycles_by_memory(
|
||||
return filtered_cycles
|
||||
|
||||
|
||||
def get_largest_cycles(
|
||||
def get_smallest_cycles(
|
||||
cycles: list[Cycle],
|
||||
) -> list[Cycle]:
|
||||
max_nodes = max(len(cycle) for cycle in cycles)
|
||||
return [cycle for cycle in cycles if len(cycle) == max_nodes]
|
||||
min_nodes = min(len(cycle) for cycle in cycles)
|
||||
return [cycle for cycle in cycles if len(cycle) == min_nodes]
|
||||
|
||||
|
||||
def allocate_layers_proportionally(
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from exo.shared.types.events import Event
|
||||
from exo.shared.types.state import State
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ProcessManager(Protocol):
|
||||
"""A reconciliation step that examines state and returns corrective events."""
|
||||
|
||||
async def reconcile(self, state: State) -> Sequence[Event]: ...
|
||||
@@ -1,62 +0,0 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import final
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from exo.master.reconcile import instance_connections_healthy, instance_runners_failed
|
||||
from exo.shared.types.events import Event, InstanceDeleted, InstanceRetrying
|
||||
from exo.shared.types.state import State
|
||||
|
||||
MAX_INSTANCE_RETRIES = 3
|
||||
|
||||
|
||||
@final
|
||||
class InstanceHealthReconciler:
|
||||
"""Delete instances whose network connections are broken or whose runners have all failed."""
|
||||
|
||||
async def reconcile(self, state: State) -> Sequence[Event]:
|
||||
events: list[Event] = []
|
||||
for instance_id, instance in state.instances.items():
|
||||
if not instance_connections_healthy(instance, state.topology):
|
||||
events.append(
|
||||
InstanceDeleted(
|
||||
instance_id=instance_id,
|
||||
failure_error="Network connection lost",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
is_failed, error_message = instance_runners_failed(
|
||||
instance, state.runners, state.node_identities
|
||||
)
|
||||
if is_failed:
|
||||
# Retry within the same instance if backed by a MetaInstance
|
||||
mid = instance.meta_instance_id
|
||||
mi = state.meta_instances.get(mid) if mid else None
|
||||
if mid and mi and mi.consecutive_failures < MAX_INSTANCE_RETRIES:
|
||||
logger.info(
|
||||
f"Instance {instance_id} failed (attempt"
|
||||
f" {mi.consecutive_failures + 1}/{MAX_INSTANCE_RETRIES}),"
|
||||
f" retrying: {error_message}"
|
||||
)
|
||||
events.append(
|
||||
InstanceRetrying(
|
||||
instance_id=instance_id,
|
||||
meta_instance_id=mid,
|
||||
failure_error=error_message or "Runner failed",
|
||||
)
|
||||
)
|
||||
else:
|
||||
if mid and mi:
|
||||
logger.warning(
|
||||
f"Instance {instance_id} exceeded retry limit"
|
||||
f" ({MAX_INSTANCE_RETRIES}), deleting:"
|
||||
f" {error_message}"
|
||||
)
|
||||
events.append(
|
||||
InstanceDeleted(
|
||||
instance_id=instance_id,
|
||||
failure_error=error_message,
|
||||
)
|
||||
)
|
||||
return events
|
||||
@@ -1,92 +0,0 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import final
|
||||
|
||||
import anyio
|
||||
from loguru import logger
|
||||
|
||||
from exo.master.reconcile import (
|
||||
find_unsatisfied_meta_instances,
|
||||
try_place_for_meta_instance,
|
||||
)
|
||||
from exo.shared.models.model_cards import ModelCard
|
||||
from exo.shared.types.events import Event, InstanceCreated, MetaInstancePlacementFailed
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
||||
|
||||
MODEL_CARD_LOAD_TIMEOUT_SECONDS = 10
|
||||
|
||||
|
||||
@final
|
||||
class MetaInstanceReconciler:
|
||||
"""Place instances for unsatisfied MetaInstances."""
|
||||
|
||||
async def reconcile(self, state: State) -> Sequence[Event]:
|
||||
all_events: list[Event] = []
|
||||
# Local copy for intermediate tracking — so placement of B
|
||||
# sees A's instance and doesn't double-place on same resources.
|
||||
current_instances: dict[InstanceId, Instance] = dict(state.instances)
|
||||
|
||||
unsatisfied = find_unsatisfied_meta_instances(
|
||||
state.meta_instances,
|
||||
current_instances,
|
||||
state.topology,
|
||||
)
|
||||
for meta_instance in unsatisfied:
|
||||
try:
|
||||
with anyio.fail_after(MODEL_CARD_LOAD_TIMEOUT_SECONDS):
|
||||
model_card = await ModelCard.load(meta_instance.model_id)
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
f"ModelCard.load timed out for {meta_instance.model_id}, skipping this cycle"
|
||||
)
|
||||
continue
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
f"ModelCard.load failed for {meta_instance.model_id}: {exc}"
|
||||
)
|
||||
error = f"Failed to load model card: {exc}"
|
||||
if meta_instance.placement_error != error:
|
||||
all_events.append(
|
||||
MetaInstancePlacementFailed(
|
||||
meta_instance_id=meta_instance.meta_instance_id,
|
||||
reason=error,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
result = try_place_for_meta_instance(
|
||||
meta_instance,
|
||||
model_card,
|
||||
state.topology,
|
||||
current_instances,
|
||||
state.node_memory,
|
||||
state.node_network,
|
||||
state.tasks,
|
||||
)
|
||||
# Update local instance map so next placement sees this one
|
||||
for event in result.events:
|
||||
if isinstance(event, InstanceCreated):
|
||||
logger.info(
|
||||
f"MetaInstance reconciler placed instance"
|
||||
f" {event.instance.instance_id} for"
|
||||
f" {meta_instance.model_id}"
|
||||
)
|
||||
current_instances[event.instance.instance_id] = event.instance
|
||||
all_events.extend(result.events)
|
||||
|
||||
# Emit placement failure if error differs from what's already in state
|
||||
if (
|
||||
result.error is not None
|
||||
and meta_instance.placement_error != result.error
|
||||
):
|
||||
logger.warning(
|
||||
f"MetaInstance placement failed for"
|
||||
f" {meta_instance.model_id}: {result.error}"
|
||||
)
|
||||
all_events.append(
|
||||
MetaInstancePlacementFailed(
|
||||
meta_instance_id=meta_instance.meta_instance_id,
|
||||
reason=result.error,
|
||||
)
|
||||
)
|
||||
return all_events
|
||||
@@ -1,27 +0,0 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import final
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.types.events import Event, NodeTimedOut
|
||||
from exo.shared.types.state import State
|
||||
|
||||
_DEFAULT_TIMEOUT = timedelta(seconds=30)
|
||||
|
||||
|
||||
@final
|
||||
class NodeTimeoutReconciler:
|
||||
"""Time out nodes that haven't been seen recently."""
|
||||
|
||||
def __init__(self, timeout: timedelta = _DEFAULT_TIMEOUT) -> None:
|
||||
self.timeout = timeout
|
||||
|
||||
async def reconcile(self, state: State) -> Sequence[Event]:
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
events: list[Event] = []
|
||||
for node_id, last_seen in state.last_seen.items():
|
||||
if now - last_seen > self.timeout:
|
||||
logger.info(f"Removing node {node_id} due to inactivity")
|
||||
events.append(NodeTimedOut(node_id=node_id))
|
||||
return events
|
||||
@@ -1,244 +0,0 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import NamedTuple
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from exo.master.placement import get_transition_events, place_instance
|
||||
from exo.shared.models.model_cards import ModelCard
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.commands import PlaceInstance
|
||||
from exo.shared.types.common import MetaInstanceId, NodeId
|
||||
from exo.shared.types.events import Event
|
||||
from exo.shared.types.meta_instance import MetaInstance
|
||||
from exo.shared.types.profiling import MemoryUsage, NodeIdentity, NodeNetworkInfo
|
||||
from exo.shared.types.tasks import Task, TaskId
|
||||
from exo.shared.types.topology import RDMAConnection, SocketConnection
|
||||
from exo.shared.types.worker.instances import (
|
||||
BaseInstance,
|
||||
Instance,
|
||||
InstanceId,
|
||||
MlxJacclInstance,
|
||||
MlxRingInstance,
|
||||
)
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerFailed,
|
||||
RunnerId,
|
||||
RunnerShutdown,
|
||||
RunnerStatus,
|
||||
)
|
||||
|
||||
|
||||
class PlacementResult(NamedTuple):
|
||||
"""Result of a placement attempt: events to apply and optional error reason."""
|
||||
|
||||
events: Sequence[Event]
|
||||
error: str | None
|
||||
|
||||
|
||||
def _get_ring_order(instance: BaseInstance) -> list[NodeId]:
|
||||
"""Reconstruct ring order from shard device_rank."""
|
||||
node_ranks: list[tuple[NodeId, int]] = []
|
||||
for node_id, runner_id in instance.shard_assignments.node_to_runner.items():
|
||||
shard = instance.shard_assignments.runner_to_shard[runner_id]
|
||||
node_ranks.append((node_id, shard.device_rank))
|
||||
node_ranks.sort(key=lambda x: x[1])
|
||||
return [node_id for node_id, _ in node_ranks]
|
||||
|
||||
|
||||
def _ring_connections_healthy(instance: MlxRingInstance, topology: Topology) -> bool:
|
||||
"""Check that the specific IPs used by a ring instance still exist in the topology."""
|
||||
ring = _get_ring_order(instance)
|
||||
n = len(ring)
|
||||
for node in ring:
|
||||
hosts = instance.hosts_by_node[node]
|
||||
for idx in range(n):
|
||||
host = hosts[idx]
|
||||
if host.ip in ("0.0.0.0", "198.51.100.1"):
|
||||
continue # self or placeholder
|
||||
# Real connection: node → ring[idx]. Check specific IP.
|
||||
connections = topology.get_all_connections_between(node, ring[idx])
|
||||
if not any(
|
||||
isinstance(c, SocketConnection)
|
||||
and c.sink_multiaddr.ip_address == host.ip
|
||||
for c in connections
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _jaccl_connections_healthy(instance: MlxJacclInstance, topology: Topology) -> bool:
|
||||
"""Check that the specific RDMA interfaces used by a JACCL instance still exist."""
|
||||
ring = _get_ring_order(instance)
|
||||
n = len(ring)
|
||||
for i in range(n):
|
||||
for j in range(n):
|
||||
iface = instance.jaccl_devices[i][j]
|
||||
if iface is None:
|
||||
continue
|
||||
connections = topology.get_all_connections_between(ring[i], ring[j])
|
||||
if not any(
|
||||
isinstance(c, RDMAConnection) and c.source_rdma_iface == iface
|
||||
for c in connections
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def instance_connections_healthy(instance: Instance, topology: Topology) -> bool:
|
||||
"""Check that an instance's nodes and specific connections are still in the topology."""
|
||||
instance_nodes = set(instance.shard_assignments.node_to_runner.keys())
|
||||
if not all(topology.contains_node(n) for n in instance_nodes):
|
||||
return False
|
||||
if len(instance_nodes) <= 1:
|
||||
return True
|
||||
match instance:
|
||||
case MlxRingInstance():
|
||||
return _ring_connections_healthy(instance, topology)
|
||||
case MlxJacclInstance():
|
||||
return _jaccl_connections_healthy(instance, topology)
|
||||
|
||||
|
||||
def instance_runners_failed(
|
||||
instance: Instance,
|
||||
runners: Mapping[RunnerId, RunnerStatus],
|
||||
node_identities: Mapping[NodeId, NodeIdentity],
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Check if an instance's runners have all reached terminal failure states.
|
||||
|
||||
Returns ``(True, error_message)`` when ALL runners are terminal
|
||||
(``RunnerFailed`` or ``RunnerShutdown``) and at least one is ``RunnerFailed``.
|
||||
|
||||
Returns ``(False, None)`` when runners are still active, haven't reported
|
||||
yet, or all gracefully shut down (no ``RunnerFailed``).
|
||||
"""
|
||||
instance_runner_ids = set(instance.shard_assignments.node_to_runner.values())
|
||||
|
||||
if not instance_runner_ids:
|
||||
return False, None
|
||||
|
||||
# Build reverse mapping: runner_id -> node_id
|
||||
runner_to_node: dict[RunnerId, NodeId] = {
|
||||
runner_id: node_id
|
||||
for node_id, runner_id in instance.shard_assignments.node_to_runner.items()
|
||||
}
|
||||
|
||||
has_any_failed = False
|
||||
error_messages: list[str] = []
|
||||
|
||||
for runner_id in instance_runner_ids:
|
||||
status = runners.get(runner_id)
|
||||
if status is None:
|
||||
# Runner hasn't reported yet — instance is still starting
|
||||
return False, None
|
||||
if isinstance(status, RunnerFailed):
|
||||
has_any_failed = True
|
||||
if status.error_message:
|
||||
node_id = runner_to_node.get(runner_id)
|
||||
name = (
|
||||
node_identities[node_id].friendly_name
|
||||
if node_id and node_id in node_identities
|
||||
else node_id or "unknown"
|
||||
)
|
||||
error_messages.append(f"{name}: {status.error_message}")
|
||||
elif isinstance(status, RunnerShutdown):
|
||||
pass # Terminal but not a failure indicator on its own
|
||||
else:
|
||||
# Runner is still active (connecting, loading, running, etc.)
|
||||
return False, None
|
||||
|
||||
if has_any_failed:
|
||||
return True, "; ".join(error_messages) if error_messages else "Runner failed"
|
||||
|
||||
# All runners are Shutdown but none Failed — graceful shutdown, not a failure
|
||||
return False, None
|
||||
|
||||
|
||||
def instance_satisfies_meta_instance(
|
||||
meta_instance: MetaInstance,
|
||||
instance: Instance,
|
||||
) -> bool:
|
||||
"""Check if a single instance satisfies a meta-instance's constraints.
|
||||
|
||||
This is a pure constraint check (model, min_nodes, node_ids).
|
||||
Use ``instance_connections_healthy`` separately for topology health.
|
||||
"""
|
||||
if instance.shard_assignments.model_id != meta_instance.model_id:
|
||||
return False
|
||||
|
||||
instance_nodes = set(instance.shard_assignments.node_to_runner.keys())
|
||||
|
||||
if len(instance_nodes) < meta_instance.min_nodes:
|
||||
return False
|
||||
|
||||
return meta_instance.node_ids is None or set(meta_instance.node_ids).issubset(
|
||||
instance_nodes
|
||||
)
|
||||
|
||||
|
||||
def find_unsatisfied_meta_instances(
|
||||
meta_instances: Mapping[MetaInstanceId, MetaInstance],
|
||||
instances: Mapping[InstanceId, Instance],
|
||||
topology: Topology,
|
||||
) -> Sequence[MetaInstance]:
|
||||
"""Return meta-instances that have no healthy backing instance."""
|
||||
unsatisfied: list[MetaInstance] = []
|
||||
for meta_id, meta_instance in meta_instances.items():
|
||||
has_healthy_backing = any(
|
||||
instance.meta_instance_id == meta_id
|
||||
and instance_connections_healthy(instance, topology)
|
||||
for instance in instances.values()
|
||||
)
|
||||
if not has_healthy_backing:
|
||||
unsatisfied.append(meta_instance)
|
||||
return unsatisfied
|
||||
|
||||
|
||||
def try_place_for_meta_instance(
|
||||
meta_instance: MetaInstance,
|
||||
model_card: ModelCard,
|
||||
topology: Topology,
|
||||
current_instances: Mapping[InstanceId, Instance],
|
||||
node_memory: Mapping[NodeId, MemoryUsage],
|
||||
node_network: Mapping[NodeId, NodeNetworkInfo],
|
||||
tasks: Mapping[TaskId, Task] | None = None,
|
||||
) -> PlacementResult:
|
||||
"""Try to place an instance satisfying the meta-instance constraints.
|
||||
|
||||
Returns a :class:`PlacementResult` with events on success, or an error
|
||||
reason on failure.
|
||||
"""
|
||||
command = PlaceInstance(
|
||||
model_card=model_card,
|
||||
sharding=meta_instance.sharding,
|
||||
instance_meta=meta_instance.instance_meta,
|
||||
min_nodes=meta_instance.min_nodes,
|
||||
)
|
||||
try:
|
||||
target_instances = place_instance(
|
||||
command,
|
||||
topology,
|
||||
current_instances,
|
||||
node_memory,
|
||||
node_network,
|
||||
required_nodes=(
|
||||
set(meta_instance.node_ids) if meta_instance.node_ids else None
|
||||
),
|
||||
)
|
||||
# Tag the new instance with meta_instance_id
|
||||
new_instance_ids = set(target_instances.keys()) - set(current_instances.keys())
|
||||
if new_instance_ids:
|
||||
new_id = next(iter(new_instance_ids))
|
||||
target_instances[new_id] = target_instances[new_id].model_copy(
|
||||
update={"meta_instance_id": meta_instance.meta_instance_id}
|
||||
)
|
||||
return PlacementResult(
|
||||
events=list(
|
||||
get_transition_events(current_instances, target_instances, tasks or {})
|
||||
),
|
||||
error=None,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.debug(
|
||||
f"MetaInstance placement not possible for {meta_instance.model_id}: {e}"
|
||||
)
|
||||
return PlacementResult(events=[], error=str(e))
|
||||
@@ -1,778 +0,0 @@
|
||||
"""Edge-case and regression tests for MetaInstance lifecycle, concurrent operations, and error handling."""
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.master.process_managers.instance_health import (
|
||||
MAX_INSTANCE_RETRIES,
|
||||
InstanceHealthReconciler,
|
||||
)
|
||||
from exo.master.process_managers.meta_instance import MetaInstanceReconciler
|
||||
from exo.master.reconcile import (
|
||||
find_unsatisfied_meta_instances,
|
||||
instance_connections_healthy,
|
||||
instance_runners_failed,
|
||||
instance_satisfies_meta_instance,
|
||||
)
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import Host, MetaInstanceId, NodeId
|
||||
from exo.shared.types.events import (
|
||||
IndexedEvent,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
InstanceRetrying,
|
||||
MetaInstanceCreated,
|
||||
MetaInstanceDeleted,
|
||||
MetaInstancePlacementFailed,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.meta_instance import MetaInstance
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import NodeIdentity
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import LoadModel, TaskId, TaskStatus
|
||||
from exo.shared.types.topology import Connection, SocketConnection
|
||||
from exo.shared.types.worker.instances import (
|
||||
InstanceId,
|
||||
MlxRingInstance,
|
||||
)
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerFailed,
|
||||
RunnerId,
|
||||
RunnerReady,
|
||||
ShardAssignments,
|
||||
)
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
|
||||
# --- Helpers (copied from test_reconcile.py for independence) ---
|
||||
|
||||
|
||||
def _model_card(model_id: str = "test-org/test-model") -> ModelCard:
|
||||
return ModelCard(
|
||||
model_id=ModelId(model_id),
|
||||
storage_size=Memory.from_kb(1000),
|
||||
n_layers=10,
|
||||
hidden_size=30,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
)
|
||||
|
||||
|
||||
def _topology(*node_ids: str, connect: bool = True) -> Topology:
|
||||
t = Topology()
|
||||
nodes = [NodeId(n) for n in node_ids]
|
||||
for n in nodes:
|
||||
t.add_node(n)
|
||||
if connect and len(nodes) > 1:
|
||||
for i in range(len(nodes)):
|
||||
j = (i + 1) % len(nodes)
|
||||
t.add_connection(
|
||||
Connection(
|
||||
source=nodes[i],
|
||||
sink=nodes[j],
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(
|
||||
address=f"/ip4/10.0.0.{j + 1}/tcp/50000"
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
t.add_connection(
|
||||
Connection(
|
||||
source=nodes[j],
|
||||
sink=nodes[i],
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(
|
||||
address=f"/ip4/10.0.0.{i + 1}/tcp/50000"
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
return t
|
||||
|
||||
|
||||
def _meta_instance(
|
||||
model_id: str = "test-org/test-model",
|
||||
*,
|
||||
min_nodes: int = 1,
|
||||
node_ids: list[NodeId] | None = None,
|
||||
meta_instance_id: MetaInstanceId | None = None,
|
||||
consecutive_failures: int = 0,
|
||||
last_failure_error: str | None = None,
|
||||
placement_error: str | None = None,
|
||||
) -> MetaInstance:
|
||||
return MetaInstance(
|
||||
meta_instance_id=meta_instance_id or MetaInstanceId(),
|
||||
model_id=ModelId(model_id),
|
||||
min_nodes=min_nodes,
|
||||
node_ids=node_ids,
|
||||
consecutive_failures=consecutive_failures,
|
||||
last_failure_error=last_failure_error,
|
||||
placement_error=placement_error,
|
||||
)
|
||||
|
||||
|
||||
def _instance(
|
||||
model_id: str = "test-org/test-model",
|
||||
node_ids: list[str] | None = None,
|
||||
instance_id: InstanceId | None = None,
|
||||
meta_instance_id: MetaInstanceId | None = None,
|
||||
) -> tuple[InstanceId, MlxRingInstance]:
|
||||
iid = instance_id or InstanceId()
|
||||
nodes = node_ids or ["node-a"]
|
||||
n = len(nodes)
|
||||
mc = _model_card(model_id)
|
||||
ephemeral_port = 50000
|
||||
node_to_runner = {NodeId(nd): RunnerId() for nd in nodes}
|
||||
runner_to_shard = {
|
||||
runner_id: PipelineShardMetadata(
|
||||
model_card=mc,
|
||||
device_rank=i,
|
||||
world_size=n,
|
||||
start_layer=0,
|
||||
end_layer=mc.n_layers,
|
||||
n_layers=mc.n_layers,
|
||||
)
|
||||
for i, runner_id in enumerate(node_to_runner.values())
|
||||
}
|
||||
hosts_by_node: dict[NodeId, list[Host]] = {}
|
||||
for r, node_str in enumerate(nodes):
|
||||
hosts: list[Host] = []
|
||||
for idx in range(n):
|
||||
if idx == r:
|
||||
hosts.append(Host(ip="0.0.0.0", port=ephemeral_port))
|
||||
elif n > 1 and idx in ((r - 1) % n, (r + 1) % n):
|
||||
hosts.append(Host(ip=f"10.0.0.{idx + 1}", port=ephemeral_port))
|
||||
else:
|
||||
hosts.append(Host(ip="198.51.100.1", port=0))
|
||||
hosts_by_node[NodeId(node_str)] = hosts
|
||||
return iid, MlxRingInstance(
|
||||
instance_id=iid,
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=ModelId(model_id),
|
||||
runner_to_shard=runner_to_shard,
|
||||
node_to_runner=node_to_runner,
|
||||
),
|
||||
hosts_by_node=hosts_by_node,
|
||||
ephemeral_port=ephemeral_port,
|
||||
meta_instance_id=meta_instance_id,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 1. MetaInstance lifecycle edge cases
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_meta_instance_model_is_frozen():
|
||||
"""MetaInstance should be immutable (frozen model)."""
|
||||
meta = _meta_instance()
|
||||
try:
|
||||
meta.model_id = ModelId("something-else")
|
||||
raise AssertionError("Should have raised")
|
||||
except Exception:
|
||||
pass # Expected — frozen model
|
||||
|
||||
|
||||
def test_meta_instance_created_then_deleted_roundtrip():
|
||||
"""Create and delete a MetaInstance through apply — state should be clean."""
|
||||
state = State()
|
||||
meta = _meta_instance()
|
||||
state = apply(
|
||||
state, IndexedEvent(idx=0, event=MetaInstanceCreated(meta_instance=meta))
|
||||
)
|
||||
assert meta.meta_instance_id in state.meta_instances
|
||||
state = apply(
|
||||
state,
|
||||
IndexedEvent(
|
||||
idx=1, event=MetaInstanceDeleted(meta_instance_id=meta.meta_instance_id)
|
||||
),
|
||||
)
|
||||
assert meta.meta_instance_id not in state.meta_instances
|
||||
assert len(state.meta_instances) == 0
|
||||
|
||||
|
||||
def test_delete_nonexistent_meta_instance_is_safe():
|
||||
"""Deleting a MetaInstance that doesn't exist should not crash."""
|
||||
state = State()
|
||||
event = MetaInstanceDeleted(meta_instance_id=MetaInstanceId("nonexistent"))
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
assert len(new_state.meta_instances) == 0
|
||||
|
||||
|
||||
def test_placement_failed_for_nonexistent_meta_instance_is_safe():
|
||||
"""MetaInstancePlacementFailed for unknown ID should not crash."""
|
||||
state = State()
|
||||
event = MetaInstancePlacementFailed(
|
||||
meta_instance_id=MetaInstanceId("nonexistent"),
|
||||
reason="test",
|
||||
)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
assert len(new_state.meta_instances) == 0
|
||||
|
||||
|
||||
def test_multiple_meta_instances_for_same_model():
|
||||
"""Multiple MetaInstances for the same model are tracked independently."""
|
||||
state = State()
|
||||
meta_a = _meta_instance("test-org/model-x")
|
||||
meta_b = _meta_instance("test-org/model-x")
|
||||
state = apply(
|
||||
state, IndexedEvent(idx=0, event=MetaInstanceCreated(meta_instance=meta_a))
|
||||
)
|
||||
state = apply(
|
||||
state, IndexedEvent(idx=1, event=MetaInstanceCreated(meta_instance=meta_b))
|
||||
)
|
||||
assert len(state.meta_instances) == 2
|
||||
assert meta_a.meta_instance_id in state.meta_instances
|
||||
assert meta_b.meta_instance_id in state.meta_instances
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 2. Retry logic edge cases
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_retry_counter_resets_on_successful_instance_creation():
|
||||
"""When a new instance is created for a meta-instance, failures should reset."""
|
||||
meta = _meta_instance(consecutive_failures=2, last_failure_error="old")
|
||||
_, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
state = State(meta_instances={meta.meta_instance_id: meta})
|
||||
state = apply(state, IndexedEvent(idx=0, event=InstanceCreated(instance=inst)))
|
||||
mi = state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == 0
|
||||
# last_failure_error is preserved (for UI display)
|
||||
assert mi.last_failure_error == "old"
|
||||
|
||||
|
||||
async def test_retry_count_increments_through_full_cycle():
|
||||
"""Walk through MAX_INSTANCE_RETRIES worth of retries, then verify delete."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
topology = _topology("node-a")
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
topology=topology,
|
||||
)
|
||||
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
for idx, i in enumerate(range(MAX_INSTANCE_RETRIES)):
|
||||
# Simulate runners failing
|
||||
state_with_runners = state.model_copy(
|
||||
update={"runners": {runner_ids[0]: RunnerFailed(error_message=f"fail-{i}")}}
|
||||
)
|
||||
reconciler = InstanceHealthReconciler()
|
||||
events = await reconciler.reconcile(state_with_runners)
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceRetrying), f"iteration {i}"
|
||||
state = apply(state, IndexedEvent(idx=idx, event=events[0]))
|
||||
|
||||
# After MAX_INSTANCE_RETRIES retries, failure counter should be at max
|
||||
mi = state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == MAX_INSTANCE_RETRIES
|
||||
|
||||
# Next failure should result in deletion
|
||||
state_with_runners = state.model_copy(
|
||||
update={"runners": {runner_ids[0]: RunnerFailed(error_message="final")}}
|
||||
)
|
||||
reconciler = InstanceHealthReconciler()
|
||||
events = await reconciler.reconcile(state_with_runners)
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceDeleted)
|
||||
|
||||
|
||||
async def test_health_reconciler_respects_exact_limit():
|
||||
"""At exactly MAX_INSTANCE_RETRIES, reconciler should delete, not retry."""
|
||||
meta = _meta_instance(consecutive_failures=MAX_INSTANCE_RETRIES)
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
runners={runner_ids[0]: RunnerFailed(error_message="OOM")},
|
||||
topology=_topology("node-a"),
|
||||
)
|
||||
reconciler = InstanceHealthReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceDeleted)
|
||||
|
||||
|
||||
async def test_health_reconciler_at_limit_minus_one_retries():
|
||||
"""At MAX_INSTANCE_RETRIES - 1, reconciler should still retry."""
|
||||
meta = _meta_instance(consecutive_failures=MAX_INSTANCE_RETRIES - 1)
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
runners={runner_ids[0]: RunnerFailed(error_message="OOM")},
|
||||
topology=_topology("node-a"),
|
||||
)
|
||||
reconciler = InstanceHealthReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceRetrying)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 3. Error handling edge cases
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_runners_failed_with_empty_error_message():
|
||||
"""RunnerFailed with empty error_message should still report as failed."""
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
runners = {
|
||||
rid: RunnerFailed(error_message="")
|
||||
for rid in inst.shard_assignments.node_to_runner.values()
|
||||
}
|
||||
is_failed, error = instance_runners_failed(inst, runners, {})
|
||||
assert is_failed is True
|
||||
# Empty error message means we get the fallback
|
||||
assert error == "Runner failed"
|
||||
|
||||
|
||||
def test_runners_failed_with_none_error_message():
|
||||
"""RunnerFailed with None error_message should still report as failed."""
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
runners = {
|
||||
rid: RunnerFailed(error_message=None)
|
||||
for rid in inst.shard_assignments.node_to_runner.values()
|
||||
}
|
||||
is_failed, error = instance_runners_failed(inst, runners, {})
|
||||
assert is_failed is True
|
||||
assert error == "Runner failed"
|
||||
|
||||
|
||||
def test_runners_failed_collects_all_error_messages():
|
||||
"""With multiple failed runners, all error messages should be collected."""
|
||||
_, inst = _instance(node_ids=["node-a", "node-b", "node-c"])
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
runners = {
|
||||
runner_ids[0]: RunnerFailed(error_message="OOM on GPU 0"),
|
||||
runner_ids[1]: RunnerFailed(error_message="OOM on GPU 1"),
|
||||
runner_ids[2]: RunnerFailed(error_message="OOM on GPU 2"),
|
||||
}
|
||||
is_failed, error = instance_runners_failed(inst, runners, {})
|
||||
assert is_failed is True
|
||||
assert error is not None
|
||||
assert "OOM on GPU 0" in error
|
||||
assert "OOM on GPU 1" in error
|
||||
assert "OOM on GPU 2" in error
|
||||
|
||||
|
||||
def test_runners_failed_includes_friendly_name():
|
||||
"""Error messages should include node friendly names when available."""
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
node_id = NodeId("node-a")
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
runners = {runner_ids[0]: RunnerFailed(error_message="OOM")}
|
||||
identities = {node_id: NodeIdentity(friendly_name="My Mac Studio")}
|
||||
is_failed, error = instance_runners_failed(inst, runners, identities)
|
||||
assert is_failed is True
|
||||
assert error is not None
|
||||
assert "My Mac Studio" in error
|
||||
|
||||
|
||||
def test_instance_retrying_for_missing_instance_is_safe():
|
||||
"""InstanceRetrying for an instance not in state should not crash.
|
||||
|
||||
NOTE: When the instance is missing, the handler returns early WITHOUT
|
||||
incrementing the MetaInstance failure counter. This means stale retry
|
||||
events for already-deleted instances are silently dropped. This is
|
||||
acceptable since the InstanceDeleted handler already increments failures.
|
||||
"""
|
||||
meta = _meta_instance()
|
||||
state = State(meta_instances={meta.meta_instance_id: meta})
|
||||
event = InstanceRetrying(
|
||||
instance_id=InstanceId("nonexistent"),
|
||||
meta_instance_id=meta.meta_instance_id,
|
||||
failure_error="crash",
|
||||
)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
# Does not crash, but failure count is NOT incremented (early return)
|
||||
mi = new_state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == 0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 4. Backward compatibility
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_instance_without_meta_instance_id_works():
|
||||
"""Instances created without meta_instance_id should still function normally."""
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
assert inst.meta_instance_id is None
|
||||
topology = _topology("node-a")
|
||||
assert instance_connections_healthy(inst, topology) is True
|
||||
|
||||
|
||||
def test_instance_deleted_without_meta_does_not_affect_meta_instances():
|
||||
"""Deleting an instance without meta_instance_id should not affect meta_instances."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(node_ids=["node-a"]) # no meta_instance_id
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
)
|
||||
event = InstanceDeleted(instance_id=iid, failure_error="crash")
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
mi = new_state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == 0 # unchanged
|
||||
|
||||
|
||||
def test_satisfies_ignores_meta_instance_id_binding():
|
||||
"""instance_satisfies_meta_instance checks constraints only, not binding."""
|
||||
meta = _meta_instance()
|
||||
_, inst = _instance(node_ids=["node-a"]) # no meta_instance_id set
|
||||
# Should match on constraints (model, min_nodes) regardless of binding
|
||||
assert instance_satisfies_meta_instance(meta, inst) is True
|
||||
|
||||
|
||||
def test_find_unsatisfied_uses_binding_not_constraints():
|
||||
"""find_unsatisfied checks meta_instance_id binding, not just constraint matching."""
|
||||
meta = _meta_instance()
|
||||
# Instance matches constraints but is NOT bound to this meta_instance
|
||||
iid, inst = _instance(node_ids=["node-a"])
|
||||
topology = _topology("node-a")
|
||||
result = find_unsatisfied_meta_instances(
|
||||
{meta.meta_instance_id: meta}, {iid: inst}, topology
|
||||
)
|
||||
# Should be unsatisfied because instance.meta_instance_id != meta.meta_instance_id
|
||||
assert list(result) == [meta]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 5. Concurrent / multi-instance scenarios
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def test_health_reconciler_handles_multiple_failing_instances():
|
||||
"""Multiple instances failing simultaneously should each get their own event."""
|
||||
meta_a = _meta_instance()
|
||||
meta_b = _meta_instance()
|
||||
iid_a, inst_a = _instance(
|
||||
node_ids=["node-a"], meta_instance_id=meta_a.meta_instance_id
|
||||
)
|
||||
iid_b, inst_b = _instance(
|
||||
node_ids=["node-b"], meta_instance_id=meta_b.meta_instance_id
|
||||
)
|
||||
runner_ids_a = list(inst_a.shard_assignments.node_to_runner.values())
|
||||
runner_ids_b = list(inst_b.shard_assignments.node_to_runner.values())
|
||||
state = State(
|
||||
meta_instances={
|
||||
meta_a.meta_instance_id: meta_a,
|
||||
meta_b.meta_instance_id: meta_b,
|
||||
},
|
||||
instances={iid_a: inst_a, iid_b: inst_b},
|
||||
runners={
|
||||
runner_ids_a[0]: RunnerFailed(error_message="OOM"),
|
||||
runner_ids_b[0]: RunnerFailed(error_message="OOM"),
|
||||
},
|
||||
topology=_topology("node-a", "node-b"),
|
||||
)
|
||||
reconciler = InstanceHealthReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
assert len(events) == 2
|
||||
# Both should be InstanceRetrying since failures < MAX
|
||||
assert all(isinstance(e, InstanceRetrying) for e in events)
|
||||
instance_ids = {e.instance_id for e in events} # type: ignore[union-attr]
|
||||
assert instance_ids == {iid_a, iid_b}
|
||||
|
||||
|
||||
async def test_health_reconciler_mixed_healthy_and_failing():
|
||||
"""Only failing instances should produce events; healthy ones should not."""
|
||||
meta_healthy = _meta_instance()
|
||||
meta_failing = _meta_instance()
|
||||
iid_h, inst_h = _instance(
|
||||
node_ids=["node-a"], meta_instance_id=meta_healthy.meta_instance_id
|
||||
)
|
||||
iid_f, inst_f = _instance(
|
||||
node_ids=["node-b"], meta_instance_id=meta_failing.meta_instance_id
|
||||
)
|
||||
runner_ids_h = list(inst_h.shard_assignments.node_to_runner.values())
|
||||
runner_ids_f = list(inst_f.shard_assignments.node_to_runner.values())
|
||||
state = State(
|
||||
meta_instances={
|
||||
meta_healthy.meta_instance_id: meta_healthy,
|
||||
meta_failing.meta_instance_id: meta_failing,
|
||||
},
|
||||
instances={iid_h: inst_h, iid_f: inst_f},
|
||||
runners={
|
||||
runner_ids_h[0]: RunnerReady(),
|
||||
runner_ids_f[0]: RunnerFailed(error_message="crash"),
|
||||
},
|
||||
topology=_topology("node-a", "node-b"),
|
||||
)
|
||||
reconciler = InstanceHealthReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceRetrying)
|
||||
assert events[0].instance_id == iid_f
|
||||
|
||||
|
||||
async def test_meta_instance_reconciler_empty_state():
|
||||
"""MetaInstanceReconciler with no meta_instances should produce no events."""
|
||||
state = State()
|
||||
reconciler = MetaInstanceReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
assert len(events) == 0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 6. Placement error tracking
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_placement_failed_sets_error():
|
||||
"""MetaInstancePlacementFailed should set placement_error on the MetaInstance."""
|
||||
meta = _meta_instance()
|
||||
state = State(meta_instances={meta.meta_instance_id: meta})
|
||||
event = MetaInstancePlacementFailed(
|
||||
meta_instance_id=meta.meta_instance_id,
|
||||
reason="Not enough memory",
|
||||
)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
mi = new_state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.placement_error == "Not enough memory"
|
||||
|
||||
|
||||
def test_instance_created_clears_placement_error():
|
||||
"""InstanceCreated should clear placement_error on the MetaInstance."""
|
||||
meta = _meta_instance(placement_error="Not enough memory")
|
||||
_, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
state = State(meta_instances={meta.meta_instance_id: meta})
|
||||
state = apply(state, IndexedEvent(idx=0, event=InstanceCreated(instance=inst)))
|
||||
mi = state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.placement_error is None
|
||||
|
||||
|
||||
def test_placement_error_does_not_increment_failures():
|
||||
"""Placement failures should only set placement_error, not increment consecutive_failures."""
|
||||
meta = _meta_instance()
|
||||
state = State(meta_instances={meta.meta_instance_id: meta})
|
||||
event = MetaInstancePlacementFailed(
|
||||
meta_instance_id=meta.meta_instance_id,
|
||||
reason="No resources",
|
||||
)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
mi = new_state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == 0
|
||||
assert mi.placement_error == "No resources"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 7. State serialization roundtrip
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_state_with_meta_instances_serializes():
|
||||
"""State with meta_instances should serialize and deserialize correctly."""
|
||||
meta = _meta_instance(consecutive_failures=2, last_failure_error="test")
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
)
|
||||
json_str = state.model_dump_json()
|
||||
restored = State.model_validate_json(json_str)
|
||||
assert meta.meta_instance_id in restored.meta_instances
|
||||
mi = restored.meta_instances[meta.meta_instance_id]
|
||||
assert mi.model_id == meta.model_id
|
||||
assert mi.consecutive_failures == 2
|
||||
assert mi.last_failure_error == "test"
|
||||
assert iid in restored.instances
|
||||
assert restored.instances[iid].meta_instance_id == meta.meta_instance_id
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 8. MetaInstanceReconciler error handling
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def test_meta_instance_reconciler_model_load_error_emits_placement_failed(
|
||||
monkeypatch: "pytest.MonkeyPatch",
|
||||
):
|
||||
"""When ModelCard.load raises, reconciler emits MetaInstancePlacementFailed."""
|
||||
import exo.master.process_managers.meta_instance as mi_mod
|
||||
|
||||
meta = _meta_instance()
|
||||
topo = _topology("node-a")
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
topology=topo,
|
||||
)
|
||||
|
||||
async def _failing_load(_model_id: ModelId) -> ModelCard:
|
||||
raise RuntimeError("Network error")
|
||||
|
||||
monkeypatch.setattr(
|
||||
mi_mod, "ModelCard", type("MC", (), {"load": staticmethod(_failing_load)})
|
||||
)
|
||||
|
||||
reconciler = MetaInstanceReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
|
||||
placement_failed = [e for e in events if isinstance(e, MetaInstancePlacementFailed)]
|
||||
assert len(placement_failed) == 1
|
||||
assert "Failed to load model card" in placement_failed[0].reason
|
||||
assert meta.meta_instance_id == placement_failed[0].meta_instance_id
|
||||
|
||||
|
||||
async def test_meta_instance_reconciler_model_load_error_skips_dedup(
|
||||
monkeypatch: "pytest.MonkeyPatch",
|
||||
):
|
||||
"""When ModelCard.load error matches existing placement_error, no duplicate event."""
|
||||
import exo.master.process_managers.meta_instance as mi_mod
|
||||
|
||||
meta = _meta_instance(placement_error="Failed to load model card: Network error")
|
||||
topo = _topology("node-a")
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
topology=topo,
|
||||
)
|
||||
|
||||
async def _failing_load(_model_id: ModelId) -> ModelCard:
|
||||
raise RuntimeError("Network error")
|
||||
|
||||
monkeypatch.setattr(
|
||||
mi_mod, "ModelCard", type("MC", (), {"load": staticmethod(_failing_load)})
|
||||
)
|
||||
|
||||
reconciler = MetaInstanceReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
|
||||
# Error matches existing placement_error, so no duplicate event emitted
|
||||
assert len(events) == 0
|
||||
|
||||
|
||||
async def test_meta_instance_reconciler_continues_after_error(
|
||||
monkeypatch: "pytest.MonkeyPatch",
|
||||
):
|
||||
"""Reconciler should continue to next meta-instance after one fails to load."""
|
||||
import exo.master.process_managers.meta_instance as mi_mod
|
||||
|
||||
meta_a = _meta_instance(model_id="org/model-a")
|
||||
meta_b = _meta_instance(model_id="org/model-b")
|
||||
topo = _topology("node-a")
|
||||
state = State(
|
||||
meta_instances={
|
||||
meta_a.meta_instance_id: meta_a,
|
||||
meta_b.meta_instance_id: meta_b,
|
||||
},
|
||||
topology=topo,
|
||||
)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def _load_second_fails(model_id: ModelId) -> ModelCard:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
raise RuntimeError(f"Cannot load {model_id}")
|
||||
|
||||
monkeypatch.setattr(
|
||||
mi_mod, "ModelCard", type("MC", (), {"load": staticmethod(_load_second_fails)})
|
||||
)
|
||||
|
||||
reconciler = MetaInstanceReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
|
||||
# Both meta-instances should have been attempted (not short-circuited)
|
||||
assert call_count == 2
|
||||
# Both should have placement failed events
|
||||
placement_failed = [e for e in events if isinstance(e, MetaInstancePlacementFailed)]
|
||||
assert len(placement_failed) == 2
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 8. Cascade delete with task cancellation
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_cascade_delete_cancels_active_tasks():
|
||||
"""Deleting a MetaInstance should cancel tasks on backing instances.
|
||||
|
||||
Regression test: previously, cascade-deleting backing instances via
|
||||
DeleteMetaInstance did not emit TaskStatusUpdated(Cancelled) for active
|
||||
tasks, leaving orphaned task references in state.
|
||||
"""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
task_id = TaskId()
|
||||
task = LoadModel(task_id=task_id, instance_id=iid, task_status=TaskStatus.Running)
|
||||
|
||||
# Build state with meta-instance, backing instance, and active task
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
tasks={task_id: task},
|
||||
topology=_topology("node-a"),
|
||||
)
|
||||
|
||||
# Simulate the cascade-delete event sequence produced by main.py:
|
||||
# 1. MetaInstanceDeleted
|
||||
# 2. TaskStatusUpdated(Cancelled) for active tasks
|
||||
# 3. InstanceDeleted
|
||||
idx = 0
|
||||
state = apply(
|
||||
state,
|
||||
IndexedEvent(
|
||||
idx=idx,
|
||||
event=MetaInstanceDeleted(meta_instance_id=meta.meta_instance_id),
|
||||
),
|
||||
)
|
||||
idx += 1
|
||||
state = apply(
|
||||
state,
|
||||
IndexedEvent(
|
||||
idx=idx,
|
||||
event=TaskStatusUpdated(task_id=task_id, task_status=TaskStatus.Cancelled),
|
||||
),
|
||||
)
|
||||
idx += 1
|
||||
state = apply(
|
||||
state,
|
||||
IndexedEvent(idx=idx, event=InstanceDeleted(instance_id=iid)),
|
||||
)
|
||||
|
||||
# Verify everything is cleaned up
|
||||
assert len(state.meta_instances) == 0
|
||||
assert len(state.instances) == 0
|
||||
assert state.tasks[task_id].task_status == TaskStatus.Cancelled
|
||||
|
||||
|
||||
def test_cascade_delete_skips_completed_tasks():
|
||||
"""Cascade delete should only cancel Pending/Running tasks, not completed ones."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
|
||||
running_task_id = TaskId()
|
||||
completed_task_id = TaskId()
|
||||
running_task = LoadModel(
|
||||
task_id=running_task_id, instance_id=iid, task_status=TaskStatus.Running
|
||||
)
|
||||
completed_task = LoadModel(
|
||||
task_id=completed_task_id, instance_id=iid, task_status=TaskStatus.Complete
|
||||
)
|
||||
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
tasks={running_task_id: running_task, completed_task_id: completed_task},
|
||||
topology=_topology("node-a"),
|
||||
)
|
||||
|
||||
# Only the running task should be cancelled — we verify the logic pattern
|
||||
# by checking which tasks are Pending or Running
|
||||
active_tasks = [
|
||||
t
|
||||
for t in state.tasks.values()
|
||||
if t.instance_id == iid
|
||||
and t.task_status in (TaskStatus.Pending, TaskStatus.Running)
|
||||
]
|
||||
assert len(active_tasks) == 1
|
||||
assert active_tasks[0].task_id == running_task_id
|
||||
@@ -3,10 +3,10 @@ import pytest
|
||||
from exo.master.placement_utils import (
|
||||
allocate_layers_proportionally,
|
||||
filter_cycles_by_memory,
|
||||
get_largest_cycles,
|
||||
get_mlx_jaccl_coordinators,
|
||||
get_shard_assignments,
|
||||
get_shard_assignments_for_pipeline_parallel,
|
||||
get_smallest_cycles,
|
||||
)
|
||||
from exo.master.tests.conftest import (
|
||||
create_node_memory,
|
||||
@@ -143,7 +143,7 @@ def test_filter_multiple_cycles_by_memory():
|
||||
}
|
||||
|
||||
|
||||
def test_get_largest_cycles():
|
||||
def test_get_smallest_cycles():
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
@@ -175,12 +175,12 @@ def test_get_largest_cycles():
|
||||
cycles = [c for c in topology.get_cycles() if len(c) != 1] # ignore singletons
|
||||
|
||||
# act
|
||||
largest_cycles = get_largest_cycles(cycles)
|
||||
smallest_cycles = get_smallest_cycles(cycles)
|
||||
|
||||
# assert
|
||||
assert len(largest_cycles) == 1
|
||||
assert len(largest_cycles[0]) == 3
|
||||
assert set(n for n in largest_cycles[0]) == {node_a_id, node_b_id, node_c_id}
|
||||
assert len(smallest_cycles) == 1
|
||||
assert len(smallest_cycles[0]) == 2
|
||||
assert set(n for n in smallest_cycles[0]) == {node_a_id, node_b_id}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@@ -1,742 +0,0 @@
|
||||
from exo.master.process_managers.instance_health import InstanceHealthReconciler
|
||||
from exo.master.reconcile import (
|
||||
find_unsatisfied_meta_instances,
|
||||
instance_connections_healthy,
|
||||
instance_runners_failed,
|
||||
instance_satisfies_meta_instance,
|
||||
)
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import Host, MetaInstanceId, NodeId
|
||||
from exo.shared.types.events import (
|
||||
IndexedEvent,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
InstanceRetrying,
|
||||
MetaInstanceCreated,
|
||||
MetaInstanceDeleted,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.meta_instance import MetaInstance
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.topology import Connection, SocketConnection
|
||||
from exo.shared.types.worker.instances import (
|
||||
InstanceId,
|
||||
MlxRingInstance,
|
||||
)
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerFailed,
|
||||
RunnerId,
|
||||
RunnerLoading,
|
||||
RunnerReady,
|
||||
RunnerShutdown,
|
||||
ShardAssignments,
|
||||
)
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
|
||||
|
||||
def _model_card(model_id: str = "test-org/test-model") -> ModelCard:
|
||||
return ModelCard(
|
||||
model_id=ModelId(model_id),
|
||||
storage_size=Memory.from_kb(1000),
|
||||
n_layers=10,
|
||||
hidden_size=30,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
)
|
||||
|
||||
|
||||
def _topology(*node_ids: str, connect: bool = True) -> Topology:
|
||||
"""Build a topology with nodes connected in a bidirectional ring with unique IPs.
|
||||
|
||||
Node at index ``i`` gets IP ``10.0.0.{i+1}``. Edges go in both directions
|
||||
between consecutive nodes (including wrap-around).
|
||||
"""
|
||||
t = Topology()
|
||||
nodes = [NodeId(n) for n in node_ids]
|
||||
for n in nodes:
|
||||
t.add_node(n)
|
||||
if connect and len(nodes) > 1:
|
||||
for i in range(len(nodes)):
|
||||
j = (i + 1) % len(nodes)
|
||||
t.add_connection(
|
||||
Connection(
|
||||
source=nodes[i],
|
||||
sink=nodes[j],
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(
|
||||
address=f"/ip4/10.0.0.{j + 1}/tcp/50000"
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
t.add_connection(
|
||||
Connection(
|
||||
source=nodes[j],
|
||||
sink=nodes[i],
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(
|
||||
address=f"/ip4/10.0.0.{i + 1}/tcp/50000"
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
return t
|
||||
|
||||
|
||||
def _meta_instance(
|
||||
model_id: str = "test-org/test-model",
|
||||
*,
|
||||
min_nodes: int = 1,
|
||||
node_ids: list[NodeId] | None = None,
|
||||
meta_instance_id: MetaInstanceId | None = None,
|
||||
) -> MetaInstance:
|
||||
return MetaInstance(
|
||||
meta_instance_id=meta_instance_id or MetaInstanceId(),
|
||||
model_id=ModelId(model_id),
|
||||
min_nodes=min_nodes,
|
||||
node_ids=node_ids,
|
||||
)
|
||||
|
||||
|
||||
def _instance(
|
||||
model_id: str = "test-org/test-model",
|
||||
node_ids: list[str] | None = None,
|
||||
instance_id: InstanceId | None = None,
|
||||
meta_instance_id: MetaInstanceId | None = None,
|
||||
) -> tuple[InstanceId, MlxRingInstance]:
|
||||
"""Create a test instance with hosts_by_node matching ``_topology()`` IPs."""
|
||||
iid = instance_id or InstanceId()
|
||||
nodes = node_ids or ["node-a"]
|
||||
n = len(nodes)
|
||||
mc = _model_card(model_id)
|
||||
ephemeral_port = 50000
|
||||
node_to_runner = {NodeId(nd): RunnerId() for nd in nodes}
|
||||
runner_to_shard = {
|
||||
runner_id: PipelineShardMetadata(
|
||||
model_card=mc,
|
||||
device_rank=i,
|
||||
world_size=n,
|
||||
start_layer=0,
|
||||
end_layer=mc.n_layers,
|
||||
n_layers=mc.n_layers,
|
||||
)
|
||||
for i, runner_id in enumerate(node_to_runner.values())
|
||||
}
|
||||
# Build hosts_by_node with IPs matching _topology() convention:
|
||||
# node at index idx has IP 10.0.0.{idx+1}
|
||||
hosts_by_node: dict[NodeId, list[Host]] = {}
|
||||
for r, node_str in enumerate(nodes):
|
||||
hosts: list[Host] = []
|
||||
for idx in range(n):
|
||||
if idx == r:
|
||||
hosts.append(Host(ip="0.0.0.0", port=ephemeral_port))
|
||||
elif n > 1 and idx in ((r - 1) % n, (r + 1) % n):
|
||||
hosts.append(Host(ip=f"10.0.0.{idx + 1}", port=ephemeral_port))
|
||||
else:
|
||||
hosts.append(Host(ip="198.51.100.1", port=0))
|
||||
hosts_by_node[NodeId(node_str)] = hosts
|
||||
return iid, MlxRingInstance(
|
||||
instance_id=iid,
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=ModelId(model_id),
|
||||
runner_to_shard=runner_to_shard,
|
||||
node_to_runner=node_to_runner,
|
||||
),
|
||||
hosts_by_node=hosts_by_node,
|
||||
ephemeral_port=ephemeral_port,
|
||||
meta_instance_id=meta_instance_id,
|
||||
)
|
||||
|
||||
|
||||
# --- instance_satisfies_meta_instance (pure constraint matching) ---
|
||||
|
||||
|
||||
def test_satisfies_matching_model():
|
||||
meta = _meta_instance()
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
assert instance_satisfies_meta_instance(meta, inst) is True
|
||||
|
||||
|
||||
def test_not_satisfies_wrong_model():
|
||||
meta = _meta_instance("test-org/model-a")
|
||||
_, inst = _instance("test-org/model-b")
|
||||
assert instance_satisfies_meta_instance(meta, inst) is False
|
||||
|
||||
|
||||
def test_not_satisfies_missing_required_node():
|
||||
meta = _meta_instance(node_ids=[NodeId("node-c")])
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
assert instance_satisfies_meta_instance(meta, inst) is False
|
||||
|
||||
|
||||
def test_not_satisfies_fewer_than_min_nodes():
|
||||
meta = _meta_instance(min_nodes=3)
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
assert instance_satisfies_meta_instance(meta, inst) is False
|
||||
|
||||
|
||||
def test_satisfies_with_node_ids_specified():
|
||||
meta = _meta_instance(node_ids=[NodeId("node-a"), NodeId("node-b")], min_nodes=2)
|
||||
_, inst = _instance(node_ids=["node-a", "node-b", "node-c"])
|
||||
assert instance_satisfies_meta_instance(meta, inst) is True
|
||||
|
||||
|
||||
# --- instance_connections_healthy ---
|
||||
|
||||
|
||||
def test_healthy_single_node_present():
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
topology = _topology("node-a")
|
||||
assert instance_connections_healthy(inst, topology) is True
|
||||
|
||||
|
||||
def test_unhealthy_single_node_missing():
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
topology = Topology() # empty
|
||||
assert instance_connections_healthy(inst, topology) is False
|
||||
|
||||
|
||||
def test_healthy_two_node_ring():
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
topology = _topology("node-a", "node-b")
|
||||
assert instance_connections_healthy(inst, topology) is True
|
||||
|
||||
|
||||
def test_unhealthy_two_node_edge_removed():
|
||||
"""Nodes present but edge removed — ring broken."""
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
topology = _topology("node-a", "node-b", connect=False)
|
||||
assert instance_connections_healthy(inst, topology) is False
|
||||
|
||||
|
||||
def test_unhealthy_two_node_ip_changed():
|
||||
"""Edge exists but with a different IP than instance was configured with."""
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
# Build topology with different IPs than _instance() expects
|
||||
topology = Topology()
|
||||
topology.add_node(NodeId("node-a"))
|
||||
topology.add_node(NodeId("node-b"))
|
||||
topology.add_connection(
|
||||
Connection(
|
||||
source=NodeId("node-a"),
|
||||
sink=NodeId("node-b"),
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/192.168.99.99/tcp/50000")
|
||||
),
|
||||
)
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(
|
||||
source=NodeId("node-b"),
|
||||
sink=NodeId("node-a"),
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/192.168.99.98/tcp/50000")
|
||||
),
|
||||
)
|
||||
)
|
||||
assert instance_connections_healthy(inst, topology) is False
|
||||
|
||||
|
||||
def test_healthy_three_node_ring():
|
||||
_, inst = _instance(node_ids=["node-a", "node-b", "node-c"])
|
||||
topology = _topology("node-a", "node-b", "node-c")
|
||||
assert instance_connections_healthy(inst, topology) is True
|
||||
|
||||
|
||||
def test_unhealthy_three_node_one_edge_removed():
|
||||
"""Remove one edge from a three-node ring — instance unhealthy."""
|
||||
_, inst = _instance(node_ids=["node-a", "node-b", "node-c"])
|
||||
# Build topology with one direction of one edge missing
|
||||
topology = Topology()
|
||||
nodes = [NodeId("node-a"), NodeId("node-b"), NodeId("node-c")]
|
||||
for n in nodes:
|
||||
topology.add_node(n)
|
||||
# Add all edges except node-a → node-b
|
||||
topology.add_connection(
|
||||
Connection(
|
||||
source=nodes[1],
|
||||
sink=nodes[0],
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/50000")
|
||||
),
|
||||
)
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(
|
||||
source=nodes[1],
|
||||
sink=nodes[2],
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.3/tcp/50000")
|
||||
),
|
||||
)
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(
|
||||
source=nodes[2],
|
||||
sink=nodes[1],
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.2/tcp/50000")
|
||||
),
|
||||
)
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(
|
||||
source=nodes[2],
|
||||
sink=nodes[0],
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/50000")
|
||||
),
|
||||
)
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(
|
||||
source=nodes[0],
|
||||
sink=nodes[2],
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.3/tcp/50000")
|
||||
),
|
||||
)
|
||||
)
|
||||
# Missing: node-a → node-b (ip 10.0.0.2)
|
||||
assert instance_connections_healthy(inst, topology) is False
|
||||
|
||||
|
||||
def test_unhealthy_node_missing_from_topology():
|
||||
"""Instance has a node that's not in the topology at all."""
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
topology = _topology("node-a") # node-b not present
|
||||
assert instance_connections_healthy(inst, topology) is False
|
||||
|
||||
|
||||
def test_healthy_extra_nodes_in_topology():
|
||||
"""Extra nodes in topology don't affect instance health."""
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
topology = _topology("node-a", "node-b", "node-c")
|
||||
assert instance_connections_healthy(inst, topology) is True
|
||||
|
||||
|
||||
# --- find_unsatisfied_meta_instances ---
|
||||
|
||||
|
||||
def test_unsatisfied_no_meta_instances():
|
||||
result = find_unsatisfied_meta_instances({}, {}, Topology())
|
||||
assert list(result) == []
|
||||
|
||||
|
||||
def test_unsatisfied_one_satisfied():
|
||||
meta = _meta_instance()
|
||||
id_a, inst_a = _instance(meta_instance_id=meta.meta_instance_id)
|
||||
topology = _topology("node-a")
|
||||
result = find_unsatisfied_meta_instances(
|
||||
{meta.meta_instance_id: meta},
|
||||
{id_a: inst_a},
|
||||
topology,
|
||||
)
|
||||
assert list(result) == []
|
||||
|
||||
|
||||
def test_unsatisfied_one_not_satisfied():
|
||||
meta = _meta_instance("test-org/model-x")
|
||||
id_a, inst_a = _instance("test-org/model-y")
|
||||
topology = _topology("node-a")
|
||||
result = find_unsatisfied_meta_instances(
|
||||
{meta.meta_instance_id: meta}, {id_a: inst_a}, topology
|
||||
)
|
||||
assert list(result) == [meta]
|
||||
|
||||
|
||||
def test_unsatisfied_mix():
|
||||
meta_satisfied = _meta_instance("test-org/model-a")
|
||||
meta_unsatisfied = _meta_instance("test-org/model-b")
|
||||
id_a, inst_a = _instance(
|
||||
"test-org/model-a", meta_instance_id=meta_satisfied.meta_instance_id
|
||||
)
|
||||
topology = _topology("node-a")
|
||||
result = find_unsatisfied_meta_instances(
|
||||
{
|
||||
meta_satisfied.meta_instance_id: meta_satisfied,
|
||||
meta_unsatisfied.meta_instance_id: meta_unsatisfied,
|
||||
},
|
||||
{id_a: inst_a},
|
||||
topology,
|
||||
)
|
||||
assert list(result) == [meta_unsatisfied]
|
||||
|
||||
|
||||
def test_unsatisfied_node_disconnect():
|
||||
meta = _meta_instance()
|
||||
id_a, inst_a = _instance(
|
||||
node_ids=["node-a", "node-b"], meta_instance_id=meta.meta_instance_id
|
||||
)
|
||||
topology = _topology("node-a") # node-b disconnected
|
||||
result = find_unsatisfied_meta_instances(
|
||||
{meta.meta_instance_id: meta},
|
||||
{id_a: inst_a},
|
||||
topology,
|
||||
)
|
||||
assert list(result) == [meta]
|
||||
|
||||
|
||||
def test_unsatisfied_edge_break():
|
||||
"""Instance exists but its connections broke — meta-instance becomes unsatisfied."""
|
||||
meta = _meta_instance()
|
||||
id_a, inst_a = _instance(
|
||||
node_ids=["node-a", "node-b"], meta_instance_id=meta.meta_instance_id
|
||||
)
|
||||
topology = _topology("node-a", "node-b", connect=False) # nodes present, no edges
|
||||
result = find_unsatisfied_meta_instances(
|
||||
{meta.meta_instance_id: meta},
|
||||
{id_a: inst_a},
|
||||
topology,
|
||||
)
|
||||
assert list(result) == [meta]
|
||||
|
||||
|
||||
def test_unsatisfied_idempotent():
|
||||
meta = _meta_instance("test-org/model-x")
|
||||
topology = _topology("node-a")
|
||||
meta_instances = {meta.meta_instance_id: meta}
|
||||
instances: dict[InstanceId, MlxRingInstance] = {}
|
||||
result_1 = list(
|
||||
find_unsatisfied_meta_instances(meta_instances, instances, topology)
|
||||
)
|
||||
result_2 = list(
|
||||
find_unsatisfied_meta_instances(meta_instances, instances, topology)
|
||||
)
|
||||
assert result_1 == result_2
|
||||
|
||||
|
||||
def test_unsatisfied_exclusive_binding():
|
||||
"""Two MetaInstances for the same model: one is bound via meta_instance_id, the other is unsatisfied."""
|
||||
meta_a = _meta_instance("test-org/model-x")
|
||||
meta_b = _meta_instance("test-org/model-x")
|
||||
id_inst, inst = _instance(
|
||||
"test-org/model-x", meta_instance_id=meta_a.meta_instance_id
|
||||
)
|
||||
topology = _topology("node-a")
|
||||
result = find_unsatisfied_meta_instances(
|
||||
{
|
||||
meta_a.meta_instance_id: meta_a,
|
||||
meta_b.meta_instance_id: meta_b,
|
||||
},
|
||||
{id_inst: inst},
|
||||
topology,
|
||||
)
|
||||
assert list(result) == [meta_b]
|
||||
|
||||
|
||||
# --- apply handlers ---
|
||||
|
||||
|
||||
def test_apply_meta_instance_created():
|
||||
state = State()
|
||||
meta = _meta_instance()
|
||||
event = MetaInstanceCreated(meta_instance=meta)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
assert meta.meta_instance_id in new_state.meta_instances
|
||||
assert new_state.meta_instances[meta.meta_instance_id] == meta
|
||||
|
||||
|
||||
def test_apply_meta_instance_deleted():
|
||||
meta = _meta_instance()
|
||||
state = State(meta_instances={meta.meta_instance_id: meta})
|
||||
event = MetaInstanceDeleted(meta_instance_id=meta.meta_instance_id)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
assert meta.meta_instance_id not in new_state.meta_instances
|
||||
|
||||
|
||||
def test_apply_meta_instance_deleted_clears_failure_info():
|
||||
meta = _meta_instance().model_copy(
|
||||
update={"consecutive_failures": 2, "last_failure_error": "OOM"}
|
||||
)
|
||||
state = State(meta_instances={meta.meta_instance_id: meta})
|
||||
event = MetaInstanceDeleted(meta_instance_id=meta.meta_instance_id)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
assert meta.meta_instance_id not in new_state.meta_instances
|
||||
|
||||
|
||||
# --- instance_runners_failed ---
|
||||
|
||||
|
||||
def test_runners_failed_all_failed():
|
||||
"""All runners in RunnerFailed -> instance is failed."""
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
runners = {
|
||||
rid: RunnerFailed(error_message="OOM")
|
||||
for rid in inst.shard_assignments.node_to_runner.values()
|
||||
}
|
||||
is_failed, error = instance_runners_failed(inst, runners, {})
|
||||
assert is_failed is True
|
||||
assert error is not None
|
||||
assert "OOM" in error
|
||||
|
||||
|
||||
def test_runners_failed_mixed_failed_shutdown():
|
||||
"""One Failed + one Shutdown = failed."""
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
runners = {
|
||||
runner_ids[0]: RunnerFailed(error_message="crash"),
|
||||
runner_ids[1]: RunnerShutdown(),
|
||||
}
|
||||
is_failed, error = instance_runners_failed(inst, runners, {})
|
||||
assert is_failed is True
|
||||
assert error is not None
|
||||
assert "crash" in error
|
||||
|
||||
|
||||
def test_runners_not_failed_all_shutdown():
|
||||
"""All Shutdown (graceful) = not a failure."""
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
runners = {
|
||||
rid: RunnerShutdown() for rid in inst.shard_assignments.node_to_runner.values()
|
||||
}
|
||||
is_failed, _ = instance_runners_failed(inst, runners, {})
|
||||
assert is_failed is False
|
||||
|
||||
|
||||
def test_runners_not_failed_still_active():
|
||||
"""Some runners still active = not failed yet."""
|
||||
_, inst = _instance(node_ids=["node-a", "node-b"])
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
runners = {
|
||||
runner_ids[0]: RunnerFailed(error_message="OOM"),
|
||||
runner_ids[1]: RunnerLoading(),
|
||||
}
|
||||
is_failed, _ = instance_runners_failed(inst, runners, {})
|
||||
assert is_failed is False
|
||||
|
||||
|
||||
def test_runners_not_failed_no_status():
|
||||
"""Runner not yet reported = not failed."""
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
is_failed, _ = instance_runners_failed(inst, {}, {})
|
||||
assert is_failed is False
|
||||
|
||||
|
||||
def test_runners_not_failed_healthy():
|
||||
"""Runners in Ready state = not failed."""
|
||||
_, inst = _instance(node_ids=["node-a"])
|
||||
runners = {
|
||||
rid: RunnerReady() for rid in inst.shard_assignments.node_to_runner.values()
|
||||
}
|
||||
is_failed, _ = instance_runners_failed(inst, runners, {})
|
||||
assert is_failed is False
|
||||
|
||||
|
||||
# --- failure tracking in apply_instance_deleted ---
|
||||
|
||||
|
||||
def test_apply_instance_deleted_tracks_failure():
|
||||
"""InstanceDeleted with failure_error increments meta instance failure count."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
)
|
||||
event = InstanceDeleted(instance_id=iid, failure_error="Runner OOM")
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
mi = new_state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == 1
|
||||
assert mi.last_failure_error == "Runner OOM"
|
||||
|
||||
|
||||
def test_apply_instance_deleted_increments_failure():
|
||||
"""Subsequent failures increment the counter."""
|
||||
meta = _meta_instance().model_copy(
|
||||
update={"consecutive_failures": 2, "last_failure_error": "previous error"}
|
||||
)
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
)
|
||||
event = InstanceDeleted(instance_id=iid, failure_error="new error")
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
mi = new_state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == 3
|
||||
assert mi.last_failure_error == "new error"
|
||||
|
||||
|
||||
def test_apply_instance_deleted_no_failure_no_tracking():
|
||||
"""InstanceDeleted without failure_error does not track."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
)
|
||||
event = InstanceDeleted(instance_id=iid)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
mi = new_state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == 0
|
||||
|
||||
|
||||
def test_apply_instance_deleted_orphan_no_tracking():
|
||||
"""InstanceDeleted for orphan instance (no meta_instance_id) does not track."""
|
||||
iid, inst = _instance(node_ids=["node-a"])
|
||||
state = State(instances={iid: inst})
|
||||
event = InstanceDeleted(instance_id=iid, failure_error="crash")
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
assert len(new_state.meta_instances) == 0
|
||||
|
||||
|
||||
# --- InstanceRetrying ---
|
||||
|
||||
|
||||
def test_apply_instance_retrying_removes_runners():
|
||||
"""InstanceRetrying removes the instance's runners from state but keeps the instance."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(
|
||||
node_ids=["node-a", "node-b"], meta_instance_id=meta.meta_instance_id
|
||||
)
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
runners = {
|
||||
runner_ids[0]: RunnerFailed(error_message="OOM"),
|
||||
runner_ids[1]: RunnerShutdown(),
|
||||
}
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
runners=runners,
|
||||
)
|
||||
event = InstanceRetrying(
|
||||
instance_id=iid,
|
||||
meta_instance_id=meta.meta_instance_id,
|
||||
failure_error="OOM",
|
||||
)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
# Instance still exists
|
||||
assert iid in new_state.instances
|
||||
# Runners removed
|
||||
assert runner_ids[0] not in new_state.runners
|
||||
assert runner_ids[1] not in new_state.runners
|
||||
|
||||
|
||||
def test_apply_instance_retrying_increments_failure():
|
||||
"""InstanceRetrying increments consecutive_failures on the MetaInstance."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
)
|
||||
event = InstanceRetrying(
|
||||
instance_id=iid,
|
||||
meta_instance_id=meta.meta_instance_id,
|
||||
failure_error="crash",
|
||||
)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
mi = new_state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == 1
|
||||
assert mi.last_failure_error == "crash"
|
||||
|
||||
|
||||
def test_apply_instance_retrying_skips_missing_runners():
|
||||
"""InstanceRetrying doesn't assert if runners haven't reported yet."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
# No runners in state at all
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
)
|
||||
event = InstanceRetrying(
|
||||
instance_id=iid,
|
||||
meta_instance_id=meta.meta_instance_id,
|
||||
failure_error="crash",
|
||||
)
|
||||
# Should not raise
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
assert iid in new_state.instances
|
||||
|
||||
|
||||
def test_apply_instance_created_resets_failure_counter():
|
||||
"""InstanceCreated resets consecutive_failures but preserves last_failure_error."""
|
||||
meta = _meta_instance().model_copy(
|
||||
update={"consecutive_failures": 3, "last_failure_error": "old error"}
|
||||
)
|
||||
_, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
state = State(meta_instances={meta.meta_instance_id: meta})
|
||||
event = InstanceCreated(instance=inst)
|
||||
new_state = apply(state, IndexedEvent(idx=0, event=event))
|
||||
mi = new_state.meta_instances[meta.meta_instance_id]
|
||||
assert mi.consecutive_failures == 0
|
||||
assert mi.last_failure_error == "old error"
|
||||
assert mi.placement_error is None
|
||||
|
||||
|
||||
# --- InstanceHealthReconciler retry-vs-delete ---
|
||||
|
||||
|
||||
async def test_health_reconciler_retries_when_under_limit():
|
||||
"""InstanceHealthReconciler emits InstanceRetrying when consecutive_failures < 3."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
runners={runner_ids[0]: RunnerFailed(error_message="OOM")},
|
||||
topology=_topology("node-a"),
|
||||
)
|
||||
reconciler = InstanceHealthReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceRetrying)
|
||||
assert events[0].instance_id == iid
|
||||
assert events[0].meta_instance_id == meta.meta_instance_id
|
||||
|
||||
|
||||
async def test_health_reconciler_deletes_when_limit_reached():
|
||||
"""InstanceHealthReconciler emits InstanceDeleted when consecutive_failures >= 3."""
|
||||
meta = _meta_instance().model_copy(update={"consecutive_failures": 3})
|
||||
iid, inst = _instance(node_ids=["node-a"], meta_instance_id=meta.meta_instance_id)
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
runners={runner_ids[0]: RunnerFailed(error_message="OOM")},
|
||||
topology=_topology("node-a"),
|
||||
)
|
||||
reconciler = InstanceHealthReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceDeleted)
|
||||
|
||||
|
||||
async def test_health_reconciler_deletes_without_meta_instance():
|
||||
"""Instances without a MetaInstance are deleted immediately on runner failure."""
|
||||
iid, inst = _instance(node_ids=["node-a"])
|
||||
runner_ids = list(inst.shard_assignments.node_to_runner.values())
|
||||
state = State(
|
||||
instances={iid: inst},
|
||||
runners={runner_ids[0]: RunnerFailed(error_message="crash")},
|
||||
topology=_topology("node-a"),
|
||||
)
|
||||
reconciler = InstanceHealthReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceDeleted)
|
||||
|
||||
|
||||
async def test_health_reconciler_network_failure_always_deletes():
|
||||
"""Network failure always triggers InstanceDeleted regardless of retry count."""
|
||||
meta = _meta_instance()
|
||||
iid, inst = _instance(
|
||||
node_ids=["node-a", "node-b"], meta_instance_id=meta.meta_instance_id
|
||||
)
|
||||
state = State(
|
||||
meta_instances={meta.meta_instance_id: meta},
|
||||
instances={iid: inst},
|
||||
topology=_topology("node-a"), # node-b missing
|
||||
)
|
||||
reconciler = InstanceHealthReconciler()
|
||||
events = await reconciler.reconcile(state)
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], InstanceDeleted)
|
||||
assert events[0].failure_error == "Network connection lost"
|
||||
@@ -4,7 +4,7 @@ from datetime import datetime
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.types.common import MetaInstanceId, NodeId
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
@@ -12,12 +12,6 @@ from exo.shared.types.events import (
|
||||
InputChunkReceived,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
InstanceRetrying,
|
||||
JacclSideChannelData,
|
||||
JacclSideChannelGathered,
|
||||
MetaInstanceCreated,
|
||||
MetaInstanceDeleted,
|
||||
MetaInstancePlacementFailed,
|
||||
NodeDownloadProgress,
|
||||
NodeGatheredInfo,
|
||||
NodeTimedOut,
|
||||
@@ -34,7 +28,6 @@ from exo.shared.types.events import (
|
||||
TracesCollected,
|
||||
TracesMerged,
|
||||
)
|
||||
from exo.shared.types.meta_instance import MetaInstance
|
||||
from exo.shared.types.profiling import (
|
||||
NodeIdentity,
|
||||
NodeNetworkInfo,
|
||||
@@ -73,22 +66,12 @@ def event_apply(event: Event, state: State) -> State:
|
||||
| InputChunkReceived()
|
||||
| TracesCollected()
|
||||
| TracesMerged()
|
||||
| JacclSideChannelData()
|
||||
| JacclSideChannelGathered()
|
||||
): # Pass-through events that don't modify state
|
||||
return state
|
||||
case InstanceCreated():
|
||||
return apply_instance_created(event, state)
|
||||
case InstanceDeleted():
|
||||
return apply_instance_deleted(event, state)
|
||||
case InstanceRetrying():
|
||||
return apply_instance_retrying(event, state)
|
||||
case MetaInstanceCreated():
|
||||
return apply_meta_instance_created(event, state)
|
||||
case MetaInstanceDeleted():
|
||||
return apply_meta_instance_deleted(event, state)
|
||||
case MetaInstancePlacementFailed():
|
||||
return apply_meta_instance_placement_failed(event, state)
|
||||
case NodeTimedOut():
|
||||
return apply_node_timed_out(event, state)
|
||||
case NodeDownloadProgress():
|
||||
@@ -191,123 +174,20 @@ def apply_task_failed(event: TaskFailed, state: State) -> State:
|
||||
return state.model_copy(update={"tasks": new_tasks})
|
||||
|
||||
|
||||
def _update_meta_instance(
|
||||
state: State, mid: MetaInstanceId, **fields: object
|
||||
) -> Mapping[MetaInstanceId, MetaInstance]:
|
||||
mi = state.meta_instances[mid]
|
||||
return {**state.meta_instances, mid: mi.model_copy(update=fields)}
|
||||
|
||||
|
||||
def apply_instance_created(event: InstanceCreated, state: State) -> State:
|
||||
instance = event.instance
|
||||
new_instances: Mapping[InstanceId, Instance] = {
|
||||
**state.instances,
|
||||
instance.instance_id: instance,
|
||||
}
|
||||
update: dict[str, object] = {"instances": new_instances}
|
||||
# Reset failure tracking when a new instance is created for a meta-instance
|
||||
if instance.meta_instance_id and instance.meta_instance_id in state.meta_instances:
|
||||
mi = state.meta_instances[instance.meta_instance_id]
|
||||
if mi.placement_error is not None or mi.consecutive_failures > 0:
|
||||
update["meta_instances"] = _update_meta_instance(
|
||||
state,
|
||||
instance.meta_instance_id,
|
||||
placement_error=None,
|
||||
consecutive_failures=0,
|
||||
)
|
||||
return state.model_copy(update=update)
|
||||
return state.model_copy(update={"instances": new_instances})
|
||||
|
||||
|
||||
def apply_instance_deleted(event: InstanceDeleted, state: State) -> State:
|
||||
deleted_instance = state.instances.get(event.instance_id)
|
||||
new_instances: Mapping[InstanceId, Instance] = {
|
||||
iid: inst for iid, inst in state.instances.items() if iid != event.instance_id
|
||||
}
|
||||
update: dict[str, object] = {"instances": new_instances}
|
||||
|
||||
# Track failure on the MetaInstance itself
|
||||
if (
|
||||
event.failure_error
|
||||
and deleted_instance
|
||||
and deleted_instance.meta_instance_id
|
||||
and deleted_instance.meta_instance_id in state.meta_instances
|
||||
):
|
||||
mid = deleted_instance.meta_instance_id
|
||||
mi = state.meta_instances[mid]
|
||||
update["meta_instances"] = {
|
||||
**state.meta_instances,
|
||||
mid: mi.model_copy(
|
||||
update={
|
||||
"consecutive_failures": mi.consecutive_failures + 1,
|
||||
"last_failure_error": event.failure_error,
|
||||
}
|
||||
),
|
||||
}
|
||||
|
||||
return state.model_copy(update=update)
|
||||
|
||||
|
||||
def apply_instance_retrying(event: InstanceRetrying, state: State) -> State:
|
||||
"""Runners failed but retry limit not reached — remove runners, keep instance."""
|
||||
instance = state.instances.get(event.instance_id)
|
||||
if instance is None:
|
||||
# Instance was already deleted (e.g. cascade from DeleteMetaInstance).
|
||||
# The InstanceDeleted handler already incremented consecutive_failures
|
||||
# on the MetaInstance, so skipping here avoids double-counting.
|
||||
return state
|
||||
|
||||
# Remove all runners belonging to this instance from state
|
||||
runner_ids_to_remove = set(instance.shard_assignments.node_to_runner.values())
|
||||
new_runners: Mapping[RunnerId, RunnerStatus] = {
|
||||
rid: rs for rid, rs in state.runners.items() if rid not in runner_ids_to_remove
|
||||
}
|
||||
|
||||
update: dict[str, object] = {"runners": new_runners}
|
||||
|
||||
# Increment failure count on the MetaInstance
|
||||
if event.meta_instance_id in state.meta_instances:
|
||||
update["meta_instances"] = _update_meta_instance(
|
||||
state,
|
||||
event.meta_instance_id,
|
||||
consecutive_failures=state.meta_instances[
|
||||
event.meta_instance_id
|
||||
].consecutive_failures
|
||||
+ 1,
|
||||
last_failure_error=event.failure_error,
|
||||
)
|
||||
|
||||
return state.model_copy(update=update)
|
||||
|
||||
|
||||
def apply_meta_instance_created(event: MetaInstanceCreated, state: State) -> State:
|
||||
new_meta: Mapping[MetaInstanceId, MetaInstance] = {
|
||||
**state.meta_instances,
|
||||
event.meta_instance.meta_instance_id: event.meta_instance,
|
||||
}
|
||||
return state.model_copy(update={"meta_instances": new_meta})
|
||||
|
||||
|
||||
def apply_meta_instance_deleted(event: MetaInstanceDeleted, state: State) -> State:
|
||||
new_meta: Mapping[MetaInstanceId, MetaInstance] = {
|
||||
mid: mi
|
||||
for mid, mi in state.meta_instances.items()
|
||||
if mid != event.meta_instance_id
|
||||
}
|
||||
return state.model_copy(update={"meta_instances": new_meta})
|
||||
|
||||
|
||||
def apply_meta_instance_placement_failed(
|
||||
event: MetaInstancePlacementFailed, state: State
|
||||
) -> State:
|
||||
if event.meta_instance_id not in state.meta_instances:
|
||||
return state
|
||||
return state.model_copy(
|
||||
update={
|
||||
"meta_instances": _update_meta_instance(
|
||||
state, event.meta_instance_id, placement_error=event.reason
|
||||
)
|
||||
}
|
||||
)
|
||||
return state.model_copy(update={"instances": new_instances})
|
||||
|
||||
|
||||
def apply_runner_status_updated(event: RunnerStatusUpdated, state: State) -> State:
|
||||
|
||||
@@ -6,7 +6,7 @@ from uuid import uuid4
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.types.common import CommandId, MetaInstanceId, NodeId
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding, ShardMetadata
|
||||
@@ -262,26 +262,6 @@ class DeleteInstanceResponse(BaseModel):
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class CreateMetaInstanceParams(BaseModel):
|
||||
model_id: ModelId
|
||||
sharding: Sharding = Sharding.Pipeline
|
||||
instance_meta: InstanceMeta = InstanceMeta.MlxRing
|
||||
min_nodes: int = 1
|
||||
node_ids: list[NodeId] | None = None
|
||||
|
||||
|
||||
class CreateMetaInstanceResponse(BaseModel):
|
||||
message: str
|
||||
command_id: CommandId
|
||||
meta_instance_id: MetaInstanceId
|
||||
|
||||
|
||||
class DeleteMetaInstanceResponse(BaseModel):
|
||||
message: str
|
||||
command_id: CommandId
|
||||
meta_instance_id: MetaInstanceId
|
||||
|
||||
|
||||
class AdvancedImageParams(BaseModel):
|
||||
seed: Annotated[int, Field(ge=0)] | None = None
|
||||
num_inference_steps: Annotated[int, Field(ge=1, le=100)] | None = None
|
||||
|
||||
@@ -6,8 +6,7 @@ from exo.shared.types.api import (
|
||||
ImageGenerationTaskParams,
|
||||
)
|
||||
from exo.shared.types.chunks import InputImageChunk
|
||||
from exo.shared.types.common import CommandId, MetaInstanceId, NodeId
|
||||
from exo.shared.types.meta_instance import MetaInstance
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.text_generation import TextGenerationTaskParams
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding, ShardMetadata
|
||||
@@ -53,14 +52,6 @@ class TaskCancelled(BaseCommand):
|
||||
cancelled_command_id: CommandId
|
||||
|
||||
|
||||
class CreateMetaInstance(BaseCommand):
|
||||
meta_instance: MetaInstance
|
||||
|
||||
|
||||
class DeleteMetaInstance(BaseCommand):
|
||||
meta_instance_id: MetaInstanceId
|
||||
|
||||
|
||||
class TaskFinished(BaseCommand):
|
||||
finished_command_id: CommandId
|
||||
|
||||
@@ -103,9 +94,6 @@ Command = (
|
||||
| CreateInstance
|
||||
| DeleteInstance
|
||||
| TaskCancelled
|
||||
| CreateMetaInstance
|
||||
| DeleteMetaInstance
|
||||
| TaskCancelled
|
||||
| TaskFinished
|
||||
| SendInputChunk
|
||||
)
|
||||
|
||||
@@ -42,10 +42,6 @@ class CommandId(Id):
|
||||
pass
|
||||
|
||||
|
||||
class MetaInstanceId(Id):
|
||||
"""Identifier for a MetaInstance."""
|
||||
|
||||
|
||||
class Host(CamelCaseModel):
|
||||
ip: str
|
||||
port: int
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
import base64
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Annotated, final
|
||||
from typing import final
|
||||
|
||||
from pydantic import BeforeValidator, Field, PlainSerializer
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.topology import Connection
|
||||
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
|
||||
from exo.shared.types.common import CommandId, Id, MetaInstanceId, NodeId, SessionId
|
||||
from exo.shared.types.meta_instance import MetaInstance
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.downloads import DownloadProgress
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
||||
@@ -17,28 +14,6 @@ from exo.utils.info_gatherer.info_gatherer import GatheredInfo
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, FrozenModel, TaggedModel
|
||||
|
||||
|
||||
def _decode_base64_bytes(v: bytes | str) -> bytes:
|
||||
if isinstance(v, bytes):
|
||||
return v
|
||||
return base64.b64decode(v)
|
||||
|
||||
|
||||
def _encode_base64_bytes(v: bytes) -> str:
|
||||
return base64.b64encode(v).decode("ascii")
|
||||
|
||||
|
||||
Base64Bytes = Annotated[
|
||||
bytes,
|
||||
BeforeValidator(_decode_base64_bytes),
|
||||
PlainSerializer(_encode_base64_bytes, return_type=str),
|
||||
]
|
||||
"""bytes that serialize to/from base64 strings in JSON.
|
||||
|
||||
Needed because TaggedModel's wrap validator converts JSON→Python validation
|
||||
context, which breaks strict-mode bytes deserialization from JSON strings.
|
||||
"""
|
||||
|
||||
|
||||
class EventId(Id):
|
||||
"""
|
||||
Newtype around `ID`
|
||||
@@ -91,30 +66,6 @@ class InstanceCreated(BaseEvent):
|
||||
|
||||
class InstanceDeleted(BaseEvent):
|
||||
instance_id: InstanceId
|
||||
failure_error: str | None = None
|
||||
|
||||
|
||||
class MetaInstanceCreated(BaseEvent):
|
||||
meta_instance: MetaInstance
|
||||
|
||||
|
||||
class MetaInstanceDeleted(BaseEvent):
|
||||
meta_instance_id: MetaInstanceId
|
||||
|
||||
|
||||
@final
|
||||
class MetaInstancePlacementFailed(BaseEvent):
|
||||
meta_instance_id: MetaInstanceId
|
||||
reason: str
|
||||
|
||||
|
||||
@final
|
||||
class InstanceRetrying(BaseEvent):
|
||||
"""Runners failed but retry count is below the limit — restart runners, keep instance."""
|
||||
|
||||
instance_id: InstanceId
|
||||
meta_instance_id: MetaInstanceId
|
||||
failure_error: str
|
||||
|
||||
|
||||
class RunnerStatusUpdated(BaseEvent):
|
||||
@@ -181,25 +132,6 @@ class TracesMerged(BaseEvent):
|
||||
traces: list[TraceEventData]
|
||||
|
||||
|
||||
@final
|
||||
class JacclSideChannelData(BaseEvent):
|
||||
"""A runner's local contribution to a JACCL SideChannel all_gather round."""
|
||||
|
||||
instance_id: InstanceId
|
||||
runner_id: RunnerId
|
||||
sequence: int
|
||||
data: Base64Bytes
|
||||
|
||||
|
||||
@final
|
||||
class JacclSideChannelGathered(BaseEvent):
|
||||
"""Gathered result of a JACCL SideChannel all_gather round."""
|
||||
|
||||
instance_id: InstanceId
|
||||
sequence: int
|
||||
gathered_data: Mapping[RunnerId, Base64Bytes]
|
||||
|
||||
|
||||
Event = (
|
||||
TestEvent
|
||||
| TaskCreated
|
||||
@@ -209,10 +141,6 @@ Event = (
|
||||
| TaskAcknowledged
|
||||
| InstanceCreated
|
||||
| InstanceDeleted
|
||||
| InstanceRetrying
|
||||
| MetaInstanceCreated
|
||||
| MetaInstanceDeleted
|
||||
| MetaInstancePlacementFailed
|
||||
| RunnerStatusUpdated
|
||||
| RunnerDeleted
|
||||
| NodeTimedOut
|
||||
@@ -224,8 +152,6 @@ Event = (
|
||||
| TopologyEdgeDeleted
|
||||
| TracesCollected
|
||||
| TracesMerged
|
||||
| JacclSideChannelData
|
||||
| JacclSideChannelGathered
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
from typing import final
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.common import MetaInstanceId, NodeId
|
||||
from exo.shared.types.worker.instances import InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.utils.pydantic_ext import FrozenModel
|
||||
|
||||
|
||||
@final
|
||||
class MetaInstance(FrozenModel):
|
||||
"""Declarative constraint: ensure an instance matching these parameters always exists."""
|
||||
|
||||
meta_instance_id: MetaInstanceId = Field(default_factory=MetaInstanceId)
|
||||
model_id: ModelId
|
||||
sharding: Sharding = Sharding.Pipeline
|
||||
instance_meta: InstanceMeta = InstanceMeta.MlxRing
|
||||
min_nodes: int = 1
|
||||
node_ids: list[NodeId] | None = None
|
||||
# Failure tracking
|
||||
placement_error: str | None = None
|
||||
consecutive_failures: int = 0
|
||||
last_failure_error: str | None = None
|
||||
@@ -6,8 +6,7 @@ from pydantic import ConfigDict, Field, field_serializer, field_validator
|
||||
from pydantic.alias_generators import to_camel
|
||||
|
||||
from exo.shared.topology import Topology, TopologySnapshot
|
||||
from exo.shared.types.common import MetaInstanceId, NodeId
|
||||
from exo.shared.types.meta_instance import MetaInstance
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.profiling import (
|
||||
DiskUsage,
|
||||
MemoryUsage,
|
||||
@@ -42,7 +41,6 @@ class State(CamelCaseModel):
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
instances: Mapping[InstanceId, Instance] = {}
|
||||
meta_instances: Mapping[MetaInstanceId, MetaInstance] = {}
|
||||
runners: Mapping[RunnerId, RunnerStatus] = {}
|
||||
downloads: Mapping[NodeId, Sequence[DownloadProgress]] = {}
|
||||
tasks: Mapping[TaskId, Task] = {}
|
||||
|
||||
@@ -2,7 +2,7 @@ from enum import Enum
|
||||
|
||||
from pydantic import model_validator
|
||||
|
||||
from exo.shared.types.common import Host, Id, MetaInstanceId, NodeId
|
||||
from exo.shared.types.common import Host, Id, NodeId
|
||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
|
||||
@@ -19,7 +19,6 @@ class InstanceMeta(str, Enum):
|
||||
class BaseInstance(TaggedModel):
|
||||
instance_id: InstanceId
|
||||
shard_assignments: ShardAssignments
|
||||
meta_instance_id: MetaInstanceId | None = None
|
||||
|
||||
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
|
||||
return self.shard_assignments.runner_to_shard.get(runner_id, None)
|
||||
|
||||
@@ -574,11 +574,6 @@ def mlx_cleanup(
|
||||
|
||||
|
||||
def mx_any(bool_: bool, group: Group | None) -> bool:
|
||||
"""Synchronize a boolean across all distributed nodes.
|
||||
|
||||
Returns True if any node has bool_=True. Uses all_sum so every
|
||||
node participates in the collective — preventing GPU deadlocks.
|
||||
"""
|
||||
if group is None:
|
||||
return bool_
|
||||
num_true = mx.distributed.all_sum(
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import contextlib
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from random import random
|
||||
from typing import Iterator
|
||||
|
||||
import anyio
|
||||
from anyio import CancelScope, ClosedResourceError, create_task_group, fail_after
|
||||
from anyio import CancelScope, create_task_group, fail_after
|
||||
from anyio.abc import TaskGroup
|
||||
from loguru import logger
|
||||
|
||||
@@ -25,7 +24,6 @@ from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
InputChunkReceived,
|
||||
JacclSideChannelGathered,
|
||||
NodeGatheredInfo,
|
||||
TaskCreated,
|
||||
TaskStatusUpdated,
|
||||
@@ -118,8 +116,7 @@ class Worker:
|
||||
self.command_sender.close()
|
||||
self.download_command_sender.close()
|
||||
for runner in self.runners.values():
|
||||
with contextlib.suppress(ClosedResourceError):
|
||||
runner.shutdown()
|
||||
runner.shutdown()
|
||||
|
||||
async def _forward_info(self, recv: Receiver[GatheredInfo]):
|
||||
with recv as info_stream:
|
||||
@@ -162,15 +159,6 @@ class Worker:
|
||||
for idx, event in indexed_events:
|
||||
self.state = apply(self.state, IndexedEvent(idx=idx, event=event))
|
||||
|
||||
# Dispatch JACCL gathered events to the relevant RunnerSupervisor
|
||||
if isinstance(event, JacclSideChannelGathered):
|
||||
for runner in self.runners.values():
|
||||
if (
|
||||
runner.bound_instance.instance.instance_id
|
||||
== event.instance_id
|
||||
):
|
||||
runner.notify_gathered(event)
|
||||
|
||||
# Buffer input image chunks for image editing
|
||||
if isinstance(event, InputChunkReceived):
|
||||
cmd_id = event.command_id
|
||||
@@ -248,8 +236,7 @@ class Worker:
|
||||
)
|
||||
)
|
||||
finally:
|
||||
with contextlib.suppress(ClosedResourceError):
|
||||
runner.shutdown()
|
||||
runner.shutdown()
|
||||
case CancelTask(
|
||||
cancelled_task_id=cancelled_task_id, runner_id=runner_id
|
||||
):
|
||||
|
||||
@@ -35,7 +35,6 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerLoading,
|
||||
RunnerReady,
|
||||
RunnerRunning,
|
||||
RunnerShutdown,
|
||||
RunnerStatus,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
@@ -57,7 +56,7 @@ def plan(
|
||||
return (
|
||||
_cancel_tasks(runners, tasks)
|
||||
or _kill_runner(runners, all_runners, instances)
|
||||
or _create_runner(node_id, runners, instances, all_runners)
|
||||
or _create_runner(node_id, runners, instances)
|
||||
or _model_needs_download(node_id, runners, global_download_status)
|
||||
or _init_distributed_backend(runners, all_runners)
|
||||
or _load_model(runners, all_runners, global_download_status)
|
||||
@@ -76,12 +75,6 @@ def _kill_runner(
|
||||
if (instance_id := runner.bound_instance.instance.instance_id) not in instances:
|
||||
return Shutdown(instance_id=instance_id, runner_id=runner_id)
|
||||
|
||||
# Master removed our runner from state (retry signal) and process is dead
|
||||
if runner_id not in all_runners and isinstance(
|
||||
runner.status, (RunnerFailed, RunnerShutdown)
|
||||
):
|
||||
return Shutdown(instance_id=instance_id, runner_id=runner_id)
|
||||
|
||||
for (
|
||||
global_runner_id
|
||||
) in runner.bound_instance.instance.shard_assignments.node_to_runner.values():
|
||||
@@ -99,7 +92,6 @@ def _create_runner(
|
||||
node_id: NodeId,
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
instances: Mapping[InstanceId, Instance],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
) -> CreateRunner | None:
|
||||
for instance in instances.values():
|
||||
runner_id = instance.shard_assignments.node_to_runner.get(node_id, None)
|
||||
@@ -109,16 +101,6 @@ def _create_runner(
|
||||
if runner_id in runners:
|
||||
continue
|
||||
|
||||
# Don't create while any peer runner is in a terminal state — wait for
|
||||
# the master to emit InstanceRetrying which removes them from state.
|
||||
has_terminal_peer = any(
|
||||
isinstance(all_runners.get(peer_rid), (RunnerFailed, RunnerShutdown))
|
||||
for peer_rid in instance.shard_assignments.node_to_runner.values()
|
||||
if peer_rid != runner_id
|
||||
)
|
||||
if has_terminal_peer:
|
||||
continue
|
||||
|
||||
shard = instance.shard(runner_id)
|
||||
assert shard is not None
|
||||
|
||||
|
||||
@@ -17,7 +17,6 @@ def entrypoint(
|
||||
task_receiver: MpReceiver[Task],
|
||||
cancel_receiver: MpReceiver[TaskId],
|
||||
_logger: "loguru.Logger",
|
||||
pipe_fifo_paths: tuple[str, str] | None = None,
|
||||
) -> None:
|
||||
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
|
||||
if fast_synch_override == "on" or (
|
||||
@@ -31,16 +30,6 @@ def entrypoint(
|
||||
else:
|
||||
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
|
||||
|
||||
# Open JACCL FIFOs by path and set env vars for C++ SideChannel.
|
||||
# Named pipes (FIFOs) work across multiprocessing spawn (macOS default).
|
||||
if pipe_fifo_paths is not None:
|
||||
fifo_c2p, fifo_p2c = pipe_fifo_paths
|
||||
# C++ reads gathered data from p2c (PIPE_IN), writes local data to c2p (PIPE_OUT)
|
||||
pipe_in_fd = os.open(fifo_p2c, os.O_RDONLY)
|
||||
pipe_out_fd = os.open(fifo_c2p, os.O_WRONLY)
|
||||
os.environ["MLX_JACCL_PIPE_IN"] = str(pipe_in_fd)
|
||||
os.environ["MLX_JACCL_PIPE_OUT"] = str(pipe_out_fd)
|
||||
|
||||
global logger
|
||||
logger = _logger
|
||||
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
import contextlib
|
||||
import os
|
||||
import signal
|
||||
import struct
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from multiprocessing import Process
|
||||
from typing import Self
|
||||
|
||||
@@ -18,14 +14,12 @@ from loguru import logger
|
||||
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
JacclSideChannelData,
|
||||
JacclSideChannelGathered,
|
||||
RunnerStatusUpdated,
|
||||
TaskAcknowledged,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerConnecting,
|
||||
RunnerFailed,
|
||||
@@ -40,26 +34,6 @@ from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.utils.channels import MpReceiver, MpSender, Sender, mp_channel
|
||||
from exo.worker.runner.bootstrap import entrypoint
|
||||
|
||||
|
||||
def _pipe_read_exact(fd: int, n: int) -> bytes | None:
|
||||
"""Read exactly n bytes from a file descriptor. Returns None on EOF."""
|
||||
data = b""
|
||||
while len(data) < n:
|
||||
chunk = os.read(fd, n - len(data))
|
||||
if not chunk:
|
||||
return None
|
||||
data += chunk
|
||||
return data
|
||||
|
||||
|
||||
def _pipe_write_all(fd: int, data: bytes) -> None:
|
||||
"""Write all bytes to a file descriptor."""
|
||||
view = memoryview(data)
|
||||
while view:
|
||||
written = os.write(fd, view)
|
||||
view = view[written:]
|
||||
|
||||
|
||||
PREFILL_TIMEOUT_SECONDS = 60
|
||||
DECODE_TIMEOUT_SECONDS = 5
|
||||
|
||||
@@ -74,19 +48,10 @@ class RunnerSupervisor:
|
||||
_task_sender: MpSender[Task]
|
||||
_event_sender: Sender[Event]
|
||||
_cancel_sender: MpSender[TaskId]
|
||||
_pipe_read_fd: int | None = None # Python reads runner's pipe output
|
||||
_pipe_write_fd: int | None = None # Python writes gathered data to runner
|
||||
_child_pipe_fds: tuple[int, int] | None = None # fds to close after fork
|
||||
_fifo_dir: str | None = None # Temp dir for FIFO files (for cleanup)
|
||||
_fifo_c2p: str | None = None # FIFO path: C++ writes → Python reads
|
||||
_fifo_p2c: str | None = None # FIFO path: Python writes → C++ reads
|
||||
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
|
||||
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
|
||||
completed: set[TaskId] = field(default_factory=set, init=False)
|
||||
cancelled: set[TaskId] = field(default_factory=set, init=False)
|
||||
_gathered_waiters: dict[
|
||||
int, tuple[anyio.Event, JacclSideChannelGathered | None]
|
||||
] = field(default_factory=dict, init=False)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
@@ -100,23 +65,6 @@ class RunnerSupervisor:
|
||||
task_sender, task_recv = mp_channel[Task]()
|
||||
cancel_sender, cancel_recv = mp_channel[TaskId]()
|
||||
|
||||
# For MlxJaccl instances, create named pipes (FIFOs) for SideChannel relay.
|
||||
# Named pipes work across multiprocessing.Process spawn (macOS default).
|
||||
# FIFO c2p: C++ writes local data → Python reads it
|
||||
# FIFO p2c: Python writes gathered data → C++ reads it
|
||||
fifo_dir: str | None = None
|
||||
fifo_c2p: str | None = None
|
||||
fifo_p2c: str | None = None
|
||||
pipe_fifo_paths: tuple[str, str] | None = None
|
||||
|
||||
if isinstance(bound_instance.instance, MlxJacclInstance):
|
||||
fifo_dir = tempfile.mkdtemp(prefix="exo_jaccl_")
|
||||
fifo_c2p = os.path.join(fifo_dir, "c2p") # C++ → Python
|
||||
fifo_p2c = os.path.join(fifo_dir, "p2c") # Python → C++
|
||||
os.mkfifo(fifo_c2p)
|
||||
os.mkfifo(fifo_p2c)
|
||||
pipe_fifo_paths = (fifo_c2p, fifo_p2c)
|
||||
|
||||
runner_process = Process(
|
||||
target=entrypoint,
|
||||
args=(
|
||||
@@ -125,7 +73,6 @@ class RunnerSupervisor:
|
||||
task_recv,
|
||||
cancel_recv,
|
||||
logger,
|
||||
pipe_fifo_paths,
|
||||
),
|
||||
daemon=True,
|
||||
)
|
||||
@@ -141,57 +88,21 @@ class RunnerSupervisor:
|
||||
_task_sender=task_sender,
|
||||
_cancel_sender=cancel_sender,
|
||||
_event_sender=event_sender,
|
||||
_fifo_dir=fifo_dir,
|
||||
_fifo_c2p=fifo_c2p,
|
||||
_fifo_p2c=fifo_p2c,
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
async def run(self):
|
||||
self.runner_process.start()
|
||||
|
||||
if self._fifo_c2p is not None and self._fifo_p2c is not None:
|
||||
# Open FIFOs from parent side. These block until child opens the other end,
|
||||
# so we run them in threads concurrently to avoid deadlock.
|
||||
fifo_c2p = self._fifo_c2p
|
||||
fifo_p2c = self._fifo_p2c
|
||||
|
||||
async def open_read() -> None:
|
||||
self._pipe_read_fd = await to_thread.run_sync(
|
||||
partial(os.open, fifo_c2p, os.O_RDONLY)
|
||||
)
|
||||
|
||||
async def open_write() -> None:
|
||||
self._pipe_write_fd = await to_thread.run_sync(
|
||||
partial(os.open, fifo_p2c, os.O_WRONLY)
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as open_tg:
|
||||
open_tg.start_soon(open_read)
|
||||
open_tg.start_soon(open_write)
|
||||
|
||||
logger.info(
|
||||
f"JACCL pipe relay: FIFOs opened (read_fd={self._pipe_read_fd}, write_fd={self._pipe_write_fd})"
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(self._pipe_relay)
|
||||
tg.start_soon(self._forward_events)
|
||||
else:
|
||||
await self._forward_events()
|
||||
await self._forward_events()
|
||||
|
||||
def shutdown(self):
|
||||
logger.info("Runner supervisor shutting down")
|
||||
self._ev_recv.close()
|
||||
self._task_sender.close()
|
||||
try:
|
||||
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
|
||||
self._cancel_sender.close()
|
||||
except ClosedResourceError:
|
||||
pass
|
||||
self._event_sender.close()
|
||||
self._close_pipe_fds()
|
||||
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
|
||||
self._cancel_sender.close()
|
||||
self.runner_process.join(1)
|
||||
if not self.runner_process.is_alive():
|
||||
logger.info("Runner process succesfully terminated")
|
||||
@@ -229,7 +140,6 @@ class RunnerSupervisor:
|
||||
await event.wait()
|
||||
|
||||
async def cancel_task(self, task_id: TaskId):
|
||||
"""Send a cancellation signal to the runner process."""
|
||||
if task_id in self.completed:
|
||||
logger.info(f"Unable to cancel {task_id} as it has been completed")
|
||||
return
|
||||
@@ -271,110 +181,6 @@ class RunnerSupervisor:
|
||||
for tid in self.pending:
|
||||
self.pending[tid].set()
|
||||
|
||||
def _close_pipe_fds(self) -> None:
|
||||
if self._pipe_read_fd is not None:
|
||||
with contextlib.suppress(OSError):
|
||||
os.close(self._pipe_read_fd)
|
||||
self._pipe_read_fd = None
|
||||
if self._pipe_write_fd is not None:
|
||||
with contextlib.suppress(OSError):
|
||||
os.close(self._pipe_write_fd)
|
||||
self._pipe_write_fd = None
|
||||
if self._child_pipe_fds is not None:
|
||||
for fd in self._child_pipe_fds:
|
||||
with contextlib.suppress(OSError):
|
||||
os.close(fd)
|
||||
self._child_pipe_fds = None
|
||||
# Clean up FIFO files
|
||||
if self._fifo_c2p is not None:
|
||||
with contextlib.suppress(OSError):
|
||||
os.unlink(self._fifo_c2p)
|
||||
self._fifo_c2p = None
|
||||
if self._fifo_p2c is not None:
|
||||
with contextlib.suppress(OSError):
|
||||
os.unlink(self._fifo_p2c)
|
||||
self._fifo_p2c = None
|
||||
if self._fifo_dir is not None:
|
||||
with contextlib.suppress(OSError):
|
||||
os.rmdir(self._fifo_dir)
|
||||
self._fifo_dir = None
|
||||
|
||||
async def _pipe_relay(self) -> None:
|
||||
"""Relay JACCL SideChannel all_gather rounds between runner pipes and exo events."""
|
||||
assert self._pipe_read_fd is not None
|
||||
assert self._pipe_write_fd is not None
|
||||
read_fd = self._pipe_read_fd
|
||||
write_fd = self._pipe_write_fd
|
||||
sequence = 0
|
||||
|
||||
try:
|
||||
while True:
|
||||
# 1. Read local data from runner: [uint32 size][size bytes]
|
||||
header = await to_thread.run_sync(partial(_pipe_read_exact, read_fd, 4))
|
||||
if header is None:
|
||||
logger.info("JACCL pipe relay: runner closed pipe (EOF)")
|
||||
break
|
||||
data_size: int = struct.unpack("<I", header)[0] # pyright: ignore[reportAny]
|
||||
local_data = await to_thread.run_sync(
|
||||
partial(_pipe_read_exact, read_fd, data_size)
|
||||
)
|
||||
if local_data is None:
|
||||
logger.warning("JACCL pipe relay: EOF reading data payload")
|
||||
break
|
||||
|
||||
logger.info(
|
||||
f"JACCL pipe relay: read {data_size} bytes from runner, seq={sequence}"
|
||||
)
|
||||
|
||||
# 2. Emit JacclSideChannelData event
|
||||
waiter = anyio.Event()
|
||||
self._gathered_waiters[sequence] = (waiter, None)
|
||||
await self._event_sender.send(
|
||||
JacclSideChannelData(
|
||||
instance_id=self.bound_instance.instance.instance_id,
|
||||
runner_id=self.bound_instance.bound_runner_id,
|
||||
sequence=sequence,
|
||||
data=local_data,
|
||||
)
|
||||
)
|
||||
|
||||
# 3. Wait for gathered result
|
||||
await waiter.wait()
|
||||
_, gathered_event = self._gathered_waiters.pop(sequence)
|
||||
assert gathered_event is not None
|
||||
|
||||
# 4. Order gathered data by runner rank and concatenate
|
||||
instance = self.bound_instance.instance
|
||||
assert isinstance(instance, MlxJacclInstance)
|
||||
runner_order = list(instance.shard_assignments.runner_to_shard.keys())
|
||||
ordered_data = b"".join(
|
||||
gathered_event.gathered_data[rid] for rid in runner_order
|
||||
)
|
||||
|
||||
# 5. Write gathered data to runner: [uint32 total_size][total_size bytes]
|
||||
total_size = len(ordered_data)
|
||||
response = struct.pack("<I", total_size) + ordered_data
|
||||
await to_thread.run_sync(partial(_pipe_write_all, write_fd, response))
|
||||
|
||||
logger.info(
|
||||
f"JACCL pipe relay: wrote {total_size} bytes to runner, seq={sequence}"
|
||||
)
|
||||
sequence += 1
|
||||
except OSError as e:
|
||||
logger.warning(f"JACCL pipe relay: OS error: {e}")
|
||||
except Exception as e:
|
||||
logger.opt(exception=e).error("JACCL pipe relay: unexpected error")
|
||||
|
||||
def notify_gathered(self, event: JacclSideChannelGathered) -> None:
|
||||
"""Called by the worker when a JacclSideChannelGathered event arrives."""
|
||||
seq = event.sequence
|
||||
if seq not in self._gathered_waiters:
|
||||
logger.warning(f"JACCL: received gathered event for unknown sequence {seq}")
|
||||
return
|
||||
waiter, _ = self._gathered_waiters[seq]
|
||||
self._gathered_waiters[seq] = (waiter, event)
|
||||
waiter.set()
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.runner_process.is_alive():
|
||||
logger.warning("RunnerSupervisor was not stopped cleanly.")
|
||||
|
||||
@@ -178,9 +178,6 @@ def _run(tasks: Iterable[Task]):
|
||||
# this is some c++ nonsense
|
||||
task_receiver.close = nothin
|
||||
task_receiver.join = nothin
|
||||
cancel_receiver.close = nothin
|
||||
cancel_receiver.join = nothin
|
||||
|
||||
with unittest.mock.patch(
|
||||
"exo.worker.runner.runner.mx.distributed.all_gather",
|
||||
make_nothin(mx.array([1])),
|
||||
|
||||
40
uv.lock
generated
40
uv.lock
generated
@@ -377,8 +377,8 @@ dependencies = [
|
||||
{ name = "hypercorn", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mflux", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.6", source = { registry = "https://pypi.org/simple" }, extra = ["cpu"], marker = "sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.7.dev20260217+50487b41", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#50487b4141f3c951122655db3b83df5146c1fbeb" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", extra = ["cpu"], marker = "sys_platform == 'linux'" },
|
||||
{ name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "msgspec", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "openai-harmony", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -416,7 +416,7 @@ requires-dist = [
|
||||
{ name = "hypercorn", specifier = ">=0.18.0" },
|
||||
{ name = "loguru", specifier = ">=0.7.3" },
|
||||
{ name = "mflux", specifier = "==0.15.5" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'", git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = "==0.30.6" },
|
||||
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = "==0.30.6" },
|
||||
{ name = "mlx-lm", specifier = "==0.30.6" },
|
||||
{ name = "msgspec", specifier = ">=0.19.0" },
|
||||
@@ -1020,8 +1020,8 @@ dependencies = [
|
||||
{ name = "fonttools", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "huggingface-hub", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "matplotlib", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.6", source = { registry = "https://pypi.org/simple" }, extra = ["cuda13"], marker = "sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.7.dev20260217+50487b41", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#50487b4141f3c951122655db3b83df5146c1fbeb" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", extra = ["cuda13"], marker = "sys_platform == 'linux'" },
|
||||
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "opencv-python", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "piexif", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -1048,12 +1048,18 @@ wheels = [
|
||||
name = "mlx"
|
||||
version = "0.30.6"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
resolution-markers = [
|
||||
"sys_platform == 'linux'",
|
||||
dependencies = [
|
||||
{ name = "mlx-metal", marker = "sys_platform == 'darwin'" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ae/5b/e460e144a34d5529e010056cccf50b538d56ed001473bc6b246018fd58cb/mlx-0.30.6-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:ed86f8bffc174c2f259ca589ea25464c96cf69d1bb457074a2bf2ef53737e54f", size = 573515, upload-time = "2026-02-06T03:45:23.405Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/60/25/69833fefb9a3fef30b56792b1bcd022496c4fea83e45411d289b77ef7546/mlx-0.30.6-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:c52294958269e20f300639a17c1900ca8fc737d859ddda737f9811e94bd040e5", size = 573516, upload-time = "2026-02-06T03:45:24.618Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9c/6a/7e7fbeebc5cb51b6a5eba96b263a6298707bcbdc059f4b0b73e088bc3dea/mlx-0.30.6-cp313-cp313-macosx_26_0_arm64.whl", hash = "sha256:b5b6636f7c49a4d86d8ec82643b972f45a144a7a9f3a967b27b2e6e22cf71e6a", size = 573592, upload-time = "2026-02-06T03:45:25.928Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/93/06/280f6f2ba80520a7109730425eda0d966658793aa0d02d8be8d351f75253/mlx-0.30.6-cp313-cp313-manylinux_2_35_aarch64.whl", hash = "sha256:67e6c9e30a9faeacc209917ef5523177cf9b086914b6b5d83ff886e4294b727d", size = 622011, upload-time = "2026-02-06T03:45:28.165Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fe/35/f872afbee9c079cc69924d9e9c46f5663adb7da58cba3511db082dd307c1/mlx-0.30.6-cp313-cp313-manylinux_2_35_x86_64.whl", hash = "sha256:47db8b16fcb6f6c5a47c0bdb24ed377b41237017ac93aa6cb6aa206c9bdf82e4", size = 663650, upload-time = "2026-02-06T03:45:30.315Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/60/23/361dc7a5797634e4d7e9bdd6564c6b28f9b1246672632def2f91bf066b18/mlx-0.30.6-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:78804a89dcff4a838f7c2da72392fe87a523e95122a3c840e53df019122aad45", size = 575028, upload-time = "2026-02-06T03:45:31.549Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a8/69/1854484d414171586814dfbe8def95f75c4ea2c7341ba13ba8ee675f7c62/mlx-0.30.6-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:ec13584ab069665cc7ad34a05494d9291cd623aef6ae96be48875fc87cfc25d6", size = 575026, upload-time = "2026-02-06T03:45:33.072Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6b/b8/3adbc441924209a7e4c568308b2a0b54bd09aee6a68db5bae85304791e54/mlx-0.30.6-cp314-cp314-macosx_26_0_arm64.whl", hash = "sha256:b2c5e8a090a753ef99a1380a4d059c983083f36198864f6df9faaf1223d083df", size = 575041, upload-time = "2026-02-06T03:45:34.814Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3f/54/9d9e06804fb2088202a2cdf60458e00b221f71420bea285720b60f9e82b5/mlx-0.30.6-cp314-cp314-manylinux_2_35_aarch64.whl", hash = "sha256:9ceddede4af0de31d1f6b3099f70e5469d60cd7c546975dedbdbeab3519cab3f", size = 624002, upload-time = "2026-02-06T03:45:36Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/42/92/3140a15a50cb1f9267a6552171e1dfa577861de53e093124bc43707f2a0e/mlx-0.30.6-cp314-cp314-manylinux_2_35_x86_64.whl", hash = "sha256:4a6ffd2d16728cf95f63a1b555d7c2eaeea686a0e6b73228bd265411cb5d77a4", size = 663569, upload-time = "2026-02-06T03:45:37.242Z" },
|
||||
]
|
||||
@@ -1066,14 +1072,6 @@ cuda13 = [
|
||||
{ name = "mlx-cuda-13", marker = "sys_platform == 'linux'" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlx"
|
||||
version = "0.30.7.dev20260217+50487b41"
|
||||
source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#50487b4141f3c951122655db3b83df5146c1fbeb" }
|
||||
resolution-markers = [
|
||||
"sys_platform == 'darwin'",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlx-cpu"
|
||||
version = "0.30.6"
|
||||
@@ -1104,7 +1102,7 @@ version = "0.30.6"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "mlx", version = "0.30.7.dev20260217+50487b41", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git?branch=address-rdma-gpu-locks#50487b4141f3c951122655db3b83df5146c1fbeb" }, marker = "sys_platform == 'darwin'" },
|
||||
{ name = "mlx", marker = "sys_platform == 'darwin'" },
|
||||
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
{ name = "pyyaml", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
|
||||
@@ -1116,6 +1114,16 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/20/5f/01d281f1fa8a1521d5936659beb4f5ab1f32b463d059263cf9d4cef969d9/mlx_lm-0.30.6-py3-none-any.whl", hash = "sha256:a7405bd581eacc4bf8209d7a6b7f23629585a0d7c6740c2a97e51fee35b3b0e1", size = 379451, upload-time = "2026-02-04T21:27:43.222Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mlx-metal"
|
||||
version = "0.30.6"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f3/85/44406b521f920248fad621334d4dc15e77660a494edf890e7cbee33bf38d/mlx_metal-0.30.6-py3-none-macosx_14_0_arm64.whl", hash = "sha256:ea6d0c973def9a5b4f652cc77036237db3f88c9d0af63701d76b5fddde99b820", size = 38437818, upload-time = "2026-02-06T03:44:56.19Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d0/cb/10a516995f7d0c154b0d7e633c54b51e96977a86a355105b6474cfcbe0d0/mlx_metal-0.30.6-py3-none-macosx_15_0_arm64.whl", hash = "sha256:0f8cb94634d07e06a372d6ad9a090f38a18bab1ff19a140aede60eacf707bb94", size = 38433701, upload-time = "2026-02-06T03:44:59.678Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4c/7d/70cb272f7373c334709f210ed8420511fc9d64d05a7a646c0b3b94c29c04/mlx_metal-0.30.6-py3-none-macosx_26_0_arm64.whl", hash = "sha256:d761ae26304f2c4b454eeea7f612a56919d9e5e57dbb1dc0788f8e34aa6f41c2", size = 47718448, upload-time = "2026-02-06T03:45:03.133Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "more-itertools"
|
||||
version = "10.8.0"
|
||||
|
||||
Reference in New Issue
Block a user