Files
LocalAI/backend/python/diffusers/test.py
Attila Györffy 5a67b5d73c Fix image upload processing and img2img pipeline in diffusers backend (#8879)
* fix: add missing bufio.Flush in processImageFile

The processImageFile function writes decoded image data (from base64
or URL download) through a bufio.NewWriter but never calls Flush()
before closing the underlying file. Since bufio's default buffer is
4096 bytes, small images produce 0-byte files and large images are
truncated — causing PIL to fail with "cannot identify image file".

This breaks all image input paths: file, files, and ref_images
parameters in /v1/images/generations, making img2img, inpainting,
and reference image features non-functional.

Signed-off-by: Attila Györffy <attila+git@attilagyorffy.com>

* fix: merge options into kwargs in diffusers GenerateImage

The GenerateImage method builds a local `options` dict containing the
source image (PIL), negative_prompt, and num_inference_steps, but
never merges it into `kwargs` before calling self.pipe(**kwargs).
This causes img2img to fail with "Input is in incorrect format"
because the pipeline never receives the image parameter.

Signed-off-by: Attila Györffy <attila+git@attilagyorffy.com>

* test: add unit test for processImageFile base64 decoding

Verifies that a base64-encoded PNG survives the write path
(encode → decode → bufio.Write → Flush → file on disk) with
byte-for-byte fidelity. The test image is small enough to fit
entirely in bufio's 4096-byte buffer, which is the exact scenario
where the missing Flush() produced a 0-byte file.

Also tests that invalid base64 input is handled gracefully.

Signed-off-by: Attila Györffy <attila+git@attilagyorffy.com>

* test: verify GenerateImage merges options into pipeline kwargs

Mocks the diffusers pipeline and calls GenerateImage with a source
image and negative prompt. Asserts that the pipeline receives the
image, negative_prompt, and num_inference_steps via kwargs — the
exact parameters that were silently dropped before the fix.

Signed-off-by: Attila Györffy <attila+git@attilagyorffy.com>

* fix: move kwargs.update(options) earlier in GenerateImage

Move the options merge right after self.options merge (L742) so that
image, negative_prompt, and num_inference_steps are available to all
downstream code paths including img2vid and txt2vid.

Signed-off-by: Attila Györffy <attila+git@attilagyorffy.com>

* test: convert processImageFile tests to ginkgo

Replace standard testing with ginkgo/gomega to be consistent with
the rest of the test suites in the project.

Signed-off-by: Attila Györffy <attila+git@attilagyorffy.com>

---------

Signed-off-by: Attila Györffy <attila+git@attilagyorffy.com>
Co-authored-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
2026-03-11 08:05:50 +01:00

376 lines
14 KiB
Python

"""
A test script to test the gRPC service and dynamic loader
"""
import unittest
import subprocess
import time
from unittest.mock import patch, MagicMock
# Import dynamic loader for testing (these don't need gRPC)
import diffusers_dynamic_loader as loader
from diffusers import DiffusionPipeline, StableDiffusionPipeline
# Try to import gRPC modules - may not be available during unit testing
try:
import grpc
import backend_pb2
import backend_pb2_grpc
GRPC_AVAILABLE = True
except ImportError:
GRPC_AVAILABLE = False
@unittest.skipUnless(GRPC_AVAILABLE, "gRPC modules not available")
class TestBackendServicer(unittest.TestCase):
"""
TestBackendServicer is the class that tests the gRPC service
"""
def setUp(self):
"""
This method sets up the gRPC service by starting the server
"""
self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"])
def tearDown(self) -> None:
"""
This method tears down the gRPC service by terminating the server
"""
self.service.kill()
self.service.wait()
def test_server_startup(self):
"""
This method tests if the server starts up successfully
"""
time.sleep(20)
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.Health(backend_pb2.HealthMessage())
self.assertEqual(response.message, b'OK')
except Exception as err:
print(err)
self.fail("Server failed to start")
finally:
self.tearDown()
def test_load_model(self):
"""
This method tests if the model is loaded successfully
"""
time.sleep(20)
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.LoadModel(backend_pb2.ModelOptions(Model="Lykon/dreamshaper-8"))
self.assertTrue(response.success)
self.assertEqual(response.message, "Model loaded successfully")
except Exception as err:
print(err)
self.fail("LoadModel service failed")
finally:
self.tearDown()
def test(self):
"""
This method tests if the backend can generate images
"""
time.sleep(20)
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.LoadModel(backend_pb2.ModelOptions(Model="Lykon/dreamshaper-8"))
print(response.message)
self.assertTrue(response.success)
image_req = backend_pb2.GenerateImageRequest(positive_prompt="cat", width=16,height=16, dst="test.jpg")
re = stub.GenerateImage(image_req)
self.assertTrue(re.success)
except Exception as err:
print(err)
self.fail("Image gen service failed")
finally:
self.tearDown()
class TestDiffusersDynamicLoader(unittest.TestCase):
"""Test cases for the diffusers dynamic loader functionality."""
@classmethod
def setUpClass(cls):
"""Set up test fixtures - clear caches to ensure fresh discovery."""
# Reset the caches to ensure fresh discovery
loader._pipeline_registry = None
loader._task_aliases = None
def test_camel_to_kebab_conversion(self):
"""Test CamelCase to kebab-case conversion."""
test_cases = [
("StableDiffusionPipeline", "stable-diffusion-pipeline"),
("StableDiffusionXLPipeline", "stable-diffusion-xl-pipeline"),
("FluxPipeline", "flux-pipeline"),
("DiffusionPipeline", "diffusion-pipeline"),
]
for input_val, expected in test_cases:
with self.subTest(input=input_val):
result = loader._camel_to_kebab(input_val)
self.assertEqual(result, expected)
def test_extract_task_keywords(self):
"""Test task keyword extraction from class names."""
# Test text-to-image detection
aliases = loader._extract_task_keywords("StableDiffusionPipeline")
self.assertIn("stable-diffusion", aliases)
# Test img2img detection
aliases = loader._extract_task_keywords("StableDiffusionImg2ImgPipeline")
self.assertIn("image-to-image", aliases)
self.assertIn("img2img", aliases)
# Test inpainting detection
aliases = loader._extract_task_keywords("StableDiffusionInpaintPipeline")
self.assertIn("inpainting", aliases)
self.assertIn("inpaint", aliases)
# Test depth2img detection
aliases = loader._extract_task_keywords("StableDiffusionDepth2ImgPipeline")
self.assertIn("depth-to-image", aliases)
def test_discover_pipelines_finds_known_classes(self):
"""Test that pipeline discovery finds at least one known pipeline class."""
registry = loader.get_pipeline_registry()
# Check that the registry is not empty
self.assertGreater(len(registry), 0, "Pipeline registry should not be empty")
# Check for known pipeline classes
known_pipelines = [
"StableDiffusionPipeline",
"DiffusionPipeline",
]
for pipeline_name in known_pipelines:
with self.subTest(pipeline=pipeline_name):
self.assertIn(
pipeline_name,
registry,
f"Expected to find {pipeline_name} in registry"
)
def test_discover_pipelines_caches_results(self):
"""Test that pipeline discovery results are cached."""
# Get registry twice
registry1 = loader.get_pipeline_registry()
registry2 = loader.get_pipeline_registry()
# Should be the same object (cached)
self.assertIs(registry1, registry2, "Registry should be cached")
def test_get_available_pipelines(self):
"""Test getting list of available pipelines."""
available = loader.get_available_pipelines()
# Should return a list
self.assertIsInstance(available, list)
# Should contain known pipelines
self.assertIn("StableDiffusionPipeline", available)
self.assertIn("DiffusionPipeline", available)
# Should be sorted
self.assertEqual(available, sorted(available))
def test_get_available_tasks(self):
"""Test getting list of available task aliases."""
tasks = loader.get_available_tasks()
# Should return a list
self.assertIsInstance(tasks, list)
# Should be sorted
self.assertEqual(tasks, sorted(tasks))
def test_resolve_pipeline_class_by_name(self):
"""Test resolving pipeline class by exact name."""
cls = loader.resolve_pipeline_class(class_name="StableDiffusionPipeline")
self.assertEqual(cls, StableDiffusionPipeline)
def test_resolve_pipeline_class_by_name_case_insensitive(self):
"""Test that class name resolution is case-insensitive."""
cls1 = loader.resolve_pipeline_class(class_name="StableDiffusionPipeline")
cls2 = loader.resolve_pipeline_class(class_name="stablediffusionpipeline")
self.assertEqual(cls1, cls2)
def test_resolve_pipeline_class_by_task(self):
"""Test resolving pipeline class by task alias."""
# Get the registry to find available tasks
aliases = loader.get_task_aliases()
# Test with a common task that should be available
if "stable-diffusion" in aliases:
cls = loader.resolve_pipeline_class(task="stable-diffusion")
self.assertIsNotNone(cls)
def test_resolve_pipeline_class_unknown_name_raises(self):
"""Test that resolving unknown class name raises ValueError with helpful message."""
with self.assertRaises(ValueError) as ctx:
loader.resolve_pipeline_class(class_name="NonExistentPipeline")
# Check that error message includes available pipelines
error_msg = str(ctx.exception)
self.assertIn("Unknown pipeline class", error_msg)
self.assertIn("Available pipelines", error_msg)
def test_resolve_pipeline_class_unknown_task_raises(self):
"""Test that resolving unknown task raises ValueError with helpful message."""
with self.assertRaises(ValueError) as ctx:
loader.resolve_pipeline_class(task="nonexistent-task-xyz")
# Check that error message includes available tasks
error_msg = str(ctx.exception)
self.assertIn("Unknown task", error_msg)
self.assertIn("Available tasks", error_msg)
def test_resolve_pipeline_class_no_params_raises(self):
"""Test that calling with no parameters raises helpful ValueError."""
with self.assertRaises(ValueError) as ctx:
loader.resolve_pipeline_class()
error_msg = str(ctx.exception)
self.assertIn("Must provide at least one of", error_msg)
def test_get_pipeline_info(self):
"""Test getting pipeline information."""
info = loader.get_pipeline_info("StableDiffusionPipeline")
self.assertEqual(info['name'], "StableDiffusionPipeline")
self.assertIsInstance(info['aliases'], list)
self.assertIsInstance(info['supports_single_file'], bool)
def test_get_pipeline_info_unknown_raises(self):
"""Test that getting info for unknown pipeline raises ValueError."""
with self.assertRaises(ValueError) as ctx:
loader.get_pipeline_info("NonExistentPipeline")
self.assertIn("Unknown pipeline", str(ctx.exception))
def test_discover_diffusers_classes_pipelines(self):
"""Test generic class discovery for DiffusionPipeline."""
classes = loader.discover_diffusers_classes("DiffusionPipeline")
# Should return a dict
self.assertIsInstance(classes, dict)
# Should contain known pipeline classes
self.assertIn("DiffusionPipeline", classes)
self.assertIn("StableDiffusionPipeline", classes)
def test_discover_diffusers_classes_caches_results(self):
"""Test that class discovery results are cached."""
classes1 = loader.discover_diffusers_classes("DiffusionPipeline")
classes2 = loader.discover_diffusers_classes("DiffusionPipeline")
# Should be the same object (cached)
self.assertIs(classes1, classes2)
def test_discover_diffusers_classes_exclude_base(self):
"""Test discovering classes without base class."""
classes = loader.discover_diffusers_classes("DiffusionPipeline", include_base=False)
# Should still contain subclasses
self.assertIn("StableDiffusionPipeline", classes)
def test_get_available_classes(self):
"""Test getting list of available classes for a base class."""
classes = loader.get_available_classes("DiffusionPipeline")
# Should return a sorted list
self.assertIsInstance(classes, list)
self.assertEqual(classes, sorted(classes))
# Should contain known classes
self.assertIn("StableDiffusionPipeline", classes)
class TestDiffusersDynamicLoaderWithMocks(unittest.TestCase):
"""Test cases using mocks to test edge cases."""
def test_load_pipeline_requires_model_id(self):
"""Test that load_diffusers_pipeline requires model_id."""
with self.assertRaises(ValueError) as ctx:
loader.load_diffusers_pipeline(class_name="StableDiffusionPipeline")
self.assertIn("model_id is required", str(ctx.exception))
def test_resolve_with_model_id_uses_diffusion_pipeline_fallback(self):
"""Test that resolving with only model_id falls back to DiffusionPipeline."""
# When model_id is provided, if hub lookup is not successful,
# should fall back to DiffusionPipeline.
# This tests the fallback behavior - the actual hub lookup may succeed
# or fail depending on network, but the fallback path should work.
cls = loader.resolve_pipeline_class(model_id="some/nonexistent/model")
self.assertEqual(cls, DiffusionPipeline)
@unittest.skipUnless(GRPC_AVAILABLE, "gRPC modules not available")
class TestGenerateImageOptionsKwargsMerge(unittest.TestCase):
"""Test that GenerateImage merges the options dict into pipeline kwargs.
The options dict holds image (PIL), negative_prompt, and
num_inference_steps. Without the merge, img2img pipelines never
receive the source image and fail with 'Input is in incorrect format'.
"""
def test_options_merged_into_pipeline_kwargs(self):
from backend import BackendServicer
from PIL import Image
import tempfile, os
svc = BackendServicer.__new__(BackendServicer)
# Minimal attributes the method reads
svc.pipe = MagicMock()
svc.pipe.return_value.images = [Image.new("RGB", (4, 4))]
svc.cfg_scale = 7.5
svc.controlnet = None
svc.img2vid = False
svc.txt2vid = False
svc.clip_skip = 0
svc.PipelineType = "StableDiffusionImg2ImgPipeline"
svc.options = {}
# Create a tiny source image for the request's src field
src_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
Image.new("RGB", (4, 4), color="red").save(src_file, format="PNG")
src_file.close()
dst_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
dst_file.close()
try:
request = MagicMock()
request.positive_prompt = "a test prompt"
request.negative_prompt = "bad quality"
request.step = 10
request.seed = 0
request.width = 0
request.height = 0
request.src = src_file.name
request.ref_images = []
request.dst = dst_file.name
svc.GenerateImage(request, context=None)
# The pipeline must have been called with the image kwarg
svc.pipe.assert_called_once()
_, call_kwargs = svc.pipe.call_args
self.assertIn("image", call_kwargs,
"source image must be passed to pipeline via kwargs")
self.assertIn("negative_prompt", call_kwargs,
"negative_prompt must be passed to pipeline via kwargs")
self.assertEqual(call_kwargs["num_inference_steps"], 10)
finally:
os.unlink(src_file.name)
os.unlink(dst_file.name)