Files
exo/.mlx_typings/mlx/nn/layers/activations.pyi
2025-11-06 21:59:29 +00:00

524 lines
14 KiB
Python

"""
This type stub file was generated by pyright.
"""
from functools import partial
from typing import Any
import mlx.core as mx
from base import Module
@partial(mx.compile, shapeless=True)
def sigmoid(x: mx.array) -> mx.array:
r"""Applies the sigmoid function.
.. math::
\text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
"""
@partial(mx.compile, shapeless=True)
def relu(x: mx.array) -> mx.array:
r"""Applies the Rectified Linear Unit.
Simply ``mx.maximum(x, 0)``.
"""
@partial(mx.compile, shapeless=True)
def relu2(x: mx.array) -> mx.array:
r"""Applies the ReLU² activation function.
Applies :math:`\max(0, x)^2` element wise.
"""
@partial(mx.compile, shapeless=True)
def relu6(x: mx.array) -> mx.array:
r"""Applies the Rectified Linear Unit 6.
Applies :math:`\min(\max(x, 0), 6)` element wise.
"""
@partial(mx.compile, shapeless=True)
def leaky_relu(x: mx.array, negative_slope=...) -> mx.array:
r"""Applies the Leaky Rectified Linear Unit.
Simply ``mx.maximum(negative_slope * x, x)``.
"""
@partial(mx.compile, shapeless=True)
def log_softmax(x: mx.array, axis=...):
r"""Applies the Log Softmax function.
Applies :math:`x + \log \sum_i e^{x_i}` element wise.
"""
@partial(mx.compile, shapeless=True)
def elu(x: mx.array, alpha=...) -> mx.array:
r"""Applies the Exponential Linear Unit.
Simply ``mx.where(x > 0, x, alpha * (mx.exp(x) - 1))``.
"""
@partial(mx.compile, shapeless=True)
def softmax(x: mx.array, axis=...) -> mx.array:
r"""Applies the Softmax function.
Applies :math:`\frac{e^{x_i}}{\sum_j e^{x_j}}` element wise.
"""
@partial(mx.compile, shapeless=True)
def softplus(x: mx.array) -> mx.array:
r"""Applies the Softplus function.
Applies :math:`\log(1 + \exp(x))` element wise.
"""
@partial(mx.compile, shapeless=True)
def softsign(x: mx.array) -> mx.array:
r"""Applies the Softsign function.
Applies :math:`\frac{x}{1 + |x|}` element wise.
"""
@partial(mx.compile, shapeless=True)
def softshrink(x: mx.array, lambd: float = ...) -> mx.array:
r"""Applies the Softshrink activation function.
.. math::
\text{softshrink}(x) = \begin{cases}
x - \lambda & \text{if } x > \lambda \\
x + \lambda & \text{if } x < -\lambda \\
0 & \text{otherwise}
\end{cases}
"""
@partial(mx.compile, shapeless=True)
def celu(x: mx.array, alpha=...) -> mx.array:
r"""Applies the Continuously Differentiable Exponential Linear Unit.
Applies :math:`\max(0, x) + \min(0, \alpha * (\exp(x / \alpha) - 1))`
element wise.
"""
@partial(mx.compile, shapeless=True)
def silu(x: mx.array) -> mx.array:
r"""Applies the Sigmoid Linear Unit. Also known as Swish.
Applies :math:`x \sigma(x)` element wise, where :math:`\sigma(\cdot)` is
the logistic sigmoid.
"""
@partial(mx.compile, shapeless=True)
def log_sigmoid(x: mx.array) -> mx.array:
r"""Applies the Log Sigmoid function.
Applies :math:`\log(\sigma(x)) = -\log(1 + e^{-x})` element wise.
"""
@partial(mx.compile, shapeless=True)
def gelu(x: mx.array) -> mx.array:
r"""Applies the Gaussian Error Linear Units function.
.. math::
\textrm{GELU}(x) = x * \Phi(x)
where :math:`\Phi(x)` is the Gaussian CDF.
See also :func:`gelu_approx` and :func:`gelu_fast_approx` for faster
approximations.
"""
@partial(mx.compile, shapeless=True)
def gelu_approx(x: mx.array) -> mx.array:
r"""An approximation to Gaussian Error Linear Unit.
See :func:`gelu` for the exact computation.
This function approximates ``gelu`` with a maximum absolute error :math:`<
0.0005` in the range :math:`[-6, 6]` using the following
.. math::
x = 0.5 * x * \left(1 + \text{Tanh}\left((\sqrt{2 / \pi} * \left(x + 0.044715 * x^3\right)\right)\right)
"""
@partial(mx.compile, shapeless=True)
def gelu_fast_approx(x: mx.array) -> mx.array:
r"""A fast approximation to Gaussian Error Linear Unit.
See :func:`gelu` for the exact computation.
This function approximates ``gelu`` with a maximum absolute error :math:`<
0.015` in the range :math:`[-6, 6]` using the following
.. math::
x = x \sigma\left(1.702 x\right)
where :math:`\sigma(\cdot)` is the logistic sigmoid.
References:
- https://github.com/hendrycks/GELUs
- https://arxiv.org/abs/1606.08415
"""
def glu(x: mx.array, axis: int = ...) -> mx.array:
r"""Applies the gated linear unit function.
This function splits the ``axis`` dimension of the input into two halves
(:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`.
.. math::
\textrm{GLU}(x) = a * \sigma(b)
Args:
axis (int): The dimension to split along. Default: ``-1``
"""
@partial(mx.compile, shapeless=True)
def step(x: mx.array, threshold: float = ...) -> mx.array:
r"""Applies the Step Activation Function.
This function implements a binary step activation, where the output is set
to 1 if the input is greater than a specified threshold, and 0 otherwise.
.. math::
\text{step}(x) = \begin{cases}
0 & \text{if } x < \text{threshold} \\
1 & \text{if } x \geq \text{threshold}
\end{cases}
Args:
threshold: The value to threshold at.
"""
@partial(mx.compile, shapeless=True)
def selu(x: mx.array) -> mx.array:
r"""Applies the Scaled Exponential Linear Unit.
.. math::
\text{selu}(x) = \begin{cases}
\lambda x & \text{if } x > 0 \\
\lambda \alpha (\exp(x) - 1) & \text{if } x \leq 0
\end{cases}
where :math:`\lambda = 1.0507` and :math:`\alpha = 1.67326`.
See also :func:`elu`.
"""
@partial(mx.compile, shapeless=True)
def prelu(x: mx.array, alpha: mx.array) -> mx.array:
r"""Applies the element-wise parametric ReLU.
.. math::
\text{PReLU}(x) = \max(0,x) + a * \min(0,x)
where :math:`a` is an array.
"""
@partial(mx.compile, shapeless=True)
def mish(x: mx.array) -> mx.array:
r"""Applies the Mish function, element-wise.
Mish: A Self Regularized Non-Monotonic Neural Activation Function.
Reference: https://arxiv.org/abs/1908.08681
.. math::
\text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
"""
@partial(mx.compile, shapeless=True)
def hardswish(x: mx.array) -> mx.array:
r"""Applies the hardswish function, element-wise.
.. math::
\text{Hardswish}(x) = x * \min(\max(x + 3, 0), 6) / 6
"""
@partial(mx.compile, shapeless=True)
def hard_tanh(x: mx.array, min_val=..., max_val=...) -> mx.array:
r"""Applies the HardTanh function.
Applies :math:`\max(\min(x, \text{max\_val}), \text{min\_val})` element-wise.
"""
@partial(mx.compile, shapeless=True)
def hard_shrink(x: mx.array, lambd=...) -> mx.array:
r"""Applies the HardShrink activation function.
.. math::
\text{hardshrink}(x) = \begin{cases}
x & \text{if } x > \lambda \\
x & \text{if } x < -\lambda \\
0 & \text{otherwise}
\end{cases}
"""
@partial(mx.compile, shapeless=True)
def softmin(x: mx.array, axis=...) -> mx.array:
r"""Applies the Softmin function.
Applies :math:`\frac{e^{-x_i}}{\sum_j e^{-x_j}}` element-wise.
"""
def tanh(x: mx.array) -> mx.array:
"""Applies the hyperbolic tangent function.
Simply ``mx.tanh(x)``.
"""
class GLU(Module):
r"""Applies the gated linear unit function.
This function splits the ``axis`` dimension of the input into two halves
(:math:`a` and :math:`b`) and applies :math:`a * \sigma(b)`.
.. math::
\textrm{GLU}(x) = a * \sigma(b)
Args:
axis (int): The dimension to split along. Default: ``-1``
"""
def __init__(self, axis: int = ...) -> None: ...
def __call__(self, x) -> Any: ...
@_make_activation_module(sigmoid)
class Sigmoid(Module):
r"""Applies the sigmoid function, element-wise.
.. math::
\text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
"""
@_make_activation_module(mish)
class Mish(Module):
r"""Applies the Mish function, element-wise.
Reference: https://arxiv.org/abs/1908.08681
.. math::
\text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
"""
@_make_activation_module(relu)
class ReLU(Module):
r"""Applies the Rectified Linear Unit.
Simply ``mx.maximum(x, 0)``.
See :func:`relu` for the functional equivalent.
"""
@_make_activation_module(relu2)
class ReLU2(Module):
r"""Applies the ReLU² activation function.
See :func:`relu2` for the functional equivalent.
"""
@_make_activation_module(relu6)
class ReLU6(Module):
r"""Applies the Rectified Linear Unit 6.
See :func:`relu6` for the functional equivalent.
"""
class LeakyReLU(Module):
r"""Applies the Leaky Rectified Linear Unit.
Simply ``mx.maximum(negative_slope * x, x)``.
Args:
negative_slope: Controls the angle of the negative slope. Default: ``1e-2``
"""
def __init__(self, negative_slope=...) -> None: ...
def __call__(self, x): ...
class ELU(Module):
r"""Applies the Exponential Linear Unit.
Simply ``mx.where(x > 0, x, alpha * (mx.exp(x) - 1))``.
See :func:`elu` for the functional equivalent.
Args:
alpha: the :math:`\alpha` value for the ELU formulation. Default: ``1.0``
"""
def __init__(self, alpha=...) -> None: ...
def __call__(self, x): ...
@_make_activation_module(softmax)
class Softmax(Module):
r"""Applies the Softmax function.
See :func:`softmax` for the functional equivalent.
"""
@_make_activation_module(softplus)
class Softplus(Module):
r"""Applies the Softplus function.
See :func:`softplus` for the functional equivalent.
"""
@_make_activation_module(softsign)
class Softsign(Module):
r"""Applies the Softsign function.
See :func:`softsign` for the functional equivalent.
"""
class Softshrink(Module):
r"""Applies the Softshrink function.
See :func:`softshrink` for the functional equivalent.
Args:
lambd: the :math:`\lambda` value for Softshrink. Default: ``0.5``
"""
def __init__(self, lambd=...) -> None: ...
def __call__(self, x): ...
class CELU(Module):
r"""Applies the Continuously Differentiable Exponential Linear Unit.
Applies :math:`\max(0, x) + \min(0, \alpha * (\exp(x / \alpha) - 1))`
element wise.
See :func:`celu` for the functional equivalent.
Args:
alpha: the :math:`\alpha` value for the CELU formulation. Default: ``1.0``
"""
def __init__(self, alpha=...) -> None: ...
def __call__(self, x): ...
@_make_activation_module(silu)
class SiLU(Module):
r"""Applies the Sigmoid Linear Unit. Also known as Swish.
See :func:`silu` for the functional equivalent.
"""
@_make_activation_module(log_softmax)
class LogSoftmax(Module):
r"""Applies the Log Softmax function.
See :func:`log_softmax` for the functional equivalent.
"""
@_make_activation_module(log_sigmoid)
class LogSigmoid(Module):
r"""Applies the Log Sigmoid function.
See :func:`log_sigmoid` for the functional equivalent.
"""
class PReLU(Module):
r"""Applies the element-wise parametric ReLU.
Applies :math:`\max(0, x) + a * \min(0, x)` element wise, where :math:`a`
is an array.
See :func:`prelu` for the functional equivalent.
Args:
num_parameters: number of :math:`a` to learn. Default: ``1``
init: the initial value of :math:`a`. Default: ``0.25``
"""
def __init__(self, num_parameters=..., init=...) -> None: ...
def __call__(self, x: mx.array): ...
class GELU(Module):
r"""Applies the Gaussian Error Linear Units.
.. math::
\textrm{GELU}(x) = x * \Phi(x)
where :math:`\Phi(x)` is the Gaussian CDF.
However, if ``approx`` is set to 'precise' or 'fast' it applies
.. math::
\textrm{GELUApprox}(x) &= 0.5 * x * \left(1 + \text{Tanh}\left((\sqrt{2 / \pi} * \left(x + 0.044715 * x^3\right)\right)\right) \\
\textrm{GELUFast}(x) &= x * \sigma\left(1.702 * x\right)
respectively.
.. note::
For compatibility with the PyTorch API, 'tanh' can be used as an alias
for 'precise'.
See :func:`gelu`, :func:`gelu_approx` and :func:`gelu_fast_approx` for the
functional equivalents and information regarding error bounds.
Args:
approx ('none' | 'precise' | 'fast'): Which approximation to gelu to use if any.
"""
def __init__(self, approx=...) -> None: ...
def __call__(self, x): ...
@_make_activation_module(tanh)
class Tanh(Module):
r"""Applies the hyperbolic tangent function.
See :func:`tanh` for the functional equivalent.
"""
@_make_activation_module(hardswish)
class Hardswish(Module):
r"""Applies the hardswish function, element-wise.
See :func:`hardswish` for the functional equivalent.
"""
class Step(Module):
r"""Applies the Step Activation Function.
This function implements a binary step activation, where the output is set
to 1 if the input is greater than a specified threshold, and 0 otherwise.
.. math::
\text{step}(x) = \begin{cases}
0 & \text{if } x < \text{threshold} \\
1 & \text{if } x \geq \text{threshold}
\end{cases}
Args:
threshold: The value to threshold at.
"""
def __init__(self, threshold: float = ...) -> None: ...
def __call__(self, x: mx.array): ...
@_make_activation_module(selu)
class SELU(Module):
r"""Applies the Scaled Exponential Linear Unit.
See :func:`selu` for the functional equivalent.
"""
@_make_activation_module(hard_tanh)
class HardTanh(Module):
r"""Applies the HardTanh function.
See :func:`hard_tanh` for the functional equivalent.
"""
@_make_activation_module(hard_shrink)
class HardShrink(Module):
r"""Applies the HardShrink function.
See :func:`hard_shrink` for the functional equivalent.
Args:
lambd: the :math:`\lambda` value for Hardshrink. Default: ``0.5``
"""
@_make_activation_module(softmin)
class Softmin(Module):
r"""Applies the Softmin function.
See :func:`softmin` for the functional equivalent.
"""