Add adaptor registry

This commit is contained in:
ciaranbor
2025-12-20 17:06:22 +00:00
parent 46181a35ae
commit 3bcdd46bb1
2 changed files with 49 additions and 16 deletions

View File

@@ -21,30 +21,15 @@ from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.download.download_utils import build_model_path
from exo.worker.engines.mflux.config import get_config_for_model
from exo.worker.engines.mflux.config.model_config import ImageModelConfig
from exo.worker.engines.mflux.pipefusion import get_adapter_for_model
from exo.worker.engines.mflux.pipefusion.adapter import ModelAdapter
from exo.worker.engines.mflux.pipefusion.distributed_denoising import (
DistributedDenoising,
)
from exo.worker.engines.mflux.pipefusion.flux_adapter import FluxModelAdapter
from exo.worker.engines.mlx.utils_mlx import mlx_distributed_init, mx_barrier
from exo.worker.runner.bootstrap import logger
def get_adapter_for_model(config: ImageModelConfig) -> ModelAdapter:
"""Get the appropriate adapter for a model configuration.
Args:
config: The model configuration
Returns:
A ModelAdapter instance for the model family
"""
if config.model_family == "flux":
return FluxModelAdapter(config)
else:
raise ValueError(f"No adapter found for model family: {config.model_family}")
class DistributedImageModel:
"""
Model-agnostic wrapper for distributed image generation.

View File

@@ -0,0 +1,48 @@
"""
Adapter registry for model-specific operations.
This module provides a registry pattern for managing model adapters,
allowing new model families to be added without modifying core code.
"""
from typing import Callable
from exo.worker.engines.mflux.config.model_config import ImageModelConfig
from exo.worker.engines.mflux.pipefusion.adapter import ModelAdapter
from exo.worker.engines.mflux.pipefusion.flux_adapter import FluxModelAdapter
# Type alias for adapter factory functions
AdapterFactory = Callable[[ImageModelConfig], ModelAdapter]
# Registry maps model_family string to adapter factory
_ADAPTER_REGISTRY: dict[str, AdapterFactory] = {
"flux": FluxModelAdapter,
}
def get_adapter_for_model(config: ImageModelConfig) -> ModelAdapter:
"""Get the appropriate adapter for a model configuration.
Args:
config: The model configuration
Returns:
A ModelAdapter instance for the model family
Raises:
ValueError: If no adapter is registered for the model family
"""
factory = _ADAPTER_REGISTRY.get(config.model_family)
if factory is None:
raise ValueError(f"No adapter found for model family: {config.model_family}")
return factory(config)
def register_adapter(model_family: str, factory: AdapterFactory) -> None:
"""Register a new adapter factory for a model family.
Args:
model_family: The model family identifier (e.g., "flux", "fibo", "qwen")
factory: A callable that takes an ImageModelConfig and returns a ModelAdapter
"""
_ADAPTER_REGISTRY[model_family] = factory