mirror of
https://github.com/exo-explore/exo.git
synced 2026-04-17 12:30:29 -04:00
## Motivation Images! TODO (in a future PR): Add audio and video support. ## Test Plan ### Manual Testing <img width="2652" height="1900" alt="image" src="https://github.com/user-attachments/assets/7d3a7137-542f-4f94-9193-2c73b7c4a5ec" /> <img width="2770" height="1956" alt="image" src="https://github.com/user-attachments/assets/e3c3a096-8029-4409-97a6-aca31a9a3f24" /> <img width="2738" height="1768" alt="image" src="https://github.com/user-attachments/assets/d70ea37f-cd1d-4a4c-ad08-3beb9fafa380" /> (And batching also works) --------- Co-authored-by: David Hind <davehind@yahoo.co.uk>
64 lines
2.2 KiB
Python
64 lines
2.2 KiB
Python
from exo.worker.engines.mlx.cache import KVPrefixCache
|
|
from exo.worker.engines.mlx.vision import MediaRegion
|
|
|
|
validate = KVPrefixCache._validate_media_match
|
|
|
|
|
|
class TestValidateMediaMatch:
|
|
def test_text_only_no_truncation(self):
|
|
assert validate(8000, [], []) == 8000
|
|
|
|
def test_text_prefix_before_image(self):
|
|
cached = [MediaRegion("hashA", 5000, 8600)]
|
|
assert validate(5000, cached, []) == 5000
|
|
|
|
def test_same_image_same_position(self):
|
|
cached = [MediaRegion("hashA", 5000, 8600)]
|
|
query = [MediaRegion("hashA", 5000, 8600)]
|
|
assert validate(9000, cached, query) == 9000
|
|
|
|
def test_different_image_truncates(self):
|
|
cached = [MediaRegion("hashA", 5000, 8600)]
|
|
query = [MediaRegion("hashB", 5000, 8600)]
|
|
assert validate(9000, cached, query) == 5000
|
|
|
|
def test_match_below_region_start(self):
|
|
cached = [MediaRegion("hashA", 5000, 8600)]
|
|
query = [MediaRegion("hashB", 5000, 8600)]
|
|
assert validate(4000, cached, query) == 4000
|
|
|
|
def test_text_followup_no_images_in_query(self):
|
|
cached = [MediaRegion("hashA", 5000, 8600)]
|
|
assert validate(9000, cached, []) == 9000
|
|
|
|
def test_multiple_images_first_mismatch_truncates(self):
|
|
cached = [
|
|
MediaRegion("hashA", 2000, 4000),
|
|
MediaRegion("hashB", 6000, 8000),
|
|
]
|
|
query = [
|
|
MediaRegion("hashA", 2000, 4000),
|
|
MediaRegion("hashC", 6000, 8000),
|
|
]
|
|
assert validate(9000, cached, query) == 6000
|
|
|
|
def test_multiple_images_all_match(self):
|
|
cached = [
|
|
MediaRegion("hashA", 2000, 4000),
|
|
MediaRegion("hashB", 6000, 8000),
|
|
]
|
|
query = [
|
|
MediaRegion("hashA", 2000, 4000),
|
|
MediaRegion("hashB", 6000, 8000),
|
|
]
|
|
assert validate(9000, cached, query) == 9000
|
|
|
|
def test_no_cached_regions(self):
|
|
query = [MediaRegion("hashA", 100, 200)]
|
|
assert validate(500, [], query) == 500
|
|
|
|
def test_cached_region_beyond_match(self):
|
|
cached = [MediaRegion("hashA", 10000, 12000)]
|
|
query = [MediaRegion("hashB", 10000, 12000)]
|
|
assert validate(5000, cached, query) == 5000
|