mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
114 lines
3.8 KiB
Python
114 lines
3.8 KiB
Python
"""
|
|
This type stub file was generated by pyright.
|
|
"""
|
|
|
|
from typing import Callable, Optional
|
|
|
|
import mlx.core as mx
|
|
from base import Module
|
|
|
|
class RNN(Module):
|
|
r"""An Elman recurrent layer.
|
|
|
|
The input is a sequence of shape ``NLD`` or ``LD`` where:
|
|
|
|
* ``N`` is the optional batch dimension
|
|
* ``L`` is the sequence length
|
|
* ``D`` is the input's feature dimension
|
|
|
|
Concretely, for each element along the sequence length axis, this
|
|
layer applies the function:
|
|
|
|
.. math::
|
|
|
|
h_{t + 1} = \text{tanh} (W_{ih}x_t + W_{hh}h_t + b)
|
|
|
|
The hidden state :math:`h` has shape ``NH`` or ``H``, depending on
|
|
whether the input is batched or not. Returns the hidden state at each
|
|
time step, of shape ``NLH`` or ``LH``.
|
|
|
|
Args:
|
|
input_size (int): Dimension of the input, ``D``.
|
|
hidden_size (int): Dimension of the hidden state, ``H``.
|
|
bias (bool, optional): Whether to use a bias. Default: ``True``.
|
|
nonlinearity (callable, optional): Non-linearity to use. If ``None``,
|
|
then func:`tanh` is used. Default: ``None``.
|
|
"""
|
|
def __init__(
|
|
self,
|
|
input_size: int,
|
|
hidden_size: int,
|
|
bias: bool = ...,
|
|
nonlinearity: Optional[Callable] = ...,
|
|
) -> None: ...
|
|
def __call__(self, x: mx.array, hidden=...) -> mx.array: ...
|
|
|
|
class GRU(Module):
|
|
r"""A gated recurrent unit (GRU) RNN layer.
|
|
|
|
The input has shape ``NLD`` or ``LD`` where:
|
|
|
|
* ``N`` is the optional batch dimension
|
|
* ``L`` is the sequence length
|
|
* ``D`` is the input's feature dimension
|
|
|
|
Concretely, for each element of the sequence, this layer computes:
|
|
|
|
.. math::
|
|
|
|
\begin{aligned}
|
|
r_t &= \sigma (W_{xr}x_t + W_{hr}h_t + b_{r}) \\
|
|
z_t &= \sigma (W_{xz}x_t + W_{hz}h_t + b_{z}) \\
|
|
n_t &= \text{tanh}(W_{xn}x_t + b_{n} + r_t \odot (W_{hn}h_t + b_{hn})) \\
|
|
h_{t + 1} &= (1 - z_t) \odot n_t + z_t \odot h_t
|
|
\end{aligned}
|
|
|
|
The hidden state :math:`h` has shape ``NH`` or ``H`` depending on
|
|
whether the input is batched or not. Returns the hidden state at each
|
|
time step of shape ``NLH`` or ``LH``.
|
|
|
|
Args:
|
|
input_size (int): Dimension of the input, ``D``.
|
|
hidden_size (int): Dimension of the hidden state, ``H``.
|
|
bias (bool): Whether to use biases or not. Default: ``True``.
|
|
"""
|
|
def __init__(self, input_size: int, hidden_size: int, bias: bool = ...) -> None: ...
|
|
def __call__(self, x: mx.array, hidden=...) -> mx.array: ...
|
|
|
|
class LSTM(Module):
|
|
r"""An LSTM recurrent layer.
|
|
|
|
The input has shape ``NLD`` or ``LD`` where:
|
|
|
|
* ``N`` is the optional batch dimension
|
|
* ``L`` is the sequence length
|
|
* ``D`` is the input's feature dimension
|
|
|
|
Concretely, for each element of the sequence, this layer computes:
|
|
|
|
.. math::
|
|
\begin{aligned}
|
|
i_t &= \sigma (W_{xi}x_t + W_{hi}h_t + b_{i}) \\
|
|
f_t &= \sigma (W_{xf}x_t + W_{hf}h_t + b_{f}) \\
|
|
g_t &= \text{tanh} (W_{xg}x_t + W_{hg}h_t + b_{g}) \\
|
|
o_t &= \sigma (W_{xo}x_t + W_{ho}h_t + b_{o}) \\
|
|
c_{t + 1} &= f_t \odot c_t + i_t \odot g_t \\
|
|
h_{t + 1} &= o_t \text{tanh}(c_{t + 1})
|
|
\end{aligned}
|
|
|
|
The hidden state :math:`h` and cell state :math:`c` have shape ``NH``
|
|
or ``H``, depending on whether the input is batched or not.
|
|
|
|
The layer returns two arrays, the hidden state and the cell state at
|
|
each time step, both of shape ``NLH`` or ``LH``.
|
|
|
|
Args:
|
|
input_size (int): Dimension of the input, ``D``.
|
|
hidden_size (int): Dimension of the hidden state, ``H``.
|
|
bias (bool): Whether to use biases or not. Default: ``True``.
|
|
"""
|
|
def __init__(self, input_size: int, hidden_size: int, bias: bool = ...) -> None: ...
|
|
def __call__(
|
|
self, x: mx.array, hidden=..., cell=...
|
|
) -> tuple[mx.array, mx.array]: ...
|