mlx.distributed.Group type stubs

This commit is contained in:
rltakashige
2025-11-05 21:26:04 -08:00
committed by GitHub
parent 16f724e24c
commit 6bbb6344b6
14 changed files with 6062 additions and 72 deletions

View File

@@ -62,6 +62,33 @@ jobs:
fi
shell: bash
- name: Configure basedpyright include for local MLX
run: |
RUNNER_LABELS='${{ toJSON(runner.labels) }}'
if echo "$RUNNER_LABELS" | grep -q "local_mlx"; then
if [ -d "/Users/Shared/mlx" ]; then
echo "Updating [tool.basedpyright].include to use /Users/Shared/mlx"
awk '
BEGIN { in=0 }
/^\[tool\.basedpyright\]/ { in=1; print; next }
in && /^\[/ { in=0 } # next section
in && /^[ \t]*include[ \t]*=/ {
print "include = [\"/Users/Shared/mlx\"]"
next
}
{ print }
' pyproject.toml > pyproject.toml.tmp && mv pyproject.toml.tmp pyproject.toml
echo "New [tool.basedpyright] section:"
sed -n '/^\[tool\.basedpyright\]/,/^\[/p' pyproject.toml | sed '$d' || true
else
echo "local_mlx tag present but /Users/Shared/mlx not found; leaving pyproject unchanged."
fi
else
echo "Runner does not have 'local_mlx' tag; leaving pyproject unchanged."
fi
shell: bash
- uses: ./.github/actions/typecheck
# ci:

View File

@@ -1,5 +1,5 @@
fmt:
uv run ruff format src
uv run ruff format src typings
lint:
uv run ruff check --fix src

View File

@@ -82,6 +82,7 @@ build-backend = "uv_build"
###
[tool.basedpyright]
include = [".venv/lib/mlx", "src"]
typeCheckingMode = "strict"
failOnWarnings = true
@@ -97,15 +98,12 @@ reportUnnecessaryTypeIgnoreComment = "error"
pythonVersion = "3.13"
pythonPlatform = "Darwin"
exclude = ["**/.venv", "**/venv", "**/__pycache__", "**/exo_scripts", "**/.direnv", "**/rust", "mlx/*", "mlx-lm/*"]
exclude = ["**/.venv", "**/venv", "**/__pycache__", "**/exo_scripts", "**/.direnv", "**/rust"]
stubPath = "typings"
[[tool.basedpyright.executionEnvironments]]
root = "src"
[[tool.basedpyright.executionEnvironments]]
root = "."
###
# uv configuration
###

View File

@@ -162,9 +162,9 @@ class PipelineParallelisationStrategy(ParallelisationShardStrategy):
class TensorParallelisationStrategy(ParallelisationShardStrategy):
def __init__(self, group: mx.distributed.Group): # type: ignore
self.group = group # type: ignore
self.N = self.group.size # type: ignore
def __init__(self, group: mx.distributed.Group):
self.group = group
self.N = self.group.size
def auto_parallel(
self, model: nn.Module, model_shard_meta: ShardMetadata
@@ -174,28 +174,28 @@ class TensorParallelisationStrategy(ParallelisationShardStrategy):
all_to_sharded_linear = partial(
shard_linear,
sharding="all-to-sharded",
group=self.group, # pyright: ignore
group=self.group,
)
sharded_to_all_linear = partial(
shard_linear,
sharding="sharded-to-all",
group=self.group, # type: ignore
group=self.group,
)
all_to_sharded_linear_in_place = partial(
shard_inplace,
sharding="all-to-sharded",
group=self.group, # pyright: ignore
group=self.group,
)
sharded_to_all_linear_in_place = partial(
shard_inplace,
sharding="sharded-to-all",
group=self.group, # type: ignore
group=self.group,
)
if isinstance(model, LlamaModel):
tensor_parallel_sharding_strategy = LlamaShardingStrategy(
self.group, # type: ignore
self.group,
all_to_sharded_linear,
sharded_to_all_linear,
all_to_sharded_linear_in_place,
@@ -203,7 +203,7 @@ class TensorParallelisationStrategy(ParallelisationShardStrategy):
)
elif isinstance(model, DeepseekV3Model):
tensor_parallel_sharding_strategy = DeepSeekShardingStrategy(
self.group, # type: ignore
self.group,
all_to_sharded_linear,
sharded_to_all_linear,
all_to_sharded_linear_in_place,
@@ -211,7 +211,7 @@ class TensorParallelisationStrategy(ParallelisationShardStrategy):
)
elif isinstance(model, Qwen3MoeModel):
tensor_parallel_sharding_strategy = QwenShardingStrategy(
self.group, # type: ignore
self.group,
all_to_sharded_linear,
sharded_to_all_linear,
all_to_sharded_linear_in_place,
@@ -305,14 +305,14 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
class ShardedDeepseekV3MoE(CustomMlxLayer):
def __init__(self, layer: _LayerCallable):
super().__init__(layer)
self.sharding_group: mx.distributed.Group | None = None # type: ignore
self.sharding_group: mx.distributed.Group | None = None
def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None: # type: ignore
if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x) # type: ignore
y = self.original_layer.__call__(x) # type: ignore
if self.sharding_group is not None: # type: ignore
y = mx.distributed.all_sum(y, group=self.sharding_group) # type: ignore
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y
@@ -349,12 +349,12 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
class ShardedQwenMoE(CustomMlxLayer):
def __init__(self, layer: _LayerCallable):
super().__init__(layer)
self.sharding_group: mx.distributed.Group | None = None # type: ignore
self.sharding_group: mx.distributed.Group | None = None
def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None: # type: ignore
if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x) # type: ignore
y = self.original_layer.__call__(x) # type: ignore
if self.sharding_group is not None: # type: ignore
y = mx.distributed.all_sum(y, group=self.sharding_group) # type: ignore
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y

View File

@@ -38,17 +38,17 @@ mlx_rank: None | int = None
mlx_world_size: None | int = None
def mx_barrier(group: mx.distributed.Group | None = None): # type: ignore
mx.eval( # type: ignore
def mx_barrier(group: mx.distributed.Group | None = None):
mx.eval(
mx.distributed.all_sum(
mx.array(1.0),
stream=mx.default_stream(mx.Device(mx.cpu)),
group=group, # type: ignore[type-arg]
group=group,
)
)
def broadcast_from_zero(value: int, group: mx.distributed.Group | None = None): # type: ignore
def broadcast_from_zero(value: int, group: mx.distributed.Group | None = None):
if mlx_rank is None:
return value
@@ -57,8 +57,8 @@ def broadcast_from_zero(value: int, group: mx.distributed.Group | None = None):
else:
a = mx.array([0], dtype=mx.int32)
m = mx.distributed.all_sum(a, stream=mx.Device(mx.DeviceType.cpu), group=group) # type: ignore
mx.eval(m) # type: ignore
m = mx.distributed.all_sum(a, stream=mx.Device(mx.DeviceType.cpu), group=group)
mx.eval(m)
return int(m.item())
@@ -68,12 +68,12 @@ class HostList(RootModel[list[str]]):
return cls(root=[str(host) for host in hosts])
def mlx_distributed_init( # type: ignore[return]
def mlx_distributed_init(
rank: int,
hosts: list[Host] | None = None,
mlx_ibv_devices: list[list[str | None]] | None = None,
mlx_ibv_coordinator: str | None = None,
) -> mx.distributed.Group: # type: ignore
) -> mx.distributed.Group:
"""
Initialize the MLX distributed (runs in thread pool).
@@ -132,7 +132,9 @@ def initialize_mlx(
hosts: list[Host] | None = None,
mlx_ibv_devices: list[list[str | None]] | None = None,
mlx_ibv_coordinator: str | None = None,
) -> tuple[Model, TokenizerWrapper, Callable[[mx.array], mx.array], Any]:
) -> tuple[
Model, TokenizerWrapper, Callable[[mx.array], mx.array], mx.distributed.Group
]:
"""
Initialize the MLX model, tokenizer, and sampler. Runs in the MLX thread.
@@ -141,7 +143,7 @@ def initialize_mlx(
- mlx_ibv_devices: RDMA connectivity matrix
"""
mx.random.seed(42)
group = mlx_distributed_init( # type: ignore[misc]
group = mlx_distributed_init(
model_shard_meta.device_rank,
hosts=hosts,
mlx_ibv_devices=mlx_ibv_devices,
@@ -154,14 +156,14 @@ def initialize_mlx(
sampler: Callable[[mx.array], mx.array] = make_sampler(temp=0.7)
model, tokenizer = shard_and_load(model_shard_meta, group=group) # type: ignore[reportUnknownArgumentType]
model, tokenizer = shard_and_load(model_shard_meta, group=group)
return model, tokenizer, sampler, group # type: ignore[return-value]
def shard_and_load(
model_shard_meta: ShardMetadata,
group: mx.distributed.Group, # type: ignore
group: mx.distributed.Group,
) -> tuple[nn.Module, TokenizerWrapper]:
model_path = build_model_path(model_shard_meta.model_meta.model_id)
@@ -177,7 +179,7 @@ def shard_and_load(
assert isinstance(tokenizer, _TokenizerWrapper)
if group:
runner_print(f"Group size: {group.size()}, group rank: {group.rank()}") # type: ignore
runner_print(f"Group size: {group.size()}, group rank: {group.rank()}")
else:
runner_print("!!! No group")
@@ -189,19 +191,19 @@ def shard_and_load(
case "pipeline_rdma":
strategy = PipelineParallelisationStrategy()
case "tensor":
strategy = TensorParallelisationStrategy(group) # type: ignore[reportUnknownArgumentType]
strategy = TensorParallelisationStrategy(group)
case "tensor_rdma":
strategy = TensorParallelisationStrategy(group) # type: ignore[reportUnknownArgumentType]
strategy = TensorParallelisationStrategy(group)
model = strategy.auto_parallel(model, model_shard_meta)
runner_print(f"Model after auto_parallel: {str(model)}")
mx.eval(model.parameters()) # type: ignore
mx.eval(model) # type: ignore
mx.eval(model)
# Synchronize processes before generation to avoid timeout
mx_barrier(group) # type: ignore[reportUnknownArgumentType]
mx_barrier(group)
return model, tokenizer # type: ignore
@@ -288,15 +290,15 @@ def mlx_force_oom(size: int = 40000) -> None:
"""
Force an Out-Of-Memory (OOM) error in MLX by performing large tensor operations.
"""
mx.set_default_device(mx.gpu) # type: ignore
mx.set_default_device(mx.gpu)
a = mx.random.uniform(shape=(size, size), dtype=mx.float32)
b = mx.random.uniform(shape=(size, size), dtype=mx.float32)
mx.eval(a, b) # type: ignore
mx.eval(a, b)
c = mx.matmul(a, b)
d = mx.matmul(a, c)
e = mx.matmul(b, c)
f = mx.sigmoid(d + e)
mx.eval(f) # type: ignore
mx.eval(f)
def set_wired_limit_for_model(model_size: Memory):

View File

@@ -14,9 +14,9 @@ from mlx_lm.models.cache import KVCache
from exo.engines.mlx import Model, TokenizerWrapper
from exo.engines.mlx.utils_mlx import (
apply_chat_template,
broadcast_from_zero, # type: ignore
broadcast_from_zero,
make_kv_cache,
mx_barrier, # type: ignore
mx_barrier,
)
from exo.shared.types.api import ChatCompletionMessage
from exo.shared.types.tasks import ChatCompletionTaskParams
@@ -62,7 +62,7 @@ def generate_step(
quantized_kv_start: int = 0,
prompt_progress_callback: Callable[[int, int], None] | None = None,
input_embeddings: mx.array | None = None,
group: mx.distributed.Group | None = None, # type: ignore[type-arg]
group: mx.distributed.Group | None = None,
) -> Generator[Tuple[int, mx.array], None, None]:
"""
A generator producing token ids based on the given prompt from the model.
@@ -213,7 +213,7 @@ def generate_step(
y, logprobs = _step(input_tokens=prompt, input_embeddings=input_embeddings)
mx.async_eval(y, logprobs) # type: ignore[type-arg]
mx.async_eval(y, logprobs)
next_y: array | None = None
next_logprobs: array | None = None
n = 0
@@ -221,7 +221,7 @@ def generate_step(
if n != max_tokens:
assert y is not None
next_y, next_logprobs = _step(y)
mx.async_eval(next_y, next_logprobs) # type: ignore[type-arg]
mx.async_eval(next_y, next_logprobs)
if n == 0:
mx.eval(y) # type: ignore[type-arg]
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
@@ -250,7 +250,7 @@ def stream_generate(
quantized_kv_start: int = 0,
prompt_progress_callback: Callable[[int, int], None] | None = None,
input_embeddings: mx.array | None = None,
group: mx.distributed.Group | None = None, # type: ignore[type-arg]
group: mx.distributed.Group | None = None,
) -> Generator[GenerationResponse, None, None]:
# Try to infer if special tokens are needed
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(
@@ -310,7 +310,7 @@ async def warmup_inference(
model: Model,
tokenizer: TokenizerWrapper,
sampler: Callable[[mx.array], mx.array],
group: mx.distributed.Group | None = None, # type: ignore
group: mx.distributed.Group | None = None,
) -> int:
loop = asyncio.get_running_loop()

View File

@@ -25,7 +25,7 @@ from exo.shared.types.worker.communication import (
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils import ensure_type
from exo.worker.runner.generate import mlx_generate, warmup_inference # type: ignore
from exo.worker.runner.generate import mlx_generate, warmup_inference
async def main(raw_conn: Connection):
@@ -51,7 +51,7 @@ async def main(raw_conn: Connection):
mlx_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
loop = asyncio.get_running_loop()
model, tokenizer, sampler, group = await loop.run_in_executor( # type: ignore[type-arg]
model, tokenizer, sampler, group = await loop.run_in_executor(
mlx_executor,
partial(
initialize_mlx,
@@ -70,7 +70,7 @@ async def main(raw_conn: Connection):
model=model,
tokenizer=tokenizer,
sampler=sampler,
group=group, # type: ignore[type-arg]
group=group,
)
runner_print(f"Warmed up by generating {toks} tokens")
await conn.send(InitializedResponse(time_taken=time.time() - setup_start_time))

View File

@@ -221,9 +221,7 @@ class RunnerSupervisor:
timeout = PREFILL_TIMEOUT_SECONDS
logger.info(
f"Starting chat completion with timeout {timeout}"
)
logger.info(f"Starting chat completion with timeout {timeout}")
while True:
try:

View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,2 @@
def is_available() -> bool:
"""Check if the CUDA back-end is available."""

View File

@@ -0,0 +1,216 @@
from typing import Sequence
from mlx.core import Device, Dtype, Stream, array
class Group:
"""
An :class:`mlx.core.distributed.Group` represents a group of independent mlx
processes that can communicate.
"""
def rank(self) -> int:
"""Get the rank of this process"""
def size(self) -> int:
"""Get the size of the group"""
def split(self, color: int, key: int = ...) -> Group:
"""
Split the group to subgroups based on the provided color.
Processes that use the same color go to the same group. The ``key``
argument defines the rank in the new group. The smaller the key the
smaller the rank. If the key is negative then the rank in the
current group is used.
Args:
color (int): A value to group processes into subgroups.
key (int, optional): A key to optionally change the rank ordering
of the processes.
"""
def all_gather(
x: array, *, group: Group | None = ..., stream: Stream | Device | None = ...
) -> array:
"""
Gather arrays from all processes.
Gather the ``x`` arrays from all processes in the group and concatenate
them along the first axis. The arrays should all have the same shape.
Args:
x (array): Input array.
group (Group): The group of processes that will participate in the
gather. If set to ``None`` the global group is used. Default:
``None``.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
array: The concatenation of all ``x`` arrays.
"""
def all_max(
x: array, *, group: Group | None = ..., stream: Stream | Device | None = ...
) -> array:
"""
All reduce max.
Find the maximum of the ``x`` arrays from all processes in the group.
Args:
x (array): Input array.
group (Group): The group of processes that will participate in the
reduction. If set to ``None`` the global group is used. Default:
``None``.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
array: The maximum of all ``x`` arrays.
"""
def all_min(
x: array, *, group: Group | None = ..., stream: Stream | Device | None = ...
) -> array:
"""
All reduce min.
Find the minimum of the ``x`` arrays from all processes in the group.
Args:
x (array): Input array.
group (Group): The group of processes that will participate in the
reduction. If set to ``None`` the global group is used. Default:
``None``.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
array: The minimum of all ``x`` arrays.
"""
def all_sum(
x: array, *, group: Group | None = ..., stream: Stream | Device | None = ...
) -> array:
"""
All reduce sum.
Sum the ``x`` arrays from all processes in the group.
Args:
x (array): Input array.
group (Group): The group of processes that will participate in the
reduction. If set to ``None`` the global group is used. Default:
``None``.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
array: The sum of all ``x`` arrays.
"""
def init(strict: bool = ..., backend: str = ...) -> Group:
"""
Initialize the communication backend and create the global communication group.
Example:
.. code:: python
import mlx.core as mx
group = mx.distributed.init(backend="ring")
Args:
strict (bool, optional): If set to False it returns a singleton group
in case ``mx.distributed.is_available()`` returns False otherwise
it throws a runtime error. Default: ``False``
backend (str, optional): Which distributed backend to initialize.
Possible values ``mpi``, ``ring``, ``nccl``, ``any``. If set to ``any`` all
available backends are tried and the first one that succeeds
becomes the global group which will be returned in subsequent
calls. Default: ``any``
Returns:
Group: The group representing all the launched processes.
"""
def is_available() -> bool:
"""Check if a communication backend is available."""
def recv(
shape: Sequence[int],
dtype: Dtype,
src: int,
*,
group: Group | None = ...,
stream: Stream | Device | None = ...,
) -> array:
"""
Recv an array with shape ``shape`` and dtype ``dtype`` from process
with rank ``src``.
Args:
shape (tuple[int]): The shape of the array we are receiving.
dtype (Dtype): The data type of the array we are receiving.
src (int): Rank of the source process in the group.
group (Group): The group of processes that will participate in the
recv. If set to ``None`` the global group is used. Default:
``None``.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
array: The array that was received from ``src``.
"""
def recv_like(
x: array,
src: int,
*,
group: Group | None = ...,
stream: Stream | Device | None = ...,
) -> array:
"""
Recv an array with shape and type like ``x`` from process with rank
``src``.
It is equivalent to calling ``mx.distributed.recv(x.shape, x.dtype, src)``.
Args:
x (array): An array defining the shape and dtype of the array we are
receiving.
src (int): Rank of the source process in the group.
group (Group): The group of processes that will participate in the
recv. If set to ``None`` the global group is used. Default:
``None``.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
array: The array that was received from ``src``.
"""
def send(
x: array,
dst: int,
*,
group: Group | None = ...,
stream: Stream | Device | None = ...,
) -> array:
"""
Send an array from the current process to the process that has rank
``dst`` in the group.
Args:
x (array): Input array.
dst (int): Rank of the destination process in the group.
group (Group): The group of processes that will participate in the
sned. If set to ``None`` the global group is used. Default:
``None``.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
array: An array identical to ``x`` which when evaluated the send is performed.
"""

View File

@@ -0,0 +1,38 @@
def clear_cache() -> None: ...
def device_info() -> dict[str, str | int]:
"""
Get information about the GPU device and system settings.
Currently returns:
* ``architecture``
* ``max_buffer_size``
* ``max_recommended_working_set_size``
* ``memory_size``
* ``resource_limit``
Returns:
dict: A dictionary with string keys and string or integer values.
"""
def get_active_memory() -> int: ...
def get_cache_memory() -> int: ...
def get_peak_memory() -> int: ...
def is_available() -> bool:
"""Check if the Metal back-end is available."""
def reset_peak_memory() -> None: ...
def set_cache_limit(limit: int) -> int: ...
def set_memory_limit(limit: int) -> int: ...
def set_wired_limit(limit: int) -> int: ...
def start_capture(path: str) -> None:
"""
Start a Metal capture.
Args:
path (str): The path to save the capture which should have
the extension ``.gputrace``.
"""
def stop_capture() -> None:
"""Stop a Metal capture."""

View File

@@ -0,0 +1,301 @@
from typing import Sequence
from mlx.core import Device, Dtype, Stream, array, scalar
from mlx.core.distributed import state as state
def bernoulli(
p: scalar | array = ...,
shape: Sequence[int] | None = ...,
key: array | None = ...,
stream: Stream | Device | None = ...,
) -> array:
"""
Generate Bernoulli random values.
The values are sampled from the bernoulli distribution with parameter
``p``. The parameter ``p`` can be a :obj:`float` or :obj:`array` and
must be broadcastable to ``shape``.
Args:
p (float or array, optional): Parameter of the Bernoulli
distribution. Default: ``0.5``.
shape (list(int), optional): Shape of the output.
Default: ``p.shape``.
key (array, optional): A PRNG key. Default: ``None``.
Returns:
array: The array of random integers.
"""
def categorical(
logits: array,
axis: int = ...,
shape: Sequence[int] | None = ...,
num_samples: int | None = ...,
key: array | None = ...,
stream: Stream | Device | None = ...,
) -> array:
"""
Sample from a categorical distribution.
The values are sampled from the categorical distribution specified by
the unnormalized values in ``logits``. Note, at most one of ``shape``
or ``num_samples`` can be specified. If both are ``None``, the output
has the same shape as ``logits`` with the ``axis`` dimension removed.
Args:
logits (array): The *unnormalized* categorical distribution(s).
axis (int, optional): The axis which specifies the distribution.
Default: ``-1``.
shape (list(int), optional): The shape of the output. This must
be broadcast compatible with ``logits.shape`` with the ``axis``
dimension removed. Default: ``None``
num_samples (int, optional): The number of samples to draw from each
of the categorical distributions in ``logits``. The output will have
``num_samples`` in the last dimension. Default: ``None``.
key (array, optional): A PRNG key. Default: ``None``.
Returns:
array: The ``shape``-sized output array with type ``uint32``.
"""
def gumbel(
shape: Sequence[int] = ...,
dtype: Dtype | None = ...,
key: Stream | Device | None = ...,
stream: array | None = ...,
) -> array:
"""
Sample from the standard Gumbel distribution.
The values are sampled from a standard Gumbel distribution
which CDF ``exp(-exp(-x))``.
Args:
shape (list(int)): The shape of the output.
dtype (Dtype, optional): The data type of the output.
Default: ``float32``.
key (array, optional): A PRNG key. Default: ``None``.
Returns:
array:
The :class:`array` with shape ``shape`` and distributed according
to the Gumbel distribution.
"""
def key(seed: int) -> array:
"""
Get a PRNG key from a seed.
Args:
seed (int): Seed for the PRNG.
Returns:
array: The PRNG key array.
"""
def laplace(
shape: Sequence[int] = ...,
dtype: Dtype | None = ...,
loc: float = ...,
scale: float = ...,
key: array | None = ...,
stream: Stream | Device | None = ...,
) -> array:
"""
Sample numbers from a Laplace distribution.
Args:
shape (list(int), optional): Shape of the output. Default: ``()``.
dtype (Dtype, optional): Type of the output. Default: ``float32``.
loc (float, optional): Mean of the distribution. Default: ``0.0``.
scale (float, optional): The scale "b" of the Laplace distribution.
Default:``1.0``.
key (array, optional): A PRNG key. Default: ``None``.
Returns:
array: The output array of random values.
"""
def multivariate_normal(
mean: array,
cov: array,
shape: Sequence[int] = ...,
dtype: Dtype | None = ...,
key: array | None = ...,
stream: Stream | Device | None = ...,
) -> array:
"""
Generate jointly-normal random samples given a mean and covariance.
The matrix ``cov`` must be positive semi-definite. The behavior is
undefined if it is not. The only supported ``dtype`` is ``float32``.
Args:
mean (array): array of shape ``(..., n)``, the mean of the
distribution.
cov (array): array of shape ``(..., n, n)``, the covariance
matrix of the distribution. The batch shape ``...`` must be
broadcast-compatible with that of ``mean``.
shape (list(int), optional): The output shape must be
broadcast-compatible with ``mean.shape[:-1]`` and ``cov.shape[:-2]``.
If empty, the result shape is determined by broadcasting the batch
shapes of ``mean`` and ``cov``. Default: ``[]``.
dtype (Dtype, optional): The output type. Default: ``float32``.
key (array, optional): A PRNG key. Default: ``None``.
Returns:
array: The output array of random values.
"""
def normal(
shape: Sequence[int] = ...,
dtype: Dtype | None = ...,
loc: scalar | array | None = ...,
scale: scalar | array | None = ...,
key: array | None = ...,
stream: Stream | Device | None = ...,
) -> array:
r"""
Generate normally distributed random numbers.
If ``loc`` and ``scale`` are not provided the "standard" normal
distribution is used. That means $x \sim \mathcal{N}(0, 1)$ for
real numbers and $\text{Re}(x),\text{Im}(x) \sim \mathcal{N}(0,
\frac{1}{2})$ for complex numbers.
Args:
shape (list(int), optional): Shape of the output. Default: ``()``.
dtype (Dtype, optional): Type of the output. Default: ``float32``.
loc (scalar or array, optional): Mean of the distribution.
Default: ``None``.
scale (scalar or array, optional): Standard deviation of the
distribution. Default: ``None``.
key (array, optional): A PRNG key. Default: ``None``.
Returns:
array: The output array of random values.
"""
def permutation(
x: int | array,
axis: int = ...,
key: array | None = ...,
stream: Stream | Device | None = ...,
) -> array:
"""
Generate a random permutation or permute the entries of an array.
Args:
x (int or array, optional): If an integer is provided a random
permtuation of ``mx.arange(x)`` is returned. Otherwise the entries
of ``x`` along the given axis are randomly permuted.
axis (int, optional): The axis to permute along. Default: ``0``.
key (array, optional): A PRNG key. Default: ``None``.
Returns:
array:
The generated random permutation or randomly permuted input array.
"""
def randint(
low: scalar | array,
high: scalar | array,
shape: Sequence[int] = ...,
dtype: Dtype | None = ...,
key: array | None = ...,
stream: Stream | Device | None = ...,
) -> array:
"""
Generate random integers from the given interval.
The values are sampled with equal probability from the integers in
half-open interval ``[low, high)``. The lower and upper bound can be
scalars or arrays and must be broadcastable to ``shape``.
Args:
low (scalar or array): Lower bound of the interval.
high (scalar or array): Upper bound of the interval.
shape (list(int), optional): Shape of the output. Default: ``()``.
dtype (Dtype, optional): Type of the output. Default: ``int32``.
key (array, optional): A PRNG key. Default: ``None``.
Returns:
array: The array of random integers.
"""
def seed(seed: int) -> None:
"""
Seed the global PRNG.
Args:
seed (int): Seed for the global PRNG.
"""
def split(key: array, num: int = ..., stream: Stream | Device | None = ...) -> array:
"""
Split a PRNG key into sub keys.
Args:
key (array): Input key to split.
num (int, optional): Number of sub keys. Default: ``2``.
Returns:
array: The array of sub keys with ``num`` as its first dimension.
"""
def truncated_normal(
lower: scalar | array,
upper: scalar | array,
shape: Sequence[int] | None = ...,
dtype: Dtype | None = ...,
key: array | None = ...,
stream: Stream | Device | None = ...,
) -> array:
"""
Generate values from a truncated normal distribution.
The values are sampled from the truncated normal distribution
on the domain ``(lower, upper)``. The bounds ``lower`` and ``upper``
can be scalars or arrays and must be broadcastable to ``shape``.
Args:
lower (scalar or array): Lower bound of the domain.
upper (scalar or array): Upper bound of the domain.
shape (list(int), optional): The shape of the output.
Default:``()``.
dtype (Dtype, optional): The data type of the output.
Default: ``float32``.
key (array, optional): A PRNG key. Default: ``None``.
Returns:
array: The output array of random values.
"""
def uniform(
low: scalar | array = ...,
high: scalar | array = ...,
shape: Sequence[int] = ...,
dtype: Dtype | None = ...,
key: array | None = ...,
stream: Stream | Device | None = ...,
) -> array:
"""
Generate uniformly distributed random numbers.
The values are sampled uniformly in the half-open interval ``[low, high)``.
The lower and upper bound can be scalars or arrays and must be
broadcastable to ``shape``.
Args:
low (scalar or array, optional): Lower bound of the distribution.
Default: ``0``.
high (scalar or array, optional): Upper bound of the distribution.
Default: ``1``.
shape (list(int), optional): Shape of the output. Default:``()``.
dtype (Dtype, optional): Type of the output. Default: ``float32``.
key (array, optional): A PRNG key. Default: ``None``.
Returns:
array: The output array random values.
"""

20
uv.lock generated
View File

@@ -14,7 +14,6 @@ supported-markers = [
members = [
"exo",
"exo-pyo3-bindings",
"exo-scripts",
]
[[package]]
@@ -438,21 +437,6 @@ dev = [
{ name = "pytest-asyncio", specifier = ">=1.0.0" },
]
[[package]]
name = "exo-scripts"
version = "0.1.0"
source = { editable = "scripts" }
dependencies = [
{ name = "exo", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "huggingface-hub", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
[package.metadata]
requires-dist = [
{ name = "exo", editable = "." },
{ name = "huggingface-hub", specifier = ">=0.33.4" },
]
[[package]]
name = "fastapi"
version = "0.121.0"
@@ -561,12 +545,16 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/ee/43/3cecdc0349359e1a527cbf2e3e28e5f8f06d3343aaf82ca13437a9aa290f/greenlet-3.2.4-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:23768528f2911bcd7e475210822ffb5254ed10d71f4028387e5a99b4c6699671", size = 610497, upload-time = "2025-08-07T13:18:31.636Z" },
{ url = "https://files.pythonhosted.org/packages/b8/19/06b6cf5d604e2c382a6f31cafafd6f33d5dea706f4db7bdab184bad2b21d/greenlet-3.2.4-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:00fadb3fedccc447f517ee0d3fd8fe49eae949e1cd0f6a611818f4f6fb7dc83b", size = 1121662, upload-time = "2025-08-07T13:42:41.117Z" },
{ url = "https://files.pythonhosted.org/packages/a2/15/0d5e4e1a66fab130d98168fe984c509249c833c1a3c16806b90f253ce7b9/greenlet-3.2.4-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:d25c5091190f2dc0eaa3f950252122edbbadbb682aa7b1ef2f8af0f8c0afefae", size = 1149210, upload-time = "2025-08-07T13:18:24.072Z" },
{ url = "https://files.pythonhosted.org/packages/1c/53/f9c440463b3057485b8594d7a638bed53ba531165ef0ca0e6c364b5cc807/greenlet-3.2.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6e343822feb58ac4d0a1211bd9399de2b3a04963ddeec21530fc426cc121f19b", size = 1564759, upload-time = "2025-11-04T12:42:19.395Z" },
{ url = "https://files.pythonhosted.org/packages/47/e4/3bb4240abdd0a8d23f4f88adec746a3099f0d86bfedb623f063b2e3b4df0/greenlet-3.2.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ca7f6f1f2649b89ce02f6f229d7c19f680a6238af656f61e0115b24857917929", size = 1634288, upload-time = "2025-11-04T12:42:21.174Z" },
{ url = "https://files.pythonhosted.org/packages/22/5c/85273fd7cc388285632b0498dbbab97596e04b154933dfe0f3e68156c68c/greenlet-3.2.4-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:49a30d5fda2507ae77be16479bdb62a660fa51b1eb4928b524975b3bde77b3c0", size = 273586, upload-time = "2025-08-07T13:16:08.004Z" },
{ url = "https://files.pythonhosted.org/packages/d1/75/10aeeaa3da9332c2e761e4c50d4c3556c21113ee3f0afa2cf5769946f7a3/greenlet-3.2.4-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:299fd615cd8fc86267b47597123e3f43ad79c9d8a22bebdce535e53550763e2f", size = 686346, upload-time = "2025-08-07T13:42:59.944Z" },
{ url = "https://files.pythonhosted.org/packages/c0/aa/687d6b12ffb505a4447567d1f3abea23bd20e73a5bed63871178e0831b7a/greenlet-3.2.4-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:c17b6b34111ea72fc5a4e4beec9711d2226285f0386ea83477cbb97c30a3f3a5", size = 699218, upload-time = "2025-08-07T13:45:30.969Z" },
{ url = "https://files.pythonhosted.org/packages/dc/8b/29aae55436521f1d6f8ff4e12fb676f3400de7fcf27fccd1d4d17fd8fecd/greenlet-3.2.4-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:b4a1870c51720687af7fa3e7cda6d08d801dae660f75a76f3845b642b4da6ee1", size = 694659, upload-time = "2025-08-07T13:53:17.759Z" },
{ url = "https://files.pythonhosted.org/packages/92/2e/ea25914b1ebfde93b6fc4ff46d6864564fba59024e928bdc7de475affc25/greenlet-3.2.4-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:061dc4cf2c34852b052a8620d40f36324554bc192be474b9e9770e8c042fd735", size = 695355, upload-time = "2025-08-07T13:18:34.517Z" },
{ url = "https://files.pythonhosted.org/packages/72/60/fc56c62046ec17f6b0d3060564562c64c862948c9d4bc8aa807cf5bd74f4/greenlet-3.2.4-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:44358b9bf66c8576a9f57a590d5f5d6e72fa4228b763d0e43fee6d3b06d3a337", size = 657512, upload-time = "2025-08-07T13:18:33.969Z" },
{ url = "https://files.pythonhosted.org/packages/23/6e/74407aed965a4ab6ddd93a7ded3180b730d281c77b765788419484cdfeef/greenlet-3.2.4-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:2917bdf657f5859fbf3386b12d68ede4cf1f04c90c3a6bc1f013dd68a22e2269", size = 1612508, upload-time = "2025-11-04T12:42:23.427Z" },
{ url = "https://files.pythonhosted.org/packages/0d/da/343cd760ab2f92bac1845ca07ee3faea9fe52bee65f7bcb19f16ad7de08b/greenlet-3.2.4-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:015d48959d4add5d6c9f6c5210ee3803a830dce46356e3bc326d6776bde54681", size = 1680760, upload-time = "2025-11-04T12:42:25.341Z" },
]
[[package]]