diff --git a/src/exo/worker/engines/mlx/auto_parallel.py b/src/exo/worker/engines/mlx/auto_parallel.py index 28e82f73..1e470399 100644 --- a/src/exo/worker/engines/mlx/auto_parallel.py +++ b/src/exo/worker/engines/mlx/auto_parallel.py @@ -635,7 +635,7 @@ class WrappedMiniMaxAttention(CustomMlxLayer): self, x: mx.array, mask: mx.array | None = None, - cache: Cache | None = None, + cache: "Cache | None" = None, ) -> mx.array: batch_dim, seq_dim, _ = x.shape