mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-10 23:18:57 -05:00
Compare commits
232 Commits
linux-cpu-
...
ciaran/ima
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9b6a6059b9 | ||
|
|
fff08568ba | ||
|
|
edde03d74c | ||
|
|
d67eb9f2af | ||
|
|
bcf1fbd561 | ||
|
|
9067033f20 | ||
|
|
17ef7d8838 | ||
|
|
23962ab4e2 | ||
|
|
59c5de8256 | ||
|
|
a43a14da1f | ||
|
|
bb46c12878 | ||
|
|
79b5316efe | ||
|
|
40d49c2720 | ||
|
|
f87d06b1f1 | ||
|
|
617c6ffdcb | ||
|
|
acd246f49e | ||
|
|
d5536c1a2b | ||
|
|
2aba3fc9a9 | ||
|
|
c7cb22d546 | ||
|
|
5a429f4ab6 | ||
|
|
42c427a5bb | ||
|
|
4b8976be51 | ||
|
|
85bf4aba1c | ||
|
|
de12873b1a | ||
|
|
a86bb97d65 | ||
|
|
a5c6db7145 | ||
|
|
e74345bb09 | ||
|
|
58d1f159b7 | ||
|
|
8183225714 | ||
|
|
46f957ee5b | ||
|
|
a2f52e04e3 | ||
|
|
859960608b | ||
|
|
80ad016004 | ||
|
|
d8938e6e72 | ||
|
|
bd6a6cc6d3 | ||
|
|
a88d588de4 | ||
|
|
08e8a30fb7 | ||
|
|
d926df8f95 | ||
|
|
4ff550106d | ||
|
|
1d168dfe61 | ||
|
|
f813b9f5e1 | ||
|
|
0f96083d48 | ||
|
|
1a1b394f6d | ||
|
|
6ab5a9d3d4 | ||
|
|
90bf4608df | ||
|
|
f2a0fdf25c | ||
|
|
f574b3f57e | ||
|
|
5cca9d8493 | ||
|
|
3cd421079b | ||
|
|
d9eb4637ee | ||
|
|
19f52e80fd | ||
|
|
3f4162b732 | ||
|
|
cad86ee76e | ||
|
|
d7be6a09b0 | ||
|
|
79603e73ed | ||
|
|
78901cfe23 | ||
|
|
c0ac199ab8 | ||
|
|
b70d6abfa2 | ||
|
|
16bfab9bab | ||
|
|
28ee6f6370 | ||
|
|
6b299bab8f | ||
|
|
a3754a60b6 | ||
|
|
06039f93f5 | ||
|
|
fcfecc9cd8 | ||
|
|
ba798ae4f9 | ||
|
|
9a0e1e93a9 | ||
|
|
196f504c82 | ||
|
|
e3d89b8d63 | ||
|
|
cb8079525c | ||
|
|
cb03c62c4a | ||
|
|
0653668048 | ||
|
|
0054bc4c14 | ||
|
|
b7b682b7bb | ||
|
|
f7a651c1c1 | ||
|
|
98e8d74cea | ||
|
|
27567f8a4e | ||
|
|
28227bb45a | ||
|
|
7683d4a21f | ||
|
|
0a3cb77a29 | ||
|
|
3f5810c1fe | ||
|
|
fc62ae1b9b | ||
|
|
ec5bad4254 | ||
|
|
f9f54be32b | ||
|
|
36daf9183f | ||
|
|
5d38ffc77e | ||
|
|
1b4851765a | ||
|
|
8787eaf3df | ||
|
|
e1e3aa7a5e | ||
|
|
0fe5239273 | ||
|
|
7eddf7404b | ||
|
|
5f3bc30f17 | ||
|
|
90a7e6601d | ||
|
|
ce2691c8d3 | ||
|
|
076d2901e8 | ||
|
|
7a733b584c | ||
|
|
94fee6f2d2 | ||
|
|
ef4fe09424 | ||
|
|
2919bcf21d | ||
|
|
dd84cc9ca2 | ||
|
|
5a74d76d41 | ||
|
|
e115814c74 | ||
|
|
d85432d4f0 | ||
|
|
da823a2b02 | ||
|
|
8576f4252b | ||
|
|
7ca0bc5b55 | ||
|
|
db24f052d7 | ||
|
|
7b8382be10 | ||
|
|
d3685b0eb5 | ||
|
|
93f4bdc5f9 | ||
|
|
8eea0327b8 | ||
|
|
085358e5e0 | ||
|
|
546efe4dd2 | ||
|
|
4ddfb6e254 | ||
|
|
12f20fd94e | ||
|
|
f7ba70d5ae | ||
|
|
4ecad10a66 | ||
|
|
552ae776fe | ||
|
|
6e0a6e8956 | ||
|
|
e8b0a2124c | ||
|
|
129df1ec89 | ||
|
|
a87fe26973 | ||
|
|
a9ea223dc7 | ||
|
|
0af3349f2f | ||
|
|
20e3319a3e | ||
|
|
4c88fac266 | ||
|
|
e1d916f743 | ||
|
|
09c9b2e29f | ||
|
|
b6359a7199 | ||
|
|
b5a043f676 | ||
|
|
55e690fd49 | ||
|
|
9e4ffb11ec | ||
|
|
d665a8d05a | ||
|
|
cac77816be | ||
|
|
25b9c3369e | ||
|
|
c19c5b4080 | ||
|
|
9592f8b6b0 | ||
|
|
7d7c16ebc1 | ||
|
|
450d0ba923 | ||
|
|
ea64062362 | ||
|
|
206b12e912 | ||
|
|
eecc1da596 | ||
|
|
44e68e4498 | ||
|
|
f1548452fa | ||
|
|
97769c82a9 | ||
|
|
26e5b03285 | ||
|
|
8f93a1ff78 | ||
|
|
e07dcc43b9 | ||
|
|
f91d0797fb | ||
|
|
aaeebaf79e | ||
|
|
c3075a003e | ||
|
|
be796e55ac | ||
|
|
6e0c611f37 | ||
|
|
88996eddcb | ||
|
|
fb4fae51fa | ||
|
|
dbefc209f5 | ||
|
|
e6dd95524c | ||
|
|
c2a9e5e53b | ||
|
|
21587898bc | ||
|
|
b6f23d0b01 | ||
|
|
f00ba03f4b | ||
|
|
73e3713296 | ||
|
|
ecca6b4d20 | ||
|
|
8bac08a236 | ||
|
|
e7cca752fd | ||
|
|
540fe8b278 | ||
|
|
2972f4620c | ||
|
|
0ed81d8afa | ||
|
|
66a24d59b9 | ||
|
|
5dcc359dba | ||
|
|
c2a4d61865 | ||
|
|
ba12ee4897 | ||
|
|
bcd69a3b01 | ||
|
|
f5eb5d0338 | ||
|
|
058aff5145 | ||
|
|
5cb0bc6a63 | ||
|
|
c3aab450c6 | ||
|
|
cf27673e20 | ||
|
|
96c165e297 | ||
|
|
2a589177cd | ||
|
|
f782b619b6 | ||
|
|
dc661e4b5e | ||
|
|
8b7d8ef394 | ||
|
|
7dd2b328c8 | ||
|
|
73a165702d | ||
|
|
0c76978b35 | ||
|
|
25188c845e | ||
|
|
df94169aba | ||
|
|
a2d4c0de2a | ||
|
|
2edbc7e026 | ||
|
|
8f6e360d21 | ||
|
|
085b966a5f | ||
|
|
c64a55bfed | ||
|
|
fee716faab | ||
|
|
b88c89ee9c | ||
|
|
9ef7b913e2 | ||
|
|
0daa4b36db | ||
|
|
3c2da43792 | ||
|
|
8c4c53b50a | ||
|
|
b2beb4c9cd | ||
|
|
098a11b262 | ||
|
|
bedb9045a0 | ||
|
|
8e23841b4e | ||
|
|
4420eac10d | ||
|
|
d0772e9e0f | ||
|
|
8d861168f1 | ||
|
|
242648dff4 | ||
|
|
9b06b754cb | ||
|
|
1603984f45 | ||
|
|
f9418843f8 | ||
|
|
877e7196c3 | ||
|
|
db7c4670b9 | ||
|
|
4f6fcd9e93 | ||
|
|
839b67f318 | ||
|
|
47b8e0ce12 | ||
|
|
17f9b583a4 | ||
|
|
844bcc7ce6 | ||
|
|
c1be5184b2 | ||
|
|
1ec550dff1 | ||
|
|
283c0e39e4 | ||
|
|
35be4c55c3 | ||
|
|
31d4cd8409 | ||
|
|
8a6da58404 | ||
|
|
16e2bfd3b3 | ||
|
|
ade3ee7ec5 | ||
|
|
fea42473dd | ||
|
|
ca7adcc2a8 | ||
|
|
9d9e24f969 | ||
|
|
b5d424b658 | ||
|
|
b465134012 | ||
|
|
eabdcab978 | ||
|
|
8e9332d6a7 | ||
|
|
4b65d5f896 |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -7,6 +7,8 @@ digest.txt
|
||||
# nix
|
||||
.direnv/
|
||||
|
||||
# IDEA (PyCharm)
|
||||
.idea
|
||||
|
||||
# xcode / macos
|
||||
*.xcuserstate
|
||||
|
||||
83
README.md
83
README.md
@@ -61,10 +61,10 @@ Devices running exo automatically discover each other, without needing any manua
|
||||
|
||||
There are two ways to run exo:
|
||||
|
||||
### Run from Source (Mac & Linux)
|
||||
### Run from Source (macOS)
|
||||
|
||||
**Prerequisites:**
|
||||
- [brew](https://github.com/Homebrew/brew) (for simple package management on MacOS)
|
||||
- [brew](https://github.com/Homebrew/brew) (for simple package management on macOS)
|
||||
|
||||
```bash
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
|
||||
@@ -98,6 +98,62 @@ uv run exo
|
||||
|
||||
This starts the exo dashboard and API at http://localhost:52415/
|
||||
|
||||
### Run from Source (Linux)
|
||||
|
||||
**Prerequisites:**
|
||||
|
||||
- [uv](https://github.com/astral-sh/uv) (for Python dependency management)
|
||||
- [node](https://github.com/nodejs/node) (for building the dashboard) - version 18 or higher
|
||||
- [rust](https://github.com/rust-lang/rustup) (to build Rust bindings, nightly for now)
|
||||
|
||||
**Installation methods:**
|
||||
|
||||
**Option 1: Using system package manager (Ubuntu/Debian example):**
|
||||
```bash
|
||||
# Install Node.js and npm
|
||||
sudo apt update
|
||||
sudo apt install nodejs npm
|
||||
|
||||
# Install uv
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
# Install Rust (using rustup)
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
|
||||
rustup toolchain install nightly
|
||||
```
|
||||
|
||||
**Option 2: Using Homebrew on Linux (if preferred):**
|
||||
```bash
|
||||
# Install Homebrew on Linux
|
||||
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
|
||||
|
||||
# Install dependencies
|
||||
brew install uv node
|
||||
|
||||
# Install Rust (using rustup)
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
|
||||
rustup toolchain install nightly
|
||||
```
|
||||
|
||||
**Note:** The `macmon` package is macOS-only and not required for Linux.
|
||||
|
||||
Clone the repo, build the dashboard, and run exo:
|
||||
|
||||
```bash
|
||||
# Clone exo
|
||||
git clone https://github.com/exo-explore/exo
|
||||
|
||||
# Build dashboard
|
||||
cd exo/dashboard && npm install && npm run build && cd ..
|
||||
|
||||
# Run exo
|
||||
uv run exo
|
||||
```
|
||||
|
||||
This starts the exo dashboard and API at http://localhost:52415/
|
||||
|
||||
**Important note for Linux users:** Currently, exo runs on CPU on Linux. GPU support for Linux platforms is under development. If you'd like to see support for your specific Linux hardware, please [search for existing feature requests](https://github.com/exo-explore/exo/issues) or create a new one.
|
||||
|
||||
### macOS App
|
||||
|
||||
exo ships a macOS app that runs in the background on your Mac.
|
||||
@@ -112,6 +168,29 @@ The app will ask for permission to modify system settings and install a new Netw
|
||||
|
||||
---
|
||||
|
||||
### Enabling RDMA on macOS
|
||||
|
||||
RDMA is a new capability added to macOS 26.2. It works on any Mac with Thunderbolt 5 (M4 Pro Mac Mini, M4 Max Mac Studio, M4 Max MacBook Pro, M3 Ultra Mac Studio).
|
||||
|
||||
Note that on Mac Studio, you cannot use the Thunderbolt 5 port next to the Ethernet port.
|
||||
|
||||
To enable RDMA on macOS, follow these steps:
|
||||
|
||||
1. Shut down your Mac.
|
||||
2. Hold down the power button for 10 seconds until the boot menu appears.
|
||||
3. Select "Options" to enter Recovery mode.
|
||||
4. When the Recovery UI appears, open the Terminal from the Utilities menu.
|
||||
5. In the Terminal, type:
|
||||
```
|
||||
rdma_ctl enable
|
||||
```
|
||||
and press Enter.
|
||||
6. Reboot your Mac.
|
||||
|
||||
After that, RDMA will be enabled in macOS and exo will take care of the rest.
|
||||
|
||||
---
|
||||
|
||||
### Using the API
|
||||
|
||||
If you prefer to interact with exo via the API, here is an example creating an instance of a small model (`mlx-community/Llama-3.2-1B-Instruct-4bit`), sending a chat completions request and deleting the instance.
|
||||
|
||||
@@ -20,6 +20,8 @@ struct ContentView: View {
|
||||
@State private var showDebugInfo = false
|
||||
@State private var bugReportInFlight = false
|
||||
@State private var bugReportMessage: String?
|
||||
@State private var showAdvancedOptions = false
|
||||
@State private var pendingNamespace: String = ""
|
||||
|
||||
var body: some View {
|
||||
VStack(alignment: .leading, spacing: 12) {
|
||||
@@ -49,7 +51,7 @@ struct ContentView: View {
|
||||
|
||||
private var topologySection: some View {
|
||||
Group {
|
||||
if let topology = stateService.latestSnapshot?.topologyViewModel(), !topology.nodes.isEmpty {
|
||||
if let topology = stateService.latestSnapshot?.topologyViewModel(localNodeId: stateService.localNodeId), !topology.nodes.isEmpty {
|
||||
TopologyMiniView(topology: topology)
|
||||
}
|
||||
}
|
||||
@@ -197,6 +199,8 @@ struct ContentView: View {
|
||||
updater.checkForUpdates()
|
||||
}
|
||||
.padding(.bottom, 8)
|
||||
advancedOptionsSection
|
||||
.padding(.bottom, 8)
|
||||
debugSection
|
||||
.padding(.bottom, 8)
|
||||
controlButton(title: "Quit", tint: .secondary) {
|
||||
@@ -327,6 +331,47 @@ struct ContentView: View {
|
||||
}
|
||||
}
|
||||
|
||||
private var advancedOptionsSection: some View {
|
||||
VStack(alignment: .leading, spacing: 6) {
|
||||
HStack {
|
||||
Text("Advanced Options")
|
||||
.font(.caption)
|
||||
.foregroundColor(.secondary)
|
||||
Spacer()
|
||||
collapseButton(isExpanded: $showAdvancedOptions)
|
||||
}
|
||||
.animation(nil, value: showAdvancedOptions)
|
||||
if showAdvancedOptions {
|
||||
VStack(alignment: .leading, spacing: 8) {
|
||||
VStack(alignment: .leading, spacing: 4) {
|
||||
Text("Cluster Namespace")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
HStack {
|
||||
TextField("optional", text: $pendingNamespace)
|
||||
.textFieldStyle(.roundedBorder)
|
||||
.font(.caption2)
|
||||
.onAppear {
|
||||
pendingNamespace = controller.customNamespace
|
||||
}
|
||||
Button("Save & Restart") {
|
||||
controller.customNamespace = pendingNamespace
|
||||
if controller.status == .running || controller.status == .starting {
|
||||
controller.restart()
|
||||
}
|
||||
}
|
||||
.font(.caption2)
|
||||
.disabled(pendingNamespace == controller.customNamespace)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
.transition(.opacity)
|
||||
}
|
||||
}
|
||||
.animation(.easeInOut(duration: 0.25), value: showAdvancedOptions)
|
||||
}
|
||||
|
||||
private var debugSection: some View {
|
||||
VStack(alignment: .leading, spacing: 6) {
|
||||
HStack {
|
||||
|
||||
@@ -2,6 +2,8 @@ import AppKit
|
||||
import Combine
|
||||
import Foundation
|
||||
|
||||
private let customNamespaceKey = "EXOCustomNamespace"
|
||||
|
||||
@MainActor
|
||||
final class ExoProcessController: ObservableObject {
|
||||
enum Status: Equatable {
|
||||
@@ -27,6 +29,13 @@ final class ExoProcessController: ObservableObject {
|
||||
@Published private(set) var status: Status = .stopped
|
||||
@Published private(set) var lastError: String?
|
||||
@Published private(set) var launchCountdownSeconds: Int?
|
||||
@Published var customNamespace: String = {
|
||||
return UserDefaults.standard.string(forKey: customNamespaceKey) ?? ""
|
||||
}() {
|
||||
didSet {
|
||||
UserDefaults.standard.set(customNamespace, forKey: customNamespaceKey)
|
||||
}
|
||||
}
|
||||
|
||||
private var process: Process?
|
||||
private var runtimeDirectoryURL: URL?
|
||||
@@ -180,7 +189,7 @@ final class ExoProcessController: ObservableObject {
|
||||
private func makeEnvironment(for runtimeURL: URL) -> [String: String] {
|
||||
var environment = ProcessInfo.processInfo.environment
|
||||
environment["EXO_RUNTIME_DIR"] = runtimeURL.path
|
||||
environment["EXO_LIBP2P_NAMESPACE"] = buildTag()
|
||||
environment["EXO_LIBP2P_NAMESPACE"] = computeNamespace()
|
||||
|
||||
var paths: [String] = []
|
||||
if let existing = environment["PATH"], !existing.isEmpty {
|
||||
@@ -217,6 +226,12 @@ final class ExoProcessController: ObservableObject {
|
||||
}
|
||||
return "dev"
|
||||
}
|
||||
|
||||
private func computeNamespace() -> String {
|
||||
let base = buildTag()
|
||||
let custom = customNamespace.trimmingCharacters(in: .whitespaces)
|
||||
return custom.isEmpty ? base : custom
|
||||
}
|
||||
}
|
||||
|
||||
struct RuntimeError: LocalizedError {
|
||||
|
||||
@@ -82,7 +82,6 @@ struct BugReportService {
|
||||
}
|
||||
|
||||
private func loadCredentials() throws -> AWSConfig {
|
||||
// These credentials are write-only and necessary to receive bug reports from users
|
||||
return AWSConfig(
|
||||
accessKey: "AKIAYEKP5EMXTOBYDGHX",
|
||||
secretKey: "Ep5gIlUZ1o8ssTLQwmyy34yPGfTPEYQ4evE8NdPE",
|
||||
|
||||
@@ -7,6 +7,7 @@ final class ClusterStateService: ObservableObject {
|
||||
@Published private(set) var lastError: String?
|
||||
@Published private(set) var lastActionMessage: String?
|
||||
@Published private(set) var modelOptions: [ModelOption] = []
|
||||
@Published private(set) var localNodeId: String?
|
||||
|
||||
private var timer: Timer?
|
||||
private let decoder: JSONDecoder
|
||||
@@ -29,6 +30,7 @@ final class ClusterStateService: ObservableObject {
|
||||
func startPolling(interval: TimeInterval = 0.5) {
|
||||
stopPolling()
|
||||
Task {
|
||||
await fetchLocalNodeId()
|
||||
await fetchModels()
|
||||
await fetchSnapshot()
|
||||
}
|
||||
@@ -46,9 +48,31 @@ final class ClusterStateService: ObservableObject {
|
||||
latestSnapshot = nil
|
||||
lastError = nil
|
||||
lastActionMessage = nil
|
||||
localNodeId = nil
|
||||
}
|
||||
|
||||
private func fetchLocalNodeId() async {
|
||||
do {
|
||||
let url = baseURL.appendingPathComponent("node_id")
|
||||
var request = URLRequest(url: url)
|
||||
request.cachePolicy = .reloadIgnoringLocalCacheData
|
||||
let (data, response) = try await session.data(for: request)
|
||||
guard let httpResponse = response as? HTTPURLResponse, (200..<300).contains(httpResponse.statusCode) else {
|
||||
return
|
||||
}
|
||||
if let nodeId = try? decoder.decode(String.self, from: data) {
|
||||
localNodeId = nodeId
|
||||
}
|
||||
} catch {
|
||||
// Silently ignore - localNodeId will remain nil and retry on next poll
|
||||
}
|
||||
}
|
||||
|
||||
private func fetchSnapshot() async {
|
||||
// Retry fetching local node ID if not yet set
|
||||
if localNodeId == nil {
|
||||
await fetchLocalNodeId()
|
||||
}
|
||||
do {
|
||||
var request = URLRequest(url: endpoint)
|
||||
request.cachePolicy = .reloadIgnoringLocalCacheData
|
||||
|
||||
@@ -85,7 +85,7 @@ struct TopologyViewModel {
|
||||
}
|
||||
|
||||
extension ClusterState {
|
||||
func topologyViewModel() -> TopologyViewModel? {
|
||||
func topologyViewModel(localNodeId: String?) -> TopologyViewModel? {
|
||||
let topologyNodeIds = Set(topology?.nodes.map(\.nodeId) ?? [])
|
||||
let allNodes = nodeViewModels().filter { topologyNodeIds.isEmpty || topologyNodeIds.contains($0.id) }
|
||||
guard !allNodes.isEmpty else { return nil }
|
||||
@@ -105,6 +105,11 @@ extension ClusterState {
|
||||
orderedNodes = allNodes
|
||||
}
|
||||
|
||||
// Rotate so the local node (from /node_id API) is first
|
||||
if let localId = localNodeId, let index = orderedNodes.firstIndex(where: { $0.id == localId }) {
|
||||
orderedNodes = Array(orderedNodes[index...]) + Array(orderedNodes[..<index])
|
||||
}
|
||||
|
||||
let nodeIds = Set(orderedNodes.map(\.id))
|
||||
let edgesArray: [TopologyEdgeViewModel] = topology?.connections?.compactMap { connection in
|
||||
guard nodeIds.contains(connection.localNodeId), nodeIds.contains(connection.sendBackNodeId) else { return nil }
|
||||
@@ -112,10 +117,7 @@ extension ClusterState {
|
||||
} ?? []
|
||||
let edges = Set(edgesArray)
|
||||
|
||||
let topologyRootId = topology?.nodes.first?.nodeId
|
||||
let currentId = orderedNodes.first(where: { $0.id == topologyRootId })?.id ?? orderedNodes.first?.id
|
||||
|
||||
return TopologyViewModel(nodes: orderedNodes, edges: Array(edges), currentNodeId: currentId)
|
||||
return TopologyViewModel(nodes: orderedNodes, edges: Array(edges), currentNodeId: localNodeId)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
48
dashboard/package-lock.json
generated
48
dashboard/package-lock.json
generated
@@ -9,6 +9,8 @@
|
||||
"version": "1.0.0",
|
||||
"dependencies": {
|
||||
"highlight.js": "^11.11.1",
|
||||
"katex": "^0.16.27",
|
||||
"marked": "^17.0.1",
|
||||
"mode-watcher": "^1.1.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
@@ -861,7 +863,6 @@
|
||||
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@standard-schema/spec": "^1.0.0",
|
||||
"@sveltejs/acorn-typescript": "^1.0.5",
|
||||
@@ -901,7 +902,6 @@
|
||||
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
|
||||
"debug": "^4.4.1",
|
||||
@@ -1518,7 +1518,6 @@
|
||||
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"undici-types": "~6.21.0"
|
||||
}
|
||||
@@ -1528,7 +1527,6 @@
|
||||
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
|
||||
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"acorn": "bin/acorn"
|
||||
},
|
||||
@@ -1941,7 +1939,6 @@
|
||||
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
|
||||
"dev": true,
|
||||
"license": "ISC",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
}
|
||||
@@ -2254,6 +2251,31 @@
|
||||
"jiti": "lib/jiti-cli.mjs"
|
||||
}
|
||||
},
|
||||
"node_modules/katex": {
|
||||
"version": "0.16.27",
|
||||
"resolved": "https://registry.npmjs.org/katex/-/katex-0.16.27.tgz",
|
||||
"integrity": "sha512-aeQoDkuRWSqQN6nSvVCEFvfXdqo1OQiCmmW1kc9xSdjutPv7BGO7pqY9sQRJpMOGrEdfDgF2TfRXe5eUAD2Waw==",
|
||||
"funding": [
|
||||
"https://opencollective.com/katex",
|
||||
"https://github.com/sponsors/katex"
|
||||
],
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"commander": "^8.3.0"
|
||||
},
|
||||
"bin": {
|
||||
"katex": "cli.js"
|
||||
}
|
||||
},
|
||||
"node_modules/katex/node_modules/commander": {
|
||||
"version": "8.3.0",
|
||||
"resolved": "https://registry.npmjs.org/commander/-/commander-8.3.0.tgz",
|
||||
"integrity": "sha512-OkTL9umf+He2DZkUq8f8J9of7yL6RJKI24dVITBmNfZBmri9zYZQrKkuXiKhyfPSu8tUhnVBB1iKXevvnlR4Ww==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">= 12"
|
||||
}
|
||||
},
|
||||
"node_modules/kleur": {
|
||||
"version": "4.1.5",
|
||||
"resolved": "https://registry.npmjs.org/kleur/-/kleur-4.1.5.tgz",
|
||||
@@ -2540,6 +2562,18 @@
|
||||
"@jridgewell/sourcemap-codec": "^1.5.5"
|
||||
}
|
||||
},
|
||||
"node_modules/marked": {
|
||||
"version": "17.0.1",
|
||||
"resolved": "https://registry.npmjs.org/marked/-/marked-17.0.1.tgz",
|
||||
"integrity": "sha512-boeBdiS0ghpWcSwoNm/jJBwdpFaMnZWRzjA6SkUMYb40SVaN1x7mmfGKp0jvexGcx+7y2La5zRZsYFZI6Qpypg==",
|
||||
"license": "MIT",
|
||||
"bin": {
|
||||
"marked": "bin/marked.js"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 20"
|
||||
}
|
||||
},
|
||||
"node_modules/mode-watcher": {
|
||||
"version": "1.1.0",
|
||||
"resolved": "https://registry.npmjs.org/mode-watcher/-/mode-watcher-1.1.0.tgz",
|
||||
@@ -2612,7 +2646,6 @@
|
||||
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
},
|
||||
@@ -2800,7 +2833,6 @@
|
||||
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
|
||||
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@jridgewell/remapping": "^2.3.4",
|
||||
"@jridgewell/sourcemap-codec": "^1.5.0",
|
||||
@@ -2945,7 +2977,6 @@
|
||||
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
|
||||
"dev": true,
|
||||
"license": "Apache-2.0",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"tsc": "bin/tsc",
|
||||
"tsserver": "bin/tsserver"
|
||||
@@ -2967,7 +2998,6 @@
|
||||
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"esbuild": "^0.25.0",
|
||||
"fdir": "^6.4.4",
|
||||
|
||||
@@ -27,7 +27,8 @@
|
||||
},
|
||||
"dependencies": {
|
||||
"highlight.js": "^11.11.1",
|
||||
"katex": "^0.16.27",
|
||||
"marked": "^17.0.1",
|
||||
"mode-watcher": "^1.1.0"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
<script lang="ts">
|
||||
import { isLoading, sendMessage, selectedChatModel, setSelectedChatModel, instances, ttftMs, tps, totalTokens } from '$lib/stores/app.svelte';
|
||||
import { isLoading, sendMessage, generateImage, selectedChatModel, setSelectedChatModel, instances, ttftMs, tps, totalTokens } from '$lib/stores/app.svelte';
|
||||
import ChatAttachments from './ChatAttachments.svelte';
|
||||
import type { ChatUploadedFile } from '$lib/types/files';
|
||||
import { processUploadedFiles, getAcceptString } from '$lib/types/files';
|
||||
@@ -10,6 +10,7 @@
|
||||
showHelperText?: boolean;
|
||||
autofocus?: boolean;
|
||||
showModelSelector?: boolean;
|
||||
modelTasks?: Record<string, string[]>;
|
||||
}
|
||||
|
||||
let {
|
||||
@@ -17,7 +18,8 @@
|
||||
placeholder = 'Ask anything',
|
||||
showHelperText = false,
|
||||
autofocus = true,
|
||||
showModelSelector = false
|
||||
showModelSelector = false,
|
||||
modelTasks = {}
|
||||
}: Props = $props();
|
||||
|
||||
let message = $state('');
|
||||
@@ -48,13 +50,29 @@
|
||||
// Accept all supported file types
|
||||
const acceptString = getAcceptString(['image', 'text', 'pdf']);
|
||||
|
||||
// Check if a model supports image generation
|
||||
function modelSupportsImageGeneration(modelId: string): boolean {
|
||||
const tasks = modelTasks[modelId] || [];
|
||||
return tasks.includes('TextToImage') || tasks.includes('ImageToImage');
|
||||
}
|
||||
|
||||
// Check if the currently selected model supports image generation
|
||||
const isImageModel = $derived(() => {
|
||||
if (!currentModel) return false;
|
||||
return modelSupportsImageGeneration(currentModel);
|
||||
});
|
||||
|
||||
// Extract available models from running instances
|
||||
const availableModels = $derived(() => {
|
||||
const models: Array<{id: string, label: string}> = [];
|
||||
const models: Array<{id: string, label: string, isImageModel: boolean}> = [];
|
||||
for (const [, instance] of Object.entries(instanceData)) {
|
||||
const modelId = getInstanceModelId(instance);
|
||||
if (modelId && modelId !== 'Unknown' && !models.some(m => m.id === modelId)) {
|
||||
models.push({ id: modelId, label: modelId.split('/').pop() || modelId });
|
||||
models.push({
|
||||
id: modelId,
|
||||
label: modelId.split('/').pop() || modelId,
|
||||
isImageModel: modelSupportsImageGeneration(modelId)
|
||||
});
|
||||
}
|
||||
}
|
||||
return models;
|
||||
@@ -139,6 +157,11 @@
|
||||
}
|
||||
|
||||
function handleKeydown(event: KeyboardEvent) {
|
||||
// Prevent form submission during IME composition (e.g., Chinese, Japanese, Korean input)
|
||||
if (event.isComposing || event.keyCode === 229) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (event.key === 'Enter' && !event.shiftKey) {
|
||||
event.preventDefault();
|
||||
handleSubmit();
|
||||
@@ -155,7 +178,12 @@
|
||||
uploadedFiles = [];
|
||||
resetTextareaHeight();
|
||||
|
||||
sendMessage(content, files);
|
||||
// Use image generation for image models
|
||||
if (isImageModel() && content) {
|
||||
generateImage(content);
|
||||
} else {
|
||||
sendMessage(content, files);
|
||||
}
|
||||
|
||||
// Refocus the textarea after sending
|
||||
setTimeout(() => textareaRef?.focus(), 10);
|
||||
@@ -292,7 +320,14 @@
|
||||
{:else}
|
||||
<span class="w-3"></span>
|
||||
{/if}
|
||||
<span class="truncate">{model.label}</span>
|
||||
{#if model.isImageModel}
|
||||
<svg class="w-3.5 h-3.5 flex-shrink-0 text-exo-yellow" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2" aria-label="Image generation model">
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
|
||||
<circle cx="8.5" cy="8.5" r="1.5"/>
|
||||
<polyline points="21 15 16 10 5 21"/>
|
||||
</svg>
|
||||
{/if}
|
||||
<span class="truncate flex-1">{model.label}</span>
|
||||
</button>
|
||||
{/each}
|
||||
</div>
|
||||
@@ -352,7 +387,7 @@
|
||||
onkeydown={handleKeydown}
|
||||
oninput={handleInput}
|
||||
onpaste={handlePaste}
|
||||
{placeholder}
|
||||
placeholder={isImageModel() ? 'Describe the image you want to generate...' : placeholder}
|
||||
disabled={loading}
|
||||
rows={1}
|
||||
class="flex-1 resize-none bg-transparent text-foreground placeholder:text-exo-light-gray/60 placeholder:text-sm placeholder:tracking-[0.15em] placeholder:leading-7 focus:outline-none focus:ring-0 focus:border-none disabled:opacity-50 text-sm leading-7 font-mono"
|
||||
@@ -366,14 +401,23 @@
|
||||
{!canSend || loading
|
||||
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
|
||||
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
|
||||
aria-label="Send message"
|
||||
aria-label={isImageModel() ? "Generate image" : "Send message"}
|
||||
>
|
||||
{#if loading}
|
||||
<span class="inline-flex items-center gap-1 sm:gap-2">
|
||||
<span class="w-2.5 h-2.5 sm:w-3 sm:h-3 border-2 border-current border-t-transparent rounded-full animate-spin"></span>
|
||||
<span class="hidden sm:inline">PROCESSING</span>
|
||||
<span class="hidden sm:inline">{isImageModel() ? 'GENERATING' : 'PROCESSING'}</span>
|
||||
<span class="sm:hidden">...</span>
|
||||
</span>
|
||||
{:else if isImageModel()}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg class="w-3.5 h-3.5" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
|
||||
<circle cx="8.5" cy="8.5" r="1.5"/>
|
||||
<polyline points="21 15 16 10 5 21"/>
|
||||
</svg>
|
||||
<span>GENERATE</span>
|
||||
</span>
|
||||
{:else}
|
||||
SEND
|
||||
{/if}
|
||||
|
||||
@@ -8,89 +8,80 @@
|
||||
regenerateLastResponse
|
||||
} from '$lib/stores/app.svelte';
|
||||
import type { MessageAttachment } from '$lib/stores/app.svelte';
|
||||
import { tick, onDestroy } from 'svelte';
|
||||
import MarkdownContent from './MarkdownContent.svelte';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
scrollParent?: HTMLElement | null;
|
||||
}
|
||||
interface Props {
|
||||
class?: string;
|
||||
scrollParent?: HTMLElement | null;
|
||||
}
|
||||
|
||||
let { class: className = '', scrollParent = null }: Props = $props();
|
||||
let { class: className = '', scrollParent = null }: Props = $props();
|
||||
|
||||
const messageList = $derived(messages());
|
||||
const response = $derived(currentResponse());
|
||||
const loading = $derived(isLoading());
|
||||
|
||||
// Ref for scroll anchor at bottom
|
||||
let scrollAnchorRef: HTMLDivElement | undefined = $state();
|
||||
// Scroll management - user controls scroll, show button when not at bottom
|
||||
const SCROLL_THRESHOLD = 100;
|
||||
let showScrollButton = $state(false);
|
||||
let lastMessageCount = 0;
|
||||
let containerRef: HTMLDivElement | undefined = $state();
|
||||
|
||||
// Scroll management
|
||||
const SCROLL_BOTTOM_THRESHOLD = 120;
|
||||
let autoScrollEnabled = true;
|
||||
let currentScrollEl: HTMLElement | null = null;
|
||||
|
||||
function resolveScrollElement(): HTMLElement | null {
|
||||
if (scrollParent) return scrollParent;
|
||||
let node: HTMLElement | null = scrollAnchorRef?.parentElement as HTMLElement | null;
|
||||
while (node) {
|
||||
const isScrollable = node.scrollHeight > node.clientHeight + 1;
|
||||
if (isScrollable) return node;
|
||||
node = node.parentElement;
|
||||
function getScrollContainer(): HTMLElement | null {
|
||||
if (scrollParent) return scrollParent;
|
||||
return containerRef?.parentElement ?? null;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
function handleScroll() {
|
||||
if (!currentScrollEl) return;
|
||||
const distanceFromBottom = currentScrollEl.scrollHeight - currentScrollEl.scrollTop - currentScrollEl.clientHeight;
|
||||
const isNearBottom = distanceFromBottom < SCROLL_BOTTOM_THRESHOLD;
|
||||
autoScrollEnabled = isNearBottom;
|
||||
}
|
||||
|
||||
function attachScrollListener() {
|
||||
const nextEl = resolveScrollElement();
|
||||
if (currentScrollEl === nextEl) return;
|
||||
if (currentScrollEl) {
|
||||
currentScrollEl.removeEventListener('scroll', handleScroll);
|
||||
function isNearBottom(el: HTMLElement): boolean {
|
||||
return el.scrollHeight - el.scrollTop - el.clientHeight < SCROLL_THRESHOLD;
|
||||
}
|
||||
currentScrollEl = nextEl;
|
||||
if (currentScrollEl) {
|
||||
currentScrollEl.addEventListener('scroll', handleScroll);
|
||||
// Initialize state based on current position
|
||||
handleScroll();
|
||||
}
|
||||
}
|
||||
|
||||
onDestroy(() => {
|
||||
if (currentScrollEl) {
|
||||
currentScrollEl.removeEventListener('scroll', handleScroll);
|
||||
}
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
// Re-evaluate scroll container if prop changes or after mount
|
||||
scrollParent;
|
||||
attachScrollListener();
|
||||
});
|
||||
|
||||
// Auto-scroll to bottom when messages change or response updates, but only if user is near bottom
|
||||
$effect(() => {
|
||||
// Track these values to trigger effect
|
||||
const _ = messageList.length;
|
||||
const __ = response;
|
||||
const ___ = loading;
|
||||
|
||||
tick().then(() => {
|
||||
const el = currentScrollEl ?? resolveScrollElement();
|
||||
if (!el || !scrollAnchorRef) return;
|
||||
const distanceFromBottom = el.scrollHeight - el.scrollTop - el.clientHeight;
|
||||
const isNearBottom = distanceFromBottom < SCROLL_BOTTOM_THRESHOLD;
|
||||
if (autoScrollEnabled || isNearBottom) {
|
||||
scrollAnchorRef.scrollIntoView({ behavior: 'smooth', block: 'end' });
|
||||
autoScrollEnabled = true;
|
||||
function scrollToBottom() {
|
||||
const el = getScrollContainer();
|
||||
if (el) {
|
||||
el.scrollTo({ top: el.scrollHeight, behavior: 'smooth' });
|
||||
}
|
||||
}
|
||||
|
||||
function updateScrollButtonVisibility() {
|
||||
const el = getScrollContainer();
|
||||
if (!el) return;
|
||||
showScrollButton = !isNearBottom(el);
|
||||
}
|
||||
|
||||
// Attach scroll listener
|
||||
$effect(() => {
|
||||
const el = scrollParent ?? containerRef?.parentElement;
|
||||
if (!el) return;
|
||||
|
||||
el.addEventListener('scroll', updateScrollButtonVisibility, { passive: true });
|
||||
// Initial check
|
||||
updateScrollButtonVisibility();
|
||||
return () => el.removeEventListener('scroll', updateScrollButtonVisibility);
|
||||
});
|
||||
|
||||
// Auto-scroll when user sends a new message
|
||||
$effect(() => {
|
||||
const count = messageList.length;
|
||||
if (count > lastMessageCount) {
|
||||
const el = getScrollContainer();
|
||||
if (el) {
|
||||
requestAnimationFrame(() => {
|
||||
el.scrollTo({ top: el.scrollHeight, behavior: 'smooth' });
|
||||
});
|
||||
}
|
||||
}
|
||||
lastMessageCount = count;
|
||||
});
|
||||
|
||||
// Update scroll button visibility when content changes
|
||||
$effect(() => {
|
||||
// Track response to trigger re-check during streaming
|
||||
const _ = response;
|
||||
|
||||
// Small delay to let DOM update
|
||||
requestAnimationFrame(() => updateScrollButtonVisibility());
|
||||
});
|
||||
});
|
||||
|
||||
// Edit state
|
||||
let editingMessageId = $state<string | null>(null);
|
||||
@@ -231,7 +222,7 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
<div class="flex flex-col gap-4 sm:gap-6 {className}">
|
||||
{#each messageList as message (message.id)}
|
||||
<div class="group flex {message.role === 'user' ? 'justify-end' : 'justify-start'}">
|
||||
<div class="{message.role === 'user' ? 'max-w-[85%] sm:max-w-[70%] flex flex-col items-end' : 'max-w-[95%] sm:max-w-[85%]'}">
|
||||
<div class="{message.role === 'user' ? 'max-w-[85%] sm:max-w-[70%] flex flex-col items-end' : 'w-full max-w-[98%] sm:max-w-[95%]'}">
|
||||
{#if message.role === 'assistant'}
|
||||
<!-- Assistant message header -->
|
||||
<div class="flex items-center gap-1.5 sm:gap-2 mb-1.5 sm:mb-2">
|
||||
@@ -305,7 +296,7 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
{:else}
|
||||
<div class="{message.role === 'user'
|
||||
? 'command-panel rounded-lg rounded-tr-sm inline-block'
|
||||
: 'command-panel rounded-lg rounded-tl-sm border-l-2 border-l-exo-yellow/50 inline-block'}">
|
||||
: 'command-panel rounded-lg rounded-tl-sm border-l-2 border-l-exo-yellow/50 block w-full'}">
|
||||
|
||||
{#if message.role === 'user'}
|
||||
<!-- User message styling -->
|
||||
@@ -331,7 +322,7 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
{/if}
|
||||
|
||||
{#if message.content}
|
||||
<div class="text-sm text-foreground font-mono tracking-wide whitespace-pre-wrap break-words leading-relaxed">
|
||||
<div class="text-xs text-foreground font-mono tracking-wide whitespace-pre-wrap break-words leading-relaxed">
|
||||
{message.content}
|
||||
</div>
|
||||
{/if}
|
||||
@@ -360,7 +351,7 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
</svg>
|
||||
<span>Thinking...</span>
|
||||
</span>
|
||||
<span class="text-[10px] tracking-[0.2em] text-exo-light-gray/60">
|
||||
<span class="text-[10px] tracking-[0.2em] text-exo-light-gray/60 ml-4">
|
||||
{isThinkingExpanded(message.id) ? 'HIDE' : 'SHOW'}
|
||||
</span>
|
||||
</button>
|
||||
@@ -374,10 +365,58 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
<div class="text-sm text-foreground font-mono tracking-wide whitespace-pre-wrap break-words leading-relaxed">
|
||||
{message.content || (loading ? response : '')}
|
||||
{#if loading && !message.content}
|
||||
<span class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"></span>
|
||||
|
||||
<!-- Generated Images -->
|
||||
{#if message.attachments?.some(a => a.type === 'generated-image')}
|
||||
<div class="mb-3">
|
||||
{#each message.attachments.filter(a => a.type === 'generated-image') as attachment}
|
||||
<div class="relative group/img inline-block">
|
||||
<img
|
||||
src={attachment.preview}
|
||||
alt=""
|
||||
class="max-w-full max-h-[512px] rounded-lg border border-exo-yellow/20 shadow-lg shadow-black/20"
|
||||
/>
|
||||
<!-- Download button overlay -->
|
||||
<button
|
||||
type="button"
|
||||
class="absolute top-2 right-2 p-2 rounded-lg bg-exo-dark-gray/80 border border-exo-yellow/30 text-exo-yellow opacity-0 group-hover/img:opacity-100 transition-opacity hover:bg-exo-dark-gray hover:border-exo-yellow/50 cursor-pointer"
|
||||
onclick={() => {
|
||||
if (attachment.preview) {
|
||||
const link = document.createElement('a');
|
||||
link.href = attachment.preview;
|
||||
link.download = `generated-image-${Date.now()}.png`;
|
||||
link.click();
|
||||
}
|
||||
}}
|
||||
title="Download image"
|
||||
>
|
||||
<svg class="w-4 h-4" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M4 16v1a3 3 0 003 3h10a3 3 0 003-3v-1m-4-4l-4 4m0 0l-4-4m4 4V4" />
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<div class="text-xs text-foreground">
|
||||
{#if message.content === 'Generating image...'}
|
||||
<div class="flex items-center gap-3 text-exo-yellow">
|
||||
<div class="relative">
|
||||
<div class="w-8 h-8 border-2 border-exo-yellow/30 border-t-exo-yellow rounded-full animate-spin"></div>
|
||||
<svg class="absolute inset-0 w-8 h-8 p-1.5 text-exo-yellow/60" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
|
||||
<circle cx="8.5" cy="8.5" r="1.5"/>
|
||||
<polyline points="21 15 16 10 5 21"/>
|
||||
</svg>
|
||||
</div>
|
||||
<span class="font-mono tracking-wider uppercase text-sm">Generating image...</span>
|
||||
</div>
|
||||
{:else if message.content || (loading && !message.attachments?.some(a => a.type === 'generated-image'))}
|
||||
<MarkdownContent content={message.content || (loading ? response : '')} />
|
||||
{#if loading && !message.content}
|
||||
<span class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"></span>
|
||||
{/if}
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
@@ -457,6 +496,20 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Scroll anchor for auto-scroll -->
|
||||
<div bind:this={scrollAnchorRef}></div>
|
||||
<!-- Invisible element for container reference -->
|
||||
<div bind:this={containerRef}></div>
|
||||
|
||||
<!-- Scroll to bottom button -->
|
||||
{#if showScrollButton}
|
||||
<button
|
||||
type="button"
|
||||
onclick={scrollToBottom}
|
||||
class="sticky bottom-4 left-1/2 -translate-x-1/2 w-10 h-10 rounded-full bg-exo-dark-gray/90 border border-exo-medium-gray/50 flex items-center justify-center text-exo-light-gray hover:text-exo-yellow hover:border-exo-yellow/50 transition-all shadow-lg cursor-pointer z-10"
|
||||
title="Scroll to bottom"
|
||||
>
|
||||
<svg class="w-5 h-5" fill="none" viewBox="0 0 24 24" stroke="currentColor">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 14l-7 7m0 0l-7-7m7 7V3" />
|
||||
</svg>
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
@@ -10,7 +10,9 @@ import {
|
||||
clearChat,
|
||||
instances,
|
||||
debugMode,
|
||||
toggleDebugMode
|
||||
toggleDebugMode,
|
||||
topologyOnlyMode,
|
||||
toggleTopologyOnlyMode
|
||||
} from '$lib/stores/app.svelte';
|
||||
|
||||
interface Props {
|
||||
@@ -23,6 +25,7 @@ import {
|
||||
const activeId = $derived(activeConversationId());
|
||||
const instanceData = $derived(instances());
|
||||
const debugEnabled = $derived(debugMode());
|
||||
const topologyOnlyEnabled = $derived(topologyOnlyMode());
|
||||
|
||||
let searchQuery = $state('');
|
||||
let editingId = $state<string | null>(null);
|
||||
@@ -424,6 +427,19 @@ const debugEnabled = $derived(debugMode());
|
||||
<div class="text-xs text-white/60 font-mono tracking-wider text-center">
|
||||
{conversationList.length} CONVERSATION{conversationList.length !== 1 ? 'S' : ''}
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
onclick={toggleTopologyOnlyMode}
|
||||
class="p-1.5 rounded border border-exo-medium-gray/40 hover:border-exo-yellow/50 transition-colors cursor-pointer"
|
||||
title="Toggle topology only mode"
|
||||
>
|
||||
<svg class="w-4 h-4 {topologyOnlyEnabled ? 'text-exo-yellow' : 'text-exo-medium-gray'}" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<circle cx="12" cy="5" r="2" fill="currentColor" />
|
||||
<circle cx="5" cy="19" r="2" fill="currentColor" />
|
||||
<circle cx="19" cy="19" r="2" fill="currentColor" />
|
||||
<path stroke-linecap="round" d="M12 7v5m0 0l-5 5m5-5l5 5" />
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</aside>
|
||||
|
||||
@@ -3,6 +3,9 @@
|
||||
|
||||
export let showHome = true;
|
||||
export let onHome: (() => void) | null = null;
|
||||
export let showSidebarToggle = false;
|
||||
export let sidebarVisible = true;
|
||||
export let onToggleSidebar: (() => void) | null = null;
|
||||
|
||||
function handleHome(): void {
|
||||
if (onHome) {
|
||||
@@ -14,13 +17,38 @@
|
||||
window.location.hash = '/';
|
||||
}
|
||||
}
|
||||
|
||||
function handleToggleSidebar(): void {
|
||||
if (onToggleSidebar) {
|
||||
onToggleSidebar();
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<header class="relative z-20 flex items-center justify-center px-6 pt-8 pb-4 bg-exo-dark-gray">
|
||||
<!-- Left: Sidebar Toggle -->
|
||||
{#if showSidebarToggle}
|
||||
<div class="absolute left-6 top-1/2 -translate-y-1/2">
|
||||
<button
|
||||
onclick={handleToggleSidebar}
|
||||
class="p-2 rounded border border-exo-medium-gray/40 hover:border-exo-yellow/50 transition-colors cursor-pointer"
|
||||
title={sidebarVisible ? 'Hide sidebar' : 'Show sidebar'}
|
||||
>
|
||||
<svg class="w-5 h-5 {sidebarVisible ? 'text-exo-yellow' : 'text-exo-medium-gray'}" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
{#if sidebarVisible}
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M11 19l-7-7 7-7m8 14l-7-7 7-7" />
|
||||
{:else}
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M13 5l7 7-7 7M5 5l7 7-7 7" />
|
||||
{/if}
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Center: Logo (clickable to go home) -->
|
||||
<button
|
||||
onclick={handleHome}
|
||||
class="hover:opacity-80 transition-opacity {showHome ? 'cursor-pointer' : 'cursor-default'}"
|
||||
class="bg-transparent border-none outline-none focus:outline-none transition-opacity duration-200 hover:opacity-90 {showHome ? 'cursor-pointer' : 'cursor-default'}"
|
||||
title={showHome ? 'Go to home' : ''}
|
||||
disabled={!showHome}
|
||||
>
|
||||
|
||||
451
dashboard/src/lib/components/MarkdownContent.svelte
Normal file
451
dashboard/src/lib/components/MarkdownContent.svelte
Normal file
@@ -0,0 +1,451 @@
|
||||
<script lang="ts">
|
||||
import { marked } from 'marked';
|
||||
import hljs from 'highlight.js';
|
||||
import katex from 'katex';
|
||||
import 'katex/dist/katex.min.css';
|
||||
import { browser } from '$app/environment';
|
||||
|
||||
interface Props {
|
||||
content: string;
|
||||
class?: string;
|
||||
}
|
||||
|
||||
let { content, class: className = '' }: Props = $props();
|
||||
|
||||
let containerRef = $state<HTMLDivElement>();
|
||||
let processedHtml = $state('');
|
||||
|
||||
// Configure marked with syntax highlighting
|
||||
marked.setOptions({
|
||||
gfm: true,
|
||||
breaks: true
|
||||
});
|
||||
|
||||
// Custom renderer for code blocks
|
||||
const renderer = new marked.Renderer();
|
||||
|
||||
renderer.code = function ({ text, lang }: { text: string; lang?: string }) {
|
||||
const language = lang && hljs.getLanguage(lang) ? lang : 'plaintext';
|
||||
const highlighted = hljs.highlight(text, { language }).value;
|
||||
const codeId = `code-${Date.now()}-${Math.random().toString(36).slice(2, 9)}`;
|
||||
|
||||
return `
|
||||
<div class="code-block-wrapper">
|
||||
<div class="code-block-header">
|
||||
<span class="code-language">${language}</span>
|
||||
<button type="button" class="copy-code-btn" data-code="${encodeURIComponent(text)}" title="Copy code">
|
||||
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<rect width="14" height="14" x="8" y="8" rx="2" ry="2"/>
|
||||
<path d="M4 16c-1.1 0-2-.9-2-2V4c0-1.1.9-2 2-2h10c1.1 0 2 .9 2 2"/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
<pre><code class="hljs language-${language}" data-code-id="${codeId}">${highlighted}</code></pre>
|
||||
</div>
|
||||
`;
|
||||
};
|
||||
|
||||
// Inline code
|
||||
renderer.codespan = function ({ text }: { text: string }) {
|
||||
return `<code class="inline-code">${text}</code>`;
|
||||
};
|
||||
|
||||
marked.use({ renderer });
|
||||
|
||||
/**
|
||||
* Preprocess LaTeX: convert \(...\) to $...$ and \[...\] to $$...$$
|
||||
* Also protect code blocks from LaTeX processing
|
||||
*/
|
||||
function preprocessLaTeX(text: string): string {
|
||||
// Protect code blocks
|
||||
const codeBlocks: string[] = [];
|
||||
let processed = text.replace(/```[\s\S]*?```|`[^`]+`/g, (match) => {
|
||||
codeBlocks.push(match);
|
||||
return `<<CODE_${codeBlocks.length - 1}>>`;
|
||||
});
|
||||
|
||||
// Convert \(...\) to $...$
|
||||
processed = processed.replace(/\\\((.+?)\\\)/g, '$$$1$');
|
||||
|
||||
// Convert \[...\] to $$...$$
|
||||
processed = processed.replace(/\\\[([\s\S]*?)\\\]/g, '$$$$$1$$$$');
|
||||
|
||||
// Restore code blocks
|
||||
processed = processed.replace(/<<CODE_(\d+)>>/g, (_, index) => codeBlocks[parseInt(index)]);
|
||||
|
||||
return processed;
|
||||
}
|
||||
|
||||
/**
|
||||
* Render math expressions with KaTeX after HTML is generated
|
||||
*/
|
||||
function renderMath(html: string): string {
|
||||
// Render display math ($$...$$)
|
||||
html = html.replace(/\$\$([\s\S]*?)\$\$/g, (_, math) => {
|
||||
try {
|
||||
return katex.renderToString(math.trim(), {
|
||||
displayMode: true,
|
||||
throwOnError: false,
|
||||
output: 'html'
|
||||
});
|
||||
} catch {
|
||||
return `<span class="math-error">$$${math}$$</span>`;
|
||||
}
|
||||
});
|
||||
|
||||
// Render inline math ($...$) but avoid matching currency like $5
|
||||
html = html.replace(/\$([^\$\n]+?)\$/g, (match, math) => {
|
||||
// Skip if it looks like currency ($ followed by number)
|
||||
if (/^\d/.test(math.trim())) {
|
||||
return match;
|
||||
}
|
||||
try {
|
||||
return katex.renderToString(math.trim(), {
|
||||
displayMode: false,
|
||||
throwOnError: false,
|
||||
output: 'html'
|
||||
});
|
||||
} catch {
|
||||
return `<span class="math-error">$${math}$</span>`;
|
||||
}
|
||||
});
|
||||
|
||||
return html;
|
||||
}
|
||||
|
||||
function processMarkdown(text: string): string {
|
||||
try {
|
||||
// Preprocess LaTeX notation
|
||||
const preprocessed = preprocessLaTeX(text);
|
||||
// Parse markdown
|
||||
let html = marked.parse(preprocessed) as string;
|
||||
// Render math expressions
|
||||
html = renderMath(html);
|
||||
return html;
|
||||
} catch (error) {
|
||||
console.error('Markdown processing error:', error);
|
||||
return text.replace(/\n/g, '<br>');
|
||||
}
|
||||
}
|
||||
|
||||
async function handleCopyClick(event: Event) {
|
||||
const target = event.currentTarget as HTMLButtonElement;
|
||||
const encodedCode = target.getAttribute('data-code');
|
||||
if (!encodedCode) return;
|
||||
|
||||
const code = decodeURIComponent(encodedCode);
|
||||
|
||||
try {
|
||||
await navigator.clipboard.writeText(code);
|
||||
// Show copied feedback
|
||||
const originalHtml = target.innerHTML;
|
||||
target.innerHTML = `
|
||||
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<path d="M20 6L9 17l-5-5"/>
|
||||
</svg>
|
||||
`;
|
||||
target.classList.add('copied');
|
||||
setTimeout(() => {
|
||||
target.innerHTML = originalHtml;
|
||||
target.classList.remove('copied');
|
||||
}, 2000);
|
||||
} catch (error) {
|
||||
console.error('Failed to copy:', error);
|
||||
}
|
||||
}
|
||||
|
||||
function setupCopyButtons() {
|
||||
if (!containerRef || !browser) return;
|
||||
|
||||
const buttons = containerRef.querySelectorAll<HTMLButtonElement>('.copy-code-btn');
|
||||
for (const button of buttons) {
|
||||
if (button.dataset.listenerBound !== 'true') {
|
||||
button.dataset.listenerBound = 'true';
|
||||
button.addEventListener('click', handleCopyClick);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (content) {
|
||||
processedHtml = processMarkdown(content);
|
||||
} else {
|
||||
processedHtml = '';
|
||||
}
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
if (containerRef && processedHtml) {
|
||||
setupCopyButtons();
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
<div bind:this={containerRef} class="markdown-content {className}">
|
||||
{@html processedHtml}
|
||||
</div>
|
||||
|
||||
<style>
|
||||
.markdown-content {
|
||||
line-height: 1.6;
|
||||
}
|
||||
|
||||
/* Paragraphs */
|
||||
.markdown-content :global(p) {
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(p:last-child) {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
/* Headers */
|
||||
.markdown-content :global(h1) {
|
||||
font-size: 1.5rem;
|
||||
font-weight: 700;
|
||||
margin: 1.5rem 0 0.75rem 0;
|
||||
color: var(--exo-yellow, #ffd700);
|
||||
}
|
||||
|
||||
.markdown-content :global(h2) {
|
||||
font-size: 1.25rem;
|
||||
font-weight: 600;
|
||||
margin: 1.25rem 0 0.5rem 0;
|
||||
color: var(--exo-yellow, #ffd700);
|
||||
}
|
||||
|
||||
.markdown-content :global(h3) {
|
||||
font-size: 1.125rem;
|
||||
font-weight: 600;
|
||||
margin: 1rem 0 0.5rem 0;
|
||||
}
|
||||
|
||||
.markdown-content :global(h4),
|
||||
.markdown-content :global(h5),
|
||||
.markdown-content :global(h6) {
|
||||
font-size: 1rem;
|
||||
font-weight: 600;
|
||||
margin: 0.75rem 0 0.25rem 0;
|
||||
}
|
||||
|
||||
/* Bold and italic */
|
||||
.markdown-content :global(strong) {
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.markdown-content :global(em) {
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
/* Inline code */
|
||||
.markdown-content :global(.inline-code) {
|
||||
background: rgba(255, 215, 0, 0.1);
|
||||
color: var(--exo-yellow, #ffd700);
|
||||
padding: 0.125rem 0.375rem;
|
||||
border-radius: 0.25rem;
|
||||
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
|
||||
font-size: 0.875em;
|
||||
}
|
||||
|
||||
/* Links */
|
||||
.markdown-content :global(a) {
|
||||
color: var(--exo-yellow, #ffd700);
|
||||
text-decoration: underline;
|
||||
text-underline-offset: 2px;
|
||||
}
|
||||
|
||||
.markdown-content :global(a:hover) {
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
/* Lists */
|
||||
.markdown-content :global(ul) {
|
||||
list-style-type: disc;
|
||||
margin-left: 1.5rem;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(ol) {
|
||||
list-style-type: decimal;
|
||||
margin-left: 1.5rem;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(li) {
|
||||
margin-bottom: 0.25rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(li::marker) {
|
||||
color: var(--exo-light-gray, #9ca3af);
|
||||
}
|
||||
|
||||
/* Blockquotes */
|
||||
.markdown-content :global(blockquote) {
|
||||
border-left: 3px solid var(--exo-yellow, #ffd700);
|
||||
padding: 0.5rem 1rem;
|
||||
margin: 1rem 0;
|
||||
background: rgba(255, 215, 0, 0.05);
|
||||
border-radius: 0 0.25rem 0.25rem 0;
|
||||
}
|
||||
|
||||
/* Tables */
|
||||
.markdown-content :global(table) {
|
||||
width: 100%;
|
||||
margin: 1rem 0;
|
||||
border-collapse: collapse;
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(th) {
|
||||
background: rgba(255, 215, 0, 0.1);
|
||||
border: 1px solid rgba(255, 215, 0, 0.2);
|
||||
padding: 0.5rem;
|
||||
text-align: left;
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.markdown-content :global(td) {
|
||||
border: 1px solid rgba(255, 255, 255, 0.1);
|
||||
padding: 0.5rem;
|
||||
}
|
||||
|
||||
/* Horizontal rule */
|
||||
.markdown-content :global(hr) {
|
||||
border: none;
|
||||
border-top: 1px solid rgba(255, 255, 255, 0.1);
|
||||
margin: 1.5rem 0;
|
||||
}
|
||||
|
||||
/* Code block wrapper */
|
||||
.markdown-content :global(.code-block-wrapper) {
|
||||
margin: 1rem 0;
|
||||
border-radius: 0.5rem;
|
||||
overflow: hidden;
|
||||
border: 1px solid rgba(255, 215, 0, 0.2);
|
||||
background: rgba(0, 0, 0, 0.4);
|
||||
}
|
||||
|
||||
.markdown-content :global(.code-block-header) {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
padding: 0.5rem 0.75rem;
|
||||
background: rgba(255, 215, 0, 0.05);
|
||||
border-bottom: 1px solid rgba(255, 215, 0, 0.1);
|
||||
}
|
||||
|
||||
.markdown-content :global(.code-language) {
|
||||
color: var(--exo-yellow, #ffd700);
|
||||
font-size: 0.7rem;
|
||||
font-weight: 500;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.1em;
|
||||
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
|
||||
}
|
||||
|
||||
.markdown-content :global(.copy-code-btn) {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
padding: 0.25rem;
|
||||
background: transparent;
|
||||
border: none;
|
||||
color: var(--exo-light-gray, #9ca3af);
|
||||
cursor: pointer;
|
||||
transition: color 0.2s;
|
||||
border-radius: 0.25rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(.copy-code-btn:hover) {
|
||||
color: var(--exo-yellow, #ffd700);
|
||||
}
|
||||
|
||||
.markdown-content :global(.copy-code-btn.copied) {
|
||||
color: #22c55e;
|
||||
}
|
||||
|
||||
.markdown-content :global(.code-block-wrapper pre) {
|
||||
margin: 0;
|
||||
padding: 1rem;
|
||||
overflow-x: auto;
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
.markdown-content :global(.code-block-wrapper code) {
|
||||
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
|
||||
font-size: 0.8125rem;
|
||||
line-height: 1.5;
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
/* Syntax highlighting - dark theme matching EXO style */
|
||||
.markdown-content :global(.hljs) {
|
||||
color: #e5e7eb;
|
||||
}
|
||||
|
||||
.markdown-content :global(.hljs-keyword),
|
||||
.markdown-content :global(.hljs-selector-tag),
|
||||
.markdown-content :global(.hljs-literal),
|
||||
.markdown-content :global(.hljs-section),
|
||||
.markdown-content :global(.hljs-link) {
|
||||
color: #c084fc;
|
||||
}
|
||||
|
||||
.markdown-content :global(.hljs-string),
|
||||
.markdown-content :global(.hljs-title),
|
||||
.markdown-content :global(.hljs-name),
|
||||
.markdown-content :global(.hljs-type),
|
||||
.markdown-content :global(.hljs-attribute),
|
||||
.markdown-content :global(.hljs-symbol),
|
||||
.markdown-content :global(.hljs-bullet),
|
||||
.markdown-content :global(.hljs-addition),
|
||||
.markdown-content :global(.hljs-variable),
|
||||
.markdown-content :global(.hljs-template-tag),
|
||||
.markdown-content :global(.hljs-template-variable) {
|
||||
color: #fbbf24;
|
||||
}
|
||||
|
||||
.markdown-content :global(.hljs-comment),
|
||||
.markdown-content :global(.hljs-quote),
|
||||
.markdown-content :global(.hljs-deletion),
|
||||
.markdown-content :global(.hljs-meta) {
|
||||
color: #6b7280;
|
||||
}
|
||||
|
||||
.markdown-content :global(.hljs-number),
|
||||
.markdown-content :global(.hljs-regexp),
|
||||
.markdown-content :global(.hljs-literal),
|
||||
.markdown-content :global(.hljs-built_in) {
|
||||
color: #34d399;
|
||||
}
|
||||
|
||||
.markdown-content :global(.hljs-function),
|
||||
.markdown-content :global(.hljs-class .hljs-title) {
|
||||
color: #60a5fa;
|
||||
}
|
||||
|
||||
/* KaTeX math styling */
|
||||
.markdown-content :global(.katex) {
|
||||
font-size: 1.1em;
|
||||
}
|
||||
|
||||
.markdown-content :global(.katex-display) {
|
||||
margin: 1rem 0;
|
||||
overflow-x: auto;
|
||||
overflow-y: hidden;
|
||||
padding: 0.5rem 0;
|
||||
}
|
||||
|
||||
.markdown-content :global(.katex-display > .katex) {
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-error) {
|
||||
color: #f87171;
|
||||
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
|
||||
font-size: 0.875em;
|
||||
background: rgba(248, 113, 113, 0.1);
|
||||
padding: 0.125rem 0.25rem;
|
||||
border-radius: 0.25rem;
|
||||
}
|
||||
</style>
|
||||
@@ -1,5 +1,6 @@
|
||||
<script lang="ts">
|
||||
import type { DownloadProgress, NodeInfo, PlacementPreview } from '$lib/stores/app.svelte';
|
||||
import type { DownloadProgress, NodeInfo, PlacementPreview, TopologyEdge } from '$lib/stores/app.svelte';
|
||||
import { debugMode, topologyData } from '$lib/stores/app.svelte';
|
||||
|
||||
interface Props {
|
||||
model: { id: string; name?: string; storage_size_megabytes?: number };
|
||||
@@ -206,12 +207,8 @@ function toggleNodeDetails(nodeId: string): void {
|
||||
const centerY = topoHeight / 2;
|
||||
const radius = numNodes === 1 ? 0 : numNodes === 2 ? 45 : Math.min(topoWidth, topoHeight) * 0.32;
|
||||
|
||||
// Use API preview data if available
|
||||
// Only use API preview data - no local estimation
|
||||
const hasApiPreview = apiPreview !== null && apiPreview.error === null && apiPreview.memory_delta_by_node !== null;
|
||||
const canFit = hasApiPreview ? true : (() => {
|
||||
const totalAvailable = nodeArray.reduce((sum, n) => sum + n.availableGB, 0);
|
||||
return totalAvailable >= estimatedMemory;
|
||||
})();
|
||||
const error = apiPreview?.error ?? null;
|
||||
|
||||
let placementNodes: Array<{
|
||||
@@ -232,135 +229,140 @@ function toggleNodeDetails(nodeId: string): void {
|
||||
modelFillHeight: number;
|
||||
}> = [];
|
||||
|
||||
if (hasApiPreview && apiPreview.memory_delta_by_node) {
|
||||
// Use API placement data
|
||||
const memoryDelta = apiPreview.memory_delta_by_node;
|
||||
placementNodes = nodeArray.map((n, i) => {
|
||||
const deltaBytes = memoryDelta[n.id] ?? 0;
|
||||
const modelUsageGB = deltaBytes / (1024 * 1024 * 1024);
|
||||
const isUsed = deltaBytes > 0;
|
||||
const angle = numNodes === 1 ? 0 : (i / numNodes) * Math.PI * 2 - Math.PI / 2;
|
||||
const safeTotal = Math.max(n.totalGB, 0.001);
|
||||
const currentPercent = clampPercent((n.usedGB / safeTotal) * 100);
|
||||
const newPercent = clampPercent(((n.usedGB + modelUsageGB) / safeTotal) * 100);
|
||||
const screenHeight = iconSize * 0.58;
|
||||
|
||||
return {
|
||||
id: n.id,
|
||||
deviceName: n.deviceName,
|
||||
deviceType: n.deviceType,
|
||||
totalGB: n.totalGB,
|
||||
currentUsedGB: n.usedGB,
|
||||
modelUsageGB,
|
||||
currentPercent,
|
||||
newPercent,
|
||||
isUsed,
|
||||
x: centerX + Math.cos(angle) * radius,
|
||||
y: centerY + Math.sin(angle) * radius,
|
||||
iconSize,
|
||||
screenHeight,
|
||||
currentFillHeight: screenHeight * (currentPercent / 100),
|
||||
modelFillHeight: screenHeight * ((newPercent - currentPercent) / 100)
|
||||
};
|
||||
});
|
||||
} else if (apiPreview?.error) {
|
||||
// API returned an error - model can't fit, show all nodes as unused
|
||||
placementNodes = nodeArray.map((n, i) => {
|
||||
const angle = numNodes === 1 ? 0 : (i / numNodes) * Math.PI * 2 - Math.PI / 2;
|
||||
const safeTotal = Math.max(n.totalGB, 0.001);
|
||||
const currentPercent = clampPercent((n.usedGB / safeTotal) * 100);
|
||||
const screenHeight = iconSize * 0.58;
|
||||
|
||||
return {
|
||||
id: n.id,
|
||||
deviceName: n.deviceName,
|
||||
deviceType: n.deviceType,
|
||||
totalGB: n.totalGB,
|
||||
currentUsedGB: n.usedGB,
|
||||
modelUsageGB: 0,
|
||||
currentPercent,
|
||||
newPercent: currentPercent,
|
||||
isUsed: false,
|
||||
x: centerX + Math.cos(angle) * radius,
|
||||
y: centerY + Math.sin(angle) * radius,
|
||||
iconSize,
|
||||
screenHeight,
|
||||
currentFillHeight: screenHeight * (currentPercent / 100),
|
||||
modelFillHeight: 0
|
||||
};
|
||||
});
|
||||
} else {
|
||||
// Fallback: local estimation based on sharding strategy
|
||||
const memoryNeeded = estimatedMemory;
|
||||
// Use API placement data directly
|
||||
const memoryDelta = apiPreview?.memory_delta_by_node ?? {};
|
||||
placementNodes = nodeArray.map((n, i) => {
|
||||
const deltaBytes = memoryDelta[n.id] ?? 0;
|
||||
const modelUsageGB = deltaBytes / (1024 * 1024 * 1024);
|
||||
const isUsed = deltaBytes > 0;
|
||||
const angle = numNodes === 1 ? 0 : (i / numNodes) * Math.PI * 2 - Math.PI / 2;
|
||||
const safeTotal = Math.max(n.totalGB, 0.001);
|
||||
const currentPercent = clampPercent((n.usedGB / safeTotal) * 100);
|
||||
const newPercent = clampPercent(((n.usedGB + modelUsageGB) / safeTotal) * 100);
|
||||
const screenHeight = iconSize * 0.58;
|
||||
|
||||
if (sharding === 'Pipeline') {
|
||||
const memoryPerNode = memoryNeeded / numNodes;
|
||||
placementNodes = nodeArray.map((n, i) => {
|
||||
const angle = numNodes === 1 ? 0 : (i / numNodes) * Math.PI * 2 - Math.PI / 2;
|
||||
const safeTotal = Math.max(n.totalGB, 0.001);
|
||||
const currentPercent = clampPercent((n.usedGB / safeTotal) * 100);
|
||||
const newPercent = clampPercent(((n.usedGB + memoryPerNode) / safeTotal) * 100);
|
||||
const screenHeight = iconSize * 0.58;
|
||||
|
||||
return {
|
||||
id: n.id,
|
||||
deviceName: n.deviceName,
|
||||
deviceType: n.deviceType,
|
||||
totalGB: n.totalGB,
|
||||
currentUsedGB: n.usedGB,
|
||||
modelUsageGB: memoryPerNode,
|
||||
currentPercent,
|
||||
newPercent,
|
||||
isUsed: true,
|
||||
x: centerX + Math.cos(angle) * radius,
|
||||
y: centerY + Math.sin(angle) * radius,
|
||||
iconSize,
|
||||
screenHeight,
|
||||
currentFillHeight: screenHeight * (currentPercent / 100),
|
||||
modelFillHeight: screenHeight * ((newPercent - currentPercent) / 100)
|
||||
};
|
||||
});
|
||||
} else {
|
||||
let remaining = memoryNeeded;
|
||||
placementNodes = nodeArray.map((n, i) => {
|
||||
const allocated = Math.min(remaining, n.availableGB);
|
||||
remaining -= allocated;
|
||||
const isUsed = allocated > 0;
|
||||
const angle = numNodes === 1 ? 0 : (i / numNodes) * Math.PI * 2 - Math.PI / 2;
|
||||
const safeTotal = Math.max(n.totalGB, 0.001);
|
||||
const currentPercent = clampPercent((n.usedGB / safeTotal) * 100);
|
||||
const newPercent = clampPercent(((n.usedGB + allocated) / safeTotal) * 100);
|
||||
const screenHeight = iconSize * 0.58;
|
||||
|
||||
return {
|
||||
id: n.id,
|
||||
deviceName: n.deviceName,
|
||||
deviceType: n.deviceType,
|
||||
totalGB: n.totalGB,
|
||||
currentUsedGB: n.usedGB,
|
||||
modelUsageGB: allocated,
|
||||
currentPercent,
|
||||
newPercent,
|
||||
isUsed,
|
||||
x: centerX + Math.cos(angle) * radius,
|
||||
y: centerY + Math.sin(angle) * radius,
|
||||
iconSize,
|
||||
screenHeight,
|
||||
currentFillHeight: screenHeight * (currentPercent / 100),
|
||||
modelFillHeight: screenHeight * ((newPercent - currentPercent) / 100)
|
||||
};
|
||||
});
|
||||
}
|
||||
}
|
||||
return {
|
||||
id: n.id,
|
||||
deviceName: n.deviceName,
|
||||
deviceType: n.deviceType,
|
||||
totalGB: n.totalGB,
|
||||
currentUsedGB: n.usedGB,
|
||||
modelUsageGB,
|
||||
currentPercent,
|
||||
newPercent,
|
||||
isUsed,
|
||||
x: centerX + Math.cos(angle) * radius,
|
||||
y: centerY + Math.sin(angle) * radius,
|
||||
iconSize,
|
||||
screenHeight,
|
||||
currentFillHeight: screenHeight * (currentPercent / 100),
|
||||
modelFillHeight: screenHeight * ((newPercent - currentPercent) / 100)
|
||||
};
|
||||
});
|
||||
|
||||
const totalAvailable = nodeArray.reduce((sum, n) => sum + n.availableGB, 0);
|
||||
return { nodes: placementNodes, canFit: hasApiPreview || canFit, totalAvailable, topoWidth, topoHeight, error };
|
||||
return { nodes: placementNodes, canFit: hasApiPreview, totalAvailable, topoWidth, topoHeight, error };
|
||||
});
|
||||
|
||||
const canFit = $derived(apiPreview ? apiPreview.error === null : placementPreview().canFit);
|
||||
const placementError = $derived(apiPreview?.error ?? null);
|
||||
const nodeCount = $derived(nodeList().length);
|
||||
const filterId = $derived(model.id.replace(/[^a-zA-Z0-9]/g, ''));
|
||||
|
||||
// Debug mode state
|
||||
const isDebugMode = $derived(debugMode());
|
||||
const topology = $derived(topologyData());
|
||||
const isRdma = $derived(runtime === 'MlxIbv' || runtime === 'MlxJaccl');
|
||||
|
||||
// Get interface name for an IP from node data
|
||||
function getInterfaceForIp(nodeId: string, ip?: string): string | null {
|
||||
if (!ip || !topology?.nodes) return null;
|
||||
|
||||
// Strip port if present
|
||||
const cleanIp = ip.includes(':') && !ip.includes('[') ? ip.split(':')[0] : ip;
|
||||
|
||||
// Check specified node first
|
||||
const node = topology.nodes[nodeId];
|
||||
if (node) {
|
||||
const match = node.network_interfaces?.find((iface) =>
|
||||
(iface.addresses || []).some((addr) => addr === cleanIp || addr === ip)
|
||||
);
|
||||
if (match?.name) return match.name;
|
||||
|
||||
const mapped = node.ip_to_interface?.[cleanIp] || node.ip_to_interface?.[ip];
|
||||
if (mapped) return mapped;
|
||||
}
|
||||
|
||||
// Fallback: check all nodes
|
||||
for (const [, otherNode] of Object.entries(topology.nodes)) {
|
||||
if (!otherNode) continue;
|
||||
const match = otherNode.network_interfaces?.find((iface) =>
|
||||
(iface.addresses || []).some((addr) => addr === cleanIp || addr === ip)
|
||||
);
|
||||
if (match?.name) return match.name;
|
||||
|
||||
const mapped = otherNode.ip_to_interface?.[cleanIp] || otherNode.ip_to_interface?.[ip];
|
||||
if (mapped) return mapped;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
// Get directional arrow based on node positions
|
||||
function getArrow(fromNode: { x: number; y: number }, toNode: { x: number; y: number }): string {
|
||||
const dx = toNode.x - fromNode.x;
|
||||
const dy = toNode.y - fromNode.y;
|
||||
const absX = Math.abs(dx);
|
||||
const absY = Math.abs(dy);
|
||||
|
||||
if (absX > absY * 2) {
|
||||
return dx > 0 ? '→' : '←';
|
||||
} else if (absY > absX * 2) {
|
||||
return dy > 0 ? '↓' : '↑';
|
||||
} else {
|
||||
if (dx > 0 && dy > 0) return '↘';
|
||||
if (dx > 0 && dy < 0) return '↗';
|
||||
if (dx < 0 && dy > 0) return '↙';
|
||||
return '↖';
|
||||
}
|
||||
}
|
||||
|
||||
// Get connection info for edges between two nodes
|
||||
// Returns exactly one connection per direction (A→B and B→A), preferring non-loopback
|
||||
function getConnectionInfo(nodeId1: string, nodeId2: string): Array<{ ip: string; iface: string | null; from: string; to: string }> {
|
||||
if (!topology?.edges) return [];
|
||||
|
||||
// Collect candidates for each direction
|
||||
const aToBCandidates: Array<{ ip: string; iface: string | null }> = [];
|
||||
const bToACandidates: Array<{ ip: string; iface: string | null }> = [];
|
||||
|
||||
for (const edge of topology.edges) {
|
||||
const ip = edge.sendBackIp || '?';
|
||||
const iface = edge.sendBackInterface || getInterfaceForIp(edge.source, ip);
|
||||
|
||||
if (edge.source === nodeId1 && edge.target === nodeId2) {
|
||||
aToBCandidates.push({ ip, iface });
|
||||
} else if (edge.source === nodeId2 && edge.target === nodeId1) {
|
||||
bToACandidates.push({ ip, iface });
|
||||
}
|
||||
}
|
||||
|
||||
// Pick best (prefer non-loopback)
|
||||
const pickBest = (candidates: Array<{ ip: string; iface: string | null }>) => {
|
||||
if (candidates.length === 0) return null;
|
||||
return candidates.find(c => !c.ip.startsWith('127.')) || candidates[0];
|
||||
};
|
||||
|
||||
const result: Array<{ ip: string; iface: string | null; from: string; to: string }> = [];
|
||||
|
||||
const bestAtoB = pickBest(aToBCandidates);
|
||||
if (bestAtoB) result.push({ ...bestAtoB, from: nodeId1, to: nodeId2 });
|
||||
|
||||
const bestBtoA = pickBest(bToACandidates);
|
||||
if (bestBtoA) result.push({ ...bestBtoA, from: nodeId2, to: nodeId1 });
|
||||
|
||||
return result;
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="relative group">
|
||||
@@ -453,6 +455,26 @@ function toggleNodeDetails(nodeId: string): void {
|
||||
|
||||
<!-- Connection lines between nodes (if multiple) -->
|
||||
{#if preview.nodes.length > 1}
|
||||
{@const usedNodes = preview.nodes.filter(n => n.isUsed)}
|
||||
{@const nodePositions = Object.fromEntries(preview.nodes.map(n => [n.id, { x: n.x, y: n.y }]))}
|
||||
{@const allConnections = isDebugMode && usedNodes.length > 1 ? (() => {
|
||||
const conns: Array<{ ip: string; iface: string | null; from: string; to: string; midX: number; midY: number; arrow: string }> = [];
|
||||
for (let i = 0; i < usedNodes.length; i++) {
|
||||
for (let j = i + 1; j < usedNodes.length; j++) {
|
||||
const n1 = usedNodes[i];
|
||||
const n2 = usedNodes[j];
|
||||
const midX = (n1.x + n2.x) / 2;
|
||||
const midY = (n1.y + n2.y) / 2;
|
||||
for (const c of getConnectionInfo(n1.id, n2.id)) {
|
||||
const fromPos = nodePositions[c.from];
|
||||
const toPos = nodePositions[c.to];
|
||||
const arrow = fromPos && toPos ? getArrow(fromPos, toPos) : '→';
|
||||
conns.push({ ...c, midX, midY, arrow });
|
||||
}
|
||||
}
|
||||
}
|
||||
return conns;
|
||||
})() : []}
|
||||
{#each preview.nodes as node, i}
|
||||
{#each preview.nodes.slice(i + 1) as node2}
|
||||
<line
|
||||
@@ -464,6 +486,43 @@ function toggleNodeDetails(nodeId: string): void {
|
||||
/>
|
||||
{/each}
|
||||
{/each}
|
||||
<!-- Debug: Show connection IPs/interfaces in corners -->
|
||||
{#if isDebugMode && allConnections.length > 0}
|
||||
{@const centerX = preview.topoWidth / 2}
|
||||
{@const centerY = preview.topoHeight / 2}
|
||||
{@const quadrants = {
|
||||
topLeft: allConnections.filter(c => c.midX < centerX && c.midY < centerY),
|
||||
topRight: allConnections.filter(c => c.midX >= centerX && c.midY < centerY),
|
||||
bottomLeft: allConnections.filter(c => c.midX < centerX && c.midY >= centerY),
|
||||
bottomRight: allConnections.filter(c => c.midX >= centerX && c.midY >= centerY)
|
||||
}}
|
||||
{@const padding = 4}
|
||||
{@const lineHeight = 8}
|
||||
<!-- Top Left -->
|
||||
{#each quadrants.topLeft as conn, idx}
|
||||
<text x={padding} y={padding + idx * lineHeight} text-anchor="start" dominant-baseline="hanging" font-size="6" font-family="SF Mono, Monaco, monospace" fill={conn.iface ? 'rgba(255,255,255,0.85)' : 'rgba(248,113,113,0.85)'}>
|
||||
{conn.arrow} {isRdma ? (conn.iface || '?') : `${conn.ip}${conn.iface ? ` (${conn.iface})` : ''}`}
|
||||
</text>
|
||||
{/each}
|
||||
<!-- Top Right -->
|
||||
{#each quadrants.topRight as conn, idx}
|
||||
<text x={preview.topoWidth - padding} y={padding + idx * lineHeight} text-anchor="end" dominant-baseline="hanging" font-size="6" font-family="SF Mono, Monaco, monospace" fill={conn.iface ? 'rgba(255,255,255,0.85)' : 'rgba(248,113,113,0.85)'}>
|
||||
{conn.arrow} {isRdma ? (conn.iface || '?') : `${conn.ip}${conn.iface ? ` (${conn.iface})` : ''}`}
|
||||
</text>
|
||||
{/each}
|
||||
<!-- Bottom Left -->
|
||||
{#each quadrants.bottomLeft as conn, idx}
|
||||
<text x={padding} y={preview.topoHeight - padding - (quadrants.bottomLeft.length - 1 - idx) * lineHeight} text-anchor="start" dominant-baseline="auto" font-size="6" font-family="SF Mono, Monaco, monospace" fill={conn.iface ? 'rgba(255,255,255,0.85)' : 'rgba(248,113,113,0.85)'}>
|
||||
{conn.arrow} {isRdma ? (conn.iface || '?') : `${conn.ip}${conn.iface ? ` (${conn.iface})` : ''}`}
|
||||
</text>
|
||||
{/each}
|
||||
<!-- Bottom Right -->
|
||||
{#each quadrants.bottomRight as conn, idx}
|
||||
<text x={preview.topoWidth - padding} y={preview.topoHeight - padding - (quadrants.bottomRight.length - 1 - idx) * lineHeight} text-anchor="end" dominant-baseline="auto" font-size="6" font-family="SF Mono, Monaco, monospace" fill={conn.iface ? 'rgba(255,255,255,0.85)' : 'rgba(248,113,113,0.85)'}>
|
||||
{conn.arrow} {isRdma ? (conn.iface || '?') : `${conn.ip}${conn.iface ? ` (${conn.iface})` : ''}`}
|
||||
</text>
|
||||
{/each}
|
||||
{/if}
|
||||
{/if}
|
||||
|
||||
{#each preview.nodes as node}
|
||||
|
||||
@@ -24,19 +24,36 @@ function getNodeLabel(nodeId: string): string {
|
||||
|
||||
function getInterfaceLabel(nodeId: string, ip?: string): { label: string; missing: boolean } {
|
||||
if (!ip) return { label: '?', missing: true };
|
||||
const node = data?.nodes?.[nodeId];
|
||||
if (!node) return { label: '?', missing: true };
|
||||
|
||||
// Strip port if present (e.g., "192.168.1.1:8080" -> "192.168.1.1")
|
||||
const cleanIp = ip.includes(':') && !ip.includes('[') ? ip.split(':')[0] : ip;
|
||||
|
||||
// Helper to check a node's interfaces
|
||||
function checkNode(node: typeof data.nodes[string]): string | null {
|
||||
if (!node) return null;
|
||||
|
||||
const matchFromInterfaces = node.network_interfaces?.find((iface) =>
|
||||
(iface.addresses || []).some((addr) => addr === cleanIp || addr === ip)
|
||||
);
|
||||
if (matchFromInterfaces?.name) {
|
||||
return matchFromInterfaces.name;
|
||||
}
|
||||
|
||||
const matchFromInterfaces = node.network_interfaces?.find((iface) =>
|
||||
(iface.addresses || []).some((addr) => addr === ip)
|
||||
);
|
||||
if (matchFromInterfaces?.name) {
|
||||
return { label: matchFromInterfaces.name, missing: false };
|
||||
const mapped = node.ip_to_interface?.[cleanIp] || node.ip_to_interface?.[ip];
|
||||
if (mapped && mapped.trim().length > 0) {
|
||||
return mapped;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
const mapped = node.ip_to_interface?.[ip];
|
||||
if (mapped && mapped.trim().length > 0) {
|
||||
return { label: mapped, missing: false };
|
||||
|
||||
// Try specified node first
|
||||
const result = checkNode(data?.nodes?.[nodeId]);
|
||||
if (result) return { label: result, missing: false };
|
||||
|
||||
// Fallback: search all nodes for this IP
|
||||
for (const [, otherNode] of Object.entries(data?.nodes || {})) {
|
||||
const otherResult = checkNode(otherNode);
|
||||
if (otherResult) return { label: otherResult, missing: false };
|
||||
}
|
||||
|
||||
return { label: '?', missing: true };
|
||||
@@ -67,6 +84,7 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
return lines;
|
||||
}
|
||||
|
||||
|
||||
// Apple logo path for MacBook Pro screen
|
||||
const APPLE_LOGO_PATH = "M788.1 340.9c-5.8 4.5-108.2 62.2-108.2 190.5 0 148.4 130.3 200.9 134.2 202.2-.6 3.2-20.7 71.9-68.7 141.9-42.8 61.6-87.5 123.1-155.5 123.1s-85.5-39.5-164-39.5c-76.5 0-103.7 40.8-165.9 40.8s-105.6-57-155.5-127C46.7 790.7 0 663 0 541.8c0-194.4 126.4-297.5 250.8-297.5 66.1 0 121.2 43.4 162.7 43.4 39.5 0 101.1-46 176.3-46 28.5 0 130.9 2.6 198.3 99.2zm-234-181.5c31.1-36.9 53.1-88.1 53.1-139.3 0-7.1-.6-14.3-1.9-20.1-50.6 1.9-110.8 33.7-147.1 75.8-28.5 32.4-55.1 83.6-55.1 135.5 0 7.8 1.3 15.6 1.9 18.1 3.2.6 8.4 1.3 13.6 1.3 45.4 0 102.5-30.4 135.5-71.3z";
|
||||
const LOGO_NATIVE_WIDTH = 814;
|
||||
@@ -238,6 +256,7 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
const debugLabelsGroup = svg.append('g').attr('class', 'debug-edge-labels');
|
||||
|
||||
const pairMap = new Map<string, { a: string; b: string; aToB: boolean; bToA: boolean; connections: Array<{ from: string; to: string; ip: string; ifaceLabel: string; missingIface: boolean }> }>();
|
||||
let debugEdgeLabels: Array<{ connections: typeof pairMap extends Map<string, infer V> ? V['connections'] : never; isLeft: boolean; isTop: boolean; mx: number; my: number }> | null = null;
|
||||
edges.forEach(edge => {
|
||||
if (!edge.source || !edge.target || edge.source === edge.target) return;
|
||||
if (!positionById[edge.source] || !positionById[edge.target]) return;
|
||||
@@ -314,110 +333,98 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
.attr('marker-end', 'url(#arrowhead)');
|
||||
}
|
||||
|
||||
// Collect debug labels for later positioning at edges
|
||||
if (debugEnabled && entry.connections.length > 0) {
|
||||
const maxBoxes = 6;
|
||||
const fontSize = isMinimized ? 8 : 9;
|
||||
const lineGap = 2;
|
||||
const labelOffsetOut = Math.max(140, minDimension * 0.38);
|
||||
const labelOffsetSide = isMinimized ? 16 : 20;
|
||||
const boxWidth = 170;
|
||||
const maxLineLen = 26;
|
||||
|
||||
const connections = entry.connections.slice(0, maxBoxes);
|
||||
if (entry.connections.length > maxBoxes) {
|
||||
const remaining = entry.connections.length - maxBoxes;
|
||||
connections.push({
|
||||
from: '',
|
||||
to: '',
|
||||
ip: `(+${remaining} more)`,
|
||||
ifaceLabel: '',
|
||||
missingIface: false
|
||||
});
|
||||
}
|
||||
|
||||
let dirX = mx - centerX;
|
||||
let dirY = my - centerY;
|
||||
const dirLen = Math.hypot(dirX, dirY);
|
||||
if (dirLen < 1) {
|
||||
dirX = -uy;
|
||||
dirY = ux;
|
||||
} else {
|
||||
dirX /= dirLen;
|
||||
dirY /= dirLen;
|
||||
}
|
||||
|
||||
const nx = -dirY;
|
||||
const ny = dirX;
|
||||
|
||||
const labelXRaw = mx + dirX * labelOffsetOut + nx * labelOffsetSide;
|
||||
const labelYRaw = my + dirY * labelOffsetOut + ny * labelOffsetSide;
|
||||
const clampPad = Math.min(120, minDimension * 0.12);
|
||||
const labelX = Math.max(clampPad, Math.min(width - clampPad, labelXRaw));
|
||||
const labelY = Math.max(clampPad, Math.min(height - clampPad, labelYRaw));
|
||||
|
||||
const labelGroup = debugLabelsGroup.append('g')
|
||||
.attr('transform', `translate(${labelX}, ${labelY})`);
|
||||
|
||||
const textGroup = labelGroup.append('g');
|
||||
|
||||
connections.forEach((conn, idx) => {
|
||||
const rawLines = conn.from && conn.to
|
||||
? [
|
||||
`${getNodeLabel(conn.from)}→${getNodeLabel(conn.to)}`,
|
||||
`${conn.ip}`,
|
||||
`${conn.ifaceLabel}`
|
||||
]
|
||||
: [conn.ip];
|
||||
|
||||
const wrapped = rawLines.flatMap(line => wrapLine(line, maxLineLen));
|
||||
|
||||
wrapped.forEach((line, lineIdx) => {
|
||||
textGroup.append('text')
|
||||
.attr('x', 0)
|
||||
.attr('y', (idx * (wrapped.length * (fontSize + lineGap))) + lineIdx * (fontSize + lineGap))
|
||||
.attr('text-anchor', 'middle')
|
||||
.attr('dominant-baseline', 'hanging')
|
||||
.attr('font-size', fontSize)
|
||||
.attr('font-family', 'SF Mono, monospace')
|
||||
.attr('fill', conn.missingIface ? 'rgba(248,113,113,0.9)' : 'rgba(255,255,255,0.9)')
|
||||
.text(line);
|
||||
});
|
||||
// Determine which side of viewport based on edge midpoint
|
||||
const isLeft = mx < centerX;
|
||||
const isTop = my < safeCenterY;
|
||||
|
||||
// Store for batch rendering after all edges processed
|
||||
if (!debugEdgeLabels) debugEdgeLabels = [];
|
||||
debugEdgeLabels.push({
|
||||
connections: entry.connections,
|
||||
isLeft,
|
||||
isTop,
|
||||
mx,
|
||||
my
|
||||
});
|
||||
|
||||
const bbox = textGroup.node()?.getBBox();
|
||||
if (bbox) {
|
||||
const paddedWidth = Math.max(boxWidth, bbox.width + 14);
|
||||
const boxHeight = bbox.height + 8;
|
||||
const boxMinX = labelX - paddedWidth / 2;
|
||||
const boxMaxX = labelX + paddedWidth / 2;
|
||||
const boxMinY = labelY + bbox.y - 4;
|
||||
const boxMaxY = boxMinY + boxHeight;
|
||||
|
||||
const clampPadDynamic = Math.min(140, minDimension * 0.18);
|
||||
let shiftX = 0;
|
||||
let shiftY = 0;
|
||||
if (boxMinX < clampPadDynamic) shiftX = clampPadDynamic - boxMinX;
|
||||
if (boxMaxX > width - clampPadDynamic) shiftX = (width - clampPadDynamic) - boxMaxX;
|
||||
if (boxMinY < clampPadDynamic) shiftY = clampPadDynamic - boxMinY;
|
||||
if (boxMaxY > height - clampPadDynamic) shiftY = (height - clampPadDynamic) - boxMaxY;
|
||||
|
||||
const finalX = labelX + shiftX;
|
||||
const finalY = labelY + shiftY;
|
||||
labelGroup.attr('transform', `translate(${finalX}, ${finalY})`);
|
||||
|
||||
labelGroup.insert('rect', 'g')
|
||||
.attr('x', -paddedWidth / 2)
|
||||
.attr('y', bbox.y - 4)
|
||||
.attr('width', paddedWidth)
|
||||
.attr('height', boxHeight)
|
||||
.attr('rx', 4)
|
||||
.attr('fill', 'rgba(0,0,0,0.75)')
|
||||
.attr('stroke', 'rgba(255,255,255,0.12)')
|
||||
.attr('stroke-width', 0.6);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Render debug labels at viewport edges/corners
|
||||
if (debugEdgeLabels && debugEdgeLabels.length > 0) {
|
||||
const fontSize = isMinimized ? 10 : 12;
|
||||
const lineHeight = fontSize + 4;
|
||||
const padding = 10;
|
||||
|
||||
// Helper to get arrow based on direction vector
|
||||
function getArrow(fromId: string, toId: string): string {
|
||||
const fromPos = positionById[fromId];
|
||||
const toPos = positionById[toId];
|
||||
if (!fromPos || !toPos) return '→';
|
||||
|
||||
const dirX = toPos.x - fromPos.x;
|
||||
const dirY = toPos.y - fromPos.y;
|
||||
const absX = Math.abs(dirX);
|
||||
const absY = Math.abs(dirY);
|
||||
|
||||
if (absX > absY * 2) {
|
||||
return dirX > 0 ? '→' : '←';
|
||||
} else if (absY > absX * 2) {
|
||||
return dirY > 0 ? '↓' : '↑';
|
||||
} else {
|
||||
if (dirX > 0 && dirY > 0) return '↘';
|
||||
if (dirX > 0 && dirY < 0) return '↗';
|
||||
if (dirX < 0 && dirY > 0) return '↙';
|
||||
return '↖';
|
||||
}
|
||||
}
|
||||
|
||||
// Group by quadrant: topLeft, topRight, bottomLeft, bottomRight
|
||||
const quadrants: Record<string, typeof debugEdgeLabels> = {
|
||||
topLeft: [],
|
||||
topRight: [],
|
||||
bottomLeft: [],
|
||||
bottomRight: []
|
||||
};
|
||||
|
||||
debugEdgeLabels.forEach(edge => {
|
||||
const key = (edge.isTop ? 'top' : 'bottom') + (edge.isLeft ? 'Left' : 'Right');
|
||||
quadrants[key].push(edge);
|
||||
});
|
||||
|
||||
// Render each quadrant
|
||||
Object.entries(quadrants).forEach(([quadrant, edges]) => {
|
||||
if (edges.length === 0) return;
|
||||
|
||||
const isLeft = quadrant.includes('Left');
|
||||
const isTop = quadrant.includes('top');
|
||||
|
||||
let baseX = isLeft ? padding : width - padding;
|
||||
let baseY = isTop ? padding : height - padding;
|
||||
const textAnchor = isLeft ? 'start' : 'end';
|
||||
|
||||
let currentY = baseY;
|
||||
|
||||
edges.forEach(edge => {
|
||||
edge.connections.forEach(conn => {
|
||||
const arrow = getArrow(conn.from, conn.to);
|
||||
const label = `${arrow} ${conn.ip} ${conn.ifaceLabel}`;
|
||||
debugLabelsGroup.append('text')
|
||||
.attr('x', baseX)
|
||||
.attr('y', currentY)
|
||||
.attr('text-anchor', textAnchor)
|
||||
.attr('dominant-baseline', isTop ? 'hanging' : 'auto')
|
||||
.attr('font-size', fontSize)
|
||||
.attr('font-family', 'SF Mono, monospace')
|
||||
.attr('fill', conn.missingIface ? 'rgba(248,113,113,0.9)' : 'rgba(255,255,255,0.85)')
|
||||
.text(label);
|
||||
currentY += isTop ? lineHeight : -lineHeight;
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// Draw nodes
|
||||
const nodesGroup = svg.append('g').attr('class', 'nodes-group');
|
||||
|
||||
@@ -968,4 +975,5 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
from { stroke-dashoffset: 0; }
|
||||
to { stroke-dashoffset: -10; }
|
||||
}
|
||||
|
||||
</style>
|
||||
|
||||
@@ -4,4 +4,5 @@ export { default as ChatMessages } from './ChatMessages.svelte';
|
||||
export { default as ChatAttachments } from './ChatAttachments.svelte';
|
||||
export { default as ChatSidebar } from './ChatSidebar.svelte';
|
||||
export { default as ModelCard } from './ModelCard.svelte';
|
||||
export { default as MarkdownContent } from './MarkdownContent.svelte';
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -18,6 +18,10 @@
|
||||
selectedChatModel,
|
||||
debugMode,
|
||||
toggleDebugMode,
|
||||
topologyOnlyMode,
|
||||
toggleTopologyOnlyMode,
|
||||
chatSidebarVisible,
|
||||
toggleChatSidebarVisible,
|
||||
type DownloadProgress,
|
||||
type PlacementPreview
|
||||
} from '$lib/stores/app.svelte';
|
||||
@@ -37,14 +41,92 @@
|
||||
const selectedModelId = $derived(selectedPreviewModelId());
|
||||
const loadingPreviews = $derived(isLoadingPreviews());
|
||||
const debugEnabled = $derived(debugMode());
|
||||
const topologyOnlyEnabled = $derived(topologyOnlyMode());
|
||||
const sidebarVisible = $derived(chatSidebarVisible());
|
||||
|
||||
let mounted = $state(false);
|
||||
|
||||
// Instance launch state
|
||||
let models = $state<Array<{id: string, name?: string, storage_size_megabytes?: number}>>([]);
|
||||
let models = $state<Array<{id: string, name?: string, storage_size_megabytes?: number, tasks?: string[], hugging_face_id?: string}>>([]);
|
||||
|
||||
// Model tasks lookup for ChatForm - maps both short IDs and full HuggingFace IDs
|
||||
const modelTasks = $derived(() => {
|
||||
const tasks: Record<string, string[]> = {};
|
||||
for (const model of models) {
|
||||
if (model.tasks && model.tasks.length > 0) {
|
||||
// Map by short ID
|
||||
tasks[model.id] = model.tasks;
|
||||
// Also map by hugging_face_id from the API response
|
||||
if (model.hugging_face_id) {
|
||||
tasks[model.hugging_face_id] = model.tasks;
|
||||
}
|
||||
}
|
||||
}
|
||||
return tasks;
|
||||
});
|
||||
|
||||
// Helper to check if a model supports image generation
|
||||
function modelSupportsImageGeneration(modelId: string): boolean {
|
||||
const model = models.find(m => m.id === modelId || m.hugging_face_id === modelId);
|
||||
if (!model?.tasks) return false;
|
||||
return model.tasks.includes('TextToImage') || model.tasks.includes('ImageToImage');
|
||||
}
|
||||
let selectedSharding = $state<'Pipeline' | 'Tensor'>('Pipeline');
|
||||
type InstanceMeta = 'MlxRing' | 'MlxIbv' | 'MlxJaccl';
|
||||
|
||||
// Launch defaults persistence
|
||||
const LAUNCH_DEFAULTS_KEY = 'exo-launch-defaults';
|
||||
interface LaunchDefaults {
|
||||
modelId: string | null;
|
||||
sharding: 'Pipeline' | 'Tensor';
|
||||
instanceType: InstanceMeta;
|
||||
minNodes: number;
|
||||
}
|
||||
|
||||
function saveLaunchDefaults(): void {
|
||||
const defaults: LaunchDefaults = {
|
||||
modelId: selectedPreviewModelId(),
|
||||
sharding: selectedSharding,
|
||||
instanceType: selectedInstanceType,
|
||||
minNodes: selectedMinNodes,
|
||||
};
|
||||
try {
|
||||
localStorage.setItem(LAUNCH_DEFAULTS_KEY, JSON.stringify(defaults));
|
||||
} catch (e) {
|
||||
console.warn('Failed to save launch defaults:', e);
|
||||
}
|
||||
}
|
||||
|
||||
function loadLaunchDefaults(): LaunchDefaults | null {
|
||||
try {
|
||||
const stored = localStorage.getItem(LAUNCH_DEFAULTS_KEY);
|
||||
if (!stored) return null;
|
||||
return JSON.parse(stored) as LaunchDefaults;
|
||||
} catch (e) {
|
||||
console.warn('Failed to load launch defaults:', e);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function applyLaunchDefaults(availableModels: Array<{id: string}>, maxNodes: number): void {
|
||||
const defaults = loadLaunchDefaults();
|
||||
if (!defaults) return;
|
||||
|
||||
// Apply sharding and instance type unconditionally
|
||||
selectedSharding = defaults.sharding;
|
||||
selectedInstanceType = defaults.instanceType;
|
||||
|
||||
// Apply minNodes if valid (between 1 and maxNodes)
|
||||
if (defaults.minNodes && defaults.minNodes >= 1 && defaults.minNodes <= maxNodes) {
|
||||
selectedMinNodes = defaults.minNodes;
|
||||
}
|
||||
|
||||
// Only apply model if it exists in the available models
|
||||
if (defaults.modelId && availableModels.some(m => m.id === defaults.modelId)) {
|
||||
selectPreviewModel(defaults.modelId);
|
||||
}
|
||||
}
|
||||
|
||||
let selectedInstanceType = $state<InstanceMeta>('MlxRing');
|
||||
let selectedMinNodes = $state<number>(1);
|
||||
let minNodesInitialized = $state(false);
|
||||
@@ -292,6 +374,9 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
const data = await response.json();
|
||||
// API returns { data: [{ id, name }] } format
|
||||
models = data.data || [];
|
||||
// Restore last launch defaults if available
|
||||
const currentNodeCount = topologyData() ? Object.keys(topologyData()!.nodes).length : 1;
|
||||
applyLaunchDefaults(models, currentNodeCount);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to fetch models:', error);
|
||||
@@ -472,6 +557,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
|
||||
const progress = parseDownloadProgress(downloadPayload);
|
||||
if (progress) {
|
||||
// Sum all values across nodes - each node downloads independently
|
||||
totalBytes += progress.totalBytes;
|
||||
downloadedBytes += progress.downloadedBytes;
|
||||
totalSpeed += progress.speed;
|
||||
@@ -489,13 +575,17 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
return { isDownloading: false, progress: null, perNode: [] };
|
||||
}
|
||||
|
||||
// ETA = total remaining bytes / total speed across all nodes
|
||||
const remainingBytes = totalBytes - downloadedBytes;
|
||||
const etaMs = totalSpeed > 0 ? (remainingBytes / totalSpeed) * 1000 : 0;
|
||||
|
||||
return {
|
||||
isDownloading: true,
|
||||
progress: {
|
||||
totalBytes,
|
||||
downloadedBytes,
|
||||
speed: totalSpeed,
|
||||
etaMs: totalSpeed > 0 ? ((totalBytes - downloadedBytes) / totalSpeed) * 1000 : 0,
|
||||
etaMs,
|
||||
percentage: totalBytes > 0 ? (downloadedBytes / totalBytes) * 100 : 0,
|
||||
completedFiles,
|
||||
totalFiles,
|
||||
@@ -576,6 +666,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
|
||||
const progress = parseDownloadProgress(downloadPayload);
|
||||
if (progress) {
|
||||
// Sum all values across nodes - each node downloads independently
|
||||
totalBytes += progress.totalBytes;
|
||||
downloadedBytes += progress.downloadedBytes;
|
||||
totalSpeed += progress.speed;
|
||||
@@ -596,13 +687,17 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
return { isDownloading: false, progress: null, statusText: statusInfo.statusText, perNode: [] };
|
||||
}
|
||||
|
||||
// ETA = total remaining bytes / total speed across all nodes
|
||||
const remainingBytes = totalBytes - downloadedBytes;
|
||||
const etaMs = totalSpeed > 0 ? (remainingBytes / totalSpeed) * 1000 : 0;
|
||||
|
||||
return {
|
||||
isDownloading: true,
|
||||
progress: {
|
||||
totalBytes,
|
||||
downloadedBytes,
|
||||
speed: totalSpeed,
|
||||
etaMs: totalSpeed > 0 ? ((totalBytes - downloadedBytes) / totalSpeed) * 1000 : 0,
|
||||
etaMs,
|
||||
percentage: totalBytes > 0 ? (downloadedBytes / totalBytes) * 100 : 0,
|
||||
completedFiles,
|
||||
totalFiles,
|
||||
@@ -618,10 +713,12 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
function getStatusColor(statusText: string): string {
|
||||
switch (statusText) {
|
||||
case 'FAILED': return 'text-red-400';
|
||||
case 'SHUTDOWN': return 'text-gray-400';
|
||||
case 'DOWNLOADING': return 'text-blue-400';
|
||||
case 'LOADING':
|
||||
case 'WARMING UP':
|
||||
case 'WAITING': return 'text-yellow-400';
|
||||
case 'WAITING':
|
||||
case 'INITIALIZING': return 'text-yellow-400';
|
||||
case 'RUNNING': return 'text-teal-400';
|
||||
case 'READY':
|
||||
case 'LOADED': return 'text-green-400';
|
||||
@@ -644,12 +741,15 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
if (!r) return null;
|
||||
const [kind] = getTagged(r);
|
||||
const statusMap: Record<string, string> = {
|
||||
RunnerWaitingForInitialization: 'WaitingForInitialization',
|
||||
RunnerInitializingBackend: 'InitializingBackend',
|
||||
RunnerWaitingForModel: 'WaitingForModel',
|
||||
RunnerLoading: 'Loading',
|
||||
RunnerLoaded: 'Loaded',
|
||||
RunnerWarmingUp: 'WarmingUp',
|
||||
RunnerReady: 'Ready',
|
||||
RunnerRunning: 'Running',
|
||||
RunnerShutdown: 'Shutdown',
|
||||
RunnerFailed: 'Failed',
|
||||
};
|
||||
return kind ? statusMap[kind] || null : null;
|
||||
@@ -660,12 +760,15 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
|
||||
if (statuses.length === 0) return { statusText: 'UNKNOWN', statusClass: 'inactive' };
|
||||
if (has('Failed')) return { statusText: 'FAILED', statusClass: 'failed' };
|
||||
if (has('Shutdown')) return { statusText: 'SHUTDOWN', statusClass: 'inactive' };
|
||||
if (has('Loading')) return { statusText: 'LOADING', statusClass: 'starting' };
|
||||
if (has('WarmingUp')) return { statusText: 'WARMING UP', statusClass: 'starting' };
|
||||
if (has('Running')) return { statusText: 'RUNNING', statusClass: 'running' };
|
||||
if (has('Ready')) return { statusText: 'READY', statusClass: 'loaded' };
|
||||
if (has('Loaded')) return { statusText: 'LOADED', statusClass: 'loaded' };
|
||||
if (has('WaitingForModel')) return { statusText: 'WAITING', statusClass: 'starting' };
|
||||
if (has('InitializingBackend')) return { statusText: 'INITIALIZING', statusClass: 'starting' };
|
||||
if (has('WaitingForInitialization')) return { statusText: 'INITIALIZING', statusClass: 'starting' };
|
||||
|
||||
return { statusText: 'RUNNING', statusClass: 'active' };
|
||||
}
|
||||
@@ -964,6 +1067,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
|
||||
function handleSliderMouseUp() {
|
||||
isDraggingSlider = false;
|
||||
saveLaunchDefaults();
|
||||
}
|
||||
|
||||
// Handle touch events for mobile
|
||||
@@ -983,6 +1087,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
|
||||
function handleSliderTouchEnd() {
|
||||
isDraggingSlider = false;
|
||||
saveLaunchDefaults();
|
||||
}
|
||||
|
||||
const nodeCount = $derived(data ? Object.keys(data.nodes).length : 0);
|
||||
@@ -1107,16 +1212,47 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
<div class="shooting-star" style="top: 50%; left: 40%; --duration: 45s; --delay: 30s;"></div>
|
||||
</div>
|
||||
|
||||
<HeaderNav showHome={chatStarted} onHome={handleGoHome} />
|
||||
{#if !topologyOnlyEnabled}
|
||||
<HeaderNav
|
||||
showHome={chatStarted}
|
||||
onHome={handleGoHome}
|
||||
showSidebarToggle={true}
|
||||
sidebarVisible={sidebarVisible}
|
||||
onToggleSidebar={toggleChatSidebarVisible}
|
||||
/>
|
||||
{/if}
|
||||
|
||||
<!-- Main Content -->
|
||||
<main class="flex-1 flex overflow-hidden relative">
|
||||
<!-- Left: Conversation History Sidebar (always visible) -->
|
||||
<!-- Left: Conversation History Sidebar (hidden in topology-only mode or when toggled off) -->
|
||||
{#if !topologyOnlyEnabled && sidebarVisible}
|
||||
<div class="w-80 flex-shrink-0 border-r border-exo-yellow/10">
|
||||
<ChatSidebar class="h-full" />
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
{#if !chatStarted}
|
||||
{#if topologyOnlyEnabled}
|
||||
<!-- TOPOLOGY ONLY MODE: Full-screen topology -->
|
||||
<div class="flex-1 flex flex-col min-h-0 min-w-0 p-4" in:fade={{ duration: 300 }}>
|
||||
<div class="flex-1 relative bg-exo-dark-gray/40 rounded-lg overflow-hidden">
|
||||
<TopologyGraph class="w-full h-full" highlightedNodes={highlightedNodes()} />
|
||||
<!-- Exit topology-only mode button -->
|
||||
<button
|
||||
type="button"
|
||||
onclick={toggleTopologyOnlyMode}
|
||||
class="absolute bottom-4 right-4 p-2 rounded border border-exo-yellow/30 bg-exo-dark-gray/80 hover:border-exo-yellow/50 hover:bg-exo-dark-gray transition-colors cursor-pointer backdrop-blur-sm"
|
||||
title="Exit topology only mode"
|
||||
>
|
||||
<svg class="w-5 h-5 text-exo-yellow" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<circle cx="12" cy="5" r="2" fill="currentColor" />
|
||||
<circle cx="5" cy="19" r="2" fill="currentColor" />
|
||||
<circle cx="19" cy="19" r="2" fill="currentColor" />
|
||||
<path stroke-linecap="round" d="M12 7v5m0 0l-5 5m5-5l5 5" />
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
{:else if !chatStarted}
|
||||
<!-- WELCOME STATE: Topology + Instance Controls (no left sidebar for cleaner look) -->
|
||||
<div class="flex-1 flex overflow-visible relative" in:fade={{ duration: 300 }} out:fade={{ duration: 200 }}>
|
||||
|
||||
@@ -1137,6 +1273,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
placeholder="Ask anything"
|
||||
showHelperText={false}
|
||||
showModelSelector={true}
|
||||
modelTasks={modelTasks()}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
@@ -1300,14 +1437,15 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
{:else}
|
||||
{#each nodeProg.progress.files as f}
|
||||
{@const filePercent = Math.min(100, Math.max(0, f.percentage ?? 0))}
|
||||
{@const isFileComplete = filePercent >= 100}
|
||||
<div class="rounded border border-exo-medium-gray/30 bg-exo-black/40 p-2">
|
||||
<div class="flex items-center justify-between text-[10px] font-mono text-exo-light-gray/90">
|
||||
<span class="truncate pr-2">{f.name}</span>
|
||||
<span class="text-white/80">{filePercent.toFixed(1)}%</span>
|
||||
<span class={isFileComplete ? 'text-green-400' : 'text-white/80'}>{filePercent.toFixed(1)}%</span>
|
||||
</div>
|
||||
<div class="relative h-1 bg-exo-black/60 rounded-sm overflow-hidden mt-1">
|
||||
<div
|
||||
class="absolute inset-y-0 left-0 bg-gradient-to-r from-exo-yellow to-exo-yellow/70 transition-all duration-300"
|
||||
class="absolute inset-y-0 left-0 bg-gradient-to-r {isFileComplete ? 'from-green-500 to-green-400' : 'from-exo-yellow to-exo-yellow/70'} transition-all duration-300"
|
||||
style="width: {filePercent.toFixed(1)}%"
|
||||
></div>
|
||||
</div>
|
||||
@@ -1357,8 +1495,18 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
{@const foundModel = models.find(m => m.id === selectedModelId)}
|
||||
{#if foundModel}
|
||||
{@const sizeGB = getModelSizeGB(foundModel)}
|
||||
<span class="flex items-center justify-between gap-2 w-full pr-4">
|
||||
<span class="text-exo-light-gray truncate">{foundModel.name || foundModel.id}</span>
|
||||
{@const isImageModel = modelSupportsImageGeneration(foundModel.id)}
|
||||
<span class="flex items-center justify-between gap-2 w-full pr-4">
|
||||
<span class="flex items-center gap-2 text-exo-light-gray truncate">
|
||||
{#if isImageModel}
|
||||
<svg class="w-4 h-4 flex-shrink-0 text-exo-yellow" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
|
||||
<circle cx="8.5" cy="8.5" r="1.5"/>
|
||||
<polyline points="21 15 16 10 5 21"/>
|
||||
</svg>
|
||||
{/if}
|
||||
<span class="truncate">{foundModel.name || foundModel.id}</span>
|
||||
</span>
|
||||
<span class="text-white/50 text-xs flex-shrink-0">{sizeGB >= 1 ? sizeGB.toFixed(0) : sizeGB.toFixed(1)}GB</span>
|
||||
</span>
|
||||
{:else}
|
||||
@@ -1403,11 +1551,13 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
) as model}
|
||||
{@const sizeGB = getModelSizeGB(model)}
|
||||
{@const modelCanFit = hasEnoughMemory(model)}
|
||||
{@const isImageModel = modelSupportsImageGeneration(model.id)}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => {
|
||||
if (modelCanFit) {
|
||||
selectPreviewModel(model.id);
|
||||
saveLaunchDefaults();
|
||||
isModelDropdownOpen = false;
|
||||
modelDropdownSearch = '';
|
||||
}
|
||||
@@ -1421,7 +1571,16 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
: 'text-white/30 cursor-default'
|
||||
}"
|
||||
>
|
||||
<span class="truncate">{model.name || model.id}</span>
|
||||
<span class="flex items-center gap-2 truncate flex-1">
|
||||
{#if isImageModel}
|
||||
<svg class="w-4 h-4 flex-shrink-0 text-exo-yellow" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2" aria-label="Image generation model">
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
|
||||
<circle cx="8.5" cy="8.5" r="1.5"/>
|
||||
<polyline points="21 15 16 10 5 21"/>
|
||||
</svg>
|
||||
{/if}
|
||||
<span class="truncate">{model.name || model.id}</span>
|
||||
</span>
|
||||
<span class="flex-shrink-0 text-xs {modelCanFit ? 'text-white/50' : 'text-red-400/60'}">
|
||||
{sizeGB >= 1 ? sizeGB.toFixed(0) : sizeGB.toFixed(1)}GB
|
||||
</span>
|
||||
@@ -1441,7 +1600,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
<div class="text-xs text-white/70 font-mono mb-2">Sharding:</div>
|
||||
<div class="flex gap-2">
|
||||
<button
|
||||
onclick={() => selectedSharding = 'Pipeline'}
|
||||
onclick={() => { selectedSharding = 'Pipeline'; saveLaunchDefaults(); }}
|
||||
class="flex items-center gap-2 py-2 px-4 text-sm font-mono border rounded transition-all duration-200 cursor-pointer {selectedSharding === 'Pipeline' ? 'bg-transparent text-exo-yellow border-exo-yellow' : 'bg-transparent text-white/70 border-exo-medium-gray/50 hover:border-exo-yellow/50'}"
|
||||
>
|
||||
<span class="w-4 h-4 rounded-full border-2 flex items-center justify-center {selectedSharding === 'Pipeline' ? 'border-exo-yellow' : 'border-exo-medium-gray'}">
|
||||
@@ -1452,7 +1611,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
Pipeline
|
||||
</button>
|
||||
<button
|
||||
onclick={() => selectedSharding = 'Tensor'}
|
||||
onclick={() => { selectedSharding = 'Tensor'; saveLaunchDefaults(); }}
|
||||
class="flex items-center gap-2 py-2 px-4 text-sm font-mono border rounded transition-all duration-200 cursor-pointer {selectedSharding === 'Tensor' ? 'bg-transparent text-exo-yellow border-exo-yellow' : 'bg-transparent text-white/70 border-exo-medium-gray/50 hover:border-exo-yellow/50'}"
|
||||
>
|
||||
<span class="w-4 h-4 rounded-full border-2 flex items-center justify-center {selectedSharding === 'Tensor' ? 'border-exo-yellow' : 'border-exo-medium-gray'}">
|
||||
@@ -1470,7 +1629,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
<div class="text-xs text-white/70 font-mono mb-2">Instance Type:</div>
|
||||
<div class="flex gap-2">
|
||||
<button
|
||||
onclick={() => selectedInstanceType = 'MlxRing'}
|
||||
onclick={() => { selectedInstanceType = 'MlxRing'; saveLaunchDefaults(); }}
|
||||
class="flex items-center gap-2 py-2 px-4 text-sm font-mono border rounded transition-all duration-200 cursor-pointer {selectedInstanceType === 'MlxRing' ? 'bg-transparent text-exo-yellow border-exo-yellow' : 'bg-transparent text-white/70 border-exo-medium-gray/50 hover:border-exo-yellow/50'}"
|
||||
>
|
||||
<span class="w-4 h-4 rounded-full border-2 flex items-center justify-center {selectedInstanceType === 'MlxRing' ? 'border-exo-yellow' : 'border-exo-medium-gray'}">
|
||||
@@ -1481,7 +1640,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
MLX Ring
|
||||
</button>
|
||||
<button
|
||||
onclick={() => selectedInstanceType = 'MlxIbv'}
|
||||
onclick={() => { selectedInstanceType = 'MlxIbv'; saveLaunchDefaults(); }}
|
||||
class="flex items-center gap-2 py-2 px-4 text-sm font-mono border rounded transition-all duration-200 cursor-pointer {selectedInstanceType === 'MlxIbv' ? 'bg-transparent text-exo-yellow border-exo-yellow' : 'bg-transparent text-white/70 border-exo-medium-gray/50 hover:border-exo-yellow/50'}"
|
||||
>
|
||||
<span class="w-4 h-4 rounded-full border-2 flex items-center justify-center {selectedInstanceType === 'MlxIbv' ? 'border-exo-yellow' : 'border-exo-medium-gray'}">
|
||||
@@ -1611,14 +1770,14 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
in:fade={{ duration: 300, delay: 100 }}
|
||||
>
|
||||
<div class="flex-1 overflow-y-auto px-8 py-6" bind:this={chatScrollRef}>
|
||||
<div class="max-w-3xl mx-auto">
|
||||
<div class="max-w-7xl mx-auto">
|
||||
<ChatMessages scrollParent={chatScrollRef} />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="flex-shrink-0 px-8 pb-6 pt-4 bg-gradient-to-t from-exo-black via-exo-black to-transparent">
|
||||
<div class="max-w-3xl mx-auto">
|
||||
<ChatForm placeholder="Ask anything" showModelSelector={true} />
|
||||
<div class="max-w-7xl mx-auto">
|
||||
<ChatForm placeholder="Ask anything" showModelSelector={true} modelTasks={modelTasks()} />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -1655,7 +1814,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
<!-- Panel Header -->
|
||||
<div class="flex items-center gap-2 mb-4">
|
||||
<div class="w-2 h-2 bg-exo-yellow rounded-full shadow-[0_0_8px_rgba(255,215,0,0.6)] animate-pulse"></div>
|
||||
<h3 class="text-sm text-exo-yellow font-mono tracking-[0.2em] uppercase">Instances</h3>
|
||||
<h3 class="text-xs text-exo-yellow font-mono tracking-[0.2em] uppercase">Instances</h3>
|
||||
<div class="flex-1 h-px bg-gradient-to-r from-exo-yellow/30 to-transparent"></div>
|
||||
</div>
|
||||
<div class="space-y-3 max-h-72 overflow-y-auto pr-1">
|
||||
@@ -1701,28 +1860,28 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
<div class="flex justify-between items-start mb-2 pl-2">
|
||||
<div class="flex items-center gap-2">
|
||||
<div class="w-1.5 h-1.5 {isDownloading ? 'bg-blue-400 animate-pulse' : isFailed ? 'bg-red-400' : isLoading ? 'bg-yellow-400 animate-pulse' : isReady ? 'bg-green-400' : 'bg-teal-400'} rounded-full shadow-[0_0_6px_currentColor]"></div>
|
||||
<span class="text-exo-light-gray font-mono text-xs tracking-wider">{id.slice(0, 8).toUpperCase()}</span>
|
||||
<span class="text-exo-light-gray font-mono text-sm tracking-wider">{id.slice(0, 8).toUpperCase()}</span>
|
||||
</div>
|
||||
<button
|
||||
onclick={() => deleteInstance(id)}
|
||||
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400/80 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
|
||||
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
|
||||
>
|
||||
DELETE
|
||||
</button>
|
||||
</div>
|
||||
<div class="pl-2">
|
||||
<div class="text-exo-yellow text-sm font-mono tracking-wide truncate">{getInstanceModelId(instance)}</div>
|
||||
<div class="text-exo-yellow text-xs font-mono tracking-wide truncate">{getInstanceModelId(instance)}</div>
|
||||
<div class="text-white/60 text-xs font-mono">Strategy: <span class="text-white/80">{instanceInfo.sharding} ({instanceInfo.instanceType})</span></div>
|
||||
{#if instanceModelId && instanceModelId !== 'Unknown' && instanceModelId !== 'Unknown Model'}
|
||||
<a
|
||||
class="inline-flex items-center gap-1 text-[10px] text-white/60 hover:text-exo-yellow transition-colors mt-0.5"
|
||||
class="inline-flex items-center gap-1 text-[11px] text-white/60 hover:text-exo-yellow transition-colors mt-1"
|
||||
href={`https://huggingface.co/${instanceModelId}`}
|
||||
target="_blank"
|
||||
rel="noreferrer noopener"
|
||||
aria-label="View model on Hugging Face"
|
||||
>
|
||||
<span>Hugging Face</span>
|
||||
<svg class="w-3 h-3" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<svg class="w-3.5 h-3.5" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<path d="M14 3h7v7"/>
|
||||
<path d="M10 14l11-11"/>
|
||||
<path d="M21 14v6a1 1 0 0 1-1 1h-16a1 1 0 0 1-1-1v-16a1 1 0 0 1 1-1h6"/>
|
||||
@@ -1733,68 +1892,84 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
<div class="text-white/60 text-xs font-mono">{instanceInfo.nodeNames.join(', ')}</div>
|
||||
{/if}
|
||||
{#if debugEnabled && instanceConnections.length > 0}
|
||||
<div class="mt-1 space-y-0.5">
|
||||
{#each instanceConnections as conn}
|
||||
<div class="text-[10px] leading-snug font-mono text-white/70">
|
||||
<span>{conn.from} -> {conn.to}: {conn.ip}</span>
|
||||
<span class="{conn.missingIface ? 'text-red-400' : 'text-white/60'}"> ({conn.ifaceLabel})</span>
|
||||
</div>
|
||||
{/each}
|
||||
<div class="mt-2 space-y-1">
|
||||
{#each instanceConnections as conn}
|
||||
<div class="text-[11px] leading-snug font-mono text-white/70">
|
||||
<span>{conn.from} -> {conn.to}: {conn.ip}</span>
|
||||
<span class="{conn.missingIface ? 'text-red-400' : 'text-white/60'}"> ({conn.ifaceLabel})</span>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Download Progress -->
|
||||
{#if downloadInfo.isDownloading && downloadInfo.progress}
|
||||
<div class="mt-2 space-y-1">
|
||||
<div class="flex justify-between text-xs font-mono">
|
||||
<span class="text-blue-400">{downloadInfo.progress.percentage.toFixed(1)}%</span>
|
||||
<span class="text-exo-light-gray">{formatBytes(downloadInfo.progress.downloadedBytes)}/{formatBytes(downloadInfo.progress.totalBytes)}</span>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Download Progress -->
|
||||
{#if downloadInfo.isDownloading && downloadInfo.progress}
|
||||
<div class="mt-2 space-y-1">
|
||||
<div class="flex justify-between text-sm font-mono">
|
||||
<span class="text-blue-400">{downloadInfo.progress.percentage.toFixed(1)}%</span>
|
||||
<span class="text-exo-light-gray">{formatBytes(downloadInfo.progress.downloadedBytes)}/{formatBytes(downloadInfo.progress.totalBytes)}</span>
|
||||
</div>
|
||||
<div class="relative h-1 bg-exo-black/60 rounded-sm overflow-hidden">
|
||||
<div
|
||||
class="absolute inset-y-0 left-0 bg-gradient-to-r from-blue-500 to-blue-400 transition-all duration-300"
|
||||
style="width: {downloadInfo.progress.percentage}%"
|
||||
></div>
|
||||
</div>
|
||||
<div class="flex justify-between text-xs font-mono text-exo-light-gray">
|
||||
<span>{formatSpeed(downloadInfo.progress.speed)}</span>
|
||||
<span>ETA: {formatEta(downloadInfo.progress.etaMs)}</span>
|
||||
<span>{downloadInfo.progress.completedFiles}/{downloadInfo.progress.totalFiles} files</span>
|
||||
</div>
|
||||
<div class="relative h-1.5 bg-exo-black/60 rounded-sm overflow-hidden">
|
||||
<div
|
||||
class="absolute inset-y-0 left-0 bg-gradient-to-r from-blue-500 to-blue-400 transition-all duration-300"
|
||||
style="width: {downloadInfo.progress.percentage}%"
|
||||
></div>
|
||||
</div>
|
||||
{#if downloadInfo.perNode.length > 0}
|
||||
<div class="mt-2 space-y-1.5 max-h-48 overflow-y-auto pr-1">
|
||||
{#each downloadInfo.perNode as nodeProg}
|
||||
<div class="rounded border border-exo-medium-gray/40 bg-exo-black/30 p-2">
|
||||
<div class="flex items-center justify-between text-[11px] font-mono text-exo-light-gray mb-1">
|
||||
<div class="flex justify-between text-xs font-mono text-exo-light-gray">
|
||||
<span>{formatSpeed(downloadInfo.progress.speed)}</span>
|
||||
<span>ETA: {formatEta(downloadInfo.progress.etaMs)}</span>
|
||||
<span>{downloadInfo.progress.completedFiles}/{downloadInfo.progress.totalFiles} files</span>
|
||||
</div>
|
||||
</div>
|
||||
{#if downloadInfo.perNode.length > 0}
|
||||
<div class="mt-2 space-y-2 max-h-48 overflow-y-auto pr-1">
|
||||
{#each downloadInfo.perNode as nodeProg}
|
||||
{@const nodePercent = Math.min(100, Math.max(0, nodeProg.progress.percentage))}
|
||||
{@const isExpanded = instanceDownloadExpandedNodes.has(nodeProg.nodeId)}
|
||||
<div class="rounded border border-exo-medium-gray/40 bg-exo-black/30 p-2">
|
||||
<button
|
||||
type="button"
|
||||
class="w-full text-left space-y-1.5"
|
||||
onclick={() => toggleInstanceDownloadDetails(nodeProg.nodeId)}
|
||||
>
|
||||
<div class="flex items-center justify-between text-[11px] font-mono text-exo-light-gray">
|
||||
<span class="text-white/80 truncate pr-2">{nodeProg.nodeName}</span>
|
||||
<span class="text-blue-300">{Math.min(100, Math.max(0, nodeProg.progress.percentage)).toFixed(1)}%</span>
|
||||
<span class="flex items-center gap-1 text-blue-300">
|
||||
{nodePercent.toFixed(1)}%
|
||||
<svg class="w-3 h-3 text-exo-light-gray" viewBox="0 0 20 20" fill="none" stroke="currentColor" stroke-width="2">
|
||||
<path d="M6 8l4 4 4-4" class={isExpanded ? 'transform rotate-180 origin-center transition-transform duration-150' : 'transition-transform duration-150'}></path>
|
||||
</svg>
|
||||
</span>
|
||||
</div>
|
||||
<div class="relative h-1 bg-exo-black/60 rounded-sm overflow-hidden mb-1.5">
|
||||
<div class="relative h-1.5 bg-exo-black/60 rounded-sm overflow-hidden">
|
||||
<div
|
||||
class="absolute inset-y-0 left-0 bg-blue-500/80 transition-all duration-300"
|
||||
style="width: {Math.min(100, Math.max(0, nodeProg.progress.percentage)).toFixed(1)}%"
|
||||
class="absolute inset-y-0 left-0 bg-gradient-to-r from-blue-500 to-blue-400 transition-all duration-300"
|
||||
style="width: {nodePercent.toFixed(1)}%"
|
||||
></div>
|
||||
</div>
|
||||
<div class="flex items-center justify-between text-[11px] font-mono text-exo-light-gray mb-1">
|
||||
<div class="flex items-center justify-between text-[11px] font-mono text-exo-light-gray">
|
||||
<span>{formatBytes(nodeProg.progress.downloadedBytes)} / {formatBytes(nodeProg.progress.totalBytes)}</span>
|
||||
<span>{formatSpeed(nodeProg.progress.speed)} • ETA {formatEta(nodeProg.progress.etaMs)}</span>
|
||||
</div>
|
||||
{#if nodeProg.progress.files.length > 0}
|
||||
{@const inProgressFiles = nodeProg.progress.files.filter(f => (f.percentage ?? 0) < 100)}
|
||||
{@const completedFiles = nodeProg.progress.files.filter(f => (f.percentage ?? 0) >= 100)}
|
||||
{#if inProgressFiles.length > 0}
|
||||
<div class="space-y-1">
|
||||
{#each inProgressFiles as f}
|
||||
<div class="text-[10px] font-mono text-exo-light-gray/80">
|
||||
<div class="flex items-center justify-between">
|
||||
</button>
|
||||
|
||||
{#if isExpanded}
|
||||
<div class="mt-2 space-y-1.5">
|
||||
{#if nodeProg.progress.files.length === 0}
|
||||
<div class="text-[11px] font-mono text-exo-light-gray/70">No file details reported.</div>
|
||||
{:else}
|
||||
{#each nodeProg.progress.files as f}
|
||||
{@const filePercent = Math.min(100, Math.max(0, f.percentage ?? 0))}
|
||||
{@const isFileComplete = filePercent >= 100}
|
||||
<div class="rounded border border-exo-medium-gray/30 bg-exo-black/40 p-2">
|
||||
<div class="flex items-center justify-between text-[10px] font-mono text-exo-light-gray/90">
|
||||
<span class="truncate pr-2">{f.name}</span>
|
||||
<span class="text-white/70">{Math.min(100, Math.max(0, f.percentage)).toFixed(1)}%</span>
|
||||
<span class={isFileComplete ? 'text-green-400' : 'text-white/80'}>{filePercent.toFixed(1)}%</span>
|
||||
</div>
|
||||
<div class="relative h-1 bg-exo-black/50 rounded-sm overflow-hidden mt-0.5">
|
||||
<div class="relative h-1 bg-exo-black/60 rounded-sm overflow-hidden mt-1">
|
||||
<div
|
||||
class="absolute inset-y-0 left-0 bg-gradient-to-r from-exo-yellow to-exo-yellow/70"
|
||||
style="width: {Math.min(100, Math.max(0, f.percentage)).toFixed(1)}%"
|
||||
class="absolute inset-y-0 left-0 bg-gradient-to-r {isFileComplete ? 'from-green-500 to-green-400' : 'from-exo-yellow to-exo-yellow/70'} transition-all duration-300"
|
||||
style="width: {filePercent.toFixed(1)}%"
|
||||
></div>
|
||||
</div>
|
||||
<div class="flex items-center justify-between text-[10px] text-exo-light-gray/70 mt-0.5">
|
||||
@@ -1803,27 +1978,17 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
</div>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
{#if completedFiles.length > 0}
|
||||
<div class="pt-1 space-y-0.5">
|
||||
{#each completedFiles as f}
|
||||
<div class="text-[10px] font-mono text-exo-light-gray/70 flex items-center justify-between">
|
||||
<span class="truncate pr-2">{f.name}</span>
|
||||
<span class="text-white/60">100%</span>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
<div class="text-sm text-blue-400 font-mono tracking-wider mt-1">DOWNLOADING</div>
|
||||
{:else}
|
||||
<div class="text-sm {getStatusColor(downloadInfo.statusText)} font-mono tracking-wider mt-1">{downloadInfo.statusText}</div>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
<div class="text-xs text-blue-400 font-mono tracking-wider mt-1">DOWNLOADING</div>
|
||||
{:else}
|
||||
<div class="text-xs {getStatusColor(downloadInfo.statusText)} font-mono tracking-wider mt-1">{downloadInfo.statusText}</div>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -345,13 +345,19 @@
|
||||
<div class="rounded border border-exo-medium-gray/30 bg-exo-dark-gray/60 p-3 space-y-2">
|
||||
<div class="flex items-center justify-between gap-3">
|
||||
<div class="min-w-0 space-y-0.5">
|
||||
<div class="text-sm font-mono text-white truncate">{model.prettyName ?? model.modelId}</div>
|
||||
<div class="text-[11px] text-exo-light-gray font-mono truncate">
|
||||
{model.modelId}
|
||||
</div>
|
||||
<div class="text-[11px] text-exo-light-gray font-mono">
|
||||
{formatBytes(model.downloadedBytes)} / {formatBytes(model.totalBytes)}
|
||||
</div>
|
||||
<div
|
||||
class="text-xs font-mono text-white truncate"
|
||||
title={model.prettyName ?? model.modelId}
|
||||
>{model.prettyName ?? model.modelId}</div>
|
||||
<div
|
||||
class="text-[10px] text-exo-light-gray font-mono truncate"
|
||||
title={model.modelId}
|
||||
>{model.modelId}</div>
|
||||
{#if model.status !== 'completed'}
|
||||
<div class="text-[11px] text-exo-light-gray font-mono">
|
||||
{formatBytes(model.downloadedBytes)} / {formatBytes(model.totalBytes)}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
<div class="flex items-center gap-2">
|
||||
<span class="text-xs font-mono {pct >= 100 ? 'text-green-400' : pct <= 0 ? 'text-red-400' : 'text-exo-yellow'}">
|
||||
@@ -426,14 +432,14 @@
|
||||
<style>
|
||||
.downloads-grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fill, minmax(260px, 1fr));
|
||||
grid-template-columns: repeat(auto-fill, minmax(320px, 1fr));
|
||||
}
|
||||
@media (min-width: 1024px) {
|
||||
.downloads-grid {
|
||||
grid-template-columns: repeat(3, minmax(0, 1fr));
|
||||
}
|
||||
}
|
||||
@media (min-width: 1440px) {
|
||||
@media (min-width: 1600px) {
|
||||
.downloads-grid {
|
||||
grid-template-columns: repeat(4, minmax(0, 1fr));
|
||||
}
|
||||
|
||||
@@ -29,10 +29,14 @@ dependencies = [
|
||||
"exo_pyo3_bindings", # rust bindings
|
||||
"anyio==4.11.0",
|
||||
"bidict>=0.23.1",
|
||||
"mlx>=0.30.1",
|
||||
"mlx>=0.30.1; sys_platform == 'darwin'",
|
||||
"mlx[cpu]>=0.30.1; sys_platform == 'linux'",
|
||||
"mlx-lm>=0.28.3",
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
"pillow>=11.0,<12.0", # compatibility with mflux
|
||||
"mflux>=0.12.1",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import argparse
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import signal
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Self
|
||||
@@ -27,7 +28,7 @@ from exo.worker.main import Worker
|
||||
@dataclass
|
||||
class Node:
|
||||
router: Router
|
||||
worker: Worker
|
||||
worker: Worker | None
|
||||
election: Election # Every node participates in election, as we do want a node to become master even if it isn't a master candidate if no master candidates are present.
|
||||
election_result_receiver: Receiver[ElectionResult]
|
||||
master: Master | None
|
||||
@@ -61,15 +62,19 @@ class Node:
|
||||
else:
|
||||
api = None
|
||||
|
||||
worker = Worker(
|
||||
node_id,
|
||||
session_id,
|
||||
exo_shard_downloader(),
|
||||
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
|
||||
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
|
||||
local_event_sender=router.sender(topics.LOCAL_EVENTS),
|
||||
command_sender=router.sender(topics.COMMANDS),
|
||||
)
|
||||
if not args.no_worker:
|
||||
worker = Worker(
|
||||
node_id,
|
||||
session_id,
|
||||
exo_shard_downloader(),
|
||||
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
|
||||
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
|
||||
local_event_sender=router.sender(topics.LOCAL_EVENTS),
|
||||
command_sender=router.sender(topics.COMMANDS),
|
||||
)
|
||||
else:
|
||||
worker = None
|
||||
|
||||
# We start every node with a master
|
||||
master = Master(
|
||||
node_id,
|
||||
@@ -99,8 +104,9 @@ class Node:
|
||||
async with self._tg as tg:
|
||||
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
|
||||
tg.start_soon(self.router.run)
|
||||
tg.start_soon(self.worker.run)
|
||||
tg.start_soon(self.election.run)
|
||||
if self.worker:
|
||||
tg.start_soon(self.worker.run)
|
||||
if self.master:
|
||||
tg.start_soon(self.master.run)
|
||||
if self.api:
|
||||
@@ -194,6 +200,7 @@ def main():
|
||||
# TODO: Refactor the current verbosity system
|
||||
logger_setup(EXO_LOG, args.verbosity)
|
||||
logger.info("Starting EXO")
|
||||
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
|
||||
|
||||
node = anyio.run(Node.create, args)
|
||||
anyio.run(node.run)
|
||||
@@ -207,6 +214,7 @@ class Args(CamelCaseModel):
|
||||
spawn_api: bool = False
|
||||
api_port: PositiveInt = 52415
|
||||
tb_only: bool = False
|
||||
no_worker: bool = False
|
||||
|
||||
@classmethod
|
||||
def parse(cls) -> Self:
|
||||
@@ -244,6 +252,10 @@ class Args(CamelCaseModel):
|
||||
dest="api_port",
|
||||
default=52415,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-worker",
|
||||
action="store_true",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return cls(**vars(args)) # pyright: ignore[reportAny] - We are intentionally validating here, we can't do it statically
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import base64
|
||||
import json
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import cast
|
||||
from typing import Literal, cast
|
||||
|
||||
import anyio
|
||||
from anyio import create_task_group
|
||||
from anyio.abc import TaskGroup
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
@@ -13,19 +15,32 @@ from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType
|
||||
from hypercorn.config import Config
|
||||
from hypercorn.typing import ASGIFramework
|
||||
from loguru import logger
|
||||
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
HarmonyEncodingName,
|
||||
Role,
|
||||
StreamableParser,
|
||||
load_harmony_encoding,
|
||||
)
|
||||
|
||||
from exo.master.placement import place_instance as get_instance_placements
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.constants import EXO_MAX_CHUNK_SIZE
|
||||
from exo.shared.election import ElectionMessage
|
||||
from exo.shared.logging import InterceptLogger
|
||||
from exo.shared.models.model_cards import MODEL_CARDS
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard
|
||||
from exo.shared.models.model_meta import get_model_meta
|
||||
from exo.shared.types.api import (
|
||||
ChatCompletionChoice,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionResponse,
|
||||
CreateInstanceParams,
|
||||
CreateInstanceResponse,
|
||||
DeleteInstanceResponse,
|
||||
FinishReason,
|
||||
ImageData,
|
||||
ImageEditsInternalParams,
|
||||
ImageGenerationResponse,
|
||||
ImageGenerationTaskParams,
|
||||
ModelList,
|
||||
ModelListModel,
|
||||
PlaceInstanceParams,
|
||||
@@ -33,14 +48,17 @@ from exo.shared.types.api import (
|
||||
PlacementPreviewResponse,
|
||||
StreamingChoiceResponse,
|
||||
)
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.chunks import ImageChunk, InputImageChunk, TokenChunk
|
||||
from exo.shared.types.commands import (
|
||||
ChatCompletion,
|
||||
Command,
|
||||
CreateInstance,
|
||||
DeleteInstance,
|
||||
ForwarderCommand,
|
||||
ImageEdits,
|
||||
ImageGeneration,
|
||||
PlaceInstance,
|
||||
SendInputChunk,
|
||||
TaskFinished,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, NodeId, SessionId
|
||||
@@ -56,7 +74,7 @@ from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.dashboard_path import find_dashboard
|
||||
from exo.utils.event_buffer import OrderedBuffer
|
||||
|
||||
HIDE_THINKING = False
|
||||
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
|
||||
|
||||
|
||||
def chunk_to_response(
|
||||
@@ -76,12 +94,23 @@ def chunk_to_response(
|
||||
)
|
||||
|
||||
|
||||
async def resolve_model_meta(model_id: str) -> ModelMetadata:
|
||||
def get_model_card(model_id: str) -> ModelCard | None:
|
||||
if model_id in MODEL_CARDS:
|
||||
model_card = MODEL_CARDS[model_id]
|
||||
return model_card
|
||||
|
||||
for _, model_card in MODEL_CARDS.items():
|
||||
if model_id == model_card.model_id:
|
||||
return model_card
|
||||
|
||||
|
||||
async def resolve_model_meta(model_id: str) -> ModelMetadata:
|
||||
model_card = get_model_card(model_id)
|
||||
|
||||
if model_card is not None:
|
||||
return model_card.metadata
|
||||
else:
|
||||
return await get_model_meta(model_id)
|
||||
|
||||
return await get_model_meta(model_id)
|
||||
|
||||
|
||||
class API:
|
||||
@@ -125,6 +154,7 @@ class API:
|
||||
)
|
||||
|
||||
self._chat_completion_queues: dict[CommandId, Sender[TokenChunk]] = {}
|
||||
self._image_generation_queues: dict[CommandId, Sender[ImageChunk]] = {}
|
||||
self._tg: TaskGroup | None = None
|
||||
|
||||
def reset(self, new_session_id: SessionId, result_clock: int):
|
||||
@@ -133,6 +163,7 @@ class API:
|
||||
self.session_id = new_session_id
|
||||
self.event_buffer = OrderedBuffer[Event]()
|
||||
self._chat_completion_queues = {}
|
||||
self._image_generation_queues = {}
|
||||
self.unpause(result_clock)
|
||||
|
||||
def unpause(self, result_clock: int):
|
||||
@@ -161,7 +192,13 @@ class API:
|
||||
self.app.delete("/instance/{instance_id}")(self.delete_instance)
|
||||
self.app.get("/models")(self.get_models)
|
||||
self.app.get("/v1/models")(self.get_models)
|
||||
self.app.post("/v1/chat/completions")(self.chat_completions)
|
||||
self.app.post("/v1/chat/completions", response_model=None)(
|
||||
self.chat_completions
|
||||
)
|
||||
self.app.post("/v1/images/generations", response_model=None)(
|
||||
self.image_generations
|
||||
)
|
||||
self.app.post("/v1/images/edits", response_model=None)(self.image_edits)
|
||||
self.app.get("/state")(lambda: self.state)
|
||||
self.app.get("/events")(lambda: self._event_log)
|
||||
|
||||
@@ -177,17 +214,32 @@ class API:
|
||||
return CreateInstanceResponse(
|
||||
message="Command received.",
|
||||
command_id=command.command_id,
|
||||
model_meta=command.model_meta,
|
||||
)
|
||||
|
||||
async def create_instance(
|
||||
self, payload: CreateInstanceParams
|
||||
) -> CreateInstanceResponse:
|
||||
command = CreateInstance(instance=payload.instance)
|
||||
instance = payload.instance
|
||||
model_meta = await resolve_model_meta(instance.shard_assignments.model_id)
|
||||
required_memory = model_meta.storage_size
|
||||
available_memory = self._calculate_total_available_memory()
|
||||
|
||||
if required_memory > available_memory:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Insufficient memory to create instance. Required: {required_memory.in_gb:.1f}GB, Available: {available_memory.in_gb:.1f}GB",
|
||||
)
|
||||
|
||||
command = CreateInstance(
|
||||
instance=instance,
|
||||
)
|
||||
await self._send(command)
|
||||
|
||||
return CreateInstanceResponse(
|
||||
message="Command received.",
|
||||
command_id=command.command_id,
|
||||
model_meta=model_meta,
|
||||
)
|
||||
|
||||
async def get_placement(
|
||||
@@ -352,32 +404,52 @@ class API:
|
||||
instance_id=instance_id,
|
||||
)
|
||||
|
||||
async def _generate_chat_stream(
|
||||
self, command_id: CommandId
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate chat completion stream as JSON strings."""
|
||||
async def _process_gpt_oss(self, token_chunks: Receiver[TokenChunk]):
|
||||
stream = StreamableParser(encoding, role=Role.ASSISTANT)
|
||||
thinking = False
|
||||
|
||||
async for chunk in token_chunks:
|
||||
stream.process(chunk.token_id)
|
||||
|
||||
delta = stream.last_content_delta
|
||||
ch = stream.current_channel
|
||||
|
||||
if ch == "analysis" and not thinking:
|
||||
thinking = True
|
||||
yield chunk.model_copy(update={"text": "<think>"})
|
||||
|
||||
if ch != "analysis" and thinking:
|
||||
thinking = False
|
||||
yield chunk.model_copy(update={"text": "</think>"})
|
||||
|
||||
if delta:
|
||||
yield chunk.model_copy(update={"text": delta})
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
if thinking:
|
||||
yield chunk.model_copy(update={"text": "</think>"})
|
||||
yield chunk
|
||||
break
|
||||
|
||||
async def _chat_chunk_stream(
|
||||
self, command_id: CommandId, parse_gpt_oss: bool
|
||||
) -> AsyncGenerator[TokenChunk, None]:
|
||||
"""Yield `TokenChunk`s for a given command until completion."""
|
||||
|
||||
try:
|
||||
self._chat_completion_queues[command_id], recv = channel[TokenChunk]()
|
||||
|
||||
is_thinking = False
|
||||
with recv as token_chunks:
|
||||
async for chunk in token_chunks:
|
||||
if HIDE_THINKING:
|
||||
if chunk.text == "<think>":
|
||||
is_thinking = True
|
||||
if chunk.text == "</think>":
|
||||
is_thinking = False
|
||||
chunk_response: ChatCompletionResponse = chunk_to_response(
|
||||
chunk, command_id
|
||||
)
|
||||
if not (is_thinking and HIDE_THINKING):
|
||||
logger.debug(f"chunk_response: {chunk_response}")
|
||||
yield f"data: {chunk_response.model_dump_json()}\n\n"
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
yield "data: [DONE]\n\n"
|
||||
break
|
||||
if parse_gpt_oss:
|
||||
async for chunk in self._process_gpt_oss(token_chunks):
|
||||
yield chunk
|
||||
if chunk.finish_reason is not None:
|
||||
break
|
||||
else:
|
||||
async for chunk in token_chunks:
|
||||
yield chunk
|
||||
if chunk.finish_reason is not None:
|
||||
break
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
# TODO: TaskCancelled
|
||||
@@ -392,6 +464,59 @@ class API:
|
||||
await self._send(command)
|
||||
del self._chat_completion_queues[command_id]
|
||||
|
||||
async def _generate_chat_stream(
|
||||
self, command_id: CommandId, parse_gpt_oss: bool
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate chat completion stream as JSON strings."""
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
|
||||
chunk_response: ChatCompletionResponse = chunk_to_response(
|
||||
chunk, command_id
|
||||
)
|
||||
logger.debug(f"chunk_response: {chunk_response}")
|
||||
|
||||
yield f"data: {chunk_response.model_dump_json()}\n\n"
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def _collect_chat_completion(
|
||||
self, command_id: CommandId, parse_gpt_oss: bool
|
||||
) -> ChatCompletionResponse:
|
||||
"""Collect all token chunks for a chat completion and return a single response."""
|
||||
|
||||
text_parts: list[str] = []
|
||||
model: str | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
|
||||
text_parts.append(chunk.text)
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
finish_reason = chunk.finish_reason
|
||||
|
||||
combined_text = "".join(text_parts)
|
||||
assert model is not None
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content=combined_text,
|
||||
),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
async def _trigger_notify_user_to_download_model(self, model_id: str) -> None:
|
||||
logger.warning(
|
||||
"TODO: we should send a notification to the user to download the model"
|
||||
@@ -399,10 +524,12 @@ class API:
|
||||
|
||||
async def chat_completions(
|
||||
self, payload: ChatCompletionTaskParams
|
||||
) -> StreamingResponse:
|
||||
"""Handle chat completions with proper streaming response."""
|
||||
) -> ChatCompletionResponse | StreamingResponse:
|
||||
"""Handle chat completions, supporting both streaming and non-streaming responses."""
|
||||
model_meta = await resolve_model_meta(payload.model)
|
||||
payload.model = model_meta.model_id
|
||||
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
|
||||
logger.info(f"{parse_gpt_oss=}")
|
||||
|
||||
if not any(
|
||||
instance.shard_assignments.model_id == payload.model
|
||||
@@ -417,10 +544,332 @@ class API:
|
||||
request_params=payload,
|
||||
)
|
||||
await self._send(command)
|
||||
return StreamingResponse(
|
||||
self._generate_chat_stream(command.command_id),
|
||||
media_type="text/event-stream",
|
||||
if payload.stream:
|
||||
return StreamingResponse(
|
||||
self._generate_chat_stream(command.command_id, parse_gpt_oss),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
return await self._collect_chat_completion(command.command_id, parse_gpt_oss)
|
||||
|
||||
async def image_generations(
|
||||
self, payload: ImageGenerationTaskParams
|
||||
) -> ImageGenerationResponse | StreamingResponse:
|
||||
"""Handle image generation requests.
|
||||
|
||||
When stream=True and partial_images > 0, returns a StreamingResponse
|
||||
with SSE-formatted events for partial and final images.
|
||||
"""
|
||||
model_meta = await resolve_model_meta(payload.model)
|
||||
payload.model = model_meta.model_id
|
||||
|
||||
if not any(
|
||||
instance.shard_assignments.model_id == payload.model
|
||||
for instance in self.state.instances.values()
|
||||
):
|
||||
await self._trigger_notify_user_to_download_model(payload.model)
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"No instance found for model {payload.model}"
|
||||
)
|
||||
|
||||
command = ImageGeneration(
|
||||
request_params=payload,
|
||||
)
|
||||
await self._send(command)
|
||||
|
||||
# Check if streaming is requested
|
||||
if payload.stream and payload.partial_images and payload.partial_images > 0:
|
||||
return StreamingResponse(
|
||||
self._generate_image_stream(
|
||||
command_id=command.command_id,
|
||||
num_images=payload.n or 1,
|
||||
response_format=payload.response_format or "b64_json",
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
# Non-streaming: collect all image chunks
|
||||
return await self._collect_image_generation(
|
||||
command_id=command.command_id,
|
||||
num_images=payload.n or 1,
|
||||
response_format=payload.response_format or "b64_json",
|
||||
)
|
||||
|
||||
async def _generate_image_stream(
|
||||
self,
|
||||
command_id: CommandId,
|
||||
num_images: int,
|
||||
response_format: str,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate SSE stream of partial and final images."""
|
||||
# Track chunks: {(image_index, is_partial): {chunk_index: data}}
|
||||
image_chunks: dict[tuple[int, bool], dict[int, str]] = {}
|
||||
image_total_chunks: dict[tuple[int, bool], int] = {}
|
||||
image_metadata: dict[tuple[int, bool], tuple[int | None, int | None]] = {}
|
||||
images_complete = 0
|
||||
|
||||
try:
|
||||
self._image_generation_queues[command_id], recv = channel[ImageChunk]()
|
||||
|
||||
with recv as chunks:
|
||||
async for chunk in chunks:
|
||||
key = (chunk.image_index, chunk.is_partial)
|
||||
|
||||
if key not in image_chunks:
|
||||
image_chunks[key] = {}
|
||||
image_total_chunks[key] = chunk.total_chunks
|
||||
image_metadata[key] = (
|
||||
chunk.partial_index,
|
||||
chunk.total_partials,
|
||||
)
|
||||
|
||||
image_chunks[key][chunk.chunk_index] = chunk.data
|
||||
|
||||
# Check if this image is complete
|
||||
if len(image_chunks[key]) == image_total_chunks[key]:
|
||||
full_data = "".join(
|
||||
image_chunks[key][i] for i in range(len(image_chunks[key]))
|
||||
)
|
||||
|
||||
partial_idx, total_partials = image_metadata[key]
|
||||
|
||||
if chunk.is_partial:
|
||||
# Yield partial image event
|
||||
event_data = {
|
||||
"type": "partial",
|
||||
"partial_index": partial_idx,
|
||||
"total_partials": total_partials,
|
||||
"data": {
|
||||
"b64_json": full_data
|
||||
if response_format == "b64_json"
|
||||
else None,
|
||||
},
|
||||
}
|
||||
yield f"data: {json.dumps(event_data)}\n\n"
|
||||
else:
|
||||
# Final image
|
||||
event_data = {
|
||||
"type": "final",
|
||||
"image_index": chunk.image_index,
|
||||
"data": {
|
||||
"b64_json": full_data
|
||||
if response_format == "b64_json"
|
||||
else None,
|
||||
},
|
||||
}
|
||||
yield f"data: {json.dumps(event_data)}\n\n"
|
||||
images_complete += 1
|
||||
|
||||
if images_complete >= num_images:
|
||||
yield "data: [DONE]\n\n"
|
||||
break
|
||||
|
||||
# Clean up completed image chunks
|
||||
del image_chunks[key]
|
||||
del image_total_chunks[key]
|
||||
del image_metadata[key]
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
raise
|
||||
finally:
|
||||
await self._send(TaskFinished(finished_command_id=command_id))
|
||||
if command_id in self._image_generation_queues:
|
||||
del self._image_generation_queues[command_id]
|
||||
|
||||
async def _collect_image_generation(
|
||||
self,
|
||||
command_id: CommandId,
|
||||
num_images: int,
|
||||
response_format: str,
|
||||
) -> ImageGenerationResponse:
|
||||
"""Collect all image chunks (non-streaming) and return a single response."""
|
||||
# Track chunks per image: {image_index: {chunk_index: data}}
|
||||
# Only track non-partial (final) images
|
||||
image_chunks: dict[int, dict[int, str]] = {}
|
||||
image_total_chunks: dict[int, int] = {}
|
||||
images_complete = 0
|
||||
|
||||
try:
|
||||
self._image_generation_queues[command_id], recv = channel[ImageChunk]()
|
||||
|
||||
while images_complete < num_images:
|
||||
with recv as chunks:
|
||||
async for chunk in chunks:
|
||||
# Skip partial images in non-streaming mode
|
||||
if chunk.is_partial:
|
||||
continue
|
||||
|
||||
if chunk.image_index not in image_chunks:
|
||||
image_chunks[chunk.image_index] = {}
|
||||
image_total_chunks[chunk.image_index] = chunk.total_chunks
|
||||
|
||||
image_chunks[chunk.image_index][chunk.chunk_index] = chunk.data
|
||||
|
||||
# Check if this image is complete
|
||||
if (
|
||||
len(image_chunks[chunk.image_index])
|
||||
== image_total_chunks[chunk.image_index]
|
||||
):
|
||||
images_complete += 1
|
||||
|
||||
if images_complete >= num_images:
|
||||
break
|
||||
|
||||
# Reassemble images in order
|
||||
images: list[ImageData] = []
|
||||
for image_idx in range(num_images):
|
||||
chunks_dict = image_chunks[image_idx]
|
||||
full_data = "".join(chunks_dict[i] for i in range(len(chunks_dict)))
|
||||
images.append(
|
||||
ImageData(
|
||||
b64_json=full_data if response_format == "b64_json" else None,
|
||||
url=None, # URL format not implemented yet
|
||||
)
|
||||
)
|
||||
|
||||
return ImageGenerationResponse(data=images)
|
||||
except anyio.get_cancelled_exc_class():
|
||||
raise
|
||||
finally:
|
||||
await self._send(TaskFinished(finished_command_id=command_id))
|
||||
if command_id in self._image_generation_queues:
|
||||
del self._image_generation_queues[command_id]
|
||||
|
||||
async def image_edits(
|
||||
self,
|
||||
image: UploadFile = File(...),
|
||||
prompt: str = Form(...),
|
||||
model: str = Form(...),
|
||||
n: int = Form(1),
|
||||
size: str = Form("1024x1024"),
|
||||
response_format: Literal["url", "b64_json"] = Form("b64_json"),
|
||||
input_fidelity: Literal["low", "high"] = Form("low"),
|
||||
stream: bool = Form(False),
|
||||
partial_images: int = Form(0),
|
||||
) -> ImageGenerationResponse | StreamingResponse:
|
||||
"""Handle image editing requests (img2img)."""
|
||||
model_meta = await resolve_model_meta(model)
|
||||
resolved_model = model_meta.model_id
|
||||
|
||||
if not any(
|
||||
instance.shard_assignments.model_id == resolved_model
|
||||
for instance in self.state.instances.values()
|
||||
):
|
||||
await self._trigger_notify_user_to_download_model(resolved_model)
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"No instance found for model {resolved_model}"
|
||||
)
|
||||
|
||||
# Read and base64 encode the uploaded image
|
||||
image_content = await image.read()
|
||||
image_data = base64.b64encode(image_content).decode("utf-8")
|
||||
|
||||
# Map input_fidelity to image_strength
|
||||
image_strength = 0.7 if input_fidelity == "high" else 0.3
|
||||
|
||||
# Split image into chunks to stay under gossipsub message size limit
|
||||
data_chunks = [
|
||||
image_data[i : i + EXO_MAX_CHUNK_SIZE]
|
||||
for i in range(0, len(image_data), EXO_MAX_CHUNK_SIZE)
|
||||
]
|
||||
total_chunks = len(data_chunks)
|
||||
|
||||
# Create command first to get command_id
|
||||
command = ImageEdits(
|
||||
request_params=ImageEditsInternalParams(
|
||||
image_data="", # Empty - will be assembled at worker from chunks
|
||||
total_input_chunks=total_chunks,
|
||||
prompt=prompt,
|
||||
model=resolved_model,
|
||||
n=n,
|
||||
size=size,
|
||||
response_format=response_format,
|
||||
image_strength=image_strength,
|
||||
stream=stream,
|
||||
partial_images=partial_images,
|
||||
),
|
||||
)
|
||||
|
||||
# Send input chunks BEFORE the command
|
||||
logger.info(
|
||||
f"Sending input image: {len(image_data)} bytes in {total_chunks} chunks"
|
||||
)
|
||||
for chunk_index, chunk_data in enumerate(data_chunks):
|
||||
await self._send(
|
||||
SendInputChunk(
|
||||
chunk=InputImageChunk(
|
||||
idx=chunk_index,
|
||||
model=resolved_model,
|
||||
command_id=command.command_id,
|
||||
data=chunk_data,
|
||||
chunk_index=chunk_index,
|
||||
total_chunks=total_chunks,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Now send the main command
|
||||
await self._send(command)
|
||||
|
||||
num_images = n
|
||||
|
||||
# Check if streaming is requested
|
||||
if stream and partial_images and partial_images > 0:
|
||||
return StreamingResponse(
|
||||
self._generate_image_stream(
|
||||
command_id=command.command_id,
|
||||
num_images=num_images,
|
||||
response_format=response_format,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
# Track chunks per image: {image_index: {chunk_index: data}}
|
||||
image_chunks: dict[int, dict[int, str]] = {}
|
||||
image_total_chunks: dict[int, int] = {}
|
||||
images_complete = 0
|
||||
|
||||
try:
|
||||
self._image_generation_queues[command.command_id], recv = channel[
|
||||
ImageChunk
|
||||
]()
|
||||
|
||||
while images_complete < num_images:
|
||||
with recv as chunks:
|
||||
async for chunk in chunks:
|
||||
if chunk.image_index not in image_chunks:
|
||||
image_chunks[chunk.image_index] = {}
|
||||
image_total_chunks[chunk.image_index] = chunk.total_chunks
|
||||
|
||||
image_chunks[chunk.image_index][chunk.chunk_index] = chunk.data
|
||||
|
||||
if (
|
||||
len(image_chunks[chunk.image_index])
|
||||
== image_total_chunks[chunk.image_index]
|
||||
):
|
||||
images_complete += 1
|
||||
|
||||
if images_complete >= num_images:
|
||||
break
|
||||
|
||||
images: list[ImageData] = []
|
||||
for image_idx in range(num_images):
|
||||
chunks_dict = image_chunks[image_idx]
|
||||
full_data = "".join(chunks_dict[i] for i in range(len(chunks_dict)))
|
||||
images.append(
|
||||
ImageData(
|
||||
b64_json=full_data if response_format == "b64_json" else None,
|
||||
url=None, # URL format not implemented yet
|
||||
)
|
||||
)
|
||||
|
||||
return ImageGenerationResponse(data=images)
|
||||
except anyio.get_cancelled_exc_class():
|
||||
raise
|
||||
finally:
|
||||
# Send TaskFinished command
|
||||
await self._send(TaskFinished(finished_command_id=command.command_id))
|
||||
del self._image_generation_queues[command.command_id]
|
||||
|
||||
def _calculate_total_available_memory(self) -> Memory:
|
||||
"""Calculate total available memory across all nodes in bytes."""
|
||||
@@ -442,6 +891,9 @@ class API:
|
||||
name=card.name,
|
||||
description=card.description,
|
||||
tags=card.tags,
|
||||
storage_size_megabytes=int(card.metadata.storage_size.in_mb),
|
||||
supports_tensor=card.metadata.supports_tensor,
|
||||
tasks=[task.value for task in card.tasks],
|
||||
)
|
||||
for card in MODEL_CARDS.values()
|
||||
]
|
||||
@@ -458,7 +910,7 @@ class API:
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
logger.info("Starting API")
|
||||
tg.start_soon(self._applystate)
|
||||
tg.start_soon(self._apply_state)
|
||||
tg.start_soon(self._pause_on_new_election)
|
||||
print_startup_banner(self.port)
|
||||
await serve(
|
||||
@@ -470,7 +922,7 @@ class API:
|
||||
self.command_sender.close()
|
||||
self.global_event_receiver.close()
|
||||
|
||||
async def _applystate(self):
|
||||
async def _apply_state(self):
|
||||
with self.global_event_receiver as events:
|
||||
async for f_event in events:
|
||||
if f_event.origin != self.session_id.master_node_id:
|
||||
@@ -479,14 +931,17 @@ class API:
|
||||
for idx, event in self.event_buffer.drain_indexed():
|
||||
self._event_log.append(event)
|
||||
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
|
||||
if (
|
||||
isinstance(event, ChunkGenerated)
|
||||
and event.command_id in self._chat_completion_queues
|
||||
):
|
||||
assert isinstance(event.chunk, TokenChunk)
|
||||
await self._chat_completion_queues[event.command_id].send(
|
||||
event.chunk
|
||||
)
|
||||
if isinstance(event, ChunkGenerated):
|
||||
if event.command_id in self._chat_completion_queues:
|
||||
assert isinstance(event.chunk, TokenChunk)
|
||||
await self._chat_completion_queues[event.command_id].send(
|
||||
event.chunk
|
||||
)
|
||||
elif event.command_id in self._image_generation_queues:
|
||||
assert isinstance(event.chunk, ImageChunk)
|
||||
await self._image_generation_queues[event.command_id].send(
|
||||
event.chunk
|
||||
)
|
||||
|
||||
async def _pause_on_new_election(self):
|
||||
with self.election_receiver as ems:
|
||||
|
||||
@@ -2,6 +2,7 @@ from datetime import datetime, timedelta, timezone
|
||||
|
||||
import anyio
|
||||
from anyio.abc import TaskGroup
|
||||
from fastapi.routing import request_response
|
||||
from loguru import logger
|
||||
|
||||
from exo.master.placement import (
|
||||
@@ -11,13 +12,17 @@ from exo.master.placement import (
|
||||
place_instance,
|
||||
)
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.types.chunks import InputImageChunk
|
||||
from exo.shared.types.commands import (
|
||||
ChatCompletion,
|
||||
CreateInstance,
|
||||
DeleteInstance,
|
||||
ForwarderCommand,
|
||||
ImageEdits,
|
||||
ImageGeneration,
|
||||
PlaceInstance,
|
||||
RequestEventLog,
|
||||
SendInputChunk,
|
||||
TaskFinished,
|
||||
TestCommand,
|
||||
)
|
||||
@@ -26,6 +31,7 @@ from exo.shared.types.events import (
|
||||
Event,
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
InputChunkReceived,
|
||||
InstanceDeleted,
|
||||
NodeTimedOut,
|
||||
TaskCreated,
|
||||
@@ -35,6 +41,12 @@ from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion as ChatCompletionTask,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
ImageEdits as ImageEditsTask,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
ImageGeneration as ImageGenerationTask,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
@@ -99,6 +111,7 @@ class Master:
|
||||
async for forwarder_command in commands:
|
||||
try:
|
||||
logger.info(f"Executing command: {forwarder_command.command}")
|
||||
|
||||
generated_events: list[Event] = []
|
||||
command = forwarder_command.command
|
||||
match command:
|
||||
@@ -146,6 +159,92 @@ class Master:
|
||||
)
|
||||
)
|
||||
|
||||
self.command_task_mapping[command.command_id] = task_id
|
||||
case ImageGeneration():
|
||||
instance_task_counts: dict[InstanceId, int] = {}
|
||||
for instance in self.state.instances.values():
|
||||
if (
|
||||
instance.shard_assignments.model_id
|
||||
== command.request_params.model
|
||||
):
|
||||
task_count = sum(
|
||||
1
|
||||
for task in self.state.tasks.values()
|
||||
if task.instance_id == instance.instance_id
|
||||
)
|
||||
instance_task_counts[instance.instance_id] = (
|
||||
task_count
|
||||
)
|
||||
|
||||
if not instance_task_counts:
|
||||
raise ValueError(
|
||||
f"No instance found for model {command.request_params.model}"
|
||||
)
|
||||
|
||||
available_instance_ids = sorted(
|
||||
instance_task_counts.keys(),
|
||||
key=lambda instance_id: instance_task_counts[
|
||||
instance_id
|
||||
],
|
||||
)
|
||||
|
||||
task_id = TaskId()
|
||||
generated_events.append(
|
||||
TaskCreated(
|
||||
task_id=task_id,
|
||||
task=ImageGenerationTask(
|
||||
task_id=task_id,
|
||||
command_id=command.command_id,
|
||||
instance_id=available_instance_ids[0],
|
||||
task_status=TaskStatus.Pending,
|
||||
task_params=command.request_params,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
self.command_task_mapping[command.command_id] = task_id
|
||||
case ImageEdits():
|
||||
instance_task_counts: dict[InstanceId, int] = {}
|
||||
for instance in self.state.instances.values():
|
||||
if (
|
||||
instance.shard_assignments.model_id
|
||||
== command.request_params.model
|
||||
):
|
||||
task_count = sum(
|
||||
1
|
||||
for task in self.state.tasks.values()
|
||||
if task.instance_id == instance.instance_id
|
||||
)
|
||||
instance_task_counts[instance.instance_id] = (
|
||||
task_count
|
||||
)
|
||||
|
||||
if not instance_task_counts:
|
||||
raise ValueError(
|
||||
f"No instance found for model {command.request_params.model}"
|
||||
)
|
||||
|
||||
available_instance_ids = sorted(
|
||||
instance_task_counts.keys(),
|
||||
key=lambda instance_id: instance_task_counts[
|
||||
instance_id
|
||||
],
|
||||
)
|
||||
|
||||
task_id = TaskId()
|
||||
generated_events.append(
|
||||
TaskCreated(
|
||||
task_id=task_id,
|
||||
task=ImageEditsTask(
|
||||
task_id=task_id,
|
||||
command_id=command.command_id,
|
||||
instance_id=available_instance_ids[0],
|
||||
task_status=TaskStatus.Pending,
|
||||
task_params=command.request_params,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
self.command_task_mapping[command.command_id] = task_id
|
||||
case DeleteInstance():
|
||||
placement = delete_instance(command, self.state.instances)
|
||||
@@ -173,6 +272,13 @@ class Master:
|
||||
self.state.instances, placement
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case SendInputChunk(chunk=chunk):
|
||||
generated_events.append(
|
||||
InputChunkReceived(
|
||||
command_id=chunk.command_id,
|
||||
chunk=chunk,
|
||||
)
|
||||
)
|
||||
case TaskFinished():
|
||||
generated_events.append(
|
||||
TaskDeleted(
|
||||
|
||||
@@ -7,9 +7,9 @@ from loguru import logger
|
||||
|
||||
from exo.master.placement_utils import (
|
||||
filter_cycles_by_memory,
|
||||
get_hosts_from_subgraph,
|
||||
get_mlx_ibv_devices_matrix,
|
||||
get_mlx_jaccl_coordinators,
|
||||
get_mlx_ring_hosts_by_node,
|
||||
get_shard_assignments,
|
||||
get_smallest_cycles,
|
||||
)
|
||||
@@ -19,9 +19,9 @@ from exo.shared.types.commands import (
|
||||
DeleteInstance,
|
||||
PlaceInstance,
|
||||
)
|
||||
from exo.shared.types.common import Host
|
||||
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.topology import NodeInfo
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
@@ -30,6 +30,7 @@ from exo.shared.types.worker.instances import (
|
||||
MlxJacclInstance,
|
||||
MlxRingInstance,
|
||||
)
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
|
||||
|
||||
def random_ephemeral_port() -> int:
|
||||
@@ -66,6 +67,28 @@ def place_instance(
|
||||
if not cycles_with_sufficient_memory:
|
||||
raise ValueError("No cycles found with sufficient memory")
|
||||
|
||||
if command.sharding == Sharding.Tensor:
|
||||
if not command.model_meta.supports_tensor:
|
||||
raise ValueError(
|
||||
f"Requested Tensor sharding but this model does not support tensor parallelism: {command.model_meta.model_id}"
|
||||
)
|
||||
# TODO: the condition here for tensor parallel is not correct, but it works good enough for now.
|
||||
cycles_with_sufficient_memory = [
|
||||
cycle
|
||||
for cycle in cycles_with_sufficient_memory
|
||||
if command.model_meta.hidden_size % len(cycle) == 0
|
||||
]
|
||||
if not cycles_with_sufficient_memory:
|
||||
raise ValueError(
|
||||
f"No tensor sharding found for model with hidden_size {command.model_meta.hidden_size} candidate cycles"
|
||||
)
|
||||
if command.sharding == Sharding.Pipeline and command.model_meta.model_id == ModelId(
|
||||
"mlx-community/DeepSeek-V3.1-8bit"
|
||||
):
|
||||
raise ValueError(
|
||||
"Pipeline parallelism is not supported for DeepSeek V3.1 (8-bit)"
|
||||
)
|
||||
|
||||
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
|
||||
|
||||
smallest_tb_cycles = [
|
||||
@@ -130,17 +153,17 @@ def place_instance(
|
||||
jaccl_coordinators=mlx_jaccl_coordinators,
|
||||
)
|
||||
case InstanceMeta.MlxRing:
|
||||
hosts: list[Host] = get_hosts_from_subgraph(cycle_digraph)
|
||||
ephemeral_port = random_ephemeral_port()
|
||||
hosts_by_node = get_mlx_ring_hosts_by_node(
|
||||
selected_cycle=selected_cycle,
|
||||
cycle_digraph=cycle_digraph,
|
||||
ephemeral_port=ephemeral_port,
|
||||
)
|
||||
target_instances[instance_id] = MlxRingInstance(
|
||||
instance_id=instance_id,
|
||||
shard_assignments=shard_assignments,
|
||||
hosts=[
|
||||
Host(
|
||||
ip=host.ip,
|
||||
port=random_ephemeral_port(),
|
||||
)
|
||||
for host in hosts
|
||||
],
|
||||
hosts_by_node=hosts_by_node,
|
||||
ephemeral_port=ephemeral_port,
|
||||
)
|
||||
|
||||
return target_instances
|
||||
|
||||
@@ -215,9 +215,11 @@ def get_mlx_ibv_devices_matrix(
|
||||
continue
|
||||
|
||||
# Find the IP J uses to talk to I
|
||||
for connection_ip in _find_connection_ip(node_j, node_i, cycle_digraph):
|
||||
for connection_ip, _ in _find_connection_ip(node_j, node_i, cycle_digraph):
|
||||
# This is a local IP on I, which is attached to an interface: find that interface
|
||||
if interface_name := _find_interface_name_for_ip(connection_ip, node_i):
|
||||
if interface_name := _find_rdma_interface_name_for_ip(
|
||||
connection_ip, node_i
|
||||
):
|
||||
matrix[i][j] = interface_name
|
||||
logger.info(
|
||||
f"Interface name for {connection_ip} on {node_i.node_id}: {interface_name}"
|
||||
@@ -238,17 +240,17 @@ def _find_connection_ip(
|
||||
node_i: NodeInfo,
|
||||
node_j: NodeInfo,
|
||||
cycle_digraph: Topology,
|
||||
) -> Generator[str]:
|
||||
"""Find all IP addresses that connect node i to node j."""
|
||||
) -> Generator[tuple[str, bool]]:
|
||||
"""Find all IP addresses that connect node i to node j, with thunderbolt flag."""
|
||||
for connection in cycle_digraph.list_connections():
|
||||
if (
|
||||
connection.local_node_id == node_i.node_id
|
||||
and connection.send_back_node_id == node_j.node_id
|
||||
):
|
||||
yield connection.send_back_multiaddr.ip_address
|
||||
yield connection.send_back_multiaddr.ip_address, connection.is_thunderbolt()
|
||||
|
||||
|
||||
def _find_interface_name_for_ip(
|
||||
def _find_rdma_interface_name_for_ip(
|
||||
ip_address: str,
|
||||
node_info: NodeInfo,
|
||||
) -> str | None:
|
||||
@@ -269,6 +271,109 @@ def _find_interface_name_for_ip(
|
||||
return None
|
||||
|
||||
|
||||
def _find_interface_name_for_ip(
|
||||
ip_address: str,
|
||||
node_info: NodeInfo,
|
||||
) -> str | None:
|
||||
"""Find the interface name for an IP address on a node (any interface)."""
|
||||
if node_info.node_profile is None:
|
||||
return None
|
||||
|
||||
for interface in node_info.node_profile.network_interfaces:
|
||||
if interface.ip_address == ip_address:
|
||||
return interface.name
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _find_ip_prioritised(
|
||||
node: NodeInfo, other_node: NodeInfo, cycle_digraph: Topology
|
||||
) -> str | None:
|
||||
# TODO: Actually prioritize in the correct Ethernet > Wifi > Non-TB > TB order.
|
||||
"""Find an IP address between nodes with prioritization.
|
||||
|
||||
Priority order:
|
||||
1. en0 (Ethernet on Mac Studio, WiFi on MacBook)
|
||||
2. en1 (WiFi on Mac Studio, Ethernet on MacBook)
|
||||
3. Non-Thunderbolt connections
|
||||
4. Any other IP address
|
||||
"""
|
||||
ips = list(_find_connection_ip(node, other_node, cycle_digraph))
|
||||
# We expect a unique iface -> ip mapping
|
||||
iface_map = {_find_interface_name_for_ip(ip, other_node): ip for ip, _ in ips}
|
||||
|
||||
en0_ip = iface_map.get("en0")
|
||||
if en0_ip:
|
||||
return en0_ip
|
||||
|
||||
en1_ip = iface_map.get("en1")
|
||||
if en1_ip:
|
||||
return en1_ip
|
||||
|
||||
non_thunderbolt_ip = next(
|
||||
(ip for (ip, is_thunderbolt) in ips if not is_thunderbolt), None
|
||||
)
|
||||
|
||||
if non_thunderbolt_ip:
|
||||
return non_thunderbolt_ip
|
||||
|
||||
if ips:
|
||||
return ips[0][0]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_mlx_ring_hosts_by_node(
|
||||
selected_cycle: list[NodeInfo],
|
||||
cycle_digraph: Topology,
|
||||
ephemeral_port: int,
|
||||
) -> dict[NodeId, list[Host]]:
|
||||
"""Generate per-node host lists for MLX ring backend.
|
||||
|
||||
Each node gets a list where:
|
||||
- Self position: Host(ip="0.0.0.0", port=ephemeral_port)
|
||||
- Left/right neighbors: actual connection IPs
|
||||
- Non-neighbors: Host(ip="198.51.100.1", port=0) placeholder (RFC 5737 TEST-NET-2)
|
||||
"""
|
||||
world_size = len(selected_cycle)
|
||||
if world_size == 0:
|
||||
return {}
|
||||
|
||||
hosts_by_node: dict[NodeId, list[Host]] = {}
|
||||
|
||||
for rank, node in enumerate(selected_cycle):
|
||||
node_id = node.node_id
|
||||
left_rank = (rank - 1) % world_size
|
||||
right_rank = (rank + 1) % world_size
|
||||
|
||||
hosts_for_node: list[Host] = []
|
||||
|
||||
for idx, other_node in enumerate(selected_cycle):
|
||||
if idx == rank:
|
||||
hosts_for_node.append(Host(ip="0.0.0.0", port=ephemeral_port))
|
||||
continue
|
||||
|
||||
if idx not in {left_rank, right_rank}:
|
||||
# Placeholder IP from RFC 5737 TEST-NET-2
|
||||
hosts_for_node.append(Host(ip="198.51.100.1", port=0))
|
||||
continue
|
||||
|
||||
connection_ip = _find_ip_prioritised(node, other_node, cycle_digraph)
|
||||
if connection_ip is None:
|
||||
logger.warning(
|
||||
f"Failed to find prioritised connection IP between {node_id} and {other_node.node_id}"
|
||||
)
|
||||
raise ValueError(
|
||||
"MLX ring backend requires connectivity between neighbouring nodes"
|
||||
)
|
||||
|
||||
hosts_for_node.append(Host(ip=connection_ip, port=ephemeral_port))
|
||||
|
||||
hosts_by_node[node_id] = hosts_for_node
|
||||
|
||||
return hosts_by_node
|
||||
|
||||
|
||||
def get_mlx_jaccl_coordinators(
|
||||
selected_cycle: list[NodeInfo],
|
||||
coordinator_port: int,
|
||||
@@ -280,13 +385,14 @@ def get_mlx_jaccl_coordinators(
|
||||
address in format "X.X.X.X:PORT" per node.
|
||||
"""
|
||||
rank_0_node = selected_cycle[0]
|
||||
logger.info(f"Selecting coordinator from rank 0 node: {rank_0_node.node_id}")
|
||||
logger.debug(f"Selecting coordinator from rank 0 node: {rank_0_node.node_id}")
|
||||
|
||||
def get_ip_for_node(n: NodeInfo) -> str:
|
||||
if n.node_id == rank_0_node.node_id:
|
||||
return "0.0.0.0"
|
||||
|
||||
for ip in _find_connection_ip(n, rank_0_node, cycle_digraph):
|
||||
ip = _find_ip_prioritised(n, rank_0_node, cycle_digraph)
|
||||
if ip:
|
||||
return ip
|
||||
|
||||
logger.warning(
|
||||
|
||||
@@ -123,6 +123,8 @@ async def test_master():
|
||||
pretty_name="Llama 3.2 1B",
|
||||
n_layers=16,
|
||||
storage_size=Memory.from_bytes(678948),
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
sharding=Sharding.Pipeline,
|
||||
instance_meta=InstanceMeta.MlxRing,
|
||||
@@ -163,32 +165,38 @@ async def test_master():
|
||||
assert events[2].idx == 2
|
||||
assert isinstance(events[0].event, NodePerformanceMeasured)
|
||||
assert isinstance(events[1].event, InstanceCreated)
|
||||
runner_id = list(
|
||||
events[1].event.instance.shard_assignments.runner_to_shard.keys()
|
||||
)[0]
|
||||
assert events[1].event.instance == MlxRingInstance(
|
||||
instance_id=events[1].event.instance.instance_id,
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=ModelId("llama-3.2-1b"),
|
||||
runner_to_shard={
|
||||
(runner_id): PipelineShardMetadata(
|
||||
start_layer=0,
|
||||
end_layer=16,
|
||||
created_instance = events[1].event.instance
|
||||
assert isinstance(created_instance, MlxRingInstance)
|
||||
runner_id = list(created_instance.shard_assignments.runner_to_shard.keys())[0]
|
||||
# Validate the shard assignments
|
||||
expected_shard_assignments = ShardAssignments(
|
||||
model_id=ModelId("llama-3.2-1b"),
|
||||
runner_to_shard={
|
||||
(runner_id): PipelineShardMetadata(
|
||||
start_layer=0,
|
||||
end_layer=16,
|
||||
n_layers=16,
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("llama-3.2-1b"),
|
||||
pretty_name="Llama 3.2 1B",
|
||||
n_layers=16,
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("llama-3.2-1b"),
|
||||
pretty_name="Llama 3.2 1B",
|
||||
n_layers=16,
|
||||
storage_size=Memory.from_bytes(678948),
|
||||
),
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
)
|
||||
},
|
||||
node_to_runner={node_id: runner_id},
|
||||
),
|
||||
hosts=[],
|
||||
storage_size=Memory.from_bytes(678948),
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
)
|
||||
},
|
||||
node_to_runner={node_id: runner_id},
|
||||
)
|
||||
assert created_instance.shard_assignments == expected_shard_assignments
|
||||
# For single-node, hosts_by_node should have one entry with self-binding
|
||||
assert len(created_instance.hosts_by_node) == 1
|
||||
assert node_id in created_instance.hosts_by_node
|
||||
assert len(created_instance.hosts_by_node[node_id]) == 1
|
||||
assert created_instance.hosts_by_node[node_id][0].ip == "0.0.0.0"
|
||||
assert created_instance.ephemeral_port > 0
|
||||
assert isinstance(events[2].event, TaskCreated)
|
||||
assert events[2].event.task.task_status == TaskStatus.Pending
|
||||
assert isinstance(events[2].event.task, ChatCompletionTask)
|
||||
|
||||
@@ -38,7 +38,8 @@ def instance() -> Instance:
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=ModelId("test-model"), runner_to_shard={}, node_to_runner={}
|
||||
),
|
||||
hosts=[],
|
||||
hosts_by_node={},
|
||||
ephemeral_port=50000,
|
||||
)
|
||||
|
||||
|
||||
@@ -49,6 +50,8 @@ def model_meta() -> ModelMetadata:
|
||||
storage_size=Memory.from_kb(1000),
|
||||
pretty_name="Test Model",
|
||||
n_layers=10,
|
||||
hidden_size=30,
|
||||
supports_tensor=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -92,9 +95,13 @@ def test_get_instance_placements_create_instance(
|
||||
topology.add_node(create_node(available_memory[0], node_id_a))
|
||||
topology.add_node(create_node(available_memory[1], node_id_b))
|
||||
topology.add_node(create_node(available_memory[2], node_id_c))
|
||||
# Add bidirectional connections for ring topology
|
||||
topology.add_connection(create_connection(node_id_a, node_id_b))
|
||||
topology.add_connection(create_connection(node_id_b, node_id_a))
|
||||
topology.add_connection(create_connection(node_id_b, node_id_c))
|
||||
topology.add_connection(create_connection(node_id_c, node_id_b))
|
||||
topology.add_connection(create_connection(node_id_c, node_id_a))
|
||||
topology.add_connection(create_connection(node_id_a, node_id_c))
|
||||
|
||||
# act
|
||||
placements = place_instance(cic, topology, {})
|
||||
@@ -135,6 +142,8 @@ def test_get_instance_placements_one_node_exact_fit(
|
||||
storage_size=Memory.from_kb(1000),
|
||||
pretty_name="Test Model",
|
||||
n_layers=10,
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
),
|
||||
)
|
||||
placements = place_instance(cic, topology, {})
|
||||
@@ -160,6 +169,8 @@ def test_get_instance_placements_one_node_fits_with_extra_memory(
|
||||
storage_size=Memory.from_kb(1000),
|
||||
pretty_name="Test Model",
|
||||
n_layers=10,
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
),
|
||||
)
|
||||
placements = place_instance(cic, topology, {})
|
||||
@@ -185,6 +196,8 @@ def test_get_instance_placements_one_node_not_fit(
|
||||
storage_size=Memory.from_kb(1001),
|
||||
pretty_name="Test Model",
|
||||
n_layers=10,
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -234,17 +247,15 @@ def test_get_transition_events_delete_instance(instance: Instance):
|
||||
assert events[0].instance_id == instance_id
|
||||
|
||||
|
||||
def test_placement_prioritizes_leaf_cycle_with_less_memory(
|
||||
def test_placement_selects_cycle_with_most_memory(
|
||||
topology: Topology,
|
||||
model_meta: ModelMetadata,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# Arrange two 3-node cycles. The A-B-C cycle has a leaf node (only one outgoing
|
||||
# neighbor per node). The D-E-F cycle has extra outgoing edges making its nodes
|
||||
# non-leaves. Ensure both cycles have sufficient total memory, with the A-B-C
|
||||
# cycle having LESS total memory than D-E-F. The algorithm should still choose
|
||||
# the cycle that contains a leaf node.
|
||||
# Arrange two 3-node cycles with different total memory.
|
||||
# With bidirectional connections for ring topology, both cycles have non-leaf nodes.
|
||||
# The algorithm should select the cycle with the most available memory.
|
||||
|
||||
# Model requires more than any single node but fits within a 3-node cycle
|
||||
model_meta.storage_size.in_bytes = 1500
|
||||
@@ -258,11 +269,6 @@ def test_placement_prioritizes_leaf_cycle_with_less_memory(
|
||||
node_id_e = NodeId()
|
||||
node_id_f = NodeId()
|
||||
|
||||
# Extra sink nodes to make D/E/F non-leaf via additional outgoing edges
|
||||
node_id_x = NodeId()
|
||||
node_id_y = NodeId()
|
||||
node_id_z = NodeId()
|
||||
|
||||
# A-B-C cycle total memory = 1600 (< D-E-F total)
|
||||
topology.add_node(create_node(400, node_id_a))
|
||||
topology.add_node(create_node(400, node_id_b))
|
||||
@@ -273,24 +279,20 @@ def test_placement_prioritizes_leaf_cycle_with_less_memory(
|
||||
topology.add_node(create_node(600, node_id_e))
|
||||
topology.add_node(create_node(600, node_id_f))
|
||||
|
||||
# Extra nodes with tiny memory so they can't form singleton placements
|
||||
topology.add_node(create_node(10, node_id_x))
|
||||
topology.add_node(create_node(10, node_id_y))
|
||||
topology.add_node(create_node(10, node_id_z))
|
||||
|
||||
# Build directed cycles
|
||||
# Build bidirectional cycles for ring topology
|
||||
topology.add_connection(create_connection(node_id_a, node_id_b))
|
||||
topology.add_connection(create_connection(node_id_b, node_id_a))
|
||||
topology.add_connection(create_connection(node_id_b, node_id_c))
|
||||
topology.add_connection(create_connection(node_id_c, node_id_b))
|
||||
topology.add_connection(create_connection(node_id_c, node_id_a))
|
||||
topology.add_connection(create_connection(node_id_a, node_id_c))
|
||||
|
||||
topology.add_connection(create_connection(node_id_d, node_id_e))
|
||||
topology.add_connection(create_connection(node_id_e, node_id_d))
|
||||
topology.add_connection(create_connection(node_id_e, node_id_f))
|
||||
topology.add_connection(create_connection(node_id_f, node_id_e))
|
||||
topology.add_connection(create_connection(node_id_f, node_id_d))
|
||||
|
||||
# Add extra outgoing edges from D/E/F so none of them are leaves
|
||||
topology.add_connection(create_connection(node_id_d, node_id_x))
|
||||
topology.add_connection(create_connection(node_id_e, node_id_y))
|
||||
topology.add_connection(create_connection(node_id_f, node_id_z))
|
||||
topology.add_connection(create_connection(node_id_d, node_id_f))
|
||||
|
||||
cic = place_instance_command(
|
||||
model_meta=model_meta,
|
||||
@@ -299,18 +301,17 @@ def test_placement_prioritizes_leaf_cycle_with_less_memory(
|
||||
# Act
|
||||
placements = place_instance(cic, topology, {})
|
||||
|
||||
# Assert the chosen cycle is A-B-C (contains at least one leaf node), even though
|
||||
# D-E-F has more total memory.
|
||||
# Assert: D-E-F cycle should be selected as it has more total memory
|
||||
assert len(placements) == 1
|
||||
instance_id = list(placements.keys())[0]
|
||||
instance = placements[instance_id]
|
||||
|
||||
assigned_nodes = set(instance.shard_assignments.node_to_runner.keys())
|
||||
expected_leaf_cycle_nodes = {node_id_a, node_id_b, node_id_c}
|
||||
non_leaf_cycle_nodes = {node_id_d, node_id_e, node_id_f}
|
||||
less_memory_cycle_nodes = {node_id_a, node_id_b, node_id_c}
|
||||
more_memory_cycle_nodes = {node_id_d, node_id_e, node_id_f}
|
||||
|
||||
assert expected_leaf_cycle_nodes.issubset(assigned_nodes)
|
||||
assert assigned_nodes.isdisjoint(non_leaf_cycle_nodes)
|
||||
assert more_memory_cycle_nodes.issubset(assigned_nodes)
|
||||
assert assigned_nodes.isdisjoint(less_memory_cycle_nodes)
|
||||
|
||||
|
||||
def test_tensor_rdma_backend_connectivity_matrix(
|
||||
|
||||
@@ -198,6 +198,8 @@ def test_get_shard_assignments(
|
||||
pretty_name="Test Model",
|
||||
n_layers=total_layers,
|
||||
storage_size=Memory.from_kb(1000),
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
)
|
||||
cycles = topology.get_cycles()
|
||||
selected_cycle = cycles[0]
|
||||
|
||||
@@ -9,6 +9,7 @@ from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
IndexedEvent,
|
||||
InputChunkReceived,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
NodeCreated,
|
||||
@@ -40,8 +41,8 @@ def event_apply(event: Event, state: State) -> State:
|
||||
"""Apply an event to state."""
|
||||
match event:
|
||||
case (
|
||||
TestEvent() | ChunkGenerated() | TaskAcknowledged()
|
||||
): # TaskAcknowledged should never be sent by a worker but i dont mind if it just gets ignored
|
||||
TestEvent() | ChunkGenerated() | TaskAcknowledged() | InputChunkReceived()
|
||||
): # Pass-through events that don't modify state
|
||||
return state
|
||||
case InstanceCreated():
|
||||
return apply_instance_created(event, state)
|
||||
|
||||
@@ -44,3 +44,5 @@ LIBP2P_LOCAL_EVENTS_TOPIC = "worker_events"
|
||||
LIBP2P_GLOBAL_EVENTS_TOPIC = "global_events"
|
||||
LIBP2P_ELECTION_MESSAGES_TOPIC = "election_message"
|
||||
LIBP2P_COMMANDS_TOPIC = "commands"
|
||||
|
||||
EXO_MAX_CHUNK_SIZE = 512 * 1024
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.models import ComponentInfo, ModelId, ModelMetadata, ModelTask
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ class ModelCard(CamelCaseModel):
|
||||
model_id: ModelId
|
||||
name: str
|
||||
description: str
|
||||
tasks: list[ModelTask]
|
||||
tags: list[str]
|
||||
metadata: ModelMetadata
|
||||
|
||||
@@ -45,12 +46,15 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
|
||||
name="DeepSeek V3.1 (4-bit)",
|
||||
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
|
||||
pretty_name="DeepSeek V3.1 (4-bit)",
|
||||
storage_size=Memory.from_gb(378),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"deepseek-v3.1-8bit": ModelCard(
|
||||
@@ -58,12 +62,15 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
|
||||
name="DeepSeek V3.1 (8-bit)",
|
||||
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
|
||||
pretty_name="DeepSeek V3.1 (8-bit)",
|
||||
storage_size=Memory.from_gb(713),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# "deepseek-v3.2": ModelCard(
|
||||
@@ -129,12 +136,15 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
|
||||
name="Kimi K2 Instruct (4-bit)",
|
||||
description="""Kimi K2 is a large language model trained on the Kimi K2 dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
|
||||
pretty_name="Kimi K2 Instruct (4-bit)",
|
||||
storage_size=Memory.from_gb(578),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"kimi-k2-thinking": ModelCard(
|
||||
@@ -142,12 +152,15 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
|
||||
name="Kimi K2 Thinking (4-bit)",
|
||||
description="""Kimi K2 Thinking is the latest, most capable version of open-source thinking model.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
|
||||
pretty_name="Kimi K2 Thinking (4-bit)",
|
||||
storage_size=Memory.from_gb(658),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# llama-3.1
|
||||
@@ -156,12 +169,47 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
|
||||
name="Llama 3.1 8B (4-bit)",
|
||||
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
|
||||
pretty_name="Llama 3.1 8B (4-bit)",
|
||||
storage_size=Memory.from_mb(4423),
|
||||
n_layers=32,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.1-8b-8bit": ModelCard(
|
||||
short_id="llama-3.1-8b-8bit",
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
|
||||
name="Llama 3.1 8B (8-bit)",
|
||||
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
|
||||
pretty_name="Llama 3.1 8B (8-bit)",
|
||||
storage_size=Memory.from_mb(8540),
|
||||
n_layers=32,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.1-8b-bf16": ModelCard(
|
||||
short_id="llama-3.1-8b-bf16",
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
|
||||
name="Llama 3.1 8B (BF16)",
|
||||
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
|
||||
pretty_name="Llama 3.1 8B (BF16)",
|
||||
storage_size=Memory.from_mb(16100),
|
||||
n_layers=32,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.1-70b": ModelCard(
|
||||
@@ -169,12 +217,15 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
|
||||
name="Llama 3.1 70B (4-bit)",
|
||||
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
|
||||
pretty_name="Llama 3.1 70B (4-bit)",
|
||||
storage_size=Memory.from_mb(38769),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# llama-3.2
|
||||
@@ -183,12 +234,15 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
|
||||
name="Llama 3.2 1B (4-bit)",
|
||||
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
|
||||
pretty_name="Llama 3.2 1B (4-bit)",
|
||||
storage_size=Memory.from_mb(696),
|
||||
n_layers=16,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.2-3b": ModelCard(
|
||||
@@ -196,12 +250,15 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
|
||||
name="Llama 3.2 3B (4-bit)",
|
||||
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
|
||||
pretty_name="Llama 3.2 3B (4-bit)",
|
||||
storage_size=Memory.from_mb(1777),
|
||||
n_layers=28,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.2-3b-8bit": ModelCard(
|
||||
@@ -209,12 +266,15 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
|
||||
name="Llama 3.2 3B (8-bit)",
|
||||
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
|
||||
pretty_name="Llama 3.2 3B (8-bit)",
|
||||
storage_size=Memory.from_mb(3339),
|
||||
n_layers=28,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# llama-3.3
|
||||
@@ -223,12 +283,15 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
|
||||
name="Llama 3.3 70B (4-bit)",
|
||||
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
|
||||
pretty_name="Llama 3.3 70B",
|
||||
storage_size=Memory.from_mb(38769),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.3-70b-8bit": ModelCard(
|
||||
@@ -236,12 +299,15 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
|
||||
name="Llama 3.3 70B (8-bit)",
|
||||
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
|
||||
pretty_name="Llama 3.3 70B (8-bit)",
|
||||
storage_size=Memory.from_mb(73242),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.3-70b-fp16": ModelCard(
|
||||
@@ -249,26 +315,15 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
|
||||
name="Llama 3.3 70B (FP16)",
|
||||
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
|
||||
pretty_name="Llama 3.3 70B (FP16)",
|
||||
storage_size=Memory.from_mb(137695),
|
||||
n_layers=80,
|
||||
),
|
||||
),
|
||||
# phi-3
|
||||
"phi-3-mini": ModelCard(
|
||||
short_id="phi-3-mini",
|
||||
model_id=ModelId("mlx-community/Phi-3-mini-128k-instruct-4bit"),
|
||||
name="Phi 3 Mini 128k (4-bit)",
|
||||
description="""Phi 3 Mini is a large language model trained on the Phi 3 Mini dataset.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Phi-3-mini-128k-instruct-4bit"),
|
||||
pretty_name="Phi 3 Mini 128k (4-bit)",
|
||||
storage_size=Memory.from_mb(2099),
|
||||
n_layers=32,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# qwen3
|
||||
@@ -277,12 +332,15 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
|
||||
name="Qwen3 0.6B (4-bit)",
|
||||
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
|
||||
pretty_name="Qwen3 0.6B (4-bit)",
|
||||
storage_size=Memory.from_mb(327),
|
||||
n_layers=28,
|
||||
hidden_size=1024,
|
||||
supports_tensor=False,
|
||||
),
|
||||
),
|
||||
"qwen3-0.6b-8bit": ModelCard(
|
||||
@@ -290,12 +348,15 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
|
||||
name="Qwen3 0.6B (8-bit)",
|
||||
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
|
||||
pretty_name="Qwen3 0.6B (8-bit)",
|
||||
storage_size=Memory.from_mb(666),
|
||||
n_layers=28,
|
||||
hidden_size=1024,
|
||||
supports_tensor=False,
|
||||
),
|
||||
),
|
||||
"qwen3-30b": ModelCard(
|
||||
@@ -303,12 +364,15 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
|
||||
name="Qwen3 30B A3B (4-bit)",
|
||||
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
|
||||
pretty_name="Qwen3 30B A3B (4-bit)",
|
||||
storage_size=Memory.from_mb(16797),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-30b-8bit": ModelCard(
|
||||
@@ -316,12 +380,79 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
|
||||
name="Qwen3 30B A3B (8-bit)",
|
||||
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
|
||||
pretty_name="Qwen3 30B A3B (8-bit)",
|
||||
storage_size=Memory.from_mb(31738),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-80b-a3B-4bit": ModelCard(
|
||||
short_id="qwen3-80b-a3B-4bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
|
||||
name="Qwen3 80B A3B (4-bit)",
|
||||
description="""Qwen3 80B""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
|
||||
pretty_name="Qwen3 80B A3B (4-bit)",
|
||||
storage_size=Memory.from_mb(44800),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-80b-a3B-8bit": ModelCard(
|
||||
short_id="qwen3-80b-a3B-8bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
|
||||
name="Qwen3 80B A3B (8-bit)",
|
||||
description="""Qwen3 80B""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
|
||||
pretty_name="Qwen3 80B A3B (8-bit)",
|
||||
storage_size=Memory.from_mb(84700),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-80b-a3B-thinking-4bit": ModelCard(
|
||||
short_id="qwen3-80b-a3B-thinking-4bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
|
||||
name="Qwen3 80B A3B Thinking (4-bit)",
|
||||
description="""Qwen3 80B Reasoning model""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
|
||||
pretty_name="Qwen3 80B A3B (4-bit)",
|
||||
storage_size=Memory.from_mb(84700),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-80b-a3B-thinking-8bit": ModelCard(
|
||||
short_id="qwen3-80b-a3B-thinking-8bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
|
||||
name="Qwen3 80B A3B Thinking (8-bit)",
|
||||
description="""Qwen3 80B Reasoning model""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
|
||||
pretty_name="Qwen3 80B A3B (8-bit)",
|
||||
storage_size=Memory.from_mb(84700),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-235b-a22b-4bit": ModelCard(
|
||||
@@ -329,12 +460,15 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
|
||||
name="Qwen3 235B A22B (4-bit)",
|
||||
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
|
||||
pretty_name="Qwen3 235B A22B (4-bit)",
|
||||
storage_size=Memory.from_gb(132),
|
||||
n_layers=94,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-235b-a22b-8bit": ModelCard(
|
||||
@@ -342,12 +476,15 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
|
||||
name="Qwen3 235B A22B (8-bit)",
|
||||
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
|
||||
pretty_name="Qwen3 235B A22B (8-bit)",
|
||||
storage_size=Memory.from_gb(250),
|
||||
n_layers=94,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-coder-480b-a35b-4bit": ModelCard(
|
||||
@@ -355,12 +492,15 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
|
||||
name="Qwen3 Coder 480B A35B (4-bit)",
|
||||
description="""Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
|
||||
pretty_name="Qwen3 Coder 480B A35B (4-bit)",
|
||||
storage_size=Memory.from_gb(270),
|
||||
n_layers=62,
|
||||
hidden_size=6144,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-coder-480b-a35b-8bit": ModelCard(
|
||||
@@ -368,84 +508,280 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
|
||||
name="Qwen3 Coder 480B A35B (8-bit)",
|
||||
description="""Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
|
||||
pretty_name="Qwen3 Coder 480B A35B (8-bit)",
|
||||
storage_size=Memory.from_gb(540),
|
||||
n_layers=62,
|
||||
hidden_size=6144,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# granite
|
||||
"granite-3.3-2b": ModelCard(
|
||||
short_id="granite-3.3-2b",
|
||||
model_id=ModelId("mlx-community/granite-3.3-2b-instruct-fp16"),
|
||||
name="Granite 3.3 2B (FP16)",
|
||||
description="""Granite-3.3-2B-Instruct is a 2-billion parameter 128K context length language model fine-tuned for improved reasoning and instruction-following capabilities.""",
|
||||
# gpt-oss
|
||||
"gpt-oss-120b-MXFP4-Q8": ModelCard(
|
||||
short_id="gpt-oss-120b-MXFP4-Q8",
|
||||
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
|
||||
name="GPT-OSS 120B (MXFP4-Q8, MLX)",
|
||||
description="""OpenAI's GPT-OSS 120B is a 117B-parameter Mixture-of-Experts model designed for high-reasoning and general-purpose use; this variant is a 4-bit MLX conversion for Apple Silicon.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/granite-3.3-2b-instruct-fp16"),
|
||||
pretty_name="Granite 3.3 2B (FP16)",
|
||||
storage_size=Memory.from_mb(4951),
|
||||
n_layers=40,
|
||||
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
|
||||
pretty_name="GPT-OSS 120B (MXFP4-Q8, MLX)",
|
||||
storage_size=Memory.from_kb(68_996_301),
|
||||
n_layers=36,
|
||||
hidden_size=2880,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# "granite-3.3-8b": ModelCard(
|
||||
# short_id="granite-3.3-8b",
|
||||
# model_id=ModelId("mlx-community/granite-3.3-8b-instruct-fp16"),
|
||||
# name="Granite 3.3 8B",
|
||||
# description="""Granite-3.3-8B-Instruct is a 8-billion parameter 128K context length language model fine-tuned for improved reasoning and instruction-following capabilities.""",
|
||||
"gpt-oss-20b-4bit": ModelCard(
|
||||
short_id="gpt-oss-20b-4bit",
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
|
||||
name="GPT-OSS 20B (MXFP4-Q4, MLX)",
|
||||
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this MLX variant uses MXFP4 4-bit quantization.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
|
||||
pretty_name="GPT-OSS 20B (MXFP4-Q4, MLX)",
|
||||
storage_size=Memory.from_kb(11_744_051),
|
||||
n_layers=24,
|
||||
hidden_size=2880,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# Needs to be quantized g32 or g16.
|
||||
"glm-4.5-air-8bit": ModelCard(
|
||||
short_id="glm-4.5-air-8bit",
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
|
||||
name="GLM 4.5 Air 8bit",
|
||||
description="""GLM 4.5 Air 8bit""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
|
||||
pretty_name="GLM 4.5 Air 8bit",
|
||||
storage_size=Memory.from_gb(114),
|
||||
n_layers=46,
|
||||
hidden_size=4096,
|
||||
supports_tensor=False,
|
||||
),
|
||||
),
|
||||
"glm-4.5-air-bf16": ModelCard(
|
||||
short_id="glm-4.5-air-bf16",
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
|
||||
name="GLM 4.5 Air bf16",
|
||||
description="""GLM 4.5 Air bf16""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
|
||||
pretty_name="GLM 4.5 Air bf16",
|
||||
storage_size=Memory.from_gb(214),
|
||||
n_layers=46,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# "devstral-2-123b-instruct-2512-8bit": ModelCard(
|
||||
# short_id="devstral-2-123b-instruct-2512-8bit",
|
||||
# model_id=ModelId("mlx-community/Devstral-2-123B-Instruct-2512-8bit"),
|
||||
# name="Devstral 2 123B Instruct 2512 (8-bit, MLX)",
|
||||
# description="""Mistral AI's Devstral 2 123B Instruct (2512) is an agentic coding model.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/granite-3.3-8b-instruct-fp16"),
|
||||
# pretty_name="Granite 3.3 8B",
|
||||
# storage_size=Memory.from_kb(15958720),
|
||||
# n_layers=40,
|
||||
# ),
|
||||
# ),
|
||||
# smol-lm
|
||||
# "smol-lm-135m": ModelCard(
|
||||
# short_id="smol-lm-135m",
|
||||
# model_id="mlx-community/SmolLM-135M-4bit",
|
||||
# name="Smol LM 135M",
|
||||
# description="""SmolLM is a series of state-of-the-art small language models available in three sizes: 135M, 360M, and 1.7B parameters. """,
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/SmolLM-135M-4bit"),
|
||||
# pretty_name="Smol LM 135M",
|
||||
# storage_size=Memory.from_kb(73940),
|
||||
# n_layers=30,
|
||||
# ),
|
||||
# ),
|
||||
# gpt-oss
|
||||
# "gpt-oss-120b-MXFP4-Q8": ModelCard(
|
||||
# short_id="gpt-oss-120b-MXFP4-Q8",
|
||||
# model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
|
||||
# name="GPT-OSS 120B (MXFP4-Q8, MLX)",
|
||||
# description="""OpenAI's GPT-OSS 120B is a 117B-parameter Mixture-of-Experts model designed for high-reasoning and general-purpose use; this variant is a 4-bit MLX conversion for Apple Silicon.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
|
||||
# pretty_name="GPT-OSS 120B (MXFP4-Q8, MLX)",
|
||||
# storage_size=Memory.from_kb(68_996_301),
|
||||
# n_layers=36,
|
||||
# hidden_size=2880,
|
||||
# supports_tensor=True,
|
||||
# ),
|
||||
# ),
|
||||
# "gpt-oss-20b-4bit": ModelCard(
|
||||
# short_id="gpt-oss-20b-4bit",
|
||||
# model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
|
||||
# name="GPT-OSS 20B (MXFP4-Q4, MLX)",
|
||||
# description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this MLX variant uses MXFP4 4-bit quantization.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
|
||||
# pretty_name="GPT-OSS 20B (MXFP4-Q4, MLX)",
|
||||
# storage_size=Memory.from_kb(11_744_051),
|
||||
# n_layers=24,
|
||||
# hidden_size=2880,
|
||||
# model_id=ModelId("mlx-community/Devstral-2-123B-Instruct-2512-8bit"),
|
||||
# pretty_name="Devstral 2 123B Instruct 2512 (8-bit, MLX)",
|
||||
# storage_size=Memory.from_kb(133_000_000),
|
||||
# n_layers=88,
|
||||
# hidden_size=12288,
|
||||
# supports_tensor=True,
|
||||
# ),
|
||||
# ),
|
||||
"flux1-schnell": ModelCard(
|
||||
short_id="flux1-schnell",
|
||||
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
|
||||
name="FLUX.1 [schnell]",
|
||||
description="""FLUX.1 [schnell] is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions""",
|
||||
tasks=[ModelTask.TextToImage],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
|
||||
pretty_name="FLUX.1 [schnell]",
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
storage_size=Memory.from_bytes(23782357120), # + 9524621312),
|
||||
n_layers=57, # sharded layers
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="text_encoder_2",
|
||||
component_path="text_encoder_2/",
|
||||
storage_size=Memory.from_bytes(9524621312),
|
||||
n_layers=24,
|
||||
can_shard=False,
|
||||
safetensors_index_filename="model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(23782357120),
|
||||
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
),
|
||||
"flux1-dev": ModelCard(
|
||||
short_id="flux1-dev",
|
||||
model_id=ModelId("black-forest-labs/FLUX.1-dev"),
|
||||
name="FLUX.1 [dev]",
|
||||
description="""FLUX.1 [dev] is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions""",
|
||||
tasks=[ModelTask.TextToImage],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("black-forest-labs/FLUX.1-dev"),
|
||||
pretty_name="FLUX.1 [dev]",
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
storage_size=Memory.from_bytes(23782357120 + 9524621312),
|
||||
n_layers=57, # sharded layers
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="text_encoder_2",
|
||||
component_path="text_encoder_2/",
|
||||
storage_size=Memory.from_bytes(9524621312),
|
||||
n_layers=24,
|
||||
can_shard=False,
|
||||
safetensors_index_filename="model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(23802816640),
|
||||
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
),
|
||||
"qwen-image": ModelCard(
|
||||
short_id="qwen-image",
|
||||
model_id=ModelId("Qwen/Qwen-Image"),
|
||||
name="Qwen Image",
|
||||
description="""an image generation foundation model in the Qwen series that achieves significant advances in complex text rendering and precise image editing""",
|
||||
tasks=[ModelTask.TextToImage],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("Qwen/Qwen-Image"),
|
||||
pretty_name="Qwen Image",
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(16584333312),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(40860802176),
|
||||
n_layers=60,
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
),
|
||||
"qwen-image-edit-2509": ModelCard(
|
||||
short_id="qwen-image-edit-2509",
|
||||
model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
|
||||
name="Qwen Image Edit 2509",
|
||||
description="""an image generation foundation model in the Qwen series that achieves significant advances in complex text rendering and precise image editing""",
|
||||
tasks=[ModelTask.ImageToImage],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
|
||||
pretty_name="Qwen Image Edit 2509",
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(16584333312),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(40860802176),
|
||||
n_layers=60,
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ from huggingface_hub import model_info
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from exo.shared.models.model_cards import MODEL_CARDS
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.worker.download.download_utils import (
|
||||
@@ -25,6 +26,7 @@ class ConfigData(BaseModel):
|
||||
n_layers: Annotated[int, Field(ge=0)] | None = None # Sometimes used
|
||||
num_decoder_layers: Annotated[int, Field(ge=0)] | None = None # Transformer models
|
||||
decoder_layers: Annotated[int, Field(ge=0)] | None = None # Some architectures
|
||||
hidden_size: Annotated[int, Field(ge=0)] | None = None
|
||||
|
||||
@property
|
||||
def layer_count(self) -> int:
|
||||
@@ -106,10 +108,19 @@ async def _get_model_meta(model_id: str) -> ModelMetadata:
|
||||
config_data = await get_config_data(model_id)
|
||||
num_layers = config_data.layer_count
|
||||
mem_size_bytes = await get_safetensors_size(model_id)
|
||||
model_card = next(
|
||||
(card for card in MODEL_CARDS.values() if card.model_id == ModelId(model_id)),
|
||||
None,
|
||||
)
|
||||
|
||||
return ModelMetadata(
|
||||
model_id=ModelId(model_id),
|
||||
pretty_name=model_id,
|
||||
pretty_name=model_card.name if model_card is not None else model_id,
|
||||
storage_size=mem_size_bytes,
|
||||
n_layers=num_layers,
|
||||
hidden_size=config_data.hidden_size or 0,
|
||||
# TODO: all custom models currently do not support tensor. We could add a dynamic test for this?
|
||||
supports_tensor=model_card.metadata.supports_tensor
|
||||
if model_card is not None
|
||||
else False,
|
||||
)
|
||||
|
||||
@@ -36,6 +36,8 @@ def get_pipeline_shard_metadata(
|
||||
pretty_name=str(model_id),
|
||||
storage_size=Memory.from_mb(100000),
|
||||
n_layers=32,
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
),
|
||||
device_rank=device_rank,
|
||||
world_size=world_size,
|
||||
|
||||
@@ -19,7 +19,7 @@ def test_apply_node_download_progress():
|
||||
NodeDownloadProgress(download_progress=event), state
|
||||
)
|
||||
|
||||
assert new_state == State(downloads={NodeId("node-1"): [event]})
|
||||
assert new_state.downloads == {NodeId("node-1"): [event]}
|
||||
|
||||
|
||||
def test_apply_two_node_download_progress():
|
||||
@@ -42,4 +42,4 @@ def test_apply_two_node_download_progress():
|
||||
# TODO: This test is failing. We should support the following:
|
||||
# 1. Downloading multiple models concurrently on the same node (one per runner is fine).
|
||||
# 2. Downloading a model, it completes, then downloading a different model on the same node.
|
||||
assert new_state == State(downloads={NodeId("node-1"): [event1, event2]})
|
||||
assert new_state.downloads == {NodeId("node-1"): [event1, event2]}
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Literal
|
||||
|
||||
from fastapi import UploadFile
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic_core import PydanticUseDefault
|
||||
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
|
||||
@@ -27,6 +29,7 @@ class ModelListModel(BaseModel):
|
||||
tags: list[str] = Field(default=[])
|
||||
storage_size_megabytes: int = Field(default=0)
|
||||
supports_tensor: bool = Field(default=False)
|
||||
tasks: list[str] = Field(default=[])
|
||||
|
||||
|
||||
class ModelList(BaseModel):
|
||||
@@ -174,9 +177,82 @@ class DeleteInstanceTaskParams(BaseModel):
|
||||
class CreateInstanceResponse(BaseModel):
|
||||
message: str
|
||||
command_id: CommandId
|
||||
model_meta: ModelMetadata
|
||||
|
||||
|
||||
class DeleteInstanceResponse(BaseModel):
|
||||
message: str
|
||||
command_id: CommandId
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class ImageGenerationTaskParams(BaseModel):
|
||||
prompt: str
|
||||
# background: str | None = None
|
||||
model: str
|
||||
# moderation: str | None = None
|
||||
n: int | None = 1
|
||||
# output_compression: int | None = None
|
||||
output_format: Literal["png", "jpeg", "webp"] = "png"
|
||||
partial_images: int | None = 0
|
||||
quality: Literal["high", "medium", "low"] | None = "medium"
|
||||
response_format: Literal["url", "b64_json"] | None = "b64_json"
|
||||
size: str | None = "1024x1024"
|
||||
stream: bool | None = False
|
||||
# style: str | None = "vivid"
|
||||
# user: str | None = None
|
||||
|
||||
|
||||
class ImageEditsTaskParams(BaseModel):
|
||||
image: UploadFile
|
||||
prompt: str
|
||||
input_fidelity: float = 0.7
|
||||
model: str
|
||||
n: int | None = 1
|
||||
quality: Literal["high", "medium", "low"] | None = "medium"
|
||||
output_format: Literal["png", "jpeg", "webp"] = "png"
|
||||
response_format: Literal["url", "b64_json"] | None = "b64_json"
|
||||
size: str | None = "1024x1024"
|
||||
# user: str | None = None
|
||||
|
||||
|
||||
class ImageEditsInternalParams(BaseModel):
|
||||
"""Serializable version of ImageEditsTaskParams for distributed task execution."""
|
||||
|
||||
image_data: str = "" # Base64-encoded image (empty when using chunked transfer)
|
||||
total_input_chunks: int = 0
|
||||
prompt: str
|
||||
model: str
|
||||
n: int | None = 1
|
||||
quality: Literal["high", "medium", "low"] | None = "medium"
|
||||
output_format: Literal["png", "jpeg", "webp"] = "png"
|
||||
response_format: Literal["url", "b64_json"] | None = "b64_json"
|
||||
size: str | None = "1024x1024"
|
||||
image_strength: float = 0.7
|
||||
stream: bool = False
|
||||
partial_images: int | None = 0
|
||||
|
||||
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
for name, value in super().__repr_args__():
|
||||
if name == "image_data":
|
||||
yield name, f"<{len(self.image_data)} chars>"
|
||||
elif name is not None:
|
||||
yield name, value
|
||||
|
||||
|
||||
class ImageData(BaseModel):
|
||||
b64_json: str | None = None
|
||||
url: str | None = None
|
||||
revised_prompt: str | None = None
|
||||
|
||||
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
for name, value in super().__repr_args__():
|
||||
if name == "b64_json" and value is not None:
|
||||
yield name, f"<{len(value)} chars>"
|
||||
elif name is not None:
|
||||
yield name, value
|
||||
|
||||
|
||||
class ImageGenerationResponse(BaseModel):
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
data: list[ImageData]
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
from collections.abc import Generator
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
from .api import FinishReason
|
||||
from .common import CommandId
|
||||
from .models import ModelId
|
||||
|
||||
|
||||
@@ -23,7 +26,34 @@ class TokenChunk(BaseChunk):
|
||||
|
||||
|
||||
class ImageChunk(BaseChunk):
|
||||
data: bytes
|
||||
data: str
|
||||
chunk_index: int
|
||||
total_chunks: int
|
||||
image_index: int
|
||||
is_partial: bool = False
|
||||
partial_index: int | None = None
|
||||
total_partials: int | None = None
|
||||
|
||||
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
for name, value in super().__repr_args__():
|
||||
if name == "data":
|
||||
yield name, f"<{len(self.data)} chars>"
|
||||
elif name is not None:
|
||||
yield name, value
|
||||
|
||||
|
||||
class InputImageChunk(BaseChunk):
|
||||
command_id: CommandId
|
||||
data: str
|
||||
chunk_index: int
|
||||
total_chunks: int
|
||||
|
||||
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
for name, value in super().__repr_args__():
|
||||
if name == "data":
|
||||
yield name, f"<{len(self.data)} chars>"
|
||||
elif name is not None:
|
||||
yield name, value
|
||||
|
||||
|
||||
GenerationChunk = TokenChunk | ImageChunk
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.types.api import ChatCompletionTaskParams
|
||||
from exo.shared.types.api import (
|
||||
ChatCompletionTaskParams,
|
||||
ImageEditsInternalParams,
|
||||
ImageGenerationTaskParams,
|
||||
)
|
||||
from exo.shared.types.chunks import InputImageChunk
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.models import ModelMetadata
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
@@ -20,6 +25,14 @@ class ChatCompletion(BaseCommand):
|
||||
request_params: ChatCompletionTaskParams
|
||||
|
||||
|
||||
class ImageGeneration(BaseCommand):
|
||||
request_params: ImageGenerationTaskParams
|
||||
|
||||
|
||||
class ImageEdits(BaseCommand):
|
||||
request_params: ImageEditsInternalParams
|
||||
|
||||
|
||||
class PlaceInstance(BaseCommand):
|
||||
model_meta: ModelMetadata
|
||||
sharding: Sharding
|
||||
@@ -39,6 +52,12 @@ class TaskFinished(BaseCommand):
|
||||
finished_command_id: CommandId
|
||||
|
||||
|
||||
class SendInputChunk(BaseCommand):
|
||||
"""Command to send an input image chunk (converted to event by master)."""
|
||||
|
||||
chunk: InputImageChunk
|
||||
|
||||
|
||||
class RequestEventLog(BaseCommand):
|
||||
since_idx: int
|
||||
|
||||
@@ -47,10 +66,13 @@ Command = (
|
||||
TestCommand
|
||||
| RequestEventLog
|
||||
| ChatCompletion
|
||||
| ImageGeneration
|
||||
| ImageEdits
|
||||
| PlaceInstance
|
||||
| CreateInstance
|
||||
| DeleteInstance
|
||||
| TaskFinished
|
||||
| SendInputChunk
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from datetime import datetime
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.topology import Connection, NodePerformanceProfile
|
||||
from exo.shared.types.chunks import GenerationChunk
|
||||
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
from exo.shared.types.profiling import MemoryPerformanceProfile
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
@@ -106,6 +106,11 @@ class ChunkGenerated(BaseEvent):
|
||||
chunk: GenerationChunk
|
||||
|
||||
|
||||
class InputChunkReceived(BaseEvent):
|
||||
command_id: CommandId
|
||||
chunk: InputImageChunk
|
||||
|
||||
|
||||
class TopologyEdgeCreated(BaseEvent):
|
||||
edge: Connection
|
||||
|
||||
@@ -131,6 +136,7 @@ Event = (
|
||||
| NodeMemoryMeasured
|
||||
| NodeDownloadProgress
|
||||
| ChunkGenerated
|
||||
| InputChunkReceived
|
||||
| TopologyEdgeCreated
|
||||
| TopologyEdgeDeleted
|
||||
)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import PositiveInt
|
||||
|
||||
from exo.shared.types.common import Id
|
||||
@@ -9,8 +11,26 @@ class ModelId(Id):
|
||||
pass
|
||||
|
||||
|
||||
class ModelTask(str, Enum):
|
||||
TextGeneration = "TextGeneration"
|
||||
TextToImage = "TextToImage"
|
||||
ImageToImage = "ImageToImage"
|
||||
|
||||
|
||||
class ComponentInfo(CamelCaseModel):
|
||||
component_name: str
|
||||
component_path: str
|
||||
storage_size: Memory
|
||||
n_layers: PositiveInt | None
|
||||
can_shard: bool
|
||||
safetensors_index_filename: str | None
|
||||
|
||||
|
||||
class ModelMetadata(CamelCaseModel):
|
||||
model_id: ModelId
|
||||
pretty_name: str
|
||||
storage_size: Memory
|
||||
n_layers: PositiveInt
|
||||
hidden_size: PositiveInt
|
||||
supports_tensor: bool
|
||||
components: list[ComponentInfo] | None = None
|
||||
|
||||
@@ -2,7 +2,11 @@ from enum import Enum
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.types.api import ChatCompletionTaskParams
|
||||
from exo.shared.types.api import (
|
||||
ChatCompletionTaskParams,
|
||||
ImageEditsInternalParams,
|
||||
ImageGenerationTaskParams,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, Id
|
||||
from exo.shared.types.worker.instances import BoundInstance, InstanceId
|
||||
from exo.shared.types.worker.runners import RunnerId
|
||||
@@ -40,6 +44,10 @@ class LoadModel(BaseTask): # emitted by Worker
|
||||
pass
|
||||
|
||||
|
||||
class ConnectToGroup(BaseTask): # emitted by Worker
|
||||
pass
|
||||
|
||||
|
||||
class StartWarmup(BaseTask): # emitted by Worker
|
||||
pass
|
||||
|
||||
@@ -52,10 +60,34 @@ class ChatCompletion(BaseTask): # emitted by Master
|
||||
error_message: str | None = Field(default=None)
|
||||
|
||||
|
||||
class ImageGeneration(BaseTask): # emitted by Master
|
||||
command_id: CommandId
|
||||
task_params: ImageGenerationTaskParams
|
||||
|
||||
error_type: str | None = Field(default=None)
|
||||
error_message: str | None = Field(default=None)
|
||||
|
||||
|
||||
class ImageEdits(BaseTask): # emitted by Master
|
||||
command_id: CommandId
|
||||
task_params: ImageEditsInternalParams
|
||||
|
||||
error_type: str | None = Field(default=None)
|
||||
error_message: str | None = Field(default=None)
|
||||
|
||||
|
||||
class Shutdown(BaseTask): # emitted by Worker
|
||||
runner_id: RunnerId
|
||||
|
||||
|
||||
Task = (
|
||||
CreateRunner | DownloadModel | LoadModel | StartWarmup | ChatCompletion | Shutdown
|
||||
CreateRunner
|
||||
| DownloadModel
|
||||
| ConnectToGroup
|
||||
| LoadModel
|
||||
| StartWarmup
|
||||
| ChatCompletion
|
||||
| ImageGeneration
|
||||
| ImageEdits
|
||||
| Shutdown
|
||||
)
|
||||
|
||||
@@ -25,7 +25,8 @@ class BaseInstance(TaggedModel):
|
||||
|
||||
|
||||
class MlxRingInstance(BaseInstance):
|
||||
hosts: list[Host]
|
||||
hosts_by_node: dict[NodeId, list[Host]]
|
||||
ephemeral_port: int
|
||||
|
||||
|
||||
class MlxJacclInstance(BaseInstance):
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Literal
|
||||
|
||||
from exo.shared.types.api import FinishReason
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
@@ -17,5 +20,31 @@ class GenerationResponse(BaseRunnerResponse):
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
|
||||
class ImageGenerationResponse(BaseRunnerResponse):
|
||||
image_data: bytes
|
||||
format: Literal["png", "jpeg", "webp"] = "png"
|
||||
|
||||
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
for name, value in super().__repr_args__():
|
||||
if name == "image_data":
|
||||
yield name, f"<{len(self.image_data)} bytes>"
|
||||
elif name is not None:
|
||||
yield name, value
|
||||
|
||||
|
||||
class PartialImageResponse(BaseRunnerResponse):
|
||||
image_data: bytes
|
||||
format: Literal["png", "jpeg", "webp"] = "png"
|
||||
partial_index: int
|
||||
total_partials: int
|
||||
|
||||
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
for name, value in super().__repr_args__():
|
||||
if name == "image_data":
|
||||
yield name, f"<{len(self.image_data)} bytes>"
|
||||
elif name is not None:
|
||||
yield name, value
|
||||
|
||||
|
||||
class FinishedResponse(BaseRunnerResponse):
|
||||
pass
|
||||
|
||||
@@ -21,7 +21,15 @@ class BaseRunnerStatus(TaggedModel):
|
||||
return isinstance(self, RunnerRunning)
|
||||
|
||||
|
||||
class RunnerWaitingForModel(BaseRunnerStatus):
|
||||
class RunnerIdle(BaseRunnerStatus):
|
||||
pass
|
||||
|
||||
|
||||
class RunnerConnecting(BaseRunnerStatus):
|
||||
pass
|
||||
|
||||
|
||||
class RunnerConnected(BaseRunnerStatus):
|
||||
pass
|
||||
|
||||
|
||||
@@ -45,6 +53,10 @@ class RunnerRunning(BaseRunnerStatus):
|
||||
pass
|
||||
|
||||
|
||||
class RunnerShuttingDown(BaseRunnerStatus):
|
||||
pass
|
||||
|
||||
|
||||
class RunnerShutdown(BaseRunnerStatus):
|
||||
pass
|
||||
|
||||
@@ -54,12 +66,15 @@ class RunnerFailed(BaseRunnerStatus):
|
||||
|
||||
|
||||
RunnerStatus = (
|
||||
RunnerWaitingForModel
|
||||
RunnerIdle
|
||||
| RunnerConnecting
|
||||
| RunnerConnected
|
||||
| RunnerLoading
|
||||
| RunnerLoaded
|
||||
| RunnerWarmingUp
|
||||
| RunnerReady
|
||||
| RunnerRunning
|
||||
| RunnerShuttingDown
|
||||
| RunnerShutdown
|
||||
| RunnerFailed
|
||||
)
|
||||
|
||||
@@ -9,6 +9,7 @@ from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import Callable, Literal
|
||||
from urllib.parse import urljoin
|
||||
from huggingface_hub._snapshot_download import snapshot_download
|
||||
|
||||
import aiofiles
|
||||
import aiofiles.os as aios
|
||||
@@ -441,15 +442,39 @@ def calculate_repo_progress(
|
||||
async def get_weight_map(repo_id: str, revision: str = "main") -> dict[str, str]:
|
||||
target_dir = (await ensure_models_dir()) / str(repo_id).replace("/", "--")
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
index_file = await download_file_with_retry(
|
||||
repo_id, revision, "model.safetensors.index.json", target_dir
|
||||
|
||||
index_files_dir = snapshot_download(
|
||||
repo_id=repo_id, local_dir=target_dir, allow_patterns="*.safetensors.index.json"
|
||||
)
|
||||
async with aiofiles.open(index_file, "r") as f:
|
||||
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
|
||||
return index_data.weight_map
|
||||
|
||||
index_files = list(Path(index_files_dir).glob("**/*.safetensors.index.json"))
|
||||
|
||||
weight_map: dict[str, str] = {}
|
||||
|
||||
for index_file in index_files:
|
||||
relative_dir = index_file.parent.relative_to(index_files_dir)
|
||||
|
||||
async with aiofiles.open(index_file, "r") as f:
|
||||
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
|
||||
|
||||
if relative_dir != Path("."):
|
||||
prefixed_weight_map = {
|
||||
f"{relative_dir}/{key}": str(relative_dir / value)
|
||||
for key, value in index_data.weight_map.items()
|
||||
}
|
||||
weight_map = weight_map | prefixed_weight_map
|
||||
else:
|
||||
weight_map = weight_map | index_data.weight_map
|
||||
|
||||
return weight_map
|
||||
|
||||
|
||||
async def resolve_allow_patterns(shard: ShardMetadata) -> list[str]:
|
||||
# TODO: 'Smart' downloads are disabled because:
|
||||
# (i) We don't handle all kinds of files;
|
||||
# (ii) We don't have sticky sessions.
|
||||
# (iii) Tensor parallel requires all files.
|
||||
return ["*"]
|
||||
try:
|
||||
weight_map = await get_weight_map(str(shard.model_meta.model_id))
|
||||
return get_allow_patterns(weight_map, shard)
|
||||
@@ -546,8 +571,6 @@ async def download_shard(
|
||||
logger.info(f"Downloading {shard.model_meta.model_id=} with {allow_patterns=}")
|
||||
|
||||
all_start_time = time.time()
|
||||
# TODO: currently not recursive. Some models might require subdirectories - thus this will need to be changed.
|
||||
# Update: <- This does not seem to be the case. Yay?
|
||||
file_list = await fetch_file_list_with_cache(
|
||||
str(shard.model_meta.model_id), revision, recursive=True
|
||||
)
|
||||
|
||||
@@ -95,23 +95,73 @@ def extract_layer_num(tensor_name: str) -> int | None:
|
||||
|
||||
def get_allow_patterns(weight_map: dict[str, str], shard: ShardMetadata) -> list[str]:
|
||||
default_patterns = set(
|
||||
["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt", "*.jinja"]
|
||||
[
|
||||
"*.json",
|
||||
"*.py",
|
||||
"tokenizer.model",
|
||||
"tiktoken.model",
|
||||
"*/spiece.model",
|
||||
"*.tiktoken",
|
||||
"*.txt",
|
||||
"*.jinja",
|
||||
]
|
||||
)
|
||||
shard_specific_patterns: set[str] = set()
|
||||
if weight_map:
|
||||
for tensor_name, filename in weight_map.items():
|
||||
layer_num = extract_layer_num(tensor_name)
|
||||
if (
|
||||
layer_num is not None
|
||||
and shard.start_layer <= layer_num <= shard.end_layer
|
||||
):
|
||||
shard_specific_patterns.add(filename)
|
||||
layer_independent_files = set(
|
||||
[v for k, v in weight_map.items() if extract_layer_num(k) is None]
|
||||
|
||||
if shard.model_meta.components is not None:
|
||||
shardable_component = next(
|
||||
(c for c in shard.model_meta.components if c.can_shard), None
|
||||
)
|
||||
shard_specific_patterns.update(layer_independent_files)
|
||||
logger.debug(f"get_allow_patterns {shard=} {layer_independent_files=}")
|
||||
|
||||
if weight_map and shardable_component:
|
||||
for tensor_name, filename in weight_map.items():
|
||||
# Strip component prefix from tensor name (added by weight map namespacing)
|
||||
# E.g., "transformer/blocks.0.weight" -> "blocks.0.weight"
|
||||
if "/" in tensor_name:
|
||||
_, tensor_name_no_prefix = tensor_name.split("/", 1)
|
||||
else:
|
||||
tensor_name_no_prefix = tensor_name
|
||||
|
||||
# Determine which component this file belongs to from filename
|
||||
component_path = Path(filename).parts[0] if "/" in filename else None
|
||||
|
||||
if component_path == shardable_component.component_path.rstrip("/"):
|
||||
layer_num = extract_layer_num(tensor_name_no_prefix)
|
||||
if (
|
||||
layer_num is not None
|
||||
and shard.start_layer <= layer_num < shard.end_layer
|
||||
):
|
||||
shard_specific_patterns.add(filename)
|
||||
|
||||
if shard.is_first_layer or shard.is_last_layer:
|
||||
shard_specific_patterns.add(filename)
|
||||
else:
|
||||
shard_specific_patterns.add(filename)
|
||||
|
||||
else:
|
||||
shard_specific_patterns = set(["*.safetensors"])
|
||||
|
||||
# TODO(ciaran): temporary - Include all files from non-shardable components that have no index file
|
||||
for component in shard.model_meta.components:
|
||||
if not component.can_shard and component.safetensors_index_filename is None:
|
||||
component_pattern = f"{component.component_path.rstrip('/')}/*"
|
||||
shard_specific_patterns.add(component_pattern)
|
||||
else:
|
||||
shard_specific_patterns = set(["*.safetensors"])
|
||||
if weight_map:
|
||||
for tensor_name, filename in weight_map.items():
|
||||
layer_num = extract_layer_num(tensor_name)
|
||||
if (
|
||||
layer_num is not None
|
||||
and shard.start_layer <= layer_num < shard.end_layer
|
||||
):
|
||||
shard_specific_patterns.add(filename)
|
||||
layer_independent_files = set(
|
||||
[v for k, v in weight_map.items() if extract_layer_num(k) is None]
|
||||
)
|
||||
shard_specific_patterns.update(layer_independent_files)
|
||||
logger.debug(f"get_allow_patterns {shard=} {layer_independent_files=}")
|
||||
else:
|
||||
shard_specific_patterns = set(["*.safetensors"])
|
||||
|
||||
logger.info(f"get_allow_patterns {shard=} {shard_specific_patterns=}")
|
||||
return list(default_patterns | shard_specific_patterns)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from copy import copy
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import AsyncIterator, Callable
|
||||
@@ -12,7 +13,7 @@ from exo.shared.types.worker.shards import (
|
||||
from exo.worker.download.download_utils import RepoDownloadProgress
|
||||
|
||||
|
||||
# TODO: the PipelineShardMetadata getting reinstantiated is a bit messy. Shoudl this be a classmethod?
|
||||
# TODO: the PipelineShardMetadata getting reinstantiated is a bit messy. Should this be a classmethod?
|
||||
class ShardDownloader(ABC):
|
||||
@abstractmethod
|
||||
async def ensure_shard(
|
||||
@@ -43,34 +44,7 @@ class ShardDownloader(ABC):
|
||||
Yields:
|
||||
tuple[Path, RepoDownloadProgress]: The path and progress of a shard download.
|
||||
"""
|
||||
yield (
|
||||
Path("/tmp/noop_shard"),
|
||||
RepoDownloadProgress(
|
||||
repo_id="noop",
|
||||
repo_revision="noop",
|
||||
shard=PipelineShardMetadata(
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("noop"),
|
||||
pretty_name="noope",
|
||||
storage_size=Memory.from_bytes(0),
|
||||
n_layers=1,
|
||||
),
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
start_layer=0,
|
||||
end_layer=1,
|
||||
n_layers=1,
|
||||
),
|
||||
completed_files=0,
|
||||
total_files=0,
|
||||
downloaded_bytes=Memory.from_bytes(0),
|
||||
downloaded_bytes_this_session=Memory.from_bytes(0),
|
||||
total_bytes=Memory.from_bytes(0),
|
||||
overall_speed=0,
|
||||
overall_eta=timedelta(seconds=0),
|
||||
status="complete",
|
||||
),
|
||||
)
|
||||
yield (Path("/tmp/noop_shard"), NOOP_DOWNLOAD_PROGRESS)
|
||||
|
||||
@abstractmethod
|
||||
async def get_shard_download_status_for_shard(
|
||||
@@ -94,46 +68,41 @@ class NoopShardDownloader(ShardDownloader):
|
||||
) -> AsyncIterator[tuple[Path, RepoDownloadProgress]]:
|
||||
yield (
|
||||
Path("/tmp/noop_shard"),
|
||||
RepoDownloadProgress(
|
||||
repo_id="noop",
|
||||
repo_revision="noop",
|
||||
shard=PipelineShardMetadata(
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("noop"),
|
||||
pretty_name="noope",
|
||||
storage_size=Memory.from_bytes(0),
|
||||
n_layers=1,
|
||||
),
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
start_layer=0,
|
||||
end_layer=1,
|
||||
n_layers=1,
|
||||
),
|
||||
completed_files=0,
|
||||
total_files=0,
|
||||
downloaded_bytes=Memory.from_bytes(0),
|
||||
downloaded_bytes_this_session=Memory.from_bytes(0),
|
||||
total_bytes=Memory.from_bytes(0),
|
||||
overall_speed=0,
|
||||
overall_eta=timedelta(seconds=0),
|
||||
status="complete",
|
||||
),
|
||||
NOOP_DOWNLOAD_PROGRESS,
|
||||
)
|
||||
|
||||
async def get_shard_download_status_for_shard(
|
||||
self, shard: ShardMetadata
|
||||
) -> RepoDownloadProgress:
|
||||
return RepoDownloadProgress(
|
||||
repo_id="noop",
|
||||
repo_revision="noop",
|
||||
shard=shard,
|
||||
completed_files=0,
|
||||
total_files=0,
|
||||
downloaded_bytes=Memory.from_bytes(0),
|
||||
downloaded_bytes_this_session=Memory.from_bytes(0),
|
||||
total_bytes=Memory.from_bytes(0),
|
||||
overall_speed=0,
|
||||
overall_eta=timedelta(seconds=0),
|
||||
status="complete",
|
||||
)
|
||||
dp = copy(NOOP_DOWNLOAD_PROGRESS)
|
||||
dp.shard = shard
|
||||
return dp
|
||||
|
||||
|
||||
NOOP_DOWNLOAD_PROGRESS = RepoDownloadProgress(
|
||||
repo_id="noop",
|
||||
repo_revision="noop",
|
||||
shard=PipelineShardMetadata(
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("noop"),
|
||||
pretty_name="noope",
|
||||
storage_size=Memory.from_bytes(0),
|
||||
n_layers=1,
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
),
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
start_layer=0,
|
||||
end_layer=1,
|
||||
n_layers=1,
|
||||
),
|
||||
completed_files=0,
|
||||
total_files=0,
|
||||
downloaded_bytes=Memory.from_bytes(0),
|
||||
downloaded_bytes_this_session=Memory.from_bytes(0),
|
||||
total_bytes=Memory.from_bytes(0),
|
||||
overall_speed=0,
|
||||
overall_eta=timedelta(seconds=0),
|
||||
status="complete",
|
||||
)
|
||||
|
||||
10
src/exo/worker/engines/image/__init__.py
Normal file
10
src/exo/worker/engines/image/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from exo.worker.engines.image.base import ImageGenerator
|
||||
from exo.worker.engines.image.distributed_model import initialize_image_model
|
||||
from exo.worker.engines.image.generate import generate_image, warmup_image_generator
|
||||
|
||||
__all__ = [
|
||||
"ImageGenerator",
|
||||
"generate_image",
|
||||
"initialize_image_model",
|
||||
"warmup_image_generator",
|
||||
]
|
||||
50
src/exo/worker/engines/image/base.py
Normal file
50
src/exo/worker/engines/image/base.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
from typing import Literal, Protocol, runtime_checkable
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ImageGenerator(Protocol):
|
||||
@property
|
||||
def rank(self) -> int: ...
|
||||
|
||||
@property
|
||||
def is_first_stage(self) -> bool: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
height: int,
|
||||
width: int,
|
||||
quality: Literal["low", "medium", "high"],
|
||||
seed: int,
|
||||
image_path: Path | None = None,
|
||||
partial_images: int = 0,
|
||||
) -> Generator[Image.Image | tuple[Image.Image, int, int], None, None]:
|
||||
"""Generate an image from a text prompt, or edit an existing image.
|
||||
|
||||
For distributed inference, only the last stage returns images.
|
||||
Other stages yield nothing after participating in the pipeline.
|
||||
|
||||
When partial_images > 0, yields intermediate images during diffusion
|
||||
as tuples of (image, partial_index, total_partials), then yields
|
||||
the final image.
|
||||
|
||||
When partial_images = 0 (default), only yields the final image.
|
||||
|
||||
Args:
|
||||
prompt: Text description of the image to generate
|
||||
height: Image height in pixels
|
||||
width: Image width in pixels
|
||||
quality: Generation quality level
|
||||
seed: Random seed for reproducibility
|
||||
image_path: Optional path to input image for image editing
|
||||
partial_images: Number of intermediate images to yield (0 for none)
|
||||
|
||||
Yields:
|
||||
Intermediate images as (Image, partial_index, total_partials) tuples
|
||||
Final PIL Image (last stage) or nothing (other stages)
|
||||
"""
|
||||
...
|
||||
74
src/exo/worker/engines/image/config.py
Normal file
74
src/exo/worker/engines/image/config.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from enum import Enum
|
||||
from math import ceil
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BlockType(Enum):
|
||||
JOINT = "joint" # Separate image/text streams
|
||||
SINGLE = "single" # Concatenated streams
|
||||
|
||||
|
||||
class TransformerBlockConfig(BaseModel):
|
||||
model_config = {"frozen": True}
|
||||
|
||||
block_type: BlockType
|
||||
count: int
|
||||
has_separate_text_output: bool # True for joint blocks that output text separately
|
||||
|
||||
|
||||
class ImageModelConfig(BaseModel):
|
||||
model_config = {"frozen": True}
|
||||
|
||||
# Model identification
|
||||
model_family: str # "flux", "fibo", "qwen"
|
||||
model_variant: str # "schnell", "dev", etc.
|
||||
|
||||
# Architecture parameters
|
||||
hidden_dim: int
|
||||
num_heads: int
|
||||
head_dim: int
|
||||
|
||||
# Block configuration - ordered sequence of block types
|
||||
block_configs: tuple[TransformerBlockConfig, ...]
|
||||
|
||||
# Tokenization parameters
|
||||
patch_size: int # 2 for Flux/Qwen
|
||||
vae_scale_factor: int # 8 for Flux, 16 for others
|
||||
|
||||
# Inference parameters
|
||||
default_steps: dict[str, int] # {"low": X, "medium": Y, "high": Z}
|
||||
num_sync_steps_factor: float # Fraction of steps for sync phase
|
||||
|
||||
# Feature flags
|
||||
uses_attention_mask: bool # True for Fibo
|
||||
|
||||
# CFG (Classifier-Free Guidance) parameters
|
||||
guidance_scale: float | None = None # None or <= 1.0 disables CFG
|
||||
|
||||
@property
|
||||
def total_blocks(self) -> int:
|
||||
"""Total number of transformer blocks."""
|
||||
return sum(bc.count for bc in self.block_configs)
|
||||
|
||||
@property
|
||||
def joint_block_count(self) -> int:
|
||||
"""Number of joint transformer blocks."""
|
||||
return sum(
|
||||
bc.count for bc in self.block_configs if bc.block_type == BlockType.JOINT
|
||||
)
|
||||
|
||||
@property
|
||||
def single_block_count(self) -> int:
|
||||
"""Number of single transformer blocks."""
|
||||
return sum(
|
||||
bc.count for bc in self.block_configs if bc.block_type == BlockType.SINGLE
|
||||
)
|
||||
|
||||
def get_steps_for_quality(self, quality: str) -> int:
|
||||
"""Get inference steps for a quality level."""
|
||||
return self.default_steps[quality]
|
||||
|
||||
def get_num_sync_steps(self, quality: str) -> int:
|
||||
"""Get number of synchronous steps based on quality."""
|
||||
return ceil(self.default_steps[quality] * self.num_sync_steps_factor)
|
||||
228
src/exo/worker/engines/image/distributed_model.py
Normal file
228
src/exo/worker/engines/image/distributed_model.py
Normal file
@@ -0,0 +1,228 @@
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
from mflux.config.config import Config
|
||||
from PIL import Image
|
||||
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
from exo.worker.download.download_utils import build_model_path
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.models import (
|
||||
create_adapter_for_model,
|
||||
get_config_for_model,
|
||||
)
|
||||
from exo.worker.engines.image.models.base import BaseModelAdapter
|
||||
from exo.worker.engines.image.pipeline import DiffusionRunner
|
||||
from exo.worker.engines.mlx.utils_mlx import mlx_distributed_init, mx_barrier
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
|
||||
class DistributedImageModel:
|
||||
__slots__ = (
|
||||
"_config",
|
||||
"_adapter",
|
||||
"_group",
|
||||
"_shard_metadata",
|
||||
"_runner",
|
||||
)
|
||||
|
||||
_config: ImageModelConfig
|
||||
_adapter: BaseModelAdapter
|
||||
_group: Optional[mx.distributed.Group]
|
||||
_shard_metadata: PipelineShardMetadata
|
||||
_runner: DiffusionRunner
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
local_path: Path,
|
||||
shard_metadata: PipelineShardMetadata,
|
||||
group: Optional[mx.distributed.Group] = None,
|
||||
quantize: int | None = None,
|
||||
):
|
||||
# Get model config and create adapter (adapter owns the model)
|
||||
config = get_config_for_model(model_id)
|
||||
adapter = create_adapter_for_model(config, model_id, local_path, quantize)
|
||||
|
||||
if group is not None:
|
||||
adapter.slice_transformer_blocks(
|
||||
start_layer=shard_metadata.start_layer,
|
||||
end_layer=shard_metadata.end_layer,
|
||||
total_joint_blocks=config.joint_block_count,
|
||||
total_single_blocks=config.single_block_count,
|
||||
)
|
||||
|
||||
# Create diffusion runner (handles both single-node and distributed modes)
|
||||
num_sync_steps = config.get_num_sync_steps("medium") if group else 0
|
||||
runner = DiffusionRunner(
|
||||
config=config,
|
||||
adapter=adapter,
|
||||
group=group,
|
||||
shard_metadata=shard_metadata,
|
||||
num_sync_steps=num_sync_steps,
|
||||
)
|
||||
|
||||
if group is not None:
|
||||
logger.info("Initialized distributed diffusion runner")
|
||||
|
||||
mx.eval(adapter.model.parameters())
|
||||
|
||||
# TODO(ciaran): Do we need this?
|
||||
mx.eval(adapter.model)
|
||||
|
||||
# Synchronize processes before generation to avoid timeout
|
||||
mx_barrier(group)
|
||||
logger.info(f"Transformer sharded for rank {group.rank()}")
|
||||
else:
|
||||
logger.info("Single-node initialization")
|
||||
|
||||
object.__setattr__(self, "_config", config)
|
||||
object.__setattr__(self, "_adapter", adapter)
|
||||
object.__setattr__(self, "_group", group)
|
||||
object.__setattr__(self, "_shard_metadata", shard_metadata)
|
||||
object.__setattr__(self, "_runner", runner)
|
||||
|
||||
@classmethod
|
||||
def from_bound_instance(
|
||||
cls, bound_instance: BoundInstance
|
||||
) -> "DistributedImageModel":
|
||||
model_id = bound_instance.bound_shard.model_meta.model_id
|
||||
model_path = build_model_path(model_id)
|
||||
|
||||
shard_metadata = bound_instance.bound_shard
|
||||
if not isinstance(shard_metadata, PipelineShardMetadata):
|
||||
raise ValueError("Expected PipelineShardMetadata for image generation")
|
||||
|
||||
is_distributed = (
|
||||
len(bound_instance.instance.shard_assignments.node_to_runner) > 1
|
||||
)
|
||||
|
||||
if is_distributed:
|
||||
logger.info("Starting distributed init for image model")
|
||||
group = mlx_distributed_init(bound_instance)
|
||||
else:
|
||||
group = None
|
||||
|
||||
return cls(
|
||||
model_id=model_id,
|
||||
local_path=model_path,
|
||||
shard_metadata=shard_metadata,
|
||||
group=group,
|
||||
)
|
||||
|
||||
@property
|
||||
def model(self) -> Any:
|
||||
"""Return the underlying mflux model via the adapter."""
|
||||
return self._adapter.model
|
||||
|
||||
@property
|
||||
def config(self) -> ImageModelConfig:
|
||||
return self._config
|
||||
|
||||
@property
|
||||
def adapter(self) -> BaseModelAdapter:
|
||||
return self._adapter
|
||||
|
||||
@property
|
||||
def group(self) -> Optional[mx.distributed.Group]:
|
||||
return self._group
|
||||
|
||||
@property
|
||||
def shard_metadata(self) -> PipelineShardMetadata:
|
||||
return self._shard_metadata
|
||||
|
||||
@property
|
||||
def rank(self) -> int:
|
||||
return self._shard_metadata.device_rank
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return self._shard_metadata.world_size
|
||||
|
||||
@property
|
||||
def is_first_stage(self) -> bool:
|
||||
return self._shard_metadata.device_rank == 0
|
||||
|
||||
@property
|
||||
def is_last_stage(self) -> bool:
|
||||
return self._shard_metadata.device_rank == self._shard_metadata.world_size - 1
|
||||
|
||||
@property
|
||||
def is_distributed(self) -> bool:
|
||||
return self._shard_metadata.world_size > 1
|
||||
|
||||
@property
|
||||
def runner(self) -> DiffusionRunner:
|
||||
return self._runner
|
||||
|
||||
# Delegate attribute access to the underlying model via the adapter.
|
||||
# Guarded with TYPE_CHECKING to prevent type checker complaints
|
||||
# while still providing full delegation at runtime.
|
||||
if not TYPE_CHECKING:
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
return getattr(self._adapter.model, name)
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
if name in (
|
||||
"_config",
|
||||
"_adapter",
|
||||
"_group",
|
||||
"_shard_metadata",
|
||||
"_runner",
|
||||
):
|
||||
object.__setattr__(self, name, value)
|
||||
else:
|
||||
setattr(self._adapter.model, name, value)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
height: int,
|
||||
width: int,
|
||||
quality: Literal["low", "medium", "high"] = "medium",
|
||||
seed: int = 2,
|
||||
image_path: Path | None = None,
|
||||
partial_images: int = 0,
|
||||
) -> Generator[Image.Image | tuple[Image.Image, int, int], None, None]:
|
||||
# Determine number of inference steps based on quality
|
||||
steps = self._config.get_steps_for_quality(quality)
|
||||
|
||||
# For edit mode: compute dimensions from input image
|
||||
# This also stores image_paths in the adapter for encode_prompt()
|
||||
if image_path is not None:
|
||||
computed_dims = self._adapter.set_image_dimensions(image_path)
|
||||
if computed_dims is not None:
|
||||
# Override user-provided dimensions with computed ones
|
||||
width, height = computed_dims
|
||||
|
||||
config = Config(
|
||||
num_inference_steps=steps,
|
||||
height=height,
|
||||
width=width,
|
||||
image_path=image_path,
|
||||
)
|
||||
|
||||
# Generate images via the runner
|
||||
for result in self._runner.generate_image(
|
||||
settings=config,
|
||||
prompt=prompt,
|
||||
seed=seed,
|
||||
partial_images=partial_images,
|
||||
):
|
||||
if isinstance(result, tuple):
|
||||
# Partial image: (GeneratedImage, partial_index, total_partials)
|
||||
generated_image, partial_idx, total_partials = result
|
||||
yield (generated_image.image, partial_idx, total_partials)
|
||||
else:
|
||||
# Final image: GeneratedImage
|
||||
logger.info("generated image")
|
||||
yield result.image
|
||||
|
||||
|
||||
def initialize_image_model(bound_instance: BoundInstance) -> DistributedImageModel:
|
||||
"""Initialize DistributedImageModel from a BoundInstance."""
|
||||
return DistributedImageModel.from_bound_instance(bound_instance)
|
||||
120
src/exo/worker/engines/image/generate.py
Normal file
120
src/exo/worker/engines/image/generate.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import base64
|
||||
import io
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Generator, Literal
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from exo.shared.types.api import ImageEditsInternalParams, ImageGenerationTaskParams
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
ImageGenerationResponse,
|
||||
PartialImageResponse,
|
||||
)
|
||||
from exo.worker.engines.image.base import ImageGenerator
|
||||
|
||||
|
||||
def parse_size(size_str: str | None) -> tuple[int, int]:
|
||||
"""Parse size parameter like '1024x1024' to (width, height) tuple."""
|
||||
if not size_str or size_str == "auto":
|
||||
size_str = "1024x1024"
|
||||
|
||||
try:
|
||||
parts = size_str.split("x")
|
||||
if len(parts) == 2:
|
||||
width, height = int(parts[0]), int(parts[1])
|
||||
return (width, height)
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
|
||||
# Default fallback
|
||||
return (1024, 1024)
|
||||
|
||||
|
||||
def warmup_image_generator(model: ImageGenerator) -> Image.Image | None:
|
||||
"""Warmup the image generator with a small image."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Create a small dummy image for warmup (needed for edit models)
|
||||
dummy_image = Image.new("RGB", (256, 256), color=(128, 128, 128))
|
||||
dummy_path = Path(tmpdir) / "warmup.png"
|
||||
dummy_image.save(dummy_path)
|
||||
|
||||
for result in model.generate(
|
||||
prompt="Warmup",
|
||||
height=256,
|
||||
width=256,
|
||||
quality="low",
|
||||
seed=2,
|
||||
image_path=dummy_path,
|
||||
):
|
||||
if not isinstance(result, tuple):
|
||||
return result
|
||||
return None
|
||||
|
||||
|
||||
def generate_image(
|
||||
model: ImageGenerator,
|
||||
task: ImageGenerationTaskParams | ImageEditsInternalParams,
|
||||
) -> Generator[ImageGenerationResponse | PartialImageResponse, None, None]:
|
||||
"""Generate image(s), optionally yielding partial results.
|
||||
|
||||
When partial_images > 0 or stream=True, yields PartialImageResponse for
|
||||
intermediate images, then ImageGenerationResponse for the final image.
|
||||
|
||||
Yields:
|
||||
PartialImageResponse for intermediate images (if partial_images > 0)
|
||||
ImageGenerationResponse for the final complete image
|
||||
"""
|
||||
width, height = parse_size(task.size)
|
||||
quality: Literal["low", "medium", "high"] = task.quality or "medium"
|
||||
seed = 2 # TODO(ciaran): Randomise when not testing anymore
|
||||
|
||||
# Handle streaming params for both generation and edit tasks
|
||||
partial_images = task.partial_images or (3 if task.stream else 0)
|
||||
|
||||
image_path: Path | None = None
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
if isinstance(task, ImageEditsInternalParams):
|
||||
# Decode base64 image data and save to temp file
|
||||
image_path = Path(tmpdir) / "input.png"
|
||||
image_path.write_bytes(base64.b64decode(task.image_data))
|
||||
|
||||
# Iterate over generator results
|
||||
for result in model.generate(
|
||||
prompt=task.prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
quality=quality,
|
||||
seed=seed,
|
||||
image_path=image_path,
|
||||
partial_images=partial_images,
|
||||
):
|
||||
if isinstance(result, tuple):
|
||||
# Partial image: (Image, partial_index, total_partials)
|
||||
image, partial_idx, total_partials = result
|
||||
buffer = io.BytesIO()
|
||||
image_format = task.output_format.upper()
|
||||
if image_format == "JPG":
|
||||
image_format = "JPEG"
|
||||
image.save(buffer, format=image_format)
|
||||
|
||||
yield PartialImageResponse(
|
||||
image_data=buffer.getvalue(),
|
||||
format=task.output_format,
|
||||
partial_index=partial_idx,
|
||||
total_partials=total_partials,
|
||||
)
|
||||
else:
|
||||
# Final image
|
||||
image = result
|
||||
buffer = io.BytesIO()
|
||||
image_format = task.output_format.upper()
|
||||
if image_format == "JPG":
|
||||
image_format = "JPEG"
|
||||
image.save(buffer, format=image_format)
|
||||
|
||||
yield ImageGenerationResponse(
|
||||
image_data=buffer.getvalue(),
|
||||
format=task.output_format,
|
||||
)
|
||||
84
src/exo/worker/engines/image/models/__init__.py
Normal file
84
src/exo/worker/engines/image/models/__init__.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.models.flux import (
|
||||
FLUX_DEV_CONFIG,
|
||||
FLUX_SCHNELL_CONFIG,
|
||||
FluxModelAdapter,
|
||||
)
|
||||
from exo.worker.engines.image.models.qwen import (
|
||||
QWEN_IMAGE_CONFIG,
|
||||
QWEN_IMAGE_EDIT_CONFIG,
|
||||
QwenEditModelAdapter,
|
||||
QwenModelAdapter,
|
||||
)
|
||||
from exo.worker.engines.image.pipeline.adapter import ModelAdapter
|
||||
|
||||
__all__: list[str] = []
|
||||
|
||||
# Type alias for adapter factory functions
|
||||
# Factory takes (config, model_id, local_path, quantize) and returns a ModelAdapter
|
||||
AdapterFactory = Callable[[ImageModelConfig, str, Path, int | None], ModelAdapter]
|
||||
|
||||
# Registry maps model_family string to adapter factory
|
||||
_ADAPTER_REGISTRY: dict[str, AdapterFactory] = {
|
||||
"flux": FluxModelAdapter,
|
||||
"qwen-edit": QwenEditModelAdapter,
|
||||
"qwen": QwenModelAdapter,
|
||||
}
|
||||
|
||||
# Config registry: maps model ID patterns to configs
|
||||
_CONFIG_REGISTRY: dict[str, ImageModelConfig] = {
|
||||
"flux.1-schnell": FLUX_SCHNELL_CONFIG,
|
||||
"flux.1-dev": FLUX_DEV_CONFIG,
|
||||
"qwen-image-edit": QWEN_IMAGE_EDIT_CONFIG, # Must come before "qwen-image" for pattern matching
|
||||
"qwen-image": QWEN_IMAGE_CONFIG,
|
||||
}
|
||||
|
||||
|
||||
def get_config_for_model(model_id: str) -> ImageModelConfig:
|
||||
"""Get configuration for a model ID.
|
||||
|
||||
Args:
|
||||
model_id: The model identifier (e.g., "black-forest-labs/FLUX.1-schnell")
|
||||
|
||||
Returns:
|
||||
The model configuration
|
||||
|
||||
Raises:
|
||||
ValueError: If no configuration found for model ID
|
||||
"""
|
||||
model_id_lower = model_id.lower()
|
||||
|
||||
for pattern, config in _CONFIG_REGISTRY.items():
|
||||
if pattern in model_id_lower:
|
||||
return config
|
||||
|
||||
raise ValueError(f"No configuration found for model: {model_id}")
|
||||
|
||||
|
||||
def create_adapter_for_model(
|
||||
config: ImageModelConfig,
|
||||
model_id: str,
|
||||
local_path: Path,
|
||||
quantize: int | None = None,
|
||||
) -> ModelAdapter:
|
||||
"""Create a model adapter for the given configuration.
|
||||
|
||||
Args:
|
||||
config: The model configuration
|
||||
model_id: The model identifier
|
||||
local_path: Path to the model weights
|
||||
quantize: Optional quantization bits
|
||||
|
||||
Returns:
|
||||
A ModelAdapter instance
|
||||
|
||||
Raises:
|
||||
ValueError: If no adapter found for model family
|
||||
"""
|
||||
factory = _ADAPTER_REGISTRY.get(config.model_family)
|
||||
if factory is None:
|
||||
raise ValueError(f"No adapter found for model family: {config.model_family}")
|
||||
return factory(config, model_id, local_path, quantize)
|
||||
103
src/exo/worker/engines/image/models/base.py
Normal file
103
src/exo/worker/engines/image/models/base.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import mlx.core as mx
|
||||
from mflux.config.runtime_config import RuntimeConfig
|
||||
from mflux.models.common.latent_creator.latent_creator import Img2Img, LatentCreator
|
||||
from mflux.utils.array_util import ArrayUtil
|
||||
from mflux.utils.image_util import ImageUtil
|
||||
|
||||
|
||||
class BaseModelAdapter(ABC):
|
||||
"""Base class for model adapters with shared utilities.
|
||||
|
||||
Provides common implementations for latent creation and decoding.
|
||||
Subclasses implement model-specific prompt encoding and noise computation.
|
||||
"""
|
||||
|
||||
def create_latents(self, seed: int, runtime_config: RuntimeConfig) -> mx.array:
|
||||
"""Create initial latents. Uses model-specific latent creator."""
|
||||
return LatentCreator.create_for_txt2img_or_img2img(
|
||||
seed=seed,
|
||||
height=runtime_config.height,
|
||||
width=runtime_config.width,
|
||||
img2img=Img2Img(
|
||||
vae=self.model.vae,
|
||||
latent_creator=self._get_latent_creator(),
|
||||
sigmas=runtime_config.scheduler.sigmas,
|
||||
init_time_step=runtime_config.init_time_step,
|
||||
image_path=runtime_config.image_path,
|
||||
),
|
||||
)
|
||||
|
||||
def decode_latents(
|
||||
self,
|
||||
latents: mx.array,
|
||||
runtime_config: RuntimeConfig,
|
||||
seed: int,
|
||||
prompt: str,
|
||||
) -> Any:
|
||||
"""Decode latents to image. Shared implementation."""
|
||||
latents = ArrayUtil.unpack_latents(
|
||||
latents=latents,
|
||||
height=runtime_config.height,
|
||||
width=runtime_config.width,
|
||||
)
|
||||
decoded = self.model.vae.decode(latents)
|
||||
return ImageUtil.to_image(
|
||||
decoded_latents=decoded,
|
||||
config=runtime_config,
|
||||
seed=seed,
|
||||
prompt=prompt,
|
||||
quantization=self.model.bits,
|
||||
lora_paths=self.model.lora_paths,
|
||||
lora_scales=self.model.lora_scales,
|
||||
image_path=runtime_config.image_path,
|
||||
image_strength=runtime_config.image_strength,
|
||||
generation_time=0,
|
||||
)
|
||||
|
||||
# Abstract methods - subclasses must implement
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def model(self) -> Any:
|
||||
"""Return the underlying mflux model."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _get_latent_creator(self) -> type:
|
||||
"""Return the latent creator class for this model."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def slice_transformer_blocks(
|
||||
self,
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
total_joint_blocks: int,
|
||||
total_single_blocks: int,
|
||||
):
|
||||
"""Remove transformer blocks outside the assigned range.
|
||||
|
||||
This should be called BEFORE mx.eval() to avoid loading unused weights
|
||||
in distributed mode.
|
||||
|
||||
Args:
|
||||
start_layer: First layer index (inclusive) assigned to this node
|
||||
end_layer: Last layer index (exclusive) assigned to this node
|
||||
total_joint_blocks: Total number of joint blocks in the model
|
||||
total_single_blocks: Total number of single blocks in the model
|
||||
"""
|
||||
...
|
||||
|
||||
def set_image_dimensions(self, image_path: Path) -> tuple[int, int] | None:
|
||||
"""Default implementation: no dimension computation needed.
|
||||
|
||||
Override in edit adapters to compute dimensions from input image.
|
||||
|
||||
Returns:
|
||||
None (use user-specified dimensions)
|
||||
"""
|
||||
return None
|
||||
11
src/exo/worker/engines/image/models/flux/__init__.py
Normal file
11
src/exo/worker/engines/image/models/flux/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from exo.worker.engines.image.models.flux.adapter import FluxModelAdapter
|
||||
from exo.worker.engines.image.models.flux.config import (
|
||||
FLUX_DEV_CONFIG,
|
||||
FLUX_SCHNELL_CONFIG,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"FluxModelAdapter",
|
||||
"FLUX_DEV_CONFIG",
|
||||
"FLUX_SCHNELL_CONFIG",
|
||||
]
|
||||
680
src/exo/worker/engines/image/models/flux/adapter.py
Normal file
680
src/exo/worker/engines/image/models/flux/adapter.py
Normal file
@@ -0,0 +1,680 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
import mlx.core as mx
|
||||
from mflux.config.model_config import ModelConfig
|
||||
from mflux.config.runtime_config import RuntimeConfig
|
||||
from mflux.models.flux.latent_creator.flux_latent_creator import FluxLatentCreator
|
||||
from mflux.models.flux.model.flux_text_encoder.prompt_encoder import PromptEncoder
|
||||
from mflux.models.flux.model.flux_transformer.common.attention_utils import (
|
||||
AttentionUtils,
|
||||
)
|
||||
from mflux.models.flux.model.flux_transformer.joint_transformer_block import (
|
||||
JointTransformerBlock,
|
||||
)
|
||||
from mflux.models.flux.model.flux_transformer.transformer import Transformer
|
||||
from mflux.models.flux.variants.txt2img.flux import Flux1
|
||||
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.models.base import BaseModelAdapter
|
||||
from exo.worker.engines.image.pipeline.adapter import (
|
||||
BlockWrapperMode,
|
||||
JointBlockInterface,
|
||||
SingleBlockInterface,
|
||||
)
|
||||
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
|
||||
|
||||
|
||||
class FluxPromptData:
|
||||
"""Container for Flux prompt encoding results."""
|
||||
|
||||
def __init__(self, prompt_embeds: mx.array, pooled_prompt_embeds: mx.array):
|
||||
self._prompt_embeds = prompt_embeds
|
||||
self._pooled_prompt_embeds = pooled_prompt_embeds
|
||||
|
||||
@property
|
||||
def prompt_embeds(self) -> mx.array:
|
||||
return self._prompt_embeds
|
||||
|
||||
@property
|
||||
def pooled_prompt_embeds(self) -> mx.array:
|
||||
return self._pooled_prompt_embeds
|
||||
|
||||
@property
|
||||
def negative_prompt_embeds(self) -> mx.array | None:
|
||||
"""Flux does not use CFG."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def negative_pooled_prompt_embeds(self) -> mx.array | None:
|
||||
"""Flux does not use CFG."""
|
||||
return None
|
||||
|
||||
def get_extra_forward_kwargs(self, positive: bool = True) -> dict[str, Any]:
|
||||
"""Flux has no extra forward kwargs."""
|
||||
return {}
|
||||
|
||||
@property
|
||||
def conditioning_latents(self) -> mx.array | None:
|
||||
"""Flux does not use conditioning latents."""
|
||||
return None
|
||||
|
||||
|
||||
class FluxModelAdapter(BaseModelAdapter):
|
||||
def __init__(
|
||||
self,
|
||||
config: ImageModelConfig,
|
||||
model_id: str,
|
||||
local_path: Path,
|
||||
quantize: int | None = None,
|
||||
):
|
||||
self._config = config
|
||||
self._model = Flux1(
|
||||
model_config=ModelConfig.from_name(model_name=model_id, base_model=None),
|
||||
local_path=str(local_path),
|
||||
quantize=quantize,
|
||||
)
|
||||
self._transformer = self._model.transformer
|
||||
|
||||
@property
|
||||
def config(self) -> ImageModelConfig:
|
||||
return self._config
|
||||
|
||||
@property
|
||||
def model(self) -> Flux1:
|
||||
return self._model
|
||||
|
||||
@property
|
||||
def transformer(self) -> Transformer:
|
||||
return self._transformer
|
||||
|
||||
@property
|
||||
def hidden_dim(self) -> int:
|
||||
return self._transformer.x_embedder.weight.shape[0]
|
||||
|
||||
def _get_latent_creator(self) -> type:
|
||||
return FluxLatentCreator
|
||||
|
||||
def encode_prompt(self, prompt: str) -> FluxPromptData:
|
||||
"""Encode prompt into FluxPromptData."""
|
||||
prompt_embeds, pooled_prompt_embeds = PromptEncoder.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_cache=self._model.prompt_cache,
|
||||
t5_tokenizer=self._model.t5_tokenizer,
|
||||
clip_tokenizer=self._model.clip_tokenizer,
|
||||
t5_text_encoder=self._model.t5_text_encoder,
|
||||
clip_text_encoder=self._model.clip_text_encoder,
|
||||
)
|
||||
return FluxPromptData(
|
||||
prompt_embeds=prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
)
|
||||
|
||||
@property
|
||||
def needs_cfg(self) -> bool:
|
||||
return False
|
||||
|
||||
def apply_guidance(
|
||||
self,
|
||||
noise_positive: mx.array,
|
||||
noise_negative: mx.array,
|
||||
guidance_scale: float,
|
||||
) -> mx.array:
|
||||
raise NotImplementedError("Flux does not use classifier-free guidance")
|
||||
|
||||
def compute_embeddings(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
prompt_embeds: mx.array,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
embedded_hidden = self._transformer.x_embedder(hidden_states)
|
||||
embedded_encoder = self._transformer.context_embedder(prompt_embeds)
|
||||
return embedded_hidden, embedded_encoder
|
||||
|
||||
def compute_text_embeddings(
|
||||
self,
|
||||
t: int,
|
||||
runtime_config: RuntimeConfig,
|
||||
pooled_prompt_embeds: mx.array | None = None,
|
||||
hidden_states: mx.array | None = None, # Ignored by Flux
|
||||
) -> mx.array:
|
||||
if pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"pooled_prompt_embeds is required for Flux text embeddings"
|
||||
)
|
||||
|
||||
# hidden_states is ignored - Flux uses pooled_prompt_embeds instead
|
||||
return Transformer.compute_text_embeddings(
|
||||
t, pooled_prompt_embeds, self._transformer.time_text_embed, runtime_config
|
||||
)
|
||||
|
||||
def compute_rotary_embeddings(
|
||||
self,
|
||||
prompt_embeds: mx.array,
|
||||
runtime_config: RuntimeConfig,
|
||||
**kwargs: Any,
|
||||
) -> mx.array:
|
||||
kontext_image_ids = kwargs.get("kontext_image_ids")
|
||||
return Transformer.compute_rotary_embeddings(
|
||||
prompt_embeds,
|
||||
self._transformer.pos_embed,
|
||||
runtime_config,
|
||||
kontext_image_ids,
|
||||
)
|
||||
|
||||
def apply_joint_block(
|
||||
self,
|
||||
block: JointBlockInterface,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: Any, # mx.array for Flux
|
||||
kv_cache: ImagePatchKVCache | None,
|
||||
mode: BlockWrapperMode,
|
||||
text_seq_len: int,
|
||||
patch_start: int | None = None,
|
||||
patch_end: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
if mode == BlockWrapperMode.CACHING:
|
||||
return self._apply_joint_block_caching(
|
||||
block=block,
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=rotary_embeddings,
|
||||
kv_cache=kv_cache,
|
||||
text_seq_len=text_seq_len,
|
||||
)
|
||||
else:
|
||||
assert patch_start is not None and patch_end is not None
|
||||
assert kv_cache is not None
|
||||
return self._apply_joint_block_patched(
|
||||
block=block,
|
||||
patch_hidden=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=rotary_embeddings,
|
||||
kv_cache=kv_cache,
|
||||
text_seq_len=text_seq_len,
|
||||
patch_start=patch_start,
|
||||
patch_end=patch_end,
|
||||
)
|
||||
|
||||
def apply_single_block(
|
||||
self,
|
||||
block: SingleBlockInterface,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: mx.array,
|
||||
kv_cache: ImagePatchKVCache | None,
|
||||
mode: BlockWrapperMode,
|
||||
text_seq_len: int,
|
||||
patch_start: int | None = None,
|
||||
patch_end: int | None = None,
|
||||
) -> mx.array:
|
||||
if mode == BlockWrapperMode.CACHING:
|
||||
return self._apply_single_block_caching(
|
||||
block=block,
|
||||
hidden_states=hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=rotary_embeddings,
|
||||
kv_cache=kv_cache,
|
||||
text_seq_len=text_seq_len,
|
||||
)
|
||||
else:
|
||||
assert patch_start is not None and patch_end is not None
|
||||
assert kv_cache is not None
|
||||
return self._apply_single_block_patched(
|
||||
block=block,
|
||||
patch_hidden=hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=rotary_embeddings,
|
||||
kv_cache=kv_cache,
|
||||
text_seq_len=text_seq_len,
|
||||
patch_start=patch_start,
|
||||
patch_end=patch_end,
|
||||
)
|
||||
|
||||
def final_projection(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
) -> mx.array:
|
||||
hidden_states = self._transformer.norm_out(hidden_states, text_embeddings)
|
||||
return self._transformer.proj_out(hidden_states)
|
||||
|
||||
def get_joint_blocks(self) -> list[JointBlockInterface]:
|
||||
return cast(
|
||||
list[JointBlockInterface], list(self._transformer.transformer_blocks)
|
||||
)
|
||||
|
||||
def get_single_blocks(self) -> list[SingleBlockInterface]:
|
||||
return cast(
|
||||
list[SingleBlockInterface],
|
||||
list(self._transformer.single_transformer_blocks),
|
||||
)
|
||||
|
||||
def slice_transformer_blocks(
|
||||
self,
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
total_joint_blocks: int,
|
||||
total_single_blocks: int,
|
||||
) -> None:
|
||||
if end_layer <= total_joint_blocks:
|
||||
# All assigned are joint blocks
|
||||
joint_start, joint_end = start_layer, end_layer
|
||||
single_start, single_end = 0, 0
|
||||
elif start_layer >= total_joint_blocks:
|
||||
# All assigned are single blocks
|
||||
joint_start, joint_end = 0, 0
|
||||
single_start = start_layer - total_joint_blocks
|
||||
single_end = end_layer - total_joint_blocks
|
||||
else:
|
||||
# Spans both joint and single
|
||||
joint_start, joint_end = start_layer, total_joint_blocks
|
||||
single_start = 0
|
||||
single_end = end_layer - total_joint_blocks
|
||||
|
||||
all_joint = list(self._transformer.transformer_blocks)
|
||||
self._transformer.transformer_blocks = all_joint[joint_start:joint_end]
|
||||
|
||||
all_single = list(self._transformer.single_transformer_blocks)
|
||||
self._transformer.single_transformer_blocks = all_single[
|
||||
single_start:single_end
|
||||
]
|
||||
|
||||
def merge_streams(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
) -> mx.array:
|
||||
return mx.concatenate([encoder_hidden_states, hidden_states], axis=1)
|
||||
|
||||
def _apply_joint_block_caching(
|
||||
self,
|
||||
block: JointBlockInterface,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: mx.array,
|
||||
kv_cache: ImagePatchKVCache | None,
|
||||
text_seq_len: int,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
num_img_tokens = hidden_states.shape[1]
|
||||
batch_size = hidden_states.shape[0]
|
||||
attn = block.attn
|
||||
num_heads = attn.num_heads
|
||||
head_dim = attn.head_dimension
|
||||
|
||||
# 1. Compute norms
|
||||
norm_hidden, gate_msa, shift_mlp, scale_mlp, gate_mlp = block.norm1(
|
||||
hidden_states=hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
)
|
||||
norm_encoder, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
|
||||
block.norm1_context(
|
||||
hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
)
|
||||
)
|
||||
|
||||
# 2. Compute Q, K, V for full image
|
||||
img_query, img_key, img_value = AttentionUtils.process_qkv(
|
||||
hidden_states=norm_hidden,
|
||||
to_q=attn.to_q,
|
||||
to_k=attn.to_k,
|
||||
to_v=attn.to_v,
|
||||
norm_q=attn.norm_q,
|
||||
norm_k=attn.norm_k,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
)
|
||||
|
||||
# 3. Compute Q, K, V for text
|
||||
txt_query, txt_key, txt_value = AttentionUtils.process_qkv(
|
||||
hidden_states=norm_encoder,
|
||||
to_q=attn.add_q_proj,
|
||||
to_k=attn.add_k_proj,
|
||||
to_v=attn.add_v_proj,
|
||||
norm_q=attn.norm_added_q,
|
||||
norm_k=attn.norm_added_k,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
)
|
||||
|
||||
# 4. Concatenate Q, K, V: [text, image]
|
||||
query = mx.concatenate([txt_query, img_query], axis=2)
|
||||
key = mx.concatenate([txt_key, img_key], axis=2)
|
||||
value = mx.concatenate([txt_value, img_value], axis=2)
|
||||
|
||||
# 5. Apply RoPE
|
||||
query, key = AttentionUtils.apply_rope(
|
||||
xq=query, xk=key, freqs_cis=rotary_embeddings
|
||||
)
|
||||
|
||||
# 6. Store IMAGE K/V in cache for async pipeline
|
||||
if kv_cache is not None:
|
||||
kv_cache.update_image_patch(
|
||||
patch_start=0,
|
||||
patch_end=num_img_tokens,
|
||||
key=key[:, :, text_seq_len:, :],
|
||||
value=value[:, :, text_seq_len:, :],
|
||||
)
|
||||
|
||||
# 7. Compute full attention
|
||||
attn_output = AttentionUtils.compute_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
batch_size=batch_size,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
)
|
||||
|
||||
# 8. Extract and project outputs
|
||||
context_attn_output = attn_output[:, :text_seq_len, :]
|
||||
attn_output = attn_output[:, text_seq_len:, :]
|
||||
|
||||
attn_output = attn.to_out[0](attn_output)
|
||||
context_attn_output = attn.to_add_out(context_attn_output)
|
||||
|
||||
# 9. Apply norm and feed forward
|
||||
hidden_states = JointTransformerBlock.apply_norm_and_feed_forward(
|
||||
hidden_states=hidden_states,
|
||||
attn_output=attn_output,
|
||||
gate_mlp=gate_mlp,
|
||||
gate_msa=gate_msa,
|
||||
scale_mlp=scale_mlp,
|
||||
shift_mlp=shift_mlp,
|
||||
norm_layer=block.norm2,
|
||||
ff_layer=block.ff,
|
||||
)
|
||||
encoder_hidden_states = JointTransformerBlock.apply_norm_and_feed_forward(
|
||||
hidden_states=encoder_hidden_states,
|
||||
attn_output=context_attn_output,
|
||||
gate_mlp=c_gate_mlp,
|
||||
gate_msa=c_gate_msa,
|
||||
scale_mlp=c_scale_mlp,
|
||||
shift_mlp=c_shift_mlp,
|
||||
norm_layer=block.norm2_context,
|
||||
ff_layer=block.ff_context,
|
||||
)
|
||||
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
def _apply_joint_block_patched(
|
||||
self,
|
||||
block: JointBlockInterface,
|
||||
patch_hidden: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: mx.array,
|
||||
kv_cache: ImagePatchKVCache,
|
||||
text_seq_len: int,
|
||||
patch_start: int,
|
||||
patch_end: int,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
batch_size = patch_hidden.shape[0]
|
||||
attn = block.attn
|
||||
num_heads = attn.num_heads
|
||||
head_dim = attn.head_dimension
|
||||
|
||||
# 1. Compute norms
|
||||
norm_hidden, gate_msa, shift_mlp, scale_mlp, gate_mlp = block.norm1(
|
||||
hidden_states=patch_hidden,
|
||||
text_embeddings=text_embeddings,
|
||||
)
|
||||
norm_encoder, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
|
||||
block.norm1_context(
|
||||
hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
)
|
||||
)
|
||||
|
||||
# 2. Compute Q, K, V for image patch
|
||||
img_query, img_key, img_value = AttentionUtils.process_qkv(
|
||||
hidden_states=norm_hidden,
|
||||
to_q=attn.to_q,
|
||||
to_k=attn.to_k,
|
||||
to_v=attn.to_v,
|
||||
norm_q=attn.norm_q,
|
||||
norm_k=attn.norm_k,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
)
|
||||
|
||||
# 3. Compute Q, K, V for text
|
||||
txt_query, txt_key, txt_value = AttentionUtils.process_qkv(
|
||||
hidden_states=norm_encoder,
|
||||
to_q=attn.add_q_proj,
|
||||
to_k=attn.add_k_proj,
|
||||
to_v=attn.add_v_proj,
|
||||
norm_q=attn.norm_added_q,
|
||||
norm_k=attn.norm_added_k,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
)
|
||||
|
||||
# 4. Concatenate Q, K, V for patch: [text, patch]
|
||||
query = mx.concatenate([txt_query, img_query], axis=2)
|
||||
patch_key = mx.concatenate([txt_key, img_key], axis=2)
|
||||
patch_value = mx.concatenate([txt_value, img_value], axis=2)
|
||||
|
||||
# 5. Extract RoPE for [text + current_patch]
|
||||
text_rope = rotary_embeddings[:, :, :text_seq_len, ...]
|
||||
patch_img_rope = rotary_embeddings[
|
||||
:, :, text_seq_len + patch_start : text_seq_len + patch_end, ...
|
||||
]
|
||||
patch_rope = mx.concatenate([text_rope, patch_img_rope], axis=2)
|
||||
|
||||
# 6. Apply RoPE
|
||||
query, patch_key = AttentionUtils.apply_rope(
|
||||
xq=query, xk=patch_key, freqs_cis=patch_rope
|
||||
)
|
||||
|
||||
# 7. Update cache with this patch's IMAGE K/V
|
||||
kv_cache.update_image_patch(
|
||||
patch_start=patch_start,
|
||||
patch_end=patch_end,
|
||||
key=patch_key[:, :, text_seq_len:, :],
|
||||
value=patch_value[:, :, text_seq_len:, :],
|
||||
)
|
||||
|
||||
# 8. Get full K, V from cache
|
||||
full_key, full_value = kv_cache.get_full_kv(
|
||||
text_key=patch_key[:, :, :text_seq_len, :],
|
||||
text_value=patch_value[:, :, :text_seq_len, :],
|
||||
)
|
||||
|
||||
# 9. Compute attention
|
||||
attn_output = AttentionUtils.compute_attention(
|
||||
query=query,
|
||||
key=full_key,
|
||||
value=full_value,
|
||||
batch_size=batch_size,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
)
|
||||
|
||||
# 10. Extract and project outputs
|
||||
context_attn_output = attn_output[:, :text_seq_len, :]
|
||||
hidden_attn_output = attn_output[:, text_seq_len:, :]
|
||||
|
||||
hidden_attn_output = attn.to_out[0](hidden_attn_output)
|
||||
context_attn_output = attn.to_add_out(context_attn_output)
|
||||
|
||||
# 11. Apply norm and feed forward
|
||||
patch_hidden = JointTransformerBlock.apply_norm_and_feed_forward(
|
||||
hidden_states=patch_hidden,
|
||||
attn_output=hidden_attn_output,
|
||||
gate_mlp=gate_mlp,
|
||||
gate_msa=gate_msa,
|
||||
scale_mlp=scale_mlp,
|
||||
shift_mlp=shift_mlp,
|
||||
norm_layer=block.norm2,
|
||||
ff_layer=block.ff,
|
||||
)
|
||||
encoder_hidden_states = JointTransformerBlock.apply_norm_and_feed_forward(
|
||||
hidden_states=encoder_hidden_states,
|
||||
attn_output=context_attn_output,
|
||||
gate_mlp=c_gate_mlp,
|
||||
gate_msa=c_gate_msa,
|
||||
scale_mlp=c_scale_mlp,
|
||||
shift_mlp=c_shift_mlp,
|
||||
norm_layer=block.norm2_context,
|
||||
ff_layer=block.ff_context,
|
||||
)
|
||||
|
||||
return encoder_hidden_states, patch_hidden
|
||||
|
||||
def _apply_single_block_caching(
|
||||
self,
|
||||
block: SingleBlockInterface,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: mx.array,
|
||||
kv_cache: ImagePatchKVCache | None,
|
||||
text_seq_len: int,
|
||||
) -> mx.array:
|
||||
total_seq_len = hidden_states.shape[1]
|
||||
num_img_tokens = total_seq_len - text_seq_len
|
||||
batch_size = hidden_states.shape[0]
|
||||
attn = block.attn
|
||||
num_heads = attn.num_heads
|
||||
head_dim = attn.head_dimension
|
||||
|
||||
# Residual connection
|
||||
residual = hidden_states
|
||||
|
||||
# 1. Compute norm
|
||||
norm_hidden, gate = block.norm(
|
||||
hidden_states=hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
)
|
||||
|
||||
# 2. Compute Q, K, V
|
||||
query, key, value = AttentionUtils.process_qkv(
|
||||
hidden_states=norm_hidden,
|
||||
to_q=attn.to_q,
|
||||
to_k=attn.to_k,
|
||||
to_v=attn.to_v,
|
||||
norm_q=attn.norm_q,
|
||||
norm_k=attn.norm_k,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
)
|
||||
|
||||
# 3. Apply RoPE
|
||||
query, key = AttentionUtils.apply_rope(
|
||||
xq=query, xk=key, freqs_cis=rotary_embeddings
|
||||
)
|
||||
|
||||
# 4. Store IMAGE K/V in cache
|
||||
if kv_cache is not None:
|
||||
kv_cache.update_image_patch(
|
||||
patch_start=0,
|
||||
patch_end=num_img_tokens,
|
||||
key=key[:, :, text_seq_len:, :],
|
||||
value=value[:, :, text_seq_len:, :],
|
||||
)
|
||||
|
||||
# 5. Compute attention
|
||||
attn_output = AttentionUtils.compute_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
batch_size=batch_size,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
)
|
||||
|
||||
# 6. Apply feed forward and projection
|
||||
hidden_states = block._apply_feed_forward_and_projection(
|
||||
norm_hidden_states=norm_hidden,
|
||||
attn_output=attn_output,
|
||||
gate=gate,
|
||||
)
|
||||
|
||||
return residual + hidden_states
|
||||
|
||||
def _apply_single_block_patched(
|
||||
self,
|
||||
block: SingleBlockInterface,
|
||||
patch_hidden: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: mx.array,
|
||||
kv_cache: ImagePatchKVCache,
|
||||
text_seq_len: int,
|
||||
patch_start: int,
|
||||
patch_end: int,
|
||||
) -> mx.array:
|
||||
batch_size = patch_hidden.shape[0]
|
||||
attn = block.attn
|
||||
num_heads = attn.num_heads
|
||||
head_dim = attn.head_dimension
|
||||
|
||||
# Residual connection
|
||||
residual = patch_hidden
|
||||
|
||||
# 1. Compute norm
|
||||
norm_hidden, gate = block.norm(
|
||||
hidden_states=patch_hidden,
|
||||
text_embeddings=text_embeddings,
|
||||
)
|
||||
|
||||
# 2. Compute Q, K, V
|
||||
query, key, value = AttentionUtils.process_qkv(
|
||||
hidden_states=norm_hidden,
|
||||
to_q=attn.to_q,
|
||||
to_k=attn.to_k,
|
||||
to_v=attn.to_v,
|
||||
norm_q=attn.norm_q,
|
||||
norm_k=attn.norm_k,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
)
|
||||
|
||||
# 3. Extract RoPE for [text + current_patch]
|
||||
text_rope = rotary_embeddings[:, :, :text_seq_len, ...]
|
||||
patch_img_rope = rotary_embeddings[
|
||||
:, :, text_seq_len + patch_start : text_seq_len + patch_end, ...
|
||||
]
|
||||
patch_rope = mx.concatenate([text_rope, patch_img_rope], axis=2)
|
||||
|
||||
# 4. Apply RoPE
|
||||
query, key = AttentionUtils.apply_rope(xq=query, xk=key, freqs_cis=patch_rope)
|
||||
|
||||
# 5. Update cache with this patch's IMAGE K/V
|
||||
kv_cache.update_image_patch(
|
||||
patch_start=patch_start,
|
||||
patch_end=patch_end,
|
||||
key=key[:, :, text_seq_len:, :],
|
||||
value=value[:, :, text_seq_len:, :],
|
||||
)
|
||||
|
||||
# 6. Get full K, V from cache
|
||||
full_key, full_value = kv_cache.get_full_kv(
|
||||
text_key=key[:, :, :text_seq_len, :],
|
||||
text_value=value[:, :, :text_seq_len, :],
|
||||
)
|
||||
|
||||
# 7. Compute attention
|
||||
attn_output = AttentionUtils.compute_attention(
|
||||
query=query,
|
||||
key=full_key,
|
||||
value=full_value,
|
||||
batch_size=batch_size,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
)
|
||||
|
||||
# 8. Apply feed forward and projection
|
||||
hidden_states = block._apply_feed_forward_and_projection(
|
||||
norm_hidden_states=norm_hidden,
|
||||
attn_output=attn_output,
|
||||
gate=gate,
|
||||
)
|
||||
|
||||
return residual + hidden_states
|
||||
48
src/exo/worker/engines/image/models/flux/config.py
Normal file
48
src/exo/worker/engines/image/models/flux/config.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from exo.worker.engines.image.config import (
|
||||
BlockType,
|
||||
ImageModelConfig,
|
||||
TransformerBlockConfig,
|
||||
)
|
||||
|
||||
FLUX_SCHNELL_CONFIG = ImageModelConfig(
|
||||
model_family="flux",
|
||||
model_variant="schnell",
|
||||
hidden_dim=3072,
|
||||
num_heads=24,
|
||||
head_dim=128,
|
||||
block_configs=(
|
||||
TransformerBlockConfig(
|
||||
block_type=BlockType.JOINT, count=19, has_separate_text_output=True
|
||||
),
|
||||
TransformerBlockConfig(
|
||||
block_type=BlockType.SINGLE, count=38, has_separate_text_output=False
|
||||
),
|
||||
),
|
||||
patch_size=2,
|
||||
vae_scale_factor=8,
|
||||
default_steps={"low": 1, "medium": 2, "high": 4},
|
||||
num_sync_steps_factor=0.5, # 1 sync step for medium (2 steps)
|
||||
uses_attention_mask=False,
|
||||
)
|
||||
|
||||
|
||||
FLUX_DEV_CONFIG = ImageModelConfig(
|
||||
model_family="flux",
|
||||
model_variant="dev",
|
||||
hidden_dim=3072,
|
||||
num_heads=24,
|
||||
head_dim=128,
|
||||
block_configs=(
|
||||
TransformerBlockConfig(
|
||||
block_type=BlockType.JOINT, count=19, has_separate_text_output=True
|
||||
),
|
||||
TransformerBlockConfig(
|
||||
block_type=BlockType.SINGLE, count=38, has_separate_text_output=False
|
||||
),
|
||||
),
|
||||
patch_size=2,
|
||||
vae_scale_factor=8,
|
||||
default_steps={"low": 10, "medium": 25, "high": 50},
|
||||
num_sync_steps_factor=0.125, # ~3 sync steps for medium (25 steps)
|
||||
uses_attention_mask=False,
|
||||
)
|
||||
13
src/exo/worker/engines/image/models/qwen/__init__.py
Normal file
13
src/exo/worker/engines/image/models/qwen/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from exo.worker.engines.image.models.qwen.adapter import QwenModelAdapter
|
||||
from exo.worker.engines.image.models.qwen.config import (
|
||||
QWEN_IMAGE_CONFIG,
|
||||
QWEN_IMAGE_EDIT_CONFIG,
|
||||
)
|
||||
from exo.worker.engines.image.models.qwen.edit_adapter import QwenEditModelAdapter
|
||||
|
||||
__all__ = [
|
||||
"QwenModelAdapter",
|
||||
"QwenEditModelAdapter",
|
||||
"QWEN_IMAGE_CONFIG",
|
||||
"QWEN_IMAGE_EDIT_CONFIG",
|
||||
]
|
||||
519
src/exo/worker/engines/image/models/qwen/adapter.py
Normal file
519
src/exo/worker/engines/image/models/qwen/adapter.py
Normal file
@@ -0,0 +1,519 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
import mlx.core as mx
|
||||
from mflux.config.model_config import ModelConfig
|
||||
from mflux.config.runtime_config import RuntimeConfig
|
||||
from mflux.models.qwen.latent_creator.qwen_latent_creator import QwenLatentCreator
|
||||
from mflux.models.qwen.model.qwen_text_encoder.qwen_prompt_encoder import (
|
||||
QwenPromptEncoder,
|
||||
)
|
||||
from mflux.models.qwen.model.qwen_transformer.qwen_attention import QwenAttention
|
||||
from mflux.models.qwen.model.qwen_transformer.qwen_transformer import QwenTransformer
|
||||
from mflux.models.qwen.model.qwen_transformer.qwen_transformer_block import (
|
||||
QwenTransformerBlock,
|
||||
)
|
||||
from mflux.models.qwen.variants.txt2img.qwen_image import QwenImage
|
||||
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.models.base import BaseModelAdapter
|
||||
from exo.worker.engines.image.pipeline.adapter import (
|
||||
BlockWrapperMode,
|
||||
JointBlockInterface,
|
||||
SingleBlockInterface,
|
||||
)
|
||||
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
|
||||
|
||||
|
||||
class QwenPromptData:
|
||||
"""Container for Qwen prompt encoding results.
|
||||
|
||||
Implements PromptData protocol with additional Qwen-specific attributes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_embeds: mx.array,
|
||||
prompt_mask: mx.array,
|
||||
negative_prompt_embeds: mx.array,
|
||||
negative_prompt_mask: mx.array,
|
||||
):
|
||||
self._prompt_embeds = prompt_embeds
|
||||
self.prompt_mask = prompt_mask
|
||||
self._negative_prompt_embeds = negative_prompt_embeds
|
||||
self.negative_prompt_mask = negative_prompt_mask
|
||||
|
||||
@property
|
||||
def prompt_embeds(self) -> mx.array:
|
||||
"""Text embeddings from encoder."""
|
||||
return self._prompt_embeds
|
||||
|
||||
@property
|
||||
def pooled_prompt_embeds(self) -> mx.array:
|
||||
"""Placeholder for protocol compliance - Qwen doesn't use pooled embeds."""
|
||||
return self._prompt_embeds # Use prompt_embeds as placeholder
|
||||
|
||||
@property
|
||||
def negative_prompt_embeds(self) -> mx.array:
|
||||
"""Negative prompt embeddings for CFG."""
|
||||
return self._negative_prompt_embeds
|
||||
|
||||
@property
|
||||
def negative_pooled_prompt_embeds(self) -> mx.array:
|
||||
"""Placeholder - Qwen doesn't use pooled embeds."""
|
||||
return self._negative_prompt_embeds
|
||||
|
||||
def get_extra_forward_kwargs(self, positive: bool = True) -> dict[str, Any]:
|
||||
"""Return encoder_hidden_states_mask for the appropriate prompt."""
|
||||
if positive:
|
||||
return {"encoder_hidden_states_mask": self.prompt_mask}
|
||||
else:
|
||||
return {"encoder_hidden_states_mask": self.negative_prompt_mask}
|
||||
|
||||
@property
|
||||
def conditioning_latents(self) -> mx.array | None:
|
||||
"""Standard Qwen does not use conditioning latents."""
|
||||
return None
|
||||
|
||||
|
||||
class QwenModelAdapter(BaseModelAdapter):
|
||||
"""Adapter for Qwen-Image model.
|
||||
|
||||
Key differences from Flux:
|
||||
- Single text encoder (vs dual T5+CLIP)
|
||||
- 60 joint-style blocks, no single blocks
|
||||
- 3D RoPE returning ((img_cos, img_sin), (txt_cos, txt_sin))
|
||||
- Norm-preserving CFG with negative prompts
|
||||
- Uses attention mask for variable-length text
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ImageModelConfig,
|
||||
model_id: str,
|
||||
local_path: Path,
|
||||
quantize: int | None = None,
|
||||
):
|
||||
self._config = config
|
||||
self._model = QwenImage(
|
||||
model_config=ModelConfig.from_name(model_name=model_id, base_model=None),
|
||||
local_path=str(local_path),
|
||||
quantize=quantize,
|
||||
)
|
||||
self._transformer = self._model.transformer
|
||||
|
||||
@property
|
||||
def config(self) -> ImageModelConfig:
|
||||
return self._config
|
||||
|
||||
@property
|
||||
def model(self) -> QwenImage:
|
||||
return self._model
|
||||
|
||||
@property
|
||||
def transformer(self) -> QwenTransformer:
|
||||
return self._transformer
|
||||
|
||||
@property
|
||||
def hidden_dim(self) -> int:
|
||||
return self._transformer.inner_dim
|
||||
|
||||
def _get_latent_creator(self) -> type:
|
||||
return QwenLatentCreator
|
||||
|
||||
def encode_prompt(self, prompt: str) -> QwenPromptData:
|
||||
"""Encode prompt into QwenPromptData.
|
||||
|
||||
Qwen uses classifier-free guidance with explicit negative prompts.
|
||||
Returns a QwenPromptData container with all 4 tensors.
|
||||
"""
|
||||
# TODO(ciaran): empty string as default negative prompt
|
||||
negative_prompt = ""
|
||||
|
||||
prompt_embeds, prompt_mask, neg_embeds, neg_mask = (
|
||||
QwenPromptEncoder.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_cache=self._model.prompt_cache,
|
||||
qwen_tokenizer=self._model.qwen_tokenizer,
|
||||
qwen_text_encoder=self._model.text_encoder,
|
||||
)
|
||||
)
|
||||
|
||||
return QwenPromptData(
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_mask=prompt_mask,
|
||||
negative_prompt_embeds=neg_embeds,
|
||||
negative_prompt_mask=neg_mask,
|
||||
)
|
||||
|
||||
@property
|
||||
def needs_cfg(self) -> bool:
|
||||
gs = self._config.guidance_scale
|
||||
return gs is not None and gs > 1.0
|
||||
|
||||
def apply_guidance(
|
||||
self,
|
||||
noise_positive: mx.array,
|
||||
noise_negative: mx.array,
|
||||
guidance_scale: float,
|
||||
) -> mx.array:
|
||||
return self._model.compute_guided_noise(
|
||||
noise=noise_positive,
|
||||
noise_negative=noise_negative,
|
||||
guidance=guidance_scale,
|
||||
)
|
||||
|
||||
def compute_embeddings(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
prompt_embeds: mx.array,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Compute image and text embeddings."""
|
||||
# Image embedding
|
||||
embedded_hidden = self._transformer.img_in(hidden_states)
|
||||
# Text embedding: first normalize, then project
|
||||
encoder_hidden_states = self._transformer.txt_norm(prompt_embeds)
|
||||
embedded_encoder = self._transformer.txt_in(encoder_hidden_states)
|
||||
return embedded_hidden, embedded_encoder
|
||||
|
||||
def compute_text_embeddings(
|
||||
self,
|
||||
t: int,
|
||||
runtime_config: RuntimeConfig,
|
||||
pooled_prompt_embeds: mx.array | None = None,
|
||||
hidden_states: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
"""Compute time/text embeddings.
|
||||
|
||||
For Qwen, the time_text_embed only uses hidden_states for:
|
||||
- batch_size (shape[0])
|
||||
- dtype
|
||||
|
||||
This allows us to pass any tensor (latents, prompt_embeds) as a fallback
|
||||
when embedded hidden_states are not yet available.
|
||||
"""
|
||||
# Use hidden_states if provided, otherwise fall back to pooled_prompt_embeds
|
||||
# (which for Qwen is the same as prompt_embeds)
|
||||
ref_tensor = (
|
||||
hidden_states if hidden_states is not None else pooled_prompt_embeds
|
||||
)
|
||||
if ref_tensor is None:
|
||||
raise ValueError(
|
||||
"Either hidden_states or pooled_prompt_embeds is required "
|
||||
"for Qwen text embeddings"
|
||||
)
|
||||
|
||||
timestep = QwenTransformer._compute_timestep(t, runtime_config) # noqa: SLF001
|
||||
batch_size = ref_tensor.shape[0]
|
||||
timestep = mx.broadcast_to(timestep, (batch_size,)).astype(mx.float32)
|
||||
return self._transformer.time_text_embed(timestep, ref_tensor)
|
||||
|
||||
def compute_rotary_embeddings(
|
||||
self,
|
||||
prompt_embeds: mx.array,
|
||||
runtime_config: RuntimeConfig,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Compute 3D rotary embeddings for Qwen.
|
||||
|
||||
Qwen uses video-aware 3D RoPE with separate embeddings for image and text.
|
||||
|
||||
Returns:
|
||||
tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]]:
|
||||
((img_cos, img_sin), (txt_cos, txt_sin))
|
||||
"""
|
||||
encoder_hidden_states_mask = kwargs.get("encoder_hidden_states_mask")
|
||||
cond_image_grid = kwargs.get("cond_image_grid")
|
||||
|
||||
if encoder_hidden_states_mask is None:
|
||||
raise ValueError(
|
||||
"encoder_hidden_states_mask is required for Qwen RoPE computation"
|
||||
)
|
||||
|
||||
return QwenTransformer._compute_rotary_embeddings( # noqa: SLF001
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
pos_embed=self._transformer.pos_embed,
|
||||
config=runtime_config,
|
||||
cond_image_grid=cond_image_grid,
|
||||
)
|
||||
|
||||
def apply_joint_block(
|
||||
self,
|
||||
block: JointBlockInterface,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: Any, # tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]] for Qwen
|
||||
kv_cache: ImagePatchKVCache | None,
|
||||
mode: BlockWrapperMode,
|
||||
text_seq_len: int,
|
||||
patch_start: int | None = None,
|
||||
patch_end: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Apply Qwen joint block.
|
||||
|
||||
For caching mode, we run the full block and optionally populate the KV cache.
|
||||
For patched mode, we use the cached KV values (not yet implemented).
|
||||
"""
|
||||
encoder_hidden_states_mask = kwargs.get("encoder_hidden_states_mask")
|
||||
block_idx = kwargs.get("block_idx")
|
||||
|
||||
if mode == BlockWrapperMode.CACHING:
|
||||
return self._apply_joint_block_caching(
|
||||
block=block,
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=rotary_embeddings,
|
||||
kv_cache=kv_cache,
|
||||
text_seq_len=text_seq_len,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
block_idx=block_idx,
|
||||
)
|
||||
else:
|
||||
# mode == BlockWrapperMode.PATCHED
|
||||
assert patch_start is not None and patch_end is not None
|
||||
assert kv_cache is not None
|
||||
return self._apply_joint_block_patched(
|
||||
block=block,
|
||||
patch_hidden=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=rotary_embeddings,
|
||||
kv_cache=kv_cache,
|
||||
text_seq_len=text_seq_len,
|
||||
patch_start=patch_start,
|
||||
patch_end=patch_end,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
block_idx=block_idx,
|
||||
)
|
||||
|
||||
def apply_single_block(
|
||||
self,
|
||||
block: SingleBlockInterface,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: mx.array,
|
||||
kv_cache: ImagePatchKVCache | None,
|
||||
mode: BlockWrapperMode,
|
||||
text_seq_len: int,
|
||||
patch_start: int | None = None,
|
||||
patch_end: int | None = None,
|
||||
) -> mx.array:
|
||||
"""Qwen has no single blocks."""
|
||||
raise NotImplementedError("Qwen does not have single blocks")
|
||||
|
||||
def final_projection(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
) -> mx.array:
|
||||
"""Apply final normalization and projection."""
|
||||
hidden_states = self._transformer.norm_out(hidden_states, text_embeddings)
|
||||
return self._transformer.proj_out(hidden_states)
|
||||
|
||||
def get_joint_blocks(self) -> list[JointBlockInterface]:
|
||||
"""Return all 60 transformer blocks."""
|
||||
return cast(
|
||||
list[JointBlockInterface], list(self._transformer.transformer_blocks)
|
||||
)
|
||||
|
||||
def get_single_blocks(self) -> list[SingleBlockInterface]:
|
||||
"""Qwen has no single blocks."""
|
||||
return []
|
||||
|
||||
def slice_transformer_blocks(
|
||||
self,
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
total_joint_blocks: int,
|
||||
total_single_blocks: int,
|
||||
) -> None:
|
||||
all_blocks = list(self._transformer.transformer_blocks)
|
||||
assigned_blocks = all_blocks[start_layer:end_layer]
|
||||
self._transformer.transformer_blocks = assigned_blocks
|
||||
|
||||
def merge_streams(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
) -> mx.array:
|
||||
"""Merge image and text streams.
|
||||
|
||||
For Qwen, this is called before final projection.
|
||||
The streams remain separate through all blocks.
|
||||
"""
|
||||
return mx.concatenate([encoder_hidden_states, hidden_states], axis=1)
|
||||
|
||||
def _apply_joint_block_caching(
|
||||
self,
|
||||
block: Any, # QwenTransformerBlock
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]],
|
||||
kv_cache: ImagePatchKVCache | None,
|
||||
text_seq_len: int,
|
||||
encoder_hidden_states_mask: mx.array | None = None,
|
||||
block_idx: int | None = None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Apply joint block in caching mode (full attention, optionally populate cache).
|
||||
|
||||
Delegates to the QwenTransformerBlock's forward pass.
|
||||
"""
|
||||
# Call the block directly - it handles all the modulation and attention internally
|
||||
return block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
text_embeddings=text_embeddings,
|
||||
image_rotary_emb=rotary_embeddings,
|
||||
block_idx=block_idx,
|
||||
)
|
||||
|
||||
def _apply_joint_block_patched(
|
||||
self,
|
||||
block: Any, # QwenTransformerBlock
|
||||
patch_hidden: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]],
|
||||
kv_cache: ImagePatchKVCache,
|
||||
text_seq_len: int,
|
||||
patch_start: int,
|
||||
patch_end: int,
|
||||
encoder_hidden_states_mask: mx.array | None = None,
|
||||
block_idx: int | None = None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
batch_size = patch_hidden.shape[0]
|
||||
attn = block.attn
|
||||
num_heads = attn.num_heads
|
||||
head_dim = attn.head_dim
|
||||
|
||||
# 1. Compute modulation parameters
|
||||
img_mod_params = block.img_mod_linear(block.img_mod_silu(text_embeddings))
|
||||
txt_mod_params = block.txt_mod_linear(block.txt_mod_silu(text_embeddings))
|
||||
|
||||
img_mod1, img_mod2 = mx.split(img_mod_params, 2, axis=-1)
|
||||
txt_mod1, txt_mod2 = mx.split(txt_mod_params, 2, axis=-1)
|
||||
|
||||
# 2. Apply normalization and modulation
|
||||
img_normed = block.img_norm1(patch_hidden)
|
||||
img_modulated, img_gate1 = QwenTransformerBlock._modulate(img_normed, img_mod1)
|
||||
|
||||
txt_normed = block.txt_norm1(encoder_hidden_states)
|
||||
txt_modulated, txt_gate1 = QwenTransformerBlock._modulate(txt_normed, txt_mod1)
|
||||
|
||||
# 3. Compute Q, K, V for image patch
|
||||
img_query = attn.to_q(img_modulated)
|
||||
img_key = attn.to_k(img_modulated)
|
||||
img_value = attn.to_v(img_modulated)
|
||||
|
||||
# 4. Compute Q, K, V for text
|
||||
txt_query = attn.add_q_proj(txt_modulated)
|
||||
txt_key = attn.add_k_proj(txt_modulated)
|
||||
txt_value = attn.add_v_proj(txt_modulated)
|
||||
|
||||
# 5. Reshape to [B, S, H, D]
|
||||
patch_len = patch_hidden.shape[1]
|
||||
img_query = mx.reshape(img_query, (batch_size, patch_len, num_heads, head_dim))
|
||||
img_key = mx.reshape(img_key, (batch_size, patch_len, num_heads, head_dim))
|
||||
img_value = mx.reshape(img_value, (batch_size, patch_len, num_heads, head_dim))
|
||||
|
||||
txt_query = mx.reshape(
|
||||
txt_query, (batch_size, text_seq_len, num_heads, head_dim)
|
||||
)
|
||||
txt_key = mx.reshape(txt_key, (batch_size, text_seq_len, num_heads, head_dim))
|
||||
txt_value = mx.reshape(
|
||||
txt_value, (batch_size, text_seq_len, num_heads, head_dim)
|
||||
)
|
||||
|
||||
# 6. Apply RMSNorm to Q, K
|
||||
img_query = attn.norm_q(img_query)
|
||||
img_key = attn.norm_k(img_key)
|
||||
txt_query = attn.norm_added_q(txt_query)
|
||||
txt_key = attn.norm_added_k(txt_key)
|
||||
|
||||
# 7. Extract RoPE for patch: slice image RoPE, keep full text RoPE
|
||||
(img_cos, img_sin), (txt_cos, txt_sin) = rotary_embeddings
|
||||
patch_img_cos = img_cos[patch_start:patch_end]
|
||||
patch_img_sin = img_sin[patch_start:patch_end]
|
||||
|
||||
# 8. Apply RoPE to Q, K
|
||||
img_query = QwenAttention._apply_rope_qwen(
|
||||
img_query, patch_img_cos, patch_img_sin
|
||||
)
|
||||
img_key = QwenAttention._apply_rope_qwen(img_key, patch_img_cos, patch_img_sin)
|
||||
txt_query = QwenAttention._apply_rope_qwen(txt_query, txt_cos, txt_sin)
|
||||
txt_key = QwenAttention._apply_rope_qwen(txt_key, txt_cos, txt_sin)
|
||||
|
||||
# 9. Transpose to [B, H, S, D] for cache operations
|
||||
img_key_bhsd = mx.transpose(img_key, (0, 2, 1, 3))
|
||||
img_value_bhsd = mx.transpose(img_value, (0, 2, 1, 3))
|
||||
|
||||
# 10. Update cache with this patch's IMAGE K/V
|
||||
kv_cache.update_image_patch(
|
||||
patch_start=patch_start,
|
||||
patch_end=patch_end,
|
||||
key=img_key_bhsd,
|
||||
value=img_value_bhsd,
|
||||
)
|
||||
|
||||
# 11. Get full K, V from cache (text + full image)
|
||||
txt_key_bhsd = mx.transpose(txt_key, (0, 2, 1, 3))
|
||||
txt_value_bhsd = mx.transpose(txt_value, (0, 2, 1, 3))
|
||||
full_key, full_value = kv_cache.get_full_kv(
|
||||
text_key=txt_key_bhsd,
|
||||
text_value=txt_value_bhsd,
|
||||
)
|
||||
|
||||
# 12. Build query: [text, patch]
|
||||
joint_query = mx.concatenate([txt_query, img_query], axis=1)
|
||||
|
||||
# 13. Build attention mask for [text + patch] query attending to [text + full_image] KV
|
||||
mask = QwenAttention._convert_mask_for_qwen(
|
||||
mask=encoder_hidden_states_mask,
|
||||
joint_seq_len=full_key.shape[2], # text + full_image
|
||||
txt_seq_len=text_seq_len,
|
||||
)
|
||||
|
||||
# 14. Compute attention
|
||||
hidden_states = attn._compute_attention_qwen(
|
||||
query=joint_query,
|
||||
key=mx.transpose(full_key, (0, 2, 1, 3)), # Back to [B, S, H, D]
|
||||
value=mx.transpose(full_value, (0, 2, 1, 3)),
|
||||
mask=mask,
|
||||
block_idx=block_idx,
|
||||
)
|
||||
|
||||
# 15. Extract text and image attention outputs
|
||||
txt_attn_output = hidden_states[:, :text_seq_len, :]
|
||||
img_attn_output = hidden_states[:, text_seq_len:, :]
|
||||
|
||||
# 16. Project outputs
|
||||
img_attn_output = attn.attn_to_out[0](img_attn_output)
|
||||
txt_attn_output = attn.to_add_out(txt_attn_output)
|
||||
|
||||
# 17. Apply residual + gate for attention
|
||||
patch_hidden = patch_hidden + img_gate1 * img_attn_output
|
||||
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
|
||||
|
||||
# 18. Apply feed-forward for image
|
||||
img_normed2 = block.img_norm2(patch_hidden)
|
||||
img_modulated2, img_gate2 = QwenTransformerBlock._modulate(
|
||||
img_normed2, img_mod2
|
||||
)
|
||||
img_mlp_output = block.img_ff(img_modulated2)
|
||||
patch_hidden = patch_hidden + img_gate2 * img_mlp_output
|
||||
|
||||
# 19. Apply feed-forward for text
|
||||
txt_normed2 = block.txt_norm2(encoder_hidden_states)
|
||||
txt_modulated2, txt_gate2 = QwenTransformerBlock._modulate(
|
||||
txt_normed2, txt_mod2
|
||||
)
|
||||
txt_mlp_output = block.txt_ff(txt_modulated2)
|
||||
encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
|
||||
|
||||
return encoder_hidden_states, patch_hidden
|
||||
49
src/exo/worker/engines/image/models/qwen/config.py
Normal file
49
src/exo/worker/engines/image/models/qwen/config.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from exo.worker.engines.image.config import (
|
||||
BlockType,
|
||||
ImageModelConfig,
|
||||
TransformerBlockConfig,
|
||||
)
|
||||
|
||||
# Qwen-Image has 60 joint-style blocks (no single blocks)
|
||||
# Architecture: 24 heads * 128 dim = 3072 hidden dim
|
||||
# VAE uses scale factor of 16 (vs Flux's 8)
|
||||
QWEN_IMAGE_CONFIG = ImageModelConfig(
|
||||
model_family="qwen",
|
||||
model_variant="image",
|
||||
hidden_dim=3072,
|
||||
num_heads=24,
|
||||
head_dim=128,
|
||||
block_configs=(
|
||||
TransformerBlockConfig(
|
||||
block_type=BlockType.JOINT, count=60, has_separate_text_output=True
|
||||
),
|
||||
# Qwen has no single blocks - all blocks process image and text separately
|
||||
),
|
||||
patch_size=2,
|
||||
vae_scale_factor=16,
|
||||
default_steps={"low": 10, "medium": 25, "high": 50},
|
||||
num_sync_steps_factor=0.125, # ~3 sync steps for medium (30 steps)
|
||||
uses_attention_mask=True, # Qwen uses encoder_hidden_states_mask
|
||||
guidance_scale=None, # Set to None or < 1.0 to disable CFG
|
||||
)
|
||||
|
||||
# Qwen-Image-Edit uses the same architecture but different processing pipeline
|
||||
# Uses vision-language encoding and conditioning latents
|
||||
QWEN_IMAGE_EDIT_CONFIG = ImageModelConfig(
|
||||
model_family="qwen-edit",
|
||||
model_variant="image-edit",
|
||||
hidden_dim=3072,
|
||||
num_heads=24,
|
||||
head_dim=128,
|
||||
block_configs=(
|
||||
TransformerBlockConfig(
|
||||
block_type=BlockType.JOINT, count=60, has_separate_text_output=True
|
||||
),
|
||||
),
|
||||
patch_size=2,
|
||||
vae_scale_factor=16,
|
||||
default_steps={"low": 10, "medium": 25, "high": 50},
|
||||
num_sync_steps_factor=0.125,
|
||||
uses_attention_mask=True,
|
||||
guidance_scale=None,
|
||||
)
|
||||
671
src/exo/worker/engines/image/models/qwen/edit_adapter.py
Normal file
671
src/exo/worker/engines/image/models/qwen/edit_adapter.py
Normal file
@@ -0,0 +1,671 @@
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
import mlx.core as mx
|
||||
from mflux.config.runtime_config import RuntimeConfig
|
||||
from mflux.models.qwen.latent_creator.qwen_latent_creator import QwenLatentCreator
|
||||
from mflux.models.qwen.model.qwen_transformer.qwen_attention import QwenAttention
|
||||
from mflux.models.qwen.model.qwen_transformer.qwen_transformer import QwenTransformer
|
||||
from mflux.models.qwen.model.qwen_transformer.qwen_transformer_block import (
|
||||
QwenTransformerBlock,
|
||||
)
|
||||
from mflux.models.qwen.variants.edit.qwen_image_edit import QwenImageEdit
|
||||
from mflux.models.qwen.variants.edit.utils.qwen_edit_util import QwenEditUtil
|
||||
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.models.base import BaseModelAdapter
|
||||
from exo.worker.engines.image.pipeline.adapter import (
|
||||
BlockWrapperMode,
|
||||
JointBlockInterface,
|
||||
SingleBlockInterface,
|
||||
)
|
||||
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
|
||||
|
||||
|
||||
class QwenEditPromptData:
|
||||
"""Container for Qwen edit prompt encoding results.
|
||||
|
||||
Includes vision-language encoded embeddings and edit-specific conditioning.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_embeds: mx.array,
|
||||
prompt_mask: mx.array,
|
||||
negative_prompt_embeds: mx.array,
|
||||
negative_prompt_mask: mx.array,
|
||||
conditioning_latents: mx.array,
|
||||
qwen_image_ids: mx.array,
|
||||
cond_image_grid: tuple[int, int, int] | list[tuple[int, int, int]],
|
||||
):
|
||||
self._prompt_embeds = prompt_embeds
|
||||
self.prompt_mask = prompt_mask
|
||||
self._negative_prompt_embeds = negative_prompt_embeds
|
||||
self.negative_prompt_mask = negative_prompt_mask
|
||||
self._conditioning_latents = conditioning_latents
|
||||
self._qwen_image_ids = qwen_image_ids
|
||||
self._cond_image_grid = cond_image_grid
|
||||
|
||||
@property
|
||||
def prompt_embeds(self) -> mx.array:
|
||||
"""Text embeddings from vision-language encoder."""
|
||||
return self._prompt_embeds
|
||||
|
||||
@property
|
||||
def pooled_prompt_embeds(self) -> mx.array:
|
||||
"""Placeholder for protocol compliance - Qwen doesn't use pooled embeds."""
|
||||
return self._prompt_embeds
|
||||
|
||||
@property
|
||||
def negative_prompt_embeds(self) -> mx.array:
|
||||
"""Negative prompt embeddings for CFG."""
|
||||
return self._negative_prompt_embeds
|
||||
|
||||
@property
|
||||
def negative_pooled_prompt_embeds(self) -> mx.array:
|
||||
"""Placeholder - Qwen doesn't use pooled embeds."""
|
||||
return self._negative_prompt_embeds
|
||||
|
||||
@property
|
||||
def conditioning_latents(self) -> mx.array:
|
||||
"""Static image conditioning latents to concatenate with generated latents."""
|
||||
return self._conditioning_latents
|
||||
|
||||
@property
|
||||
def qwen_image_ids(self) -> mx.array:
|
||||
"""Spatial position IDs for conditioning images."""
|
||||
return self._qwen_image_ids
|
||||
|
||||
@property
|
||||
def cond_image_grid(self) -> tuple[int, int, int] | list[tuple[int, int, int]]:
|
||||
"""Conditioning image grid dimensions."""
|
||||
return self._cond_image_grid
|
||||
|
||||
def get_extra_forward_kwargs(self, positive: bool = True) -> dict[str, Any]:
|
||||
"""Return encoder_hidden_states_mask and edit-specific params."""
|
||||
if positive:
|
||||
return {
|
||||
"encoder_hidden_states_mask": self.prompt_mask,
|
||||
"qwen_image_ids": self._qwen_image_ids,
|
||||
"cond_image_grid": self._cond_image_grid,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"encoder_hidden_states_mask": self.negative_prompt_mask,
|
||||
"qwen_image_ids": self._qwen_image_ids,
|
||||
"cond_image_grid": self._cond_image_grid,
|
||||
}
|
||||
|
||||
@property
|
||||
def is_edit_mode(self) -> bool:
|
||||
"""Indicates this is edit mode with conditioning latents."""
|
||||
return True
|
||||
|
||||
|
||||
class QwenEditModelAdapter(BaseModelAdapter):
|
||||
"""Adapter for Qwen-Image-Edit model.
|
||||
|
||||
Key differences from standard QwenModelAdapter:
|
||||
- Uses QwenImageEdit model with vision-language components
|
||||
- Encodes prompts WITH input images via VL tokenizer/encoder
|
||||
- Creates conditioning latents from input images
|
||||
- Supports image editing with concatenated latents during diffusion
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ImageModelConfig,
|
||||
model_id: str,
|
||||
local_path: Path,
|
||||
quantize: int | None = None,
|
||||
):
|
||||
self._config = config
|
||||
self._model = QwenImageEdit(
|
||||
quantize=quantize,
|
||||
local_path=str(local_path),
|
||||
)
|
||||
self._transformer = self._model.transformer
|
||||
|
||||
# Store dimensions and image paths (set via set_image_dimensions)
|
||||
self._vl_width: int | None = None
|
||||
self._vl_height: int | None = None
|
||||
self._vae_width: int | None = None
|
||||
self._vae_height: int | None = None
|
||||
self._image_paths: list[str] | None = None
|
||||
|
||||
@property
|
||||
def config(self) -> ImageModelConfig:
|
||||
return self._config
|
||||
|
||||
@property
|
||||
def model(self) -> QwenImageEdit:
|
||||
return self._model
|
||||
|
||||
@property
|
||||
def transformer(self) -> QwenTransformer:
|
||||
return self._transformer
|
||||
|
||||
@property
|
||||
def hidden_dim(self) -> int:
|
||||
return self._transformer.inner_dim
|
||||
|
||||
def _get_latent_creator(self) -> type:
|
||||
return QwenLatentCreator
|
||||
|
||||
def _compute_dimensions_from_image(
|
||||
self, image_path: Path
|
||||
) -> tuple[int, int, int, int, int, int]:
|
||||
"""Compute VL and VAE dimensions from input image.
|
||||
|
||||
Returns:
|
||||
(vl_width, vl_height, vae_width, vae_height, output_width, output_height)
|
||||
"""
|
||||
from mflux.utils.image_util import ImageUtil
|
||||
|
||||
pil_image = ImageUtil.load_image(str(image_path)).convert("RGB")
|
||||
image_size = pil_image.size
|
||||
|
||||
# Vision-language dimensions (384x384 target area)
|
||||
condition_image_size = 384 * 384
|
||||
condition_ratio = image_size[0] / image_size[1]
|
||||
vl_width = math.sqrt(condition_image_size * condition_ratio)
|
||||
vl_height = vl_width / condition_ratio
|
||||
vl_width = round(vl_width / 32) * 32
|
||||
vl_height = round(vl_height / 32) * 32
|
||||
|
||||
# VAE dimensions (1024x1024 target area)
|
||||
vae_image_size = 1024 * 1024
|
||||
vae_ratio = image_size[0] / image_size[1]
|
||||
vae_width = math.sqrt(vae_image_size * vae_ratio)
|
||||
vae_height = vae_width / vae_ratio
|
||||
vae_width = round(vae_width / 32) * 32
|
||||
vae_height = round(vae_height / 32) * 32
|
||||
|
||||
# Output dimensions from input image aspect ratio
|
||||
target_area = 1024 * 1024
|
||||
ratio = image_size[0] / image_size[1]
|
||||
output_width = math.sqrt(target_area * ratio)
|
||||
output_height = output_width / ratio
|
||||
output_width = round(output_width / 32) * 32
|
||||
output_height = round(output_height / 32) * 32
|
||||
|
||||
# Ensure multiple of 16 for VAE
|
||||
vae_scale_factor = 8
|
||||
multiple_of = vae_scale_factor * 2
|
||||
output_width = output_width // multiple_of * multiple_of
|
||||
output_height = output_height // multiple_of * multiple_of
|
||||
|
||||
return (
|
||||
int(vl_width),
|
||||
int(vl_height),
|
||||
int(vae_width),
|
||||
int(vae_height),
|
||||
int(output_width),
|
||||
int(output_height),
|
||||
)
|
||||
|
||||
def create_latents(self, seed: int, runtime_config: RuntimeConfig) -> mx.array:
|
||||
"""Create initial noise latents (pure noise for edit mode)."""
|
||||
return QwenLatentCreator.create_noise(
|
||||
seed=seed,
|
||||
height=runtime_config.height,
|
||||
width=runtime_config.width,
|
||||
)
|
||||
|
||||
def encode_prompt(self, prompt: str) -> QwenEditPromptData:
|
||||
"""Encode prompt with input images using vision-language encoder.
|
||||
|
||||
Uses stored image_paths from set_image_dimensions() for VL encoding.
|
||||
|
||||
Args:
|
||||
prompt: Text prompt for editing
|
||||
|
||||
Returns:
|
||||
QwenEditPromptData with VL embeddings and conditioning latents
|
||||
"""
|
||||
# Ensure image_paths and dimensions were set via set_image_dimensions()
|
||||
if self._image_paths is None:
|
||||
raise RuntimeError(
|
||||
"set_image_dimensions() must be called before encode_prompt() "
|
||||
"for QwenEditModelAdapter"
|
||||
)
|
||||
|
||||
negative_prompt = ""
|
||||
image_paths = self._image_paths
|
||||
|
||||
# Use stored dimensions (computed from input image)
|
||||
vl_width = self._vl_width
|
||||
vl_height = self._vl_height
|
||||
vae_width = self._vae_width
|
||||
vae_height = self._vae_height
|
||||
|
||||
# Encode prompts with images via vision-language components
|
||||
tokenizer = self._model.qwen_vl_tokenizer
|
||||
pos_input_ids, pos_attention_mask, pos_pixel_values, pos_image_grid_thw = (
|
||||
tokenizer.tokenize_with_image(
|
||||
prompt, image_paths, vl_width=vl_width, vl_height=vl_height
|
||||
)
|
||||
)
|
||||
|
||||
pos_hidden_states = self._model.qwen_vl_encoder(
|
||||
input_ids=pos_input_ids,
|
||||
attention_mask=pos_attention_mask,
|
||||
pixel_values=pos_pixel_values,
|
||||
image_grid_thw=pos_image_grid_thw,
|
||||
)
|
||||
mx.eval(pos_hidden_states[0])
|
||||
mx.eval(pos_hidden_states[1])
|
||||
|
||||
# Encode negative prompt with images
|
||||
neg_input_ids, neg_attention_mask, neg_pixel_values, neg_image_grid_thw = (
|
||||
tokenizer.tokenize_with_image(
|
||||
negative_prompt, image_paths, vl_width=vl_width, vl_height=vl_height
|
||||
)
|
||||
)
|
||||
|
||||
neg_hidden_states = self._model.qwen_vl_encoder(
|
||||
input_ids=neg_input_ids,
|
||||
attention_mask=neg_attention_mask,
|
||||
pixel_values=neg_pixel_values,
|
||||
image_grid_thw=neg_image_grid_thw,
|
||||
)
|
||||
mx.eval(neg_hidden_states[0])
|
||||
mx.eval(neg_hidden_states[1])
|
||||
|
||||
# Create conditioning latents from input images
|
||||
# Ensure dimensions are set (should have been set via set_image_dimensions)
|
||||
assert vl_width is not None and vl_height is not None
|
||||
assert vae_width is not None and vae_height is not None
|
||||
|
||||
(
|
||||
conditioning_latents,
|
||||
qwen_image_ids,
|
||||
cond_h_patches,
|
||||
cond_w_patches,
|
||||
num_images,
|
||||
) = QwenEditUtil.create_image_conditioning_latents(
|
||||
vae=self._model.vae,
|
||||
height=vae_height,
|
||||
width=vae_width,
|
||||
image_paths=image_paths,
|
||||
vl_width=vl_width,
|
||||
vl_height=vl_height,
|
||||
)
|
||||
|
||||
# Build cond_image_grid
|
||||
if num_images > 1:
|
||||
cond_image_grid: tuple[int, int, int] | list[tuple[int, int, int]] = [
|
||||
(1, cond_h_patches, cond_w_patches) for _ in range(num_images)
|
||||
]
|
||||
else:
|
||||
cond_image_grid = (1, cond_h_patches, cond_w_patches)
|
||||
|
||||
return QwenEditPromptData(
|
||||
prompt_embeds=pos_hidden_states[0].astype(mx.float16),
|
||||
prompt_mask=pos_hidden_states[1].astype(mx.float16),
|
||||
negative_prompt_embeds=neg_hidden_states[0].astype(mx.float16),
|
||||
negative_prompt_mask=neg_hidden_states[1].astype(mx.float16),
|
||||
conditioning_latents=conditioning_latents,
|
||||
qwen_image_ids=qwen_image_ids,
|
||||
cond_image_grid=cond_image_grid,
|
||||
)
|
||||
|
||||
def set_image_dimensions(self, image_path: Path) -> tuple[int, int]:
|
||||
"""Compute and store dimensions from input image.
|
||||
|
||||
Also stores image_paths for use in encode_prompt().
|
||||
|
||||
Returns:
|
||||
(output_width, output_height) for runtime config
|
||||
"""
|
||||
vl_w, vl_h, vae_w, vae_h, out_w, out_h = self._compute_dimensions_from_image(
|
||||
image_path
|
||||
)
|
||||
self._vl_width = vl_w
|
||||
self._vl_height = vl_h
|
||||
self._vae_width = vae_w
|
||||
self._vae_height = vae_h
|
||||
self._image_paths = [str(image_path)]
|
||||
return out_w, out_h
|
||||
|
||||
@property
|
||||
def needs_cfg(self) -> bool:
|
||||
gs = self._config.guidance_scale
|
||||
return gs is not None and gs > 1.0
|
||||
|
||||
def apply_guidance(
|
||||
self,
|
||||
noise_positive: mx.array,
|
||||
noise_negative: mx.array,
|
||||
guidance_scale: float,
|
||||
) -> mx.array:
|
||||
from mflux.models.qwen.variants.txt2img.qwen_image import QwenImage
|
||||
|
||||
return QwenImage.compute_guided_noise(
|
||||
noise=noise_positive,
|
||||
noise_negative=noise_negative,
|
||||
guidance=guidance_scale,
|
||||
)
|
||||
|
||||
def compute_embeddings(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
prompt_embeds: mx.array,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Compute image and text embeddings."""
|
||||
embedded_hidden = self._transformer.img_in(hidden_states)
|
||||
encoder_hidden_states = self._transformer.txt_norm(prompt_embeds)
|
||||
embedded_encoder = self._transformer.txt_in(encoder_hidden_states)
|
||||
return embedded_hidden, embedded_encoder
|
||||
|
||||
def compute_text_embeddings(
|
||||
self,
|
||||
t: int,
|
||||
runtime_config: RuntimeConfig,
|
||||
pooled_prompt_embeds: mx.array | None = None,
|
||||
hidden_states: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
"""Compute time/text embeddings."""
|
||||
ref_tensor = (
|
||||
hidden_states if hidden_states is not None else pooled_prompt_embeds
|
||||
)
|
||||
if ref_tensor is None:
|
||||
raise ValueError(
|
||||
"Either hidden_states or pooled_prompt_embeds is required "
|
||||
"for Qwen text embeddings"
|
||||
)
|
||||
|
||||
timestep = QwenTransformer._compute_timestep(t, runtime_config) # noqa: SLF001
|
||||
batch_size = ref_tensor.shape[0]
|
||||
timestep = mx.broadcast_to(timestep, (batch_size,)).astype(mx.float32)
|
||||
return self._transformer.time_text_embed(timestep, ref_tensor)
|
||||
|
||||
def compute_rotary_embeddings(
|
||||
self,
|
||||
prompt_embeds: mx.array,
|
||||
runtime_config: RuntimeConfig,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Compute 3D rotary embeddings for Qwen edit."""
|
||||
encoder_hidden_states_mask = kwargs.get("encoder_hidden_states_mask")
|
||||
cond_image_grid = kwargs.get("cond_image_grid")
|
||||
|
||||
if encoder_hidden_states_mask is None:
|
||||
raise ValueError(
|
||||
"encoder_hidden_states_mask is required for Qwen RoPE computation"
|
||||
)
|
||||
|
||||
return QwenTransformer._compute_rotary_embeddings( # noqa: SLF001
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
pos_embed=self._transformer.pos_embed,
|
||||
config=runtime_config,
|
||||
cond_image_grid=cond_image_grid,
|
||||
)
|
||||
|
||||
def apply_joint_block(
|
||||
self,
|
||||
block: JointBlockInterface,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: Any,
|
||||
kv_cache: ImagePatchKVCache | None,
|
||||
mode: BlockWrapperMode,
|
||||
text_seq_len: int,
|
||||
patch_start: int | None = None,
|
||||
patch_end: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Apply Qwen joint block."""
|
||||
encoder_hidden_states_mask = kwargs.get("encoder_hidden_states_mask")
|
||||
block_idx = kwargs.get("block_idx")
|
||||
|
||||
if mode == BlockWrapperMode.CACHING:
|
||||
return self._apply_joint_block_caching(
|
||||
block=block,
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=rotary_embeddings,
|
||||
kv_cache=kv_cache,
|
||||
text_seq_len=text_seq_len,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
block_idx=block_idx,
|
||||
)
|
||||
else:
|
||||
assert patch_start is not None and patch_end is not None
|
||||
assert kv_cache is not None
|
||||
return self._apply_joint_block_patched(
|
||||
block=block,
|
||||
patch_hidden=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=rotary_embeddings,
|
||||
kv_cache=kv_cache,
|
||||
text_seq_len=text_seq_len,
|
||||
patch_start=patch_start,
|
||||
patch_end=patch_end,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
block_idx=block_idx,
|
||||
)
|
||||
|
||||
def apply_single_block(
|
||||
self,
|
||||
block: SingleBlockInterface,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: mx.array,
|
||||
kv_cache: ImagePatchKVCache | None,
|
||||
mode: BlockWrapperMode,
|
||||
text_seq_len: int,
|
||||
patch_start: int | None = None,
|
||||
patch_end: int | None = None,
|
||||
) -> mx.array:
|
||||
"""Qwen has no single blocks."""
|
||||
raise NotImplementedError("Qwen does not have single blocks")
|
||||
|
||||
def final_projection(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
) -> mx.array:
|
||||
"""Apply final normalization and projection."""
|
||||
hidden_states = self._transformer.norm_out(hidden_states, text_embeddings)
|
||||
return self._transformer.proj_out(hidden_states)
|
||||
|
||||
def get_joint_blocks(self) -> list[JointBlockInterface]:
|
||||
"""Return all 60 transformer blocks."""
|
||||
return cast(
|
||||
list[JointBlockInterface], list(self._transformer.transformer_blocks)
|
||||
)
|
||||
|
||||
def get_single_blocks(self) -> list[SingleBlockInterface]:
|
||||
"""Qwen has no single blocks."""
|
||||
return []
|
||||
|
||||
def slice_transformer_blocks(
|
||||
self,
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
total_joint_blocks: int,
|
||||
total_single_blocks: int,
|
||||
) -> None:
|
||||
all_blocks = list(self._transformer.transformer_blocks)
|
||||
assigned_blocks = all_blocks[start_layer:end_layer]
|
||||
self._transformer.transformer_blocks = assigned_blocks
|
||||
|
||||
def merge_streams(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
) -> mx.array:
|
||||
"""Merge image and text streams."""
|
||||
return mx.concatenate([encoder_hidden_states, hidden_states], axis=1)
|
||||
|
||||
def _apply_joint_block_caching(
|
||||
self,
|
||||
block: Any,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]],
|
||||
kv_cache: ImagePatchKVCache | None,
|
||||
text_seq_len: int,
|
||||
encoder_hidden_states_mask: mx.array | None = None,
|
||||
block_idx: int | None = None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Apply joint block in caching mode."""
|
||||
return block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
text_embeddings=text_embeddings,
|
||||
image_rotary_emb=rotary_embeddings,
|
||||
block_idx=block_idx,
|
||||
)
|
||||
|
||||
def _apply_joint_block_patched(
|
||||
self,
|
||||
block: Any,
|
||||
patch_hidden: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]],
|
||||
kv_cache: ImagePatchKVCache,
|
||||
text_seq_len: int,
|
||||
patch_start: int,
|
||||
patch_end: int,
|
||||
encoder_hidden_states_mask: mx.array | None = None,
|
||||
block_idx: int | None = None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
batch_size = patch_hidden.shape[0]
|
||||
attn = block.attn
|
||||
num_heads = attn.num_heads
|
||||
head_dim = attn.head_dim
|
||||
|
||||
# Modulation parameters
|
||||
img_mod_params = block.img_mod_linear(block.img_mod_silu(text_embeddings))
|
||||
txt_mod_params = block.txt_mod_linear(block.txt_mod_silu(text_embeddings))
|
||||
|
||||
img_mod1, img_mod2 = mx.split(img_mod_params, 2, axis=-1)
|
||||
txt_mod1, txt_mod2 = mx.split(txt_mod_params, 2, axis=-1)
|
||||
|
||||
# Normalization and modulation
|
||||
img_normed = block.img_norm1(patch_hidden)
|
||||
img_modulated, img_gate1 = QwenTransformerBlock._modulate(img_normed, img_mod1)
|
||||
|
||||
txt_normed = block.txt_norm1(encoder_hidden_states)
|
||||
txt_modulated, txt_gate1 = QwenTransformerBlock._modulate(txt_normed, txt_mod1)
|
||||
|
||||
# Q, K, V for image patch
|
||||
img_query = attn.to_q(img_modulated)
|
||||
img_key = attn.to_k(img_modulated)
|
||||
img_value = attn.to_v(img_modulated)
|
||||
|
||||
# Q, K, V for text
|
||||
txt_query = attn.add_q_proj(txt_modulated)
|
||||
txt_key = attn.add_k_proj(txt_modulated)
|
||||
txt_value = attn.add_v_proj(txt_modulated)
|
||||
|
||||
# Reshape to [B, S, H, D]
|
||||
patch_len = patch_hidden.shape[1]
|
||||
img_query = mx.reshape(img_query, (batch_size, patch_len, num_heads, head_dim))
|
||||
img_key = mx.reshape(img_key, (batch_size, patch_len, num_heads, head_dim))
|
||||
img_value = mx.reshape(img_value, (batch_size, patch_len, num_heads, head_dim))
|
||||
|
||||
txt_query = mx.reshape(
|
||||
txt_query, (batch_size, text_seq_len, num_heads, head_dim)
|
||||
)
|
||||
txt_key = mx.reshape(txt_key, (batch_size, text_seq_len, num_heads, head_dim))
|
||||
txt_value = mx.reshape(
|
||||
txt_value, (batch_size, text_seq_len, num_heads, head_dim)
|
||||
)
|
||||
|
||||
# RMSNorm to Q, K
|
||||
img_query = attn.norm_q(img_query)
|
||||
img_key = attn.norm_k(img_key)
|
||||
txt_query = attn.norm_added_q(txt_query)
|
||||
txt_key = attn.norm_added_k(txt_key)
|
||||
|
||||
# Extract RoPE for patch
|
||||
(img_cos, img_sin), (txt_cos, txt_sin) = rotary_embeddings
|
||||
patch_img_cos = img_cos[patch_start:patch_end]
|
||||
patch_img_sin = img_sin[patch_start:patch_end]
|
||||
|
||||
# Apply RoPE
|
||||
img_query = QwenAttention._apply_rope_qwen(
|
||||
img_query, patch_img_cos, patch_img_sin
|
||||
)
|
||||
img_key = QwenAttention._apply_rope_qwen(img_key, patch_img_cos, patch_img_sin)
|
||||
txt_query = QwenAttention._apply_rope_qwen(txt_query, txt_cos, txt_sin)
|
||||
txt_key = QwenAttention._apply_rope_qwen(txt_key, txt_cos, txt_sin)
|
||||
|
||||
# Transpose to [B, H, S, D]
|
||||
img_key_bhsd = mx.transpose(img_key, (0, 2, 1, 3))
|
||||
img_value_bhsd = mx.transpose(img_value, (0, 2, 1, 3))
|
||||
|
||||
# Update cache
|
||||
kv_cache.update_image_patch(
|
||||
patch_start=patch_start,
|
||||
patch_end=patch_end,
|
||||
key=img_key_bhsd,
|
||||
value=img_value_bhsd,
|
||||
)
|
||||
|
||||
# Get full K, V from cache
|
||||
txt_key_bhsd = mx.transpose(txt_key, (0, 2, 1, 3))
|
||||
txt_value_bhsd = mx.transpose(txt_value, (0, 2, 1, 3))
|
||||
full_key, full_value = kv_cache.get_full_kv(
|
||||
text_key=txt_key_bhsd,
|
||||
text_value=txt_value_bhsd,
|
||||
)
|
||||
|
||||
# Build query
|
||||
joint_query = mx.concatenate([txt_query, img_query], axis=1)
|
||||
|
||||
# Build attention mask
|
||||
mask = QwenAttention._convert_mask_for_qwen(
|
||||
mask=encoder_hidden_states_mask,
|
||||
joint_seq_len=full_key.shape[2],
|
||||
txt_seq_len=text_seq_len,
|
||||
)
|
||||
|
||||
# Compute attention
|
||||
hidden_states = attn._compute_attention_qwen(
|
||||
query=joint_query,
|
||||
key=mx.transpose(full_key, (0, 2, 1, 3)),
|
||||
value=mx.transpose(full_value, (0, 2, 1, 3)),
|
||||
mask=mask,
|
||||
block_idx=block_idx,
|
||||
)
|
||||
|
||||
# Extract outputs
|
||||
txt_attn_output = hidden_states[:, :text_seq_len, :]
|
||||
img_attn_output = hidden_states[:, text_seq_len:, :]
|
||||
|
||||
# Project
|
||||
img_attn_output = attn.attn_to_out[0](img_attn_output)
|
||||
txt_attn_output = attn.to_add_out(txt_attn_output)
|
||||
|
||||
# Residual + gate
|
||||
patch_hidden = patch_hidden + img_gate1 * img_attn_output
|
||||
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
|
||||
|
||||
# Feed-forward for image
|
||||
img_normed2 = block.img_norm2(patch_hidden)
|
||||
img_modulated2, img_gate2 = QwenTransformerBlock._modulate(
|
||||
img_normed2, img_mod2
|
||||
)
|
||||
img_mlp_output = block.img_ff(img_modulated2)
|
||||
patch_hidden = patch_hidden + img_gate2 * img_mlp_output
|
||||
|
||||
# Feed-forward for text
|
||||
txt_normed2 = block.txt_norm2(encoder_hidden_states)
|
||||
txt_modulated2, txt_gate2 = QwenTransformerBlock._modulate(
|
||||
txt_normed2, txt_mod2
|
||||
)
|
||||
txt_mlp_output = block.txt_ff(txt_modulated2)
|
||||
encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
|
||||
|
||||
return encoder_hidden_states, patch_hidden
|
||||
23
src/exo/worker/engines/image/pipeline/__init__.py
Normal file
23
src/exo/worker/engines/image/pipeline/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from exo.worker.engines.image.pipeline.adapter import (
|
||||
BlockWrapperMode,
|
||||
JointBlockInterface,
|
||||
ModelAdapter,
|
||||
SingleBlockInterface,
|
||||
)
|
||||
from exo.worker.engines.image.pipeline.block_wrapper import (
|
||||
JointBlockWrapper,
|
||||
SingleBlockWrapper,
|
||||
)
|
||||
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
|
||||
from exo.worker.engines.image.pipeline.runner import DiffusionRunner
|
||||
|
||||
__all__ = [
|
||||
"BlockWrapperMode",
|
||||
"DiffusionRunner",
|
||||
"ImagePatchKVCache",
|
||||
"JointBlockInterface",
|
||||
"JointBlockWrapper",
|
||||
"ModelAdapter",
|
||||
"SingleBlockInterface",
|
||||
"SingleBlockWrapper",
|
||||
]
|
||||
402
src/exo/worker/engines/image/pipeline/adapter.py
Normal file
402
src/exo/worker/engines/image/pipeline/adapter.py
Normal file
@@ -0,0 +1,402 @@
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol
|
||||
|
||||
import mlx.core as mx
|
||||
from mflux.config.runtime_config import RuntimeConfig
|
||||
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
|
||||
|
||||
|
||||
class AttentionInterface(Protocol):
|
||||
num_heads: int
|
||||
head_dimension: int
|
||||
to_q: Any
|
||||
to_k: Any
|
||||
to_v: Any
|
||||
norm_q: Any
|
||||
norm_k: Any
|
||||
to_out: list[Any]
|
||||
|
||||
|
||||
class JointAttentionInterface(AttentionInterface, Protocol):
|
||||
add_q_proj: Any
|
||||
add_k_proj: Any
|
||||
add_v_proj: Any
|
||||
norm_added_q: Any
|
||||
norm_added_k: Any
|
||||
to_add_out: Any
|
||||
|
||||
|
||||
class JointBlockInterface(Protocol):
|
||||
attn: JointAttentionInterface
|
||||
norm1: Any # Callable module: (hidden_states, text_embeddings) -> tuple[5 arrays]
|
||||
norm1_context: (
|
||||
Any # Callable module: (hidden_states, text_embeddings) -> tuple[5 arrays]
|
||||
)
|
||||
norm2: Any
|
||||
norm2_context: Any
|
||||
ff: Any
|
||||
ff_context: Any
|
||||
|
||||
|
||||
class SingleBlockInterface(Protocol):
|
||||
attn: AttentionInterface
|
||||
norm: Any # Callable module: (hidden_states, text_embeddings) -> tuple[2 arrays]
|
||||
|
||||
def _apply_feed_forward_and_projection(
|
||||
self, norm_hidden_states: mx.array, attn_output: mx.array, gate: mx.array
|
||||
) -> mx.array:
|
||||
"""Apply feed forward network and projection."""
|
||||
...
|
||||
|
||||
|
||||
class BlockWrapperMode(Enum):
|
||||
CACHING = "caching" # Sync mode: compute full attention, populate cache
|
||||
PATCHED = "patched" # Async mode: compute patch attention, use cached KV
|
||||
|
||||
|
||||
class PromptData(Protocol):
|
||||
"""Protocol for encoded prompt data.
|
||||
|
||||
All adapters must return prompt data that conforms to this protocol.
|
||||
Model-specific prompt data classes can add additional attributes
|
||||
(e.g., attention masks for Qwen).
|
||||
"""
|
||||
|
||||
@property
|
||||
def prompt_embeds(self) -> mx.array:
|
||||
"""Text embeddings from encoder."""
|
||||
...
|
||||
|
||||
@property
|
||||
def pooled_prompt_embeds(self) -> mx.array:
|
||||
"""Pooled text embeddings (for Flux) or placeholder (for Qwen)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def negative_prompt_embeds(self) -> mx.array | None:
|
||||
"""Negative prompt embeddings for CFG (None if not using CFG)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def negative_pooled_prompt_embeds(self) -> mx.array | None:
|
||||
"""Negative pooled embeddings for CFG (None if not using CFG)."""
|
||||
...
|
||||
|
||||
def get_extra_forward_kwargs(self, positive: bool = True) -> dict[str, Any]:
|
||||
"""Return model-specific kwargs for forward pass.
|
||||
|
||||
Args:
|
||||
positive: If True, return kwargs for positive prompt pass.
|
||||
If False, return kwargs for negative prompt pass.
|
||||
|
||||
Returns:
|
||||
Dict of extra kwargs (e.g., {"encoder_hidden_states_mask": ...} for Qwen)
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def conditioning_latents(self) -> mx.array | None:
|
||||
"""Conditioning latents for edit mode.
|
||||
|
||||
Returns:
|
||||
Conditioning latents array for image editing, None for standard generation.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class ModelAdapter(Protocol):
|
||||
@property
|
||||
def config(self) -> ImageModelConfig:
|
||||
"""Return the model configuration."""
|
||||
...
|
||||
|
||||
@property
|
||||
def model(self) -> Any:
|
||||
"""Return the underlying mflux model instance (e.g., Flux1, Fibo, Qwen)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def transformer(self) -> Any:
|
||||
"""Return the transformer component of the model."""
|
||||
...
|
||||
|
||||
@property
|
||||
def hidden_dim(self) -> int:
|
||||
"""Return the hidden dimension of the transformer."""
|
||||
...
|
||||
|
||||
def compute_embeddings(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
prompt_embeds: mx.array,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Compute x_embedder and context_embedder outputs.
|
||||
|
||||
Args:
|
||||
hidden_states: Input latent states
|
||||
prompt_embeds: Text embeddings from encoder
|
||||
|
||||
Returns:
|
||||
Tuple of (embedded_hidden_states, embedded_encoder_states)
|
||||
"""
|
||||
...
|
||||
|
||||
def compute_text_embeddings(
|
||||
self,
|
||||
t: int,
|
||||
runtime_config: RuntimeConfig,
|
||||
pooled_prompt_embeds: mx.array | None = None,
|
||||
hidden_states: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
"""Compute time/text embeddings for conditioning.
|
||||
|
||||
Args:
|
||||
t: Current timestep
|
||||
runtime_config: Runtime configuration
|
||||
pooled_prompt_embeds: Pooled text embeddings (used by Flux)
|
||||
hidden_states: Image hidden states
|
||||
|
||||
Returns:
|
||||
Text embeddings tensor
|
||||
"""
|
||||
...
|
||||
|
||||
def compute_rotary_embeddings(
|
||||
self,
|
||||
prompt_embeds: mx.array,
|
||||
runtime_config: RuntimeConfig,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Compute rotary position embeddings.
|
||||
|
||||
Args:
|
||||
prompt_embeds: Text embeddings
|
||||
runtime_config: Runtime configuration
|
||||
**kwargs: Model-specific arguments (e.g., encoder_hidden_states_mask for Qwen)
|
||||
|
||||
Returns:
|
||||
Flux: mx.array
|
||||
Qwen: tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]]
|
||||
"""
|
||||
...
|
||||
|
||||
def apply_joint_block(
|
||||
self,
|
||||
block: JointBlockInterface,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: Any, # Format varies: mx.array (Flux) or nested tuple (Qwen)
|
||||
kv_cache: ImagePatchKVCache | None,
|
||||
mode: "BlockWrapperMode",
|
||||
text_seq_len: int,
|
||||
patch_start: int | None = None,
|
||||
patch_end: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Apply a joint transformer block.
|
||||
|
||||
Args:
|
||||
block: The joint transformer block
|
||||
hidden_states: Image hidden states
|
||||
encoder_hidden_states: Text hidden states
|
||||
text_embeddings: Conditioning embeddings
|
||||
rotary_embeddings: Rotary position embeddings (format varies by model)
|
||||
kv_cache: KV cache (None if not using cache)
|
||||
mode: CACHING or PATCHED mode
|
||||
text_seq_len: Text sequence length
|
||||
patch_start: Start index for patched mode
|
||||
patch_end: End index for patched mode
|
||||
**kwargs: Additional model-specific arguments (e.g., encoder_hidden_states_mask,
|
||||
block_idx for Qwen)
|
||||
|
||||
Returns:
|
||||
Tuple of (encoder_hidden_states, hidden_states)
|
||||
"""
|
||||
...
|
||||
|
||||
def apply_single_block(
|
||||
self,
|
||||
block: SingleBlockInterface,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: mx.array,
|
||||
kv_cache: ImagePatchKVCache | None,
|
||||
mode: "BlockWrapperMode",
|
||||
text_seq_len: int,
|
||||
patch_start: int | None = None,
|
||||
patch_end: int | None = None,
|
||||
) -> mx.array:
|
||||
"""Apply a single transformer block.
|
||||
|
||||
Args:
|
||||
block: The single transformer block
|
||||
hidden_states: Concatenated [text + image] hidden states
|
||||
text_embeddings: Conditioning embeddings
|
||||
rotary_embeddings: Rotary position embeddings
|
||||
kv_cache: KV cache (None if not using cache)
|
||||
mode: CACHING or PATCHED mode
|
||||
text_seq_len: Text sequence length
|
||||
patch_start: Start index for patched mode
|
||||
patch_end: End index for patched mode
|
||||
|
||||
Returns:
|
||||
Output hidden states
|
||||
"""
|
||||
...
|
||||
|
||||
def final_projection(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
) -> mx.array:
|
||||
"""Apply final norm and projection.
|
||||
|
||||
Args:
|
||||
hidden_states: Hidden states (image only, text already removed)
|
||||
text_embeddings: Conditioning embeddings
|
||||
|
||||
Returns:
|
||||
Projected output
|
||||
"""
|
||||
...
|
||||
|
||||
def get_joint_blocks(self) -> list[JointBlockInterface]:
|
||||
"""Get the list of joint transformer blocks from the model."""
|
||||
...
|
||||
|
||||
def get_single_blocks(self) -> list[SingleBlockInterface]:
|
||||
"""Get the list of single transformer blocks from the model."""
|
||||
...
|
||||
|
||||
def slice_transformer_blocks(
|
||||
self,
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
total_joint_blocks: int,
|
||||
total_single_blocks: int,
|
||||
):
|
||||
"""Remove transformer blocks outside the assigned range.
|
||||
|
||||
This should be called BEFORE mx.eval() to avoid loading unused weights
|
||||
in distributed mode.
|
||||
|
||||
Args:
|
||||
start_layer: First layer index (inclusive) assigned to this node
|
||||
end_layer: Last layer index (exclusive) assigned to this node
|
||||
total_joint_blocks: Total number of joint blocks in the model
|
||||
total_single_blocks: Total number of single blocks in the model
|
||||
"""
|
||||
...
|
||||
|
||||
def merge_streams(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
) -> mx.array:
|
||||
"""Merge image and text streams for transition to single blocks.
|
||||
|
||||
This is called at the transition point from joint blocks (which process
|
||||
image and text separately) to single blocks (which process them
|
||||
together). Override to customize the merge strategy.
|
||||
|
||||
Args:
|
||||
hidden_states: Image hidden states
|
||||
encoder_hidden_states: Text hidden states
|
||||
|
||||
Returns:
|
||||
Merged hidden states (default: concatenate [text, image])
|
||||
"""
|
||||
...
|
||||
|
||||
def create_latents(self, seed: int, runtime_config: RuntimeConfig) -> mx.array:
|
||||
"""Create initial noise latents for generation.
|
||||
|
||||
Args:
|
||||
seed: Random seed
|
||||
runtime_config: Runtime configuration with dimensions
|
||||
|
||||
Returns:
|
||||
Initial latent tensor
|
||||
"""
|
||||
...
|
||||
|
||||
def encode_prompt(self, prompt: str) -> PromptData:
|
||||
"""Encode prompt into model-specific prompt data.
|
||||
|
||||
Args:
|
||||
prompt: Text prompt
|
||||
|
||||
Returns:
|
||||
PromptData containing embeddings (and model-specific extras)
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def needs_cfg(self) -> bool:
|
||||
"""Whether this model uses classifier-free guidance.
|
||||
|
||||
Returns:
|
||||
True if model requires two forward passes with guidance (e.g., Qwen)
|
||||
False if model uses a single forward pass (e.g., Flux)
|
||||
"""
|
||||
...
|
||||
|
||||
def apply_guidance(
|
||||
self,
|
||||
noise_positive: mx.array,
|
||||
noise_negative: mx.array,
|
||||
guidance_scale: float,
|
||||
) -> mx.array:
|
||||
"""Apply classifier-free guidance to combine positive/negative predictions.
|
||||
|
||||
Only called when needs_cfg is True.
|
||||
|
||||
Args:
|
||||
noise_positive: Noise prediction from positive prompt
|
||||
noise_negative: Noise prediction from negative prompt
|
||||
guidance_scale: Guidance strength
|
||||
|
||||
Returns:
|
||||
Guided noise prediction
|
||||
"""
|
||||
...
|
||||
|
||||
def decode_latents(
|
||||
self,
|
||||
latents: mx.array,
|
||||
runtime_config: RuntimeConfig,
|
||||
seed: int,
|
||||
prompt: str,
|
||||
) -> Any:
|
||||
"""Decode latents to final image.
|
||||
|
||||
Args:
|
||||
latents: Final denoised latents
|
||||
runtime_config: Runtime configuration
|
||||
seed: Random seed (for metadata)
|
||||
prompt: Text prompt (for metadata)
|
||||
|
||||
Returns:
|
||||
GeneratedImage result
|
||||
"""
|
||||
...
|
||||
|
||||
def set_image_dimensions(self, image_path: Path) -> tuple[int, int] | None:
|
||||
"""Compute and store dimensions from input image for edit mode.
|
||||
|
||||
For edit adapters: computes dimensions from input image aspect ratio,
|
||||
stores image paths internally for encode_prompt(), returns (width, height).
|
||||
|
||||
For standard adapters: returns None (use user-specified dimensions).
|
||||
|
||||
Args:
|
||||
image_path: Path to the input image
|
||||
|
||||
Returns:
|
||||
Tuple of (width, height) if dimensions were computed, None otherwise.
|
||||
"""
|
||||
...
|
||||
146
src/exo/worker/engines/image/pipeline/block_wrapper.py
Normal file
146
src/exo/worker/engines/image/pipeline/block_wrapper.py
Normal file
@@ -0,0 +1,146 @@
|
||||
from typing import Any
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from exo.worker.engines.image.pipeline.adapter import (
|
||||
BlockWrapperMode,
|
||||
JointBlockInterface,
|
||||
ModelAdapter,
|
||||
SingleBlockInterface,
|
||||
)
|
||||
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
|
||||
|
||||
|
||||
class JointBlockWrapper:
|
||||
"""Unified wrapper for joint transformer blocks.
|
||||
|
||||
Handles both CACHING (sync) and PATCHED (async) modes by delegating
|
||||
to the model adapter for model-specific attention computation.
|
||||
|
||||
The wrapper is created once at initialization and reused across calls.
|
||||
Mode and KV cache are passed at call time to support switching between
|
||||
sync and async pipelines.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block: JointBlockInterface,
|
||||
adapter: ModelAdapter,
|
||||
):
|
||||
"""Initialize the joint block wrapper.
|
||||
|
||||
Args:
|
||||
block: The joint transformer block to wrap
|
||||
adapter: Model adapter for model-specific operations
|
||||
"""
|
||||
self.block = block
|
||||
self.adapter = adapter
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: mx.array,
|
||||
text_seq_len: int,
|
||||
kv_cache: ImagePatchKVCache | None,
|
||||
mode: BlockWrapperMode,
|
||||
patch_start: int | None = None,
|
||||
patch_end: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Apply the joint block.
|
||||
|
||||
Args:
|
||||
hidden_states: Image hidden states (full or patch depending on mode)
|
||||
encoder_hidden_states: Text hidden states
|
||||
text_embeddings: Conditioning embeddings
|
||||
rotary_embeddings: Rotary position embeddings
|
||||
text_seq_len: Text sequence length
|
||||
kv_cache: KV cache for storing/retrieving image K/V (None if not using cache)
|
||||
mode: CACHING (populate cache) or PATCHED (use cached K/V)
|
||||
patch_start: Start index for patched mode (required if mode=PATCHED)
|
||||
patch_end: End index for patched mode (required if mode=PATCHED)
|
||||
**kwargs: Additional model-specific arguments (e.g., encoder_hidden_states_mask,
|
||||
block_idx for Qwen)
|
||||
|
||||
Returns:
|
||||
Tuple of (encoder_hidden_states, hidden_states)
|
||||
"""
|
||||
return self.adapter.apply_joint_block(
|
||||
block=self.block,
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=rotary_embeddings,
|
||||
kv_cache=kv_cache,
|
||||
mode=mode,
|
||||
text_seq_len=text_seq_len,
|
||||
patch_start=patch_start,
|
||||
patch_end=patch_end,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class SingleBlockWrapper:
|
||||
"""Unified wrapper for single transformer blocks.
|
||||
|
||||
Handles both CACHING (sync) and PATCHED (async) modes by delegating
|
||||
to the model adapter for model-specific attention computation.
|
||||
|
||||
The wrapper is created once at initialization and reused across calls.
|
||||
Mode and KV cache are passed at call time to support switching between
|
||||
sync and async pipelines.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block: SingleBlockInterface,
|
||||
adapter: ModelAdapter,
|
||||
):
|
||||
"""Initialize the single block wrapper.
|
||||
|
||||
Args:
|
||||
block: The single transformer block to wrap
|
||||
adapter: Model adapter for model-specific operations
|
||||
"""
|
||||
self.block = block
|
||||
self.adapter = adapter
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: mx.array,
|
||||
text_seq_len: int,
|
||||
kv_cache: ImagePatchKVCache | None,
|
||||
mode: BlockWrapperMode,
|
||||
patch_start: int | None = None,
|
||||
patch_end: int | None = None,
|
||||
) -> mx.array:
|
||||
"""Apply the single block.
|
||||
|
||||
Args:
|
||||
hidden_states: [text + image] hidden states (full or patch depending on mode)
|
||||
text_embeddings: Conditioning embeddings
|
||||
rotary_embeddings: Rotary position embeddings
|
||||
text_seq_len: Text sequence length
|
||||
kv_cache: KV cache for storing/retrieving image K/V (None if not using cache)
|
||||
mode: CACHING (populate cache) or PATCHED (use cached K/V)
|
||||
patch_start: Start index for patched mode (required if mode=PATCHED)
|
||||
patch_end: End index for patched mode (required if mode=PATCHED)
|
||||
|
||||
Returns:
|
||||
Output hidden states
|
||||
"""
|
||||
return self.adapter.apply_single_block(
|
||||
block=self.block,
|
||||
hidden_states=hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=rotary_embeddings,
|
||||
kv_cache=kv_cache,
|
||||
mode=mode,
|
||||
text_seq_len=text_seq_len,
|
||||
patch_start=patch_start,
|
||||
patch_end=patch_end,
|
||||
)
|
||||
72
src/exo/worker/engines/image/pipeline/kv_cache.py
Normal file
72
src/exo/worker/engines/image/pipeline/kv_cache.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
class ImagePatchKVCache:
|
||||
"""KV cache that stores only IMAGE K/V with patch-level updates.
|
||||
|
||||
Only caches image K/V since:
|
||||
- Text K/V is always computed fresh (same for all patches)
|
||||
- Only image portion needs stale/fresh cache management across patches
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_heads: int,
|
||||
image_seq_len: int,
|
||||
head_dim: int,
|
||||
dtype: mx.Dtype = mx.float32,
|
||||
):
|
||||
self.batch_size = batch_size
|
||||
self.num_heads = num_heads
|
||||
self.image_seq_len = image_seq_len
|
||||
self.head_dim = head_dim
|
||||
self._dtype = dtype
|
||||
|
||||
self.key_cache = mx.zeros(
|
||||
(batch_size, num_heads, image_seq_len, head_dim), dtype=dtype
|
||||
)
|
||||
self.value_cache = mx.zeros(
|
||||
(batch_size, num_heads, image_seq_len, head_dim), dtype=dtype
|
||||
)
|
||||
|
||||
def update_image_patch(
|
||||
self, patch_start: int, patch_end: int, key: mx.array, value: mx.array
|
||||
) -> None:
|
||||
"""Update cache with fresh K/V for an image patch slice.
|
||||
|
||||
Args:
|
||||
patch_start: Start token index within image portion (0-indexed)
|
||||
patch_end: End token index within image portion
|
||||
key: Fresh key tensor [batch, heads, patch_seq_len, head_dim]
|
||||
value: Fresh value tensor [batch, heads, patch_seq_len, head_dim]
|
||||
"""
|
||||
self.key_cache[:, :, patch_start:patch_end, :] = key
|
||||
self.value_cache[:, :, patch_start:patch_end, :] = value
|
||||
|
||||
def get_full_kv(
|
||||
self, text_key: mx.array, text_value: mx.array
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Return full K/V by concatenating fresh text K/V with cached image K/V.
|
||||
|
||||
Args:
|
||||
text_key: Fresh text key tensor [batch, heads, text_seq_len, head_dim]
|
||||
text_value: Fresh text value tensor [batch, heads, text_seq_len, head_dim]
|
||||
|
||||
Returns:
|
||||
Tuple of (full_key, full_value) with shape [batch, heads, text+image, head_dim]
|
||||
"""
|
||||
full_key = mx.concatenate([text_key, self.key_cache], axis=2)
|
||||
full_value = mx.concatenate([text_value, self.value_cache], axis=2)
|
||||
return full_key, full_value
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset cache to zeros."""
|
||||
self.key_cache = mx.zeros(
|
||||
(self.batch_size, self.num_heads, self.image_seq_len, self.head_dim),
|
||||
dtype=self._dtype,
|
||||
)
|
||||
self.value_cache = mx.zeros(
|
||||
(self.batch_size, self.num_heads, self.image_seq_len, self.head_dim),
|
||||
dtype=self._dtype,
|
||||
)
|
||||
975
src/exo/worker/engines/image/pipeline/runner.py
Normal file
975
src/exo/worker/engines/image/pipeline/runner.py
Normal file
@@ -0,0 +1,975 @@
|
||||
from math import ceil
|
||||
from typing import Any, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
from mflux.callbacks.callbacks import Callbacks
|
||||
from mflux.config.config import Config
|
||||
from mflux.config.runtime_config import RuntimeConfig
|
||||
from mflux.utils.exceptions import StopImageGenerationException
|
||||
from tqdm import tqdm
|
||||
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.pipeline.adapter import (
|
||||
BlockWrapperMode,
|
||||
ModelAdapter,
|
||||
PromptData,
|
||||
)
|
||||
from exo.worker.engines.image.pipeline.block_wrapper import (
|
||||
JointBlockWrapper,
|
||||
SingleBlockWrapper,
|
||||
)
|
||||
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
|
||||
|
||||
|
||||
def calculate_patch_heights(latent_height: int, num_patches: int):
|
||||
patch_height = ceil(latent_height / num_patches)
|
||||
|
||||
actual_num_patches = ceil(latent_height / patch_height)
|
||||
patch_heights = [patch_height] * (actual_num_patches - 1)
|
||||
|
||||
last_height = latent_height - patch_height * (actual_num_patches - 1)
|
||||
patch_heights.append(last_height)
|
||||
|
||||
return patch_heights, actual_num_patches
|
||||
|
||||
|
||||
def calculate_token_indices(patch_heights: list[int], latent_width: int):
|
||||
tokens_per_row = latent_width
|
||||
|
||||
token_ranges = []
|
||||
cumulative_height = 0
|
||||
|
||||
for h in patch_heights:
|
||||
start_token = tokens_per_row * cumulative_height
|
||||
end_token = tokens_per_row * (cumulative_height + h)
|
||||
|
||||
token_ranges.append((start_token, end_token))
|
||||
cumulative_height += h
|
||||
|
||||
return token_ranges
|
||||
|
||||
|
||||
class DiffusionRunner:
|
||||
"""Orchestrates the diffusion loop for image generation.
|
||||
|
||||
This class owns the entire diffusion process, handling both single-node
|
||||
and distributed (PipeFusion) modes.
|
||||
|
||||
In distributed mode, it implements PipeFusion with:
|
||||
- Sync pipeline for initial timesteps (full image, all devices in lockstep)
|
||||
- Async pipeline for later timesteps (patches processed independently)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ImageModelConfig,
|
||||
adapter: ModelAdapter,
|
||||
group: Optional[mx.distributed.Group],
|
||||
shard_metadata: PipelineShardMetadata,
|
||||
num_sync_steps: int = 1,
|
||||
num_patches: Optional[int] = None,
|
||||
):
|
||||
"""Initialize the diffusion runner.
|
||||
|
||||
Args:
|
||||
config: Model configuration (architecture, block counts, etc.)
|
||||
adapter: Model adapter for model-specific operations
|
||||
group: MLX distributed group (None for single-node mode)
|
||||
shard_metadata: Pipeline shard metadata with layer assignments
|
||||
num_sync_steps: Number of synchronous timesteps before async mode
|
||||
num_patches: Number of patches for async mode (defaults to world_size)
|
||||
"""
|
||||
self.config = config
|
||||
self.adapter = adapter
|
||||
self.group = group
|
||||
|
||||
# Handle single-node vs distributed mode
|
||||
if group is None:
|
||||
self.rank = 0
|
||||
self.world_size = 1
|
||||
self.next_rank = 0
|
||||
self.prev_rank = 0
|
||||
self.start_layer = 0
|
||||
self.end_layer = config.total_blocks
|
||||
else:
|
||||
self.rank = shard_metadata.device_rank
|
||||
self.world_size = shard_metadata.world_size
|
||||
self.next_rank = (self.rank + 1) % self.world_size
|
||||
self.prev_rank = (self.rank - 1 + self.world_size) % self.world_size
|
||||
self.start_layer = shard_metadata.start_layer
|
||||
self.end_layer = shard_metadata.end_layer
|
||||
|
||||
self.num_sync_steps = num_sync_steps
|
||||
self.num_patches = num_patches if num_patches else max(1, self.world_size)
|
||||
|
||||
# Persistent KV caches (initialized on first async timestep, reused across timesteps)
|
||||
self.joint_kv_caches: list[ImagePatchKVCache] | None = None
|
||||
self.single_kv_caches: list[ImagePatchKVCache] | None = None
|
||||
|
||||
# Get block counts from config (model-agnostic)
|
||||
self.total_joint = config.joint_block_count
|
||||
self.total_single = config.single_block_count
|
||||
self.total_layers = config.total_blocks
|
||||
|
||||
self._compute_assigned_blocks()
|
||||
|
||||
def _compute_assigned_blocks(self) -> None:
|
||||
"""Determine which joint/single blocks this stage owns."""
|
||||
start = self.start_layer
|
||||
end = self.end_layer
|
||||
|
||||
if end <= self.total_joint:
|
||||
# All assigned blocks are joint blocks
|
||||
self.joint_start = start
|
||||
self.joint_end = end
|
||||
self.single_start = 0
|
||||
self.single_end = 0
|
||||
elif start >= self.total_joint:
|
||||
# All assigned blocks are single blocks
|
||||
self.joint_start = 0
|
||||
self.joint_end = 0
|
||||
self.single_start = start - self.total_joint
|
||||
self.single_end = end - self.total_joint
|
||||
else:
|
||||
# Stage spans joint→single transition
|
||||
self.joint_start = start
|
||||
self.joint_end = self.total_joint
|
||||
self.single_start = 0
|
||||
self.single_end = end - self.total_joint
|
||||
|
||||
self.has_joint_blocks = self.joint_end > self.joint_start
|
||||
self.has_single_blocks = self.single_end > self.single_start
|
||||
|
||||
self.owns_concat_stage = self.has_joint_blocks and (
|
||||
self.has_single_blocks or self.end_layer == self.total_joint
|
||||
)
|
||||
|
||||
joint_blocks = self.adapter.get_joint_blocks()
|
||||
single_blocks = self.adapter.get_single_blocks()
|
||||
|
||||
# Wrap blocks at initialization (reused across all calls)
|
||||
self.joint_block_wrappers = [
|
||||
JointBlockWrapper(block=block, adapter=self.adapter)
|
||||
for block in joint_blocks
|
||||
]
|
||||
self.single_block_wrappers = [
|
||||
SingleBlockWrapper(block=block, adapter=self.adapter)
|
||||
for block in single_blocks
|
||||
]
|
||||
|
||||
@property
|
||||
def is_first_stage(self) -> bool:
|
||||
return self.rank == 0
|
||||
|
||||
@property
|
||||
def is_last_stage(self) -> bool:
|
||||
return self.rank == self.world_size - 1
|
||||
|
||||
@property
|
||||
def is_distributed(self) -> bool:
|
||||
return self.group is not None
|
||||
|
||||
def _calculate_capture_steps(
|
||||
self,
|
||||
partial_images: int,
|
||||
init_time_step: int,
|
||||
num_inference_steps: int,
|
||||
) -> set[int]:
|
||||
"""Calculate which timesteps should produce partial images.
|
||||
|
||||
Evenly spaces `partial_images` captures across the diffusion loop.
|
||||
Does NOT include the final timestep (that's the complete image).
|
||||
|
||||
Args:
|
||||
partial_images: Number of partial images to capture
|
||||
init_time_step: Starting timestep (for img2img this may not be 0)
|
||||
num_inference_steps: Total inference steps
|
||||
|
||||
Returns:
|
||||
Set of timestep indices to capture
|
||||
"""
|
||||
if partial_images <= 0:
|
||||
return set()
|
||||
|
||||
total_steps = num_inference_steps - init_time_step
|
||||
if total_steps <= 1:
|
||||
return set()
|
||||
|
||||
if partial_images >= total_steps - 1:
|
||||
# Capture every step except final
|
||||
return set(range(init_time_step, num_inference_steps - 1))
|
||||
|
||||
# Evenly space partial captures
|
||||
step_interval = total_steps / (partial_images + 1)
|
||||
capture_steps: set[int] = set()
|
||||
for i in range(1, partial_images + 1):
|
||||
step_idx = int(init_time_step + i * step_interval)
|
||||
# Ensure we don't capture the final step
|
||||
if step_idx < num_inference_steps - 1:
|
||||
capture_steps.add(step_idx)
|
||||
|
||||
return capture_steps
|
||||
|
||||
def generate_image(
|
||||
self,
|
||||
settings: Config,
|
||||
prompt: str,
|
||||
seed: int,
|
||||
partial_images: int = 0,
|
||||
):
|
||||
"""Primary entry point for image generation.
|
||||
|
||||
Orchestrates the full generation flow:
|
||||
1. Create runtime config
|
||||
2. Create initial latents
|
||||
3. Encode prompt
|
||||
4. Run diffusion loop (yielding partials if requested)
|
||||
5. Decode to image
|
||||
|
||||
When partial_images > 0, yields (GeneratedImage, partial_index, total_partials)
|
||||
tuples for intermediate images, then yields the final GeneratedImage.
|
||||
|
||||
Args:
|
||||
settings: Generation config (steps, height, width)
|
||||
prompt: Text prompt
|
||||
seed: Random seed
|
||||
partial_images: Number of intermediate images to yield (0 for none)
|
||||
|
||||
Yields:
|
||||
Partial images as (GeneratedImage, partial_index, total_partials) tuples
|
||||
Final GeneratedImage
|
||||
"""
|
||||
runtime_config = RuntimeConfig(settings, self.adapter.model.model_config)
|
||||
latents = self.adapter.create_latents(seed, runtime_config)
|
||||
prompt_data = self.adapter.encode_prompt(prompt)
|
||||
|
||||
# Calculate which steps to capture
|
||||
capture_steps = self._calculate_capture_steps(
|
||||
partial_images=partial_images,
|
||||
init_time_step=runtime_config.init_time_step,
|
||||
num_inference_steps=runtime_config.num_inference_steps,
|
||||
)
|
||||
|
||||
# Run diffusion loop - may yield partial latents
|
||||
diffusion_gen = self._run_diffusion_loop(
|
||||
latents=latents,
|
||||
prompt_data=prompt_data,
|
||||
runtime_config=runtime_config,
|
||||
seed=seed,
|
||||
prompt=prompt,
|
||||
capture_steps=capture_steps,
|
||||
)
|
||||
|
||||
# Process partial yields and get final latents
|
||||
partial_index = 0
|
||||
total_partials = len(capture_steps)
|
||||
|
||||
if capture_steps:
|
||||
# Generator mode - iterate to get partials and final latents
|
||||
try:
|
||||
while True:
|
||||
partial_latents, _step = next(diffusion_gen)
|
||||
if self.is_last_stage:
|
||||
partial_image = self.adapter.decode_latents(
|
||||
partial_latents, runtime_config, seed, prompt
|
||||
)
|
||||
yield (partial_image, partial_index, total_partials)
|
||||
partial_index += 1
|
||||
except StopIteration as e:
|
||||
latents = e.value
|
||||
else:
|
||||
# No partials - just consume generator to get final latents
|
||||
try:
|
||||
while True:
|
||||
next(diffusion_gen)
|
||||
except StopIteration as e:
|
||||
latents = e.value
|
||||
|
||||
# Yield final image (only on last stage)
|
||||
if self.is_last_stage:
|
||||
yield self.adapter.decode_latents(latents, runtime_config, seed, prompt)
|
||||
|
||||
def _run_diffusion_loop(
|
||||
self,
|
||||
latents: mx.array,
|
||||
prompt_data: PromptData,
|
||||
runtime_config: RuntimeConfig,
|
||||
seed: int,
|
||||
prompt: str,
|
||||
capture_steps: set[int] | None = None,
|
||||
):
|
||||
"""Execute the diffusion loop, optionally yielding at capture steps.
|
||||
|
||||
When capture_steps is provided and non-empty, this becomes a generator
|
||||
that yields (latents, step_index) tuples at the specified timesteps.
|
||||
Only the last stage yields (others have incomplete latents).
|
||||
|
||||
Args:
|
||||
latents: Initial noise latents
|
||||
prompt_data: Encoded prompt data
|
||||
runtime_config: RuntimeConfig with scheduler, steps, dimensions
|
||||
seed: Random seed (for callbacks)
|
||||
prompt: Text prompt (for callbacks)
|
||||
capture_steps: Set of timestep indices to capture (None = no captures)
|
||||
|
||||
Yields:
|
||||
(latents, step_index) tuples at capture steps (last stage only)
|
||||
|
||||
Returns:
|
||||
Final denoised latents ready for VAE decoding
|
||||
"""
|
||||
if capture_steps is None:
|
||||
capture_steps = set()
|
||||
|
||||
time_steps = tqdm(range(runtime_config.num_inference_steps))
|
||||
|
||||
# Call subscribers for beginning of loop
|
||||
Callbacks.before_loop(
|
||||
seed=seed,
|
||||
prompt=prompt,
|
||||
latents=latents,
|
||||
config=runtime_config,
|
||||
)
|
||||
|
||||
for t in time_steps:
|
||||
try:
|
||||
latents = self._diffusion_step(
|
||||
t=t,
|
||||
config=runtime_config,
|
||||
latents=latents,
|
||||
prompt_data=prompt_data,
|
||||
)
|
||||
|
||||
# Call subscribers in-loop
|
||||
Callbacks.in_loop(
|
||||
t=t,
|
||||
seed=seed,
|
||||
prompt=prompt,
|
||||
latents=latents,
|
||||
config=runtime_config,
|
||||
time_steps=time_steps,
|
||||
)
|
||||
|
||||
mx.eval(latents)
|
||||
|
||||
# Yield partial latents at capture steps (only on last stage)
|
||||
if t in capture_steps and self.is_last_stage:
|
||||
yield (latents, t)
|
||||
|
||||
except KeyboardInterrupt: # noqa: PERF203
|
||||
Callbacks.interruption(
|
||||
t=t,
|
||||
seed=seed,
|
||||
prompt=prompt,
|
||||
latents=latents,
|
||||
config=runtime_config,
|
||||
time_steps=time_steps,
|
||||
)
|
||||
raise StopImageGenerationException(
|
||||
f"Stopping image generation at step {t + 1}/{len(time_steps)}"
|
||||
) from None
|
||||
|
||||
# Call subscribers after loop
|
||||
Callbacks.after_loop(
|
||||
seed=seed,
|
||||
prompt=prompt,
|
||||
latents=latents,
|
||||
config=runtime_config,
|
||||
)
|
||||
|
||||
return latents
|
||||
|
||||
def _forward_pass(
|
||||
self,
|
||||
latents: mx.array,
|
||||
prompt_embeds: mx.array,
|
||||
pooled_prompt_embeds: mx.array,
|
||||
kwargs: dict[str, Any],
|
||||
) -> mx.array:
|
||||
"""Run a single forward pass through the transformer.
|
||||
|
||||
This is the internal method called by adapters via compute_step_noise.
|
||||
Returns noise prediction without applying scheduler step.
|
||||
|
||||
For edit mode, concatenates conditioning latents with generated latents
|
||||
before the transformer, and extracts only the generated portion after.
|
||||
|
||||
Args:
|
||||
latents: Input latents (already scaled by caller)
|
||||
prompt_embeds: Text embeddings
|
||||
pooled_prompt_embeds: Pooled text embeddings (Flux) or placeholder (Qwen)
|
||||
kwargs: Model-specific arguments (e.g., encoder_hidden_states_mask, t)
|
||||
|
||||
Returns:
|
||||
Noise prediction tensor
|
||||
"""
|
||||
t = kwargs.get("t", 0)
|
||||
config = kwargs.get("config")
|
||||
if config is None:
|
||||
raise ValueError("config must be provided in kwargs")
|
||||
scaled_latents = config.scheduler.scale_model_input(latents, t)
|
||||
|
||||
# For edit mode: concatenate with conditioning latents
|
||||
conditioning_latents = kwargs.get("conditioning_latents")
|
||||
original_latent_tokens = scaled_latents.shape[1]
|
||||
if conditioning_latents is not None:
|
||||
scaled_latents = mx.concatenate(
|
||||
[scaled_latents, conditioning_latents], axis=1
|
||||
)
|
||||
|
||||
hidden_states, encoder_hidden_states = self.adapter.compute_embeddings(
|
||||
scaled_latents, prompt_embeds
|
||||
)
|
||||
text_embeddings = self.adapter.compute_text_embeddings(
|
||||
t, config, pooled_prompt_embeds, hidden_states=hidden_states
|
||||
)
|
||||
rotary_embeddings = self.adapter.compute_rotary_embeddings(
|
||||
prompt_embeds, config, **kwargs
|
||||
)
|
||||
|
||||
text_seq_len = prompt_embeds.shape[1]
|
||||
|
||||
# Run through all joint blocks
|
||||
for block_idx, wrapper in enumerate(self.joint_block_wrappers):
|
||||
encoder_hidden_states, hidden_states = wrapper(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=rotary_embeddings,
|
||||
text_seq_len=text_seq_len,
|
||||
kv_cache=None,
|
||||
mode=BlockWrapperMode.CACHING,
|
||||
block_idx=block_idx,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Merge streams
|
||||
if self.joint_block_wrappers:
|
||||
hidden_states = self.adapter.merge_streams(
|
||||
hidden_states, encoder_hidden_states
|
||||
)
|
||||
|
||||
# Run through single blocks
|
||||
for wrapper in self.single_block_wrappers:
|
||||
hidden_states = wrapper(
|
||||
hidden_states=hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=rotary_embeddings,
|
||||
text_seq_len=text_seq_len,
|
||||
kv_cache=None,
|
||||
mode=BlockWrapperMode.CACHING,
|
||||
)
|
||||
|
||||
# Extract image portion and project
|
||||
hidden_states = hidden_states[:, text_seq_len:, ...]
|
||||
|
||||
# For edit mode: extract only the generated portion (exclude conditioning latents)
|
||||
if conditioning_latents is not None:
|
||||
hidden_states = hidden_states[:, :original_latent_tokens, ...]
|
||||
|
||||
return self.adapter.final_projection(hidden_states, text_embeddings)
|
||||
|
||||
def _diffusion_step(
|
||||
self,
|
||||
t: int,
|
||||
config: RuntimeConfig,
|
||||
latents: mx.array,
|
||||
prompt_data: PromptData,
|
||||
) -> mx.array:
|
||||
"""Execute a single diffusion step.
|
||||
|
||||
Routes to single-node, sync pipeline, or async pipeline based on
|
||||
configuration and current timestep.
|
||||
"""
|
||||
if self.group is None:
|
||||
return self._single_node_step(t, config, latents, prompt_data)
|
||||
elif t < config.init_time_step + self.num_sync_steps:
|
||||
return self._sync_pipeline(
|
||||
t,
|
||||
config,
|
||||
latents,
|
||||
prompt_data,
|
||||
)
|
||||
else:
|
||||
return self._async_pipeline_step(
|
||||
t,
|
||||
config,
|
||||
latents,
|
||||
prompt_data,
|
||||
)
|
||||
|
||||
def _single_node_step(
|
||||
self,
|
||||
t: int,
|
||||
config: RuntimeConfig,
|
||||
latents: mx.array,
|
||||
prompt_data: PromptData,
|
||||
) -> mx.array:
|
||||
"""Execute a single diffusion step on a single node (no distribution)."""
|
||||
base_kwargs = {"t": t, "config": config}
|
||||
|
||||
# For edit mode: include conditioning latents
|
||||
if prompt_data.conditioning_latents is not None:
|
||||
base_kwargs["conditioning_latents"] = prompt_data.conditioning_latents
|
||||
|
||||
if self.adapter.needs_cfg:
|
||||
# Two forward passes + guidance for CFG models (e.g., Qwen)
|
||||
pos_kwargs = {
|
||||
**base_kwargs,
|
||||
**prompt_data.get_extra_forward_kwargs(positive=True),
|
||||
}
|
||||
noise_pos = self._forward_pass(
|
||||
latents,
|
||||
prompt_data.prompt_embeds,
|
||||
prompt_data.pooled_prompt_embeds,
|
||||
pos_kwargs,
|
||||
)
|
||||
|
||||
neg_kwargs = {
|
||||
**base_kwargs,
|
||||
**prompt_data.get_extra_forward_kwargs(positive=False),
|
||||
}
|
||||
noise_neg = self._forward_pass(
|
||||
latents,
|
||||
prompt_data.negative_prompt_embeds,
|
||||
prompt_data.negative_pooled_prompt_embeds,
|
||||
neg_kwargs,
|
||||
)
|
||||
|
||||
assert self.config.guidance_scale is not None
|
||||
noise = self.adapter.apply_guidance(
|
||||
noise_pos, noise_neg, guidance_scale=self.config.guidance_scale
|
||||
)
|
||||
else:
|
||||
# Single forward pass for non-CFG models (e.g., Flux)
|
||||
kwargs = {**base_kwargs, **prompt_data.get_extra_forward_kwargs()}
|
||||
noise = self._forward_pass(
|
||||
latents,
|
||||
prompt_data.prompt_embeds,
|
||||
prompt_data.pooled_prompt_embeds,
|
||||
kwargs,
|
||||
)
|
||||
|
||||
return config.scheduler.step(model_output=noise, timestep=t, sample=latents)
|
||||
|
||||
def _initialize_kv_caches(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_img_tokens: int,
|
||||
dtype: mx.Dtype,
|
||||
) -> None:
|
||||
"""Initialize KV caches for both sync and async pipelines.
|
||||
|
||||
Note: Caches only store IMAGE K/V, not text K/V. Text K/V is always
|
||||
computed fresh and doesn't need caching (it's the same for all patches).
|
||||
"""
|
||||
self.joint_kv_caches = [
|
||||
ImagePatchKVCache(
|
||||
batch_size=batch_size,
|
||||
num_heads=self.config.num_heads,
|
||||
image_seq_len=num_img_tokens,
|
||||
head_dim=self.config.head_dim,
|
||||
dtype=dtype,
|
||||
)
|
||||
for _ in range(len(self.joint_block_wrappers))
|
||||
]
|
||||
self.single_kv_caches = [
|
||||
ImagePatchKVCache(
|
||||
batch_size=batch_size,
|
||||
num_heads=self.config.num_heads,
|
||||
image_seq_len=num_img_tokens,
|
||||
head_dim=self.config.head_dim,
|
||||
dtype=dtype,
|
||||
)
|
||||
for _ in range(len(self.single_block_wrappers))
|
||||
]
|
||||
|
||||
def _create_patches(
|
||||
self,
|
||||
latents: mx.array,
|
||||
config: RuntimeConfig,
|
||||
) -> tuple[list[mx.array], list[tuple[int, int]]]:
|
||||
"""Split latents into patches for async pipeline."""
|
||||
# Use 16 to match FluxLatentCreator.create_noise formula
|
||||
latent_height = config.height // 16
|
||||
latent_width = config.width // 16
|
||||
|
||||
patch_heights, _ = calculate_patch_heights(latent_height, self.num_patches)
|
||||
token_indices = calculate_token_indices(patch_heights, latent_width)
|
||||
|
||||
# Split latents into patches
|
||||
patch_latents = [latents[:, start:end, :] for start, end in token_indices]
|
||||
|
||||
return patch_latents, token_indices
|
||||
|
||||
def _sync_pipeline(
|
||||
self,
|
||||
t: int,
|
||||
config: RuntimeConfig,
|
||||
hidden_states: mx.array,
|
||||
prompt_data: PromptData,
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
prev_latents = hidden_states
|
||||
|
||||
# Extract embeddings and extra kwargs (e.g., encoder_hidden_states_mask for Qwen)
|
||||
prompt_embeds = prompt_data.prompt_embeds
|
||||
pooled_prompt_embeds = prompt_data.pooled_prompt_embeds
|
||||
extra_kwargs = prompt_data.get_extra_forward_kwargs()
|
||||
|
||||
hidden_states = config.scheduler.scale_model_input(hidden_states, t)
|
||||
|
||||
# For edit mode: handle conditioning latents
|
||||
# All stages need to know the total token count for correct recv templates
|
||||
conditioning_latents = prompt_data.conditioning_latents
|
||||
original_latent_tokens = hidden_states.shape[1]
|
||||
if conditioning_latents is not None:
|
||||
num_img_tokens = original_latent_tokens + conditioning_latents.shape[1]
|
||||
else:
|
||||
num_img_tokens = original_latent_tokens
|
||||
|
||||
# First stage: concatenate conditioning latents before embedding
|
||||
if self.is_first_stage and conditioning_latents is not None:
|
||||
hidden_states = mx.concatenate(
|
||||
[hidden_states, conditioning_latents], axis=1
|
||||
)
|
||||
|
||||
# === PHASE 1: Embeddings ===
|
||||
if self.is_first_stage:
|
||||
hidden_states, encoder_hidden_states = self.adapter.compute_embeddings(
|
||||
hidden_states, prompt_embeds
|
||||
)
|
||||
|
||||
# All stages need these for their blocks
|
||||
text_embeddings = self.adapter.compute_text_embeddings(
|
||||
t, config, pooled_prompt_embeds
|
||||
)
|
||||
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
|
||||
prompt_embeds,
|
||||
config,
|
||||
kontext_image_ids=kontext_image_ids,
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
# === Initialize KV caches to populate during sync for async warmstart ===
|
||||
batch_size = prev_latents.shape[0]
|
||||
text_seq_len = prompt_embeds.shape[1]
|
||||
hidden_dim = self.adapter.hidden_dim
|
||||
|
||||
if t == config.init_time_step:
|
||||
self._initialize_kv_caches(
|
||||
batch_size=batch_size,
|
||||
num_img_tokens=num_img_tokens,
|
||||
dtype=prev_latents.dtype,
|
||||
)
|
||||
|
||||
# === PHASE 2: Joint Blocks with Communication and Caching ===
|
||||
if self.has_joint_blocks:
|
||||
# Receive from previous stage (if not first stage)
|
||||
if not self.is_first_stage:
|
||||
recv_template = mx.zeros(
|
||||
(batch_size, num_img_tokens, hidden_dim), dtype=prev_latents.dtype
|
||||
)
|
||||
hidden_states = mx.distributed.recv_like(
|
||||
recv_template, self.prev_rank, group=self.group
|
||||
)
|
||||
enc_template = mx.zeros(
|
||||
(batch_size, text_seq_len, hidden_dim), dtype=prev_latents.dtype
|
||||
)
|
||||
encoder_hidden_states = mx.distributed.recv_like(
|
||||
enc_template, self.prev_rank, group=self.group
|
||||
)
|
||||
mx.eval(hidden_states, encoder_hidden_states)
|
||||
|
||||
# Run assigned joint blocks with caching mode
|
||||
for block_idx, wrapper in enumerate(self.joint_block_wrappers):
|
||||
encoder_hidden_states, hidden_states = wrapper(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
text_seq_len=text_seq_len,
|
||||
kv_cache=self.joint_kv_caches[block_idx],
|
||||
mode=BlockWrapperMode.CACHING,
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
# === PHASE 3: Joint→Single Transition ===
|
||||
if self.owns_concat_stage:
|
||||
# Merge encoder and hidden states using adapter hook
|
||||
concatenated = self.adapter.merge_streams(
|
||||
hidden_states, encoder_hidden_states
|
||||
)
|
||||
|
||||
if self.has_single_blocks or self.is_last_stage:
|
||||
# Keep locally: either for single blocks or final projection
|
||||
hidden_states = concatenated
|
||||
else:
|
||||
# Send concatenated state to next stage (which has single blocks)
|
||||
mx.eval(
|
||||
mx.distributed.send(concatenated, self.next_rank, group=self.group)
|
||||
)
|
||||
|
||||
elif self.has_joint_blocks and not self.is_last_stage:
|
||||
# Send joint block outputs to next stage (which has more joint blocks)
|
||||
mx.eval(
|
||||
mx.distributed.send(hidden_states, self.next_rank, group=self.group),
|
||||
mx.distributed.send(
|
||||
encoder_hidden_states, self.next_rank, group=self.group
|
||||
),
|
||||
)
|
||||
|
||||
# === PHASE 4: Single Blocks with Communication and Caching ===
|
||||
if self.has_single_blocks:
|
||||
# Receive from previous stage if we didn't do concatenation
|
||||
if not self.owns_concat_stage and not self.is_first_stage:
|
||||
recv_template = mx.zeros(
|
||||
(batch_size, text_seq_len + num_img_tokens, hidden_dim),
|
||||
dtype=prev_latents.dtype,
|
||||
)
|
||||
hidden_states = mx.distributed.recv_like(
|
||||
recv_template, self.prev_rank, group=self.group
|
||||
)
|
||||
mx.eval(hidden_states)
|
||||
|
||||
# Run assigned single blocks with caching mode
|
||||
for block_idx, wrapper in enumerate(self.single_block_wrappers):
|
||||
hidden_states = wrapper(
|
||||
hidden_states=hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
text_seq_len=text_seq_len,
|
||||
kv_cache=self.single_kv_caches[block_idx],
|
||||
mode=BlockWrapperMode.CACHING,
|
||||
)
|
||||
|
||||
# Send to next stage if not last
|
||||
if not self.is_last_stage:
|
||||
mx.eval(
|
||||
mx.distributed.send(hidden_states, self.next_rank, group=self.group)
|
||||
)
|
||||
|
||||
# === PHASE 5: Last Stage - Final Projection + Scheduler ===
|
||||
# Extract image portion (remove text embeddings prefix)
|
||||
hidden_states = hidden_states[:, text_seq_len:, ...]
|
||||
|
||||
# For edit mode: extract only the generated portion (exclude conditioning latents)
|
||||
if conditioning_latents is not None:
|
||||
hidden_states = hidden_states[:, :original_latent_tokens, ...]
|
||||
|
||||
if self.is_last_stage:
|
||||
hidden_states = self.adapter.final_projection(
|
||||
hidden_states, text_embeddings
|
||||
)
|
||||
|
||||
hidden_states = config.scheduler.step(
|
||||
model_output=hidden_states,
|
||||
timestep=t,
|
||||
sample=prev_latents,
|
||||
)
|
||||
|
||||
if not self.is_first_stage:
|
||||
mx.eval(mx.distributed.send(hidden_states, 0, group=self.group))
|
||||
|
||||
elif self.is_first_stage:
|
||||
hidden_states = mx.distributed.recv_like(
|
||||
prev_latents, src=self.world_size - 1, group=self.group
|
||||
)
|
||||
|
||||
mx.eval(hidden_states)
|
||||
|
||||
else:
|
||||
# For shape correctness
|
||||
hidden_states = prev_latents
|
||||
|
||||
return hidden_states
|
||||
|
||||
def _async_pipeline_step(
|
||||
self,
|
||||
t: int,
|
||||
config: RuntimeConfig,
|
||||
latents: mx.array,
|
||||
prompt_data: PromptData,
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
patch_latents, token_indices = self._create_patches(latents, config)
|
||||
|
||||
patch_latents = self._async_pipeline(
|
||||
t,
|
||||
config,
|
||||
patch_latents,
|
||||
token_indices,
|
||||
prompt_data,
|
||||
kontext_image_ids,
|
||||
)
|
||||
|
||||
return mx.concatenate(patch_latents, axis=1)
|
||||
|
||||
def _async_pipeline(
|
||||
self,
|
||||
t: int,
|
||||
config: RuntimeConfig,
|
||||
patch_latents: list[mx.array],
|
||||
token_indices: list[tuple[int, int]],
|
||||
prompt_data: PromptData,
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
) -> list[mx.array]:
|
||||
"""Execute async pipeline for all patches."""
|
||||
assert self.joint_kv_caches is not None
|
||||
assert self.single_kv_caches is not None
|
||||
|
||||
# Extract embeddings and extra kwargs (e.g., encoder_hidden_states_mask for Qwen)
|
||||
prompt_embeds = prompt_data.prompt_embeds
|
||||
pooled_prompt_embeds = prompt_data.pooled_prompt_embeds
|
||||
extra_kwargs = prompt_data.get_extra_forward_kwargs()
|
||||
|
||||
text_embeddings = self.adapter.compute_text_embeddings(
|
||||
t, config, pooled_prompt_embeds
|
||||
)
|
||||
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
|
||||
prompt_embeds,
|
||||
config,
|
||||
kontext_image_ids=kontext_image_ids,
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
batch_size = patch_latents[0].shape[0]
|
||||
text_seq_len = prompt_embeds.shape[1]
|
||||
hidden_dim = self.adapter.hidden_dim
|
||||
|
||||
for patch_idx, patch in enumerate(patch_latents):
|
||||
patch_prev = patch
|
||||
|
||||
start_token, end_token = token_indices[patch_idx]
|
||||
|
||||
if self.has_joint_blocks:
|
||||
if (
|
||||
not self.is_first_stage
|
||||
or t != config.init_time_step + self.num_sync_steps
|
||||
):
|
||||
if self.is_first_stage:
|
||||
# First stage receives latent-space from last stage (scheduler output)
|
||||
recv_template = patch
|
||||
else:
|
||||
# Other stages receive hidden-space from previous stage
|
||||
patch_len = patch.shape[1]
|
||||
recv_template = mx.zeros(
|
||||
(batch_size, patch_len, hidden_dim),
|
||||
dtype=patch.dtype,
|
||||
)
|
||||
patch = mx.distributed.recv_like(
|
||||
recv_template, src=self.prev_rank, group=self.group
|
||||
)
|
||||
mx.eval(patch)
|
||||
patch_latents[patch_idx] = patch
|
||||
|
||||
if not self.is_first_stage and patch_idx == 0:
|
||||
enc_template = mx.zeros(
|
||||
(batch_size, text_seq_len, hidden_dim),
|
||||
dtype=patch_latents[0].dtype,
|
||||
)
|
||||
encoder_hidden_states = mx.distributed.recv_like(
|
||||
enc_template, src=self.prev_rank, group=self.group
|
||||
)
|
||||
mx.eval(encoder_hidden_states)
|
||||
|
||||
if self.is_first_stage:
|
||||
patch, encoder_hidden_states = self.adapter.compute_embeddings(
|
||||
patch, prompt_embeds
|
||||
)
|
||||
|
||||
# Run assigned joint blocks with patched mode
|
||||
for block_idx, wrapper in enumerate(self.joint_block_wrappers):
|
||||
encoder_hidden_states, patch = wrapper(
|
||||
hidden_states=patch,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
text_seq_len=text_seq_len,
|
||||
kv_cache=self.joint_kv_caches[block_idx],
|
||||
mode=BlockWrapperMode.PATCHED,
|
||||
patch_start=start_token,
|
||||
patch_end=end_token,
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
if self.owns_concat_stage:
|
||||
patch_concat = self.adapter.merge_streams(patch, encoder_hidden_states)
|
||||
|
||||
if self.has_single_blocks or self.is_last_stage:
|
||||
# Keep locally: either for single blocks or final projection
|
||||
patch = patch_concat
|
||||
else:
|
||||
mx.eval(
|
||||
mx.distributed.send(
|
||||
patch_concat, self.next_rank, group=self.group
|
||||
)
|
||||
)
|
||||
|
||||
elif self.has_joint_blocks and not self.is_last_stage:
|
||||
mx.eval(mx.distributed.send(patch, self.next_rank, group=self.group))
|
||||
|
||||
if patch_idx == 0:
|
||||
mx.eval(
|
||||
mx.distributed.send(
|
||||
encoder_hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
)
|
||||
|
||||
if self.has_single_blocks:
|
||||
if not self.owns_concat_stage and not self.is_first_stage:
|
||||
recv_template = mx.zeros(
|
||||
[
|
||||
batch_size,
|
||||
text_seq_len + patch_latents[patch_idx].shape[1],
|
||||
hidden_dim,
|
||||
],
|
||||
dtype=patch_latents[0].dtype,
|
||||
)
|
||||
|
||||
patch = mx.distributed.recv_like(
|
||||
recv_template, src=self.prev_rank, group=self.group
|
||||
)
|
||||
mx.eval(patch)
|
||||
patch_latents[patch_idx] = patch
|
||||
|
||||
# Run assigned single blocks with patched mode
|
||||
for block_idx, wrapper in enumerate(self.single_block_wrappers):
|
||||
patch = wrapper(
|
||||
hidden_states=patch,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
text_seq_len=text_seq_len,
|
||||
kv_cache=self.single_kv_caches[block_idx],
|
||||
mode=BlockWrapperMode.PATCHED,
|
||||
patch_start=start_token,
|
||||
patch_end=end_token,
|
||||
)
|
||||
|
||||
if not self.is_last_stage:
|
||||
mx.eval(
|
||||
mx.distributed.send(patch, self.next_rank, group=self.group)
|
||||
)
|
||||
|
||||
if self.is_last_stage:
|
||||
patch_img_only = patch[:, text_seq_len:, :]
|
||||
|
||||
patch_img_only = self.adapter.final_projection(
|
||||
patch_img_only, text_embeddings
|
||||
)
|
||||
|
||||
patch = config.scheduler.step(
|
||||
model_output=patch_img_only,
|
||||
timestep=t,
|
||||
sample=patch_prev,
|
||||
)
|
||||
|
||||
if not self.is_first_stage and t != config.num_inference_steps - 1:
|
||||
mx.eval(
|
||||
mx.distributed.send(patch, self.next_rank, group=self.group)
|
||||
)
|
||||
|
||||
patch_latents[patch_idx] = patch
|
||||
|
||||
return patch_latents
|
||||
@@ -103,6 +103,7 @@ class PipelineLastLayer(CustomMlxLayer):
|
||||
# This change happened upstream - check out mlx github somewhere??
|
||||
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
|
||||
|
||||
# TODO(ciaran): This is overkill
|
||||
output = mx.distributed.all_gather(output, group=self.group)[-output.shape[0] :]
|
||||
return output
|
||||
|
||||
|
||||
@@ -9,8 +9,7 @@ MAX_KV_SIZE: int | None = 3200
|
||||
KEEP_KV_SIZE: int | None = 1600
|
||||
QUANTIZE_MODEL_MODE: str | None = "affine"
|
||||
CACHE_GROUP_SIZE: int = 64
|
||||
KV_CACHE_BITS: int | None = 8
|
||||
TEMPERATURE: float = 1.0
|
||||
KV_CACHE_BITS: int | None = None
|
||||
|
||||
# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True
|
||||
TRUST_REMOTE_CODE: bool = True
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any, Callable, Generator, cast, get_args
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm import stream_generate
|
||||
from mlx_lm.generate import stream_generate
|
||||
from mlx_lm.models.cache import KVCache
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
from exo.worker.engines.mlx.constants import (
|
||||
CACHE_GROUP_SIZE,
|
||||
KV_CACHE_BITS,
|
||||
TEMPERATURE,
|
||||
TRUST_REMOTE_CODE,
|
||||
)
|
||||
|
||||
@@ -21,6 +20,8 @@ try:
|
||||
from mlx_lm.tokenizer_utils import load_tokenizer
|
||||
except ImportError:
|
||||
from mlx_lm.tokenizer_utils import load as load_tokenizer # type: ignore
|
||||
import contextlib
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_lm.utils import load_model
|
||||
@@ -48,6 +49,7 @@ from exo.worker.engines.mlx.auto_parallel import (
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
Group = mx.distributed.Group
|
||||
# Needed for 8 bit model
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096))
|
||||
|
||||
@@ -67,7 +69,7 @@ def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
|
||||
)
|
||||
|
||||
|
||||
def mx_barrier(group: mx.distributed.Group | None = None):
|
||||
def mx_barrier(group: Group | None = None):
|
||||
mx.eval(
|
||||
mx.distributed.all_sum(
|
||||
mx.array(1.0),
|
||||
@@ -77,7 +79,7 @@ def mx_barrier(group: mx.distributed.Group | None = None):
|
||||
)
|
||||
|
||||
|
||||
def broadcast_from_zero(value: int, group: mx.distributed.Group | None = None):
|
||||
def broadcast_from_zero(value: int, group: Group | None = None):
|
||||
if group is None:
|
||||
return value
|
||||
|
||||
@@ -99,91 +101,97 @@ class HostList(RootModel[list[str]]):
|
||||
|
||||
def mlx_distributed_init(
|
||||
bound_instance: BoundInstance,
|
||||
) -> mx.distributed.Group:
|
||||
) -> Group:
|
||||
"""
|
||||
Initialize the MLX distributed (runs in thread pool).
|
||||
|
||||
Either hosts or mlx_ibv_devices must be provided:
|
||||
- hosts: traditional host-based connectivity using MLX_HOSTFILE
|
||||
- mlx_ibv_devices: RDMA connectivity matrix using MLX_IBV_DEVICES
|
||||
- mlx_ibv_coordinator: coordinator address (IP:PORT) for RDMA setup
|
||||
- strict: if True, raise an error if the distributed backend is not available
|
||||
Initialize MLX distributed.
|
||||
"""
|
||||
rank = bound_instance.bound_shard.device_rank
|
||||
logger.info(f"Starting initialization for rank {rank}")
|
||||
|
||||
# TODO: singleton instances
|
||||
match bound_instance.instance:
|
||||
case MlxRingInstance(hosts=hosts):
|
||||
hostfile = f"./hosts_{rank}.json"
|
||||
hosts_json = HostList.from_hosts(hosts).model_dump_json()
|
||||
coordination_file = None
|
||||
try:
|
||||
# TODO: singleton instances
|
||||
match bound_instance.instance:
|
||||
case MlxRingInstance(hosts_by_node=hosts_by_node, ephemeral_port=_):
|
||||
coordination_file = (
|
||||
f"./hosts_{bound_instance.instance.instance_id}_{rank}.json"
|
||||
)
|
||||
hosts_for_node = hosts_by_node[bound_instance.bound_node_id]
|
||||
hosts_json = HostList.from_hosts(hosts_for_node).model_dump_json()
|
||||
|
||||
with open(hostfile, "w") as f:
|
||||
_ = f.write(hosts_json)
|
||||
with open(coordination_file, "w") as f:
|
||||
_ = f.write(hosts_json)
|
||||
|
||||
logger.info(f"rank {rank} hostfile: {hostfile} hosts: {hosts_json}")
|
||||
logger.info(
|
||||
f"rank {rank} hostfile: {coordination_file} hosts: {hosts_json}"
|
||||
)
|
||||
|
||||
os.environ["MLX_HOSTFILE"] = hostfile
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
os.environ["MLX_RING_VERBOSE"] = "1"
|
||||
group = mx.distributed.init(backend="ring", strict=True)
|
||||
os.environ["MLX_HOSTFILE"] = coordination_file
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
os.environ["MLX_RING_VERBOSE"] = "1"
|
||||
group = mx.distributed.init(backend="ring", strict=True)
|
||||
|
||||
case MlxJacclInstance(
|
||||
ibv_devices=ibv_devices, jaccl_coordinators=jaccl_coordinators
|
||||
):
|
||||
# Use RDMA connectivity matrix
|
||||
devices_file = f"./hosts_{rank}.json"
|
||||
ibv_devices_json = json.dumps(ibv_devices)
|
||||
case MlxJacclInstance(
|
||||
ibv_devices=ibv_devices, jaccl_coordinators=jaccl_coordinators
|
||||
):
|
||||
# Use RDMA connectivity matrix
|
||||
coordination_file = (
|
||||
f"./hosts_{bound_instance.instance.instance_id}_{rank}.json"
|
||||
)
|
||||
ibv_devices_json = json.dumps(ibv_devices)
|
||||
|
||||
with open(devices_file, "w") as f:
|
||||
_ = f.write(ibv_devices_json)
|
||||
with open(coordination_file, "w") as f:
|
||||
_ = f.write(ibv_devices_json)
|
||||
|
||||
jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id]
|
||||
jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id]
|
||||
|
||||
logger.info(f"rank {rank} MLX_IBV_DEVICES: {ibv_devices_json}")
|
||||
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
|
||||
os.environ["MLX_IBV_DEVICES"] = devices_file
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
|
||||
group = mx.distributed.init(backend="jaccl", strict=True)
|
||||
logger.info(f"rank {rank} MLX_IBV_DEVICES: {ibv_devices_json}")
|
||||
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
|
||||
os.environ["MLX_IBV_DEVICES"] = coordination_file
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
|
||||
group = mx.distributed.init(backend="jaccl", strict=True)
|
||||
|
||||
logger.info(f"Rank {rank} mlx distributed initialization complete")
|
||||
logger.info(f"Rank {rank} mlx distributed initialization complete")
|
||||
|
||||
return group
|
||||
return group
|
||||
finally:
|
||||
with contextlib.suppress(FileNotFoundError):
|
||||
if coordination_file:
|
||||
os.remove(coordination_file)
|
||||
|
||||
|
||||
def initialize_mlx(
|
||||
bound_instance: BoundInstance,
|
||||
) -> tuple[Model, TokenizerWrapper, Callable[[mx.array], mx.array]]:
|
||||
"""
|
||||
Initialize the MLX model, tokenizer, and sampler. Runs in the MLX thread.
|
||||
"""
|
||||
) -> Group:
|
||||
# should we unseed it?
|
||||
# TODO: pass in seed from params
|
||||
mx.random.seed(42)
|
||||
|
||||
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
|
||||
assert len(bound_instance.instance.shard_assignments.node_to_runner) > 1, (
|
||||
"Tried to initialize mlx for a single node instance"
|
||||
)
|
||||
return mlx_distributed_init(bound_instance)
|
||||
|
||||
sampler: Callable[[mx.array], mx.array] = make_sampler(temp=TEMPERATURE)
|
||||
|
||||
def load_mlx_items(
|
||||
bound_instance: BoundInstance, group: Group | None
|
||||
) -> tuple[Model, TokenizerWrapper, Callable[[mx.array], mx.array]]:
|
||||
# TODO: pass temperature
|
||||
sampler: Callable[[mx.array], mx.array] = make_sampler(temp=0.7)
|
||||
logger.info("Created a sampler")
|
||||
|
||||
if len(bound_instance.instance.shard_assignments.node_to_runner) <= 1:
|
||||
if group is None:
|
||||
logger.info(f"Single device used for {bound_instance.instance}")
|
||||
model_path = build_model_path(bound_instance.bound_shard.model_meta.model_id)
|
||||
start_time = time.perf_counter()
|
||||
model, _ = load_model(model_path, strict=True)
|
||||
end_time = time.perf_counter()
|
||||
logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s")
|
||||
if hasattr(model, "model") and isinstance(model.model, DeepseekV3Model): # type: ignore
|
||||
pass
|
||||
# model, config = quantize_model(
|
||||
# model, config, group_size=KV_GROUP_SIZE, bits=ATTENTION_KV_BITS, quant_predicate=quant_predicate, mode=QUANTIZE_MODEL_MODE
|
||||
# )
|
||||
|
||||
tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)
|
||||
|
||||
else:
|
||||
logger.info("Starting distributed init")
|
||||
group = mlx_distributed_init(bound_instance)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
model, tokenizer = shard_and_load(bound_instance.bound_shard, group=group)
|
||||
end_time = time.perf_counter()
|
||||
@@ -193,14 +201,12 @@ def initialize_mlx(
|
||||
|
||||
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
|
||||
|
||||
logger.debug(model)
|
||||
|
||||
return cast(Model, model), tokenizer, sampler
|
||||
|
||||
|
||||
def shard_and_load(
|
||||
shard_metadata: ShardMetadata,
|
||||
group: mx.distributed.Group,
|
||||
group: Group,
|
||||
) -> tuple[nn.Module, TokenizerWrapper]:
|
||||
model_path = build_model_path(shard_metadata.model_meta.model_id)
|
||||
|
||||
@@ -389,11 +395,5 @@ def set_wired_limit_for_model(model_size: Memory):
|
||||
"MB. This can be slow. See the documentation for possible work-arounds: "
|
||||
"https://github.com/ml-explore/mlx-lm/tree/main#large-models"
|
||||
)
|
||||
kv_bytes = int(0.02 * model_bytes)
|
||||
target_cache = int(1.10 * (model_bytes + kv_bytes))
|
||||
target_cache = min(target_cache, max_rec_size)
|
||||
mx.set_cache_limit(target_cache)
|
||||
mx.set_wired_limit(max_rec_size)
|
||||
logger.info(
|
||||
f"Wired limit set to {max_rec_size}. Cache limit set to {target_cache}."
|
||||
)
|
||||
logger.info(f"Wired limit set to {max_rec_size}.")
|
||||
|
||||
@@ -8,13 +8,15 @@ from loguru import logger
|
||||
|
||||
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.types.api import ImageEditsInternalParams
|
||||
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
from exo.shared.types.common import CommandId, NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
EventId,
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
InputChunkReceived,
|
||||
NodeDownloadProgress,
|
||||
NodeMemoryMeasured,
|
||||
NodePerformanceMeasured,
|
||||
@@ -23,12 +25,14 @@ from exo.shared.types.events import (
|
||||
TopologyEdgeCreated,
|
||||
TopologyEdgeDeleted,
|
||||
)
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformanceProfile
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
CreateRunner,
|
||||
DownloadModel,
|
||||
ImageEdits,
|
||||
Shutdown,
|
||||
Task,
|
||||
TaskStatus,
|
||||
@@ -83,7 +87,7 @@ class Worker:
|
||||
self.out_for_delivery: dict[EventId, ForwarderEvent] = {}
|
||||
|
||||
self.state: State = State()
|
||||
self.download_status: dict[ShardMetadata, DownloadProgress] = {}
|
||||
self.download_status: dict[ModelId, DownloadProgress] = {}
|
||||
self.runners: dict[RunnerId, RunnerSupervisor] = {}
|
||||
self._tg: TaskGroup | None = None
|
||||
|
||||
@@ -94,6 +98,10 @@ class Worker:
|
||||
|
||||
self.event_sender, self.event_receiver = channel[Event]()
|
||||
|
||||
# Buffer for input image chunks (for image editing)
|
||||
self.input_chunk_buffer: dict[CommandId, dict[int, str]] = {}
|
||||
self.input_chunk_counts: dict[CommandId, int] = {}
|
||||
|
||||
async def run(self):
|
||||
logger.info("Starting Worker")
|
||||
|
||||
@@ -128,6 +136,7 @@ class Worker:
|
||||
tg.start_soon(start_polling_node_metrics, resource_monitor_callback)
|
||||
|
||||
tg.start_soon(start_polling_memory_metrics, memory_monitor_callback)
|
||||
tg.start_soon(self._emit_existing_download_progress)
|
||||
tg.start_soon(self._connection_message_event_writer)
|
||||
tg.start_soon(self._resend_out_for_delivery)
|
||||
tg.start_soon(self._event_applier)
|
||||
@@ -171,6 +180,17 @@ class Worker:
|
||||
for idx, event in indexed_events:
|
||||
self.state = apply(self.state, IndexedEvent(idx=idx, event=event))
|
||||
|
||||
# Buffer input image chunks for image editing
|
||||
if isinstance(event, InputChunkReceived):
|
||||
cmd_id = event.command_id
|
||||
if cmd_id not in self.input_chunk_buffer:
|
||||
self.input_chunk_buffer[cmd_id] = {}
|
||||
self.input_chunk_counts[cmd_id] = event.chunk.total_chunks
|
||||
|
||||
self.input_chunk_buffer[cmd_id][event.chunk.chunk_index] = (
|
||||
event.chunk.data
|
||||
)
|
||||
|
||||
async def plan_step(self):
|
||||
while True:
|
||||
await anyio.sleep(0.1)
|
||||
@@ -183,6 +203,8 @@ class Worker:
|
||||
self.state.instances,
|
||||
self.state.runners,
|
||||
self.state.tasks,
|
||||
self.input_chunk_buffer,
|
||||
self.input_chunk_counts,
|
||||
)
|
||||
if task is None:
|
||||
continue
|
||||
@@ -200,11 +222,11 @@ class Worker:
|
||||
)
|
||||
)
|
||||
case DownloadModel(shard_metadata=shard):
|
||||
if shard not in self.download_status:
|
||||
if shard.model_meta.model_id not in self.download_status:
|
||||
progress = DownloadPending(
|
||||
shard_metadata=shard, node_id=self.node_id
|
||||
)
|
||||
self.download_status[shard] = progress
|
||||
self.download_status[shard.model_meta.model_id] = progress
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=progress)
|
||||
)
|
||||
@@ -217,7 +239,7 @@ class Worker:
|
||||
progress = DownloadCompleted(
|
||||
shard_metadata=shard, node_id=self.node_id
|
||||
)
|
||||
self.download_status[shard] = progress
|
||||
self.download_status[shard.model_meta.model_id] = progress
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=progress)
|
||||
)
|
||||
@@ -228,7 +250,7 @@ class Worker:
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.event_sender.send_nowait(
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Running
|
||||
)
|
||||
@@ -244,6 +266,42 @@ class Worker:
|
||||
task_id=task.task_id, task_status=TaskStatus.TimedOut
|
||||
)
|
||||
)
|
||||
case ImageEdits() if task.task_params.total_input_chunks > 0:
|
||||
# Assemble image from chunks and inject into task
|
||||
cmd_id = task.command_id
|
||||
chunks = self.input_chunk_buffer.get(cmd_id, {})
|
||||
assembled = "".join(chunks[i] for i in range(len(chunks)))
|
||||
logger.info(
|
||||
f"Assembled input image from {len(chunks)} chunks, "
|
||||
f"total size: {len(assembled)} bytes"
|
||||
)
|
||||
# Create modified task with assembled image data
|
||||
modified_task = ImageEdits(
|
||||
task_id=task.task_id,
|
||||
command_id=task.command_id,
|
||||
instance_id=task.instance_id,
|
||||
task_status=task.task_status,
|
||||
task_params=ImageEditsInternalParams(
|
||||
image_data=assembled,
|
||||
total_input_chunks=task.task_params.total_input_chunks,
|
||||
prompt=task.task_params.prompt,
|
||||
model=task.task_params.model,
|
||||
n=task.task_params.n,
|
||||
quality=task.task_params.quality,
|
||||
output_format=task.task_params.output_format,
|
||||
response_format=task.task_params.response_format,
|
||||
size=task.task_params.size,
|
||||
image_strength=task.task_params.image_strength,
|
||||
),
|
||||
)
|
||||
# Cleanup buffers
|
||||
if cmd_id in self.input_chunk_buffer:
|
||||
del self.input_chunk_buffer[cmd_id]
|
||||
if cmd_id in self.input_chunk_counts:
|
||||
del self.input_chunk_counts[cmd_id]
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(
|
||||
modified_task
|
||||
)
|
||||
case task:
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(task)
|
||||
|
||||
@@ -349,7 +407,7 @@ class Worker:
|
||||
initial_progress
|
||||
),
|
||||
)
|
||||
self.download_status[task.shard_metadata] = status
|
||||
self.download_status[task.shard_metadata.model_meta.model_id] = status
|
||||
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
|
||||
|
||||
last_progress_time = 0.0
|
||||
@@ -363,7 +421,7 @@ class Worker:
|
||||
nonlocal last_progress_time
|
||||
if progress.status == "complete":
|
||||
status = DownloadCompleted(shard_metadata=shard, node_id=self.node_id)
|
||||
self.download_status[shard] = status
|
||||
self.download_status[shard.model_meta.model_id] = status
|
||||
# Footgun!
|
||||
self.event_sender.send_nowait(
|
||||
NodeDownloadProgress(download_progress=status)
|
||||
@@ -384,7 +442,7 @@ class Worker:
|
||||
progress
|
||||
),
|
||||
)
|
||||
self.download_status[shard] = status
|
||||
self.download_status[shard.model_meta.model_id] = status
|
||||
self.event_sender.send_nowait(
|
||||
NodeDownloadProgress(download_progress=status)
|
||||
)
|
||||
@@ -414,9 +472,14 @@ class Worker:
|
||||
while True:
|
||||
# TODO: EdgeDeleted
|
||||
edges = set(self.state.topology.list_connections())
|
||||
conns = await check_reachable(self.state.topology)
|
||||
conns = await check_reachable(self.state.topology, self.node_id)
|
||||
for nid in conns:
|
||||
for ip in conns[nid]:
|
||||
if "127.0.0.1" in ip or "localhost" in ip:
|
||||
logger.warning(
|
||||
f"Loopback connection should not happen: {ip=} for {nid=}"
|
||||
)
|
||||
|
||||
edge = Connection(
|
||||
local_node_id=self.node_id,
|
||||
send_back_node_id=nid,
|
||||
@@ -439,3 +502,40 @@ class Worker:
|
||||
await self.event_sender.send(TopologyEdgeDeleted(edge=conn))
|
||||
|
||||
await anyio.sleep(10)
|
||||
|
||||
async def _emit_existing_download_progress(self) -> None:
|
||||
try:
|
||||
while True:
|
||||
logger.info("Fetching and emitting existing download progress...")
|
||||
async for (
|
||||
_,
|
||||
progress,
|
||||
) in self.shard_downloader.get_shard_download_status():
|
||||
if progress.status == "complete":
|
||||
status = DownloadCompleted(
|
||||
node_id=self.node_id, shard_metadata=progress.shard
|
||||
)
|
||||
elif progress.status in ["in_progress", "not_started"]:
|
||||
if progress.downloaded_bytes_this_session.in_bytes == 0:
|
||||
status = DownloadPending(
|
||||
node_id=self.node_id, shard_metadata=progress.shard
|
||||
)
|
||||
else:
|
||||
status = DownloadOngoing(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
progress
|
||||
),
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
self.download_status[progress.shard.model_meta.model_id] = status
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=status)
|
||||
)
|
||||
logger.info("Done emitting existing download progress.")
|
||||
await anyio.sleep(5 * 60) # 5 minutes
|
||||
except Exception as e:
|
||||
logger.error(f"Error emitting existing download progress: {e}")
|
||||
|
||||
@@ -2,11 +2,15 @@
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ConnectToGroup,
|
||||
CreateRunner,
|
||||
DownloadModel,
|
||||
ImageEdits,
|
||||
ImageGeneration,
|
||||
LoadModel,
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
@@ -14,20 +18,25 @@ from exo.shared.types.tasks import (
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadCompleted,
|
||||
DownloadOngoing,
|
||||
DownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.instances import BoundInstance, Instance, InstanceId
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerConnected,
|
||||
RunnerConnecting,
|
||||
RunnerFailed,
|
||||
RunnerId,
|
||||
RunnerIdle,
|
||||
RunnerLoaded,
|
||||
RunnerLoading,
|
||||
RunnerReady,
|
||||
RunnerRunning,
|
||||
RunnerStatus,
|
||||
RunnerWaitingForModel,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.worker.runner.runner_supervisor import RunnerSupervisor
|
||||
|
||||
|
||||
@@ -36,21 +45,24 @@ def plan(
|
||||
# Runners is expected to be FRESH and so should not come from state
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
# DL_status is expected to be FRESH and so should not come from state
|
||||
download_status: Mapping[ShardMetadata, DownloadProgress],
|
||||
download_status: Mapping[ModelId, DownloadProgress],
|
||||
# gdls is not expected to be fresh
|
||||
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
|
||||
instances: Mapping[InstanceId, Instance],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus], # all global
|
||||
tasks: Mapping[TaskId, Task],
|
||||
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
|
||||
input_chunk_counts: Mapping[CommandId, int] | None = None,
|
||||
) -> Task | None:
|
||||
# Python short circuiting OR logic should evaluate these sequentially.
|
||||
return (
|
||||
_kill_runner(runners, all_runners, instances)
|
||||
or _create_runner(node_id, runners, instances)
|
||||
or _model_needs_download(runners, download_status)
|
||||
or _init_distributed_backend(runners, all_runners)
|
||||
or _load_model(runners, all_runners, global_download_status)
|
||||
or _ready_to_warmup(runners, all_runners)
|
||||
or _pending_tasks(runners, tasks, all_runners)
|
||||
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer)
|
||||
)
|
||||
|
||||
|
||||
@@ -103,12 +115,15 @@ def _create_runner(
|
||||
|
||||
def _model_needs_download(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
download_status: Mapping[ShardMetadata, DownloadProgress],
|
||||
download_status: Mapping[ModelId, DownloadProgress],
|
||||
) -> DownloadModel | None:
|
||||
for runner in runners.values():
|
||||
if (
|
||||
isinstance(runner.status, RunnerWaitingForModel)
|
||||
and runner.bound_instance.bound_shard not in download_status
|
||||
model_id = runner.bound_instance.bound_shard.model_meta.model_id
|
||||
if isinstance(runner.status, RunnerIdle) and (
|
||||
model_id not in download_status
|
||||
or not isinstance(
|
||||
download_status[model_id], (DownloadOngoing, DownloadCompleted)
|
||||
)
|
||||
):
|
||||
# We don't invalidate download_status randomly in case a file gets deleted on disk
|
||||
return DownloadModel(
|
||||
@@ -117,14 +132,54 @@ def _model_needs_download(
|
||||
)
|
||||
|
||||
|
||||
""" --- TODO!
|
||||
def _init_backend(
|
||||
def _init_distributed_backend(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
) -> LoadModel | None:
|
||||
for runner in runner.values()
|
||||
pass
|
||||
"""
|
||||
):
|
||||
for runner in runners.values():
|
||||
instance = runner.bound_instance.instance
|
||||
shard_assignments = instance.shard_assignments
|
||||
|
||||
is_single_node_instance = len(shard_assignments.runner_to_shard) == 1
|
||||
if is_single_node_instance:
|
||||
continue
|
||||
|
||||
runner_is_idle = isinstance(runner.status, RunnerIdle)
|
||||
all_runners_connecting = all(
|
||||
isinstance(
|
||||
all_runners.get(global_runner_id),
|
||||
(RunnerConnecting, RunnerIdle),
|
||||
)
|
||||
for global_runner_id in shard_assignments.runner_to_shard
|
||||
)
|
||||
|
||||
if not (runner_is_idle and all_runners_connecting):
|
||||
continue
|
||||
|
||||
runner_id = runner.bound_instance.bound_runner_id
|
||||
|
||||
shard = runner.bound_instance.bound_shard
|
||||
device_rank = shard.device_rank
|
||||
world_size = shard.world_size
|
||||
|
||||
assert device_rank < world_size
|
||||
assert device_rank >= 0
|
||||
|
||||
accepting_ranks = device_rank < world_size - 1
|
||||
|
||||
# Rank = n-1
|
||||
connecting_rank_ready = device_rank == world_size - 1 and all(
|
||||
isinstance(all_runners.get(global_runner_id, None), RunnerConnecting)
|
||||
for global_runner_id in shard_assignments.runner_to_shard
|
||||
if global_runner_id != runner_id
|
||||
)
|
||||
|
||||
if not (accepting_ranks or connecting_rank_ready):
|
||||
continue
|
||||
|
||||
return ConnectToGroup(instance_id=instance.instance_id)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _load_model(
|
||||
@@ -136,31 +191,33 @@ def _load_model(
|
||||
instance = runner.bound_instance.instance
|
||||
shard_assignments = instance.shard_assignments
|
||||
|
||||
all_downloads_complete_local = all(
|
||||
all_local_downloads_complete = all(
|
||||
nid in global_download_status
|
||||
and any(
|
||||
isinstance(dp, DownloadCompleted)
|
||||
and dp.shard_metadata == shard_assignments.runner_to_shard[rid]
|
||||
and dp.shard_metadata.model_meta.model_id == shard_assignments.model_id
|
||||
for dp in global_download_status[nid]
|
||||
)
|
||||
for nid, rid in shard_assignments.node_to_runner.items()
|
||||
for nid in shard_assignments.node_to_runner
|
||||
)
|
||||
if not all_local_downloads_complete:
|
||||
continue
|
||||
|
||||
runner_is_waiting = isinstance(runner.status, RunnerWaitingForModel)
|
||||
is_single_node_instance = len(instance.shard_assignments.runner_to_shard) == 1
|
||||
if is_single_node_instance and isinstance(runner.status, RunnerIdle):
|
||||
return LoadModel(instance_id=instance.instance_id)
|
||||
|
||||
all_runners_expecting_model = all(
|
||||
is_runner_waiting = isinstance(runner.status, RunnerConnected)
|
||||
|
||||
all_ready_for_model = all(
|
||||
isinstance(
|
||||
all_runners.get(global_runner_id),
|
||||
(RunnerWaitingForModel, RunnerLoading, RunnerLoaded),
|
||||
all_runners.get(global_runner_id, None),
|
||||
(RunnerConnected, RunnerLoading, RunnerLoaded),
|
||||
)
|
||||
for global_runner_id in shard_assignments.runner_to_shard
|
||||
)
|
||||
|
||||
if (
|
||||
all_downloads_complete_local
|
||||
and runner_is_waiting
|
||||
and all_runners_expecting_model
|
||||
):
|
||||
if is_runner_waiting and all_ready_for_model:
|
||||
return LoadModel(instance_id=instance.instance_id)
|
||||
|
||||
return None
|
||||
@@ -183,8 +240,8 @@ def _ready_to_warmup(
|
||||
assert device_rank < world_size
|
||||
assert device_rank >= 0
|
||||
|
||||
# Rank != n-1
|
||||
accepting_ranks_ready = device_rank != world_size - 1 and all(
|
||||
# Rank != 0
|
||||
accepting_ranks_ready = device_rank > 0 and all(
|
||||
isinstance(
|
||||
all_runners.get(global_runner_id, None),
|
||||
(RunnerLoaded, RunnerWarmingUp),
|
||||
@@ -192,8 +249,8 @@ def _ready_to_warmup(
|
||||
for global_runner_id in shard_assignments.runner_to_shard
|
||||
)
|
||||
|
||||
# Rank = n-1
|
||||
connecting_rank_ready = device_rank == world_size - 1 and all(
|
||||
# Rank = 0
|
||||
connecting_rank_ready = device_rank == 0 and all(
|
||||
isinstance(all_runners.get(global_runner_id, None), RunnerWarmingUp)
|
||||
for global_runner_id in shard_assignments.runner_to_shard
|
||||
if global_runner_id != runner_id
|
||||
@@ -209,18 +266,40 @@ def _pending_tasks(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
|
||||
) -> Task | None:
|
||||
for task in tasks.values():
|
||||
# for now, just forward chat completions
|
||||
if not isinstance(task, ChatCompletion):
|
||||
# TODO(ciaran): do this better!
|
||||
if (
|
||||
not isinstance(task, ChatCompletion)
|
||||
and not isinstance(task, ImageGeneration)
|
||||
and not isinstance(task, ImageEdits)
|
||||
):
|
||||
continue
|
||||
if task.task_status not in (TaskStatus.Pending, TaskStatus.Running):
|
||||
continue
|
||||
|
||||
# For ImageEdits tasks, verify all input chunks have been received
|
||||
if isinstance(task, ImageEdits) and task.task_params.total_input_chunks > 0:
|
||||
cmd_id = task.command_id
|
||||
expected = task.task_params.total_input_chunks
|
||||
received = len((input_chunk_buffer or {}).get(cmd_id, {}))
|
||||
if received < expected:
|
||||
continue # Wait for all chunks to arrive
|
||||
|
||||
for runner in runners.values():
|
||||
if task.instance_id != runner.bound_instance.instance.instance_id:
|
||||
continue
|
||||
|
||||
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
|
||||
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
|
||||
# the actual solution is somewhat deeper than this bypass - TODO!
|
||||
if task.task_id in runner.completed:
|
||||
continue
|
||||
|
||||
# TODO: Check ordering aligns with MLX distributeds expectations.
|
||||
|
||||
if isinstance(runner.status, RunnerReady) and all(
|
||||
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
|
||||
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
|
||||
|
||||
@@ -2,16 +2,13 @@ import os
|
||||
|
||||
import loguru
|
||||
|
||||
from exo.shared.types.events import Event
|
||||
from exo.shared.types.events import Event, RunnerStatusUpdated
|
||||
from exo.shared.types.tasks import Task
|
||||
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
|
||||
from exo.shared.types.worker.runners import RunnerFailed
|
||||
from exo.utils.channels import MpReceiver, MpSender
|
||||
|
||||
logger: "loguru.Logger"
|
||||
|
||||
|
||||
if os.getenv("EXO_TESTS") == "1":
|
||||
logger = loguru.logger
|
||||
logger: "loguru.Logger" = loguru.logger
|
||||
|
||||
|
||||
def entrypoint(
|
||||
@@ -30,6 +27,23 @@ def entrypoint(
|
||||
logger = _logger
|
||||
|
||||
# Import main after setting global logger - this lets us just import logger from this module
|
||||
from exo.worker.runner.runner import main
|
||||
try:
|
||||
from exo.worker.runner.runner import main
|
||||
|
||||
main(bound_instance, event_sender, task_receiver)
|
||||
main(bound_instance, event_sender, task_receiver)
|
||||
except Exception as e:
|
||||
logger.opt(exception=e).warning(
|
||||
f"Runner {bound_instance.bound_runner_id} crashed with critical exception {e}"
|
||||
)
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=bound_instance.bound_runner_id,
|
||||
runner_status=RunnerFailed(error_message=str(e)),
|
||||
)
|
||||
)
|
||||
finally:
|
||||
event_sender.close()
|
||||
task_receiver.close()
|
||||
event_sender.join()
|
||||
task_receiver.join()
|
||||
logger.info("bye from the runner")
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import base64
|
||||
import time
|
||||
|
||||
from exo.master.api import get_model_card
|
||||
from exo.shared.constants import EXO_MAX_CHUNK_SIZE
|
||||
from exo.shared.types.api import ChatCompletionMessageText
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.chunks import ImageChunk, TokenChunk
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
@@ -9,8 +12,12 @@ from exo.shared.types.events import (
|
||||
TaskAcknowledged,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.models import ModelTask
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ConnectToGroup,
|
||||
ImageEdits,
|
||||
ImageGeneration,
|
||||
LoadModel,
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
@@ -20,22 +27,35 @@ from exo.shared.types.tasks import (
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
ImageGenerationResponse,
|
||||
PartialImageResponse,
|
||||
)
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerConnected,
|
||||
RunnerConnecting,
|
||||
RunnerFailed,
|
||||
RunnerIdle,
|
||||
RunnerLoaded,
|
||||
RunnerLoading,
|
||||
RunnerReady,
|
||||
RunnerRunning,
|
||||
RunnerShutdown,
|
||||
RunnerShuttingDown,
|
||||
RunnerStatus,
|
||||
RunnerWaitingForModel,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
|
||||
from exo.worker.engines.image import (
|
||||
ImageGenerator,
|
||||
generate_image,
|
||||
initialize_image_model,
|
||||
warmup_image_generator,
|
||||
)
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
initialize_mlx,
|
||||
load_mlx_items,
|
||||
mlx_force_oom,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
@@ -63,9 +83,14 @@ def main(
|
||||
model = None
|
||||
tokenizer = None
|
||||
sampler = None
|
||||
group = None
|
||||
|
||||
current_status: RunnerStatus = RunnerWaitingForModel()
|
||||
logger.info("runner waiting for model")
|
||||
model_card = get_model_card(shard_metadata.model_meta.model_id)
|
||||
assert model_card
|
||||
model_tasks = model_card.tasks
|
||||
|
||||
current_status: RunnerStatus = RunnerIdle()
|
||||
logger.info("runner created")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
@@ -78,9 +103,26 @@ def main(
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
match task:
|
||||
case LoadModel() if isinstance(
|
||||
current_status, (RunnerWaitingForModel, RunnerFailed)
|
||||
case ConnectToGroup() if isinstance(
|
||||
current_status, (RunnerIdle, RunnerFailed)
|
||||
):
|
||||
logger.info("runner connecting")
|
||||
current_status = RunnerConnecting()
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
group = initialize_mlx(bound_instance)
|
||||
|
||||
logger.info("runner connected")
|
||||
current_status = RunnerConnected()
|
||||
|
||||
# we load the model if it's connected with a group, or idle without a group. we should never tell a model to connect if it doesn't need to
|
||||
case LoadModel() if (
|
||||
isinstance(current_status, RunnerConnected)
|
||||
and group is not None
|
||||
) or (isinstance(current_status, RunnerIdle) and group is None):
|
||||
current_status = RunnerLoading()
|
||||
logger.info("runner loading")
|
||||
event_sender.send(
|
||||
@@ -89,19 +131,26 @@ def main(
|
||||
)
|
||||
)
|
||||
|
||||
model, tokenizer, sampler = initialize_mlx(bound_instance)
|
||||
# TODO(ciaran): switch
|
||||
if ModelTask.TextGeneration in model_tasks:
|
||||
model, tokenizer, sampler = load_mlx_items(
|
||||
bound_instance, group
|
||||
)
|
||||
elif (
|
||||
ModelTask.TextToImage in model_tasks
|
||||
or ModelTask.ImageToImage in model_tasks
|
||||
):
|
||||
model = initialize_image_model(bound_instance)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown model task(s): {model_card.tasks}"
|
||||
)
|
||||
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
||||
assert model
|
||||
assert tokenizer
|
||||
assert sampler
|
||||
|
||||
current_status = RunnerWarmingUp()
|
||||
logger.info("runner warming up")
|
||||
event_sender.send(
|
||||
@@ -111,27 +160,40 @@ def main(
|
||||
)
|
||||
|
||||
logger.info(f"warming up inference for instance: {instance}")
|
||||
toks = warmup_inference(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
sampler=sampler,
|
||||
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
|
||||
)
|
||||
logger.info(f"warmed up by generating {toks} tokens")
|
||||
logger.info(
|
||||
f"runner initialized in {time.time() - setup_start_time} seconds"
|
||||
)
|
||||
if ModelTask.TextGeneration in model_tasks:
|
||||
assert model and isinstance(model, Model)
|
||||
assert tokenizer
|
||||
assert sampler
|
||||
|
||||
toks = warmup_inference(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
sampler=sampler,
|
||||
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
|
||||
)
|
||||
logger.info(f"warmed up by generating {toks} tokens")
|
||||
logger.info(
|
||||
f"runner initialized in {time.time() - setup_start_time} seconds"
|
||||
)
|
||||
elif (
|
||||
ModelTask.TextToImage in model_tasks
|
||||
or ModelTask.ImageToImage in model_tasks
|
||||
):
|
||||
assert isinstance(model, ImageGenerator)
|
||||
image = warmup_image_generator(model=model)
|
||||
if image is not None:
|
||||
logger.info(
|
||||
f"warmed up by generating {image.size} image"
|
||||
)
|
||||
else:
|
||||
logger.info("warmup completed (non-primary node)")
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=RunnerReady()
|
||||
)
|
||||
)
|
||||
case ChatCompletion(
|
||||
task_params=task_params, command_id=command_id
|
||||
) if isinstance(current_status, RunnerReady):
|
||||
assert model
|
||||
assert model and isinstance(model, Model)
|
||||
assert tokenizer
|
||||
assert sampler
|
||||
logger.info(f"received chat request: {str(task)[:500]}")
|
||||
@@ -172,29 +234,187 @@ def main(
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case ImageGeneration(
|
||||
task_params=task_params, command_id=command_id
|
||||
) if isinstance(current_status, RunnerReady):
|
||||
assert isinstance(model, ImageGenerator)
|
||||
logger.info(
|
||||
f"received image generation request: {str(task)[:500]}"
|
||||
)
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=RunnerReady()
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
|
||||
# Generate images using the image generation backend
|
||||
# Track image_index for final images only
|
||||
image_index = 0
|
||||
for response in generate_image(
|
||||
model=model,
|
||||
task=task_params,
|
||||
):
|
||||
if (
|
||||
shard_metadata.device_rank
|
||||
== shard_metadata.world_size - 1
|
||||
):
|
||||
match response:
|
||||
case PartialImageResponse():
|
||||
encoded_data = base64.b64encode(
|
||||
response.image_data
|
||||
).decode("utf-8")
|
||||
# Split into chunks to stay under gossipsub 1MB limit
|
||||
data_chunks = [
|
||||
encoded_data[i : i + EXO_MAX_CHUNK_SIZE]
|
||||
for i in range(
|
||||
0, len(encoded_data), EXO_MAX_CHUNK_SIZE
|
||||
)
|
||||
]
|
||||
total_chunks = len(data_chunks)
|
||||
logger.info(
|
||||
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}: {len(encoded_data)} bytes in {total_chunks} chunks"
|
||||
)
|
||||
for chunk_index, chunk_data in enumerate(
|
||||
data_chunks
|
||||
):
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ImageChunk(
|
||||
idx=chunk_index,
|
||||
model=shard_metadata.model_meta.model_id,
|
||||
data=chunk_data,
|
||||
chunk_index=chunk_index,
|
||||
total_chunks=total_chunks,
|
||||
image_index=response.partial_index,
|
||||
is_partial=True,
|
||||
partial_index=response.partial_index,
|
||||
total_partials=response.total_partials,
|
||||
),
|
||||
)
|
||||
)
|
||||
case ImageGenerationResponse():
|
||||
encoded_data = base64.b64encode(
|
||||
response.image_data
|
||||
).decode("utf-8")
|
||||
# Split into chunks to stay under gossipsub 1MB limit
|
||||
data_chunks = [
|
||||
encoded_data[i : i + EXO_MAX_CHUNK_SIZE]
|
||||
for i in range(
|
||||
0, len(encoded_data), EXO_MAX_CHUNK_SIZE
|
||||
)
|
||||
]
|
||||
total_chunks = len(data_chunks)
|
||||
logger.info(
|
||||
f"sending final ImageChunk: {len(encoded_data)} bytes in {total_chunks} chunks"
|
||||
)
|
||||
for chunk_index, chunk_data in enumerate(
|
||||
data_chunks
|
||||
):
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ImageChunk(
|
||||
idx=chunk_index,
|
||||
model=shard_metadata.model_meta.model_id,
|
||||
data=chunk_data,
|
||||
chunk_index=chunk_index,
|
||||
total_chunks=total_chunks,
|
||||
image_index=image_index,
|
||||
is_partial=False,
|
||||
),
|
||||
)
|
||||
)
|
||||
image_index += 1
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case ImageEdits(task_params=task_params, command_id=command_id) if (
|
||||
isinstance(current_status, RunnerReady)
|
||||
):
|
||||
assert isinstance(model, ImageGenerator)
|
||||
logger.info(f"received image edits request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
|
||||
image_index = 0
|
||||
for response in generate_image(
|
||||
model=model,
|
||||
task=task_params,
|
||||
):
|
||||
if (
|
||||
shard_metadata.device_rank
|
||||
== shard_metadata.world_size - 1
|
||||
):
|
||||
match response:
|
||||
case ImageGenerationResponse():
|
||||
encoded_data = base64.b64encode(
|
||||
response.image_data
|
||||
).decode("utf-8")
|
||||
# Split into chunks to stay under gossipsub 1MB limit
|
||||
data_chunks = [
|
||||
encoded_data[i : i + EXO_MAX_CHUNK_SIZE]
|
||||
for i in range(
|
||||
0, len(encoded_data), EXO_MAX_CHUNK_SIZE
|
||||
)
|
||||
]
|
||||
total_chunks = len(data_chunks)
|
||||
logger.info(
|
||||
f"sending ImageChunk: {len(encoded_data)} bytes in {total_chunks} chunks"
|
||||
)
|
||||
for chunk_index, chunk_data in enumerate(
|
||||
data_chunks
|
||||
):
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ImageChunk(
|
||||
idx=chunk_index,
|
||||
model=shard_metadata.model_meta.model_id,
|
||||
data=chunk_data,
|
||||
chunk_index=chunk_index,
|
||||
total_chunks=total_chunks,
|
||||
image_index=image_index,
|
||||
is_partial=False,
|
||||
),
|
||||
)
|
||||
)
|
||||
image_index += 1
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case Shutdown():
|
||||
current_status = RunnerShuttingDown()
|
||||
logger.info("runner shutting down")
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Complete
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
break
|
||||
current_status = RunnerShutdown()
|
||||
case _:
|
||||
raise ValueError("Received task outside of state machine")
|
||||
raise ValueError(
|
||||
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
|
||||
)
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Complete
|
||||
)
|
||||
)
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=RunnerShutdown())
|
||||
)
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
if isinstance(current_status, RunnerShutdown):
|
||||
break
|
||||
except ClosedResourceError:
|
||||
logger.warning("runner communication closed unexpectedly")
|
||||
except Exception as e:
|
||||
|
||||
@@ -14,13 +14,23 @@ from anyio import (
|
||||
from anyio.abc import TaskGroup
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.types.events import Event, RunnerStatusUpdated, TaskAcknowledged
|
||||
from exo.shared.types.tasks import Task, TaskId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
RunnerStatusUpdated,
|
||||
TaskAcknowledged,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerConnecting,
|
||||
RunnerFailed,
|
||||
RunnerIdle,
|
||||
RunnerLoading,
|
||||
RunnerRunning,
|
||||
RunnerShuttingDown,
|
||||
RunnerStatus,
|
||||
RunnerWaitingForModel,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.utils.channels import MpReceiver, MpSender, Sender, mp_channel
|
||||
@@ -39,10 +49,10 @@ class RunnerSupervisor:
|
||||
_ev_recv: MpReceiver[Event]
|
||||
_task_sender: MpSender[Task]
|
||||
_event_sender: Sender[Event]
|
||||
# err_path: str
|
||||
_tg: TaskGroup | None = field(default=None, init=False)
|
||||
status: RunnerStatus = field(default_factory=RunnerWaitingForModel, init=False)
|
||||
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
|
||||
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
|
||||
completed: set[TaskId] = field(default_factory=set, init=False)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
@@ -77,7 +87,6 @@ class RunnerSupervisor:
|
||||
_ev_recv=ev_recv,
|
||||
_task_sender=task_sender,
|
||||
_event_sender=event_sender,
|
||||
# err_path=err_path,
|
||||
)
|
||||
|
||||
return self
|
||||
@@ -118,6 +127,10 @@ class RunnerSupervisor:
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def start_task(self, task: Task):
|
||||
if task.task_id in self.completed:
|
||||
logger.info(
|
||||
f"Skipping invalid task {task} as it has already been completed"
|
||||
)
|
||||
logger.info(f"Starting task {task}")
|
||||
event = anyio.Event()
|
||||
self.pending[task.task_id] = event
|
||||
@@ -138,6 +151,22 @@ class RunnerSupervisor:
|
||||
if isinstance(event, TaskAcknowledged):
|
||||
self.pending.pop(event.task_id).set()
|
||||
continue
|
||||
if (
|
||||
isinstance(event, TaskStatusUpdated)
|
||||
and event.task_status == TaskStatus.Complete
|
||||
):
|
||||
# If a task has just been completed, we should be working on it.
|
||||
assert isinstance(
|
||||
self.status,
|
||||
(
|
||||
RunnerRunning,
|
||||
RunnerWarmingUp,
|
||||
RunnerLoading,
|
||||
RunnerConnecting,
|
||||
RunnerShuttingDown,
|
||||
),
|
||||
)
|
||||
self.completed.add(event.task_id)
|
||||
await self._event_sender.send(event)
|
||||
except (ClosedResourceError, BrokenResourceError) as e:
|
||||
await self._check_runner(e)
|
||||
|
||||
@@ -9,9 +9,11 @@ MASTER_NODE_ID = NodeId("ffffffff-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
|
||||
|
||||
NODE_A: Final[NodeId] = NodeId("aaaaaaaa-aaaa-4aaa-8aaa-aaaaaaaaaaaa")
|
||||
NODE_B: Final[NodeId] = NodeId("bbbbbbbb-bbbb-4bbb-8bbb-bbbbbbbbbbbb")
|
||||
NODE_C: Final[NodeId] = NodeId("cccccccc-cccc-4ccc-8ccc-cccccccccccc")
|
||||
|
||||
RUNNER_1_ID: Final[RunnerId] = RunnerId("11111111-1111-4111-8111-111111111111")
|
||||
RUNNER_2_ID: Final[RunnerId] = RunnerId("33333333-3333-4333-8333-333333333333")
|
||||
RUNNER_3_ID: Final[RunnerId] = RunnerId("Runner3")
|
||||
|
||||
INSTANCE_1_ID: Final[InstanceId] = InstanceId("22222222-2222-4222-8222-222222222222")
|
||||
INSTANCE_2_ID: Final[InstanceId] = InstanceId("44444444-4444-4444-8444-444444444444")
|
||||
@@ -24,3 +26,9 @@ TASK_2_ID: Final[TaskId] = TaskId("66666666-6666-4666-8666-666666666666")
|
||||
|
||||
COMMAND_1_ID: Final[CommandId] = CommandId("77777777-7777-4777-8777-777777777777")
|
||||
COMMAND_2_ID: Final[CommandId] = CommandId("88888888-8888-4888-8888-888888888888")
|
||||
|
||||
SHUTDOWN_TASK_ID = TaskId("shutdown")
|
||||
CHAT_COMPLETION_TASK_ID = TaskId("chat-completion")
|
||||
INITIALIZATION_TASK_ID = TaskId("initialisation")
|
||||
LOAD_TASK_ID = TaskId("load")
|
||||
WARMUP_TASK_ID = TaskId("warmup")
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.tasks import BaseTask
|
||||
from exo.shared.types.tasks import BaseTask, TaskId
|
||||
from exo.shared.types.worker.instances import (
|
||||
BoundInstance,
|
||||
Instance,
|
||||
@@ -14,10 +14,12 @@ from exo.shared.types.worker.runners import RunnerId, RunnerStatus, ShardAssignm
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
|
||||
|
||||
|
||||
# Runner supervisor without multiprocessing logic.
|
||||
@dataclass(frozen=True)
|
||||
class FakeRunnerSupervisor:
|
||||
bound_instance: BoundInstance
|
||||
status: RunnerStatus
|
||||
completed: set[TaskId] = field(default_factory=set)
|
||||
|
||||
|
||||
class OtherTask(BaseTask):
|
||||
@@ -35,6 +37,8 @@ def get_pipeline_shard_metadata(
|
||||
pretty_name=str(model_id),
|
||||
storage_size=Memory.from_mb(100000),
|
||||
n_layers=32,
|
||||
hidden_size=2048,
|
||||
supports_tensor=False,
|
||||
),
|
||||
device_rank=device_rank,
|
||||
world_size=world_size,
|
||||
@@ -67,5 +71,27 @@ def get_mlx_ring_instance(
|
||||
shard_assignments=get_shard_assignments(
|
||||
model_id, node_to_runner, runner_to_shard
|
||||
),
|
||||
hosts=[],
|
||||
hosts_by_node={},
|
||||
ephemeral_port=50000,
|
||||
)
|
||||
|
||||
|
||||
def get_bound_mlx_ring_instance(
|
||||
instance_id: InstanceId, model_id: ModelId, runner_id: RunnerId, node_id: NodeId
|
||||
) -> BoundInstance:
|
||||
shard = get_pipeline_shard_metadata(model_id=model_id, device_rank=0, world_size=2)
|
||||
other_shard = get_pipeline_shard_metadata(
|
||||
model_id=model_id, device_rank=1, world_size=2
|
||||
)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=instance_id,
|
||||
model_id=model_id,
|
||||
node_to_runner={
|
||||
node_id: runner_id,
|
||||
NodeId("other_node"): RunnerId("other_runner"),
|
||||
},
|
||||
runner_to_shard={runner_id: shard, RunnerId("other_runner"): other_shard},
|
||||
)
|
||||
return BoundInstance(
|
||||
instance=instance, bound_runner_id=runner_id, bound_node_id=node_id
|
||||
)
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import exo.worker.plan as plan_mod
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.tasks import LoadModel
|
||||
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerWaitingForModel,
|
||||
RunnerConnected,
|
||||
RunnerIdle,
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.worker.tests.constants import (
|
||||
INSTANCE_1_ID,
|
||||
MODEL_A_ID,
|
||||
@@ -38,16 +39,14 @@ def test_plan_requests_download_when_waiting_and_shard_not_downloaded():
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerWaitingForModel()
|
||||
)
|
||||
runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerIdle())
|
||||
|
||||
runners = {RUNNER_1_ID: runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {RUNNER_1_ID: RunnerWaitingForModel()}
|
||||
all_runners = {RUNNER_1_ID: RunnerIdle()}
|
||||
|
||||
# No entry for this shard -> should trigger DownloadModel
|
||||
download_status: dict[ShardMetadata, DownloadProgress] = {}
|
||||
download_status: dict[ModelId, DownloadProgress] = {}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
@@ -82,20 +81,20 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerWaitingForModel()
|
||||
bound_instance=bound_instance, status=RunnerConnected()
|
||||
)
|
||||
|
||||
runners = {RUNNER_1_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerWaitingForModel(),
|
||||
RUNNER_2_ID: RunnerWaitingForModel(),
|
||||
RUNNER_1_ID: RunnerConnected(),
|
||||
RUNNER_2_ID: RunnerConnected(),
|
||||
}
|
||||
|
||||
# Local node has already marked its shard as downloaded (not actually used by _load_model)
|
||||
local_download_status = {
|
||||
shard1: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A) # type: ignore[reportUnhashable]
|
||||
MODEL_A_ID: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)
|
||||
}
|
||||
|
||||
# Global view has completed downloads for both nodes
|
||||
@@ -133,17 +132,15 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerWaitingForModel()
|
||||
)
|
||||
runner = FakeRunnerSupervisor(bound_instance=bound_instance, status=RunnerIdle())
|
||||
|
||||
runners = {RUNNER_1_ID: runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {RUNNER_1_ID: RunnerWaitingForModel()}
|
||||
all_runners = {RUNNER_1_ID: RunnerIdle()}
|
||||
|
||||
# Local status claims the shard is downloaded already
|
||||
local_download_status = {
|
||||
shard: DownloadCompleted(shard_metadata=shard, node_id=NODE_A) # type: ignore[reportUnhashable]
|
||||
MODEL_A_ID: DownloadCompleted(shard_metadata=shard, node_id=NODE_A)
|
||||
}
|
||||
|
||||
# Global view hasn't caught up yet (no completed shards recorded for NODE_A)
|
||||
@@ -183,19 +180,19 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerWaitingForModel()
|
||||
bound_instance=bound_instance, status=RunnerConnected()
|
||||
)
|
||||
|
||||
runners = {RUNNER_1_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerWaitingForModel(),
|
||||
RUNNER_2_ID: RunnerWaitingForModel(),
|
||||
RUNNER_1_ID: RunnerConnected(),
|
||||
RUNNER_2_ID: RunnerConnected(),
|
||||
}
|
||||
|
||||
# Only NODE_A's shard is recorded as downloaded globally
|
||||
local_download_status = {
|
||||
shard1: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A) # type: ignore[reportUnhashable]
|
||||
MODEL_A_ID: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)
|
||||
}
|
||||
global_download_status = {
|
||||
NODE_A: [DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)],
|
||||
@@ -213,3 +210,22 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
global_download_status = {
|
||||
NODE_A: [DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)],
|
||||
NODE_B: [
|
||||
DownloadCompleted(shard_metadata=shard2, node_id=NODE_B)
|
||||
], # NODE_B has no downloads completed yet
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status=local_download_status,
|
||||
global_download_status=global_download_status,
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
|
||||
@@ -5,9 +5,9 @@ from exo.shared.types.api import ChatCompletionTaskParams
|
||||
from exo.shared.types.tasks import ChatCompletion, Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.instances import BoundInstance, InstanceId
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerIdle,
|
||||
RunnerReady,
|
||||
RunnerRunning,
|
||||
RunnerWaitingForModel,
|
||||
)
|
||||
from exo.worker.tests.constants import (
|
||||
COMMAND_1_ID,
|
||||
@@ -99,7 +99,7 @@ def test_plan_does_not_forward_chat_completion_if_any_runner_not_ready():
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerReady(),
|
||||
RUNNER_2_ID: RunnerWaitingForModel(),
|
||||
RUNNER_2_ID: RunnerIdle(),
|
||||
}
|
||||
|
||||
task = ChatCompletion(
|
||||
|
||||
@@ -2,8 +2,9 @@ import exo.worker.plan as plan_mod
|
||||
from exo.shared.types.tasks import StartWarmup
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerIdle,
|
||||
RunnerLoaded,
|
||||
RunnerWaitingForModel,
|
||||
RunnerLoading,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.worker.tests.constants import (
|
||||
@@ -11,8 +12,10 @@ from exo.worker.tests.constants import (
|
||||
MODEL_A_ID,
|
||||
NODE_A,
|
||||
NODE_B,
|
||||
NODE_C,
|
||||
RUNNER_1_ID,
|
||||
RUNNER_2_ID,
|
||||
RUNNER_3_ID,
|
||||
)
|
||||
from exo.worker.tests.unittests.conftest import (
|
||||
FakeRunnerSupervisor,
|
||||
@@ -21,18 +24,19 @@ from exo.worker.tests.unittests.conftest import (
|
||||
)
|
||||
|
||||
|
||||
def test_plan_starts_warmup_for_non_zero_rank_when_all_loaded_or_warming():
|
||||
def test_plan_starts_warmup_for_accepting_rank_when_all_loaded_or_warming():
|
||||
"""
|
||||
For non-zero device_rank shards, StartWarmup should be emitted when all
|
||||
shards in the instance are Loaded/WarmingUp.
|
||||
"""
|
||||
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=3)
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=3)
|
||||
shard2 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=2, world_size=3)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID, NODE_C: RUNNER_3_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1, RUNNER_3_ID: shard2},
|
||||
)
|
||||
|
||||
bound_instance = BoundInstance(
|
||||
@@ -47,6 +51,7 @@ def test_plan_starts_warmup_for_non_zero_rank_when_all_loaded_or_warming():
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerLoaded(),
|
||||
RUNNER_2_ID: RunnerLoaded(),
|
||||
RUNNER_3_ID: RunnerWarmingUp(),
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
@@ -128,7 +133,7 @@ def test_plan_does_not_start_warmup_for_non_zero_rank_until_all_loaded_or_warmin
|
||||
runners = {RUNNER_2_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerWaitingForModel(),
|
||||
RUNNER_1_ID: RunnerIdle(),
|
||||
RUNNER_2_ID: RunnerLoaded(),
|
||||
}
|
||||
|
||||
@@ -149,6 +154,9 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
|
||||
"""
|
||||
Rank-zero shard should not start warmup until all non-zero ranks are
|
||||
already WarmingUp.
|
||||
For accepting ranks (device_rank != 0), StartWarmup should be
|
||||
emitted when all shards in the instance are Loaded/WarmingUp.
|
||||
In a 2-node setup, rank 1 is the accepting rank.
|
||||
"""
|
||||
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
@@ -159,6 +167,153 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
|
||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||
)
|
||||
|
||||
# Rank 1 is the accepting rank
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerLoaded()
|
||||
)
|
||||
|
||||
runners = {RUNNER_1_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerLoaded(),
|
||||
RUNNER_2_ID: RunnerLoaded(),
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerLoaded(),
|
||||
RUNNER_2_ID: RunnerWarmingUp(),
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert isinstance(result, StartWarmup)
|
||||
assert result.instance_id == INSTANCE_1_ID
|
||||
|
||||
|
||||
def test_plan_starts_warmup_for_connecting_rank_after_others_warming():
|
||||
"""
|
||||
For connecting rank (device_rank == world_size - 1), StartWarmup should
|
||||
only be emitted once all the other runners are already warming up.
|
||||
In a 2-node setup, rank 1 is the connecting rank.
|
||||
"""
|
||||
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||
)
|
||||
|
||||
# Rank 1 is the connecting rank
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_2_ID, bound_node_id=NODE_B
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerLoaded()
|
||||
)
|
||||
|
||||
runners = {RUNNER_2_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerWarmingUp(),
|
||||
RUNNER_2_ID: RunnerLoaded(),
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_B,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_B: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert isinstance(result, StartWarmup)
|
||||
assert result.instance_id == INSTANCE_1_ID
|
||||
|
||||
|
||||
def test_plan_does_not_start_warmup_for_accepting_rank_until_all_loaded_or_warming():
|
||||
"""
|
||||
Accepting rank should not start warmup while any shard is not Loaded/WarmingUp.
|
||||
In a 2-node setup, rank 0 is the accepting rank.
|
||||
"""
|
||||
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||
)
|
||||
|
||||
# Rank 0 is the accepting rank
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
local_runner = FakeRunnerSupervisor(
|
||||
bound_instance=bound_instance, status=RunnerLoaded()
|
||||
)
|
||||
|
||||
runners = {RUNNER_1_ID: local_runner}
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {
|
||||
RUNNER_1_ID: RunnerLoaded(),
|
||||
RUNNER_2_ID: RunnerLoading(),
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: [], NODE_B: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_plan_does_not_start_warmup_for_connecting_rank_until_others_warming():
|
||||
"""
|
||||
Connecting rank (device_rank == 0) should not start warmup
|
||||
until all other ranks are already WarmingUp.
|
||||
"""
|
||||
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
|
||||
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=1, world_size=2)
|
||||
instance = get_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
node_to_runner={NODE_A: RUNNER_1_ID, NODE_B: RUNNER_2_ID},
|
||||
runner_to_shard={RUNNER_1_ID: shard0, RUNNER_2_ID: shard1},
|
||||
)
|
||||
|
||||
# Rank 1 is the connecting rank
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RUNNER_1_ID, bound_node_id=NODE_A
|
||||
)
|
||||
|
||||
@@ -0,0 +1,212 @@
|
||||
# Check tasks are complete before runner is ever ready.
|
||||
from collections.abc import Iterable
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
import exo.worker.runner.runner as mlx_runner
|
||||
from exo.shared.types.api import ChatCompletionMessage
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
RunnerStatusUpdated,
|
||||
TaskAcknowledged,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ChatCompletionTaskParams,
|
||||
ConnectToGroup,
|
||||
LoadModel,
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
Task,
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.worker.runner_response import GenerationResponse
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerConnected,
|
||||
RunnerConnecting,
|
||||
RunnerIdle,
|
||||
RunnerLoaded,
|
||||
RunnerLoading,
|
||||
RunnerReady,
|
||||
RunnerRunning,
|
||||
RunnerShutdown,
|
||||
RunnerShuttingDown,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.utils.channels import mp_channel
|
||||
|
||||
from ...constants import (
|
||||
CHAT_COMPLETION_TASK_ID,
|
||||
COMMAND_1_ID,
|
||||
INITIALIZATION_TASK_ID,
|
||||
INSTANCE_1_ID,
|
||||
LOAD_TASK_ID,
|
||||
MODEL_A_ID,
|
||||
NODE_A,
|
||||
RUNNER_1_ID,
|
||||
SHUTDOWN_TASK_ID,
|
||||
WARMUP_TASK_ID,
|
||||
)
|
||||
from ..conftest import get_bound_mlx_ring_instance
|
||||
|
||||
|
||||
def make_nothin[T, U, V](res: T) -> Callable[[], T]:
|
||||
def nothin(*_1: U, **_2: V) -> T:
|
||||
return res
|
||||
|
||||
return nothin
|
||||
|
||||
|
||||
nothin = make_nothin(None)
|
||||
|
||||
|
||||
INIT_TASK = ConnectToGroup(
|
||||
task_id=INITIALIZATION_TASK_ID,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
)
|
||||
|
||||
LOAD_TASK = LoadModel(
|
||||
task_id=LOAD_TASK_ID,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
)
|
||||
|
||||
WARMUP_TASK = StartWarmup(
|
||||
task_id=WARMUP_TASK_ID,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
)
|
||||
|
||||
SHUTDOWN_TASK = Shutdown(
|
||||
task_id=SHUTDOWN_TASK_ID,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
runner_id=RUNNER_1_ID,
|
||||
)
|
||||
|
||||
CHAT_PARAMS = ChatCompletionTaskParams(
|
||||
model=str(MODEL_A_ID),
|
||||
messages=[ChatCompletionMessage(role="user", content="hello")],
|
||||
stream=True,
|
||||
max_tokens=4,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
CHAT_TASK = ChatCompletion(
|
||||
task_id=CHAT_COMPLETION_TASK_ID,
|
||||
command_id=COMMAND_1_ID,
|
||||
task_params=CHAT_PARAMS,
|
||||
instance_id=INSTANCE_1_ID,
|
||||
)
|
||||
|
||||
|
||||
def assert_events_equal(test_events: Iterable[Event], true_events: Iterable[Event]):
|
||||
for test_event, true_event in zip(test_events, true_events, strict=True):
|
||||
test_event.event_id = true_event.event_id
|
||||
assert test_event == true_event, f"{test_event} != {true_event}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
||||
# initialize_mlx returns a "group" equal to 1
|
||||
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, 1, 1)))
|
||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
|
||||
|
||||
def fake_generate(*_1: object, **_2: object):
|
||||
yield GenerationResponse(token=0, text="hi", finish_reason="stop")
|
||||
|
||||
monkeypatch.setattr(mlx_runner, "mlx_generate", fake_generate)
|
||||
|
||||
|
||||
def _run(tasks: Iterable[Task]):
|
||||
bound_instance = get_bound_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
runner_id=RUNNER_1_ID,
|
||||
node_id=NODE_A,
|
||||
)
|
||||
|
||||
task_sender, task_receiver = mp_channel[Task]()
|
||||
event_sender, event_receiver = mp_channel[Event]()
|
||||
|
||||
with task_sender, event_receiver:
|
||||
for t in tasks:
|
||||
task_sender.send(t)
|
||||
|
||||
# worst monkeypatch known to man
|
||||
# this is some c++ nonsense
|
||||
event_sender.close = nothin
|
||||
event_sender.join = nothin
|
||||
task_receiver.close = nothin
|
||||
task_receiver.join = nothin
|
||||
|
||||
mlx_runner.main(bound_instance, event_sender, task_receiver)
|
||||
|
||||
return event_receiver.collect()
|
||||
|
||||
|
||||
def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
|
||||
events = _run([INIT_TASK, LOAD_TASK, WARMUP_TASK, CHAT_TASK, SHUTDOWN_TASK])
|
||||
|
||||
expected_chunk = ChunkGenerated(
|
||||
command_id=COMMAND_1_ID,
|
||||
chunk=TokenChunk(
|
||||
idx=0,
|
||||
model=MODEL_A_ID,
|
||||
text="hi",
|
||||
token_id=0,
|
||||
finish_reason="stop",
|
||||
),
|
||||
)
|
||||
|
||||
assert_events_equal(
|
||||
events,
|
||||
[
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerIdle()),
|
||||
TaskStatusUpdated(
|
||||
task_id=INITIALIZATION_TASK_ID, task_status=TaskStatus.Running
|
||||
),
|
||||
TaskAcknowledged(task_id=INITIALIZATION_TASK_ID),
|
||||
RunnerStatusUpdated(
|
||||
runner_id=RUNNER_1_ID, runner_status=RunnerConnecting()
|
||||
),
|
||||
TaskStatusUpdated(
|
||||
task_id=INITIALIZATION_TASK_ID, task_status=TaskStatus.Complete
|
||||
),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerConnected()),
|
||||
TaskStatusUpdated(task_id=LOAD_TASK_ID, task_status=TaskStatus.Running),
|
||||
TaskAcknowledged(task_id=LOAD_TASK_ID),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerLoading()),
|
||||
TaskStatusUpdated(task_id=LOAD_TASK_ID, task_status=TaskStatus.Complete),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerLoaded()),
|
||||
TaskStatusUpdated(task_id=WARMUP_TASK_ID, task_status=TaskStatus.Running),
|
||||
TaskAcknowledged(task_id=WARMUP_TASK_ID),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerWarmingUp()),
|
||||
TaskStatusUpdated(task_id=WARMUP_TASK_ID, task_status=TaskStatus.Complete),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
|
||||
TaskStatusUpdated(
|
||||
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Running
|
||||
),
|
||||
TaskAcknowledged(task_id=CHAT_COMPLETION_TASK_ID),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerRunning()),
|
||||
expected_chunk,
|
||||
TaskStatusUpdated(
|
||||
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Complete
|
||||
),
|
||||
# CHAT COMPLETION TASK SHOULD COMPLETE BEFORE RUNNER READY
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
|
||||
TaskStatusUpdated(task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Running),
|
||||
TaskAcknowledged(task_id=SHUTDOWN_TASK_ID),
|
||||
RunnerStatusUpdated(
|
||||
runner_id=RUNNER_1_ID, runner_status=RunnerShuttingDown()
|
||||
),
|
||||
TaskStatusUpdated(
|
||||
task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Complete
|
||||
),
|
||||
# SPECIAL EXCEPTION FOR RUNNER SHUTDOWN
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerShutdown()),
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
# TODO:
|
||||
@@ -1,33 +1,64 @@
|
||||
import socket
|
||||
import http.client
|
||||
|
||||
from anyio import create_task_group, to_thread
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import NodeId
|
||||
|
||||
|
||||
# TODO: ref. api port
|
||||
async def check_reachability(
|
||||
target_ip: str, target_node_id: NodeId, out: dict[NodeId, set[str]]
|
||||
target_ip: str,
|
||||
expected_node_id: NodeId,
|
||||
self_node_id: NodeId,
|
||||
out: dict[NodeId, set[str]],
|
||||
) -> None:
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.settimeout(1) # 1 second timeout
|
||||
try:
|
||||
result = await to_thread.run_sync(sock.connect_ex, (target_ip, 52415))
|
||||
except socket.gaierror:
|
||||
# seems to throw on ipv6 loopback. oh well
|
||||
# logger.warning(f"invalid {target_ip=}")
|
||||
"""Check if a node is reachable at the given IP and verify its identity."""
|
||||
|
||||
def _fetch_remote_node_id() -> NodeId | None:
|
||||
connection = http.client.HTTPConnection(target_ip, 52415, timeout=1)
|
||||
try:
|
||||
connection.request("GET", "/node_id")
|
||||
response = connection.getresponse()
|
||||
if response.status != 200:
|
||||
return None
|
||||
|
||||
body = response.read().decode("utf-8").strip()
|
||||
|
||||
# Strip quotes if present (JSON string response)
|
||||
if body.startswith('"') and body.endswith('"') and len(body) >= 2:
|
||||
body = body[1:-1]
|
||||
|
||||
return NodeId(body) or None
|
||||
except OSError:
|
||||
return None
|
||||
finally:
|
||||
connection.close()
|
||||
|
||||
remote_node_id = await to_thread.run_sync(_fetch_remote_node_id)
|
||||
if remote_node_id is None:
|
||||
return
|
||||
finally:
|
||||
sock.close()
|
||||
|
||||
if result == 0:
|
||||
if target_node_id not in out:
|
||||
out[target_node_id] = set()
|
||||
out[target_node_id].add(target_ip)
|
||||
if remote_node_id == self_node_id:
|
||||
return
|
||||
|
||||
if remote_node_id != expected_node_id:
|
||||
logger.warning(
|
||||
f"Discovered node with unexpected node_id; "
|
||||
f"ip={target_ip}, expected_node_id={expected_node_id}, "
|
||||
f"remote_node_id={remote_node_id}"
|
||||
)
|
||||
return
|
||||
|
||||
if remote_node_id not in out:
|
||||
out[remote_node_id] = set()
|
||||
out[remote_node_id].add(target_ip)
|
||||
|
||||
|
||||
async def check_reachable(topology: Topology) -> dict[NodeId, set[str]]:
|
||||
async def check_reachable(
|
||||
topology: Topology, self_node_id: NodeId
|
||||
) -> dict[NodeId, set[str]]:
|
||||
"""Check which nodes are reachable and return their IPs."""
|
||||
reachable: dict[NodeId, set[str]] = {}
|
||||
async with create_task_group() as tg:
|
||||
for node in topology.list_nodes():
|
||||
@@ -35,7 +66,11 @@ async def check_reachable(topology: Topology) -> dict[NodeId, set[str]]:
|
||||
continue
|
||||
for iface in node.node_profile.network_interfaces:
|
||||
tg.start_soon(
|
||||
check_reachability, iface.ip_address, node.node_id, reachable
|
||||
check_reachability,
|
||||
iface.ip_address,
|
||||
node.node_id,
|
||||
self_node_id,
|
||||
reachable,
|
||||
)
|
||||
|
||||
return reachable
|
||||
|
||||
Reference in New Issue
Block a user