mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
74 lines
2.9 KiB
Python
74 lines
2.9 KiB
Python
"""
|
|
This type stub file was generated by pyright.
|
|
"""
|
|
|
|
from typing import Any, Callable, Optional
|
|
|
|
import mlx.core as mx
|
|
|
|
from .layers.base import Module
|
|
|
|
def value_and_grad(
|
|
model: Module, fn: Callable
|
|
): # -> _Wrapped[..., Any, ..., tuple[Any, Any]]:
|
|
"""Transform the passed function ``fn`` to a function that computes the
|
|
gradients of ``fn`` wrt the model's trainable parameters and also its
|
|
value.
|
|
|
|
Args:
|
|
model (Module): The model whose trainable parameters to compute
|
|
gradients for
|
|
fn (Callable): The scalar function to compute gradients for
|
|
|
|
Returns:
|
|
A callable that returns the value of ``fn`` and the gradients wrt the
|
|
trainable parameters of ``model``
|
|
"""
|
|
|
|
def checkpoint(
|
|
module: Module, fn: Optional[Callable] = ...
|
|
): # -> _Wrapped[..., Any, ..., Any]:
|
|
"""Transform the passed callable to one that performs gradient
|
|
checkpointing with respect to the trainable parameters of the module (and
|
|
the callable's inputs).
|
|
|
|
Args:
|
|
module (Module): The module for whose parameters we will be
|
|
performing gradient checkpointing.
|
|
fn (Callable, optional): The function to checkpoint. If not provided it
|
|
defaults to the provided module.
|
|
|
|
Returns:
|
|
A callable that saves the inputs and outputs during the forward pass
|
|
and recomputes all intermediate states during the backward pass.
|
|
"""
|
|
|
|
def average_gradients(
|
|
gradients: Any,
|
|
group: Optional[mx.distributed.Group] = ...,
|
|
all_reduce_size: int = ...,
|
|
communication_type: Optional[mx.Dtype] = ...,
|
|
communication_stream: Optional[mx.Stream] = ...,
|
|
): # -> Any:
|
|
"""Average the gradients across the distributed processes in the passed group.
|
|
|
|
This helper enables concatenating several gradients of small arrays to one
|
|
big all reduce call for better networking performance.
|
|
|
|
Args:
|
|
gradients (Any): The Python tree containing the gradients (it should
|
|
have the same structure across processes)
|
|
group (Optional[mlx.core.distributed.Group]): The group of processes to
|
|
average the gradients. If set to ``None`` the global group is used.
|
|
Default: ``None``.
|
|
all_reduce_size (int): Group arrays until their size in bytes exceeds
|
|
this number. Perform one communication step per group of arrays. If
|
|
less or equal to 0 array grouping is disabled. Default: ``32MiB``.
|
|
communication_type (Optional[mlx.core.Dtype]): If provided cast to this
|
|
type before performing the communication. Typically cast to a
|
|
smaller float to reduce the communication size. Default: ``None``.
|
|
communication_stream (Optional[mlx.core.Stream]): The stream to usse
|
|
for the communication. If unspecified the default communication
|
|
stream is used which can vary by back-end. Default: ``None``.
|
|
"""
|