mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
394 lines
14 KiB
Python
394 lines
14 KiB
Python
"""
|
|
This type stub file was generated by pyright.
|
|
"""
|
|
|
|
from typing import Any, Callable, List, Optional, Tuple, Union
|
|
|
|
import mlx.core as mx
|
|
|
|
class Module(dict):
|
|
"""Base class for building neural networks with MLX.
|
|
|
|
All the layers provided in :mod:`layers` subclass this class and
|
|
your models should do the same.
|
|
|
|
A ``Module`` can contain other ``Module`` instances or :class:`mlx.core.array`
|
|
instances in arbitrary nesting of python lists or dicts. The ``Module``
|
|
then allows recursively extracting all the :class:`mlx.core.array` instances
|
|
using :meth:`Module.parameters`.
|
|
|
|
In addition, the ``Module`` has the concept of trainable and non trainable
|
|
parameters (called "frozen"). When using :func:`value_and_grad`
|
|
the gradients are returned only with respect to the trainable parameters.
|
|
All arrays in a module are trainable unless they are added in the "frozen"
|
|
set by calling :meth:`freeze`.
|
|
|
|
.. code-block:: python
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
|
|
class MyMLP(nn.Module):
|
|
def __init__(self, in_dims: int, out_dims: int, hidden_dims: int = 16):
|
|
super().__init__()
|
|
|
|
self.in_proj = nn.Linear(in_dims, hidden_dims)
|
|
self.out_proj = nn.Linear(hidden_dims, out_dims)
|
|
|
|
def __call__(self, x):
|
|
x = self.in_proj(x)
|
|
x = mx.maximum(x, 0)
|
|
return self.out_proj(x)
|
|
|
|
model = MyMLP(2, 1)
|
|
|
|
# All the model parameters are created but since MLX is lazy by
|
|
# default, they are not evaluated yet. Calling `mx.eval` actually
|
|
# allocates memory and initializes the parameters.
|
|
mx.eval(model.parameters())
|
|
|
|
# Setting a parameter to a new value is as simply as accessing that
|
|
# parameter and assigning a new array to it.
|
|
model.in_proj.weight = model.in_proj.weight * 2
|
|
mx.eval(model.parameters())
|
|
"""
|
|
|
|
__call__: Callable
|
|
def __init__(self) -> None:
|
|
"""Should be called by the subclasses of ``Module``."""
|
|
|
|
@property
|
|
def training(self): # -> bool:
|
|
"""Boolean indicating if the model is in training mode."""
|
|
|
|
@property
|
|
def state(self): # -> Self:
|
|
"""The module's state dictionary
|
|
|
|
The module's state dictionary contains any attribute set on the
|
|
module including parameters in :meth:`Module.parameters`
|
|
|
|
Unlike :meth:`Module.parameters`, the :attr:`Module.state` property is
|
|
a reference to the module's state. Updates to it will be reflected in
|
|
the original module.
|
|
"""
|
|
|
|
def __repr__(self): # -> str:
|
|
...
|
|
def __getattr__(self, key: str): # -> None:
|
|
...
|
|
def __setattr__(self, key: str, val: Any): # -> None:
|
|
...
|
|
def __delattr__(self, name): # -> None:
|
|
...
|
|
def load_weights(
|
|
self,
|
|
file_or_weights: Union[str, List[Tuple[str, mx.array]]],
|
|
strict: bool = ...,
|
|
) -> Module:
|
|
"""
|
|
Update the model's weights from a ``.npz``, a ``.safetensors`` file, or a list.
|
|
|
|
Args:
|
|
file_or_weights (str or list(tuple(str, mx.array))): The path to
|
|
the weights ``.npz`` file (``.npz`` or ``.safetensors``) or a list
|
|
of pairs of parameter names and arrays.
|
|
strict (bool, optional): If ``True`` then checks that the provided
|
|
weights exactly match the parameters of the model. Otherwise,
|
|
only the weights actually contained in the model are loaded and
|
|
shapes are not checked. Default: ``True``.
|
|
|
|
Returns:
|
|
The module instance after updating the weights.
|
|
|
|
Example:
|
|
|
|
.. code-block:: python
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
model = nn.Linear(10, 10)
|
|
|
|
# Load from file
|
|
model.load_weights("weights.npz")
|
|
|
|
# Load from .safetensors file
|
|
model.load_weights("weights.safetensors")
|
|
|
|
# Load from list
|
|
weights = [
|
|
("weight", mx.random.uniform(shape=(10, 10))),
|
|
("bias", mx.zeros((10,))),
|
|
]
|
|
model.load_weights(weights)
|
|
|
|
# Missing weight
|
|
weights = [
|
|
("weight", mx.random.uniform(shape=(10, 10))),
|
|
]
|
|
|
|
# Raises a ValueError exception
|
|
model.load_weights(weights)
|
|
|
|
# Ok, only updates the weight but not the bias
|
|
model.load_weights(weights, strict=False)
|
|
"""
|
|
|
|
def save_weights(self, file: str): # -> None:
|
|
"""
|
|
Save the model's weights to a file. The saving method is determined by the file extension:
|
|
- ``.npz`` will use :func:`mx.savez`
|
|
- ``.safetensors`` will use :func:`mx.save_safetensors`
|
|
"""
|
|
|
|
@staticmethod
|
|
def is_module(value): # -> bool:
|
|
...
|
|
@staticmethod
|
|
def valid_child_filter(module, key, value): # -> bool:
|
|
...
|
|
@staticmethod
|
|
def valid_parameter_filter(module, key, value): # -> bool:
|
|
...
|
|
@staticmethod
|
|
def trainable_parameter_filter(module, key, value): # -> bool:
|
|
...
|
|
def filter_and_map(
|
|
self,
|
|
filter_fn: Callable[[Module, str, Any], bool],
|
|
map_fn: Optional[Callable] = ...,
|
|
is_leaf_fn: Optional[Callable[[Module, str, Any], bool]] = ...,
|
|
): # -> dict[Any, Any | dict[Any, Any | dict[Any, Any] | list[Any]] | dict[Any, Any] | list[Any]]:
|
|
"""Recursively filter the contents of the module using ``filter_fn``,
|
|
namely only select keys and values where ``filter_fn`` returns true.
|
|
|
|
This is used to implement :meth:`parameters` and :meth:`trainable_parameters`
|
|
but it can also be used to extract any subset of the module's parameters.
|
|
|
|
Args:
|
|
filter_fn (Callable): Given a value, the key in which it is found
|
|
and the containing module, decide whether to keep the value or
|
|
drop it.
|
|
map_fn (Callable, optional): Optionally transform the value before
|
|
returning it.
|
|
is_leaf_fn (Callable, optional): Given a value, the key in which it
|
|
is found and the containing module decide if it is a leaf.
|
|
|
|
Returns:
|
|
A dictionary containing the contents of the module recursively filtered
|
|
"""
|
|
|
|
def parameters(
|
|
self,
|
|
) -> mx.MX_ARRAY_TREE:
|
|
"""Recursively return all the :class:`mlx.core.array` members of this Module
|
|
as a dict of dicts and lists."""
|
|
|
|
def trainable_parameters(
|
|
self,
|
|
) -> mx.MX_ARRAY_TREE: # -> dict[Any, Any | dict[Any, Any | dict[Any, Any] | list[Any]] | dict[Any, Any] | list[Any]]:
|
|
"""Recursively return all the non frozen :class:`mlx.core.array` members of
|
|
this Module as a dict of dicts and lists."""
|
|
|
|
def children(
|
|
self,
|
|
) -> mx.MX_ARRAY_TREE: # -> dict[Any, Any | dict[Any, Any | dict[Any, Any] | list[Any]] | dict[Any, Any] | list[Any]]:
|
|
"""Return the direct descendants of this Module instance."""
|
|
|
|
def leaf_modules(
|
|
self,
|
|
) -> mx.MX_ARRAY_TREE: # -> dict[Any, Any | dict[Any, Any | dict[Any, Any] | list[Any]] | dict[Any, Any] | list[Any]]:
|
|
"""Return the submodules that do not contain other modules."""
|
|
|
|
def update(self, parameters: dict, strict: bool = ...) -> Module:
|
|
"""Replace the parameters of this Module with the provided ones in the
|
|
dict of dicts and lists.
|
|
|
|
Commonly used by the optimizer to change the model to the updated
|
|
(optimized) parameters. Also used by the :meth:`value_and_grad` to set the
|
|
tracers in the model in order to compute gradients.
|
|
|
|
The passed in parameters dictionary need not be a full dictionary
|
|
similar to :meth:`parameters`. Only the provided locations will be
|
|
updated.
|
|
|
|
Args:
|
|
parameters (dict): A complete or partial dictionary of the modules
|
|
parameters.
|
|
strict (bool): If ``True`` checks that ``parameters`` is a
|
|
subset of the module's parameters. Default: ``True``.
|
|
Returns:
|
|
The module instance after updating the parameters.
|
|
"""
|
|
|
|
def apply(
|
|
self,
|
|
map_fn: Callable[[mx.array], mx.array],
|
|
filter_fn: Optional[Callable[[Module, str, Any], bool]] = ...,
|
|
) -> Module:
|
|
"""Map all the parameters using the provided ``map_fn`` and immediately
|
|
update the module with the mapped parameters.
|
|
|
|
For instance running ``model.apply(lambda x: x.astype(mx.float16))``
|
|
casts all parameters to 16 bit floats.
|
|
|
|
Args:
|
|
map_fn (Callable): Maps an array to another array
|
|
filter_fn (Callable, optional): Filter to select which arrays to
|
|
map (default: :meth:`Module.valid_parameter_filter`).
|
|
|
|
Returns:
|
|
The module instance after updating the parameters.
|
|
"""
|
|
|
|
def update_modules(self, modules: dict, strict: bool = ...) -> Module:
|
|
"""Replace the child modules of this :class:`Module` instance with the
|
|
provided ones in the dict of dicts and lists.
|
|
|
|
It is the equivalent of :meth:`Module.update` but for modules instead
|
|
of parameters and allows us to flexibly edit complex architectures by
|
|
programmatically swapping layers.
|
|
|
|
The passed in parameters dictionary need not be a full dictionary
|
|
similar to :meth:`modules`. Only the provided locations will be
|
|
updated.
|
|
|
|
Args:
|
|
modules (dict): A complete or partial dictionary of the module's
|
|
submodules.
|
|
strict (bool): If ``True`` checks that ``modules`` is a
|
|
subset of the child modules of this instance. Default: ``True``.
|
|
Returns:
|
|
The module instance after updating the submodules.
|
|
"""
|
|
|
|
def apply_to_modules(self, apply_fn: Callable[[str, Module], Any]) -> Module:
|
|
"""Apply a function to all the modules in this instance (including this
|
|
instance).
|
|
|
|
Args:
|
|
apply_fn (Callable): The function to apply to the modules.
|
|
|
|
Returns:
|
|
The module instance after updating submodules.
|
|
"""
|
|
|
|
def modules(self): # -> list[Any]:
|
|
"""Return a list with all the modules in this instance.
|
|
|
|
Returns:
|
|
A list of :class:`Module` instances.
|
|
"""
|
|
|
|
def named_modules(self): # -> list[Any]:
|
|
"""Return a list with all the modules in this instance and their name
|
|
with dot notation.
|
|
|
|
Returns:
|
|
A list of tuples (str, :class:`Module`).
|
|
"""
|
|
|
|
def freeze(
|
|
self,
|
|
*,
|
|
recurse: bool = ...,
|
|
keys: Optional[Union[str, List[str]]] = ...,
|
|
strict: bool = ...,
|
|
) -> Module:
|
|
"""Freeze the Module's parameters or some of them. Freezing a parameter means not
|
|
computing gradients for it.
|
|
|
|
This function is idempotent i.e. freezing a frozen model is a no-op.
|
|
|
|
Example:
|
|
For instance to only train the attention parameters from a Transformer:
|
|
|
|
.. code-block:: python
|
|
|
|
model = nn.Transformer()
|
|
model.freeze()
|
|
model.apply_to_modules(lambda k, v: v.unfreeze() if k.endswith("attention") else None)
|
|
|
|
Args:
|
|
recurse (bool, optional): If True then freeze the parameters of the
|
|
submodules as well. Default: ``True``.
|
|
keys (str or list[str], optional): If provided then only these
|
|
parameters will be frozen otherwise all the parameters of a
|
|
module. For instance freeze all biases by calling
|
|
``module.freeze(keys="bias")``.
|
|
strict (bool, optional): If set to ``True`` validate that the passed keys exist.
|
|
Default: ``False``.
|
|
|
|
Returns:
|
|
The module instance after freezing the parameters.
|
|
"""
|
|
|
|
def unfreeze(
|
|
self,
|
|
*,
|
|
recurse: bool = ...,
|
|
keys: Optional[Union[str, List[str]]] = ...,
|
|
strict: bool = ...,
|
|
) -> Module:
|
|
"""Unfreeze the Module's parameters or some of them.
|
|
|
|
This function is idempotent ie unfreezing a model that is not frozen is
|
|
a noop.
|
|
|
|
Example:
|
|
|
|
For instance to only train the biases of a Transformer one can do:
|
|
|
|
.. code-block:: python
|
|
|
|
model = nn.Transformer()
|
|
model.freeze()
|
|
model.unfreeze(keys="bias")
|
|
|
|
Args:
|
|
recurse (bool, optional): If True then unfreeze the parameters of the
|
|
submodules as well. Default: ``True``.
|
|
keys (str or list[str], optional): If provided then only these
|
|
parameters will be unfrozen otherwise all the parameters of a
|
|
module. For instance unfreeze all biases by calling
|
|
``module.unfreeze(keys="bias")``.
|
|
strict (bool, optional): If set to ``True`` validate that the passed keys exist.
|
|
Default: ``False``.
|
|
|
|
Returns:
|
|
The module instance after unfreezing the parameters.
|
|
"""
|
|
|
|
def train(self, mode: bool = ...) -> Module:
|
|
"""Set the model in or out of training mode.
|
|
|
|
Training mode only applies to certain layers. For example
|
|
:obj:`Dropout` applies a random mask in training mode, but is the
|
|
identity in evaluation mode.
|
|
|
|
Args:
|
|
mode (bool): Indicate if the model should be in training or
|
|
evaluation mode. Default: ``True``.
|
|
Returns:
|
|
The module instance after updating the training mode.
|
|
"""
|
|
|
|
def eval(self) -> Module:
|
|
"""Set the model to evaluation mode.
|
|
|
|
See :func:`train`.
|
|
"""
|
|
|
|
def set_dtype(
|
|
self, dtype: mx.Dtype, predicate: Optional[Callable[[mx.Dtype], bool]] = ...
|
|
): # -> None:
|
|
"""Set the dtype of the module's parameters.
|
|
|
|
Args:
|
|
dtype (Dtype): The new dtype.
|
|
predicate (typing.Callable, optional): A predicate to select
|
|
parameters to cast. By default, only parameters of type
|
|
:attr:`floating` will be updated to avoid casting integer
|
|
parameters to the new dtype.
|
|
"""
|