mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-18 14:55:13 -05:00
Compare commits
13 Commits
alexcheema
...
sami/iOS-a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ab622f79c3 | ||
|
|
d6301ed593 | ||
|
|
6d1ca6689b | ||
|
|
c01b6fff21 | ||
|
|
8392e78afe | ||
|
|
86735ece78 | ||
|
|
2759e92334 | ||
|
|
131fb141a6 | ||
|
|
2d8bfc2e3c | ||
|
|
042999f728 | ||
|
|
b61dc2eb35 | ||
|
|
36a7115b6f | ||
|
|
0b7d88b43b |
27
.github/workflows/pipeline.yml
vendored
27
.github/workflows/pipeline.yml
vendored
@@ -8,33 +8,6 @@ on:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
typecheck:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: false
|
||||
|
||||
- uses: cachix/install-nix-action@v31
|
||||
with:
|
||||
nix_path: nixpkgs=channel:nixos-unstable
|
||||
|
||||
- uses: cachix/cachix-action@v14
|
||||
name: Configure Cachix
|
||||
with:
|
||||
name: exo
|
||||
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
|
||||
|
||||
- name: Load nix develop environment
|
||||
run: nix run github:nicknovitski/nix-develop/v1
|
||||
|
||||
- name: Sync dependencies
|
||||
run: uv sync --all-packages
|
||||
|
||||
- name: Run type checker
|
||||
run: uv run basedpyright --project pyproject.toml
|
||||
|
||||
nix:
|
||||
name: Build and check (${{ matrix.system }})
|
||||
runs-on: ${{ matrix.runner }}
|
||||
|
||||
628
app/EXO-iOS/EXO-iOS.xcodeproj/project.pbxproj
Normal file
628
app/EXO-iOS/EXO-iOS.xcodeproj/project.pbxproj
Normal file
@@ -0,0 +1,628 @@
|
||||
// !$*UTF8*$!
|
||||
{
|
||||
archiveVersion = 1;
|
||||
classes = {
|
||||
};
|
||||
objectVersion = 77;
|
||||
objects = {
|
||||
|
||||
/* Begin PBXBuildFile section */
|
||||
E09D17522F44F359009C51A3 /* MLXLLM in Frameworks */ = {isa = PBXBuildFile; productRef = E09D17512F44F359009C51A3 /* MLXLLM */; };
|
||||
E09D17542F44F359009C51A3 /* MLXLMCommon in Frameworks */ = {isa = PBXBuildFile; productRef = E09D17532F44F359009C51A3 /* MLXLMCommon */; };
|
||||
/* End PBXBuildFile section */
|
||||
|
||||
/* Begin PBXContainerItemProxy section */
|
||||
E09D167D2F44CA20009C51A3 /* PBXContainerItemProxy */ = {
|
||||
isa = PBXContainerItemProxy;
|
||||
containerPortal = E09D16672F44CA1E009C51A3 /* Project object */;
|
||||
proxyType = 1;
|
||||
remoteGlobalIDString = E09D166E2F44CA1E009C51A3;
|
||||
remoteInfo = "EXO-iOS";
|
||||
};
|
||||
E09D16872F44CA20009C51A3 /* PBXContainerItemProxy */ = {
|
||||
isa = PBXContainerItemProxy;
|
||||
containerPortal = E09D16672F44CA1E009C51A3 /* Project object */;
|
||||
proxyType = 1;
|
||||
remoteGlobalIDString = E09D166E2F44CA1E009C51A3;
|
||||
remoteInfo = "EXO-iOS";
|
||||
};
|
||||
/* End PBXContainerItemProxy section */
|
||||
|
||||
/* Begin PBXFileReference section */
|
||||
E09D166F2F44CA1E009C51A3 /* EXO-iOS.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = "EXO-iOS.app"; sourceTree = BUILT_PRODUCTS_DIR; };
|
||||
E09D167C2F44CA20009C51A3 /* EXO-iOSTests.xctest */ = {isa = PBXFileReference; explicitFileType = wrapper.cfbundle; includeInIndex = 0; path = "EXO-iOSTests.xctest"; sourceTree = BUILT_PRODUCTS_DIR; };
|
||||
E09D16862F44CA20009C51A3 /* EXO-iOSUITests.xctest */ = {isa = PBXFileReference; explicitFileType = wrapper.cfbundle; includeInIndex = 0; path = "EXO-iOSUITests.xctest"; sourceTree = BUILT_PRODUCTS_DIR; };
|
||||
/* End PBXFileReference section */
|
||||
|
||||
/* Begin PBXFileSystemSynchronizedBuildFileExceptionSet section */
|
||||
E09D169A2F44CA20009C51A3 /* Exceptions for "EXO-iOS" folder in "EXO-iOS" target */ = {
|
||||
isa = PBXFileSystemSynchronizedBuildFileExceptionSet;
|
||||
membershipExceptions = (
|
||||
Info.plist,
|
||||
);
|
||||
target = E09D166E2F44CA1E009C51A3 /* EXO-iOS */;
|
||||
};
|
||||
/* End PBXFileSystemSynchronizedBuildFileExceptionSet section */
|
||||
|
||||
/* Begin PBXFileSystemSynchronizedRootGroup section */
|
||||
E09D16712F44CA1E009C51A3 /* EXO-iOS */ = {
|
||||
isa = PBXFileSystemSynchronizedRootGroup;
|
||||
exceptions = (
|
||||
E09D169A2F44CA20009C51A3 /* Exceptions for "EXO-iOS" folder in "EXO-iOS" target */,
|
||||
);
|
||||
path = "EXO-iOS";
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
E09D167F2F44CA20009C51A3 /* EXO-iOSTests */ = {
|
||||
isa = PBXFileSystemSynchronizedRootGroup;
|
||||
path = "EXO-iOSTests";
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
E09D16892F44CA20009C51A3 /* EXO-iOSUITests */ = {
|
||||
isa = PBXFileSystemSynchronizedRootGroup;
|
||||
path = "EXO-iOSUITests";
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
/* End PBXFileSystemSynchronizedRootGroup section */
|
||||
|
||||
/* Begin PBXFrameworksBuildPhase section */
|
||||
E09D166C2F44CA1E009C51A3 /* Frameworks */ = {
|
||||
isa = PBXFrameworksBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
E09D17542F44F359009C51A3 /* MLXLMCommon in Frameworks */,
|
||||
E09D17522F44F359009C51A3 /* MLXLLM in Frameworks */,
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
E09D16792F44CA20009C51A3 /* Frameworks */ = {
|
||||
isa = PBXFrameworksBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
E09D16832F44CA20009C51A3 /* Frameworks */ = {
|
||||
isa = PBXFrameworksBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
/* End PBXFrameworksBuildPhase section */
|
||||
|
||||
/* Begin PBXGroup section */
|
||||
E09D16662F44CA1E009C51A3 = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
E09D16712F44CA1E009C51A3 /* EXO-iOS */,
|
||||
E09D167F2F44CA20009C51A3 /* EXO-iOSTests */,
|
||||
E09D16892F44CA20009C51A3 /* EXO-iOSUITests */,
|
||||
E09D16702F44CA1E009C51A3 /* Products */,
|
||||
);
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
E09D16702F44CA1E009C51A3 /* Products */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
E09D166F2F44CA1E009C51A3 /* EXO-iOS.app */,
|
||||
E09D167C2F44CA20009C51A3 /* EXO-iOSTests.xctest */,
|
||||
E09D16862F44CA20009C51A3 /* EXO-iOSUITests.xctest */,
|
||||
);
|
||||
name = Products;
|
||||
sourceTree = "<group>";
|
||||
};
|
||||
/* End PBXGroup section */
|
||||
|
||||
/* Begin PBXNativeTarget section */
|
||||
E09D166E2F44CA1E009C51A3 /* EXO-iOS */ = {
|
||||
isa = PBXNativeTarget;
|
||||
buildConfigurationList = E09D16902F44CA20009C51A3 /* Build configuration list for PBXNativeTarget "EXO-iOS" */;
|
||||
buildPhases = (
|
||||
E09D166B2F44CA1E009C51A3 /* Sources */,
|
||||
E09D166C2F44CA1E009C51A3 /* Frameworks */,
|
||||
E09D166D2F44CA1E009C51A3 /* Resources */,
|
||||
);
|
||||
buildRules = (
|
||||
);
|
||||
dependencies = (
|
||||
);
|
||||
fileSystemSynchronizedGroups = (
|
||||
E09D16712F44CA1E009C51A3 /* EXO-iOS */,
|
||||
);
|
||||
name = "EXO-iOS";
|
||||
packageProductDependencies = (
|
||||
E09D17512F44F359009C51A3 /* MLXLLM */,
|
||||
E09D17532F44F359009C51A3 /* MLXLMCommon */,
|
||||
);
|
||||
productName = "EXO-iOS";
|
||||
productReference = E09D166F2F44CA1E009C51A3 /* EXO-iOS.app */;
|
||||
productType = "com.apple.product-type.application";
|
||||
};
|
||||
E09D167B2F44CA20009C51A3 /* EXO-iOSTests */ = {
|
||||
isa = PBXNativeTarget;
|
||||
buildConfigurationList = E09D16932F44CA20009C51A3 /* Build configuration list for PBXNativeTarget "EXO-iOSTests" */;
|
||||
buildPhases = (
|
||||
E09D16782F44CA20009C51A3 /* Sources */,
|
||||
E09D16792F44CA20009C51A3 /* Frameworks */,
|
||||
E09D167A2F44CA20009C51A3 /* Resources */,
|
||||
);
|
||||
buildRules = (
|
||||
);
|
||||
dependencies = (
|
||||
E09D167E2F44CA20009C51A3 /* PBXTargetDependency */,
|
||||
);
|
||||
fileSystemSynchronizedGroups = (
|
||||
E09D167F2F44CA20009C51A3 /* EXO-iOSTests */,
|
||||
);
|
||||
name = "EXO-iOSTests";
|
||||
packageProductDependencies = (
|
||||
);
|
||||
productName = "EXO-iOSTests";
|
||||
productReference = E09D167C2F44CA20009C51A3 /* EXO-iOSTests.xctest */;
|
||||
productType = "com.apple.product-type.bundle.unit-test";
|
||||
};
|
||||
E09D16852F44CA20009C51A3 /* EXO-iOSUITests */ = {
|
||||
isa = PBXNativeTarget;
|
||||
buildConfigurationList = E09D16962F44CA20009C51A3 /* Build configuration list for PBXNativeTarget "EXO-iOSUITests" */;
|
||||
buildPhases = (
|
||||
E09D16822F44CA20009C51A3 /* Sources */,
|
||||
E09D16832F44CA20009C51A3 /* Frameworks */,
|
||||
E09D16842F44CA20009C51A3 /* Resources */,
|
||||
);
|
||||
buildRules = (
|
||||
);
|
||||
dependencies = (
|
||||
E09D16882F44CA20009C51A3 /* PBXTargetDependency */,
|
||||
);
|
||||
fileSystemSynchronizedGroups = (
|
||||
E09D16892F44CA20009C51A3 /* EXO-iOSUITests */,
|
||||
);
|
||||
name = "EXO-iOSUITests";
|
||||
packageProductDependencies = (
|
||||
);
|
||||
productName = "EXO-iOSUITests";
|
||||
productReference = E09D16862F44CA20009C51A3 /* EXO-iOSUITests.xctest */;
|
||||
productType = "com.apple.product-type.bundle.ui-testing";
|
||||
};
|
||||
/* End PBXNativeTarget section */
|
||||
|
||||
/* Begin PBXProject section */
|
||||
E09D16672F44CA1E009C51A3 /* Project object */ = {
|
||||
isa = PBXProject;
|
||||
attributes = {
|
||||
BuildIndependentTargetsInParallel = 1;
|
||||
LastSwiftUpdateCheck = 2620;
|
||||
LastUpgradeCheck = 2620;
|
||||
TargetAttributes = {
|
||||
E09D166E2F44CA1E009C51A3 = {
|
||||
CreatedOnToolsVersion = 26.2;
|
||||
};
|
||||
E09D167B2F44CA20009C51A3 = {
|
||||
CreatedOnToolsVersion = 26.2;
|
||||
TestTargetID = E09D166E2F44CA1E009C51A3;
|
||||
};
|
||||
E09D16852F44CA20009C51A3 = {
|
||||
CreatedOnToolsVersion = 26.2;
|
||||
TestTargetID = E09D166E2F44CA1E009C51A3;
|
||||
};
|
||||
};
|
||||
};
|
||||
buildConfigurationList = E09D166A2F44CA1E009C51A3 /* Build configuration list for PBXProject "EXO-iOS" */;
|
||||
developmentRegion = en;
|
||||
hasScannedForEncodings = 0;
|
||||
knownRegions = (
|
||||
en,
|
||||
Base,
|
||||
);
|
||||
mainGroup = E09D16662F44CA1E009C51A3;
|
||||
minimizedProjectReferenceProxies = 1;
|
||||
packageReferences = (
|
||||
E09D17502F44F359009C51A3 /* XCRemoteSwiftPackageReference "mlx-swift-lm" */,
|
||||
);
|
||||
preferredProjectObjectVersion = 77;
|
||||
productRefGroup = E09D16702F44CA1E009C51A3 /* Products */;
|
||||
projectDirPath = "";
|
||||
projectRoot = "";
|
||||
targets = (
|
||||
E09D166E2F44CA1E009C51A3 /* EXO-iOS */,
|
||||
E09D167B2F44CA20009C51A3 /* EXO-iOSTests */,
|
||||
E09D16852F44CA20009C51A3 /* EXO-iOSUITests */,
|
||||
);
|
||||
};
|
||||
/* End PBXProject section */
|
||||
|
||||
/* Begin PBXResourcesBuildPhase section */
|
||||
E09D166D2F44CA1E009C51A3 /* Resources */ = {
|
||||
isa = PBXResourcesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
E09D167A2F44CA20009C51A3 /* Resources */ = {
|
||||
isa = PBXResourcesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
E09D16842F44CA20009C51A3 /* Resources */ = {
|
||||
isa = PBXResourcesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
/* End PBXResourcesBuildPhase section */
|
||||
|
||||
/* Begin PBXSourcesBuildPhase section */
|
||||
E09D166B2F44CA1E009C51A3 /* Sources */ = {
|
||||
isa = PBXSourcesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
E09D16782F44CA20009C51A3 /* Sources */ = {
|
||||
isa = PBXSourcesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
E09D16822F44CA20009C51A3 /* Sources */ = {
|
||||
isa = PBXSourcesBuildPhase;
|
||||
buildActionMask = 2147483647;
|
||||
files = (
|
||||
);
|
||||
runOnlyForDeploymentPostprocessing = 0;
|
||||
};
|
||||
/* End PBXSourcesBuildPhase section */
|
||||
|
||||
/* Begin PBXTargetDependency section */
|
||||
E09D167E2F44CA20009C51A3 /* PBXTargetDependency */ = {
|
||||
isa = PBXTargetDependency;
|
||||
target = E09D166E2F44CA1E009C51A3 /* EXO-iOS */;
|
||||
targetProxy = E09D167D2F44CA20009C51A3 /* PBXContainerItemProxy */;
|
||||
};
|
||||
E09D16882F44CA20009C51A3 /* PBXTargetDependency */ = {
|
||||
isa = PBXTargetDependency;
|
||||
target = E09D166E2F44CA1E009C51A3 /* EXO-iOS */;
|
||||
targetProxy = E09D16872F44CA20009C51A3 /* PBXContainerItemProxy */;
|
||||
};
|
||||
/* End PBXTargetDependency section */
|
||||
|
||||
/* Begin XCBuildConfiguration section */
|
||||
E09D168E2F44CA20009C51A3 /* Debug */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
ALWAYS_SEARCH_USER_PATHS = NO;
|
||||
ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;
|
||||
CLANG_ANALYZER_NONNULL = YES;
|
||||
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
|
||||
CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
|
||||
CLANG_ENABLE_MODULES = YES;
|
||||
CLANG_ENABLE_OBJC_ARC = YES;
|
||||
CLANG_ENABLE_OBJC_WEAK = YES;
|
||||
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
|
||||
CLANG_WARN_BOOL_CONVERSION = YES;
|
||||
CLANG_WARN_COMMA = YES;
|
||||
CLANG_WARN_CONSTANT_CONVERSION = YES;
|
||||
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
|
||||
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
|
||||
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
|
||||
CLANG_WARN_EMPTY_BODY = YES;
|
||||
CLANG_WARN_ENUM_CONVERSION = YES;
|
||||
CLANG_WARN_INFINITE_RECURSION = YES;
|
||||
CLANG_WARN_INT_CONVERSION = YES;
|
||||
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
|
||||
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
|
||||
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
|
||||
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
|
||||
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
|
||||
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
|
||||
CLANG_WARN_STRICT_PROTOTYPES = YES;
|
||||
CLANG_WARN_SUSPICIOUS_MOVE = YES;
|
||||
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
|
||||
CLANG_WARN_UNREACHABLE_CODE = YES;
|
||||
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
|
||||
COPY_PHASE_STRIP = NO;
|
||||
DEBUG_INFORMATION_FORMAT = dwarf;
|
||||
ENABLE_STRICT_OBJC_MSGSEND = YES;
|
||||
ENABLE_TESTABILITY = YES;
|
||||
ENABLE_USER_SCRIPT_SANDBOXING = YES;
|
||||
GCC_C_LANGUAGE_STANDARD = gnu17;
|
||||
GCC_DYNAMIC_NO_PIC = NO;
|
||||
GCC_NO_COMMON_BLOCKS = YES;
|
||||
GCC_OPTIMIZATION_LEVEL = 0;
|
||||
GCC_PREPROCESSOR_DEFINITIONS = (
|
||||
"DEBUG=1",
|
||||
"$(inherited)",
|
||||
);
|
||||
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
|
||||
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
|
||||
GCC_WARN_UNDECLARED_SELECTOR = YES;
|
||||
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
|
||||
GCC_WARN_UNUSED_FUNCTION = YES;
|
||||
GCC_WARN_UNUSED_VARIABLE = YES;
|
||||
IPHONEOS_DEPLOYMENT_TARGET = 26.2;
|
||||
LOCALIZATION_PREFERS_STRING_CATALOGS = YES;
|
||||
MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE;
|
||||
MTL_FAST_MATH = YES;
|
||||
ONLY_ACTIVE_ARCH = YES;
|
||||
SDKROOT = iphoneos;
|
||||
SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)";
|
||||
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
|
||||
};
|
||||
name = Debug;
|
||||
};
|
||||
E09D168F2F44CA20009C51A3 /* Release */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
ALWAYS_SEARCH_USER_PATHS = NO;
|
||||
ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;
|
||||
CLANG_ANALYZER_NONNULL = YES;
|
||||
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
|
||||
CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
|
||||
CLANG_ENABLE_MODULES = YES;
|
||||
CLANG_ENABLE_OBJC_ARC = YES;
|
||||
CLANG_ENABLE_OBJC_WEAK = YES;
|
||||
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
|
||||
CLANG_WARN_BOOL_CONVERSION = YES;
|
||||
CLANG_WARN_COMMA = YES;
|
||||
CLANG_WARN_CONSTANT_CONVERSION = YES;
|
||||
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
|
||||
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
|
||||
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
|
||||
CLANG_WARN_EMPTY_BODY = YES;
|
||||
CLANG_WARN_ENUM_CONVERSION = YES;
|
||||
CLANG_WARN_INFINITE_RECURSION = YES;
|
||||
CLANG_WARN_INT_CONVERSION = YES;
|
||||
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
|
||||
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
|
||||
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
|
||||
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
|
||||
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
|
||||
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
|
||||
CLANG_WARN_STRICT_PROTOTYPES = YES;
|
||||
CLANG_WARN_SUSPICIOUS_MOVE = YES;
|
||||
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
|
||||
CLANG_WARN_UNREACHABLE_CODE = YES;
|
||||
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
|
||||
COPY_PHASE_STRIP = NO;
|
||||
DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym";
|
||||
ENABLE_NS_ASSERTIONS = NO;
|
||||
ENABLE_STRICT_OBJC_MSGSEND = YES;
|
||||
ENABLE_USER_SCRIPT_SANDBOXING = YES;
|
||||
GCC_C_LANGUAGE_STANDARD = gnu17;
|
||||
GCC_NO_COMMON_BLOCKS = YES;
|
||||
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
|
||||
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
|
||||
GCC_WARN_UNDECLARED_SELECTOR = YES;
|
||||
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
|
||||
GCC_WARN_UNUSED_FUNCTION = YES;
|
||||
GCC_WARN_UNUSED_VARIABLE = YES;
|
||||
IPHONEOS_DEPLOYMENT_TARGET = 26.2;
|
||||
LOCALIZATION_PREFERS_STRING_CATALOGS = YES;
|
||||
MTL_ENABLE_DEBUG_INFO = NO;
|
||||
MTL_FAST_MATH = YES;
|
||||
SDKROOT = iphoneos;
|
||||
SWIFT_COMPILATION_MODE = wholemodule;
|
||||
VALIDATE_PRODUCT = YES;
|
||||
};
|
||||
name = Release;
|
||||
};
|
||||
E09D16912F44CA20009C51A3 /* Debug */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
|
||||
ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor;
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
CURRENT_PROJECT_VERSION = 1;
|
||||
DEVELOPMENT_TEAM = 3M3M67U93M;
|
||||
ENABLE_PREVIEWS = YES;
|
||||
GENERATE_INFOPLIST_FILE = YES;
|
||||
INFOPLIST_FILE = "EXO-iOS/Info.plist";
|
||||
INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES;
|
||||
INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES;
|
||||
INFOPLIST_KEY_UILaunchScreen_Generation = YES;
|
||||
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
|
||||
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
|
||||
LD_RUNPATH_SEARCH_PATHS = (
|
||||
"$(inherited)",
|
||||
"@executable_path/Frameworks",
|
||||
);
|
||||
MARKETING_VERSION = 1.0;
|
||||
PRODUCT_BUNDLE_IDENTIFIER = "com.exo.EXO-iOS";
|
||||
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||
STRING_CATALOG_GENERATE_SYMBOLS = YES;
|
||||
SWIFT_APPROACHABLE_CONCURRENCY = YES;
|
||||
SWIFT_DEFAULT_ACTOR_ISOLATION = MainActor;
|
||||
SWIFT_EMIT_LOC_STRINGS = YES;
|
||||
SWIFT_UPCOMING_FEATURE_MEMBER_IMPORT_VISIBILITY = YES;
|
||||
SWIFT_VERSION = 5.0;
|
||||
TARGETED_DEVICE_FAMILY = "1,2";
|
||||
};
|
||||
name = Debug;
|
||||
};
|
||||
E09D16922F44CA20009C51A3 /* Release */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
|
||||
ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor;
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
CURRENT_PROJECT_VERSION = 1;
|
||||
DEVELOPMENT_TEAM = 3M3M67U93M;
|
||||
ENABLE_PREVIEWS = YES;
|
||||
GENERATE_INFOPLIST_FILE = YES;
|
||||
INFOPLIST_FILE = "EXO-iOS/Info.plist";
|
||||
INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES;
|
||||
INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES;
|
||||
INFOPLIST_KEY_UILaunchScreen_Generation = YES;
|
||||
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
|
||||
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
|
||||
LD_RUNPATH_SEARCH_PATHS = (
|
||||
"$(inherited)",
|
||||
"@executable_path/Frameworks",
|
||||
);
|
||||
MARKETING_VERSION = 1.0;
|
||||
PRODUCT_BUNDLE_IDENTIFIER = "com.exo.EXO-iOS";
|
||||
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||
STRING_CATALOG_GENERATE_SYMBOLS = YES;
|
||||
SWIFT_APPROACHABLE_CONCURRENCY = YES;
|
||||
SWIFT_DEFAULT_ACTOR_ISOLATION = MainActor;
|
||||
SWIFT_EMIT_LOC_STRINGS = YES;
|
||||
SWIFT_UPCOMING_FEATURE_MEMBER_IMPORT_VISIBILITY = YES;
|
||||
SWIFT_VERSION = 5.0;
|
||||
TARGETED_DEVICE_FAMILY = "1,2";
|
||||
};
|
||||
name = Release;
|
||||
};
|
||||
E09D16942F44CA20009C51A3 /* Debug */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
BUNDLE_LOADER = "$(TEST_HOST)";
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
CURRENT_PROJECT_VERSION = 1;
|
||||
GENERATE_INFOPLIST_FILE = YES;
|
||||
IPHONEOS_DEPLOYMENT_TARGET = 26.2;
|
||||
MARKETING_VERSION = 1.0;
|
||||
PRODUCT_BUNDLE_IDENTIFIER = "com.exo.EXO-iOSTests";
|
||||
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||
STRING_CATALOG_GENERATE_SYMBOLS = NO;
|
||||
SWIFT_APPROACHABLE_CONCURRENCY = YES;
|
||||
SWIFT_EMIT_LOC_STRINGS = NO;
|
||||
SWIFT_UPCOMING_FEATURE_MEMBER_IMPORT_VISIBILITY = YES;
|
||||
SWIFT_VERSION = 5.0;
|
||||
TARGETED_DEVICE_FAMILY = "1,2";
|
||||
TEST_HOST = "$(BUILT_PRODUCTS_DIR)/EXO-iOS.app/$(BUNDLE_EXECUTABLE_FOLDER_PATH)/EXO-iOS";
|
||||
};
|
||||
name = Debug;
|
||||
};
|
||||
E09D16952F44CA20009C51A3 /* Release */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
BUNDLE_LOADER = "$(TEST_HOST)";
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
CURRENT_PROJECT_VERSION = 1;
|
||||
GENERATE_INFOPLIST_FILE = YES;
|
||||
IPHONEOS_DEPLOYMENT_TARGET = 26.2;
|
||||
MARKETING_VERSION = 1.0;
|
||||
PRODUCT_BUNDLE_IDENTIFIER = "com.exo.EXO-iOSTests";
|
||||
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||
STRING_CATALOG_GENERATE_SYMBOLS = NO;
|
||||
SWIFT_APPROACHABLE_CONCURRENCY = YES;
|
||||
SWIFT_EMIT_LOC_STRINGS = NO;
|
||||
SWIFT_UPCOMING_FEATURE_MEMBER_IMPORT_VISIBILITY = YES;
|
||||
SWIFT_VERSION = 5.0;
|
||||
TARGETED_DEVICE_FAMILY = "1,2";
|
||||
TEST_HOST = "$(BUILT_PRODUCTS_DIR)/EXO-iOS.app/$(BUNDLE_EXECUTABLE_FOLDER_PATH)/EXO-iOS";
|
||||
};
|
||||
name = Release;
|
||||
};
|
||||
E09D16972F44CA20009C51A3 /* Debug */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
CURRENT_PROJECT_VERSION = 1;
|
||||
GENERATE_INFOPLIST_FILE = YES;
|
||||
MARKETING_VERSION = 1.0;
|
||||
PRODUCT_BUNDLE_IDENTIFIER = "com.exo.EXO-iOSUITests";
|
||||
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||
STRING_CATALOG_GENERATE_SYMBOLS = NO;
|
||||
SWIFT_APPROACHABLE_CONCURRENCY = YES;
|
||||
SWIFT_EMIT_LOC_STRINGS = NO;
|
||||
SWIFT_UPCOMING_FEATURE_MEMBER_IMPORT_VISIBILITY = YES;
|
||||
SWIFT_VERSION = 5.0;
|
||||
TARGETED_DEVICE_FAMILY = "1,2";
|
||||
TEST_TARGET_NAME = "EXO-iOS";
|
||||
};
|
||||
name = Debug;
|
||||
};
|
||||
E09D16982F44CA20009C51A3 /* Release */ = {
|
||||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
CURRENT_PROJECT_VERSION = 1;
|
||||
GENERATE_INFOPLIST_FILE = YES;
|
||||
MARKETING_VERSION = 1.0;
|
||||
PRODUCT_BUNDLE_IDENTIFIER = "com.exo.EXO-iOSUITests";
|
||||
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||
STRING_CATALOG_GENERATE_SYMBOLS = NO;
|
||||
SWIFT_APPROACHABLE_CONCURRENCY = YES;
|
||||
SWIFT_EMIT_LOC_STRINGS = NO;
|
||||
SWIFT_UPCOMING_FEATURE_MEMBER_IMPORT_VISIBILITY = YES;
|
||||
SWIFT_VERSION = 5.0;
|
||||
TARGETED_DEVICE_FAMILY = "1,2";
|
||||
TEST_TARGET_NAME = "EXO-iOS";
|
||||
};
|
||||
name = Release;
|
||||
};
|
||||
/* End XCBuildConfiguration section */
|
||||
|
||||
/* Begin XCConfigurationList section */
|
||||
E09D166A2F44CA1E009C51A3 /* Build configuration list for PBXProject "EXO-iOS" */ = {
|
||||
isa = XCConfigurationList;
|
||||
buildConfigurations = (
|
||||
E09D168E2F44CA20009C51A3 /* Debug */,
|
||||
E09D168F2F44CA20009C51A3 /* Release */,
|
||||
);
|
||||
defaultConfigurationIsVisible = 0;
|
||||
defaultConfigurationName = Release;
|
||||
};
|
||||
E09D16902F44CA20009C51A3 /* Build configuration list for PBXNativeTarget "EXO-iOS" */ = {
|
||||
isa = XCConfigurationList;
|
||||
buildConfigurations = (
|
||||
E09D16912F44CA20009C51A3 /* Debug */,
|
||||
E09D16922F44CA20009C51A3 /* Release */,
|
||||
);
|
||||
defaultConfigurationIsVisible = 0;
|
||||
defaultConfigurationName = Release;
|
||||
};
|
||||
E09D16932F44CA20009C51A3 /* Build configuration list for PBXNativeTarget "EXO-iOSTests" */ = {
|
||||
isa = XCConfigurationList;
|
||||
buildConfigurations = (
|
||||
E09D16942F44CA20009C51A3 /* Debug */,
|
||||
E09D16952F44CA20009C51A3 /* Release */,
|
||||
);
|
||||
defaultConfigurationIsVisible = 0;
|
||||
defaultConfigurationName = Release;
|
||||
};
|
||||
E09D16962F44CA20009C51A3 /* Build configuration list for PBXNativeTarget "EXO-iOSUITests" */ = {
|
||||
isa = XCConfigurationList;
|
||||
buildConfigurations = (
|
||||
E09D16972F44CA20009C51A3 /* Debug */,
|
||||
E09D16982F44CA20009C51A3 /* Release */,
|
||||
);
|
||||
defaultConfigurationIsVisible = 0;
|
||||
defaultConfigurationName = Release;
|
||||
};
|
||||
/* End XCConfigurationList section */
|
||||
|
||||
/* Begin XCRemoteSwiftPackageReference section */
|
||||
E09D17502F44F359009C51A3 /* XCRemoteSwiftPackageReference "mlx-swift-lm" */ = {
|
||||
isa = XCRemoteSwiftPackageReference;
|
||||
repositoryURL = "https://github.com/ml-explore/mlx-swift-lm";
|
||||
requirement = {
|
||||
kind = upToNextMajorVersion;
|
||||
minimumVersion = 2.30.3;
|
||||
};
|
||||
};
|
||||
/* End XCRemoteSwiftPackageReference section */
|
||||
|
||||
/* Begin XCSwiftPackageProductDependency section */
|
||||
E09D17512F44F359009C51A3 /* MLXLLM */ = {
|
||||
isa = XCSwiftPackageProductDependency;
|
||||
package = E09D17502F44F359009C51A3 /* XCRemoteSwiftPackageReference "mlx-swift-lm" */;
|
||||
productName = MLXLLM;
|
||||
};
|
||||
E09D17532F44F359009C51A3 /* MLXLMCommon */ = {
|
||||
isa = XCSwiftPackageProductDependency;
|
||||
package = E09D17502F44F359009C51A3 /* XCRemoteSwiftPackageReference "mlx-swift-lm" */;
|
||||
productName = MLXLMCommon;
|
||||
};
|
||||
/* End XCSwiftPackageProductDependency section */
|
||||
};
|
||||
rootObject = E09D16672F44CA1E009C51A3 /* Project object */;
|
||||
}
|
||||
7
app/EXO-iOS/EXO-iOS.xcodeproj/project.xcworkspace/contents.xcworkspacedata
generated
Normal file
7
app/EXO-iOS/EXO-iOS.xcodeproj/project.xcworkspace/contents.xcworkspacedata
generated
Normal file
@@ -0,0 +1,7 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<Workspace
|
||||
version = "1.0">
|
||||
<FileRef
|
||||
location = "self:">
|
||||
</FileRef>
|
||||
</Workspace>
|
||||
@@ -0,0 +1,60 @@
|
||||
{
|
||||
"originHash" : "facc0ac7c70363ea20f6cd1235de91dea6b06f0d00190946045a6c8ae753abc2",
|
||||
"pins" : [
|
||||
{
|
||||
"identity" : "mlx-swift",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/ml-explore/mlx-swift",
|
||||
"state" : {
|
||||
"revision" : "6ba4827fb82c97d012eec9ab4b2de21f85c3b33d",
|
||||
"version" : "0.30.6"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "mlx-swift-lm",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/ml-explore/mlx-swift-lm",
|
||||
"state" : {
|
||||
"revision" : "360c5052b81cc154b04ee0933597a4ad6db4b8ae",
|
||||
"version" : "2.30.3"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-collections",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/apple/swift-collections.git",
|
||||
"state" : {
|
||||
"revision" : "7b847a3b7008b2dc2f47ca3110d8c782fb2e5c7e",
|
||||
"version" : "1.3.0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-jinja",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/huggingface/swift-jinja.git",
|
||||
"state" : {
|
||||
"revision" : "d81197f35f41445bc10e94600795e68c6f5e94b0",
|
||||
"version" : "2.3.1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-numerics",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/apple/swift-numerics",
|
||||
"state" : {
|
||||
"revision" : "0c0290ff6b24942dadb83a929ffaaa1481df04a2",
|
||||
"version" : "1.1.1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"identity" : "swift-transformers",
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/huggingface/swift-transformers",
|
||||
"state" : {
|
||||
"revision" : "573e5c9036c2f136b3a8a071da8e8907322403d0",
|
||||
"version" : "1.1.6"
|
||||
}
|
||||
}
|
||||
],
|
||||
"version" : 3
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
{
|
||||
"colors" : [
|
||||
{
|
||||
"idiom" : "universal"
|
||||
}
|
||||
],
|
||||
"info" : {
|
||||
"author" : "xcode",
|
||||
"version" : 1
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
{
|
||||
"images" : [
|
||||
{
|
||||
"idiom" : "universal",
|
||||
"platform" : "ios",
|
||||
"size" : "1024x1024"
|
||||
},
|
||||
{
|
||||
"appearances" : [
|
||||
{
|
||||
"appearance" : "luminosity",
|
||||
"value" : "dark"
|
||||
}
|
||||
],
|
||||
"idiom" : "universal",
|
||||
"platform" : "ios",
|
||||
"size" : "1024x1024"
|
||||
},
|
||||
{
|
||||
"appearances" : [
|
||||
{
|
||||
"appearance" : "luminosity",
|
||||
"value" : "tinted"
|
||||
}
|
||||
],
|
||||
"idiom" : "universal",
|
||||
"platform" : "ios",
|
||||
"size" : "1024x1024"
|
||||
}
|
||||
],
|
||||
"info" : {
|
||||
"author" : "xcode",
|
||||
"version" : 1
|
||||
}
|
||||
}
|
||||
6
app/EXO-iOS/EXO-iOS/Assets.xcassets/Contents.json
Normal file
6
app/EXO-iOS/EXO-iOS/Assets.xcassets/Contents.json
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"info" : {
|
||||
"author" : "xcode",
|
||||
"version" : 1
|
||||
}
|
||||
}
|
||||
8
app/EXO-iOS/EXO-iOS/EXO-iOS.entitlements
Normal file
8
app/EXO-iOS/EXO-iOS/EXO-iOS.entitlements
Normal file
@@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>com.apple.developer.kernel.increased-memory-limit</key>
|
||||
<true/>
|
||||
</dict>
|
||||
</plist>
|
||||
42
app/EXO-iOS/EXO-iOS/EXO_iOSApp.swift
Normal file
42
app/EXO-iOS/EXO-iOS/EXO_iOSApp.swift
Normal file
@@ -0,0 +1,42 @@
|
||||
import SwiftUI
|
||||
|
||||
@main
|
||||
struct EXO_iOSApp: App {
|
||||
@State private var clusterService = ClusterService()
|
||||
@State private var discoveryService = DiscoveryService()
|
||||
@State private var localInferenceService = LocalInferenceService()
|
||||
@State private var chatService: ChatService?
|
||||
|
||||
var body: some Scene {
|
||||
WindowGroup {
|
||||
if let chatService {
|
||||
RootView()
|
||||
.environment(clusterService)
|
||||
.environment(discoveryService)
|
||||
.environment(chatService)
|
||||
.environment(localInferenceService)
|
||||
.task {
|
||||
await clusterService.attemptAutoReconnect()
|
||||
discoveryService.startBrowsing()
|
||||
await localInferenceService.prepareModel()
|
||||
}
|
||||
.onChange(of: discoveryService.discoveredClusters) { _, clusters in
|
||||
guard !clusterService.isConnected,
|
||||
case .disconnected = clusterService.connectionState,
|
||||
let first = clusters.first
|
||||
else { return }
|
||||
Task {
|
||||
await clusterService.connectToDiscoveredCluster(first, using: discoveryService)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Color.clear.onAppear {
|
||||
chatService = ChatService(
|
||||
clusterService: clusterService,
|
||||
localInferenceService: localInferenceService
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
17
app/EXO-iOS/EXO-iOS/Info.plist
Normal file
17
app/EXO-iOS/EXO-iOS/Info.plist
Normal file
@@ -0,0 +1,17 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>CFBundleDisplayName</key>
|
||||
<string>EXO</string>
|
||||
<key>NSLocalNetworkUsageDescription</key>
|
||||
<string>EXO needs local network access to connect to your EXO cluster.</string>
|
||||
<key>NSBonjourServices</key>
|
||||
<array>
|
||||
<string>_exo._tcp</string>
|
||||
<string>_p2p._tcp</string>
|
||||
<string>_p2p._udp</string>
|
||||
<string>_libp2p._udp</string>
|
||||
</array>
|
||||
</dict>
|
||||
</plist>
|
||||
129
app/EXO-iOS/EXO-iOS/Models/ChatCompletionTypes.swift
Normal file
129
app/EXO-iOS/EXO-iOS/Models/ChatCompletionTypes.swift
Normal file
@@ -0,0 +1,129 @@
|
||||
import Foundation
|
||||
|
||||
// MARK: - Request
|
||||
|
||||
struct ChatCompletionRequest: Encodable {
|
||||
let model: String
|
||||
let messages: [ChatCompletionMessageParam]
|
||||
let stream: Bool
|
||||
let maxTokens: Int?
|
||||
let temperature: Double?
|
||||
|
||||
enum CodingKeys: String, CodingKey {
|
||||
case model, messages, stream, temperature
|
||||
case maxTokens = "max_tokens"
|
||||
}
|
||||
}
|
||||
|
||||
struct ChatCompletionMessageParam: Encodable {
|
||||
let role: String
|
||||
let content: String
|
||||
}
|
||||
|
||||
// MARK: - Streaming Response
|
||||
|
||||
struct ChatCompletionChunk: Decodable {
|
||||
let id: String
|
||||
let model: String?
|
||||
let choices: [StreamingChoice]
|
||||
let usage: ChunkUsage?
|
||||
|
||||
init(id: String, model: String?, choices: [StreamingChoice], usage: ChunkUsage?) {
|
||||
self.id = id
|
||||
self.model = model
|
||||
self.choices = choices
|
||||
self.usage = usage
|
||||
}
|
||||
}
|
||||
|
||||
struct StreamingChoice: Decodable {
|
||||
let index: Int
|
||||
let delta: Delta
|
||||
let finishReason: String?
|
||||
|
||||
enum CodingKeys: String, CodingKey {
|
||||
case index, delta
|
||||
case finishReason = "finish_reason"
|
||||
}
|
||||
|
||||
init(index: Int, delta: Delta, finishReason: String?) {
|
||||
self.index = index
|
||||
self.delta = delta
|
||||
self.finishReason = finishReason
|
||||
}
|
||||
}
|
||||
|
||||
struct Delta: Decodable {
|
||||
let role: String?
|
||||
let content: String?
|
||||
|
||||
init(role: String?, content: String?) {
|
||||
self.role = role
|
||||
self.content = content
|
||||
}
|
||||
}
|
||||
|
||||
struct ChunkUsage: Decodable {
|
||||
let promptTokens: Int?
|
||||
let completionTokens: Int?
|
||||
let totalTokens: Int?
|
||||
|
||||
enum CodingKeys: String, CodingKey {
|
||||
case promptTokens = "prompt_tokens"
|
||||
case completionTokens = "completion_tokens"
|
||||
case totalTokens = "total_tokens"
|
||||
}
|
||||
|
||||
init(promptTokens: Int?, completionTokens: Int?, totalTokens: Int?) {
|
||||
self.promptTokens = promptTokens
|
||||
self.completionTokens = completionTokens
|
||||
self.totalTokens = totalTokens
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Non-Streaming Response
|
||||
|
||||
struct ChatCompletionResponse: Decodable {
|
||||
let id: String
|
||||
let model: String?
|
||||
let choices: [ResponseChoice]
|
||||
}
|
||||
|
||||
struct ResponseChoice: Decodable {
|
||||
let index: Int
|
||||
let message: ResponseMessage
|
||||
let finishReason: String?
|
||||
|
||||
enum CodingKeys: String, CodingKey {
|
||||
case index, message
|
||||
case finishReason = "finish_reason"
|
||||
}
|
||||
}
|
||||
|
||||
struct ResponseMessage: Decodable {
|
||||
let role: String?
|
||||
let content: String?
|
||||
}
|
||||
|
||||
// MARK: - Models List
|
||||
|
||||
struct ModelListResponse: Decodable {
|
||||
let data: [ModelInfo]
|
||||
}
|
||||
|
||||
struct ModelInfo: Decodable, Identifiable {
|
||||
let id: String
|
||||
let name: String?
|
||||
}
|
||||
|
||||
// MARK: - Error
|
||||
|
||||
struct APIErrorResponse: Decodable {
|
||||
let error: APIErrorInfo
|
||||
}
|
||||
|
||||
struct APIErrorInfo: Decodable {
|
||||
let message: String
|
||||
let type: String?
|
||||
let code: Int?
|
||||
}
|
||||
26
app/EXO-iOS/EXO-iOS/Models/ChatMessage.swift
Normal file
26
app/EXO-iOS/EXO-iOS/Models/ChatMessage.swift
Normal file
@@ -0,0 +1,26 @@
|
||||
import Foundation
|
||||
|
||||
struct ChatMessage: Identifiable, Equatable {
|
||||
let id: UUID
|
||||
let role: Role
|
||||
var content: String
|
||||
let timestamp: Date
|
||||
var isStreaming: Bool
|
||||
|
||||
enum Role: String, Codable {
|
||||
case user
|
||||
case assistant
|
||||
case system
|
||||
}
|
||||
|
||||
init(
|
||||
id: UUID = UUID(), role: Role, content: String, timestamp: Date = Date(),
|
||||
isStreaming: Bool = false
|
||||
) {
|
||||
self.id = id
|
||||
self.role = role
|
||||
self.content = content
|
||||
self.timestamp = timestamp
|
||||
self.isStreaming = isStreaming
|
||||
}
|
||||
}
|
||||
11
app/EXO-iOS/EXO-iOS/Models/ConnectionInfo.swift
Normal file
11
app/EXO-iOS/EXO-iOS/Models/ConnectionInfo.swift
Normal file
@@ -0,0 +1,11 @@
|
||||
import Foundation
|
||||
|
||||
struct ConnectionInfo: Codable, Equatable {
|
||||
let host: String
|
||||
let port: Int
|
||||
let nodeId: String?
|
||||
|
||||
var baseURL: URL { URL(string: "http://\(host):\(port)")! }
|
||||
|
||||
static let defaultPort = 52415
|
||||
}
|
||||
34
app/EXO-iOS/EXO-iOS/Models/Conversation.swift
Normal file
34
app/EXO-iOS/EXO-iOS/Models/Conversation.swift
Normal file
@@ -0,0 +1,34 @@
|
||||
import Foundation
|
||||
|
||||
struct Conversation: Identifiable, Codable, Equatable {
|
||||
let id: UUID
|
||||
var title: String
|
||||
var messages: [StoredMessage]
|
||||
var modelId: String?
|
||||
let createdAt: Date
|
||||
|
||||
init(
|
||||
id: UUID = UUID(), title: String = "New Chat", messages: [StoredMessage] = [],
|
||||
modelId: String? = nil, createdAt: Date = Date()
|
||||
) {
|
||||
self.id = id
|
||||
self.title = title
|
||||
self.messages = messages
|
||||
self.modelId = modelId
|
||||
self.createdAt = createdAt
|
||||
}
|
||||
}
|
||||
|
||||
struct StoredMessage: Identifiable, Codable, Equatable {
|
||||
let id: UUID
|
||||
let role: String
|
||||
var content: String
|
||||
let timestamp: Date
|
||||
|
||||
init(id: UUID = UUID(), role: String, content: String, timestamp: Date = Date()) {
|
||||
self.id = id
|
||||
self.role = role
|
||||
self.content = content
|
||||
self.timestamp = timestamp
|
||||
}
|
||||
}
|
||||
225
app/EXO-iOS/EXO-iOS/Services/ChatService.swift
Normal file
225
app/EXO-iOS/EXO-iOS/Services/ChatService.swift
Normal file
@@ -0,0 +1,225 @@
|
||||
import Foundation
|
||||
|
||||
@Observable
|
||||
@MainActor
|
||||
final class ChatService {
|
||||
var conversations: [Conversation] = []
|
||||
var activeConversationId: UUID?
|
||||
private(set) var isGenerating: Bool = false
|
||||
private var currentGenerationTask: Task<Void, Never>?
|
||||
|
||||
private let clusterService: ClusterService
|
||||
private let localInferenceService: LocalInferenceService
|
||||
|
||||
var canSendMessage: Bool {
|
||||
clusterService.isConnected || localInferenceService.isAvailable
|
||||
}
|
||||
|
||||
var activeConversation: Conversation? {
|
||||
guard let id = activeConversationId else { return nil }
|
||||
return conversations.first { $0.id == id }
|
||||
}
|
||||
|
||||
var activeMessages: [ChatMessage] {
|
||||
guard let conversation = activeConversation else { return [] }
|
||||
return conversation.messages.map { stored in
|
||||
ChatMessage(
|
||||
id: stored.id,
|
||||
role: ChatMessage.Role(rawValue: stored.role) ?? .user,
|
||||
content: stored.content,
|
||||
timestamp: stored.timestamp
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
init(clusterService: ClusterService, localInferenceService: LocalInferenceService) {
|
||||
self.clusterService = clusterService
|
||||
self.localInferenceService = localInferenceService
|
||||
loadConversations()
|
||||
}
|
||||
|
||||
// MARK: - Conversation Management
|
||||
|
||||
func createConversation(modelId: String? = nil) {
|
||||
let conversation = Conversation(
|
||||
modelId: modelId ?? clusterService.availableModels.first?.id)
|
||||
conversations.insert(conversation, at: 0)
|
||||
activeConversationId = conversation.id
|
||||
saveConversations()
|
||||
}
|
||||
|
||||
func deleteConversation(id: UUID) {
|
||||
conversations.removeAll { $0.id == id }
|
||||
if activeConversationId == id {
|
||||
activeConversationId = conversations.first?.id
|
||||
}
|
||||
saveConversations()
|
||||
}
|
||||
|
||||
func setActiveConversation(id: UUID) {
|
||||
activeConversationId = id
|
||||
}
|
||||
|
||||
func setModelForActiveConversation(_ modelId: String) {
|
||||
guard let index = conversations.firstIndex(where: { $0.id == activeConversationId }) else {
|
||||
return
|
||||
}
|
||||
conversations[index].modelId = modelId
|
||||
saveConversations()
|
||||
}
|
||||
|
||||
// MARK: - Messaging
|
||||
|
||||
func sendMessage(_ text: String) {
|
||||
guard !text.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty else { return }
|
||||
|
||||
if activeConversation == nil {
|
||||
createConversation()
|
||||
}
|
||||
|
||||
guard let index = conversations.firstIndex(where: { $0.id == activeConversationId }) else {
|
||||
return
|
||||
}
|
||||
|
||||
let userMessage = StoredMessage(role: "user", content: text)
|
||||
conversations[index].messages.append(userMessage)
|
||||
|
||||
if conversations[index].title == "New Chat" {
|
||||
let preview = String(text.prefix(40))
|
||||
conversations[index].title = preview + (text.count > 40 ? "..." : "")
|
||||
}
|
||||
|
||||
let modelId: String
|
||||
if clusterService.isConnected {
|
||||
guard let clusterId = conversations[index].modelId ?? clusterService.availableModels.first?.id
|
||||
else {
|
||||
let errorMessage = StoredMessage(
|
||||
role: "assistant", content: "No model selected. Please select a model first.")
|
||||
conversations[index].messages.append(errorMessage)
|
||||
saveConversations()
|
||||
return
|
||||
}
|
||||
modelId = clusterId
|
||||
} else if localInferenceService.isAvailable {
|
||||
modelId = localInferenceService.defaultModelId
|
||||
} else {
|
||||
let errorMessage = StoredMessage(
|
||||
role: "assistant",
|
||||
content: "Not connected to a cluster and local model is not available.")
|
||||
conversations[index].messages.append(errorMessage)
|
||||
saveConversations()
|
||||
return
|
||||
}
|
||||
|
||||
conversations[index].modelId = modelId
|
||||
|
||||
let assistantMessageId = UUID()
|
||||
let assistantMessage = StoredMessage(
|
||||
id: assistantMessageId, role: "assistant", content: "", timestamp: Date())
|
||||
conversations[index].messages.append(assistantMessage)
|
||||
|
||||
let messagesForAPI = conversations[index].messages.dropLast().map { stored in
|
||||
ChatCompletionMessageParam(role: stored.role, content: stored.content)
|
||||
}
|
||||
|
||||
let request = ChatCompletionRequest(
|
||||
model: modelId,
|
||||
messages: Array(messagesForAPI),
|
||||
stream: true,
|
||||
maxTokens: 4096,
|
||||
temperature: nil
|
||||
)
|
||||
|
||||
let conversationId = conversations[index].id
|
||||
|
||||
isGenerating = true
|
||||
currentGenerationTask = Task { [weak self] in
|
||||
guard let self else { return }
|
||||
await self.performStreaming(
|
||||
request: request, conversationId: conversationId,
|
||||
assistantMessageId: assistantMessageId)
|
||||
}
|
||||
|
||||
saveConversations()
|
||||
}
|
||||
|
||||
func cancelGeneration() {
|
||||
currentGenerationTask?.cancel()
|
||||
currentGenerationTask = nil
|
||||
localInferenceService.cancelGeneration()
|
||||
isGenerating = false
|
||||
}
|
||||
|
||||
// MARK: - Streaming
|
||||
|
||||
private func performStreaming(
|
||||
request: ChatCompletionRequest, conversationId: UUID, assistantMessageId: UUID
|
||||
) async {
|
||||
defer {
|
||||
isGenerating = false
|
||||
currentGenerationTask = nil
|
||||
saveConversations()
|
||||
}
|
||||
|
||||
do {
|
||||
let stream =
|
||||
clusterService.isConnected
|
||||
? clusterService.streamChatCompletion(request: request)
|
||||
: localInferenceService.streamChatCompletion(request: request)
|
||||
for try await chunk in stream {
|
||||
guard !Task.isCancelled else { return }
|
||||
guard let content = chunk.choices.first?.delta.content, !content.isEmpty else {
|
||||
continue
|
||||
}
|
||||
|
||||
if let convIndex = conversations.firstIndex(where: { $0.id == conversationId }),
|
||||
let msgIndex = conversations[convIndex].messages.firstIndex(where: {
|
||||
$0.id == assistantMessageId
|
||||
})
|
||||
{
|
||||
conversations[convIndex].messages[msgIndex].content += content
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
if !Task.isCancelled {
|
||||
if let convIndex = conversations.firstIndex(where: { $0.id == conversationId }),
|
||||
let msgIndex = conversations[convIndex].messages.firstIndex(where: {
|
||||
$0.id == assistantMessageId
|
||||
})
|
||||
{
|
||||
if conversations[convIndex].messages[msgIndex].content.isEmpty {
|
||||
conversations[convIndex].messages[msgIndex].content =
|
||||
"Error: \(error.localizedDescription)"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Persistence
|
||||
|
||||
private static var storageURL: URL {
|
||||
let documents = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask)
|
||||
.first!
|
||||
return documents.appendingPathComponent("exo_conversations.json")
|
||||
}
|
||||
|
||||
private func saveConversations() {
|
||||
do {
|
||||
let data = try JSONEncoder().encode(conversations)
|
||||
try data.write(to: Self.storageURL, options: .atomic)
|
||||
} catch {
|
||||
// Save failed silently
|
||||
}
|
||||
}
|
||||
|
||||
private func loadConversations() {
|
||||
do {
|
||||
let data = try Data(contentsOf: Self.storageURL)
|
||||
conversations = try JSONDecoder().decode([Conversation].self, from: data)
|
||||
activeConversationId = conversations.first?.id
|
||||
} catch {
|
||||
conversations = []
|
||||
}
|
||||
}
|
||||
}
|
||||
198
app/EXO-iOS/EXO-iOS/Services/ClusterService.swift
Normal file
198
app/EXO-iOS/EXO-iOS/Services/ClusterService.swift
Normal file
@@ -0,0 +1,198 @@
|
||||
import Foundation
|
||||
|
||||
enum ConnectionState: Equatable {
|
||||
case disconnected
|
||||
case connecting
|
||||
case connected(ConnectionInfo)
|
||||
}
|
||||
|
||||
struct ModelOption: Identifiable, Equatable {
|
||||
let id: String
|
||||
let displayName: String
|
||||
}
|
||||
|
||||
@Observable
|
||||
@MainActor
|
||||
final class ClusterService {
|
||||
private(set) var connectionState: ConnectionState = .disconnected
|
||||
private(set) var availableModels: [ModelOption] = []
|
||||
private(set) var lastError: String?
|
||||
|
||||
private let session: URLSession
|
||||
private let decoder: JSONDecoder
|
||||
private var pollingTask: Task<Void, Never>?
|
||||
|
||||
private static let connectionInfoKey = "exo_last_connection_info"
|
||||
|
||||
var isConnected: Bool {
|
||||
if case .connected = connectionState { return true }
|
||||
return false
|
||||
}
|
||||
|
||||
var currentConnection: ConnectionInfo? {
|
||||
if case .connected(let info) = connectionState { return info }
|
||||
return nil
|
||||
}
|
||||
|
||||
init(session: URLSession = .shared) {
|
||||
self.session = session
|
||||
let decoder = JSONDecoder()
|
||||
self.decoder = decoder
|
||||
}
|
||||
|
||||
// MARK: - Connection
|
||||
|
||||
func connect(to info: ConnectionInfo) async {
|
||||
connectionState = .connecting
|
||||
lastError = nil
|
||||
|
||||
do {
|
||||
let url = info.baseURL.appendingPathComponent("node_id")
|
||||
var request = URLRequest(url: url)
|
||||
request.timeoutInterval = 5
|
||||
request.cachePolicy = .reloadIgnoringLocalCacheData
|
||||
let (_, response) = try await session.data(for: request)
|
||||
|
||||
guard let httpResponse = response as? HTTPURLResponse,
|
||||
(200..<300).contains(httpResponse.statusCode)
|
||||
else {
|
||||
throw URLError(.badServerResponse)
|
||||
}
|
||||
|
||||
connectionState = .connected(info)
|
||||
persistConnection(info)
|
||||
startPolling()
|
||||
await fetchModels(baseURL: info.baseURL)
|
||||
} catch {
|
||||
connectionState = .disconnected
|
||||
lastError = "Could not connect to \(info.host):\(info.port)"
|
||||
}
|
||||
}
|
||||
|
||||
func connectToDiscoveredCluster(_ cluster: DiscoveredCluster, using discoveryService: DiscoveryService) async {
|
||||
guard case .disconnected = connectionState else { return }
|
||||
connectionState = .connecting
|
||||
lastError = nil
|
||||
|
||||
guard let info = await discoveryService.resolve(cluster) else {
|
||||
connectionState = .disconnected
|
||||
lastError = "Could not resolve \(cluster.name)"
|
||||
return
|
||||
}
|
||||
connectionState = .disconnected // reset so connect() can proceed
|
||||
await connect(to: info)
|
||||
}
|
||||
|
||||
func disconnect() {
|
||||
stopPolling()
|
||||
connectionState = .disconnected
|
||||
availableModels = []
|
||||
lastError = nil
|
||||
}
|
||||
|
||||
func attemptAutoReconnect() async {
|
||||
guard case .disconnected = connectionState,
|
||||
let info = loadPersistedConnection()
|
||||
else { return }
|
||||
await connect(to: info)
|
||||
}
|
||||
|
||||
// MARK: - Polling
|
||||
|
||||
private func startPolling(interval: TimeInterval = 2.0) {
|
||||
stopPolling()
|
||||
pollingTask = Task { [weak self] in
|
||||
while !Task.isCancelled {
|
||||
try? await Task.sleep(for: .seconds(interval))
|
||||
guard let self, !Task.isCancelled else { return }
|
||||
guard let connection = self.currentConnection else { return }
|
||||
await self.fetchModels(baseURL: connection.baseURL)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func stopPolling() {
|
||||
pollingTask?.cancel()
|
||||
pollingTask = nil
|
||||
}
|
||||
|
||||
// MARK: - API
|
||||
|
||||
private func fetchModels(baseURL: URL) async {
|
||||
do {
|
||||
let url = baseURL.appendingPathComponent("models")
|
||||
var request = URLRequest(url: url)
|
||||
request.cachePolicy = .reloadIgnoringLocalCacheData
|
||||
let (data, response) = try await session.data(for: request)
|
||||
|
||||
guard let httpResponse = response as? HTTPURLResponse,
|
||||
(200..<300).contains(httpResponse.statusCode)
|
||||
else { return }
|
||||
|
||||
let list = try decoder.decode(ModelListResponse.self, from: data)
|
||||
availableModels = list.data.map {
|
||||
ModelOption(id: $0.id, displayName: $0.name ?? $0.id)
|
||||
}
|
||||
} catch {
|
||||
// Models fetch failed silently — will retry on next poll
|
||||
}
|
||||
}
|
||||
|
||||
func streamChatCompletion(request body: ChatCompletionRequest) -> AsyncThrowingStream<
|
||||
ChatCompletionChunk, Error
|
||||
> {
|
||||
AsyncThrowingStream { continuation in
|
||||
let task = Task { [weak self] in
|
||||
guard let self, let connection = self.currentConnection else {
|
||||
continuation.finish(throwing: URLError(.notConnectedToInternet))
|
||||
return
|
||||
}
|
||||
|
||||
do {
|
||||
let url = connection.baseURL.appendingPathComponent("v1/chat/completions")
|
||||
var request = URLRequest(url: url)
|
||||
request.httpMethod = "POST"
|
||||
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
|
||||
request.httpBody = try JSONEncoder().encode(body)
|
||||
|
||||
let (bytes, response) = try await self.session.bytes(for: request)
|
||||
|
||||
guard let httpResponse = response as? HTTPURLResponse,
|
||||
(200..<300).contains(httpResponse.statusCode)
|
||||
else {
|
||||
continuation.finish(throwing: URLError(.badServerResponse))
|
||||
return
|
||||
}
|
||||
|
||||
let parser = SSEStreamParser<ChatCompletionChunk>(
|
||||
bytes: bytes, decoder: self.decoder)
|
||||
for try await chunk in parser {
|
||||
continuation.yield(chunk)
|
||||
}
|
||||
continuation.finish()
|
||||
} catch {
|
||||
continuation.finish(throwing: error)
|
||||
}
|
||||
}
|
||||
|
||||
continuation.onTermination = { _ in
|
||||
task.cancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Persistence
|
||||
|
||||
private func persistConnection(_ info: ConnectionInfo) {
|
||||
if let data = try? JSONEncoder().encode(info) {
|
||||
UserDefaults.standard.set(data, forKey: Self.connectionInfoKey)
|
||||
}
|
||||
}
|
||||
|
||||
private func loadPersistedConnection() -> ConnectionInfo? {
|
||||
guard let data = UserDefaults.standard.data(forKey: Self.connectionInfoKey) else {
|
||||
return nil
|
||||
}
|
||||
return try? JSONDecoder().decode(ConnectionInfo.self, from: data)
|
||||
}
|
||||
}
|
||||
118
app/EXO-iOS/EXO-iOS/Services/DiscoveryService.swift
Normal file
118
app/EXO-iOS/EXO-iOS/Services/DiscoveryService.swift
Normal file
@@ -0,0 +1,118 @@
|
||||
import Foundation
|
||||
import Network
|
||||
import os
|
||||
|
||||
struct DiscoveredCluster: Identifiable, Equatable {
|
||||
let id: String
|
||||
let name: String
|
||||
let endpoint: NWEndpoint
|
||||
|
||||
static func == (lhs: DiscoveredCluster, rhs: DiscoveredCluster) -> Bool {
|
||||
lhs.id == rhs.id && lhs.name == rhs.name
|
||||
}
|
||||
}
|
||||
|
||||
@Observable
|
||||
@MainActor
|
||||
final class DiscoveryService {
|
||||
private(set) var discoveredClusters: [DiscoveredCluster] = []
|
||||
private(set) var isSearching = false
|
||||
|
||||
private var browser: NWBrowser?
|
||||
|
||||
func startBrowsing() {
|
||||
guard browser == nil else { return }
|
||||
|
||||
let browser = NWBrowser(for: .bonjour(type: "_exo._tcp", domain: nil), using: .tcp)
|
||||
|
||||
browser.stateUpdateHandler = { [weak self] state in
|
||||
guard let service = self else { return }
|
||||
Task { @MainActor in
|
||||
switch state {
|
||||
case .ready:
|
||||
service.isSearching = true
|
||||
case .failed, .cancelled:
|
||||
service.isSearching = false
|
||||
default:
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
browser.browseResultsChangedHandler = { [weak self] results, _ in
|
||||
guard let service = self else { return }
|
||||
Task { @MainActor in
|
||||
service.discoveredClusters = results.compactMap { result in
|
||||
guard case .service(let name, _, _, _) = result.endpoint else {
|
||||
return nil
|
||||
}
|
||||
return DiscoveredCluster(
|
||||
id: name,
|
||||
name: name,
|
||||
endpoint: result.endpoint
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
browser.start(queue: .main)
|
||||
self.browser = browser
|
||||
}
|
||||
|
||||
func stopBrowsing() {
|
||||
browser?.cancel()
|
||||
browser = nil
|
||||
isSearching = false
|
||||
discoveredClusters = []
|
||||
}
|
||||
|
||||
/// Resolve a discovered Bonjour endpoint to an IP address and port, then return a ConnectionInfo.
|
||||
func resolve(_ cluster: DiscoveredCluster) async -> ConnectionInfo? {
|
||||
await withCheckedContinuation { continuation in
|
||||
let didResume = OSAllocatedUnfairLock(initialState: false)
|
||||
let connection = NWConnection(to: cluster.endpoint, using: .tcp)
|
||||
connection.stateUpdateHandler = { state in
|
||||
guard didResume.withLock({ guard !$0 else { return false }; $0 = true; return true })
|
||||
else { return }
|
||||
switch state {
|
||||
case .ready:
|
||||
if let innerEndpoint = connection.currentPath?.remoteEndpoint,
|
||||
case .hostPort(let host, let port) = innerEndpoint
|
||||
{
|
||||
var hostString: String
|
||||
switch host {
|
||||
case .ipv4(let addr):
|
||||
hostString = "\(addr)"
|
||||
case .ipv6(let addr):
|
||||
hostString = "\(addr)"
|
||||
case .name(let name, _):
|
||||
hostString = name
|
||||
@unknown default:
|
||||
hostString = "\(host)"
|
||||
}
|
||||
// Strip interface scope suffix (e.g. "%en0")
|
||||
if let pct = hostString.firstIndex(of: "%") {
|
||||
hostString = String(hostString[..<pct])
|
||||
}
|
||||
let info = ConnectionInfo(
|
||||
host: hostString,
|
||||
port: Int(port.rawValue),
|
||||
nodeId: nil
|
||||
)
|
||||
connection.cancel()
|
||||
continuation.resume(returning: info)
|
||||
} else {
|
||||
connection.cancel()
|
||||
continuation.resume(returning: nil)
|
||||
}
|
||||
case .failed, .cancelled:
|
||||
continuation.resume(returning: nil)
|
||||
default:
|
||||
// Not a terminal state — allow future callbacks
|
||||
didResume.withLock { $0 = false }
|
||||
}
|
||||
}
|
||||
connection.start(queue: .global(qos: .userInitiated))
|
||||
}
|
||||
}
|
||||
}
|
||||
201
app/EXO-iOS/EXO-iOS/Services/LocalInferenceService.swift
Normal file
201
app/EXO-iOS/EXO-iOS/Services/LocalInferenceService.swift
Normal file
@@ -0,0 +1,201 @@
|
||||
import Foundation
|
||||
import MLXLLM
|
||||
import MLXLMCommon
|
||||
|
||||
enum LocalModelState: Equatable {
|
||||
case notDownloaded
|
||||
case downloading(progress: Double)
|
||||
case downloaded
|
||||
case loading
|
||||
case ready
|
||||
case generating
|
||||
case error(String)
|
||||
}
|
||||
|
||||
@Observable
|
||||
@MainActor
|
||||
final class LocalInferenceService {
|
||||
private(set) var modelState: LocalModelState = .notDownloaded
|
||||
private var modelContainer: ModelContainer?
|
||||
private var generationTask: Task<Void, Never>?
|
||||
|
||||
let defaultModelId = "mlx-community/Qwen3-0.6B-4bit"
|
||||
|
||||
private static let modelDownloadedKey = "exo_local_model_downloaded"
|
||||
|
||||
var isReady: Bool {
|
||||
modelState == .ready
|
||||
}
|
||||
|
||||
var isAvailable: Bool {
|
||||
modelState == .ready || modelState == .generating
|
||||
}
|
||||
|
||||
init() {
|
||||
if UserDefaults.standard.bool(forKey: Self.modelDownloadedKey) {
|
||||
modelState = .downloaded
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Model Lifecycle
|
||||
|
||||
func prepareModel() async {
|
||||
guard modelState == .notDownloaded || modelState == .downloaded else { return }
|
||||
|
||||
let wasDownloaded = modelState == .downloaded
|
||||
|
||||
if !wasDownloaded {
|
||||
modelState = .downloading(progress: 0)
|
||||
} else {
|
||||
modelState = .loading
|
||||
}
|
||||
|
||||
do {
|
||||
let container = try await loadModelContainer(
|
||||
id: defaultModelId
|
||||
) { [weak self] progress in
|
||||
guard let self else { return }
|
||||
Task { @MainActor in
|
||||
if case .downloading = self.modelState {
|
||||
self.modelState = .downloading(progress: progress.fractionCompleted)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.modelContainer = container
|
||||
UserDefaults.standard.set(true, forKey: Self.modelDownloadedKey)
|
||||
modelState = .ready
|
||||
} catch {
|
||||
modelState = .error(error.localizedDescription)
|
||||
}
|
||||
}
|
||||
|
||||
func unloadModel() {
|
||||
cancelGeneration()
|
||||
modelContainer = nil
|
||||
modelState = .downloaded
|
||||
}
|
||||
|
||||
// MARK: - Generation
|
||||
|
||||
func streamChatCompletion(request: ChatCompletionRequest) -> AsyncThrowingStream<
|
||||
ChatCompletionChunk, Error
|
||||
> {
|
||||
AsyncThrowingStream { continuation in
|
||||
let task = Task { [weak self] in
|
||||
guard let self else {
|
||||
continuation.finish(throwing: LocalInferenceError.serviceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
guard let container = self.modelContainer else {
|
||||
continuation.finish(throwing: LocalInferenceError.modelNotLoaded)
|
||||
return
|
||||
}
|
||||
|
||||
await MainActor.run {
|
||||
self.modelState = .generating
|
||||
}
|
||||
|
||||
defer {
|
||||
Task { @MainActor [weak self] in
|
||||
if self?.modelState == .generating {
|
||||
self?.modelState = .ready
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let chunkId = "local-\(UUID().uuidString)"
|
||||
|
||||
do {
|
||||
// Build Chat.Message array from the request
|
||||
var chatMessages: [Chat.Message] = []
|
||||
for msg in request.messages {
|
||||
switch msg.role {
|
||||
case "system":
|
||||
chatMessages.append(.system(msg.content))
|
||||
case "assistant":
|
||||
chatMessages.append(.assistant(msg.content))
|
||||
default:
|
||||
chatMessages.append(.user(msg.content))
|
||||
}
|
||||
}
|
||||
|
||||
// Use ChatSession for streaming generation
|
||||
let session = ChatSession(
|
||||
container,
|
||||
history: chatMessages,
|
||||
generateParameters: GenerateParameters(
|
||||
maxTokens: request.maxTokens ?? 4096,
|
||||
temperature: Float(request.temperature ?? 0.7)
|
||||
)
|
||||
)
|
||||
|
||||
// Stream with an empty prompt since history already contains the conversation
|
||||
let stream = session.streamResponse(to: "")
|
||||
for try await text in stream {
|
||||
if Task.isCancelled { break }
|
||||
|
||||
let chunk = ChatCompletionChunk(
|
||||
id: chunkId,
|
||||
model: request.model,
|
||||
choices: [
|
||||
StreamingChoice(
|
||||
index: 0,
|
||||
delta: Delta(role: nil, content: text),
|
||||
finishReason: nil
|
||||
)
|
||||
],
|
||||
usage: nil
|
||||
)
|
||||
continuation.yield(chunk)
|
||||
}
|
||||
|
||||
// Send final chunk with finish reason
|
||||
let finalChunk = ChatCompletionChunk(
|
||||
id: chunkId,
|
||||
model: request.model,
|
||||
choices: [
|
||||
StreamingChoice(
|
||||
index: 0,
|
||||
delta: Delta(role: nil, content: nil),
|
||||
finishReason: "stop"
|
||||
)
|
||||
],
|
||||
usage: nil
|
||||
)
|
||||
continuation.yield(finalChunk)
|
||||
continuation.finish()
|
||||
} catch {
|
||||
continuation.finish(throwing: error)
|
||||
}
|
||||
}
|
||||
|
||||
self.generationTask = task
|
||||
|
||||
continuation.onTermination = { _ in
|
||||
task.cancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func cancelGeneration() {
|
||||
generationTask?.cancel()
|
||||
generationTask = nil
|
||||
if modelState == .generating {
|
||||
modelState = .ready
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum LocalInferenceError: LocalizedError {
|
||||
case serviceUnavailable
|
||||
case modelNotLoaded
|
||||
|
||||
var errorDescription: String? {
|
||||
switch self {
|
||||
case .serviceUnavailable: "Local inference service is unavailable"
|
||||
case .modelNotLoaded: "Local model is not loaded"
|
||||
}
|
||||
}
|
||||
}
|
||||
50
app/EXO-iOS/EXO-iOS/Services/SSEStreamParser.swift
Normal file
50
app/EXO-iOS/EXO-iOS/Services/SSEStreamParser.swift
Normal file
@@ -0,0 +1,50 @@
|
||||
import Foundation
|
||||
|
||||
struct SSEStreamParser<T: Decodable>: AsyncSequence {
|
||||
typealias Element = T
|
||||
|
||||
let bytes: URLSession.AsyncBytes
|
||||
let decoder: JSONDecoder
|
||||
|
||||
init(bytes: URLSession.AsyncBytes, decoder: JSONDecoder = JSONDecoder()) {
|
||||
self.bytes = bytes
|
||||
self.decoder = decoder
|
||||
}
|
||||
|
||||
func makeAsyncIterator() -> AsyncIterator {
|
||||
AsyncIterator(lines: bytes.lines, decoder: decoder)
|
||||
}
|
||||
|
||||
struct AsyncIterator: AsyncIteratorProtocol {
|
||||
var lines: AsyncLineSequence<URLSession.AsyncBytes>.AsyncIterator
|
||||
let decoder: JSONDecoder
|
||||
|
||||
init(lines: AsyncLineSequence<URLSession.AsyncBytes>, decoder: JSONDecoder) {
|
||||
self.lines = lines.makeAsyncIterator()
|
||||
self.decoder = decoder
|
||||
}
|
||||
|
||||
mutating func next() async throws -> T? {
|
||||
while let line = try await lines.next() {
|
||||
let trimmed = line.trimmingCharacters(in: .whitespaces)
|
||||
|
||||
guard trimmed.hasPrefix("data: ") else { continue }
|
||||
|
||||
let payload = String(trimmed.dropFirst(6))
|
||||
|
||||
if payload == "[DONE]" {
|
||||
return nil
|
||||
}
|
||||
|
||||
guard let data = payload.data(using: .utf8) else { continue }
|
||||
|
||||
do {
|
||||
return try decoder.decode(T.self, from: data)
|
||||
} catch {
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
171
app/EXO-iOS/EXO-iOS/Views/Chat/ChatView.swift
Normal file
171
app/EXO-iOS/EXO-iOS/Views/Chat/ChatView.swift
Normal file
@@ -0,0 +1,171 @@
|
||||
import SwiftUI
|
||||
|
||||
struct ChatView: View {
|
||||
@Environment(ClusterService.self) private var clusterService
|
||||
@Environment(ChatService.self) private var chatService
|
||||
@Environment(LocalInferenceService.self) private var localInferenceService
|
||||
@State private var inputText = ""
|
||||
@State private var showModelSelector = false
|
||||
|
||||
var body: some View {
|
||||
VStack(spacing: 0) {
|
||||
modelBar
|
||||
|
||||
Divider()
|
||||
|
||||
messageList
|
||||
|
||||
Divider()
|
||||
|
||||
inputBar
|
||||
}
|
||||
.sheet(isPresented: $showModelSelector) {
|
||||
ModelSelectorView(
|
||||
models: clusterService.availableModels,
|
||||
selectedModelId: chatService.activeConversation?.modelId
|
||||
) { modelId in
|
||||
chatService.setModelForActiveConversation(modelId)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Model Bar
|
||||
|
||||
private var useLocalModel: Bool {
|
||||
!clusterService.isConnected && localInferenceService.isAvailable
|
||||
}
|
||||
|
||||
private var modelBar: some View {
|
||||
Button {
|
||||
if !useLocalModel {
|
||||
showModelSelector = true
|
||||
}
|
||||
} label: {
|
||||
HStack {
|
||||
Image(systemName: useLocalModel ? "iphone" : "cpu")
|
||||
.foregroundStyle(useLocalModel ? .blue : .secondary)
|
||||
|
||||
if useLocalModel {
|
||||
Text(localInferenceService.defaultModelId)
|
||||
.font(.subheadline)
|
||||
.lineLimit(1)
|
||||
} else if let modelId = chatService.activeConversation?.modelId {
|
||||
Text(modelId)
|
||||
.font(.subheadline)
|
||||
.lineLimit(1)
|
||||
} else {
|
||||
Text("Select Model")
|
||||
.font(.subheadline)
|
||||
.foregroundStyle(.secondary)
|
||||
}
|
||||
|
||||
Spacer()
|
||||
|
||||
if useLocalModel {
|
||||
Text("On-Device")
|
||||
.font(.caption2)
|
||||
.foregroundStyle(.blue)
|
||||
.padding(.horizontal, 6)
|
||||
.padding(.vertical, 2)
|
||||
.background(.blue.opacity(0.1))
|
||||
.clipShape(Capsule())
|
||||
} else {
|
||||
Image(systemName: "chevron.right")
|
||||
.font(.caption)
|
||||
.foregroundStyle(.tertiary)
|
||||
}
|
||||
}
|
||||
.padding(.horizontal)
|
||||
.padding(.vertical, 10)
|
||||
.background(Color(.secondarySystemBackground))
|
||||
}
|
||||
.tint(.primary)
|
||||
.disabled(useLocalModel)
|
||||
}
|
||||
|
||||
// MARK: - Messages
|
||||
|
||||
private var messageList: some View {
|
||||
ScrollViewReader { proxy in
|
||||
ScrollView {
|
||||
LazyVStack(spacing: 12) {
|
||||
if chatService.activeMessages.isEmpty {
|
||||
emptyState
|
||||
} else {
|
||||
ForEach(chatService.activeMessages) { message in
|
||||
MessageBubbleView(message: message)
|
||||
.id(message.id)
|
||||
}
|
||||
}
|
||||
}
|
||||
.padding()
|
||||
}
|
||||
.onChange(of: chatService.activeMessages.last?.content) {
|
||||
if let lastId = chatService.activeMessages.last?.id {
|
||||
withAnimation(.easeOut(duration: 0.2)) {
|
||||
proxy.scrollTo(lastId, anchor: .bottom)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private var emptyState: some View {
|
||||
VStack(spacing: 12) {
|
||||
Spacer(minLength: 80)
|
||||
Image(systemName: "bubble.left.and.bubble.right")
|
||||
.font(.system(size: 48))
|
||||
.foregroundStyle(.tertiary)
|
||||
Text("Start a conversation")
|
||||
.font(.headline)
|
||||
.foregroundStyle(.secondary)
|
||||
Text("Send a message to begin chatting with the model.")
|
||||
.font(.subheadline)
|
||||
.foregroundStyle(.tertiary)
|
||||
.multilineTextAlignment(.center)
|
||||
Spacer(minLength: 80)
|
||||
}
|
||||
.padding()
|
||||
}
|
||||
|
||||
// MARK: - Input
|
||||
|
||||
private var inputBar: some View {
|
||||
HStack(alignment: .bottom, spacing: 8) {
|
||||
TextField("Message...", text: $inputText, axis: .vertical)
|
||||
.lineLimit(1...6)
|
||||
.textFieldStyle(.plain)
|
||||
.padding(10)
|
||||
.background(Color(.tertiarySystemBackground))
|
||||
.clipShape(RoundedRectangle(cornerRadius: 20))
|
||||
|
||||
if chatService.isGenerating {
|
||||
Button {
|
||||
chatService.cancelGeneration()
|
||||
} label: {
|
||||
Image(systemName: "stop.circle.fill")
|
||||
.font(.title2)
|
||||
.foregroundStyle(.red)
|
||||
}
|
||||
} else {
|
||||
Button {
|
||||
let text = inputText
|
||||
inputText = ""
|
||||
chatService.sendMessage(text)
|
||||
} label: {
|
||||
Image(systemName: "arrow.up.circle.fill")
|
||||
.font(.title2)
|
||||
.foregroundStyle(canSend ? Color.accentColor : Color.gray)
|
||||
}
|
||||
.disabled(!canSend)
|
||||
}
|
||||
}
|
||||
.padding(.horizontal)
|
||||
.padding(.vertical, 8)
|
||||
}
|
||||
|
||||
private var canSend: Bool {
|
||||
!inputText.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty
|
||||
&& (clusterService.isConnected || localInferenceService.isAvailable)
|
||||
}
|
||||
}
|
||||
27
app/EXO-iOS/EXO-iOS/Views/Chat/MessageBubbleView.swift
Normal file
27
app/EXO-iOS/EXO-iOS/Views/Chat/MessageBubbleView.swift
Normal file
@@ -0,0 +1,27 @@
|
||||
import SwiftUI
|
||||
|
||||
struct MessageBubbleView: View {
|
||||
let message: ChatMessage
|
||||
|
||||
var body: some View {
|
||||
HStack {
|
||||
if message.role == .user { Spacer(minLength: 48) }
|
||||
|
||||
VStack(alignment: message.role == .user ? .trailing : .leading, spacing: 4) {
|
||||
Text(message.content + (message.isStreaming ? " \u{258C}" : ""))
|
||||
.textSelection(.enabled)
|
||||
.padding(.horizontal, 14)
|
||||
.padding(.vertical, 10)
|
||||
.background(bubbleBackground)
|
||||
.foregroundStyle(message.role == .user ? .white : .primary)
|
||||
.clipShape(RoundedRectangle(cornerRadius: 16))
|
||||
}
|
||||
|
||||
if message.role == .assistant { Spacer(minLength: 48) }
|
||||
}
|
||||
}
|
||||
|
||||
private var bubbleBackground: Color {
|
||||
message.role == .user ? .accentColor : Color(.secondarySystemBackground)
|
||||
}
|
||||
}
|
||||
66
app/EXO-iOS/EXO-iOS/Views/Chat/ModelSelectorView.swift
Normal file
66
app/EXO-iOS/EXO-iOS/Views/Chat/ModelSelectorView.swift
Normal file
@@ -0,0 +1,66 @@
|
||||
import SwiftUI
|
||||
|
||||
struct ModelSelectorView: View {
|
||||
let models: [ModelOption]
|
||||
let selectedModelId: String?
|
||||
let onSelect: (String) -> Void
|
||||
@Environment(\.dismiss) private var dismiss
|
||||
|
||||
var body: some View {
|
||||
NavigationStack {
|
||||
List {
|
||||
if models.isEmpty {
|
||||
emptyContent
|
||||
} else {
|
||||
modelsList
|
||||
}
|
||||
}
|
||||
.navigationTitle("Select Model")
|
||||
.navigationBarTitleDisplayMode(.inline)
|
||||
.toolbar {
|
||||
ToolbarItem(placement: .cancellationAction) {
|
||||
Button("Cancel") { dismiss() }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private var emptyContent: some View {
|
||||
ContentUnavailableView(
|
||||
"No Models Available",
|
||||
systemImage: "cpu",
|
||||
description: Text("Connect to an EXO cluster to see available models.")
|
||||
)
|
||||
}
|
||||
|
||||
private var modelsList: some View {
|
||||
ForEach(models) { model in
|
||||
Button {
|
||||
onSelect(model.id)
|
||||
dismiss()
|
||||
} label: {
|
||||
modelRow(model)
|
||||
}
|
||||
.tint(.primary)
|
||||
}
|
||||
}
|
||||
|
||||
private func modelRow(_ model: ModelOption) -> some View {
|
||||
HStack {
|
||||
VStack(alignment: .leading, spacing: 2) {
|
||||
Text(model.displayName)
|
||||
.fontWeight(.medium)
|
||||
Text(model.id)
|
||||
.font(.caption)
|
||||
.foregroundStyle(.secondary)
|
||||
}
|
||||
|
||||
Spacer()
|
||||
|
||||
if model.id == selectedModelId {
|
||||
Image(systemName: "checkmark")
|
||||
.foregroundStyle(Color.accentColor)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
import SwiftUI
|
||||
|
||||
struct ConnectionStatusBadge: View {
|
||||
let connectionState: ConnectionState
|
||||
var localModelState: LocalModelState = .notDownloaded
|
||||
|
||||
private var isLocalReady: Bool {
|
||||
if case .disconnected = connectionState {
|
||||
return localModelState == .ready || localModelState == .generating
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
var body: some View {
|
||||
HStack(spacing: 6) {
|
||||
Circle()
|
||||
.fill(dotColor)
|
||||
.frame(width: 8, height: 8)
|
||||
|
||||
Text(label)
|
||||
.font(.caption)
|
||||
.fontWeight(.medium)
|
||||
}
|
||||
.padding(.horizontal, 10)
|
||||
.padding(.vertical, 5)
|
||||
.background(backgroundColor)
|
||||
.clipShape(Capsule())
|
||||
}
|
||||
|
||||
private var dotColor: Color {
|
||||
if isLocalReady {
|
||||
return .blue
|
||||
}
|
||||
switch connectionState {
|
||||
case .connected: return .green
|
||||
case .connecting: return .orange
|
||||
case .disconnected: return .gray
|
||||
}
|
||||
}
|
||||
|
||||
private var label: String {
|
||||
if isLocalReady {
|
||||
return "Local"
|
||||
}
|
||||
switch connectionState {
|
||||
case .connected: return "Connected"
|
||||
case .connecting: return "Connecting..."
|
||||
case .disconnected: return "Disconnected"
|
||||
}
|
||||
}
|
||||
|
||||
private var backgroundColor: Color {
|
||||
if isLocalReady {
|
||||
return .blue.opacity(0.15)
|
||||
}
|
||||
switch connectionState {
|
||||
case .connected: return .green.opacity(0.15)
|
||||
case .connecting: return .orange.opacity(0.15)
|
||||
case .disconnected: return .gray.opacity(0.15)
|
||||
}
|
||||
}
|
||||
}
|
||||
117
app/EXO-iOS/EXO-iOS/Views/RootView.swift
Normal file
117
app/EXO-iOS/EXO-iOS/Views/RootView.swift
Normal file
@@ -0,0 +1,117 @@
|
||||
import SwiftUI
|
||||
|
||||
struct RootView: View {
|
||||
@Environment(ClusterService.self) private var clusterService
|
||||
@Environment(DiscoveryService.self) private var discoveryService
|
||||
@Environment(ChatService.self) private var chatService
|
||||
@Environment(LocalInferenceService.self) private var localInferenceService
|
||||
@State private var showSettings = false
|
||||
@State private var showConversations = false
|
||||
|
||||
var body: some View {
|
||||
NavigationStack {
|
||||
ChatView()
|
||||
.navigationTitle("EXO")
|
||||
.navigationBarTitleDisplayMode(.inline)
|
||||
.toolbar {
|
||||
ToolbarItem(placement: .topBarLeading) {
|
||||
conversationMenuButton
|
||||
}
|
||||
|
||||
ToolbarItem(placement: .principal) {
|
||||
ConnectionStatusBadge(
|
||||
connectionState: clusterService.connectionState,
|
||||
localModelState: localInferenceService.modelState
|
||||
)
|
||||
}
|
||||
|
||||
ToolbarItem(placement: .topBarTrailing) {
|
||||
Button {
|
||||
showSettings = true
|
||||
} label: {
|
||||
Image(systemName: "gear")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
.sheet(isPresented: $showSettings) {
|
||||
SettingsView()
|
||||
.environment(discoveryService)
|
||||
}
|
||||
.sheet(isPresented: $showConversations) {
|
||||
conversationList
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Conversations
|
||||
|
||||
private var conversationMenuButton: some View {
|
||||
HStack(spacing: 12) {
|
||||
Button {
|
||||
showConversations = true
|
||||
} label: {
|
||||
Image(systemName: "sidebar.left")
|
||||
}
|
||||
|
||||
Button {
|
||||
chatService.createConversation()
|
||||
} label: {
|
||||
Image(systemName: "square.and.pencil")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private var conversationList: some View {
|
||||
NavigationStack {
|
||||
List {
|
||||
if chatService.conversations.isEmpty {
|
||||
Text("No conversations yet")
|
||||
.foregroundStyle(.secondary)
|
||||
} else {
|
||||
ForEach(chatService.conversations) { conversation in
|
||||
Button {
|
||||
chatService.setActiveConversation(id: conversation.id)
|
||||
showConversations = false
|
||||
} label: {
|
||||
VStack(alignment: .leading, spacing: 4) {
|
||||
Text(conversation.title)
|
||||
.fontWeight(
|
||||
conversation.id == chatService.activeConversationId
|
||||
? .semibold : .regular
|
||||
)
|
||||
.lineLimit(1)
|
||||
|
||||
if let modelId = conversation.modelId {
|
||||
Text(modelId)
|
||||
.font(.caption)
|
||||
.foregroundStyle(.secondary)
|
||||
.lineLimit(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
.tint(.primary)
|
||||
}
|
||||
.onDelete { indexSet in
|
||||
for index in indexSet {
|
||||
chatService.deleteConversation(id: chatService.conversations[index].id)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
.navigationTitle("Conversations")
|
||||
.navigationBarTitleDisplayMode(.inline)
|
||||
.toolbar {
|
||||
ToolbarItem(placement: .confirmationAction) {
|
||||
Button("Done") { showConversations = false }
|
||||
}
|
||||
ToolbarItem(placement: .topBarLeading) {
|
||||
Button {
|
||||
chatService.createConversation()
|
||||
} label: {
|
||||
Image(systemName: "plus")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
197
app/EXO-iOS/EXO-iOS/Views/Settings/SettingsView.swift
Normal file
197
app/EXO-iOS/EXO-iOS/Views/Settings/SettingsView.swift
Normal file
@@ -0,0 +1,197 @@
|
||||
import SwiftUI
|
||||
|
||||
struct SettingsView: View {
|
||||
@Environment(ClusterService.self) private var clusterService
|
||||
@Environment(DiscoveryService.self) private var discoveryService
|
||||
@Environment(LocalInferenceService.self) private var localInferenceService
|
||||
@Environment(\.dismiss) private var dismiss
|
||||
@State private var host: String = ""
|
||||
@State private var port: String = "52415"
|
||||
|
||||
var body: some View {
|
||||
NavigationStack {
|
||||
Form {
|
||||
localModelSection
|
||||
nearbyClustersSection
|
||||
connectionSection
|
||||
statusSection
|
||||
}
|
||||
.navigationTitle("Settings")
|
||||
.navigationBarTitleDisplayMode(.inline)
|
||||
.toolbar {
|
||||
ToolbarItem(placement: .confirmationAction) {
|
||||
Button("Done") { dismiss() }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private var localModelSection: some View {
|
||||
Section {
|
||||
HStack {
|
||||
VStack(alignment: .leading, spacing: 4) {
|
||||
Text(localInferenceService.defaultModelId)
|
||||
.font(.subheadline)
|
||||
.fontWeight(.medium)
|
||||
|
||||
Text(localModelStatusText)
|
||||
.font(.caption)
|
||||
.foregroundStyle(.secondary)
|
||||
}
|
||||
|
||||
Spacer()
|
||||
|
||||
localModelActionButton
|
||||
}
|
||||
|
||||
if case .downloading(let progress) = localInferenceService.modelState {
|
||||
ProgressView(value: progress)
|
||||
.tint(.blue)
|
||||
}
|
||||
} header: {
|
||||
Text("Local Model")
|
||||
} footer: {
|
||||
Text("When disconnected from a cluster, messages are processed on-device using this model.")
|
||||
}
|
||||
}
|
||||
|
||||
private var localModelStatusText: String {
|
||||
switch localInferenceService.modelState {
|
||||
case .notDownloaded: "Not downloaded"
|
||||
case .downloading(let progress): "Downloading \(Int(progress * 100))%..."
|
||||
case .downloaded: "Downloaded — not loaded"
|
||||
case .loading: "Loading into memory..."
|
||||
case .ready: "Ready"
|
||||
case .generating: "Generating..."
|
||||
case .error(let message): "Error: \(message)"
|
||||
}
|
||||
}
|
||||
|
||||
@ViewBuilder
|
||||
private var localModelActionButton: some View {
|
||||
switch localInferenceService.modelState {
|
||||
case .notDownloaded:
|
||||
Button("Download") {
|
||||
Task { await localInferenceService.prepareModel() }
|
||||
}
|
||||
.buttonStyle(.borderedProminent)
|
||||
.controlSize(.small)
|
||||
case .downloading:
|
||||
ProgressView()
|
||||
.controlSize(.small)
|
||||
case .downloaded:
|
||||
Button("Load") {
|
||||
Task { await localInferenceService.prepareModel() }
|
||||
}
|
||||
.buttonStyle(.bordered)
|
||||
.controlSize(.small)
|
||||
case .loading:
|
||||
ProgressView()
|
||||
.controlSize(.small)
|
||||
case .ready, .generating:
|
||||
Button("Unload") {
|
||||
localInferenceService.unloadModel()
|
||||
}
|
||||
.buttonStyle(.bordered)
|
||||
.controlSize(.small)
|
||||
case .error:
|
||||
Button("Retry") {
|
||||
Task { await localInferenceService.prepareModel() }
|
||||
}
|
||||
.buttonStyle(.borderedProminent)
|
||||
.controlSize(.small)
|
||||
.tint(.red)
|
||||
}
|
||||
}
|
||||
|
||||
private var nearbyClustersSection: some View {
|
||||
Section {
|
||||
if discoveryService.discoveredClusters.isEmpty {
|
||||
if discoveryService.isSearching {
|
||||
HStack {
|
||||
ProgressView()
|
||||
.padding(.trailing, 8)
|
||||
Text("Searching for clusters...")
|
||||
.foregroundStyle(.secondary)
|
||||
}
|
||||
} else {
|
||||
Text("No clusters found")
|
||||
.foregroundStyle(.secondary)
|
||||
}
|
||||
} else {
|
||||
ForEach(discoveryService.discoveredClusters) { cluster in
|
||||
HStack {
|
||||
VStack(alignment: .leading) {
|
||||
Text(cluster.name)
|
||||
.font(.body)
|
||||
}
|
||||
Spacer()
|
||||
Button("Connect") {
|
||||
Task {
|
||||
await clusterService.connectToDiscoveredCluster(
|
||||
cluster, using: discoveryService
|
||||
)
|
||||
if clusterService.isConnected {
|
||||
dismiss()
|
||||
}
|
||||
}
|
||||
}
|
||||
.buttonStyle(.borderedProminent)
|
||||
.controlSize(.small)
|
||||
}
|
||||
}
|
||||
}
|
||||
} header: {
|
||||
Text("Nearby Clusters")
|
||||
}
|
||||
}
|
||||
|
||||
private var connectionSection: some View {
|
||||
Section("Manual Connection") {
|
||||
TextField("IP Address (e.g. 192.168.1.42)", text: $host)
|
||||
.keyboardType(.decimalPad)
|
||||
.textContentType(.URL)
|
||||
.autocorrectionDisabled()
|
||||
|
||||
TextField("Port", text: $port)
|
||||
.keyboardType(.numberPad)
|
||||
|
||||
Button(clusterService.isConnected ? "Reconnect" : "Connect") {
|
||||
Task {
|
||||
let portNum = Int(port) ?? ConnectionInfo.defaultPort
|
||||
let info = ConnectionInfo(host: host, port: portNum, nodeId: nil)
|
||||
await clusterService.connect(to: info)
|
||||
if clusterService.isConnected {
|
||||
dismiss()
|
||||
}
|
||||
}
|
||||
}
|
||||
.disabled(host.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty)
|
||||
}
|
||||
}
|
||||
|
||||
private var statusSection: some View {
|
||||
Section("Status") {
|
||||
if let connection = clusterService.currentConnection {
|
||||
LabeledContent("Host", value: connection.host)
|
||||
LabeledContent("Port", value: "\(connection.port)")
|
||||
if let nodeId = connection.nodeId {
|
||||
LabeledContent("Node ID", value: String(nodeId.prefix(12)) + "...")
|
||||
}
|
||||
LabeledContent("Models", value: "\(clusterService.availableModels.count)")
|
||||
|
||||
Button("Disconnect", role: .destructive) {
|
||||
clusterService.disconnect()
|
||||
}
|
||||
} else {
|
||||
if let error = clusterService.lastError {
|
||||
Label(error, systemImage: "exclamationmark.triangle")
|
||||
.foregroundStyle(.red)
|
||||
} else {
|
||||
Text("Not connected")
|
||||
.foregroundStyle(.secondary)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
18
app/EXO-iOS/EXO-iOSTests/EXO_iOSTests.swift
Normal file
18
app/EXO-iOS/EXO-iOSTests/EXO_iOSTests.swift
Normal file
@@ -0,0 +1,18 @@
|
||||
//
|
||||
// EXO_iOSTests.swift
|
||||
// EXO-iOSTests
|
||||
//
|
||||
// Created by Sami Khan on 2026-02-17.
|
||||
//
|
||||
|
||||
import Testing
|
||||
|
||||
@testable import EXO_iOS
|
||||
|
||||
struct EXO_iOSTests {
|
||||
|
||||
@Test func example() async throws {
|
||||
// Write your test here and use APIs like `#expect(...)` to check expected conditions.
|
||||
}
|
||||
|
||||
}
|
||||
41
app/EXO-iOS/EXO-iOSUITests/EXO_iOSUITests.swift
Normal file
41
app/EXO-iOS/EXO-iOSUITests/EXO_iOSUITests.swift
Normal file
@@ -0,0 +1,41 @@
|
||||
//
|
||||
// EXO_iOSUITests.swift
|
||||
// EXO-iOSUITests
|
||||
//
|
||||
// Created by Sami Khan on 2026-02-17.
|
||||
//
|
||||
|
||||
import XCTest
|
||||
|
||||
final class EXO_iOSUITests: XCTestCase {
|
||||
|
||||
override func setUpWithError() throws {
|
||||
// Put setup code here. This method is called before the invocation of each test method in the class.
|
||||
|
||||
// In UI tests it is usually best to stop immediately when a failure occurs.
|
||||
continueAfterFailure = false
|
||||
|
||||
// In UI tests it’s important to set the initial state - such as interface orientation - required for your tests before they run. The setUp method is a good place to do this.
|
||||
}
|
||||
|
||||
override func tearDownWithError() throws {
|
||||
// Put teardown code here. This method is called after the invocation of each test method in the class.
|
||||
}
|
||||
|
||||
@MainActor
|
||||
func testExample() throws {
|
||||
// UI tests must launch the application that they test.
|
||||
let app = XCUIApplication()
|
||||
app.launch()
|
||||
|
||||
// Use XCTAssert and related functions to verify your tests produce the correct results.
|
||||
}
|
||||
|
||||
@MainActor
|
||||
func testLaunchPerformance() throws {
|
||||
// This measures how long it takes to launch your application.
|
||||
measure(metrics: [XCTApplicationLaunchMetric()]) {
|
||||
XCUIApplication().launch()
|
||||
}
|
||||
}
|
||||
}
|
||||
33
app/EXO-iOS/EXO-iOSUITests/EXO_iOSUITestsLaunchTests.swift
Normal file
33
app/EXO-iOS/EXO-iOSUITests/EXO_iOSUITestsLaunchTests.swift
Normal file
@@ -0,0 +1,33 @@
|
||||
//
|
||||
// EXO_iOSUITestsLaunchTests.swift
|
||||
// EXO-iOSUITests
|
||||
//
|
||||
// Created by Sami Khan on 2026-02-17.
|
||||
//
|
||||
|
||||
import XCTest
|
||||
|
||||
final class EXO_iOSUITestsLaunchTests: XCTestCase {
|
||||
|
||||
override class var runsForEachTargetApplicationUIConfiguration: Bool {
|
||||
true
|
||||
}
|
||||
|
||||
override func setUpWithError() throws {
|
||||
continueAfterFailure = false
|
||||
}
|
||||
|
||||
@MainActor
|
||||
func testLaunch() throws {
|
||||
let app = XCUIApplication()
|
||||
app.launch()
|
||||
|
||||
// Insert steps here to perform after app launch but before taking a screenshot,
|
||||
// such as logging into a test account or navigating somewhere in the app
|
||||
|
||||
let attachment = XCTAttachment(screenshot: app.screenshot())
|
||||
attachment.name = "Launch Screen"
|
||||
attachment.lifetime = .keepAlways
|
||||
add(attachment)
|
||||
}
|
||||
}
|
||||
7
bench/bench.toml
Normal file
7
bench/bench.toml
Normal file
@@ -0,0 +1,7 @@
|
||||
# Canary benchmark manifest
|
||||
#
|
||||
# Lists the suite files to include. Each file defines benchmarks
|
||||
# with shared constraints, topology, and default args.
|
||||
include = [
|
||||
"single-m3-ultra.toml",
|
||||
]
|
||||
@@ -288,6 +288,151 @@ def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]
|
||||
raise ValueError(f"Model not found in /models: {model_arg}")
|
||||
|
||||
|
||||
def run_planning_phase(
|
||||
client: ExoClient,
|
||||
full_model_id: str,
|
||||
preview: dict[str, Any],
|
||||
danger_delete: bool,
|
||||
timeout: float,
|
||||
settle_deadline: float | None,
|
||||
) -> None:
|
||||
"""Check disk space and ensure model is downloaded before benchmarking."""
|
||||
# Get model size from /models
|
||||
models = client.request_json("GET", "/models") or {}
|
||||
model_bytes = 0
|
||||
for m in models.get("data", []):
|
||||
if m.get("hugging_face_id") == full_model_id:
|
||||
model_bytes = m.get("storage_size_megabytes", 0) * 1024 * 1024
|
||||
break
|
||||
|
||||
if not model_bytes:
|
||||
logger.warning(
|
||||
f"Could not determine size for {full_model_id}, skipping disk check"
|
||||
)
|
||||
return
|
||||
|
||||
# Get nodes from preview
|
||||
inner = unwrap_instance(preview["instance"])
|
||||
node_ids = list(inner["shardAssignments"]["nodeToRunner"].keys())
|
||||
runner_to_shard = inner["shardAssignments"]["runnerToShard"]
|
||||
|
||||
state = client.request_json("GET", "/state")
|
||||
downloads = state.get("downloads", {})
|
||||
node_disk = state.get("nodeDisk", {})
|
||||
|
||||
for node_id in node_ids:
|
||||
node_downloads = downloads.get(node_id, [])
|
||||
|
||||
# Check if model already downloaded on this node
|
||||
already_downloaded = any(
|
||||
"DownloadCompleted" in p
|
||||
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
]
|
||||
== full_model_id
|
||||
for p in node_downloads
|
||||
)
|
||||
if already_downloaded:
|
||||
continue
|
||||
|
||||
# Wait for disk info if settle_deadline is set
|
||||
disk_info = node_disk.get(node_id, {})
|
||||
backoff = _SETTLE_INITIAL_BACKOFF_S
|
||||
while not disk_info and settle_deadline and time.monotonic() < settle_deadline:
|
||||
remaining = settle_deadline - time.monotonic()
|
||||
logger.info(
|
||||
f"Waiting for disk info on {node_id} ({remaining:.0f}s remaining)..."
|
||||
)
|
||||
time.sleep(min(backoff, remaining))
|
||||
backoff = min(backoff * _SETTLE_BACKOFF_MULTIPLIER, _SETTLE_MAX_BACKOFF_S)
|
||||
state = client.request_json("GET", "/state")
|
||||
node_disk = state.get("nodeDisk", {})
|
||||
disk_info = node_disk.get(node_id, {})
|
||||
|
||||
if not disk_info:
|
||||
logger.warning(f"No disk info for {node_id}, skipping space check")
|
||||
continue
|
||||
|
||||
avail = disk_info.get("available", {}).get("inBytes", 0)
|
||||
if avail >= model_bytes:
|
||||
continue
|
||||
|
||||
if not danger_delete:
|
||||
raise RuntimeError(
|
||||
f"Insufficient disk on {node_id}: need {model_bytes // (1024**3)}GB, "
|
||||
f"have {avail // (1024**3)}GB. Use --danger-delete-downloads to free space."
|
||||
)
|
||||
|
||||
# Delete from smallest to largest
|
||||
completed = [
|
||||
(
|
||||
unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
],
|
||||
p["DownloadCompleted"]["totalBytes"]["inBytes"],
|
||||
)
|
||||
for p in node_downloads
|
||||
if "DownloadCompleted" in p
|
||||
]
|
||||
for del_model, size in sorted(completed, key=lambda x: x[1]):
|
||||
logger.info(f"Deleting {del_model} from {node_id} ({size // (1024**2)}MB)")
|
||||
client.request_json("DELETE", f"/download/{node_id}/{del_model}")
|
||||
avail += size
|
||||
if avail >= model_bytes:
|
||||
break
|
||||
|
||||
if avail < model_bytes:
|
||||
raise RuntimeError(f"Could not free enough space on {node_id}")
|
||||
|
||||
# Start downloads (idempotent)
|
||||
for node_id in node_ids:
|
||||
runner_id = inner["shardAssignments"]["nodeToRunner"][node_id]
|
||||
shard = runner_to_shard[runner_id]
|
||||
client.request_json(
|
||||
"POST",
|
||||
"/download/start",
|
||||
body={
|
||||
"targetNodeId": node_id,
|
||||
"shardMetadata": shard,
|
||||
},
|
||||
)
|
||||
logger.info(f"Started download on {node_id}")
|
||||
|
||||
# Wait for downloads
|
||||
start = time.time()
|
||||
while time.time() - start < timeout:
|
||||
state = client.request_json("GET", "/state")
|
||||
downloads = state.get("downloads", {})
|
||||
all_done = True
|
||||
for node_id in node_ids:
|
||||
done = any(
|
||||
"DownloadCompleted" in p
|
||||
and unwrap_instance(p["DownloadCompleted"]["shardMetadata"])[
|
||||
"modelCard"
|
||||
]["modelId"]
|
||||
== full_model_id
|
||||
for p in downloads.get(node_id, [])
|
||||
)
|
||||
failed = [
|
||||
p["DownloadFailed"]["errorMessage"]
|
||||
for p in downloads.get(node_id, [])
|
||||
if "DownloadFailed" in p
|
||||
and unwrap_instance(p["DownloadFailed"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
]
|
||||
== full_model_id
|
||||
]
|
||||
if failed:
|
||||
raise RuntimeError(f"Download failed on {node_id}: {failed[0]}")
|
||||
if not done:
|
||||
all_done = False
|
||||
if all_done:
|
||||
return
|
||||
time.sleep(1)
|
||||
|
||||
raise TimeoutError("Downloads did not complete in time")
|
||||
|
||||
|
||||
def placement_filter(instance_meta: str, wanted: str) -> bool:
|
||||
s = (instance_meta or "").lower()
|
||||
if wanted == "both":
|
||||
@@ -535,6 +680,11 @@ def main() -> int:
|
||||
default=0,
|
||||
help="Max seconds to wait for the cluster to produce valid placements (0 = try once).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--danger-delete-downloads",
|
||||
action="store_true",
|
||||
help="Delete existing models from smallest to largest to make room for benchmark model.",
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
pp_list = parse_int_list(args.pp)
|
||||
@@ -569,13 +719,16 @@ def main() -> int:
|
||||
logger.error("[exo-bench] tokenizer usable but prompt sizing failed")
|
||||
raise
|
||||
|
||||
settle_deadline = (
|
||||
time.monotonic() + args.settle_timeout if args.settle_timeout > 0 else None
|
||||
)
|
||||
|
||||
selected = fetch_and_filter_placements(client, full_model_id, args)
|
||||
|
||||
if not selected and args.settle_timeout > 0:
|
||||
if not selected and settle_deadline:
|
||||
backoff = _SETTLE_INITIAL_BACKOFF_S
|
||||
deadline = time.monotonic() + args.settle_timeout
|
||||
while not selected and time.monotonic() < deadline:
|
||||
remaining = deadline - time.monotonic()
|
||||
while not selected and time.monotonic() < settle_deadline:
|
||||
remaining = settle_deadline - time.monotonic()
|
||||
logger.warning(
|
||||
f"No valid placements yet (cluster may still be settling). "
|
||||
f"Retrying in {backoff:.1f}s ({remaining:.0f}s remaining)..."
|
||||
@@ -607,6 +760,16 @@ def main() -> int:
|
||||
if args.dry_run:
|
||||
return 0
|
||||
|
||||
logger.info("Planning phase: checking downloads...")
|
||||
run_planning_phase(
|
||||
client,
|
||||
full_model_id,
|
||||
selected[0],
|
||||
args.danger_delete_downloads,
|
||||
args.timeout,
|
||||
settle_deadline,
|
||||
)
|
||||
|
||||
all_rows: list[dict[str, Any]] = []
|
||||
|
||||
for preview in selected:
|
||||
|
||||
189
bench/single-m3-ultra.toml
Normal file
189
bench/single-m3-ultra.toml
Normal file
@@ -0,0 +1,189 @@
|
||||
# Single-node M3 Ultra benchmarks
|
||||
#
|
||||
# Shared constraints applied to ALL benchmarks in this file.
|
||||
constraints = [
|
||||
"All(MacOsBuild(=25D125))",
|
||||
"Hosts(=1)",
|
||||
"All(Chip(m3_ultra))",
|
||||
"All(GpuCores(=80))",
|
||||
]
|
||||
|
||||
[topology]
|
||||
type = "none"
|
||||
|
||||
# Default args merged into each benchmark's args (benchmark-level args win).
|
||||
[defaults]
|
||||
pp = [512, 2048, 8192, 16384]
|
||||
tg = 128
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/gpt-oss-120b-MXFP4-Q8"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.7-Flash-8bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Coder-Next-6bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-30B-A3B-8bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-0.6B-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-0.6B-8bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Llama-3.2-1B-Instruct-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Llama-3.2-3B-Instruct-8bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/gpt-oss-20b-MXFP4-Q8"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-30B-A3B-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.7-Flash-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.7-Flash-5bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.7-Flash-6bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Llama-3.3-70B-Instruct-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Coder-Next-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Coder-Next-5bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Coder-Next-8bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"
|
||||
extra_constraints = ["All(Memory(>=96GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Llama-3.3-70B-Instruct-8bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/llama-3.3-70b-instruct-fp16"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.5-Air-8bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.5-Air-bf16"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.7-4bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/MiniMax-M2.1-3bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/MiniMax-M2.1-8bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Coder-Next-bf16"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Step-3.5-Flash-4bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Step-3.5-Flash-6bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Step-3.5-Flash-8Bit"
|
||||
extra_constraints = ["All(Memory(>=256GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/DeepSeek-V3.1-4bit"
|
||||
extra_constraints = ["All(Memory(>=512GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.7-6bit"
|
||||
extra_constraints = ["All(Memory(>=512GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/GLM-4.7-8bit-gs32"
|
||||
extra_constraints = ["All(Memory(>=512GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"
|
||||
extra_constraints = ["All(Memory(>=512GiB))"]
|
||||
|
||||
[[benchmark]]
|
||||
model = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"
|
||||
extra_constraints = ["All(Memory(>=512GiB))"]
|
||||
@@ -1,7 +1,6 @@
|
||||
<script lang="ts">
|
||||
import {
|
||||
isLoading,
|
||||
stopGeneration,
|
||||
sendMessage,
|
||||
generateImage,
|
||||
editImage,
|
||||
@@ -266,6 +265,7 @@
|
||||
|
||||
function handleSubmit() {
|
||||
if ((!message.trim() && uploadedFiles.length === 0) || loading) return;
|
||||
if (isEditOnlyWithoutImage) return;
|
||||
|
||||
const content = message.trim();
|
||||
const files = [...uploadedFiles];
|
||||
@@ -290,7 +290,11 @@
|
||||
if (imageFile.preview) {
|
||||
editImage(content, imageFile.preview);
|
||||
}
|
||||
} else if (isImageModel() && content) {
|
||||
} else if (
|
||||
currentModel &&
|
||||
modelSupportsTextToImage(currentModel) &&
|
||||
content
|
||||
) {
|
||||
// Use image generation for text-to-image models
|
||||
generateImage(content);
|
||||
} else {
|
||||
@@ -649,92 +653,86 @@
|
||||
style="min-height: 28px; max-height: 150px;"
|
||||
></textarea>
|
||||
|
||||
{#if loading}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => stopGeneration()}
|
||||
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap bg-exo-medium-gray/70 text-exo-light-gray hover:bg-red-900/50 hover:text-red-400 border border-exo-medium-gray/50 hover:border-red-500/50 cursor-pointer"
|
||||
aria-label="Stop generation"
|
||||
>
|
||||
<button
|
||||
type="submit"
|
||||
disabled={!canSend || loading || isEditOnlyWithoutImage}
|
||||
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
|
||||
{!canSend || loading || isEditOnlyWithoutImage
|
||||
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
|
||||
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
|
||||
aria-label={shouldShowEditMode
|
||||
? "Edit image"
|
||||
: isImageModel()
|
||||
? "Generate image"
|
||||
: "Send message"}
|
||||
>
|
||||
{#if loading}
|
||||
<span class="inline-flex items-center gap-1 sm:gap-2">
|
||||
<svg
|
||||
class="w-2.5 h-2.5 sm:w-3 sm:h-3"
|
||||
viewBox="0 0 24 24"
|
||||
fill="currentColor"
|
||||
<span
|
||||
class="w-2.5 h-2.5 sm:w-3 sm:h-3 border-2 border-current border-t-transparent rounded-full animate-spin"
|
||||
></span>
|
||||
<span class="hidden sm:inline"
|
||||
>{shouldShowEditMode
|
||||
? "EDITING"
|
||||
: isImageModel()
|
||||
? "GENERATING"
|
||||
: "PROCESSING"}</span
|
||||
>
|
||||
<rect x="4" y="4" width="16" height="16" rx="2" />
|
||||
</svg>
|
||||
<span class="hidden sm:inline">STOP</span>
|
||||
<span class="sm:hidden">...</span>
|
||||
</span>
|
||||
</button>
|
||||
{:else}
|
||||
<button
|
||||
type="submit"
|
||||
disabled={!canSend || isEditOnlyWithoutImage}
|
||||
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
|
||||
{!canSend || isEditOnlyWithoutImage
|
||||
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
|
||||
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
|
||||
aria-label={shouldShowEditMode
|
||||
? "Edit image"
|
||||
: isImageModel()
|
||||
? "Generate image"
|
||||
: "Send message"}
|
||||
>
|
||||
{#if shouldShowEditMode}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
</span>
|
||||
{:else if isEditOnlyWithoutImage}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
</span>
|
||||
{:else if isImageModel()}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
|
||||
<circle cx="8.5" cy="8.5" r="1.5" />
|
||||
<polyline points="21 15 16 10 5 21" />
|
||||
</svg>
|
||||
<span>GENERATE</span>
|
||||
</span>
|
||||
{:else}
|
||||
SEND
|
||||
{/if}
|
||||
</button>
|
||||
{/if}
|
||||
{:else if shouldShowEditMode}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
</span>
|
||||
{:else if isEditOnlyWithoutImage}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
</span>
|
||||
{:else if isImageModel()}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
|
||||
<circle cx="8.5" cy="8.5" r="1.5" />
|
||||
<polyline points="21 15 16 10 5 21" />
|
||||
</svg>
|
||||
<span>GENERATE</span>
|
||||
</span>
|
||||
{:else}
|
||||
SEND
|
||||
{/if}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Bottom accent line -->
|
||||
|
||||
@@ -225,6 +225,7 @@
|
||||
}
|
||||
|
||||
function handleDeleteClick(messageId: string) {
|
||||
if (loading) return;
|
||||
deleteConfirmId = messageId;
|
||||
}
|
||||
|
||||
@@ -255,7 +256,7 @@
|
||||
</script>
|
||||
|
||||
<div class="flex flex-col gap-4 sm:gap-6 {className}">
|
||||
{#each messageList as message (message.id)}
|
||||
{#each messageList as message, i (message.id)}
|
||||
<div
|
||||
class="group flex {message.role === 'user'
|
||||
? 'justify-end'
|
||||
@@ -317,9 +318,11 @@
|
||||
<!-- Delete confirmation -->
|
||||
<div class="bg-red-500/10 border border-red-500/30 rounded-lg p-3">
|
||||
<p class="text-xs text-red-400 mb-3">
|
||||
Delete this message{message.role === "user"
|
||||
? " and all responses after it"
|
||||
: ""}?
|
||||
{#if i === messageList.length - 1}
|
||||
Delete this message?
|
||||
{:else}
|
||||
Delete this message and all messages after it?
|
||||
{/if}
|
||||
</p>
|
||||
<div class="flex gap-2 justify-end">
|
||||
<button
|
||||
@@ -751,8 +754,13 @@
|
||||
<!-- Delete button -->
|
||||
<button
|
||||
onclick={() => handleDeleteClick(message.id)}
|
||||
class="p-1.5 text-exo-light-gray hover:text-red-400 transition-colors rounded hover:bg-red-500/10 cursor-pointer"
|
||||
title="Delete message"
|
||||
disabled={loading}
|
||||
class="p-1.5 transition-colors rounded {loading
|
||||
? 'text-exo-light-gray/30 cursor-not-allowed'
|
||||
: 'text-exo-light-gray hover:text-red-400 hover:bg-red-500/10 cursor-pointer'}"
|
||||
title={loading
|
||||
? "Cannot delete while generating"
|
||||
: "Delete message"}
|
||||
>
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
|
||||
@@ -514,7 +514,6 @@ class AppStore {
|
||||
messages = $state<Message[]>([]);
|
||||
currentResponse = $state("");
|
||||
isLoading = $state(false);
|
||||
private currentAbortController: AbortController | null = null;
|
||||
|
||||
// Performance metrics
|
||||
ttftMs = $state<number | null>(null); // Time to first token in ms
|
||||
@@ -1815,11 +1814,9 @@ class AppStore {
|
||||
return;
|
||||
}
|
||||
|
||||
this.currentAbortController = new AbortController();
|
||||
const response = await fetch("/v1/chat/completions", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
signal: this.currentAbortController.signal,
|
||||
body: JSON.stringify({
|
||||
model: modelToUse,
|
||||
messages: apiMessages,
|
||||
@@ -1933,7 +1930,6 @@ class AppStore {
|
||||
"Unknown error",
|
||||
);
|
||||
} finally {
|
||||
this.currentAbortController = null;
|
||||
this.isLoading = false;
|
||||
this.currentResponse = "";
|
||||
this.saveConversationsToStorage();
|
||||
@@ -2070,10 +2066,6 @@ class AppStore {
|
||||
assistantMessageId: string,
|
||||
errorPrefix = "Failed to get response",
|
||||
): void {
|
||||
// Don't show error for user-initiated abort (stop button)
|
||||
if (error instanceof DOMException && error.name === "AbortError") {
|
||||
return;
|
||||
}
|
||||
if (this.conversationExists(targetConversationId)) {
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
@@ -2115,17 +2107,6 @@ class AppStore {
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop the current generation by aborting the HTTP connection.
|
||||
* This triggers backend cancellation via the mechanism in PR #1276.
|
||||
*/
|
||||
stopGeneration() {
|
||||
if (this.currentAbortController) {
|
||||
this.currentAbortController.abort();
|
||||
this.currentAbortController = null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Send a message to the LLM and stream the response
|
||||
*/
|
||||
@@ -2274,13 +2255,11 @@ class AppStore {
|
||||
let firstTokenTime: number | null = null;
|
||||
let tokenCount = 0;
|
||||
|
||||
this.currentAbortController = new AbortController();
|
||||
const response = await fetch("/v1/chat/completions", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
signal: this.currentAbortController.signal,
|
||||
body: JSON.stringify({
|
||||
model: modelToUse,
|
||||
messages: apiMessages,
|
||||
@@ -2431,7 +2410,6 @@ class AppStore {
|
||||
"Failed to get response",
|
||||
);
|
||||
} finally {
|
||||
this.currentAbortController = null;
|
||||
this.isLoading = false;
|
||||
this.currentResponse = "";
|
||||
this.saveConversationsToStorage();
|
||||
@@ -2536,13 +2514,11 @@ class AppStore {
|
||||
};
|
||||
}
|
||||
|
||||
this.currentAbortController = new AbortController();
|
||||
const response = await fetch("/v1/images/generations", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
signal: this.currentAbortController.signal,
|
||||
body: JSON.stringify(requestBody),
|
||||
});
|
||||
|
||||
@@ -2691,7 +2667,6 @@ class AppStore {
|
||||
"Failed to generate image",
|
||||
);
|
||||
} finally {
|
||||
this.currentAbortController = null;
|
||||
this.isLoading = false;
|
||||
this.saveConversationsToStorage();
|
||||
}
|
||||
@@ -2813,10 +2788,8 @@ class AppStore {
|
||||
);
|
||||
}
|
||||
|
||||
this.currentAbortController = new AbortController();
|
||||
const apiResponse = await fetch("/v1/images/edits", {
|
||||
method: "POST",
|
||||
signal: this.currentAbortController.signal,
|
||||
body: formData,
|
||||
});
|
||||
|
||||
@@ -2926,7 +2899,6 @@ class AppStore {
|
||||
"Failed to edit image",
|
||||
);
|
||||
} finally {
|
||||
this.currentAbortController = null;
|
||||
this.isLoading = false;
|
||||
this.saveConversationsToStorage();
|
||||
}
|
||||
@@ -3067,7 +3039,6 @@ export const hasStartedChat = () => appStore.hasStartedChat;
|
||||
export const messages = () => appStore.messages;
|
||||
export const currentResponse = () => appStore.currentResponse;
|
||||
export const isLoading = () => appStore.isLoading;
|
||||
export const stopGeneration = () => appStore.stopGeneration();
|
||||
export const ttftMs = () => appStore.ttftMs;
|
||||
export const tps = () => appStore.tps;
|
||||
export const totalTokens = () => appStore.totalTokens;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -14,7 +14,9 @@
|
||||
|
||||
# Override overlay to inject Nix-built components
|
||||
exoOverlay = final: prev: {
|
||||
# Replace workspace exo_pyo3_bindings with Nix-built wheel
|
||||
# Replace workspace exo_pyo3_bindings with Nix-built wheel.
|
||||
# Preserve passthru so mkVirtualEnv can resolve dependency groups.
|
||||
# Copy .pyi stub + py.typed marker so basedpyright can find the types.
|
||||
exo-pyo3-bindings = pkgs.stdenv.mkDerivation {
|
||||
pname = "exo-pyo3-bindings";
|
||||
version = "0.1.0";
|
||||
@@ -22,6 +24,12 @@
|
||||
# Install from pre-built wheel
|
||||
nativeBuildInputs = [ final.pyprojectWheelHook ];
|
||||
dontStrip = true;
|
||||
passthru = prev.exo-pyo3-bindings.passthru or { };
|
||||
postInstall = ''
|
||||
local siteDir=$out/${final.python.sitePackages}/exo_pyo3_bindings
|
||||
cp ${inputs.self}/rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi $siteDir/
|
||||
touch $siteDir/py.typed
|
||||
'';
|
||||
};
|
||||
};
|
||||
|
||||
@@ -29,17 +37,32 @@
|
||||
|
||||
# Overlay to provide build systems and custom packages
|
||||
buildSystemsOverlay = final: prev: {
|
||||
# Use our pure Nix-built MLX with Metal support
|
||||
mlx = self'.packages.mlx;
|
||||
|
||||
# mlx-lm is a git dependency that needs setuptools
|
||||
mlx-lm = prev.mlx-lm.overrideAttrs (old: {
|
||||
nativeBuildInputs = (old.nativeBuildInputs or [ ]) ++ [
|
||||
final.setuptools
|
||||
];
|
||||
});
|
||||
} // lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin {
|
||||
# Use our pure Nix-built MLX with Metal support (macOS only)
|
||||
mlx = self'.packages.mlx;
|
||||
};
|
||||
|
||||
# Additional overlay for Linux-specific fixes (type checking env).
|
||||
# Native wheels have shared lib dependencies we don't need at type-check time.
|
||||
linuxOverlay = final: prev:
|
||||
let
|
||||
ignoreMissing = drv: drv.overrideAttrs { autoPatchelfIgnoreMissingDeps = [ "*" ]; };
|
||||
nvidiaPackages = lib.filterAttrs (name: _: lib.hasPrefix "nvidia-" name) prev;
|
||||
in
|
||||
lib.optionalAttrs pkgs.stdenv.hostPlatform.isLinux (
|
||||
(lib.mapAttrs (_: ignoreMissing) nvidiaPackages) // {
|
||||
mlx = ignoreMissing prev.mlx;
|
||||
torch = ignoreMissing prev.torch;
|
||||
triton = ignoreMissing prev.triton;
|
||||
}
|
||||
);
|
||||
|
||||
pythonSet = (pkgs.callPackage inputs.pyproject-nix.build.packages {
|
||||
inherit python;
|
||||
}).overrideScope (
|
||||
@@ -48,6 +71,7 @@
|
||||
overlay
|
||||
exoOverlay
|
||||
buildSystemsOverlay
|
||||
linuxOverlay
|
||||
]
|
||||
);
|
||||
exoVenv = pythonSet.mkVirtualEnv "exo-env" workspace.deps.default;
|
||||
@@ -118,6 +142,21 @@
|
||||
${pkgs.ruff}/bin/ruff check ${inputs.self}
|
||||
touch $out
|
||||
'';
|
||||
|
||||
# Hermetic basedpyright type checking
|
||||
typecheck = pkgs.runCommand "typecheck"
|
||||
{
|
||||
nativeBuildInputs = [
|
||||
testVenv
|
||||
pkgs.basedpyright
|
||||
];
|
||||
}
|
||||
''
|
||||
cd ${inputs.self}
|
||||
export HOME=$TMPDIR
|
||||
basedpyright --pythonpath ${testVenv}/bin/python
|
||||
touch $out
|
||||
'';
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ from exo.download.download_utils import (
|
||||
map_repo_download_progress_to_download_progress_data,
|
||||
)
|
||||
from exo.download.shard_downloader import ShardDownloader
|
||||
from exo.shared.constants import EXO_MODELS_DIR
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.commands import (
|
||||
CancelDownload,
|
||||
@@ -63,6 +64,9 @@ class DownloadCoordinator:
|
||||
self.event_sender, self.event_receiver = channel[Event]()
|
||||
self.shard_downloader.on_progress(self._download_progress_callback)
|
||||
|
||||
def _model_dir(self, model_id: ModelId) -> str:
|
||||
return str(EXO_MODELS_DIR / model_id.normalize())
|
||||
|
||||
async def _download_progress_callback(
|
||||
self, callback_shard: ShardMetadata, progress: RepoDownloadProgress
|
||||
) -> None:
|
||||
@@ -74,6 +78,7 @@ class DownloadCoordinator:
|
||||
shard_metadata=callback_shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=progress.total_bytes,
|
||||
model_directory=self._model_dir(model_id),
|
||||
)
|
||||
self.download_status[model_id] = completed
|
||||
await self.event_sender.send(
|
||||
@@ -93,6 +98,7 @@ class DownloadCoordinator:
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
progress
|
||||
),
|
||||
model_directory=self._model_dir(model_id),
|
||||
)
|
||||
self.download_status[model_id] = ongoing
|
||||
await self.event_sender.send(
|
||||
@@ -170,7 +176,11 @@ class DownloadCoordinator:
|
||||
return
|
||||
|
||||
# Emit pending status
|
||||
progress = DownloadPending(shard_metadata=shard, node_id=self.node_id)
|
||||
progress = DownloadPending(
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
model_directory=self._model_dir(model_id),
|
||||
)
|
||||
self.download_status[model_id] = progress
|
||||
await self.event_sender.send(NodeDownloadProgress(download_progress=progress))
|
||||
|
||||
@@ -184,6 +194,7 @@ class DownloadCoordinator:
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=initial_progress.total_bytes,
|
||||
model_directory=self._model_dir(model_id),
|
||||
)
|
||||
self.download_status[model_id] = completed
|
||||
await self.event_sender.send(
|
||||
@@ -206,6 +217,7 @@ class DownloadCoordinator:
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
initial_progress
|
||||
),
|
||||
model_directory=self._model_dir(model_id),
|
||||
)
|
||||
self.download_status[model_id] = status
|
||||
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
|
||||
@@ -219,6 +231,7 @@ class DownloadCoordinator:
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
error_message=str(e),
|
||||
model_directory=self._model_dir(model_id),
|
||||
)
|
||||
self.download_status[model_id] = failed
|
||||
await self.event_sender.send(
|
||||
@@ -253,6 +266,7 @@ class DownloadCoordinator:
|
||||
pending = DownloadPending(
|
||||
shard_metadata=current_status.shard_metadata,
|
||||
node_id=self.node_id,
|
||||
model_directory=self._model_dir(model_id),
|
||||
)
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=pending)
|
||||
@@ -295,11 +309,18 @@ class DownloadCoordinator:
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
total_bytes=progress.total_bytes,
|
||||
model_directory=self._model_dir(
|
||||
progress.shard.model_card.model_id
|
||||
),
|
||||
)
|
||||
elif progress.status in ["in_progress", "not_started"]:
|
||||
if progress.downloaded_bytes_this_session.in_bytes == 0:
|
||||
status = DownloadPending(
|
||||
node_id=self.node_id, shard_metadata=progress.shard
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
model_directory=self._model_dir(
|
||||
progress.shard.model_card.model_id
|
||||
),
|
||||
)
|
||||
else:
|
||||
status = DownloadOngoing(
|
||||
@@ -308,6 +329,9 @@ class DownloadCoordinator:
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
progress
|
||||
),
|
||||
model_directory=self._model_dir(
|
||||
progress.shard.model_card.model_id
|
||||
),
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
@@ -17,6 +17,7 @@ from exo.shared.types.api import (
|
||||
LogprobsContentItem,
|
||||
StreamingChoiceResponse,
|
||||
ToolCall,
|
||||
Usage,
|
||||
)
|
||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.common import CommandId
|
||||
@@ -125,6 +126,8 @@ async def generate_chat_stream(
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate Chat Completions API streaming events from chunks."""
|
||||
last_usage: Usage | None = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
error_response = ErrorResponse(
|
||||
@@ -138,6 +141,8 @@ async def generate_chat_stream(
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
last_usage = chunk.usage or last_usage
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
tool_call_deltas = [
|
||||
ToolCall(
|
||||
@@ -161,12 +166,15 @@ async def generate_chat_stream(
|
||||
finish_reason="tool_calls",
|
||||
)
|
||||
],
|
||||
usage=last_usage,
|
||||
)
|
||||
yield f"data: {tool_response.model_dump_json()}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
chunk_response = chunk_to_response(chunk, command_id)
|
||||
if chunk.finish_reason is not None:
|
||||
chunk_response = chunk_response.model_copy(update={"usage": last_usage})
|
||||
yield f"data: {chunk_response.model_dump_json()}\n\n"
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
@@ -177,6 +185,8 @@ async def collect_chat_response(
|
||||
command_id: CommandId,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
) -> AsyncGenerator[str]:
|
||||
# This is an AsyncGenerator[str] rather than returning a ChatCompletionReponse because
|
||||
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
|
||||
"""Collect all token chunks and return a single ChatCompletionResponse."""
|
||||
text_parts: list[str] = []
|
||||
tool_calls: list[ToolCall] = []
|
||||
@@ -184,6 +194,7 @@ async def collect_chat_response(
|
||||
model: str | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
error_message: str | None = None
|
||||
last_usage: Usage | None = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
@@ -193,6 +204,8 @@ async def collect_chat_response(
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
|
||||
last_usage = chunk.usage or last_usage
|
||||
|
||||
if isinstance(chunk, TokenChunk):
|
||||
text_parts.append(chunk.text)
|
||||
if chunk.logprob is not None:
|
||||
@@ -241,5 +254,6 @@ async def collect_chat_response(
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
usage=last_usage,
|
||||
).model_dump_json()
|
||||
return
|
||||
|
||||
@@ -4,7 +4,7 @@ import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from exo.shared.types.api import FinishReason
|
||||
from exo.shared.types.api import FinishReason, Usage
|
||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.claude_api import (
|
||||
ClaudeContentBlock,
|
||||
@@ -161,12 +161,14 @@ async def collect_claude_response(
|
||||
command_id: CommandId,
|
||||
model: str,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
) -> ClaudeMessagesResponse:
|
||||
) -> AsyncGenerator[str]:
|
||||
# This is an AsyncGenerator[str] rather than returning a ChatCompletionReponse because
|
||||
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
|
||||
"""Collect all token chunks and return a single ClaudeMessagesResponse."""
|
||||
text_parts: list[str] = []
|
||||
tool_use_blocks: list[ClaudeToolUseBlock] = []
|
||||
stop_reason: ClaudeStopReason | None = None
|
||||
last_stats = None
|
||||
last_usage: Usage | None = None
|
||||
error_message: str | None = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
@@ -174,6 +176,8 @@ async def collect_claude_response(
|
||||
error_message = chunk.error_message or "Internal server error"
|
||||
break
|
||||
|
||||
last_usage = chunk.usage or last_usage
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
for tool in chunk.tool_calls:
|
||||
tool_use_blocks.append(
|
||||
@@ -183,12 +187,10 @@ async def collect_claude_response(
|
||||
input=json.loads(tool.arguments), # pyright: ignore[reportAny]
|
||||
)
|
||||
)
|
||||
last_stats = chunk.stats or last_stats
|
||||
stop_reason = "tool_use"
|
||||
continue
|
||||
|
||||
text_parts.append(chunk.text)
|
||||
last_stats = chunk.stats or last_stats
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
|
||||
@@ -208,11 +210,11 @@ async def collect_claude_response(
|
||||
if not content:
|
||||
content.append(ClaudeTextBlock(text=""))
|
||||
|
||||
# Use actual usage data from stats if available
|
||||
input_tokens = last_stats.prompt_tokens if last_stats else 0
|
||||
output_tokens = last_stats.generation_tokens if last_stats else 0
|
||||
# Use actual usage data if available
|
||||
input_tokens = last_usage.prompt_tokens if last_usage else 0
|
||||
output_tokens = last_usage.completion_tokens if last_usage else 0
|
||||
|
||||
return ClaudeMessagesResponse(
|
||||
yield ClaudeMessagesResponse(
|
||||
id=f"msg_{command_id}",
|
||||
model=model,
|
||||
content=content,
|
||||
@@ -221,7 +223,8 @@ async def collect_claude_response(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
),
|
||||
)
|
||||
).model_dump_json()
|
||||
return
|
||||
|
||||
|
||||
async def generate_claude_stream(
|
||||
@@ -249,7 +252,7 @@ async def generate_claude_stream(
|
||||
|
||||
output_tokens = 0
|
||||
stop_reason: ClaudeStopReason | None = None
|
||||
last_stats = None
|
||||
last_usage: Usage | None = None
|
||||
next_block_index = 1 # text block is 0, tool blocks start at 1
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
@@ -257,8 +260,9 @@ async def generate_claude_stream(
|
||||
# Close text block and bail
|
||||
break
|
||||
|
||||
last_usage = chunk.usage or last_usage
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
last_stats = chunk.stats or last_stats
|
||||
stop_reason = "tool_use"
|
||||
|
||||
# Emit tool_use content blocks
|
||||
@@ -290,7 +294,6 @@ async def generate_claude_stream(
|
||||
continue
|
||||
|
||||
output_tokens += 1 # Count each chunk as one token
|
||||
last_stats = chunk.stats or last_stats
|
||||
|
||||
# content_block_delta
|
||||
delta_event = ClaudeContentBlockDeltaEvent(
|
||||
@@ -302,9 +305,9 @@ async def generate_claude_stream(
|
||||
if chunk.finish_reason is not None:
|
||||
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
|
||||
|
||||
# Use actual token count from stats if available
|
||||
if last_stats is not None:
|
||||
output_tokens = last_stats.generation_tokens
|
||||
# Use actual token count from usage if available
|
||||
if last_usage is not None:
|
||||
output_tokens = last_usage.completion_tokens
|
||||
|
||||
# content_block_stop for text block
|
||||
block_stop = ClaudeContentBlockStopEvent(index=0)
|
||||
|
||||
@@ -4,6 +4,7 @@ from collections.abc import AsyncGenerator
|
||||
from itertools import count
|
||||
from typing import Any
|
||||
|
||||
from exo.shared.types.api import Usage
|
||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.openai_responses import (
|
||||
@@ -121,13 +122,15 @@ async def collect_responses_response(
|
||||
command_id: CommandId,
|
||||
model: str,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
) -> ResponsesResponse:
|
||||
) -> AsyncGenerator[str]:
|
||||
# This is an AsyncGenerator[str] rather than returning a ChatCompletionReponse because
|
||||
# FastAPI handles the cancellation better but wouldn't auto-serialize for some reason
|
||||
"""Collect all token chunks and return a single ResponsesResponse."""
|
||||
response_id = f"resp_{command_id}"
|
||||
item_id = f"item_{command_id}"
|
||||
accumulated_text = ""
|
||||
function_call_items: list[ResponseFunctionCallItem] = []
|
||||
last_stats = None
|
||||
last_usage: Usage | None = None
|
||||
error_message: str | None = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
@@ -135,32 +138,32 @@ async def collect_responses_response(
|
||||
error_message = chunk.error_message or "Internal server error"
|
||||
break
|
||||
|
||||
last_usage = chunk.usage or last_usage
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
for tool in chunk.tool_calls:
|
||||
function_call_items.append(
|
||||
ResponseFunctionCallItem(
|
||||
id=f"fc_{tool.id}",
|
||||
call_id=f"call_{tool.id}",
|
||||
id=tool.id,
|
||||
call_id=tool.id,
|
||||
name=tool.name,
|
||||
arguments=tool.arguments,
|
||||
)
|
||||
)
|
||||
last_stats = chunk.stats or last_stats
|
||||
continue
|
||||
|
||||
accumulated_text += chunk.text
|
||||
last_stats = chunk.stats or last_stats
|
||||
|
||||
if error_message is not None:
|
||||
raise ValueError(error_message)
|
||||
|
||||
# Create usage from stats if available
|
||||
# Create usage from usage data if available
|
||||
usage = None
|
||||
if last_stats is not None:
|
||||
if last_usage is not None:
|
||||
usage = ResponseUsage(
|
||||
input_tokens=last_stats.prompt_tokens,
|
||||
output_tokens=last_stats.generation_tokens,
|
||||
total_tokens=last_stats.prompt_tokens + last_stats.generation_tokens,
|
||||
input_tokens=last_usage.prompt_tokens,
|
||||
output_tokens=last_usage.completion_tokens,
|
||||
total_tokens=last_usage.total_tokens,
|
||||
)
|
||||
|
||||
output: list[ResponseItem] = [
|
||||
@@ -172,14 +175,15 @@ async def collect_responses_response(
|
||||
]
|
||||
output.extend(function_call_items)
|
||||
|
||||
return ResponsesResponse(
|
||||
yield ResponsesResponse(
|
||||
id=response_id,
|
||||
model=model,
|
||||
status="completed",
|
||||
output=output,
|
||||
output_text=accumulated_text,
|
||||
usage=usage,
|
||||
)
|
||||
).model_dump_json()
|
||||
return
|
||||
|
||||
|
||||
async def generate_responses_stream(
|
||||
@@ -235,15 +239,16 @@ async def generate_responses_stream(
|
||||
|
||||
accumulated_text = ""
|
||||
function_call_items: list[ResponseFunctionCallItem] = []
|
||||
last_stats = None
|
||||
last_usage: Usage | None = None
|
||||
next_output_index = 1 # message item is at 0
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
break
|
||||
|
||||
last_usage = chunk.usage or last_usage
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
last_stats = chunk.stats or last_stats
|
||||
for tool in chunk.tool_calls:
|
||||
fc_id = f"fc_{tool.id}"
|
||||
call_id = f"call_{tool.id}"
|
||||
@@ -302,7 +307,6 @@ async def generate_responses_stream(
|
||||
continue
|
||||
|
||||
accumulated_text += chunk.text
|
||||
last_stats = chunk.stats or last_stats
|
||||
|
||||
# response.output_text.delta
|
||||
delta_event = ResponseTextDeltaEvent(
|
||||
@@ -346,13 +350,13 @@ async def generate_responses_stream(
|
||||
)
|
||||
yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n"
|
||||
|
||||
# Create usage from stats if available
|
||||
# Create usage from usage data if available
|
||||
usage = None
|
||||
if last_stats is not None:
|
||||
if last_usage is not None:
|
||||
usage = ResponseUsage(
|
||||
input_tokens=last_stats.prompt_tokens,
|
||||
output_tokens=last_stats.generation_tokens,
|
||||
total_tokens=last_stats.prompt_tokens + last_stats.generation_tokens,
|
||||
input_tokens=last_usage.prompt_tokens,
|
||||
output_tokens=last_usage.completion_tokens,
|
||||
total_tokens=last_usage.total_tokens,
|
||||
)
|
||||
|
||||
# response.completed
|
||||
|
||||
@@ -1232,12 +1232,15 @@ class API:
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
return await collect_claude_response(
|
||||
command.command_id,
|
||||
payload.model,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
)
|
||||
else:
|
||||
return StreamingResponse(
|
||||
collect_claude_response(
|
||||
command.command_id,
|
||||
payload.model,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
),
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
async def openai_responses(
|
||||
self, payload: ResponsesRequest
|
||||
@@ -1265,11 +1268,15 @@ class API:
|
||||
},
|
||||
)
|
||||
|
||||
return await collect_responses_response(
|
||||
command.command_id,
|
||||
payload.model,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
)
|
||||
else:
|
||||
return StreamingResponse(
|
||||
collect_responses_response(
|
||||
command.command_id,
|
||||
payload.model,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
),
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
def _calculate_total_available_memory(self) -> Memory:
|
||||
"""Calculate total available memory across all nodes in bytes."""
|
||||
@@ -1360,6 +1367,7 @@ class API:
|
||||
async def run(self):
|
||||
shutdown_ev = anyio.Event()
|
||||
|
||||
bonjour_cleanup = self._register_bonjour_service()
|
||||
try:
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
@@ -1375,10 +1383,38 @@ class API:
|
||||
with anyio.CancelScope(shield=True):
|
||||
shutdown_ev.set()
|
||||
finally:
|
||||
bonjour_cleanup()
|
||||
self._event_log.close()
|
||||
self.command_sender.close()
|
||||
self.global_event_receiver.close()
|
||||
|
||||
def _register_bonjour_service(self) -> Callable[[], None]:
|
||||
"""Register a Bonjour service via the system mDNSResponder. Returns a cleanup function."""
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
if sys.platform != "darwin":
|
||||
logger.info("Bonjour service registration is only supported on macOS")
|
||||
return lambda: None
|
||||
|
||||
service_name = f"EXO Cluster ({self.node_id[:8]})"
|
||||
try:
|
||||
proc = subprocess.Popen(
|
||||
["dns-sd", "-R", service_name, "_exo._tcp", "local", str(self.port), f"node_id={self.node_id}"],
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
logger.info(f"Registered Bonjour service _exo._tcp on port {self.port} (pid {proc.pid})")
|
||||
|
||||
def cleanup() -> None:
|
||||
proc.terminate()
|
||||
proc.wait()
|
||||
|
||||
return cleanup
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to register Bonjour service: {e}")
|
||||
return lambda: None
|
||||
|
||||
async def run_api(self, ev: anyio.Event):
|
||||
cfg = Config()
|
||||
cfg.bind = [f"0.0.0.0:{self.port}"]
|
||||
|
||||
@@ -4,7 +4,11 @@ import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any, cast
|
||||
|
||||
from exo.master.adapters.claude import collect_claude_response, generate_claude_stream
|
||||
from exo.master.adapters.claude import (
|
||||
ClaudeMessagesResponse,
|
||||
collect_claude_response,
|
||||
generate_claude_stream,
|
||||
)
|
||||
from exo.shared.types.api import ToolCallItem
|
||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.common import CommandId, ModelId
|
||||
@@ -17,6 +21,18 @@ async def _chunks_to_stream(
|
||||
yield chunk
|
||||
|
||||
|
||||
async def _collect_response(
|
||||
command_id: CommandId,
|
||||
model: str,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
) -> ClaudeMessagesResponse:
|
||||
"""Helper to consume the async generator and parse the JSON response."""
|
||||
parts: list[str] = []
|
||||
async for part in collect_claude_response(command_id, model, chunk_stream):
|
||||
parts.append(part)
|
||||
return ClaudeMessagesResponse.model_validate_json("".join(parts))
|
||||
|
||||
|
||||
MODEL = ModelId("test-model")
|
||||
COMMAND_ID = CommandId("cmd_test123")
|
||||
|
||||
@@ -47,7 +63,7 @@ class TestCollectClaudeResponseToolUse:
|
||||
],
|
||||
),
|
||||
]
|
||||
response = await collect_claude_response(
|
||||
response = await _collect_response(
|
||||
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
|
||||
)
|
||||
|
||||
@@ -77,7 +93,7 @@ class TestCollectClaudeResponseToolUse:
|
||||
],
|
||||
),
|
||||
]
|
||||
response = await collect_claude_response(
|
||||
response = await _collect_response(
|
||||
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
|
||||
)
|
||||
|
||||
@@ -102,7 +118,7 @@ class TestCollectClaudeResponseToolUse:
|
||||
],
|
||||
),
|
||||
]
|
||||
response = await collect_claude_response(
|
||||
response = await _collect_response(
|
||||
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
|
||||
)
|
||||
|
||||
@@ -116,7 +132,7 @@ class TestCollectClaudeResponseToolUse:
|
||||
|
||||
async def test_no_content_produces_empty_text_block(self):
|
||||
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = []
|
||||
response = await collect_claude_response(
|
||||
response = await _collect_response(
|
||||
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
|
||||
)
|
||||
assert len(response.content) == 1
|
||||
|
||||
@@ -218,11 +218,6 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||
key: value for key, value in state.downloads.items() if key != event.node_id
|
||||
}
|
||||
# Clean up all granular node mappings
|
||||
node_identities = {
|
||||
key: value
|
||||
for key, value in state.node_identities.items()
|
||||
if key != event.node_id
|
||||
}
|
||||
node_memory = {
|
||||
key: value for key, value in state.node_memory.items() if key != event.node_id
|
||||
}
|
||||
@@ -263,7 +258,6 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||
"downloads": downloads,
|
||||
"topology": topology,
|
||||
"last_seen": last_seen,
|
||||
"node_identities": node_identities,
|
||||
"node_memory": node_memory,
|
||||
"node_disk": node_disk,
|
||||
"node_system": node_system,
|
||||
|
||||
@@ -3,8 +3,7 @@ from collections.abc import Generator
|
||||
from typing import Annotated, Any, Literal
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic_core import PydanticUseDefault
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
@@ -228,13 +227,6 @@ class PlaceInstanceParams(BaseModel):
|
||||
instance_meta: InstanceMeta = InstanceMeta.MlxRing
|
||||
min_nodes: int = 1
|
||||
|
||||
@field_validator("sharding", "instance_meta", mode="plain")
|
||||
@classmethod
|
||||
def use_default(cls, v: object):
|
||||
if not v or not isinstance(v, (Sharding, InstanceMeta)):
|
||||
raise PydanticUseDefault()
|
||||
return v
|
||||
|
||||
|
||||
class CreateInstanceParams(BaseModel):
|
||||
instance: Instance
|
||||
|
||||
@@ -26,6 +26,7 @@ class DownloadProgressData(CamelCaseModel):
|
||||
class BaseDownloadProgress(TaggedModel):
|
||||
node_id: NodeId
|
||||
shard_metadata: ShardMetadata
|
||||
model_directory: str = ""
|
||||
|
||||
|
||||
class DownloadPending(BaseDownloadProgress):
|
||||
|
||||
@@ -62,6 +62,7 @@ class PartialImageResponse(BaseRunnerResponse):
|
||||
class ToolCallResponse(BaseRunnerResponse):
|
||||
tool_calls: list[ToolCallItem]
|
||||
usage: Usage | None
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
|
||||
class FinishedResponse(BaseRunnerResponse):
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import sys
|
||||
|
||||
|
||||
def print_startup_banner(port: int) -> None:
|
||||
"""Print a prominent startup banner with API endpoint information."""
|
||||
dashboard_url = f"http://localhost:{port}"
|
||||
banner = f"""
|
||||
╔═══════════════════════════════════════════════════════════════════════╗
|
||||
@@ -27,4 +29,4 @@ def print_startup_banner(port: int) -> None:
|
||||
|
||||
"""
|
||||
|
||||
print(banner)
|
||||
print(banner, file=sys.stderr)
|
||||
|
||||
@@ -125,7 +125,9 @@ class MpSender[T]:
|
||||
self._state.buffer.put(item, block=True)
|
||||
|
||||
async def send_async(self, item: T) -> None:
|
||||
await to_thread.run_sync(self.send, item, limiter=CapacityLimiter(1))
|
||||
await to_thread.run_sync(
|
||||
self.send, item, limiter=CapacityLimiter(1), abandon_on_cancel=True
|
||||
)
|
||||
|
||||
def close(self) -> None:
|
||||
if not self._state.closed.is_set():
|
||||
|
||||
@@ -306,7 +306,7 @@ def mlx_generate(
|
||||
max_stop_len = max((len(s) for s in stop_sequences), default=0)
|
||||
|
||||
mx_barrier(group)
|
||||
logger.info("Ready to prefill")
|
||||
logger.info("Starting prefill")
|
||||
|
||||
# Prefill cache with all tokens except the last one
|
||||
prefill_tps, prefill_tokens, ssm_snapshots_list = prefill(
|
||||
@@ -393,10 +393,11 @@ def mlx_generate(
|
||||
f"Model generated unexpected finish_reason: {out.finish_reason}"
|
||||
)
|
||||
|
||||
total_prompt_tokens = len(all_prompt_tokens)
|
||||
usage = Usage(
|
||||
prompt_tokens=int(out.prompt_tokens),
|
||||
prompt_tokens=total_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=int(out.prompt_tokens) + completion_tokens,
|
||||
total_tokens=total_prompt_tokens + completion_tokens,
|
||||
prompt_tokens_details=PromptTokensDetails(
|
||||
cached_tokens=prefix_hit_length
|
||||
),
|
||||
|
||||
@@ -353,7 +353,13 @@ def load_tokenizer_for_model_id(
|
||||
return list(hf_tokenizer.model.encode(text, allowed_special="all")) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
|
||||
|
||||
hf_tokenizer.encode = _patched_encode
|
||||
return TokenizerWrapper(hf_tokenizer, eos_token_ids=eos_token_ids)
|
||||
return TokenizerWrapper(
|
||||
hf_tokenizer,
|
||||
eos_token_ids=eos_token_ids,
|
||||
tool_call_start="<|tool_calls_section_begin|>",
|
||||
tool_call_end="<|tool_calls_section_end|>",
|
||||
tool_parser=_parse_kimi_tool_calls,
|
||||
)
|
||||
|
||||
tokenizer = load_tokenizer(
|
||||
model_path,
|
||||
@@ -585,3 +591,41 @@ def mx_barrier(group: Group | None):
|
||||
mx.array(1.0), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _parse_kimi_tool_calls(text: str):
|
||||
import regex as re
|
||||
|
||||
# kimi has a fixed function naming scheme, with a json formatted arg
|
||||
# functions.multiply:0<|tool_call_argument_begin|>{"a": 2, "b": 3}
|
||||
_func_name_regex = re.compile(
|
||||
r"^\s*((?:functions\.)?(.+?):\d+)\s*<\|tool_call_argument_begin\|>", re.DOTALL
|
||||
)
|
||||
_func_arg_regex = re.compile(r"<\|tool_call_argument_begin\|>\s*(.*)\s*", re.DOTALL)
|
||||
_tool_call_split_regex = re.compile(
|
||||
r"<\|tool_call_begin\|>(.*?)<\|tool_call_end\|>", re.DOTALL
|
||||
)
|
||||
|
||||
def _parse_single_tool(text: str) -> dict[str, Any]:
|
||||
func_name_match = _func_name_regex.search(text)
|
||||
if func_name_match is None:
|
||||
raise ValueError("No tool call found.")
|
||||
tool_call_id = func_name_match.group(1) # e.g. "functions.get_weather:0"
|
||||
func_name = func_name_match.group(2) # e.g. "get_weather"
|
||||
|
||||
func_args_match = _func_arg_regex.search(text)
|
||||
if func_args_match is None:
|
||||
raise ValueError("No tool call arguments found.")
|
||||
func_args = func_args_match.group(1)
|
||||
try:
|
||||
arg_dct = json.loads(func_args) # pyright: ignore[reportAny]
|
||||
except Exception:
|
||||
arg_dct = None
|
||||
|
||||
return dict(id=tool_call_id, name=func_name, arguments=arg_dct)
|
||||
|
||||
tool_matches = _tool_call_split_regex.findall(text)
|
||||
if tool_matches:
|
||||
return [_parse_single_tool(match) for match in tool_matches] # pyright: ignore[reportAny]
|
||||
else:
|
||||
return [_parse_single_tool(text)]
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import base64
|
||||
import json
|
||||
import math
|
||||
import resource
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from functools import cache
|
||||
from typing import Any, Callable, Literal
|
||||
from typing import Literal
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
@@ -16,7 +15,6 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
StreamableParser,
|
||||
load_harmony_encoding,
|
||||
)
|
||||
from pydantic import ValidationError
|
||||
|
||||
from exo.shared.constants import EXO_MAX_CHUNK_SIZE, EXO_TRACING_ENABLED
|
||||
from exo.shared.models.model_cards import ModelId, ModelTask
|
||||
@@ -93,6 +91,8 @@ from exo.worker.engines.mlx.utils_mlx import (
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
from .tool_parsers import ToolParser, make_mlx_parser
|
||||
|
||||
|
||||
def _is_primary_output_node(shard_metadata: ShardMetadata) -> bool:
|
||||
"""Check if this node is the primary output node for image generation.
|
||||
@@ -138,6 +138,7 @@ def main(
|
||||
inference_model: Model | None = None
|
||||
image_model: DistributedImageModel | None = None
|
||||
tokenizer = None
|
||||
tool_parser: ToolParser | None = None
|
||||
group = None
|
||||
kv_prefix_cache: KVPrefixCache | None = None
|
||||
check_for_cancel_every: int | None = None
|
||||
@@ -203,8 +204,17 @@ def main(
|
||||
bound_instance, group, on_timeout=on_model_load_timeout
|
||||
)
|
||||
logger.info(
|
||||
f"model has_tool_calling={tokenizer.has_tool_calling}"
|
||||
f"model has_tool_calling={tokenizer.has_tool_calling} using tokens {tokenizer.tool_call_start}, {tokenizer.tool_call_end}"
|
||||
)
|
||||
if tokenizer.has_tool_calling:
|
||||
assert tokenizer.tool_call_start
|
||||
assert tokenizer.tool_call_end
|
||||
assert tokenizer.tool_parser # pyright: ignore[reportAny]
|
||||
tool_parser = make_mlx_parser(
|
||||
tokenizer.tool_call_start,
|
||||
tokenizer.tool_call_end,
|
||||
tokenizer.tool_parser, # pyright: ignore[reportAny]
|
||||
)
|
||||
kv_prefix_cache = KVPrefixCache(group)
|
||||
|
||||
elif (
|
||||
@@ -233,7 +243,7 @@ def main(
|
||||
assert inference_model
|
||||
assert tokenizer
|
||||
|
||||
t = time.perf_counter()
|
||||
t = time.monotonic()
|
||||
toks = warmup_inference(
|
||||
model=inference_model,
|
||||
tokenizer=tokenizer,
|
||||
@@ -241,7 +251,7 @@ def main(
|
||||
)
|
||||
logger.info(f"warmed up by generating {toks} tokens")
|
||||
check_for_cancel_every = min(
|
||||
math.ceil(toks / (time.perf_counter() - t)), 100
|
||||
math.ceil(toks / min(time.monotonic() - t, 0.001)), 100
|
||||
)
|
||||
if group is not None:
|
||||
check_for_cancel_every = int(
|
||||
@@ -310,31 +320,11 @@ def main(
|
||||
mlx_generator, tokenizer
|
||||
)
|
||||
|
||||
# Kimi-K2 has tool call sections - we don't care about them
|
||||
if "kimi" in shard_metadata.model_card.model_id.lower():
|
||||
mlx_generator = filter_kimi_tokens(mlx_generator)
|
||||
patch_kimi_tokenizer(tokenizer)
|
||||
|
||||
# GLM models need patched parser (upstream has bug with None regex match)
|
||||
elif "glm" in shard_metadata.model_card.model_id.lower():
|
||||
patch_glm_tokenizer(tokenizer)
|
||||
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
elif isinstance(inference_model, GptOssModel):
|
||||
if isinstance(inference_model, GptOssModel):
|
||||
mlx_generator = parse_gpt_oss(mlx_generator)
|
||||
|
||||
if tokenizer.has_tool_calling and not isinstance(
|
||||
inference_model, GptOssModel
|
||||
):
|
||||
assert tokenizer.tool_call_start
|
||||
assert tokenizer.tool_call_end
|
||||
assert tokenizer.tool_parser # pyright: ignore[reportAny]
|
||||
mlx_generator = parse_tool_calls(
|
||||
mlx_generator,
|
||||
tokenizer.tool_call_start,
|
||||
tokenizer.tool_call_end,
|
||||
tokenizer.tool_parser, # pyright: ignore[reportAny]
|
||||
)
|
||||
elif tool_parser:
|
||||
mlx_generator = parse_tool_calls(mlx_generator, tool_parser)
|
||||
|
||||
completion_tokens = 0
|
||||
tokens_since_last_cancel_check = 0
|
||||
@@ -396,6 +386,7 @@ def main(
|
||||
tool_calls=response.tool_calls,
|
||||
model=shard_metadata.model_card.model_id,
|
||||
usage=response.usage,
|
||||
stats=response.stats,
|
||||
),
|
||||
)
|
||||
)
|
||||
@@ -559,9 +550,15 @@ def main(
|
||||
raise ValueError(
|
||||
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
|
||||
)
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete)
|
||||
was_cancelled = (task.task_id in cancelled_tasks) or (
|
||||
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
|
||||
)
|
||||
if not was_cancelled:
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Complete
|
||||
)
|
||||
)
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
@@ -580,21 +577,8 @@ def get_gpt_oss_encoding():
|
||||
return encoding
|
||||
|
||||
|
||||
def filter_kimi_tokens(
|
||||
responses: Generator[GenerationResponse | ToolCallResponse],
|
||||
) -> Generator[GenerationResponse]:
|
||||
for resp in responses:
|
||||
assert isinstance(resp, GenerationResponse)
|
||||
if (
|
||||
resp.text == "<|tool_calls_section_begin|>"
|
||||
or resp.text == "<|tool_calls_section_end|>"
|
||||
):
|
||||
continue
|
||||
yield resp
|
||||
|
||||
|
||||
def parse_gpt_oss(
|
||||
responses: Generator[GenerationResponse | ToolCallResponse],
|
||||
responses: Generator[GenerationResponse],
|
||||
) -> Generator[GenerationResponse | ToolCallResponse]:
|
||||
encoding = get_gpt_oss_encoding()
|
||||
stream = StreamableParser(encoding, role=Role.ASSISTANT)
|
||||
@@ -651,9 +635,9 @@ def parse_gpt_oss(
|
||||
|
||||
|
||||
def parse_thinking_models(
|
||||
responses: Generator[GenerationResponse | ToolCallResponse],
|
||||
responses: Generator[GenerationResponse],
|
||||
tokenizer: TokenizerWrapper,
|
||||
) -> Generator[GenerationResponse | ToolCallResponse]:
|
||||
) -> Generator[GenerationResponse]:
|
||||
"""
|
||||
For models that inject thinking tags in the prompt (like GLM-4.7),
|
||||
prepend the thinking tag to the output stream so the frontend
|
||||
@@ -774,218 +758,55 @@ def _process_image_response(
|
||||
|
||||
|
||||
def parse_tool_calls(
|
||||
responses: Generator[GenerationResponse | ToolCallResponse],
|
||||
tool_call_start: str,
|
||||
tool_call_end: str,
|
||||
tool_parser: Callable[[str], dict[str, Any] | list[dict[str, Any]]],
|
||||
responses: Generator[GenerationResponse], tool_parser: ToolParser
|
||||
) -> Generator[GenerationResponse | ToolCallResponse]:
|
||||
in_tool_call = False
|
||||
tool_call_text_parts: list[str] = []
|
||||
for response in responses:
|
||||
assert isinstance(response, GenerationResponse)
|
||||
# assumption: the tool call start is one token
|
||||
if response.text == tool_call_start:
|
||||
if response.text.startswith(tool_parser.start_parsing):
|
||||
in_tool_call = True
|
||||
continue
|
||||
# assumption: the tool call end is one token
|
||||
if in_tool_call and response.text == tool_call_end:
|
||||
try:
|
||||
# tool_parser returns an arbitrarily nested python dictionary
|
||||
# we actually don't want the python dictionary, we just want to
|
||||
# parse the top level { function: ..., arguments: ... } structure
|
||||
# as we're just gonna hand it back to the api anyway
|
||||
parsed = tool_parser("".join(tool_call_text_parts).strip())
|
||||
logger.info(f"parsed {tool_call_text_parts=} into {parsed=}")
|
||||
if isinstance(parsed, list):
|
||||
tools = [_validate_single_tool(tool) for tool in parsed]
|
||||
else:
|
||||
tools = [_validate_single_tool(parsed)]
|
||||
yield ToolCallResponse(tool_calls=tools, usage=response.usage)
|
||||
|
||||
except (
|
||||
json.JSONDecodeError,
|
||||
ValidationError,
|
||||
ValueError,
|
||||
AttributeError,
|
||||
) as e:
|
||||
# ValueError: our parsers raise this for malformed tool calls
|
||||
# AttributeError: upstream parsers (e.g. glm47) may raise this when regex doesn't match
|
||||
logger.opt(exception=e).warning("tool call parsing failed")
|
||||
# assumption: talking about tool calls, not making a tool call
|
||||
response.text = (
|
||||
tool_call_start + "".join(tool_call_text_parts) + tool_call_end
|
||||
)
|
||||
yield response
|
||||
|
||||
in_tool_call = False
|
||||
tool_call_text_parts = []
|
||||
continue
|
||||
|
||||
if in_tool_call:
|
||||
tool_call_text_parts.append(response.text)
|
||||
if response.text.endswith(tool_parser.end_parsing):
|
||||
# parse the actual tool calls from the tool call text
|
||||
parsed = tool_parser.parse_tool_calls(
|
||||
"".join(tool_call_text_parts).strip()
|
||||
)
|
||||
logger.info(f"parsed {tool_call_text_parts=} into {parsed=}")
|
||||
if parsed is not None:
|
||||
yield ToolCallResponse(
|
||||
tool_calls=parsed, usage=response.usage, stats=response.stats
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"tool call parsing failed for text {''.join(tool_call_text_parts)}"
|
||||
)
|
||||
response.text = "".join(tool_call_text_parts)
|
||||
yield response
|
||||
|
||||
in_tool_call = False
|
||||
tool_call_text_parts = []
|
||||
continue
|
||||
|
||||
if response.finish_reason is not None:
|
||||
logger.info(
|
||||
"toll call parsing interrupted, yield partial tool call as text"
|
||||
"tool call parsing interrupted, yield partial tool call as text"
|
||||
)
|
||||
yield GenerationResponse(
|
||||
text=tool_call_start + "".join(tool_call_text_parts),
|
||||
token=0,
|
||||
finish_reason=response.finish_reason,
|
||||
usage=None,
|
||||
response = response.model_copy(
|
||||
update={
|
||||
"text": "".join(tool_call_text_parts),
|
||||
"token": 0,
|
||||
}
|
||||
)
|
||||
yield response
|
||||
|
||||
continue
|
||||
|
||||
# fallthrough
|
||||
yield response
|
||||
|
||||
|
||||
def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
|
||||
"""
|
||||
Version of to-be-upstreamed kimi-k2 tool parser
|
||||
"""
|
||||
import ast
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
|
||||
# kimi has a fixed function naming scheme, with a json formatted arg
|
||||
# functions.multiply:0 <|tool_call_argument_begin|> {"a": 2, "b": 3}
|
||||
# Also needs to handle tools like call_0<|tool_call_argument_begin|>{"filePath": "..."}
|
||||
_func_name_regex = re.compile(
|
||||
r"^\s*(.+)[:](\d+)\s*<\|tool_call_argument_begin\|>", re.DOTALL
|
||||
)
|
||||
_func_arg_regex = re.compile(r"<\|tool_call_argument_begin\|>\s*(.*)\s*", re.DOTALL)
|
||||
|
||||
# kimi has a tool_calls_section - we're leaving this up to the caller to handle
|
||||
tool_call_start = "<|tool_call_begin|>"
|
||||
tool_call_end = "<|tool_call_end|>"
|
||||
|
||||
def _deserialize(value: str) -> Any: # pyright: ignore[reportAny]
|
||||
try:
|
||||
return json.loads(value) # pyright: ignore[reportAny]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
return ast.literal_eval(value) # pyright: ignore[reportAny]
|
||||
except Exception:
|
||||
pass
|
||||
return value
|
||||
|
||||
def parse_tool_call(text: str, tools: Any | None = None):
|
||||
func_name_match = _func_name_regex.search(text)
|
||||
if func_name_match is None:
|
||||
raise ValueError(f"Could not parse function name from tool call: {text!r}")
|
||||
original_func_name = func_name_match.group(1)
|
||||
tool_id = func_name_match.group(2)
|
||||
# strip off the `functions.` prefix, if it exists.
|
||||
func_name = original_func_name[original_func_name.find(".") + 1 :]
|
||||
|
||||
func_args_match = _func_arg_regex.search(text)
|
||||
if func_args_match is None:
|
||||
raise ValueError(f"Could not parse function args from tool call: {text!r}")
|
||||
func_args = func_args_match.group(1)
|
||||
# the args should be valid json - no need to check against our tools to deserialize
|
||||
arg_dct = _deserialize(func_args) # pyright: ignore[reportAny]
|
||||
|
||||
return dict(
|
||||
id=f"{original_func_name}:{tool_id}",
|
||||
name=func_name,
|
||||
arguments=arg_dct, # pyright: ignore[reportAny]
|
||||
)
|
||||
|
||||
tokenizer._tool_call_start = tool_call_start
|
||||
tokenizer._tool_call_end = tool_call_end
|
||||
tokenizer._tool_parser = parse_tool_call
|
||||
|
||||
|
||||
def patch_glm_tokenizer(tokenizer: TokenizerWrapper):
|
||||
"""
|
||||
Fixed version of mlx_lm's glm47 tool parser that handles regex match failures.
|
||||
"""
|
||||
import ast
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
|
||||
_func_name_regex = re.compile(r"^(.*?)<arg_key>", re.DOTALL)
|
||||
_func_arg_regex = re.compile(
|
||||
r"<arg_key>(.*?)</arg_key>(?:\n|\s)*<arg_value>(.*?)(?:</arg_value>|(?=<arg_key>)|$)",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
tool_call_start = "<tool_call>"
|
||||
tool_call_end = "</tool_call>"
|
||||
|
||||
def _is_string_type(
|
||||
tool_name: str,
|
||||
arg_name: str,
|
||||
tools: list[Any] | None,
|
||||
) -> bool:
|
||||
if tools is None:
|
||||
return False
|
||||
for tool in tools: # pyright: ignore[reportAny]
|
||||
func = tool["function"] # pyright: ignore[reportAny]
|
||||
if func["name"] == tool_name:
|
||||
params = func["parameters"] # pyright: ignore[reportAny]
|
||||
if params is None:
|
||||
return False
|
||||
props = params.get("properties", {}) # pyright: ignore[reportAny]
|
||||
arg_props = props.get(arg_name, {}) # pyright: ignore[reportAny]
|
||||
arg_type = arg_props.get("type", None) # pyright: ignore[reportAny]
|
||||
return arg_type == "string" # pyright: ignore[reportAny]
|
||||
return False
|
||||
|
||||
def _deserialize(value: str) -> Any: # pyright: ignore[reportAny]
|
||||
try:
|
||||
return json.loads(value) # pyright: ignore[reportAny]
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
return ast.literal_eval(value) # pyright: ignore[reportAny]
|
||||
except Exception:
|
||||
pass
|
||||
return value
|
||||
|
||||
def parse_tool_call(text: str, tools: list[Any] | None = None):
|
||||
func_name_match = _func_name_regex.search(text)
|
||||
if func_name_match is None:
|
||||
raise ValueError(f"Could not parse function name from tool call: {text!r}")
|
||||
func_name = func_name_match.group(1)
|
||||
|
||||
pairs = _func_arg_regex.findall(text)
|
||||
arg_dct: dict[str, Any] = {}
|
||||
for key, value in pairs: # pyright: ignore[reportAny]
|
||||
arg_key = key.strip() # pyright: ignore[reportAny]
|
||||
arg_val = value.strip() # pyright: ignore[reportAny]
|
||||
if not _is_string_type(func_name, arg_key, tools): # pyright: ignore[reportAny]
|
||||
arg_val = _deserialize(arg_val) # pyright: ignore[reportAny]
|
||||
arg_dct[arg_key] = arg_val
|
||||
return dict(name=func_name, arguments=arg_dct)
|
||||
|
||||
tokenizer._tool_call_start = tool_call_start
|
||||
tokenizer._tool_call_end = tool_call_end
|
||||
tokenizer._tool_parser = parse_tool_call
|
||||
|
||||
|
||||
def _validate_single_tool(obj: dict[str, Any]) -> ToolCallItem:
|
||||
if (
|
||||
((name := obj.get("name")) is not None)
|
||||
and ((args := obj.get("arguments")) is not None)
|
||||
and isinstance(name, str)
|
||||
):
|
||||
raw_id: object = obj.get("id")
|
||||
extra = {"id": str(raw_id)} if raw_id is not None else {}
|
||||
return ToolCallItem(
|
||||
**extra,
|
||||
name=name,
|
||||
arguments=json.dumps(args),
|
||||
)
|
||||
else:
|
||||
raise ValidationError
|
||||
|
||||
|
||||
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
|
||||
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
|
||||
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"
|
||||
|
||||
@@ -101,6 +101,7 @@ class RunnerSupervisor:
|
||||
self._ev_recv.close()
|
||||
self._task_sender.close()
|
||||
self._event_sender.close()
|
||||
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
|
||||
self._cancel_sender.close()
|
||||
self.runner_process.join(1)
|
||||
if not self.runner_process.is_alive():
|
||||
@@ -143,7 +144,11 @@ class RunnerSupervisor:
|
||||
logger.info(f"Unable to cancel {task_id} as it has been completed")
|
||||
return
|
||||
self.cancelled.add(task_id)
|
||||
await self._cancel_sender.send_async(task_id)
|
||||
with anyio.move_on_after(0.5) as scope:
|
||||
await self._cancel_sender.send_async(task_id)
|
||||
if scope.cancel_called:
|
||||
logger.error("RunnerSupervisor cancel pipe blocked")
|
||||
await self._check_runner(TimeoutError("cancel pipe blocked"))
|
||||
|
||||
async def _forward_events(self):
|
||||
with self._ev_recv as events:
|
||||
|
||||
72
src/exo/worker/runner/tool_parsers.py
Normal file
72
src/exo/worker/runner/tool_parsers.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable
|
||||
|
||||
from exo.shared.types.api import ToolCallItem
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolParser:
|
||||
start_parsing: str
|
||||
end_parsing: str
|
||||
parse_tool_calls: Callable[[str], list[ToolCallItem] | None]
|
||||
|
||||
|
||||
def make_mlx_parser(
|
||||
tool_call_start: str,
|
||||
tool_call_end: str,
|
||||
tool_parser: Callable[[str], dict[str, Any] | list[dict[str, Any]]],
|
||||
) -> ToolParser:
|
||||
def parse_tool_calls(text: str) -> list[ToolCallItem] | None:
|
||||
try:
|
||||
text = text.removeprefix(tool_call_start)
|
||||
text = text.removesuffix(tool_call_end)
|
||||
parsed = tool_parser(text)
|
||||
if isinstance(parsed, list):
|
||||
return [ToolCallItem.model_validate(_flatten(p)) for p in parsed]
|
||||
else:
|
||||
return [ToolCallItem.model_validate(_flatten(parsed))]
|
||||
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return ToolParser(
|
||||
start_parsing=tool_call_start,
|
||||
end_parsing=tool_call_end,
|
||||
parse_tool_calls=parse_tool_calls,
|
||||
)
|
||||
|
||||
|
||||
# TODO / example code:
|
||||
def _parse_json_calls(text: str) -> list[ToolCallItem] | None:
|
||||
try:
|
||||
text = text.removeprefix("<tool_call>")
|
||||
text = text.removesuffix("</tool_call>")
|
||||
top_level = {
|
||||
k: json.dumps(v) if isinstance(v, (dict, list)) else v
|
||||
for k, v in json.loads(text).items() # pyright: ignore[reportAny]
|
||||
}
|
||||
return [ToolCallItem.model_validate(top_level)]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _flatten(p: dict[str, Any]) -> dict[str, str]:
|
||||
return {
|
||||
k: json.dumps(v) if isinstance(v, (dict, list)) else str(v) # pyright: ignore[reportAny]
|
||||
for k, v in p.items() # pyright: ignore[reportAny]
|
||||
}
|
||||
|
||||
|
||||
json_tool_parser = ToolParser(
|
||||
start_parsing="<tool_call>",
|
||||
end_parsing="</tool_call>",
|
||||
parse_tool_calls=_parse_json_calls,
|
||||
)
|
||||
|
||||
|
||||
def infer_tool_parser(chat_template: str) -> ToolParser | None:
|
||||
"""Attempt to auto-infer a tool parser from the chat template."""
|
||||
if "<tool_call>" in chat_template and "tool_call.name" in chat_template:
|
||||
return json_tool_parser
|
||||
return None
|
||||
@@ -1,4 +1,5 @@
|
||||
# Check tasks are complete before runner is ever ready.
|
||||
import unittest.mock
|
||||
from collections.abc import Iterable
|
||||
from typing import Callable
|
||||
|
||||
@@ -115,7 +116,6 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
|
||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
|
||||
monkeypatch.setattr(mx.distributed, "all_gather", make_nothin(mx.array([1])))
|
||||
monkeypatch.setattr(mlx_runner, "mx_any", make_nothin(False))
|
||||
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).
|
||||
# Returns a prompt without thinking tag so detect_thinking_prompt_suffix returns None.
|
||||
@@ -178,8 +178,16 @@ def _run(tasks: Iterable[Task]):
|
||||
# this is some c++ nonsense
|
||||
task_receiver.close = nothin
|
||||
task_receiver.join = nothin
|
||||
|
||||
mlx_runner.main(bound_instance, event_sender, task_receiver, cancel_receiver) # pyright: ignore[reportArgumentType]
|
||||
with unittest.mock.patch(
|
||||
"exo.worker.runner.runner.mx.distributed.all_gather",
|
||||
make_nothin(mx.array([1])),
|
||||
):
|
||||
mlx_runner.main(
|
||||
bound_instance,
|
||||
event_sender, # pyright: ignore[reportArgumentType]
|
||||
task_receiver,
|
||||
cancel_receiver,
|
||||
)
|
||||
|
||||
return event_sender.events
|
||||
|
||||
|
||||
@@ -5,12 +5,13 @@ from typing import Any
|
||||
|
||||
from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse
|
||||
from exo.worker.runner.runner import parse_tool_calls
|
||||
from exo.worker.runner.tool_parsers import make_mlx_parser
|
||||
|
||||
|
||||
def _make_responses(
|
||||
texts: list[str],
|
||||
finish_on_last: bool = True,
|
||||
) -> Generator[GenerationResponse | ToolCallResponse]:
|
||||
) -> Generator[GenerationResponse]:
|
||||
"""Create a sequence of GenerationResponses from text strings."""
|
||||
for i, text in enumerate(texts):
|
||||
is_last = i == len(texts) - 1
|
||||
@@ -22,10 +23,13 @@ def _make_responses(
|
||||
)
|
||||
|
||||
|
||||
def _dummy_parser(text: str) -> dict[str, Any]:
|
||||
def _dummier_parser(text: str) -> dict[str, Any]:
|
||||
return {"name": "test_fn", "arguments": {"arg": text}}
|
||||
|
||||
|
||||
_dummy_parser = make_mlx_parser("<tool_call>", "</tool_call>", _dummier_parser)
|
||||
|
||||
|
||||
class TestParseToolCalls:
|
||||
"""Tests for parse_tool_calls generator."""
|
||||
|
||||
@@ -35,8 +39,6 @@ class TestParseToolCalls:
|
||||
results = list(
|
||||
parse_tool_calls(
|
||||
_make_responses(texts, finish_on_last=False),
|
||||
"<tool_call>",
|
||||
"</tool_call>",
|
||||
_dummy_parser,
|
||||
)
|
||||
)
|
||||
@@ -50,8 +52,6 @@ class TestParseToolCalls:
|
||||
results = list(
|
||||
parse_tool_calls(
|
||||
_make_responses(texts),
|
||||
"<tool_call>",
|
||||
"</tool_call>",
|
||||
_dummy_parser,
|
||||
)
|
||||
)
|
||||
@@ -76,9 +76,7 @@ class TestParseToolCalls:
|
||||
results = list(
|
||||
parse_tool_calls(
|
||||
_make_responses(texts, finish_on_last=False),
|
||||
"<tool_call>",
|
||||
"</tool_call>",
|
||||
_failing_parser,
|
||||
make_mlx_parser("<tool_call>", "</tool_call>", _failing_parser),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user