Files
OpenLLM/openllm-python/src/openllm/models/auto/modeling_flax_auto.py
2023-08-22 08:55:46 -04:00

11 lines
461 B
Python

from __future__ import annotations
import typing as t
from collections import OrderedDict
from .factory import BaseAutoLLMClass, _LazyAutoMapping
from openllm_core.config import CONFIG_MAPPING_NAMES
MODEL_FLAX_MAPPING_NAMES = OrderedDict([("flan_t5", "FlaxFlanT5"), ("opt", "FlaxOPT")])
MODEL_FLAX_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FLAX_MAPPING_NAMES)
class AutoFlaxLLM(BaseAutoLLMClass):
_model_mapping: t.ClassVar = MODEL_FLAX_MAPPING