feat(diffusers): add experimental support for sd_embed-style prompt embedding (#8504)

* add experimental support for sd_embed-style prompt embedding

Signed-off-by: Austen Dicken <cvpcsm@gmail.com>

* add doc equivalent to compel

Signed-off-by: Austen Dicken <cvpcsm@gmail.com>

* need to use flux1 embedding function for flux model

Signed-off-by: Austen Dicken <cvpcsm@gmail.com>

---------

Signed-off-by: Austen Dicken <cvpcsm@gmail.com>
This commit is contained in:
Austen
2026-02-11 15:58:19 -06:00
committed by GitHub
parent 79a25f7ae9
commit cff972094c
10 changed files with 58 additions and 0 deletions

View File

@@ -40,6 +40,7 @@ from compel import Compel, ReturnedEmbeddingsType
from optimum.quanto import freeze, qfloat8, quantize
from transformers import T5EncoderModel
from safetensors.torch import load_file
from sd_embed.embedding_funcs import get_weighted_text_embeddings_sd15, get_weighted_text_embeddings_sdxl, get_weighted_text_embeddings_sd3, get_weighted_text_embeddings_flux1
# Import LTX-2 specific utilities
from diffusers.pipelines.ltx2.export_utils import encode_video as ltx2_encode_video
@@ -47,6 +48,7 @@ from diffusers import LTX2VideoTransformer3DModel, GGUFQuantizationConfig
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
COMPEL = os.environ.get("COMPEL", "0") == "1"
SD_EMBED = os.environ.get("SD_EMBED", "0") == "1"
XPU = os.environ.get("XPU", "0") == "1"
CLIPSKIP = os.environ.get("CLIPSKIP", "1") == "1"
SAFETENSORS = os.environ.get("SAFETENSORS", "1") == "1"
@@ -737,6 +739,51 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
kwargs["prompt_embeds"] = conditioning
kwargs["pooled_prompt_embeds"] = pooled
# pass the kwargs dictionary to the self.pipe method
image = self.pipe(
guidance_scale=self.cfg_scale,
**kwargs
).images[0]
elif SD_EMBED:
if self.PipelineType == "StableDiffusionPipeline":
(
kwargs["prompt_embeds"],
kwargs["negative_prompt_embeds"],
) = get_weighted_text_embeddings_sd15(
pipe = self.pipe,
prompt = prompt,
neg_prompt = request.negative_prompt if hasattr(request, 'negative_prompt') else None,
)
if self.PipelineType == "StableDiffusionXLPipeline":
(
kwargs["prompt_embeds"],
kwargs["negative_prompt_embeds"],
kwargs["pooled_prompt_embeds"],
kwargs["negative_pooled_prompt_embeds"],
) = get_weighted_text_embeddings_sdxl(
pipe = self.pipe,
prompt = prompt,
neg_prompt = request.negative_prompt if hasattr(request, 'negative_prompt') else None
)
if self.PipelineType == "StableDiffusion3Pipeline":
(
kwargs["prompt_embeds"],
kwargs["negative_prompt_embeds"],
kwargs["pooled_prompt_embeds"],
kwargs["negative_pooled_prompt_embeds"],
) = get_weighted_text_embeddings_sd3(
pipe = self.pipe,
prompt = prompt,
neg_prompt = request.negative_prompt if hasattr(request, 'negative_prompt') else None
)
if self.PipelineType == "FluxTransformer2DModel":
(
kwargs["prompt_embeds"],
kwargs["pooled_prompt_embeds"],
) = get_weighted_text_embeddings_flux1(
pipe = self.pipe,
prompt = prompt,
)
image = self.pipe(
guidance_scale=self.cfg_scale,
**kwargs