Files
exo/.mlx_typings/mlx/nn/utils.pyi
2025-11-06 21:59:29 +00:00

74 lines
2.9 KiB
Python

"""
This type stub file was generated by pyright.
"""
from typing import Any, Callable, Optional
import mlx.core as mx
from .layers.base import Module
def value_and_grad(
model: Module, fn: Callable
): # -> _Wrapped[..., Any, ..., tuple[Any, Any]]:
"""Transform the passed function ``fn`` to a function that computes the
gradients of ``fn`` wrt the model's trainable parameters and also its
value.
Args:
model (Module): The model whose trainable parameters to compute
gradients for
fn (Callable): The scalar function to compute gradients for
Returns:
A callable that returns the value of ``fn`` and the gradients wrt the
trainable parameters of ``model``
"""
def checkpoint(
module: Module, fn: Optional[Callable] = ...
): # -> _Wrapped[..., Any, ..., Any]:
"""Transform the passed callable to one that performs gradient
checkpointing with respect to the trainable parameters of the module (and
the callable's inputs).
Args:
module (Module): The module for whose parameters we will be
performing gradient checkpointing.
fn (Callable, optional): The function to checkpoint. If not provided it
defaults to the provided module.
Returns:
A callable that saves the inputs and outputs during the forward pass
and recomputes all intermediate states during the backward pass.
"""
def average_gradients(
gradients: Any,
group: Optional[mx.distributed.Group] = ...,
all_reduce_size: int = ...,
communication_type: Optional[mx.Dtype] = ...,
communication_stream: Optional[mx.Stream] = ...,
): # -> Any:
"""Average the gradients across the distributed processes in the passed group.
This helper enables concatenating several gradients of small arrays to one
big all reduce call for better networking performance.
Args:
gradients (Any): The Python tree containing the gradients (it should
have the same structure across processes)
group (Optional[mlx.core.distributed.Group]): The group of processes to
average the gradients. If set to ``None`` the global group is used.
Default: ``None``.
all_reduce_size (int): Group arrays until their size in bytes exceeds
this number. Perform one communication step per group of arrays. If
less or equal to 0 array grouping is disabled. Default: ``32MiB``.
communication_type (Optional[mlx.core.Dtype]): If provided cast to this
type before performing the communication. Typically cast to a
smaller float to reduce the communication size. Default: ``None``.
communication_stream (Optional[mlx.core.Stream]): The stream to usse
for the communication. If unspecified the default communication
stream is used which can vary by back-end. Default: ``None``.
"""