Files
exo/tests/test_vision_cache.py
rltakashige 635801d515 Add multimodality! (#1802)
## Motivation

Images!

TODO (in a future PR): Add audio and video support.

## Test Plan

### Manual Testing
<img width="2652" height="1900" alt="image"
src="https://github.com/user-attachments/assets/7d3a7137-542f-4f94-9193-2c73b7c4a5ec"
/>

<img width="2770" height="1956" alt="image"
src="https://github.com/user-attachments/assets/e3c3a096-8029-4409-97a6-aca31a9a3f24"
/>
<img width="2738" height="1768" alt="image"
src="https://github.com/user-attachments/assets/d70ea37f-cd1d-4a4c-ad08-3beb9fafa380"
/>

(And batching also works)

---------

Co-authored-by: David Hind <davehind@yahoo.co.uk>
2026-03-30 11:52:19 +01:00

64 lines
2.2 KiB
Python

from exo.worker.engines.mlx.cache import KVPrefixCache
from exo.worker.engines.mlx.vision import MediaRegion
validate = KVPrefixCache._validate_media_match
class TestValidateMediaMatch:
def test_text_only_no_truncation(self):
assert validate(8000, [], []) == 8000
def test_text_prefix_before_image(self):
cached = [MediaRegion("hashA", 5000, 8600)]
assert validate(5000, cached, []) == 5000
def test_same_image_same_position(self):
cached = [MediaRegion("hashA", 5000, 8600)]
query = [MediaRegion("hashA", 5000, 8600)]
assert validate(9000, cached, query) == 9000
def test_different_image_truncates(self):
cached = [MediaRegion("hashA", 5000, 8600)]
query = [MediaRegion("hashB", 5000, 8600)]
assert validate(9000, cached, query) == 5000
def test_match_below_region_start(self):
cached = [MediaRegion("hashA", 5000, 8600)]
query = [MediaRegion("hashB", 5000, 8600)]
assert validate(4000, cached, query) == 4000
def test_text_followup_no_images_in_query(self):
cached = [MediaRegion("hashA", 5000, 8600)]
assert validate(9000, cached, []) == 9000
def test_multiple_images_first_mismatch_truncates(self):
cached = [
MediaRegion("hashA", 2000, 4000),
MediaRegion("hashB", 6000, 8000),
]
query = [
MediaRegion("hashA", 2000, 4000),
MediaRegion("hashC", 6000, 8000),
]
assert validate(9000, cached, query) == 6000
def test_multiple_images_all_match(self):
cached = [
MediaRegion("hashA", 2000, 4000),
MediaRegion("hashB", 6000, 8000),
]
query = [
MediaRegion("hashA", 2000, 4000),
MediaRegion("hashB", 6000, 8000),
]
assert validate(9000, cached, query) == 9000
def test_no_cached_regions(self):
query = [MediaRegion("hashA", 100, 200)]
assert validate(500, [], query) == 500
def test_cached_region_beyond_match(self):
cached = [MediaRegion("hashA", 10000, 12000)]
query = [MediaRegion("hashB", 10000, 12000)]
assert validate(5000, cached, query) == 5000