mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
190 lines
6.4 KiB
Python
190 lines
6.4 KiB
Python
"""
|
|
This type stub file was generated by pyright.
|
|
"""
|
|
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
|
|
from mlx.core import MX_ARRAY_TREE
|
|
|
|
def tree_map(
|
|
fn: Callable, tree: Any, *rest: Any, is_leaf: Optional[Callable] = ...
|
|
) -> Any:
|
|
"""Applies ``fn`` to the leaves of the Python tree ``tree`` and
|
|
returns a new collection with the results.
|
|
|
|
If ``rest`` is provided, every item is assumed to be a superset of ``tree``
|
|
and the corresponding leaves are provided as extra positional arguments to
|
|
``fn``. In that respect, :meth:`tree_map` is closer to :func:`itertools.starmap`
|
|
than to :func:`map`.
|
|
|
|
The keyword argument ``is_leaf`` decides what constitutes a leaf from
|
|
``tree`` similar to :func:`tree_flatten`.
|
|
|
|
.. code-block:: python
|
|
|
|
import mlx.nn as nn
|
|
from mlx.utils import tree_map
|
|
|
|
model = nn.Linear(10, 10)
|
|
print(model.parameters().keys())
|
|
# dict_keys(['weight', 'bias'])
|
|
|
|
# square the parameters
|
|
model.update(tree_map(lambda x: x*x, model.parameters()))
|
|
|
|
Args:
|
|
fn (callable): The function that processes the leaves of the tree.
|
|
tree (Any): The main Python tree that will be iterated upon.
|
|
rest (tuple[Any]): Extra trees to be iterated together with ``tree``.
|
|
is_leaf (callable, optional): An optional callable that returns ``True``
|
|
if the passed object is considered a leaf or ``False`` otherwise.
|
|
|
|
Returns:
|
|
A Python tree with the new values returned by ``fn``.
|
|
"""
|
|
|
|
def tree_map_with_path(
|
|
fn: Callable,
|
|
tree: Any,
|
|
*rest: Any,
|
|
is_leaf: Optional[Callable] = ...,
|
|
path: Optional[Any] = ...,
|
|
) -> Any:
|
|
"""Applies ``fn`` to the path and leaves of the Python tree ``tree`` and
|
|
returns a new collection with the results.
|
|
|
|
This function is the same :func:`tree_map` but the ``fn`` takes the path as
|
|
the first argument followed by the remaining tree nodes.
|
|
|
|
Args:
|
|
fn (callable): The function that processes the leaves of the tree.
|
|
tree (Any): The main Python tree that will be iterated upon.
|
|
rest (tuple[Any]): Extra trees to be iterated together with ``tree``.
|
|
is_leaf (Optional[Callable]): An optional callable that returns ``True``
|
|
if the passed object is considered a leaf or ``False`` otherwise.
|
|
path (Optional[Any]): Prefix will be added to the result.
|
|
|
|
Returns:
|
|
A Python tree with the new values returned by ``fn``.
|
|
|
|
Example:
|
|
>>> from mlx.utils import tree_map_with_path
|
|
>>> tree = {"model": [{"w": 0, "b": 1}, {"w": 0, "b": 1}]}
|
|
>>> new_tree = tree_map_with_path(lambda path, _: print(path), tree)
|
|
model.0.w
|
|
model.0.b
|
|
model.1.w
|
|
model.1.b
|
|
"""
|
|
|
|
def tree_flatten(
|
|
tree: Any,
|
|
prefix: str = ...,
|
|
is_leaf: Optional[Callable] = ...,
|
|
destination: Optional[Union[List[Tuple[str, Any]], Dict[str, Any]]] = ...,
|
|
) -> Union[List[Tuple[str, Any]], Dict[str, Any]]:
|
|
"""Flattens a Python tree to a list of key, value tuples.
|
|
|
|
The keys are using the dot notation to define trees of arbitrary depth and
|
|
complexity.
|
|
|
|
.. code-block:: python
|
|
|
|
from mlx.utils import tree_flatten
|
|
|
|
print(tree_flatten([[[0]]]))
|
|
# [("0.0.0", 0)]
|
|
|
|
print(tree_flatten([[[0]]], prefix=".hello"))
|
|
# [("hello.0.0.0", 0)]
|
|
|
|
tree_flatten({"a": {"b": 1}}, destination={})
|
|
{"a.b": 1}
|
|
|
|
.. note::
|
|
Dictionaries should have keys that are valid Python identifiers.
|
|
|
|
Args:
|
|
tree (Any): The Python tree to be flattened.
|
|
prefix (str): A prefix to use for the keys. The first character is
|
|
always discarded.
|
|
is_leaf (callable): An optional callable that returns True if the
|
|
passed object is considered a leaf or False otherwise.
|
|
destination (list or dict, optional): A list or dictionary to store the
|
|
flattened tree. If None an empty list will be used. Default: ``None``.
|
|
|
|
Returns:
|
|
Union[List[Tuple[str, Any]], Dict[str, Any]]: The flat representation of
|
|
the Python tree.
|
|
"""
|
|
|
|
def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any:
|
|
"""Recreate a Python tree from its flat representation.
|
|
|
|
.. code-block:: python
|
|
|
|
from mlx.utils import tree_unflatten
|
|
|
|
d = tree_unflatten([("hello.world", 42)])
|
|
print(d)
|
|
# {"hello": {"world": 42}}
|
|
|
|
d = tree_unflatten({"hello.world": 42})
|
|
print(d)
|
|
# {"hello": {"world": 42}}
|
|
|
|
Args:
|
|
tree (list[tuple[str, Any]] or dict[str, Any]): The flat representation of a Python tree.
|
|
For instance as returned by :meth:`tree_flatten`.
|
|
|
|
Returns:
|
|
A Python tree.
|
|
"""
|
|
|
|
def tree_reduce(
|
|
fn: Callable[[Any, Any], Any],
|
|
tree: list[MX_ARRAY_TREE] | tuple[MX_ARRAY_TREE, ...] | dict[str, MX_ARRAY_TREE],
|
|
initializer=...,
|
|
is_leaf=...,
|
|
) -> None:
|
|
"""Applies a reduction to the leaves of a Python tree.
|
|
|
|
This function reduces Python trees into an accumulated result by applying
|
|
the provided function ``fn`` to the leaves of the tree.
|
|
|
|
Example:
|
|
>>> from mlx.utils import tree_reduce
|
|
>>> tree = {"a": [1, 2, 3], "b": [4, 5]}
|
|
>>> tree_reduce(lambda acc, x: acc + x, tree, 0)
|
|
15
|
|
|
|
Args:
|
|
fn (callable): The reducer function that takes two arguments (accumulator,
|
|
current value) and returns the updated accumulator.
|
|
tree (Any): The Python tree to reduce. It can be any nested combination of
|
|
lists, tuples, or dictionaries.
|
|
initializer (Any, optional): The initial value to start the reduction. If
|
|
not provided, the first leaf value is used.
|
|
is_leaf (callable, optional): A function to determine if an object is a
|
|
leaf, returning ``True`` for leaf nodes and ``False`` otherwise.
|
|
|
|
Returns:
|
|
Any: The accumulated value.
|
|
"""
|
|
|
|
def tree_merge(
|
|
tree_a, tree_b, merge_fn=...
|
|
): # -> dict[Any, Any] | list[Any] | tuple[Any, *tuple[Any, ...]] | tuple[Any, ...]:
|
|
"""Merge two Python trees in one containing the values of both. It can be
|
|
thought of as a deep dict.update method.
|
|
|
|
Args:
|
|
tree_a (Any): The first Python tree.
|
|
tree_b (Any): The second Python tree.
|
|
merge_fn (callable, optional): A function to merge leaves.
|
|
|
|
Returns:
|
|
The Python tree containing the values of both ``tree_a`` and
|
|
``tree_b``.
|
|
"""
|