Files
ollama/x/mlxrunner/mlx/generated.h
Patrick Devine 44bdd9a2ef Add MLX runner with GLM4-MoE-Lite model support (#14185)
This change adds a new MLX based runner which includes:

  * Method-based MLX bindings
  * Subprocess-based MLX runner (x/mlxrunner)
  * KV cache with tree management
  * A basic sampler

The GLM4-MoE-Lite model has been ported to use the new bindings.

---------

Co-authored-by: Michael Yang <git@mxy.ng>
2026-02-10 14:57:57 -08:00

7135 lines
214 KiB
C

// This code is auto-generated; DO NOT EDIT.
#ifndef MLX_GENERATED_H
#define MLX_GENERATED_H
#include "dynamic.h"
#define mlx_dtype_size mlx_dtype_size_mlx_gen_orig_
#define mlx_array_tostring mlx_array_tostring_mlx_gen_orig_
#define mlx_array_new mlx_array_new_mlx_gen_orig_
#define mlx_array_free mlx_array_free_mlx_gen_orig_
#define mlx_array_new_bool mlx_array_new_bool_mlx_gen_orig_
#define mlx_array_new_int mlx_array_new_int_mlx_gen_orig_
#define mlx_array_new_float32 mlx_array_new_float32_mlx_gen_orig_
#define mlx_array_new_float mlx_array_new_float_mlx_gen_orig_
#define mlx_array_new_float64 mlx_array_new_float64_mlx_gen_orig_
#define mlx_array_new_double mlx_array_new_double_mlx_gen_orig_
#define mlx_array_new_complex mlx_array_new_complex_mlx_gen_orig_
#define mlx_array_new_data mlx_array_new_data_mlx_gen_orig_
#define mlx_array_set mlx_array_set_mlx_gen_orig_
#define mlx_array_set_bool mlx_array_set_bool_mlx_gen_orig_
#define mlx_array_set_int mlx_array_set_int_mlx_gen_orig_
#define mlx_array_set_float32 mlx_array_set_float32_mlx_gen_orig_
#define mlx_array_set_float mlx_array_set_float_mlx_gen_orig_
#define mlx_array_set_float64 mlx_array_set_float64_mlx_gen_orig_
#define mlx_array_set_double mlx_array_set_double_mlx_gen_orig_
#define mlx_array_set_complex mlx_array_set_complex_mlx_gen_orig_
#define mlx_array_set_data mlx_array_set_data_mlx_gen_orig_
#define mlx_array_itemsize mlx_array_itemsize_mlx_gen_orig_
#define mlx_array_size mlx_array_size_mlx_gen_orig_
#define mlx_array_nbytes mlx_array_nbytes_mlx_gen_orig_
#define mlx_array_ndim mlx_array_ndim_mlx_gen_orig_
#define mlx_array_shape mlx_array_shape_mlx_gen_orig_
#define mlx_array_strides mlx_array_strides_mlx_gen_orig_
#define mlx_array_dim mlx_array_dim_mlx_gen_orig_
#define mlx_array_dtype mlx_array_dtype_mlx_gen_orig_
#define mlx_array_eval mlx_array_eval_mlx_gen_orig_
#define mlx_array_item_bool mlx_array_item_bool_mlx_gen_orig_
#define mlx_array_item_uint8 mlx_array_item_uint8_mlx_gen_orig_
#define mlx_array_item_uint16 mlx_array_item_uint16_mlx_gen_orig_
#define mlx_array_item_uint32 mlx_array_item_uint32_mlx_gen_orig_
#define mlx_array_item_uint64 mlx_array_item_uint64_mlx_gen_orig_
#define mlx_array_item_int8 mlx_array_item_int8_mlx_gen_orig_
#define mlx_array_item_int16 mlx_array_item_int16_mlx_gen_orig_
#define mlx_array_item_int32 mlx_array_item_int32_mlx_gen_orig_
#define mlx_array_item_int64 mlx_array_item_int64_mlx_gen_orig_
#define mlx_array_item_float32 mlx_array_item_float32_mlx_gen_orig_
#define mlx_array_item_float64 mlx_array_item_float64_mlx_gen_orig_
#define mlx_array_item_complex64 mlx_array_item_complex64_mlx_gen_orig_
#define mlx_array_item_float16 mlx_array_item_float16_mlx_gen_orig_
#define mlx_array_item_bfloat16 mlx_array_item_bfloat16_mlx_gen_orig_
#define mlx_array_data_bool mlx_array_data_bool_mlx_gen_orig_
#define mlx_array_data_uint8 mlx_array_data_uint8_mlx_gen_orig_
#define mlx_array_data_uint16 mlx_array_data_uint16_mlx_gen_orig_
#define mlx_array_data_uint32 mlx_array_data_uint32_mlx_gen_orig_
#define mlx_array_data_uint64 mlx_array_data_uint64_mlx_gen_orig_
#define mlx_array_data_int8 mlx_array_data_int8_mlx_gen_orig_
#define mlx_array_data_int16 mlx_array_data_int16_mlx_gen_orig_
#define mlx_array_data_int32 mlx_array_data_int32_mlx_gen_orig_
#define mlx_array_data_int64 mlx_array_data_int64_mlx_gen_orig_
#define mlx_array_data_float32 mlx_array_data_float32_mlx_gen_orig_
#define mlx_array_data_float64 mlx_array_data_float64_mlx_gen_orig_
#define mlx_array_data_complex64 mlx_array_data_complex64_mlx_gen_orig_
#define mlx_array_data_float16 mlx_array_data_float16_mlx_gen_orig_
#define mlx_array_data_bfloat16 mlx_array_data_bfloat16_mlx_gen_orig_
#define _mlx_array_is_available _mlx_array_is_available_mlx_gen_orig_
#define _mlx_array_wait _mlx_array_wait_mlx_gen_orig_
#define _mlx_array_is_contiguous _mlx_array_is_contiguous_mlx_gen_orig_
#define _mlx_array_is_row_contiguous _mlx_array_is_row_contiguous_mlx_gen_orig_
#define _mlx_array_is_col_contiguous _mlx_array_is_col_contiguous_mlx_gen_orig_
#define mlx_closure_new mlx_closure_new_mlx_gen_orig_
#define mlx_closure_free mlx_closure_free_mlx_gen_orig_
#define mlx_closure_new_func mlx_closure_new_func_mlx_gen_orig_
#define mlx_closure_new_func_payload mlx_closure_new_func_payload_mlx_gen_orig_
#define mlx_closure_set mlx_closure_set_mlx_gen_orig_
#define mlx_closure_apply mlx_closure_apply_mlx_gen_orig_
#define mlx_closure_new_unary mlx_closure_new_unary_mlx_gen_orig_
#define mlx_closure_kwargs_new mlx_closure_kwargs_new_mlx_gen_orig_
#define mlx_closure_kwargs_free mlx_closure_kwargs_free_mlx_gen_orig_
#define mlx_closure_kwargs_new_func mlx_closure_kwargs_new_func_mlx_gen_orig_
#define mlx_closure_kwargs_new_func_payload mlx_closure_kwargs_new_func_payload_mlx_gen_orig_
#define mlx_closure_kwargs_set mlx_closure_kwargs_set_mlx_gen_orig_
#define mlx_closure_kwargs_apply mlx_closure_kwargs_apply_mlx_gen_orig_
#define mlx_closure_value_and_grad_new mlx_closure_value_and_grad_new_mlx_gen_orig_
#define mlx_closure_value_and_grad_free mlx_closure_value_and_grad_free_mlx_gen_orig_
#define mlx_closure_value_and_grad_new_func mlx_closure_value_and_grad_new_func_mlx_gen_orig_
#define mlx_closure_value_and_grad_new_func_payload mlx_closure_value_and_grad_new_func_payload_mlx_gen_orig_
#define mlx_closure_value_and_grad_set mlx_closure_value_and_grad_set_mlx_gen_orig_
#define mlx_closure_value_and_grad_apply mlx_closure_value_and_grad_apply_mlx_gen_orig_
#define mlx_closure_custom_new mlx_closure_custom_new_mlx_gen_orig_
#define mlx_closure_custom_free mlx_closure_custom_free_mlx_gen_orig_
#define mlx_closure_custom_new_func mlx_closure_custom_new_func_mlx_gen_orig_
#define mlx_closure_custom_new_func_payload mlx_closure_custom_new_func_payload_mlx_gen_orig_
#define mlx_closure_custom_set mlx_closure_custom_set_mlx_gen_orig_
#define mlx_closure_custom_apply mlx_closure_custom_apply_mlx_gen_orig_
#define mlx_closure_custom_jvp_new mlx_closure_custom_jvp_new_mlx_gen_orig_
#define mlx_closure_custom_jvp_free mlx_closure_custom_jvp_free_mlx_gen_orig_
#define mlx_closure_custom_jvp_new_func mlx_closure_custom_jvp_new_func_mlx_gen_orig_
#define mlx_closure_custom_jvp_new_func_payload mlx_closure_custom_jvp_new_func_payload_mlx_gen_orig_
#define mlx_closure_custom_jvp_set mlx_closure_custom_jvp_set_mlx_gen_orig_
#define mlx_closure_custom_jvp_apply mlx_closure_custom_jvp_apply_mlx_gen_orig_
#define mlx_closure_custom_vmap_new mlx_closure_custom_vmap_new_mlx_gen_orig_
#define mlx_closure_custom_vmap_free mlx_closure_custom_vmap_free_mlx_gen_orig_
#define mlx_closure_custom_vmap_new_func mlx_closure_custom_vmap_new_func_mlx_gen_orig_
#define mlx_closure_custom_vmap_new_func_payload mlx_closure_custom_vmap_new_func_payload_mlx_gen_orig_
#define mlx_closure_custom_vmap_set mlx_closure_custom_vmap_set_mlx_gen_orig_
#define mlx_closure_custom_vmap_apply mlx_closure_custom_vmap_apply_mlx_gen_orig_
#define mlx_compile mlx_compile_mlx_gen_orig_
#define mlx_detail_compile mlx_detail_compile_mlx_gen_orig_
#define mlx_detail_compile_clear_cache mlx_detail_compile_clear_cache_mlx_gen_orig_
#define mlx_detail_compile_erase mlx_detail_compile_erase_mlx_gen_orig_
#define mlx_disable_compile mlx_disable_compile_mlx_gen_orig_
#define mlx_enable_compile mlx_enable_compile_mlx_gen_orig_
#define mlx_set_compile_mode mlx_set_compile_mode_mlx_gen_orig_
#define mlx_device_new mlx_device_new_mlx_gen_orig_
#define mlx_device_new_type mlx_device_new_type_mlx_gen_orig_
#define mlx_device_free mlx_device_free_mlx_gen_orig_
#define mlx_device_set mlx_device_set_mlx_gen_orig_
#define mlx_device_tostring mlx_device_tostring_mlx_gen_orig_
#define mlx_device_equal mlx_device_equal_mlx_gen_orig_
#define mlx_device_get_index mlx_device_get_index_mlx_gen_orig_
#define mlx_device_get_type mlx_device_get_type_mlx_gen_orig_
#define mlx_get_default_device mlx_get_default_device_mlx_gen_orig_
#define mlx_set_default_device mlx_set_default_device_mlx_gen_orig_
#define mlx_distributed_group_rank mlx_distributed_group_rank_mlx_gen_orig_
#define mlx_distributed_group_size mlx_distributed_group_size_mlx_gen_orig_
#define mlx_distributed_group_split mlx_distributed_group_split_mlx_gen_orig_
#define mlx_distributed_is_available mlx_distributed_is_available_mlx_gen_orig_
#define mlx_distributed_init mlx_distributed_init_mlx_gen_orig_
#define mlx_distributed_all_gather mlx_distributed_all_gather_mlx_gen_orig_
#define mlx_distributed_all_max mlx_distributed_all_max_mlx_gen_orig_
#define mlx_distributed_all_min mlx_distributed_all_min_mlx_gen_orig_
#define mlx_distributed_all_sum mlx_distributed_all_sum_mlx_gen_orig_
#define mlx_distributed_recv mlx_distributed_recv_mlx_gen_orig_
#define mlx_distributed_recv_like mlx_distributed_recv_like_mlx_gen_orig_
#define mlx_distributed_send mlx_distributed_send_mlx_gen_orig_
#define mlx_distributed_sum_scatter mlx_distributed_sum_scatter_mlx_gen_orig_
#define mlx_set_error_handler mlx_set_error_handler_mlx_gen_orig_
#define _mlx_error _mlx_error_mlx_gen_orig_
#define mlx_export_function mlx_export_function_mlx_gen_orig_
#define mlx_export_function_kwargs mlx_export_function_kwargs_mlx_gen_orig_
#define mlx_function_exporter_new mlx_function_exporter_new_mlx_gen_orig_
#define mlx_function_exporter_free mlx_function_exporter_free_mlx_gen_orig_
#define mlx_function_exporter_apply mlx_function_exporter_apply_mlx_gen_orig_
#define mlx_function_exporter_apply_kwargs mlx_function_exporter_apply_kwargs_mlx_gen_orig_
#define mlx_imported_function_new mlx_imported_function_new_mlx_gen_orig_
#define mlx_imported_function_free mlx_imported_function_free_mlx_gen_orig_
#define mlx_imported_function_apply mlx_imported_function_apply_mlx_gen_orig_
#define mlx_imported_function_apply_kwargs mlx_imported_function_apply_kwargs_mlx_gen_orig_
#define mlx_fast_cuda_kernel_config_new mlx_fast_cuda_kernel_config_new_mlx_gen_orig_
#define mlx_fast_cuda_kernel_config_free mlx_fast_cuda_kernel_config_free_mlx_gen_orig_
#define mlx_fast_cuda_kernel_config_add_output_arg mlx_fast_cuda_kernel_config_add_output_arg_mlx_gen_orig_
#define mlx_fast_cuda_kernel_config_set_grid mlx_fast_cuda_kernel_config_set_grid_mlx_gen_orig_
#define mlx_fast_cuda_kernel_config_set_thread_group mlx_fast_cuda_kernel_config_set_thread_group_mlx_gen_orig_
#define mlx_fast_cuda_kernel_config_set_init_value mlx_fast_cuda_kernel_config_set_init_value_mlx_gen_orig_
#define mlx_fast_cuda_kernel_config_set_verbose mlx_fast_cuda_kernel_config_set_verbose_mlx_gen_orig_
#define mlx_fast_cuda_kernel_config_add_template_arg_dtype mlx_fast_cuda_kernel_config_add_template_arg_dtype_mlx_gen_orig_
#define mlx_fast_cuda_kernel_config_add_template_arg_int mlx_fast_cuda_kernel_config_add_template_arg_int_mlx_gen_orig_
#define mlx_fast_cuda_kernel_config_add_template_arg_bool mlx_fast_cuda_kernel_config_add_template_arg_bool_mlx_gen_orig_
#define mlx_fast_cuda_kernel_new mlx_fast_cuda_kernel_new_mlx_gen_orig_
#define mlx_fast_cuda_kernel_free mlx_fast_cuda_kernel_free_mlx_gen_orig_
#define mlx_fast_cuda_kernel_apply mlx_fast_cuda_kernel_apply_mlx_gen_orig_
#define mlx_fast_layer_norm mlx_fast_layer_norm_mlx_gen_orig_
#define mlx_fast_metal_kernel_config_new mlx_fast_metal_kernel_config_new_mlx_gen_orig_
#define mlx_fast_metal_kernel_config_free mlx_fast_metal_kernel_config_free_mlx_gen_orig_
#define mlx_fast_metal_kernel_config_add_output_arg mlx_fast_metal_kernel_config_add_output_arg_mlx_gen_orig_
#define mlx_fast_metal_kernel_config_set_grid mlx_fast_metal_kernel_config_set_grid_mlx_gen_orig_
#define mlx_fast_metal_kernel_config_set_thread_group mlx_fast_metal_kernel_config_set_thread_group_mlx_gen_orig_
#define mlx_fast_metal_kernel_config_set_init_value mlx_fast_metal_kernel_config_set_init_value_mlx_gen_orig_
#define mlx_fast_metal_kernel_config_set_verbose mlx_fast_metal_kernel_config_set_verbose_mlx_gen_orig_
#define mlx_fast_metal_kernel_config_add_template_arg_dtype mlx_fast_metal_kernel_config_add_template_arg_dtype_mlx_gen_orig_
#define mlx_fast_metal_kernel_config_add_template_arg_int mlx_fast_metal_kernel_config_add_template_arg_int_mlx_gen_orig_
#define mlx_fast_metal_kernel_config_add_template_arg_bool mlx_fast_metal_kernel_config_add_template_arg_bool_mlx_gen_orig_
#define mlx_fast_metal_kernel_new mlx_fast_metal_kernel_new_mlx_gen_orig_
#define mlx_fast_metal_kernel_free mlx_fast_metal_kernel_free_mlx_gen_orig_
#define mlx_fast_metal_kernel_apply mlx_fast_metal_kernel_apply_mlx_gen_orig_
#define mlx_fast_rms_norm mlx_fast_rms_norm_mlx_gen_orig_
#define mlx_fast_rope mlx_fast_rope_mlx_gen_orig_
#define mlx_fast_scaled_dot_product_attention mlx_fast_scaled_dot_product_attention_mlx_gen_orig_
#define mlx_fft_fft mlx_fft_fft_mlx_gen_orig_
#define mlx_fft_fft2 mlx_fft_fft2_mlx_gen_orig_
#define mlx_fft_fftn mlx_fft_fftn_mlx_gen_orig_
#define mlx_fft_fftshift mlx_fft_fftshift_mlx_gen_orig_
#define mlx_fft_ifft mlx_fft_ifft_mlx_gen_orig_
#define mlx_fft_ifft2 mlx_fft_ifft2_mlx_gen_orig_
#define mlx_fft_ifftn mlx_fft_ifftn_mlx_gen_orig_
#define mlx_fft_ifftshift mlx_fft_ifftshift_mlx_gen_orig_
#define mlx_fft_irfft mlx_fft_irfft_mlx_gen_orig_
#define mlx_fft_irfft2 mlx_fft_irfft2_mlx_gen_orig_
#define mlx_fft_irfftn mlx_fft_irfftn_mlx_gen_orig_
#define mlx_fft_rfft mlx_fft_rfft_mlx_gen_orig_
#define mlx_fft_rfft2 mlx_fft_rfft2_mlx_gen_orig_
#define mlx_fft_rfftn mlx_fft_rfftn_mlx_gen_orig_
#define mlx_io_reader_new mlx_io_reader_new_mlx_gen_orig_
#define mlx_io_reader_descriptor mlx_io_reader_descriptor_mlx_gen_orig_
#define mlx_io_reader_tostring mlx_io_reader_tostring_mlx_gen_orig_
#define mlx_io_reader_free mlx_io_reader_free_mlx_gen_orig_
#define mlx_io_writer_new mlx_io_writer_new_mlx_gen_orig_
#define mlx_io_writer_descriptor mlx_io_writer_descriptor_mlx_gen_orig_
#define mlx_io_writer_tostring mlx_io_writer_tostring_mlx_gen_orig_
#define mlx_io_writer_free mlx_io_writer_free_mlx_gen_orig_
#define mlx_load_reader mlx_load_reader_mlx_gen_orig_
#define mlx_load mlx_load_mlx_gen_orig_
#define mlx_load_safetensors_reader mlx_load_safetensors_reader_mlx_gen_orig_
#define mlx_load_safetensors mlx_load_safetensors_mlx_gen_orig_
#define mlx_save_writer mlx_save_writer_mlx_gen_orig_
#define mlx_save mlx_save_mlx_gen_orig_
#define mlx_save_safetensors_writer mlx_save_safetensors_writer_mlx_gen_orig_
#define mlx_save_safetensors mlx_save_safetensors_mlx_gen_orig_
#define mlx_linalg_cholesky mlx_linalg_cholesky_mlx_gen_orig_
#define mlx_linalg_cholesky_inv mlx_linalg_cholesky_inv_mlx_gen_orig_
#define mlx_linalg_cross mlx_linalg_cross_mlx_gen_orig_
#define mlx_linalg_eig mlx_linalg_eig_mlx_gen_orig_
#define mlx_linalg_eigh mlx_linalg_eigh_mlx_gen_orig_
#define mlx_linalg_eigvals mlx_linalg_eigvals_mlx_gen_orig_
#define mlx_linalg_eigvalsh mlx_linalg_eigvalsh_mlx_gen_orig_
#define mlx_linalg_inv mlx_linalg_inv_mlx_gen_orig_
#define mlx_linalg_lu mlx_linalg_lu_mlx_gen_orig_
#define mlx_linalg_lu_factor mlx_linalg_lu_factor_mlx_gen_orig_
#define mlx_linalg_norm mlx_linalg_norm_mlx_gen_orig_
#define mlx_linalg_norm_matrix mlx_linalg_norm_matrix_mlx_gen_orig_
#define mlx_linalg_norm_l2 mlx_linalg_norm_l2_mlx_gen_orig_
#define mlx_linalg_pinv mlx_linalg_pinv_mlx_gen_orig_
#define mlx_linalg_qr mlx_linalg_qr_mlx_gen_orig_
#define mlx_linalg_solve mlx_linalg_solve_mlx_gen_orig_
#define mlx_linalg_solve_triangular mlx_linalg_solve_triangular_mlx_gen_orig_
#define mlx_linalg_svd mlx_linalg_svd_mlx_gen_orig_
#define mlx_linalg_tri_inv mlx_linalg_tri_inv_mlx_gen_orig_
#define mlx_map_string_to_array_new mlx_map_string_to_array_new_mlx_gen_orig_
#define mlx_map_string_to_array_set mlx_map_string_to_array_set_mlx_gen_orig_
#define mlx_map_string_to_array_free mlx_map_string_to_array_free_mlx_gen_orig_
#define mlx_map_string_to_array_insert mlx_map_string_to_array_insert_mlx_gen_orig_
#define mlx_map_string_to_array_get mlx_map_string_to_array_get_mlx_gen_orig_
#define mlx_map_string_to_array_iterator_new mlx_map_string_to_array_iterator_new_mlx_gen_orig_
#define mlx_map_string_to_array_iterator_free mlx_map_string_to_array_iterator_free_mlx_gen_orig_
#define mlx_map_string_to_array_iterator_next mlx_map_string_to_array_iterator_next_mlx_gen_orig_
#define mlx_map_string_to_string_new mlx_map_string_to_string_new_mlx_gen_orig_
#define mlx_map_string_to_string_set mlx_map_string_to_string_set_mlx_gen_orig_
#define mlx_map_string_to_string_free mlx_map_string_to_string_free_mlx_gen_orig_
#define mlx_map_string_to_string_insert mlx_map_string_to_string_insert_mlx_gen_orig_
#define mlx_map_string_to_string_get mlx_map_string_to_string_get_mlx_gen_orig_
#define mlx_map_string_to_string_iterator_new mlx_map_string_to_string_iterator_new_mlx_gen_orig_
#define mlx_map_string_to_string_iterator_free mlx_map_string_to_string_iterator_free_mlx_gen_orig_
#define mlx_map_string_to_string_iterator_next mlx_map_string_to_string_iterator_next_mlx_gen_orig_
#define mlx_clear_cache mlx_clear_cache_mlx_gen_orig_
#define mlx_get_active_memory mlx_get_active_memory_mlx_gen_orig_
#define mlx_get_cache_memory mlx_get_cache_memory_mlx_gen_orig_
#define mlx_get_memory_limit mlx_get_memory_limit_mlx_gen_orig_
#define mlx_get_peak_memory mlx_get_peak_memory_mlx_gen_orig_
#define mlx_reset_peak_memory mlx_reset_peak_memory_mlx_gen_orig_
#define mlx_set_cache_limit mlx_set_cache_limit_mlx_gen_orig_
#define mlx_set_memory_limit mlx_set_memory_limit_mlx_gen_orig_
#define mlx_set_wired_limit mlx_set_wired_limit_mlx_gen_orig_
#define mlx_metal_device_info mlx_metal_device_info_mlx_gen_orig_
#define mlx_metal_is_available mlx_metal_is_available_mlx_gen_orig_
#define mlx_metal_start_capture mlx_metal_start_capture_mlx_gen_orig_
#define mlx_metal_stop_capture mlx_metal_stop_capture_mlx_gen_orig_
#define mlx_abs mlx_abs_mlx_gen_orig_
#define mlx_add mlx_add_mlx_gen_orig_
#define mlx_addmm mlx_addmm_mlx_gen_orig_
#define mlx_all_axes mlx_all_axes_mlx_gen_orig_
#define mlx_all_axis mlx_all_axis_mlx_gen_orig_
#define mlx_all mlx_all_mlx_gen_orig_
#define mlx_allclose mlx_allclose_mlx_gen_orig_
#define mlx_any_axes mlx_any_axes_mlx_gen_orig_
#define mlx_any_axis mlx_any_axis_mlx_gen_orig_
#define mlx_any mlx_any_mlx_gen_orig_
#define mlx_arange mlx_arange_mlx_gen_orig_
#define mlx_arccos mlx_arccos_mlx_gen_orig_
#define mlx_arccosh mlx_arccosh_mlx_gen_orig_
#define mlx_arcsin mlx_arcsin_mlx_gen_orig_
#define mlx_arcsinh mlx_arcsinh_mlx_gen_orig_
#define mlx_arctan mlx_arctan_mlx_gen_orig_
#define mlx_arctan2 mlx_arctan2_mlx_gen_orig_
#define mlx_arctanh mlx_arctanh_mlx_gen_orig_
#define mlx_argmax_axis mlx_argmax_axis_mlx_gen_orig_
#define mlx_argmax mlx_argmax_mlx_gen_orig_
#define mlx_argmin_axis mlx_argmin_axis_mlx_gen_orig_
#define mlx_argmin mlx_argmin_mlx_gen_orig_
#define mlx_argpartition_axis mlx_argpartition_axis_mlx_gen_orig_
#define mlx_argpartition mlx_argpartition_mlx_gen_orig_
#define mlx_argsort_axis mlx_argsort_axis_mlx_gen_orig_
#define mlx_argsort mlx_argsort_mlx_gen_orig_
#define mlx_array_equal mlx_array_equal_mlx_gen_orig_
#define mlx_as_strided mlx_as_strided_mlx_gen_orig_
#define mlx_astype mlx_astype_mlx_gen_orig_
#define mlx_atleast_1d mlx_atleast_1d_mlx_gen_orig_
#define mlx_atleast_2d mlx_atleast_2d_mlx_gen_orig_
#define mlx_atleast_3d mlx_atleast_3d_mlx_gen_orig_
#define mlx_bitwise_and mlx_bitwise_and_mlx_gen_orig_
#define mlx_bitwise_invert mlx_bitwise_invert_mlx_gen_orig_
#define mlx_bitwise_or mlx_bitwise_or_mlx_gen_orig_
#define mlx_bitwise_xor mlx_bitwise_xor_mlx_gen_orig_
#define mlx_block_masked_mm mlx_block_masked_mm_mlx_gen_orig_
#define mlx_broadcast_arrays mlx_broadcast_arrays_mlx_gen_orig_
#define mlx_broadcast_to mlx_broadcast_to_mlx_gen_orig_
#define mlx_ceil mlx_ceil_mlx_gen_orig_
#define mlx_clip mlx_clip_mlx_gen_orig_
#define mlx_concatenate_axis mlx_concatenate_axis_mlx_gen_orig_
#define mlx_concatenate mlx_concatenate_mlx_gen_orig_
#define mlx_conjugate mlx_conjugate_mlx_gen_orig_
#define mlx_contiguous mlx_contiguous_mlx_gen_orig_
#define mlx_conv1d mlx_conv1d_mlx_gen_orig_
#define mlx_conv2d mlx_conv2d_mlx_gen_orig_
#define mlx_conv3d mlx_conv3d_mlx_gen_orig_
#define mlx_conv_general mlx_conv_general_mlx_gen_orig_
#define mlx_conv_transpose1d mlx_conv_transpose1d_mlx_gen_orig_
#define mlx_conv_transpose2d mlx_conv_transpose2d_mlx_gen_orig_
#define mlx_conv_transpose3d mlx_conv_transpose3d_mlx_gen_orig_
#define mlx_copy mlx_copy_mlx_gen_orig_
#define mlx_cos mlx_cos_mlx_gen_orig_
#define mlx_cosh mlx_cosh_mlx_gen_orig_
#define mlx_cummax mlx_cummax_mlx_gen_orig_
#define mlx_cummin mlx_cummin_mlx_gen_orig_
#define mlx_cumprod mlx_cumprod_mlx_gen_orig_
#define mlx_cumsum mlx_cumsum_mlx_gen_orig_
#define mlx_degrees mlx_degrees_mlx_gen_orig_
#define mlx_depends mlx_depends_mlx_gen_orig_
#define mlx_dequantize mlx_dequantize_mlx_gen_orig_
#define mlx_diag mlx_diag_mlx_gen_orig_
#define mlx_diagonal mlx_diagonal_mlx_gen_orig_
#define mlx_divide mlx_divide_mlx_gen_orig_
#define mlx_divmod mlx_divmod_mlx_gen_orig_
#define mlx_einsum mlx_einsum_mlx_gen_orig_
#define mlx_equal mlx_equal_mlx_gen_orig_
#define mlx_erf mlx_erf_mlx_gen_orig_
#define mlx_erfinv mlx_erfinv_mlx_gen_orig_
#define mlx_exp mlx_exp_mlx_gen_orig_
#define mlx_expand_dims_axes mlx_expand_dims_axes_mlx_gen_orig_
#define mlx_expand_dims mlx_expand_dims_mlx_gen_orig_
#define mlx_expm1 mlx_expm1_mlx_gen_orig_
#define mlx_eye mlx_eye_mlx_gen_orig_
#define mlx_flatten mlx_flatten_mlx_gen_orig_
#define mlx_floor mlx_floor_mlx_gen_orig_
#define mlx_floor_divide mlx_floor_divide_mlx_gen_orig_
#define mlx_from_fp8 mlx_from_fp8_mlx_gen_orig_
#define mlx_full mlx_full_mlx_gen_orig_
#define mlx_full_like mlx_full_like_mlx_gen_orig_
#define mlx_gather mlx_gather_mlx_gen_orig_
#define mlx_gather_mm mlx_gather_mm_mlx_gen_orig_
#define mlx_gather_qmm mlx_gather_qmm_mlx_gen_orig_
#define mlx_greater mlx_greater_mlx_gen_orig_
#define mlx_greater_equal mlx_greater_equal_mlx_gen_orig_
#define mlx_hadamard_transform mlx_hadamard_transform_mlx_gen_orig_
#define mlx_identity mlx_identity_mlx_gen_orig_
#define mlx_imag mlx_imag_mlx_gen_orig_
#define mlx_inner mlx_inner_mlx_gen_orig_
#define mlx_isclose mlx_isclose_mlx_gen_orig_
#define mlx_isfinite mlx_isfinite_mlx_gen_orig_
#define mlx_isinf mlx_isinf_mlx_gen_orig_
#define mlx_isnan mlx_isnan_mlx_gen_orig_
#define mlx_isneginf mlx_isneginf_mlx_gen_orig_
#define mlx_isposinf mlx_isposinf_mlx_gen_orig_
#define mlx_kron mlx_kron_mlx_gen_orig_
#define mlx_left_shift mlx_left_shift_mlx_gen_orig_
#define mlx_less mlx_less_mlx_gen_orig_
#define mlx_less_equal mlx_less_equal_mlx_gen_orig_
#define mlx_linspace mlx_linspace_mlx_gen_orig_
#define mlx_log mlx_log_mlx_gen_orig_
#define mlx_log10 mlx_log10_mlx_gen_orig_
#define mlx_log1p mlx_log1p_mlx_gen_orig_
#define mlx_log2 mlx_log2_mlx_gen_orig_
#define mlx_logaddexp mlx_logaddexp_mlx_gen_orig_
#define mlx_logcumsumexp mlx_logcumsumexp_mlx_gen_orig_
#define mlx_logical_and mlx_logical_and_mlx_gen_orig_
#define mlx_logical_not mlx_logical_not_mlx_gen_orig_
#define mlx_logical_or mlx_logical_or_mlx_gen_orig_
#define mlx_logsumexp_axes mlx_logsumexp_axes_mlx_gen_orig_
#define mlx_logsumexp_axis mlx_logsumexp_axis_mlx_gen_orig_
#define mlx_logsumexp mlx_logsumexp_mlx_gen_orig_
#define mlx_masked_scatter mlx_masked_scatter_mlx_gen_orig_
#define mlx_matmul mlx_matmul_mlx_gen_orig_
#define mlx_max_axes mlx_max_axes_mlx_gen_orig_
#define mlx_max_axis mlx_max_axis_mlx_gen_orig_
#define mlx_max mlx_max_mlx_gen_orig_
#define mlx_maximum mlx_maximum_mlx_gen_orig_
#define mlx_mean_axes mlx_mean_axes_mlx_gen_orig_
#define mlx_mean_axis mlx_mean_axis_mlx_gen_orig_
#define mlx_mean mlx_mean_mlx_gen_orig_
#define mlx_median mlx_median_mlx_gen_orig_
#define mlx_meshgrid mlx_meshgrid_mlx_gen_orig_
#define mlx_min_axes mlx_min_axes_mlx_gen_orig_
#define mlx_min_axis mlx_min_axis_mlx_gen_orig_
#define mlx_min mlx_min_mlx_gen_orig_
#define mlx_minimum mlx_minimum_mlx_gen_orig_
#define mlx_moveaxis mlx_moveaxis_mlx_gen_orig_
#define mlx_multiply mlx_multiply_mlx_gen_orig_
#define mlx_nan_to_num mlx_nan_to_num_mlx_gen_orig_
#define mlx_negative mlx_negative_mlx_gen_orig_
#define mlx_not_equal mlx_not_equal_mlx_gen_orig_
#define mlx_number_of_elements mlx_number_of_elements_mlx_gen_orig_
#define mlx_ones mlx_ones_mlx_gen_orig_
#define mlx_ones_like mlx_ones_like_mlx_gen_orig_
#define mlx_outer mlx_outer_mlx_gen_orig_
#define mlx_pad mlx_pad_mlx_gen_orig_
#define mlx_pad_symmetric mlx_pad_symmetric_mlx_gen_orig_
#define mlx_partition_axis mlx_partition_axis_mlx_gen_orig_
#define mlx_partition mlx_partition_mlx_gen_orig_
#define mlx_power mlx_power_mlx_gen_orig_
#define mlx_prod_axes mlx_prod_axes_mlx_gen_orig_
#define mlx_prod_axis mlx_prod_axis_mlx_gen_orig_
#define mlx_prod mlx_prod_mlx_gen_orig_
#define mlx_put_along_axis mlx_put_along_axis_mlx_gen_orig_
#define mlx_quantize mlx_quantize_mlx_gen_orig_
#define mlx_quantized_matmul mlx_quantized_matmul_mlx_gen_orig_
#define mlx_radians mlx_radians_mlx_gen_orig_
#define mlx_real mlx_real_mlx_gen_orig_
#define mlx_reciprocal mlx_reciprocal_mlx_gen_orig_
#define mlx_remainder mlx_remainder_mlx_gen_orig_
#define mlx_repeat_axis mlx_repeat_axis_mlx_gen_orig_
#define mlx_repeat mlx_repeat_mlx_gen_orig_
#define mlx_reshape mlx_reshape_mlx_gen_orig_
#define mlx_right_shift mlx_right_shift_mlx_gen_orig_
#define mlx_roll_axis mlx_roll_axis_mlx_gen_orig_
#define mlx_roll_axes mlx_roll_axes_mlx_gen_orig_
#define mlx_roll mlx_roll_mlx_gen_orig_
#define mlx_round mlx_round_mlx_gen_orig_
#define mlx_rsqrt mlx_rsqrt_mlx_gen_orig_
#define mlx_scatter mlx_scatter_mlx_gen_orig_
#define mlx_scatter_add mlx_scatter_add_mlx_gen_orig_
#define mlx_scatter_add_axis mlx_scatter_add_axis_mlx_gen_orig_
#define mlx_scatter_max mlx_scatter_max_mlx_gen_orig_
#define mlx_scatter_min mlx_scatter_min_mlx_gen_orig_
#define mlx_scatter_prod mlx_scatter_prod_mlx_gen_orig_
#define mlx_segmented_mm mlx_segmented_mm_mlx_gen_orig_
#define mlx_sigmoid mlx_sigmoid_mlx_gen_orig_
#define mlx_sign mlx_sign_mlx_gen_orig_
#define mlx_sin mlx_sin_mlx_gen_orig_
#define mlx_sinh mlx_sinh_mlx_gen_orig_
#define mlx_slice mlx_slice_mlx_gen_orig_
#define mlx_slice_dynamic mlx_slice_dynamic_mlx_gen_orig_
#define mlx_slice_update mlx_slice_update_mlx_gen_orig_
#define mlx_slice_update_dynamic mlx_slice_update_dynamic_mlx_gen_orig_
#define mlx_softmax_axes mlx_softmax_axes_mlx_gen_orig_
#define mlx_softmax_axis mlx_softmax_axis_mlx_gen_orig_
#define mlx_softmax mlx_softmax_mlx_gen_orig_
#define mlx_sort_axis mlx_sort_axis_mlx_gen_orig_
#define mlx_sort mlx_sort_mlx_gen_orig_
#define mlx_split mlx_split_mlx_gen_orig_
#define mlx_split_sections mlx_split_sections_mlx_gen_orig_
#define mlx_sqrt mlx_sqrt_mlx_gen_orig_
#define mlx_square mlx_square_mlx_gen_orig_
#define mlx_squeeze_axes mlx_squeeze_axes_mlx_gen_orig_
#define mlx_squeeze_axis mlx_squeeze_axis_mlx_gen_orig_
#define mlx_squeeze mlx_squeeze_mlx_gen_orig_
#define mlx_stack_axis mlx_stack_axis_mlx_gen_orig_
#define mlx_stack mlx_stack_mlx_gen_orig_
#define mlx_std_axes mlx_std_axes_mlx_gen_orig_
#define mlx_std_axis mlx_std_axis_mlx_gen_orig_
#define mlx_std mlx_std_mlx_gen_orig_
#define mlx_stop_gradient mlx_stop_gradient_mlx_gen_orig_
#define mlx_subtract mlx_subtract_mlx_gen_orig_
#define mlx_sum_axes mlx_sum_axes_mlx_gen_orig_
#define mlx_sum_axis mlx_sum_axis_mlx_gen_orig_
#define mlx_sum mlx_sum_mlx_gen_orig_
#define mlx_swapaxes mlx_swapaxes_mlx_gen_orig_
#define mlx_take_axis mlx_take_axis_mlx_gen_orig_
#define mlx_take mlx_take_mlx_gen_orig_
#define mlx_take_along_axis mlx_take_along_axis_mlx_gen_orig_
#define mlx_tan mlx_tan_mlx_gen_orig_
#define mlx_tanh mlx_tanh_mlx_gen_orig_
#define mlx_tensordot mlx_tensordot_mlx_gen_orig_
#define mlx_tensordot_axis mlx_tensordot_axis_mlx_gen_orig_
#define mlx_tile mlx_tile_mlx_gen_orig_
#define mlx_to_fp8 mlx_to_fp8_mlx_gen_orig_
#define mlx_topk_axis mlx_topk_axis_mlx_gen_orig_
#define mlx_topk mlx_topk_mlx_gen_orig_
#define mlx_trace mlx_trace_mlx_gen_orig_
#define mlx_transpose_axes mlx_transpose_axes_mlx_gen_orig_
#define mlx_transpose mlx_transpose_mlx_gen_orig_
#define mlx_tri mlx_tri_mlx_gen_orig_
#define mlx_tril mlx_tril_mlx_gen_orig_
#define mlx_triu mlx_triu_mlx_gen_orig_
#define mlx_unflatten mlx_unflatten_mlx_gen_orig_
#define mlx_var_axes mlx_var_axes_mlx_gen_orig_
#define mlx_var_axis mlx_var_axis_mlx_gen_orig_
#define mlx_var mlx_var_mlx_gen_orig_
#define mlx_view mlx_view_mlx_gen_orig_
#define mlx_where mlx_where_mlx_gen_orig_
#define mlx_zeros mlx_zeros_mlx_gen_orig_
#define mlx_zeros_like mlx_zeros_like_mlx_gen_orig_
#define mlx_random_bernoulli mlx_random_bernoulli_mlx_gen_orig_
#define mlx_random_bits mlx_random_bits_mlx_gen_orig_
#define mlx_random_categorical_shape mlx_random_categorical_shape_mlx_gen_orig_
#define mlx_random_categorical_num_samples mlx_random_categorical_num_samples_mlx_gen_orig_
#define mlx_random_categorical mlx_random_categorical_mlx_gen_orig_
#define mlx_random_gumbel mlx_random_gumbel_mlx_gen_orig_
#define mlx_random_key mlx_random_key_mlx_gen_orig_
#define mlx_random_laplace mlx_random_laplace_mlx_gen_orig_
#define mlx_random_multivariate_normal mlx_random_multivariate_normal_mlx_gen_orig_
#define mlx_random_normal_broadcast mlx_random_normal_broadcast_mlx_gen_orig_
#define mlx_random_normal mlx_random_normal_mlx_gen_orig_
#define mlx_random_permutation mlx_random_permutation_mlx_gen_orig_
#define mlx_random_permutation_arange mlx_random_permutation_arange_mlx_gen_orig_
#define mlx_random_randint mlx_random_randint_mlx_gen_orig_
#define mlx_random_seed mlx_random_seed_mlx_gen_orig_
#define mlx_random_split_num mlx_random_split_num_mlx_gen_orig_
#define mlx_random_split mlx_random_split_mlx_gen_orig_
#define mlx_random_truncated_normal mlx_random_truncated_normal_mlx_gen_orig_
#define mlx_random_uniform mlx_random_uniform_mlx_gen_orig_
#define mlx_stream_new mlx_stream_new_mlx_gen_orig_
#define mlx_stream_new_device mlx_stream_new_device_mlx_gen_orig_
#define mlx_stream_set mlx_stream_set_mlx_gen_orig_
#define mlx_stream_free mlx_stream_free_mlx_gen_orig_
#define mlx_stream_tostring mlx_stream_tostring_mlx_gen_orig_
#define mlx_stream_equal mlx_stream_equal_mlx_gen_orig_
#define mlx_stream_get_device mlx_stream_get_device_mlx_gen_orig_
#define mlx_stream_get_index mlx_stream_get_index_mlx_gen_orig_
#define mlx_synchronize mlx_synchronize_mlx_gen_orig_
#define mlx_get_default_stream mlx_get_default_stream_mlx_gen_orig_
#define mlx_set_default_stream mlx_set_default_stream_mlx_gen_orig_
#define mlx_default_cpu_stream_new mlx_default_cpu_stream_new_mlx_gen_orig_
#define mlx_default_gpu_stream_new mlx_default_gpu_stream_new_mlx_gen_orig_
#define mlx_string_new mlx_string_new_mlx_gen_orig_
#define mlx_string_new_data mlx_string_new_data_mlx_gen_orig_
#define mlx_string_set mlx_string_set_mlx_gen_orig_
#define mlx_string_data mlx_string_data_mlx_gen_orig_
#define mlx_string_free mlx_string_free_mlx_gen_orig_
#define mlx_detail_vmap_replace mlx_detail_vmap_replace_mlx_gen_orig_
#define mlx_detail_vmap_trace mlx_detail_vmap_trace_mlx_gen_orig_
#define mlx_async_eval mlx_async_eval_mlx_gen_orig_
#define mlx_checkpoint mlx_checkpoint_mlx_gen_orig_
#define mlx_custom_function mlx_custom_function_mlx_gen_orig_
#define mlx_custom_vjp mlx_custom_vjp_mlx_gen_orig_
#define mlx_eval mlx_eval_mlx_gen_orig_
#define mlx_jvp mlx_jvp_mlx_gen_orig_
#define mlx_value_and_grad mlx_value_and_grad_mlx_gen_orig_
#define mlx_vjp mlx_vjp_mlx_gen_orig_
#define mlx_vector_array_new mlx_vector_array_new_mlx_gen_orig_
#define mlx_vector_array_set mlx_vector_array_set_mlx_gen_orig_
#define mlx_vector_array_free mlx_vector_array_free_mlx_gen_orig_
#define mlx_vector_array_new_data mlx_vector_array_new_data_mlx_gen_orig_
#define mlx_vector_array_new_value mlx_vector_array_new_value_mlx_gen_orig_
#define mlx_vector_array_set_data mlx_vector_array_set_data_mlx_gen_orig_
#define mlx_vector_array_set_value mlx_vector_array_set_value_mlx_gen_orig_
#define mlx_vector_array_append_data mlx_vector_array_append_data_mlx_gen_orig_
#define mlx_vector_array_append_value mlx_vector_array_append_value_mlx_gen_orig_
#define mlx_vector_array_size mlx_vector_array_size_mlx_gen_orig_
#define mlx_vector_array_get mlx_vector_array_get_mlx_gen_orig_
#define mlx_vector_vector_array_new mlx_vector_vector_array_new_mlx_gen_orig_
#define mlx_vector_vector_array_set mlx_vector_vector_array_set_mlx_gen_orig_
#define mlx_vector_vector_array_free mlx_vector_vector_array_free_mlx_gen_orig_
#define mlx_vector_vector_array_new_data mlx_vector_vector_array_new_data_mlx_gen_orig_
#define mlx_vector_vector_array_new_value mlx_vector_vector_array_new_value_mlx_gen_orig_
#define mlx_vector_vector_array_set_data mlx_vector_vector_array_set_data_mlx_gen_orig_
#define mlx_vector_vector_array_set_value mlx_vector_vector_array_set_value_mlx_gen_orig_
#define mlx_vector_vector_array_append_data mlx_vector_vector_array_append_data_mlx_gen_orig_
#define mlx_vector_vector_array_append_value mlx_vector_vector_array_append_value_mlx_gen_orig_
#define mlx_vector_vector_array_size mlx_vector_vector_array_size_mlx_gen_orig_
#define mlx_vector_vector_array_get mlx_vector_vector_array_get_mlx_gen_orig_
#define mlx_vector_int_new mlx_vector_int_new_mlx_gen_orig_
#define mlx_vector_int_set mlx_vector_int_set_mlx_gen_orig_
#define mlx_vector_int_free mlx_vector_int_free_mlx_gen_orig_
#define mlx_vector_int_new_data mlx_vector_int_new_data_mlx_gen_orig_
#define mlx_vector_int_new_value mlx_vector_int_new_value_mlx_gen_orig_
#define mlx_vector_int_set_data mlx_vector_int_set_data_mlx_gen_orig_
#define mlx_vector_int_set_value mlx_vector_int_set_value_mlx_gen_orig_
#define mlx_vector_int_append_data mlx_vector_int_append_data_mlx_gen_orig_
#define mlx_vector_int_append_value mlx_vector_int_append_value_mlx_gen_orig_
#define mlx_vector_int_size mlx_vector_int_size_mlx_gen_orig_
#define mlx_vector_int_get mlx_vector_int_get_mlx_gen_orig_
#define mlx_vector_string_new mlx_vector_string_new_mlx_gen_orig_
#define mlx_vector_string_set mlx_vector_string_set_mlx_gen_orig_
#define mlx_vector_string_free mlx_vector_string_free_mlx_gen_orig_
#define mlx_vector_string_new_data mlx_vector_string_new_data_mlx_gen_orig_
#define mlx_vector_string_new_value mlx_vector_string_new_value_mlx_gen_orig_
#define mlx_vector_string_set_data mlx_vector_string_set_data_mlx_gen_orig_
#define mlx_vector_string_set_value mlx_vector_string_set_value_mlx_gen_orig_
#define mlx_vector_string_append_data mlx_vector_string_append_data_mlx_gen_orig_
#define mlx_vector_string_append_value mlx_vector_string_append_value_mlx_gen_orig_
#define mlx_vector_string_size mlx_vector_string_size_mlx_gen_orig_
#define mlx_vector_string_get mlx_vector_string_get_mlx_gen_orig_
#define mlx_version mlx_version_mlx_gen_orig_
#include "mlx/c/mlx.h"
#undef mlx_dtype_size
#undef mlx_array_tostring
#undef mlx_array_new
#undef mlx_array_free
#undef mlx_array_new_bool
#undef mlx_array_new_int
#undef mlx_array_new_float32
#undef mlx_array_new_float
#undef mlx_array_new_float64
#undef mlx_array_new_double
#undef mlx_array_new_complex
#undef mlx_array_new_data
#undef mlx_array_set
#undef mlx_array_set_bool
#undef mlx_array_set_int
#undef mlx_array_set_float32
#undef mlx_array_set_float
#undef mlx_array_set_float64
#undef mlx_array_set_double
#undef mlx_array_set_complex
#undef mlx_array_set_data
#undef mlx_array_itemsize
#undef mlx_array_size
#undef mlx_array_nbytes
#undef mlx_array_ndim
#undef mlx_array_shape
#undef mlx_array_strides
#undef mlx_array_dim
#undef mlx_array_dtype
#undef mlx_array_eval
#undef mlx_array_item_bool
#undef mlx_array_item_uint8
#undef mlx_array_item_uint16
#undef mlx_array_item_uint32
#undef mlx_array_item_uint64
#undef mlx_array_item_int8
#undef mlx_array_item_int16
#undef mlx_array_item_int32
#undef mlx_array_item_int64
#undef mlx_array_item_float32
#undef mlx_array_item_float64
#undef mlx_array_item_complex64
#undef mlx_array_item_float16
#undef mlx_array_item_bfloat16
#undef mlx_array_data_bool
#undef mlx_array_data_uint8
#undef mlx_array_data_uint16
#undef mlx_array_data_uint32
#undef mlx_array_data_uint64
#undef mlx_array_data_int8
#undef mlx_array_data_int16
#undef mlx_array_data_int32
#undef mlx_array_data_int64
#undef mlx_array_data_float32
#undef mlx_array_data_float64
#undef mlx_array_data_complex64
#undef mlx_array_data_float16
#undef mlx_array_data_bfloat16
#undef _mlx_array_is_available
#undef _mlx_array_wait
#undef _mlx_array_is_contiguous
#undef _mlx_array_is_row_contiguous
#undef _mlx_array_is_col_contiguous
#undef mlx_closure_new
#undef mlx_closure_free
#undef mlx_closure_new_func
#undef mlx_closure_new_func_payload
#undef mlx_closure_set
#undef mlx_closure_apply
#undef mlx_closure_new_unary
#undef mlx_closure_kwargs_new
#undef mlx_closure_kwargs_free
#undef mlx_closure_kwargs_new_func
#undef mlx_closure_kwargs_new_func_payload
#undef mlx_closure_kwargs_set
#undef mlx_closure_kwargs_apply
#undef mlx_closure_value_and_grad_new
#undef mlx_closure_value_and_grad_free
#undef mlx_closure_value_and_grad_new_func
#undef mlx_closure_value_and_grad_new_func_payload
#undef mlx_closure_value_and_grad_set
#undef mlx_closure_value_and_grad_apply
#undef mlx_closure_custom_new
#undef mlx_closure_custom_free
#undef mlx_closure_custom_new_func
#undef mlx_closure_custom_new_func_payload
#undef mlx_closure_custom_set
#undef mlx_closure_custom_apply
#undef mlx_closure_custom_jvp_new
#undef mlx_closure_custom_jvp_free
#undef mlx_closure_custom_jvp_new_func
#undef mlx_closure_custom_jvp_new_func_payload
#undef mlx_closure_custom_jvp_set
#undef mlx_closure_custom_jvp_apply
#undef mlx_closure_custom_vmap_new
#undef mlx_closure_custom_vmap_free
#undef mlx_closure_custom_vmap_new_func
#undef mlx_closure_custom_vmap_new_func_payload
#undef mlx_closure_custom_vmap_set
#undef mlx_closure_custom_vmap_apply
#undef mlx_compile
#undef mlx_detail_compile
#undef mlx_detail_compile_clear_cache
#undef mlx_detail_compile_erase
#undef mlx_disable_compile
#undef mlx_enable_compile
#undef mlx_set_compile_mode
#undef mlx_device_new
#undef mlx_device_new_type
#undef mlx_device_free
#undef mlx_device_set
#undef mlx_device_tostring
#undef mlx_device_equal
#undef mlx_device_get_index
#undef mlx_device_get_type
#undef mlx_get_default_device
#undef mlx_set_default_device
#undef mlx_distributed_group_rank
#undef mlx_distributed_group_size
#undef mlx_distributed_group_split
#undef mlx_distributed_is_available
#undef mlx_distributed_init
#undef mlx_distributed_all_gather
#undef mlx_distributed_all_max
#undef mlx_distributed_all_min
#undef mlx_distributed_all_sum
#undef mlx_distributed_recv
#undef mlx_distributed_recv_like
#undef mlx_distributed_send
#undef mlx_distributed_sum_scatter
#undef mlx_set_error_handler
#undef _mlx_error
#undef mlx_export_function
#undef mlx_export_function_kwargs
#undef mlx_function_exporter_new
#undef mlx_function_exporter_free
#undef mlx_function_exporter_apply
#undef mlx_function_exporter_apply_kwargs
#undef mlx_imported_function_new
#undef mlx_imported_function_free
#undef mlx_imported_function_apply
#undef mlx_imported_function_apply_kwargs
#undef mlx_fast_cuda_kernel_config_new
#undef mlx_fast_cuda_kernel_config_free
#undef mlx_fast_cuda_kernel_config_add_output_arg
#undef mlx_fast_cuda_kernel_config_set_grid
#undef mlx_fast_cuda_kernel_config_set_thread_group
#undef mlx_fast_cuda_kernel_config_set_init_value
#undef mlx_fast_cuda_kernel_config_set_verbose
#undef mlx_fast_cuda_kernel_config_add_template_arg_dtype
#undef mlx_fast_cuda_kernel_config_add_template_arg_int
#undef mlx_fast_cuda_kernel_config_add_template_arg_bool
#undef mlx_fast_cuda_kernel_new
#undef mlx_fast_cuda_kernel_free
#undef mlx_fast_cuda_kernel_apply
#undef mlx_fast_layer_norm
#undef mlx_fast_metal_kernel_config_new
#undef mlx_fast_metal_kernel_config_free
#undef mlx_fast_metal_kernel_config_add_output_arg
#undef mlx_fast_metal_kernel_config_set_grid
#undef mlx_fast_metal_kernel_config_set_thread_group
#undef mlx_fast_metal_kernel_config_set_init_value
#undef mlx_fast_metal_kernel_config_set_verbose
#undef mlx_fast_metal_kernel_config_add_template_arg_dtype
#undef mlx_fast_metal_kernel_config_add_template_arg_int
#undef mlx_fast_metal_kernel_config_add_template_arg_bool
#undef mlx_fast_metal_kernel_new
#undef mlx_fast_metal_kernel_free
#undef mlx_fast_metal_kernel_apply
#undef mlx_fast_rms_norm
#undef mlx_fast_rope
#undef mlx_fast_scaled_dot_product_attention
#undef mlx_fft_fft
#undef mlx_fft_fft2
#undef mlx_fft_fftn
#undef mlx_fft_fftshift
#undef mlx_fft_ifft
#undef mlx_fft_ifft2
#undef mlx_fft_ifftn
#undef mlx_fft_ifftshift
#undef mlx_fft_irfft
#undef mlx_fft_irfft2
#undef mlx_fft_irfftn
#undef mlx_fft_rfft
#undef mlx_fft_rfft2
#undef mlx_fft_rfftn
#undef mlx_io_reader_new
#undef mlx_io_reader_descriptor
#undef mlx_io_reader_tostring
#undef mlx_io_reader_free
#undef mlx_io_writer_new
#undef mlx_io_writer_descriptor
#undef mlx_io_writer_tostring
#undef mlx_io_writer_free
#undef mlx_load_reader
#undef mlx_load
#undef mlx_load_safetensors_reader
#undef mlx_load_safetensors
#undef mlx_save_writer
#undef mlx_save
#undef mlx_save_safetensors_writer
#undef mlx_save_safetensors
#undef mlx_linalg_cholesky
#undef mlx_linalg_cholesky_inv
#undef mlx_linalg_cross
#undef mlx_linalg_eig
#undef mlx_linalg_eigh
#undef mlx_linalg_eigvals
#undef mlx_linalg_eigvalsh
#undef mlx_linalg_inv
#undef mlx_linalg_lu
#undef mlx_linalg_lu_factor
#undef mlx_linalg_norm
#undef mlx_linalg_norm_matrix
#undef mlx_linalg_norm_l2
#undef mlx_linalg_pinv
#undef mlx_linalg_qr
#undef mlx_linalg_solve
#undef mlx_linalg_solve_triangular
#undef mlx_linalg_svd
#undef mlx_linalg_tri_inv
#undef mlx_map_string_to_array_new
#undef mlx_map_string_to_array_set
#undef mlx_map_string_to_array_free
#undef mlx_map_string_to_array_insert
#undef mlx_map_string_to_array_get
#undef mlx_map_string_to_array_iterator_new
#undef mlx_map_string_to_array_iterator_free
#undef mlx_map_string_to_array_iterator_next
#undef mlx_map_string_to_string_new
#undef mlx_map_string_to_string_set
#undef mlx_map_string_to_string_free
#undef mlx_map_string_to_string_insert
#undef mlx_map_string_to_string_get
#undef mlx_map_string_to_string_iterator_new
#undef mlx_map_string_to_string_iterator_free
#undef mlx_map_string_to_string_iterator_next
#undef mlx_clear_cache
#undef mlx_get_active_memory
#undef mlx_get_cache_memory
#undef mlx_get_memory_limit
#undef mlx_get_peak_memory
#undef mlx_reset_peak_memory
#undef mlx_set_cache_limit
#undef mlx_set_memory_limit
#undef mlx_set_wired_limit
#undef mlx_metal_device_info
#undef mlx_metal_is_available
#undef mlx_metal_start_capture
#undef mlx_metal_stop_capture
#undef mlx_abs
#undef mlx_add
#undef mlx_addmm
#undef mlx_all_axes
#undef mlx_all_axis
#undef mlx_all
#undef mlx_allclose
#undef mlx_any_axes
#undef mlx_any_axis
#undef mlx_any
#undef mlx_arange
#undef mlx_arccos
#undef mlx_arccosh
#undef mlx_arcsin
#undef mlx_arcsinh
#undef mlx_arctan
#undef mlx_arctan2
#undef mlx_arctanh
#undef mlx_argmax_axis
#undef mlx_argmax
#undef mlx_argmin_axis
#undef mlx_argmin
#undef mlx_argpartition_axis
#undef mlx_argpartition
#undef mlx_argsort_axis
#undef mlx_argsort
#undef mlx_array_equal
#undef mlx_as_strided
#undef mlx_astype
#undef mlx_atleast_1d
#undef mlx_atleast_2d
#undef mlx_atleast_3d
#undef mlx_bitwise_and
#undef mlx_bitwise_invert
#undef mlx_bitwise_or
#undef mlx_bitwise_xor
#undef mlx_block_masked_mm
#undef mlx_broadcast_arrays
#undef mlx_broadcast_to
#undef mlx_ceil
#undef mlx_clip
#undef mlx_concatenate_axis
#undef mlx_concatenate
#undef mlx_conjugate
#undef mlx_contiguous
#undef mlx_conv1d
#undef mlx_conv2d
#undef mlx_conv3d
#undef mlx_conv_general
#undef mlx_conv_transpose1d
#undef mlx_conv_transpose2d
#undef mlx_conv_transpose3d
#undef mlx_copy
#undef mlx_cos
#undef mlx_cosh
#undef mlx_cummax
#undef mlx_cummin
#undef mlx_cumprod
#undef mlx_cumsum
#undef mlx_degrees
#undef mlx_depends
#undef mlx_dequantize
#undef mlx_diag
#undef mlx_diagonal
#undef mlx_divide
#undef mlx_divmod
#undef mlx_einsum
#undef mlx_equal
#undef mlx_erf
#undef mlx_erfinv
#undef mlx_exp
#undef mlx_expand_dims_axes
#undef mlx_expand_dims
#undef mlx_expm1
#undef mlx_eye
#undef mlx_flatten
#undef mlx_floor
#undef mlx_floor_divide
#undef mlx_from_fp8
#undef mlx_full
#undef mlx_full_like
#undef mlx_gather
#undef mlx_gather_mm
#undef mlx_gather_qmm
#undef mlx_greater
#undef mlx_greater_equal
#undef mlx_hadamard_transform
#undef mlx_identity
#undef mlx_imag
#undef mlx_inner
#undef mlx_isclose
#undef mlx_isfinite
#undef mlx_isinf
#undef mlx_isnan
#undef mlx_isneginf
#undef mlx_isposinf
#undef mlx_kron
#undef mlx_left_shift
#undef mlx_less
#undef mlx_less_equal
#undef mlx_linspace
#undef mlx_log
#undef mlx_log10
#undef mlx_log1p
#undef mlx_log2
#undef mlx_logaddexp
#undef mlx_logcumsumexp
#undef mlx_logical_and
#undef mlx_logical_not
#undef mlx_logical_or
#undef mlx_logsumexp_axes
#undef mlx_logsumexp_axis
#undef mlx_logsumexp
#undef mlx_masked_scatter
#undef mlx_matmul
#undef mlx_max_axes
#undef mlx_max_axis
#undef mlx_max
#undef mlx_maximum
#undef mlx_mean_axes
#undef mlx_mean_axis
#undef mlx_mean
#undef mlx_median
#undef mlx_meshgrid
#undef mlx_min_axes
#undef mlx_min_axis
#undef mlx_min
#undef mlx_minimum
#undef mlx_moveaxis
#undef mlx_multiply
#undef mlx_nan_to_num
#undef mlx_negative
#undef mlx_not_equal
#undef mlx_number_of_elements
#undef mlx_ones
#undef mlx_ones_like
#undef mlx_outer
#undef mlx_pad
#undef mlx_pad_symmetric
#undef mlx_partition_axis
#undef mlx_partition
#undef mlx_power
#undef mlx_prod_axes
#undef mlx_prod_axis
#undef mlx_prod
#undef mlx_put_along_axis
#undef mlx_quantize
#undef mlx_quantized_matmul
#undef mlx_radians
#undef mlx_real
#undef mlx_reciprocal
#undef mlx_remainder
#undef mlx_repeat_axis
#undef mlx_repeat
#undef mlx_reshape
#undef mlx_right_shift
#undef mlx_roll_axis
#undef mlx_roll_axes
#undef mlx_roll
#undef mlx_round
#undef mlx_rsqrt
#undef mlx_scatter
#undef mlx_scatter_add
#undef mlx_scatter_add_axis
#undef mlx_scatter_max
#undef mlx_scatter_min
#undef mlx_scatter_prod
#undef mlx_segmented_mm
#undef mlx_sigmoid
#undef mlx_sign
#undef mlx_sin
#undef mlx_sinh
#undef mlx_slice
#undef mlx_slice_dynamic
#undef mlx_slice_update
#undef mlx_slice_update_dynamic
#undef mlx_softmax_axes
#undef mlx_softmax_axis
#undef mlx_softmax
#undef mlx_sort_axis
#undef mlx_sort
#undef mlx_split
#undef mlx_split_sections
#undef mlx_sqrt
#undef mlx_square
#undef mlx_squeeze_axes
#undef mlx_squeeze_axis
#undef mlx_squeeze
#undef mlx_stack_axis
#undef mlx_stack
#undef mlx_std_axes
#undef mlx_std_axis
#undef mlx_std
#undef mlx_stop_gradient
#undef mlx_subtract
#undef mlx_sum_axes
#undef mlx_sum_axis
#undef mlx_sum
#undef mlx_swapaxes
#undef mlx_take_axis
#undef mlx_take
#undef mlx_take_along_axis
#undef mlx_tan
#undef mlx_tanh
#undef mlx_tensordot
#undef mlx_tensordot_axis
#undef mlx_tile
#undef mlx_to_fp8
#undef mlx_topk_axis
#undef mlx_topk
#undef mlx_trace
#undef mlx_transpose_axes
#undef mlx_transpose
#undef mlx_tri
#undef mlx_tril
#undef mlx_triu
#undef mlx_unflatten
#undef mlx_var_axes
#undef mlx_var_axis
#undef mlx_var
#undef mlx_view
#undef mlx_where
#undef mlx_zeros
#undef mlx_zeros_like
#undef mlx_random_bernoulli
#undef mlx_random_bits
#undef mlx_random_categorical_shape
#undef mlx_random_categorical_num_samples
#undef mlx_random_categorical
#undef mlx_random_gumbel
#undef mlx_random_key
#undef mlx_random_laplace
#undef mlx_random_multivariate_normal
#undef mlx_random_normal_broadcast
#undef mlx_random_normal
#undef mlx_random_permutation
#undef mlx_random_permutation_arange
#undef mlx_random_randint
#undef mlx_random_seed
#undef mlx_random_split_num
#undef mlx_random_split
#undef mlx_random_truncated_normal
#undef mlx_random_uniform
#undef mlx_stream_new
#undef mlx_stream_new_device
#undef mlx_stream_set
#undef mlx_stream_free
#undef mlx_stream_tostring
#undef mlx_stream_equal
#undef mlx_stream_get_device
#undef mlx_stream_get_index
#undef mlx_synchronize
#undef mlx_get_default_stream
#undef mlx_set_default_stream
#undef mlx_default_cpu_stream_new
#undef mlx_default_gpu_stream_new
#undef mlx_string_new
#undef mlx_string_new_data
#undef mlx_string_set
#undef mlx_string_data
#undef mlx_string_free
#undef mlx_detail_vmap_replace
#undef mlx_detail_vmap_trace
#undef mlx_async_eval
#undef mlx_checkpoint
#undef mlx_custom_function
#undef mlx_custom_vjp
#undef mlx_eval
#undef mlx_jvp
#undef mlx_value_and_grad
#undef mlx_vjp
#undef mlx_vector_array_new
#undef mlx_vector_array_set
#undef mlx_vector_array_free
#undef mlx_vector_array_new_data
#undef mlx_vector_array_new_value
#undef mlx_vector_array_set_data
#undef mlx_vector_array_set_value
#undef mlx_vector_array_append_data
#undef mlx_vector_array_append_value
#undef mlx_vector_array_size
#undef mlx_vector_array_get
#undef mlx_vector_vector_array_new
#undef mlx_vector_vector_array_set
#undef mlx_vector_vector_array_free
#undef mlx_vector_vector_array_new_data
#undef mlx_vector_vector_array_new_value
#undef mlx_vector_vector_array_set_data
#undef mlx_vector_vector_array_set_value
#undef mlx_vector_vector_array_append_data
#undef mlx_vector_vector_array_append_value
#undef mlx_vector_vector_array_size
#undef mlx_vector_vector_array_get
#undef mlx_vector_int_new
#undef mlx_vector_int_set
#undef mlx_vector_int_free
#undef mlx_vector_int_new_data
#undef mlx_vector_int_new_value
#undef mlx_vector_int_set_data
#undef mlx_vector_int_set_value
#undef mlx_vector_int_append_data
#undef mlx_vector_int_append_value
#undef mlx_vector_int_size
#undef mlx_vector_int_get
#undef mlx_vector_string_new
#undef mlx_vector_string_set
#undef mlx_vector_string_free
#undef mlx_vector_string_new_data
#undef mlx_vector_string_new_value
#undef mlx_vector_string_set_data
#undef mlx_vector_string_set_value
#undef mlx_vector_string_append_data
#undef mlx_vector_string_append_value
#undef mlx_vector_string_size
#undef mlx_vector_string_get
#undef mlx_version
extern size_t (*mlx_dtype_size_)(mlx_dtype dtype);
extern int (*mlx_array_tostring_)(mlx_string* str, const mlx_array arr);
extern mlx_array (*mlx_array_new_)(void);
extern int (*mlx_array_free_)(mlx_array arr);
extern mlx_array (*mlx_array_new_bool_)(bool val);
extern mlx_array (*mlx_array_new_int_)(int val);
extern mlx_array (*mlx_array_new_float32_)(float val);
extern mlx_array (*mlx_array_new_float_)(float val);
extern mlx_array (*mlx_array_new_float64_)(double val);
extern mlx_array (*mlx_array_new_double_)(double val);
extern mlx_array (*mlx_array_new_complex_)(float real_val, float imag_val);
extern mlx_array (*mlx_array_new_data_)(
const void* data,
const int* shape,
int dim,
mlx_dtype dtype);
extern int (*mlx_array_set_)(mlx_array* arr, const mlx_array src);
extern int (*mlx_array_set_bool_)(mlx_array* arr, bool val);
extern int (*mlx_array_set_int_)(mlx_array* arr, int val);
extern int (*mlx_array_set_float32_)(mlx_array* arr, float val);
extern int (*mlx_array_set_float_)(mlx_array* arr, float val);
extern int (*mlx_array_set_float64_)(mlx_array* arr, double val);
extern int (*mlx_array_set_double_)(mlx_array* arr, double val);
extern int (*mlx_array_set_complex_)(mlx_array* arr, float real_val, float imag_val);
extern int (*mlx_array_set_data_)(
mlx_array* arr,
const void* data,
const int* shape,
int dim,
mlx_dtype dtype);
extern size_t (*mlx_array_itemsize_)(const mlx_array arr);
extern size_t (*mlx_array_size_)(const mlx_array arr);
extern size_t (*mlx_array_nbytes_)(const mlx_array arr);
extern size_t (*mlx_array_ndim_)(const mlx_array arr);
extern const int * (*mlx_array_shape_)(const mlx_array arr);
extern const size_t * (*mlx_array_strides_)(const mlx_array arr);
extern int (*mlx_array_dim_)(const mlx_array arr, int dim);
extern mlx_dtype (*mlx_array_dtype_)(const mlx_array arr);
extern int (*mlx_array_eval_)(mlx_array arr);
extern int (*mlx_array_item_bool_)(bool* res, const mlx_array arr);
extern int (*mlx_array_item_uint8_)(uint8_t* res, const mlx_array arr);
extern int (*mlx_array_item_uint16_)(uint16_t* res, const mlx_array arr);
extern int (*mlx_array_item_uint32_)(uint32_t* res, const mlx_array arr);
extern int (*mlx_array_item_uint64_)(uint64_t* res, const mlx_array arr);
extern int (*mlx_array_item_int8_)(int8_t* res, const mlx_array arr);
extern int (*mlx_array_item_int16_)(int16_t* res, const mlx_array arr);
extern int (*mlx_array_item_int32_)(int32_t* res, const mlx_array arr);
extern int (*mlx_array_item_int64_)(int64_t* res, const mlx_array arr);
extern int (*mlx_array_item_float32_)(float* res, const mlx_array arr);
extern int (*mlx_array_item_float64_)(double* res, const mlx_array arr);
extern int (*mlx_array_item_complex64_)(float _Complex* res, const mlx_array arr);
extern int (*mlx_array_item_float16_)(float16_t* res, const mlx_array arr);
extern int (*mlx_array_item_bfloat16_)(bfloat16_t* res, const mlx_array arr);
extern const bool * (*mlx_array_data_bool_)(const mlx_array arr);
extern const uint8_t * (*mlx_array_data_uint8_)(const mlx_array arr);
extern const uint16_t * (*mlx_array_data_uint16_)(const mlx_array arr);
extern const uint32_t * (*mlx_array_data_uint32_)(const mlx_array arr);
extern const uint64_t * (*mlx_array_data_uint64_)(const mlx_array arr);
extern const int8_t * (*mlx_array_data_int8_)(const mlx_array arr);
extern const int16_t * (*mlx_array_data_int16_)(const mlx_array arr);
extern const int32_t * (*mlx_array_data_int32_)(const mlx_array arr);
extern const int64_t * (*mlx_array_data_int64_)(const mlx_array arr);
extern const float * (*mlx_array_data_float32_)(const mlx_array arr);
extern const double * (*mlx_array_data_float64_)(const mlx_array arr);
extern const float _Complex * (*mlx_array_data_complex64_)(const mlx_array arr);
extern const float16_t * (*mlx_array_data_float16_)(const mlx_array arr);
extern const bfloat16_t * (*mlx_array_data_bfloat16_)(const mlx_array arr);
extern int (*_mlx_array_is_available_)(bool* res, const mlx_array arr);
extern int (*_mlx_array_wait_)(const mlx_array arr);
extern int (*_mlx_array_is_contiguous_)(bool* res, const mlx_array arr);
extern int (*_mlx_array_is_row_contiguous_)(bool* res, const mlx_array arr);
extern int (*_mlx_array_is_col_contiguous_)(bool* res, const mlx_array arr);
extern mlx_closure (*mlx_closure_new_)(void);
extern int (*mlx_closure_free_)(mlx_closure cls);
extern mlx_closure (*mlx_closure_new_func_)(
int (*fun)(mlx_vector_array*, const mlx_vector_array));
extern mlx_closure (*mlx_closure_new_func_payload_)(
int (*fun)(mlx_vector_array*, const mlx_vector_array, void*),
void* payload,
void (*dtor)(void*));
extern int (*mlx_closure_set_)(mlx_closure* cls, const mlx_closure src);
extern int (*mlx_closure_apply_)(
mlx_vector_array* res,
mlx_closure cls,
const mlx_vector_array input);
extern mlx_closure (*mlx_closure_new_unary_)(int (*fun)(mlx_array*, const mlx_array));
extern mlx_closure_kwargs (*mlx_closure_kwargs_new_)(void);
extern int (*mlx_closure_kwargs_free_)(mlx_closure_kwargs cls);
extern mlx_closure_kwargs (*mlx_closure_kwargs_new_func_)(int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_map_string_to_array));
extern mlx_closure_kwargs (*mlx_closure_kwargs_new_func_payload_)(
int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_map_string_to_array,
void*),
void* payload,
void (*dtor)(void*));
extern int (*mlx_closure_kwargs_set_)(
mlx_closure_kwargs* cls,
const mlx_closure_kwargs src);
extern int (*mlx_closure_kwargs_apply_)(
mlx_vector_array* res,
mlx_closure_kwargs cls,
const mlx_vector_array input_0,
const mlx_map_string_to_array input_1);
extern mlx_closure_value_and_grad (*mlx_closure_value_and_grad_new_)(void);
extern int (*mlx_closure_value_and_grad_free_)(mlx_closure_value_and_grad cls);
extern mlx_closure_value_and_grad (*mlx_closure_value_and_grad_new_func_)(
int (*fun)(mlx_vector_array*, mlx_vector_array*, const mlx_vector_array));
extern mlx_closure_value_and_grad (*mlx_closure_value_and_grad_new_func_payload_)(
int (*fun)(
mlx_vector_array*,
mlx_vector_array*,
const mlx_vector_array,
void*),
void* payload,
void (*dtor)(void*));
extern int (*mlx_closure_value_and_grad_set_)(
mlx_closure_value_and_grad* cls,
const mlx_closure_value_and_grad src);
extern int (*mlx_closure_value_and_grad_apply_)(
mlx_vector_array* res_0,
mlx_vector_array* res_1,
mlx_closure_value_and_grad cls,
const mlx_vector_array input);
extern mlx_closure_custom (*mlx_closure_custom_new_)(void);
extern int (*mlx_closure_custom_free_)(mlx_closure_custom cls);
extern mlx_closure_custom (*mlx_closure_custom_new_func_)(int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_vector_array,
const mlx_vector_array));
extern mlx_closure_custom (*mlx_closure_custom_new_func_payload_)(
int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_vector_array,
const mlx_vector_array,
void*),
void* payload,
void (*dtor)(void*));
extern int (*mlx_closure_custom_set_)(
mlx_closure_custom* cls,
const mlx_closure_custom src);
extern int (*mlx_closure_custom_apply_)(
mlx_vector_array* res,
mlx_closure_custom cls,
const mlx_vector_array input_0,
const mlx_vector_array input_1,
const mlx_vector_array input_2);
extern mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_)(void);
extern int (*mlx_closure_custom_jvp_free_)(mlx_closure_custom_jvp cls);
extern mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_)(int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_vector_array,
const int*,
size_t _num));
extern mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_payload_)(
int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_vector_array,
const int*,
size_t _num,
void*),
void* payload,
void (*dtor)(void*));
extern int (*mlx_closure_custom_jvp_set_)(
mlx_closure_custom_jvp* cls,
const mlx_closure_custom_jvp src);
extern int (*mlx_closure_custom_jvp_apply_)(
mlx_vector_array* res,
mlx_closure_custom_jvp cls,
const mlx_vector_array input_0,
const mlx_vector_array input_1,
const int* input_2,
size_t input_2_num);
extern mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_)(void);
extern int (*mlx_closure_custom_vmap_free_)(mlx_closure_custom_vmap cls);
extern mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_)(int (*fun)(
mlx_vector_array*,
mlx_vector_int*,
const mlx_vector_array,
const int*,
size_t _num));
extern mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_payload_)(
int (*fun)(
mlx_vector_array*,
mlx_vector_int*,
const mlx_vector_array,
const int*,
size_t _num,
void*),
void* payload,
void (*dtor)(void*));
extern int (*mlx_closure_custom_vmap_set_)(
mlx_closure_custom_vmap* cls,
const mlx_closure_custom_vmap src);
extern int (*mlx_closure_custom_vmap_apply_)(
mlx_vector_array* res_0,
mlx_vector_int* res_1,
mlx_closure_custom_vmap cls,
const mlx_vector_array input_0,
const int* input_1,
size_t input_1_num);
extern int (*mlx_compile_)(mlx_closure* res, const mlx_closure fun, bool shapeless);
extern int (*mlx_detail_compile_)(
mlx_closure* res,
const mlx_closure fun,
uintptr_t fun_id,
bool shapeless,
const uint64_t* constants,
size_t constants_num);
extern int (*mlx_detail_compile_clear_cache_)(void);
extern int (*mlx_detail_compile_erase_)(uintptr_t fun_id);
extern int (*mlx_disable_compile_)(void);
extern int (*mlx_enable_compile_)(void);
extern int (*mlx_set_compile_mode_)(mlx_compile_mode mode);
extern mlx_device (*mlx_device_new_)(void);
extern mlx_device (*mlx_device_new_type_)(mlx_device_type type, int index);
extern int (*mlx_device_free_)(mlx_device dev);
extern int (*mlx_device_set_)(mlx_device* dev, const mlx_device src);
extern int (*mlx_device_tostring_)(mlx_string* str, mlx_device dev);
extern bool (*mlx_device_equal_)(mlx_device lhs, mlx_device rhs);
extern int (*mlx_device_get_index_)(int* index, mlx_device dev);
extern int (*mlx_device_get_type_)(mlx_device_type* type, mlx_device dev);
extern int (*mlx_get_default_device_)(mlx_device* dev);
extern int (*mlx_set_default_device_)(mlx_device dev);
extern int (*mlx_distributed_group_rank_)(mlx_distributed_group group);
extern int (*mlx_distributed_group_size_)(mlx_distributed_group group);
extern mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key);
extern bool (*mlx_distributed_is_available_)(void);
extern mlx_distributed_group (*mlx_distributed_init_)(bool strict);
extern int (*mlx_distributed_all_gather_)(
mlx_array* res,
const mlx_array x,
const mlx_distributed_group group /* may be null */,
const mlx_stream S);
extern int (*mlx_distributed_all_max_)(
mlx_array* res,
const mlx_array x,
const mlx_distributed_group group /* may be null */,
const mlx_stream s);
extern int (*mlx_distributed_all_min_)(
mlx_array* res,
const mlx_array x,
const mlx_distributed_group group /* may be null */,
const mlx_stream s);
extern int (*mlx_distributed_all_sum_)(
mlx_array* res,
const mlx_array x,
const mlx_distributed_group group /* may be null */,
const mlx_stream s);
extern int (*mlx_distributed_recv_)(
mlx_array* res,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
int src,
const mlx_distributed_group group /* may be null */,
const mlx_stream s);
extern int (*mlx_distributed_recv_like_)(
mlx_array* res,
const mlx_array x,
int src,
const mlx_distributed_group group /* may be null */,
const mlx_stream s);
extern int (*mlx_distributed_send_)(
mlx_array* res,
const mlx_array x,
int dst,
const mlx_distributed_group group /* may be null */,
const mlx_stream s);
extern int (*mlx_distributed_sum_scatter_)(
mlx_array* res,
const mlx_array x,
const mlx_distributed_group group /* may be null */,
const mlx_stream s);
extern void (*mlx_set_error_handler_)(
mlx_error_handler_func handler,
void* data,
void (*dtor)(void*));
extern void (*_mlx_error_)(const char* file, const int line, const char* fmt, ...);
extern int (*mlx_export_function_)(
const char* file,
const mlx_closure fun,
const mlx_vector_array args,
bool shapeless);
extern int (*mlx_export_function_kwargs_)(
const char* file,
const mlx_closure_kwargs fun,
const mlx_vector_array args,
const mlx_map_string_to_array kwargs,
bool shapeless);
extern mlx_function_exporter (*mlx_function_exporter_new_)(
const char* file,
const mlx_closure fun,
bool shapeless);
extern int (*mlx_function_exporter_free_)(mlx_function_exporter xfunc);
extern int (*mlx_function_exporter_apply_)(
const mlx_function_exporter xfunc,
const mlx_vector_array args);
extern int (*mlx_function_exporter_apply_kwargs_)(
const mlx_function_exporter xfunc,
const mlx_vector_array args,
const mlx_map_string_to_array kwargs);
extern mlx_imported_function (*mlx_imported_function_new_)(const char* file);
extern int (*mlx_imported_function_free_)(mlx_imported_function xfunc);
extern int (*mlx_imported_function_apply_)(
mlx_vector_array* res,
const mlx_imported_function xfunc,
const mlx_vector_array args);
extern int (*mlx_imported_function_apply_kwargs_)(
mlx_vector_array* res,
const mlx_imported_function xfunc,
const mlx_vector_array args,
const mlx_map_string_to_array kwargs);
extern mlx_fast_cuda_kernel_config (*mlx_fast_cuda_kernel_config_new_)(void);
extern void (*mlx_fast_cuda_kernel_config_free_)(mlx_fast_cuda_kernel_config cls);
extern int (*mlx_fast_cuda_kernel_config_add_output_arg_)(
mlx_fast_cuda_kernel_config cls,
const int* shape,
size_t size,
mlx_dtype dtype);
extern int (*mlx_fast_cuda_kernel_config_set_grid_)(
mlx_fast_cuda_kernel_config cls,
int grid1,
int grid2,
int grid3);
extern int (*mlx_fast_cuda_kernel_config_set_thread_group_)(
mlx_fast_cuda_kernel_config cls,
int thread1,
int thread2,
int thread3);
extern int (*mlx_fast_cuda_kernel_config_set_init_value_)(
mlx_fast_cuda_kernel_config cls,
float value);
extern int (*mlx_fast_cuda_kernel_config_set_verbose_)(
mlx_fast_cuda_kernel_config cls,
bool verbose);
extern int (*mlx_fast_cuda_kernel_config_add_template_arg_dtype_)(
mlx_fast_cuda_kernel_config cls,
const char* name,
mlx_dtype dtype);
extern int (*mlx_fast_cuda_kernel_config_add_template_arg_int_)(
mlx_fast_cuda_kernel_config cls,
const char* name,
int value);
extern int (*mlx_fast_cuda_kernel_config_add_template_arg_bool_)(
mlx_fast_cuda_kernel_config cls,
const char* name,
bool value);
extern mlx_fast_cuda_kernel (*mlx_fast_cuda_kernel_new_)(
const char* name,
const mlx_vector_string input_names,
const mlx_vector_string output_names,
const char* source,
const char* header,
bool ensure_row_contiguous,
int shared_memory);
extern void (*mlx_fast_cuda_kernel_free_)(mlx_fast_cuda_kernel cls);
extern int (*mlx_fast_cuda_kernel_apply_)(
mlx_vector_array* outputs,
mlx_fast_cuda_kernel cls,
const mlx_vector_array inputs,
const mlx_fast_cuda_kernel_config config,
const mlx_stream stream);
extern int (*mlx_fast_layer_norm_)(
mlx_array* res,
const mlx_array x,
const mlx_array weight /* may be null */,
const mlx_array bias /* may be null */,
float eps,
const mlx_stream s);
extern mlx_fast_metal_kernel_config (*mlx_fast_metal_kernel_config_new_)(void);
extern void (*mlx_fast_metal_kernel_config_free_)(mlx_fast_metal_kernel_config cls);
extern int (*mlx_fast_metal_kernel_config_add_output_arg_)(
mlx_fast_metal_kernel_config cls,
const int* shape,
size_t size,
mlx_dtype dtype);
extern int (*mlx_fast_metal_kernel_config_set_grid_)(
mlx_fast_metal_kernel_config cls,
int grid1,
int grid2,
int grid3);
extern int (*mlx_fast_metal_kernel_config_set_thread_group_)(
mlx_fast_metal_kernel_config cls,
int thread1,
int thread2,
int thread3);
extern int (*mlx_fast_metal_kernel_config_set_init_value_)(
mlx_fast_metal_kernel_config cls,
float value);
extern int (*mlx_fast_metal_kernel_config_set_verbose_)(
mlx_fast_metal_kernel_config cls,
bool verbose);
extern int (*mlx_fast_metal_kernel_config_add_template_arg_dtype_)(
mlx_fast_metal_kernel_config cls,
const char* name,
mlx_dtype dtype);
extern int (*mlx_fast_metal_kernel_config_add_template_arg_int_)(
mlx_fast_metal_kernel_config cls,
const char* name,
int value);
extern int (*mlx_fast_metal_kernel_config_add_template_arg_bool_)(
mlx_fast_metal_kernel_config cls,
const char* name,
bool value);
extern mlx_fast_metal_kernel (*mlx_fast_metal_kernel_new_)(
const char* name,
const mlx_vector_string input_names,
const mlx_vector_string output_names,
const char* source,
const char* header,
bool ensure_row_contiguous,
bool atomic_outputs);
extern void (*mlx_fast_metal_kernel_free_)(mlx_fast_metal_kernel cls);
extern int (*mlx_fast_metal_kernel_apply_)(
mlx_vector_array* outputs,
mlx_fast_metal_kernel cls,
const mlx_vector_array inputs,
const mlx_fast_metal_kernel_config config,
const mlx_stream stream);
extern int (*mlx_fast_rms_norm_)(
mlx_array* res,
const mlx_array x,
const mlx_array weight /* may be null */,
float eps,
const mlx_stream s);
extern int (*mlx_fast_rope_)(
mlx_array* res,
const mlx_array x,
int dims,
bool traditional,
mlx_optional_float base,
float scale,
int offset,
const mlx_array freqs /* may be null */,
const mlx_stream s);
extern int (*mlx_fast_scaled_dot_product_attention_)(
mlx_array* res,
const mlx_array queries,
const mlx_array keys,
const mlx_array values,
float scale,
const char* mask_mode,
const mlx_array mask_arr /* may be null */,
const mlx_array sinks /* may be null */,
const mlx_stream s);
extern int (*mlx_fft_fft_)(
mlx_array* res,
const mlx_array a,
int n,
int axis,
const mlx_stream s);
extern int (*mlx_fft_fft2_)(
mlx_array* res,
const mlx_array a,
const int* n,
size_t n_num,
const int* axes,
size_t axes_num,
const mlx_stream s);
extern int (*mlx_fft_fftn_)(
mlx_array* res,
const mlx_array a,
const int* n,
size_t n_num,
const int* axes,
size_t axes_num,
const mlx_stream s);
extern int (*mlx_fft_fftshift_)(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
const mlx_stream s);
extern int (*mlx_fft_ifft_)(
mlx_array* res,
const mlx_array a,
int n,
int axis,
const mlx_stream s);
extern int (*mlx_fft_ifft2_)(
mlx_array* res,
const mlx_array a,
const int* n,
size_t n_num,
const int* axes,
size_t axes_num,
const mlx_stream s);
extern int (*mlx_fft_ifftn_)(
mlx_array* res,
const mlx_array a,
const int* n,
size_t n_num,
const int* axes,
size_t axes_num,
const mlx_stream s);
extern int (*mlx_fft_ifftshift_)(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
const mlx_stream s);
extern int (*mlx_fft_irfft_)(
mlx_array* res,
const mlx_array a,
int n,
int axis,
const mlx_stream s);
extern int (*mlx_fft_irfft2_)(
mlx_array* res,
const mlx_array a,
const int* n,
size_t n_num,
const int* axes,
size_t axes_num,
const mlx_stream s);
extern int (*mlx_fft_irfftn_)(
mlx_array* res,
const mlx_array a,
const int* n,
size_t n_num,
const int* axes,
size_t axes_num,
const mlx_stream s);
extern int (*mlx_fft_rfft_)(
mlx_array* res,
const mlx_array a,
int n,
int axis,
const mlx_stream s);
extern int (*mlx_fft_rfft2_)(
mlx_array* res,
const mlx_array a,
const int* n,
size_t n_num,
const int* axes,
size_t axes_num,
const mlx_stream s);
extern int (*mlx_fft_rfftn_)(
mlx_array* res,
const mlx_array a,
const int* n,
size_t n_num,
const int* axes,
size_t axes_num,
const mlx_stream s);
extern mlx_io_reader (*mlx_io_reader_new_)(void* desc, mlx_io_vtable vtable);
extern int (*mlx_io_reader_descriptor_)(void** desc_, mlx_io_reader io);
extern int (*mlx_io_reader_tostring_)(mlx_string* str_, mlx_io_reader io);
extern int (*mlx_io_reader_free_)(mlx_io_reader io);
extern mlx_io_writer (*mlx_io_writer_new_)(void* desc, mlx_io_vtable vtable);
extern int (*mlx_io_writer_descriptor_)(void** desc_, mlx_io_writer io);
extern int (*mlx_io_writer_tostring_)(mlx_string* str_, mlx_io_writer io);
extern int (*mlx_io_writer_free_)(mlx_io_writer io);
extern int (*mlx_load_reader_)(
mlx_array* res,
mlx_io_reader in_stream,
const mlx_stream s);
extern int (*mlx_load_)(mlx_array* res, const char* file, const mlx_stream s);
extern int (*mlx_load_safetensors_reader_)(
mlx_map_string_to_array* res_0,
mlx_map_string_to_string* res_1,
mlx_io_reader in_stream,
const mlx_stream s);
extern int (*mlx_load_safetensors_)(
mlx_map_string_to_array* res_0,
mlx_map_string_to_string* res_1,
const char* file,
const mlx_stream s);
extern int (*mlx_save_writer_)(mlx_io_writer out_stream, const mlx_array a);
extern int (*mlx_save_)(const char* file, const mlx_array a);
extern int (*mlx_save_safetensors_writer_)(
mlx_io_writer in_stream,
const mlx_map_string_to_array param,
const mlx_map_string_to_string metadata);
extern int (*mlx_save_safetensors_)(
const char* file,
const mlx_map_string_to_array param,
const mlx_map_string_to_string metadata);
extern int (*mlx_linalg_cholesky_)(
mlx_array* res,
const mlx_array a,
bool upper,
const mlx_stream s);
extern int (*mlx_linalg_cholesky_inv_)(
mlx_array* res,
const mlx_array a,
bool upper,
const mlx_stream s);
extern int (*mlx_linalg_cross_)(
mlx_array* res,
const mlx_array a,
const mlx_array b,
int axis,
const mlx_stream s);
extern int (*mlx_linalg_eig_)(
mlx_array* res_0,
mlx_array* res_1,
const mlx_array a,
const mlx_stream s);
extern int (*mlx_linalg_eigh_)(
mlx_array* res_0,
mlx_array* res_1,
const mlx_array a,
const char* UPLO,
const mlx_stream s);
extern int (*mlx_linalg_eigvals_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_linalg_eigvalsh_)(
mlx_array* res,
const mlx_array a,
const char* UPLO,
const mlx_stream s);
extern int (*mlx_linalg_inv_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_linalg_lu_)(mlx_vector_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_linalg_lu_factor_)(
mlx_array* res_0,
mlx_array* res_1,
const mlx_array a,
const mlx_stream s);
extern int (*mlx_linalg_norm_)(
mlx_array* res,
const mlx_array a,
double ord,
const int* axis /* may be null */,
size_t axis_num,
bool keepdims,
const mlx_stream s);
extern int (*mlx_linalg_norm_matrix_)(
mlx_array* res,
const mlx_array a,
const char* ord,
const int* axis /* may be null */,
size_t axis_num,
bool keepdims,
const mlx_stream s);
extern int (*mlx_linalg_norm_l2_)(
mlx_array* res,
const mlx_array a,
const int* axis /* may be null */,
size_t axis_num,
bool keepdims,
const mlx_stream s);
extern int (*mlx_linalg_pinv_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_linalg_qr_)(
mlx_array* res_0,
mlx_array* res_1,
const mlx_array a,
const mlx_stream s);
extern int (*mlx_linalg_solve_)(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s);
extern int (*mlx_linalg_solve_triangular_)(
mlx_array* res,
const mlx_array a,
const mlx_array b,
bool upper,
const mlx_stream s);
extern int (*mlx_linalg_svd_)(
mlx_vector_array* res,
const mlx_array a,
bool compute_uv,
const mlx_stream s);
extern int (*mlx_linalg_tri_inv_)(
mlx_array* res,
const mlx_array a,
bool upper,
const mlx_stream s);
extern mlx_map_string_to_array (*mlx_map_string_to_array_new_)(void);
extern int (*mlx_map_string_to_array_set_)(
mlx_map_string_to_array* map,
const mlx_map_string_to_array src);
extern int (*mlx_map_string_to_array_free_)(mlx_map_string_to_array map);
extern int (*mlx_map_string_to_array_insert_)(
mlx_map_string_to_array map,
const char* key,
const mlx_array value);
extern int (*mlx_map_string_to_array_get_)(
mlx_array* value,
const mlx_map_string_to_array map,
const char* key);
extern mlx_map_string_to_array_iterator (*mlx_map_string_to_array_iterator_new_)(
mlx_map_string_to_array map);
extern int (*mlx_map_string_to_array_iterator_free_)(mlx_map_string_to_array_iterator it);
extern int (*mlx_map_string_to_array_iterator_next_)(
const char** key,
mlx_array* value,
mlx_map_string_to_array_iterator it);
extern mlx_map_string_to_string (*mlx_map_string_to_string_new_)(void);
extern int (*mlx_map_string_to_string_set_)(
mlx_map_string_to_string* map,
const mlx_map_string_to_string src);
extern int (*mlx_map_string_to_string_free_)(mlx_map_string_to_string map);
extern int (*mlx_map_string_to_string_insert_)(
mlx_map_string_to_string map,
const char* key,
const char* value);
extern int (*mlx_map_string_to_string_get_)(
const char** value,
const mlx_map_string_to_string map,
const char* key);
extern mlx_map_string_to_string_iterator (*mlx_map_string_to_string_iterator_new_)(
mlx_map_string_to_string map);
extern int (*mlx_map_string_to_string_iterator_free_)(
mlx_map_string_to_string_iterator it);
extern int (*mlx_map_string_to_string_iterator_next_)(
const char** key,
const char** value,
mlx_map_string_to_string_iterator it);
extern int (*mlx_clear_cache_)(void);
extern int (*mlx_get_active_memory_)(size_t* res);
extern int (*mlx_get_cache_memory_)(size_t* res);
extern int (*mlx_get_memory_limit_)(size_t* res);
extern int (*mlx_get_peak_memory_)(size_t* res);
extern int (*mlx_reset_peak_memory_)(void);
extern int (*mlx_set_cache_limit_)(size_t* res, size_t limit);
extern int (*mlx_set_memory_limit_)(size_t* res, size_t limit);
extern int (*mlx_set_wired_limit_)(size_t* res, size_t limit);
extern mlx_metal_device_info_t (*mlx_metal_device_info_)(void);
extern int (*mlx_metal_is_available_)(bool* res);
extern int (*mlx_metal_start_capture_)(const char* path);
extern int (*mlx_metal_stop_capture_)(void);
extern int (*mlx_abs_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_add_)(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s);
extern int (*mlx_addmm_)(
mlx_array* res,
const mlx_array c,
const mlx_array a,
const mlx_array b,
float alpha,
float beta,
const mlx_stream s);
extern int (*mlx_all_axes_)(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
bool keepdims,
const mlx_stream s);
extern int (*mlx_all_axis_)(
mlx_array* res,
const mlx_array a,
int axis,
bool keepdims,
const mlx_stream s);
extern int (*mlx_all_)(
mlx_array* res,
const mlx_array a,
bool keepdims,
const mlx_stream s);
extern int (*mlx_allclose_)(
mlx_array* res,
const mlx_array a,
const mlx_array b,
double rtol,
double atol,
bool equal_nan,
const mlx_stream s);
extern int (*mlx_any_axes_)(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
bool keepdims,
const mlx_stream s);
extern int (*mlx_any_axis_)(
mlx_array* res,
const mlx_array a,
int axis,
bool keepdims,
const mlx_stream s);
extern int (*mlx_any_)(
mlx_array* res,
const mlx_array a,
bool keepdims,
const mlx_stream s);
extern int (*mlx_arange_)(
mlx_array* res,
double start,
double stop,
double step,
mlx_dtype dtype,
const mlx_stream s);
extern int (*mlx_arccos_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_arccosh_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_arcsin_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_arcsinh_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_arctan_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_arctan2_)(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s);
extern int (*mlx_arctanh_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_argmax_axis_)(
mlx_array* res,
const mlx_array a,
int axis,
bool keepdims,
const mlx_stream s);
extern int (*mlx_argmax_)(
mlx_array* res,
const mlx_array a,
bool keepdims,
const mlx_stream s);
extern int (*mlx_argmin_axis_)(
mlx_array* res,
const mlx_array a,
int axis,
bool keepdims,
const mlx_stream s);
extern int (*mlx_argmin_)(
mlx_array* res,
const mlx_array a,
bool keepdims,
const mlx_stream s);
extern int (*mlx_argpartition_axis_)(
mlx_array* res,
const mlx_array a,
int kth,
int axis,
const mlx_stream s);
extern int (*mlx_argpartition_)(
mlx_array* res,
const mlx_array a,
int kth,
const mlx_stream s);
extern int (*mlx_argsort_axis_)(
mlx_array* res,
const mlx_array a,
int axis,
const mlx_stream s);
extern int (*mlx_argsort_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_array_equal_)(
mlx_array* res,
const mlx_array a,
const mlx_array b,
bool equal_nan,
const mlx_stream s);
extern int (*mlx_as_strided_)(
mlx_array* res,
const mlx_array a,
const int* shape,
size_t shape_num,
const int64_t* strides,
size_t strides_num,
size_t offset,
const mlx_stream s);
extern int (*mlx_astype_)(
mlx_array* res,
const mlx_array a,
mlx_dtype dtype,
const mlx_stream s);
extern int (*mlx_atleast_1d_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_atleast_2d_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_atleast_3d_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_bitwise_and_)(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s);
extern int (*mlx_bitwise_invert_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_bitwise_or_)(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s);
extern int (*mlx_bitwise_xor_)(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s);
extern int (*mlx_block_masked_mm_)(
mlx_array* res,
const mlx_array a,
const mlx_array b,
int block_size,
const mlx_array mask_out /* may be null */,
const mlx_array mask_lhs /* may be null */,
const mlx_array mask_rhs /* may be null */,
const mlx_stream s);
extern int (*mlx_broadcast_arrays_)(
mlx_vector_array* res,
const mlx_vector_array inputs,
const mlx_stream s);
extern int (*mlx_broadcast_to_)(
mlx_array* res,
const mlx_array a,
const int* shape,
size_t shape_num,
const mlx_stream s);
extern int (*mlx_ceil_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_clip_)(
mlx_array* res,
const mlx_array a,
const mlx_array a_min /* may be null */,
const mlx_array a_max /* may be null */,
const mlx_stream s);
extern int (*mlx_concatenate_axis_)(
mlx_array* res,
const mlx_vector_array arrays,
int axis,
const mlx_stream s);
extern int (*mlx_concatenate_)(
mlx_array* res,
const mlx_vector_array arrays,
const mlx_stream s);
extern int (*mlx_conjugate_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_contiguous_)(
mlx_array* res,
const mlx_array a,
bool allow_col_major,
const mlx_stream s);
extern int (*mlx_conv1d_)(
mlx_array* res,
const mlx_array input,
const mlx_array weight,
int stride,
int padding,
int dilation,
int groups,
const mlx_stream s);
extern int (*mlx_conv2d_)(
mlx_array* res,
const mlx_array input,
const mlx_array weight,
int stride_0,
int stride_1,
int padding_0,
int padding_1,
int dilation_0,
int dilation_1,
int groups,
const mlx_stream s);
extern int (*mlx_conv3d_)(
mlx_array* res,
const mlx_array input,
const mlx_array weight,
int stride_0,
int stride_1,
int stride_2,
int padding_0,
int padding_1,
int padding_2,
int dilation_0,
int dilation_1,
int dilation_2,
int groups,
const mlx_stream s);
extern int (*mlx_conv_general_)(
mlx_array* res,
const mlx_array input,
const mlx_array weight,
const int* stride,
size_t stride_num,
const int* padding_lo,
size_t padding_lo_num,
const int* padding_hi,
size_t padding_hi_num,
const int* kernel_dilation,
size_t kernel_dilation_num,
const int* input_dilation,
size_t input_dilation_num,
int groups,
bool flip,
const mlx_stream s);
extern int (*mlx_conv_transpose1d_)(
mlx_array* res,
const mlx_array input,
const mlx_array weight,
int stride,
int padding,
int dilation,
int output_padding,
int groups,
const mlx_stream s);
extern int (*mlx_conv_transpose2d_)(
mlx_array* res,
const mlx_array input,
const mlx_array weight,
int stride_0,
int stride_1,
int padding_0,
int padding_1,
int dilation_0,
int dilation_1,
int output_padding_0,
int output_padding_1,
int groups,
const mlx_stream s);
extern int (*mlx_conv_transpose3d_)(
mlx_array* res,
const mlx_array input,
const mlx_array weight,
int stride_0,
int stride_1,
int stride_2,
int padding_0,
int padding_1,
int padding_2,
int dilation_0,
int dilation_1,
int dilation_2,
int output_padding_0,
int output_padding_1,
int output_padding_2,
int groups,
const mlx_stream s);
extern int (*mlx_copy_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_cos_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_cosh_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_cummax_)(
mlx_array* res,
const mlx_array a,
int axis,
bool reverse,
bool inclusive,
const mlx_stream s);
extern int (*mlx_cummin_)(
mlx_array* res,
const mlx_array a,
int axis,
bool reverse,
bool inclusive,
const mlx_stream s);
extern int (*mlx_cumprod_)(
mlx_array* res,
const mlx_array a,
int axis,
bool reverse,
bool inclusive,
const mlx_stream s);
extern int (*mlx_cumsum_)(
mlx_array* res,
const mlx_array a,
int axis,
bool reverse,
bool inclusive,
const mlx_stream s);
extern int (*mlx_degrees_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_depends_)(
mlx_vector_array* res,
const mlx_vector_array inputs,
const mlx_vector_array dependencies);
extern int (*mlx_dequantize_)(
mlx_array* res,
const mlx_array w,
const mlx_array scales,
const mlx_array biases /* may be null */,
mlx_optional_int group_size,
mlx_optional_int bits,
const char* mode,
mlx_optional_dtype dtype,
const mlx_stream s);
extern int (*mlx_diag_)(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
extern int (*mlx_diagonal_)(
mlx_array* res,
const mlx_array a,
int offset,
int axis1,
int axis2,
const mlx_stream s);
extern int (*mlx_divide_)(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s);
extern int (*mlx_divmod_)(
mlx_vector_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s);
extern int (*mlx_einsum_)(
mlx_array* res,
const char* subscripts,
const mlx_vector_array operands,
const mlx_stream s);
extern int (*mlx_equal_)(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s);
extern int (*mlx_erf_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_erfinv_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_exp_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_expand_dims_axes_)(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
const mlx_stream s);
extern int (*mlx_expand_dims_)(
mlx_array* res,
const mlx_array a,
int axis,
const mlx_stream s);
extern int (*mlx_expm1_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_eye_)(
mlx_array* res,
int n,
int m,
int k,
mlx_dtype dtype,
const mlx_stream s);
extern int (*mlx_flatten_)(
mlx_array* res,
const mlx_array a,
int start_axis,
int end_axis,
const mlx_stream s);
extern int (*mlx_floor_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_floor_divide_)(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s);
extern int (*mlx_from_fp8_)(
mlx_array* res,
const mlx_array x,
mlx_dtype dtype,
const mlx_stream s);
extern int (*mlx_full_)(
mlx_array* res,
const int* shape,
size_t shape_num,
const mlx_array vals,
mlx_dtype dtype,
const mlx_stream s);
extern int (*mlx_full_like_)(
mlx_array* res,
const mlx_array a,
const mlx_array vals,
mlx_dtype dtype,
const mlx_stream s);
extern int (*mlx_gather_)(
mlx_array* res,
const mlx_array a,
const mlx_vector_array indices,
const int* axes,
size_t axes_num,
const int* slice_sizes,
size_t slice_sizes_num,
const mlx_stream s);
extern int (*mlx_gather_mm_)(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_array lhs_indices /* may be null */,
const mlx_array rhs_indices /* may be null */,
bool sorted_indices,
const mlx_stream s);
extern int (*mlx_gather_qmm_)(
mlx_array* res,
const mlx_array x,
const mlx_array w,
const mlx_array scales,
const mlx_array biases /* may be null */,
const mlx_array lhs_indices /* may be null */,
const mlx_array rhs_indices /* may be null */,
bool transpose,
mlx_optional_int group_size,
mlx_optional_int bits,
const char* mode,
bool sorted_indices,
const mlx_stream s);
extern int (*mlx_greater_)(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s);
extern int (*mlx_greater_equal_)(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s);
extern int (*mlx_hadamard_transform_)(
mlx_array* res,
const mlx_array a,
mlx_optional_float scale,
const mlx_stream s);
extern int (*mlx_identity_)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
extern int (*mlx_imag_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_inner_)(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s);
extern int (*mlx_isclose_)(
mlx_array* res,
const mlx_array a,
const mlx_array b,
double rtol,
double atol,
bool equal_nan,
const mlx_stream s);
extern int (*mlx_isfinite_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_isinf_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_isnan_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_isneginf_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_isposinf_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_kron_)(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s);
extern int (*mlx_left_shift_)(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s);
extern int (*mlx_less_)(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s);
extern int (*mlx_less_equal_)(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s);
extern int (*mlx_linspace_)(
mlx_array* res,
double start,
double stop,
int num,
mlx_dtype dtype,
const mlx_stream s);
extern int (*mlx_log_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_log10_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_log1p_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_log2_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_logaddexp_)(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s);
extern int (*mlx_logcumsumexp_)(
mlx_array* res,
const mlx_array a,
int axis,
bool reverse,
bool inclusive,
const mlx_stream s);
extern int (*mlx_logical_and_)(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s);
extern int (*mlx_logical_not_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_logical_or_)(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s);
extern int (*mlx_logsumexp_axes_)(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
bool keepdims,
const mlx_stream s);
extern int (*mlx_logsumexp_axis_)(
mlx_array* res,
const mlx_array a,
int axis,
bool keepdims,
const mlx_stream s);
extern int (*mlx_logsumexp_)(
mlx_array* res,
const mlx_array a,
bool keepdims,
const mlx_stream s);
extern int (*mlx_masked_scatter_)(
mlx_array* res,
const mlx_array a,
const mlx_array mask,
const mlx_array src,
const mlx_stream s);
extern int (*mlx_matmul_)(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s);
extern int (*mlx_max_axes_)(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
bool keepdims,
const mlx_stream s);
extern int (*mlx_max_axis_)(
mlx_array* res,
const mlx_array a,
int axis,
bool keepdims,
const mlx_stream s);
extern int (*mlx_max_)(
mlx_array* res,
const mlx_array a,
bool keepdims,
const mlx_stream s);
extern int (*mlx_maximum_)(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s);
extern int (*mlx_mean_axes_)(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
bool keepdims,
const mlx_stream s);
extern int (*mlx_mean_axis_)(
mlx_array* res,
const mlx_array a,
int axis,
bool keepdims,
const mlx_stream s);
extern int (*mlx_mean_)(
mlx_array* res,
const mlx_array a,
bool keepdims,
const mlx_stream s);
extern int (*mlx_median_)(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
bool keepdims,
const mlx_stream s);
extern int (*mlx_meshgrid_)(
mlx_vector_array* res,
const mlx_vector_array arrays,
bool sparse,
const char* indexing,
const mlx_stream s);
extern int (*mlx_min_axes_)(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
bool keepdims,
const mlx_stream s);
extern int (*mlx_min_axis_)(
mlx_array* res,
const mlx_array a,
int axis,
bool keepdims,
const mlx_stream s);
extern int (*mlx_min_)(
mlx_array* res,
const mlx_array a,
bool keepdims,
const mlx_stream s);
extern int (*mlx_minimum_)(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s);
extern int (*mlx_moveaxis_)(
mlx_array* res,
const mlx_array a,
int source,
int destination,
const mlx_stream s);
extern int (*mlx_multiply_)(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s);
extern int (*mlx_nan_to_num_)(
mlx_array* res,
const mlx_array a,
float nan,
mlx_optional_float posinf,
mlx_optional_float neginf,
const mlx_stream s);
extern int (*mlx_negative_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_not_equal_)(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s);
extern int (*mlx_number_of_elements_)(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
bool inverted,
mlx_dtype dtype,
const mlx_stream s);
extern int (*mlx_ones_)(
mlx_array* res,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
const mlx_stream s);
extern int (*mlx_ones_like_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_outer_)(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s);
extern int (*mlx_pad_)(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
const int* low_pad_size,
size_t low_pad_size_num,
const int* high_pad_size,
size_t high_pad_size_num,
const mlx_array pad_value,
const char* mode,
const mlx_stream s);
extern int (*mlx_pad_symmetric_)(
mlx_array* res,
const mlx_array a,
int pad_width,
const mlx_array pad_value,
const char* mode,
const mlx_stream s);
extern int (*mlx_partition_axis_)(
mlx_array* res,
const mlx_array a,
int kth,
int axis,
const mlx_stream s);
extern int (*mlx_partition_)(
mlx_array* res,
const mlx_array a,
int kth,
const mlx_stream s);
extern int (*mlx_power_)(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s);
extern int (*mlx_prod_axes_)(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
bool keepdims,
const mlx_stream s);
extern int (*mlx_prod_axis_)(
mlx_array* res,
const mlx_array a,
int axis,
bool keepdims,
const mlx_stream s);
extern int (*mlx_prod_)(
mlx_array* res,
const mlx_array a,
bool keepdims,
const mlx_stream s);
extern int (*mlx_put_along_axis_)(
mlx_array* res,
const mlx_array a,
const mlx_array indices,
const mlx_array values,
int axis,
const mlx_stream s);
extern int (*mlx_quantize_)(
mlx_vector_array* res,
const mlx_array w,
mlx_optional_int group_size,
mlx_optional_int bits,
const char* mode,
const mlx_stream s);
extern int (*mlx_quantized_matmul_)(
mlx_array* res,
const mlx_array x,
const mlx_array w,
const mlx_array scales,
const mlx_array biases /* may be null */,
bool transpose,
mlx_optional_int group_size,
mlx_optional_int bits,
const char* mode,
const mlx_stream s);
extern int (*mlx_radians_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_real_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_reciprocal_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_remainder_)(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s);
extern int (*mlx_repeat_axis_)(
mlx_array* res,
const mlx_array arr,
int repeats,
int axis,
const mlx_stream s);
extern int (*mlx_repeat_)(
mlx_array* res,
const mlx_array arr,
int repeats,
const mlx_stream s);
extern int (*mlx_reshape_)(
mlx_array* res,
const mlx_array a,
const int* shape,
size_t shape_num,
const mlx_stream s);
extern int (*mlx_right_shift_)(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s);
extern int (*mlx_roll_axis_)(
mlx_array* res,
const mlx_array a,
const int* shift,
size_t shift_num,
int axis,
const mlx_stream s);
extern int (*mlx_roll_axes_)(
mlx_array* res,
const mlx_array a,
const int* shift,
size_t shift_num,
const int* axes,
size_t axes_num,
const mlx_stream s);
extern int (*mlx_roll_)(
mlx_array* res,
const mlx_array a,
const int* shift,
size_t shift_num,
const mlx_stream s);
extern int (*mlx_round_)(
mlx_array* res,
const mlx_array a,
int decimals,
const mlx_stream s);
extern int (*mlx_rsqrt_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_scatter_)(
mlx_array* res,
const mlx_array a,
const mlx_vector_array indices,
const mlx_array updates,
const int* axes,
size_t axes_num,
const mlx_stream s);
extern int (*mlx_scatter_add_)(
mlx_array* res,
const mlx_array a,
const mlx_vector_array indices,
const mlx_array updates,
const int* axes,
size_t axes_num,
const mlx_stream s);
extern int (*mlx_scatter_add_axis_)(
mlx_array* res,
const mlx_array a,
const mlx_array indices,
const mlx_array values,
int axis,
const mlx_stream s);
extern int (*mlx_scatter_max_)(
mlx_array* res,
const mlx_array a,
const mlx_vector_array indices,
const mlx_array updates,
const int* axes,
size_t axes_num,
const mlx_stream s);
extern int (*mlx_scatter_min_)(
mlx_array* res,
const mlx_array a,
const mlx_vector_array indices,
const mlx_array updates,
const int* axes,
size_t axes_num,
const mlx_stream s);
extern int (*mlx_scatter_prod_)(
mlx_array* res,
const mlx_array a,
const mlx_vector_array indices,
const mlx_array updates,
const int* axes,
size_t axes_num,
const mlx_stream s);
extern int (*mlx_segmented_mm_)(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_array segments,
const mlx_stream s);
extern int (*mlx_sigmoid_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_sign_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_sin_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_sinh_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_slice_)(
mlx_array* res,
const mlx_array a,
const int* start,
size_t start_num,
const int* stop,
size_t stop_num,
const int* strides,
size_t strides_num,
const mlx_stream s);
extern int (*mlx_slice_dynamic_)(
mlx_array* res,
const mlx_array a,
const mlx_array start,
const int* axes,
size_t axes_num,
const int* slice_size,
size_t slice_size_num,
const mlx_stream s);
extern int (*mlx_slice_update_)(
mlx_array* res,
const mlx_array src,
const mlx_array update,
const int* start,
size_t start_num,
const int* stop,
size_t stop_num,
const int* strides,
size_t strides_num,
const mlx_stream s);
extern int (*mlx_slice_update_dynamic_)(
mlx_array* res,
const mlx_array src,
const mlx_array update,
const mlx_array start,
const int* axes,
size_t axes_num,
const mlx_stream s);
extern int (*mlx_softmax_axes_)(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
bool precise,
const mlx_stream s);
extern int (*mlx_softmax_axis_)(
mlx_array* res,
const mlx_array a,
int axis,
bool precise,
const mlx_stream s);
extern int (*mlx_softmax_)(
mlx_array* res,
const mlx_array a,
bool precise,
const mlx_stream s);
extern int (*mlx_sort_axis_)(
mlx_array* res,
const mlx_array a,
int axis,
const mlx_stream s);
extern int (*mlx_sort_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_split_)(
mlx_vector_array* res,
const mlx_array a,
int num_splits,
int axis,
const mlx_stream s);
extern int (*mlx_split_sections_)(
mlx_vector_array* res,
const mlx_array a,
const int* indices,
size_t indices_num,
int axis,
const mlx_stream s);
extern int (*mlx_sqrt_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_square_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_squeeze_axes_)(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
const mlx_stream s);
extern int (*mlx_squeeze_axis_)(
mlx_array* res,
const mlx_array a,
int axis,
const mlx_stream s);
extern int (*mlx_squeeze_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_stack_axis_)(
mlx_array* res,
const mlx_vector_array arrays,
int axis,
const mlx_stream s);
extern int (*mlx_stack_)(
mlx_array* res,
const mlx_vector_array arrays,
const mlx_stream s);
extern int (*mlx_std_axes_)(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
bool keepdims,
int ddof,
const mlx_stream s);
extern int (*mlx_std_axis_)(
mlx_array* res,
const mlx_array a,
int axis,
bool keepdims,
int ddof,
const mlx_stream s);
extern int (*mlx_std_)(
mlx_array* res,
const mlx_array a,
bool keepdims,
int ddof,
const mlx_stream s);
extern int (*mlx_stop_gradient_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_subtract_)(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s);
extern int (*mlx_sum_axes_)(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
bool keepdims,
const mlx_stream s);
extern int (*mlx_sum_axis_)(
mlx_array* res,
const mlx_array a,
int axis,
bool keepdims,
const mlx_stream s);
extern int (*mlx_sum_)(
mlx_array* res,
const mlx_array a,
bool keepdims,
const mlx_stream s);
extern int (*mlx_swapaxes_)(
mlx_array* res,
const mlx_array a,
int axis1,
int axis2,
const mlx_stream s);
extern int (*mlx_take_axis_)(
mlx_array* res,
const mlx_array a,
const mlx_array indices,
int axis,
const mlx_stream s);
extern int (*mlx_take_)(
mlx_array* res,
const mlx_array a,
const mlx_array indices,
const mlx_stream s);
extern int (*mlx_take_along_axis_)(
mlx_array* res,
const mlx_array a,
const mlx_array indices,
int axis,
const mlx_stream s);
extern int (*mlx_tan_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_tanh_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_tensordot_)(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const int* axes_a,
size_t axes_a_num,
const int* axes_b,
size_t axes_b_num,
const mlx_stream s);
extern int (*mlx_tensordot_axis_)(
mlx_array* res,
const mlx_array a,
const mlx_array b,
int axis,
const mlx_stream s);
extern int (*mlx_tile_)(
mlx_array* res,
const mlx_array arr,
const int* reps,
size_t reps_num,
const mlx_stream s);
extern int (*mlx_to_fp8_)(mlx_array* res, const mlx_array x, const mlx_stream s);
extern int (*mlx_topk_axis_)(
mlx_array* res,
const mlx_array a,
int k,
int axis,
const mlx_stream s);
extern int (*mlx_topk_)(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
extern int (*mlx_trace_)(
mlx_array* res,
const mlx_array a,
int offset,
int axis1,
int axis2,
mlx_dtype dtype,
const mlx_stream s);
extern int (*mlx_transpose_axes_)(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
const mlx_stream s);
extern int (*mlx_transpose_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_tri_)(
mlx_array* res,
int n,
int m,
int k,
mlx_dtype type,
const mlx_stream s);
extern int (*mlx_tril_)(mlx_array* res, const mlx_array x, int k, const mlx_stream s);
extern int (*mlx_triu_)(mlx_array* res, const mlx_array x, int k, const mlx_stream s);
extern int (*mlx_unflatten_)(
mlx_array* res,
const mlx_array a,
int axis,
const int* shape,
size_t shape_num,
const mlx_stream s);
extern int (*mlx_var_axes_)(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
bool keepdims,
int ddof,
const mlx_stream s);
extern int (*mlx_var_axis_)(
mlx_array* res,
const mlx_array a,
int axis,
bool keepdims,
int ddof,
const mlx_stream s);
extern int (*mlx_var_)(
mlx_array* res,
const mlx_array a,
bool keepdims,
int ddof,
const mlx_stream s);
extern int (*mlx_view_)(
mlx_array* res,
const mlx_array a,
mlx_dtype dtype,
const mlx_stream s);
extern int (*mlx_where_)(
mlx_array* res,
const mlx_array condition,
const mlx_array x,
const mlx_array y,
const mlx_stream s);
extern int (*mlx_zeros_)(
mlx_array* res,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
const mlx_stream s);
extern int (*mlx_zeros_like_)(mlx_array* res, const mlx_array a, const mlx_stream s);
extern int (*mlx_random_bernoulli_)(
mlx_array* res,
const mlx_array p,
const int* shape,
size_t shape_num,
const mlx_array key /* may be null */,
const mlx_stream s);
extern int (*mlx_random_bits_)(
mlx_array* res,
const int* shape,
size_t shape_num,
int width,
const mlx_array key /* may be null */,
const mlx_stream s);
extern int (*mlx_random_categorical_shape_)(
mlx_array* res,
const mlx_array logits,
int axis,
const int* shape,
size_t shape_num,
const mlx_array key /* may be null */,
const mlx_stream s);
extern int (*mlx_random_categorical_num_samples_)(
mlx_array* res,
const mlx_array logits_,
int axis,
int num_samples,
const mlx_array key /* may be null */,
const mlx_stream s);
extern int (*mlx_random_categorical_)(
mlx_array* res,
const mlx_array logits,
int axis,
const mlx_array key /* may be null */,
const mlx_stream s);
extern int (*mlx_random_gumbel_)(
mlx_array* res,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
const mlx_array key /* may be null */,
const mlx_stream s);
extern int (*mlx_random_key_)(mlx_array* res, uint64_t seed);
extern int (*mlx_random_laplace_)(
mlx_array* res,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
float loc,
float scale,
const mlx_array key /* may be null */,
const mlx_stream s);
extern int (*mlx_random_multivariate_normal_)(
mlx_array* res,
const mlx_array mean,
const mlx_array cov,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
const mlx_array key /* may be null */,
const mlx_stream s);
extern int (*mlx_random_normal_broadcast_)(
mlx_array* res,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
const mlx_array loc /* may be null */,
const mlx_array scale /* may be null */,
const mlx_array key /* may be null */,
const mlx_stream s);
extern int (*mlx_random_normal_)(
mlx_array* res,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
float loc,
float scale,
const mlx_array key /* may be null */,
const mlx_stream s);
extern int (*mlx_random_permutation_)(
mlx_array* res,
const mlx_array x,
int axis,
const mlx_array key /* may be null */,
const mlx_stream s);
extern int (*mlx_random_permutation_arange_)(
mlx_array* res,
int x,
const mlx_array key /* may be null */,
const mlx_stream s);
extern int (*mlx_random_randint_)(
mlx_array* res,
const mlx_array low,
const mlx_array high,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
const mlx_array key /* may be null */,
const mlx_stream s);
extern int (*mlx_random_seed_)(uint64_t seed);
extern int (*mlx_random_split_num_)(
mlx_array* res,
const mlx_array key,
int num,
const mlx_stream s);
extern int (*mlx_random_split_)(
mlx_array* res_0,
mlx_array* res_1,
const mlx_array key,
const mlx_stream s);
extern int (*mlx_random_truncated_normal_)(
mlx_array* res,
const mlx_array lower,
const mlx_array upper,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
const mlx_array key /* may be null */,
const mlx_stream s);
extern int (*mlx_random_uniform_)(
mlx_array* res,
const mlx_array low,
const mlx_array high,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
const mlx_array key /* may be null */,
const mlx_stream s);
extern mlx_stream (*mlx_stream_new_)(void);
extern mlx_stream (*mlx_stream_new_device_)(mlx_device dev);
extern int (*mlx_stream_set_)(mlx_stream* stream, const mlx_stream src);
extern int (*mlx_stream_free_)(mlx_stream stream);
extern int (*mlx_stream_tostring_)(mlx_string* str, mlx_stream stream);
extern bool (*mlx_stream_equal_)(mlx_stream lhs, mlx_stream rhs);
extern int (*mlx_stream_get_device_)(mlx_device* dev, mlx_stream stream);
extern int (*mlx_stream_get_index_)(int* index, mlx_stream stream);
extern int (*mlx_synchronize_)(mlx_stream stream);
extern int (*mlx_get_default_stream_)(mlx_stream* stream, mlx_device dev);
extern int (*mlx_set_default_stream_)(mlx_stream stream);
extern mlx_stream (*mlx_default_cpu_stream_new_)(void);
extern mlx_stream (*mlx_default_gpu_stream_new_)(void);
extern mlx_string (*mlx_string_new_)(void);
extern mlx_string (*mlx_string_new_data_)(const char* str);
extern int (*mlx_string_set_)(mlx_string* str, const mlx_string src);
extern const char * (*mlx_string_data_)(mlx_string str);
extern int (*mlx_string_free_)(mlx_string str);
extern int (*mlx_detail_vmap_replace_)(
mlx_vector_array* res,
const mlx_vector_array inputs,
const mlx_vector_array s_inputs,
const mlx_vector_array s_outputs,
const int* in_axes,
size_t in_axes_num,
const int* out_axes,
size_t out_axes_num);
extern int (*mlx_detail_vmap_trace_)(
mlx_vector_array* res_0,
mlx_vector_array* res_1,
const mlx_closure fun,
const mlx_vector_array inputs,
const int* in_axes,
size_t in_axes_num);
extern int (*mlx_async_eval_)(const mlx_vector_array outputs);
extern int (*mlx_checkpoint_)(mlx_closure* res, const mlx_closure fun);
extern int (*mlx_custom_function_)(
mlx_closure* res,
const mlx_closure fun,
const mlx_closure_custom fun_vjp /* may be null */,
const mlx_closure_custom_jvp fun_jvp /* may be null */,
const mlx_closure_custom_vmap fun_vmap /* may be null */);
extern int (*mlx_custom_vjp_)(
mlx_closure* res,
const mlx_closure fun,
const mlx_closure_custom fun_vjp);
extern int (*mlx_eval_)(const mlx_vector_array outputs);
extern int (*mlx_jvp_)(
mlx_vector_array* res_0,
mlx_vector_array* res_1,
const mlx_closure fun,
const mlx_vector_array primals,
const mlx_vector_array tangents);
extern int (*mlx_value_and_grad_)(
mlx_closure_value_and_grad* res,
const mlx_closure fun,
const int* argnums,
size_t argnums_num);
extern int (*mlx_vjp_)(
mlx_vector_array* res_0,
mlx_vector_array* res_1,
const mlx_closure fun,
const mlx_vector_array primals,
const mlx_vector_array cotangents);
extern mlx_vector_array (*mlx_vector_array_new_)(void);
extern int (*mlx_vector_array_set_)(mlx_vector_array* vec, const mlx_vector_array src);
extern int (*mlx_vector_array_free_)(mlx_vector_array vec);
extern mlx_vector_array (*mlx_vector_array_new_data_)(const mlx_array* data, size_t size);
extern mlx_vector_array (*mlx_vector_array_new_value_)(const mlx_array val);
extern int (*mlx_vector_array_set_data_)(
mlx_vector_array* vec,
const mlx_array* data,
size_t size);
extern int (*mlx_vector_array_set_value_)(mlx_vector_array* vec, const mlx_array val);
extern int (*mlx_vector_array_append_data_)(
mlx_vector_array vec,
const mlx_array* data,
size_t size);
extern int (*mlx_vector_array_append_value_)(mlx_vector_array vec, const mlx_array val);
extern size_t (*mlx_vector_array_size_)(mlx_vector_array vec);
extern int (*mlx_vector_array_get_)(
mlx_array* res,
const mlx_vector_array vec,
size_t idx);
extern mlx_vector_vector_array (*mlx_vector_vector_array_new_)(void);
extern int (*mlx_vector_vector_array_set_)(
mlx_vector_vector_array* vec,
const mlx_vector_vector_array src);
extern int (*mlx_vector_vector_array_free_)(mlx_vector_vector_array vec);
extern mlx_vector_vector_array (*mlx_vector_vector_array_new_data_)(
const mlx_vector_array* data,
size_t size);
extern mlx_vector_vector_array (*mlx_vector_vector_array_new_value_)(
const mlx_vector_array val);
extern int (*mlx_vector_vector_array_set_data_)(
mlx_vector_vector_array* vec,
const mlx_vector_array* data,
size_t size);
extern int (*mlx_vector_vector_array_set_value_)(
mlx_vector_vector_array* vec,
const mlx_vector_array val);
extern int (*mlx_vector_vector_array_append_data_)(
mlx_vector_vector_array vec,
const mlx_vector_array* data,
size_t size);
extern int (*mlx_vector_vector_array_append_value_)(
mlx_vector_vector_array vec,
const mlx_vector_array val);
extern size_t (*mlx_vector_vector_array_size_)(mlx_vector_vector_array vec);
extern int (*mlx_vector_vector_array_get_)(
mlx_vector_array* res,
const mlx_vector_vector_array vec,
size_t idx);
extern mlx_vector_int (*mlx_vector_int_new_)(void);
extern int (*mlx_vector_int_set_)(mlx_vector_int* vec, const mlx_vector_int src);
extern int (*mlx_vector_int_free_)(mlx_vector_int vec);
extern mlx_vector_int (*mlx_vector_int_new_data_)(int* data, size_t size);
extern mlx_vector_int (*mlx_vector_int_new_value_)(int val);
extern int (*mlx_vector_int_set_data_)(mlx_vector_int* vec, int* data, size_t size);
extern int (*mlx_vector_int_set_value_)(mlx_vector_int* vec, int val);
extern int (*mlx_vector_int_append_data_)(mlx_vector_int vec, int* data, size_t size);
extern int (*mlx_vector_int_append_value_)(mlx_vector_int vec, int val);
extern size_t (*mlx_vector_int_size_)(mlx_vector_int vec);
extern int (*mlx_vector_int_get_)(int* res, const mlx_vector_int vec, size_t idx);
extern mlx_vector_string (*mlx_vector_string_new_)(void);
extern int (*mlx_vector_string_set_)(mlx_vector_string* vec, const mlx_vector_string src);
extern int (*mlx_vector_string_free_)(mlx_vector_string vec);
extern mlx_vector_string (*mlx_vector_string_new_data_)(const char** data, size_t size);
extern mlx_vector_string (*mlx_vector_string_new_value_)(const char* val);
extern int (*mlx_vector_string_set_data_)(
mlx_vector_string* vec,
const char** data,
size_t size);
extern int (*mlx_vector_string_set_value_)(mlx_vector_string* vec, const char* val);
extern int (*mlx_vector_string_append_data_)(
mlx_vector_string vec,
const char** data,
size_t size);
extern int (*mlx_vector_string_append_value_)(mlx_vector_string vec, const char* val);
extern size_t (*mlx_vector_string_size_)(mlx_vector_string vec);
extern int (*mlx_vector_string_get_)(char** res, const mlx_vector_string vec, size_t idx);
extern int (*mlx_version_)(mlx_string* str_);
int mlx_dynamic_load_symbols(mlx_dynamic_handle handle);
static inline size_t mlx_dtype_size(mlx_dtype dtype) {
return mlx_dtype_size_(dtype);
}
static inline int mlx_array_tostring(mlx_string* str, const mlx_array arr) {
return mlx_array_tostring_(str, arr);
}
static inline mlx_array mlx_array_new(void) {
return mlx_array_new_();
}
static inline int mlx_array_free(mlx_array arr) {
return mlx_array_free_(arr);
}
static inline mlx_array mlx_array_new_bool(bool val) {
return mlx_array_new_bool_(val);
}
static inline mlx_array mlx_array_new_int(int val) {
return mlx_array_new_int_(val);
}
static inline mlx_array mlx_array_new_float32(float val) {
return mlx_array_new_float32_(val);
}
static inline mlx_array mlx_array_new_float(float val) {
return mlx_array_new_float_(val);
}
static inline mlx_array mlx_array_new_float64(double val) {
return mlx_array_new_float64_(val);
}
static inline mlx_array mlx_array_new_double(double val) {
return mlx_array_new_double_(val);
}
static inline mlx_array mlx_array_new_complex(float real_val, float imag_val) {
return mlx_array_new_complex_(real_val, imag_val);
}
static inline mlx_array mlx_array_new_data(
const void* data,
const int* shape,
int dim,
mlx_dtype dtype) {
return mlx_array_new_data_(data, shape, dim, dtype);
}
static inline int mlx_array_set(mlx_array* arr, const mlx_array src) {
return mlx_array_set_(arr, src);
}
static inline int mlx_array_set_bool(mlx_array* arr, bool val) {
return mlx_array_set_bool_(arr, val);
}
static inline int mlx_array_set_int(mlx_array* arr, int val) {
return mlx_array_set_int_(arr, val);
}
static inline int mlx_array_set_float32(mlx_array* arr, float val) {
return mlx_array_set_float32_(arr, val);
}
static inline int mlx_array_set_float(mlx_array* arr, float val) {
return mlx_array_set_float_(arr, val);
}
static inline int mlx_array_set_float64(mlx_array* arr, double val) {
return mlx_array_set_float64_(arr, val);
}
static inline int mlx_array_set_double(mlx_array* arr, double val) {
return mlx_array_set_double_(arr, val);
}
static inline int mlx_array_set_complex(mlx_array* arr, float real_val, float imag_val) {
return mlx_array_set_complex_(arr, real_val, imag_val);
}
static inline int mlx_array_set_data(
mlx_array* arr,
const void* data,
const int* shape,
int dim,
mlx_dtype dtype) {
return mlx_array_set_data_(arr, data, shape, dim, dtype);
}
static inline size_t mlx_array_itemsize(const mlx_array arr) {
return mlx_array_itemsize_(arr);
}
static inline size_t mlx_array_size(const mlx_array arr) {
return mlx_array_size_(arr);
}
static inline size_t mlx_array_nbytes(const mlx_array arr) {
return mlx_array_nbytes_(arr);
}
static inline size_t mlx_array_ndim(const mlx_array arr) {
return mlx_array_ndim_(arr);
}
static inline const int * mlx_array_shape(const mlx_array arr) {
return mlx_array_shape_(arr);
}
static inline const size_t * mlx_array_strides(const mlx_array arr) {
return mlx_array_strides_(arr);
}
static inline int mlx_array_dim(const mlx_array arr, int dim) {
return mlx_array_dim_(arr, dim);
}
static inline mlx_dtype mlx_array_dtype(const mlx_array arr) {
return mlx_array_dtype_(arr);
}
static inline int mlx_array_eval(mlx_array arr) {
return mlx_array_eval_(arr);
}
static inline int mlx_array_item_bool(bool* res, const mlx_array arr) {
return mlx_array_item_bool_(res, arr);
}
static inline int mlx_array_item_uint8(uint8_t* res, const mlx_array arr) {
return mlx_array_item_uint8_(res, arr);
}
static inline int mlx_array_item_uint16(uint16_t* res, const mlx_array arr) {
return mlx_array_item_uint16_(res, arr);
}
static inline int mlx_array_item_uint32(uint32_t* res, const mlx_array arr) {
return mlx_array_item_uint32_(res, arr);
}
static inline int mlx_array_item_uint64(uint64_t* res, const mlx_array arr) {
return mlx_array_item_uint64_(res, arr);
}
static inline int mlx_array_item_int8(int8_t* res, const mlx_array arr) {
return mlx_array_item_int8_(res, arr);
}
static inline int mlx_array_item_int16(int16_t* res, const mlx_array arr) {
return mlx_array_item_int16_(res, arr);
}
static inline int mlx_array_item_int32(int32_t* res, const mlx_array arr) {
return mlx_array_item_int32_(res, arr);
}
static inline int mlx_array_item_int64(int64_t* res, const mlx_array arr) {
return mlx_array_item_int64_(res, arr);
}
static inline int mlx_array_item_float32(float* res, const mlx_array arr) {
return mlx_array_item_float32_(res, arr);
}
static inline int mlx_array_item_float64(double* res, const mlx_array arr) {
return mlx_array_item_float64_(res, arr);
}
static inline int mlx_array_item_complex64(float _Complex* res, const mlx_array arr) {
return mlx_array_item_complex64_(res, arr);
}
static inline int mlx_array_item_float16(float16_t* res, const mlx_array arr) {
return mlx_array_item_float16_(res, arr);
}
static inline int mlx_array_item_bfloat16(bfloat16_t* res, const mlx_array arr) {
return mlx_array_item_bfloat16_(res, arr);
}
static inline const bool * mlx_array_data_bool(const mlx_array arr) {
return mlx_array_data_bool_(arr);
}
static inline const uint8_t * mlx_array_data_uint8(const mlx_array arr) {
return mlx_array_data_uint8_(arr);
}
static inline const uint16_t * mlx_array_data_uint16(const mlx_array arr) {
return mlx_array_data_uint16_(arr);
}
static inline const uint32_t * mlx_array_data_uint32(const mlx_array arr) {
return mlx_array_data_uint32_(arr);
}
static inline const uint64_t * mlx_array_data_uint64(const mlx_array arr) {
return mlx_array_data_uint64_(arr);
}
static inline const int8_t * mlx_array_data_int8(const mlx_array arr) {
return mlx_array_data_int8_(arr);
}
static inline const int16_t * mlx_array_data_int16(const mlx_array arr) {
return mlx_array_data_int16_(arr);
}
static inline const int32_t * mlx_array_data_int32(const mlx_array arr) {
return mlx_array_data_int32_(arr);
}
static inline const int64_t * mlx_array_data_int64(const mlx_array arr) {
return mlx_array_data_int64_(arr);
}
static inline const float * mlx_array_data_float32(const mlx_array arr) {
return mlx_array_data_float32_(arr);
}
static inline const double * mlx_array_data_float64(const mlx_array arr) {
return mlx_array_data_float64_(arr);
}
static inline const float _Complex * mlx_array_data_complex64(const mlx_array arr) {
return mlx_array_data_complex64_(arr);
}
static inline const float16_t * mlx_array_data_float16(const mlx_array arr) {
return mlx_array_data_float16_(arr);
}
static inline const bfloat16_t * mlx_array_data_bfloat16(const mlx_array arr) {
return mlx_array_data_bfloat16_(arr);
}
static inline int _mlx_array_is_available(bool* res, const mlx_array arr) {
return _mlx_array_is_available_(res, arr);
}
static inline int _mlx_array_wait(const mlx_array arr) {
return _mlx_array_wait_(arr);
}
static inline int _mlx_array_is_contiguous(bool* res, const mlx_array arr) {
return _mlx_array_is_contiguous_(res, arr);
}
static inline int _mlx_array_is_row_contiguous(bool* res, const mlx_array arr) {
return _mlx_array_is_row_contiguous_(res, arr);
}
static inline int _mlx_array_is_col_contiguous(bool* res, const mlx_array arr) {
return _mlx_array_is_col_contiguous_(res, arr);
}
static inline mlx_closure mlx_closure_new(void) {
return mlx_closure_new_();
}
static inline int mlx_closure_free(mlx_closure cls) {
return mlx_closure_free_(cls);
}
static inline mlx_closure mlx_closure_new_func(
int (*fun)(mlx_vector_array*, const mlx_vector_array)) {
return mlx_closure_new_func_(fun);
}
static inline mlx_closure mlx_closure_new_func_payload(
int (*fun)(mlx_vector_array*, const mlx_vector_array, void*),
void* payload,
void (*dtor)(void*)) {
return mlx_closure_new_func_payload_(fun, payload, dtor);
}
static inline int mlx_closure_set(mlx_closure* cls, const mlx_closure src) {
return mlx_closure_set_(cls, src);
}
static inline int mlx_closure_apply(
mlx_vector_array* res,
mlx_closure cls,
const mlx_vector_array input) {
return mlx_closure_apply_(res, cls, input);
}
static inline mlx_closure mlx_closure_new_unary(int (*fun)(mlx_array*, const mlx_array)) {
return mlx_closure_new_unary_(fun);
}
static inline mlx_closure_kwargs mlx_closure_kwargs_new(void) {
return mlx_closure_kwargs_new_();
}
static inline int mlx_closure_kwargs_free(mlx_closure_kwargs cls) {
return mlx_closure_kwargs_free_(cls);
}
static inline mlx_closure_kwargs mlx_closure_kwargs_new_func(int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_map_string_to_array)) {
return mlx_closure_kwargs_new_func_(fun);
}
static inline mlx_closure_kwargs mlx_closure_kwargs_new_func_payload(
int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_map_string_to_array,
void*),
void* payload,
void (*dtor)(void*)) {
return mlx_closure_kwargs_new_func_payload_(fun, payload, dtor);
}
static inline int mlx_closure_kwargs_set(
mlx_closure_kwargs* cls,
const mlx_closure_kwargs src) {
return mlx_closure_kwargs_set_(cls, src);
}
static inline int mlx_closure_kwargs_apply(
mlx_vector_array* res,
mlx_closure_kwargs cls,
const mlx_vector_array input_0,
const mlx_map_string_to_array input_1) {
return mlx_closure_kwargs_apply_(res, cls, input_0, input_1);
}
static inline mlx_closure_value_and_grad mlx_closure_value_and_grad_new(void) {
return mlx_closure_value_and_grad_new_();
}
static inline int mlx_closure_value_and_grad_free(mlx_closure_value_and_grad cls) {
return mlx_closure_value_and_grad_free_(cls);
}
static inline mlx_closure_value_and_grad mlx_closure_value_and_grad_new_func(
int (*fun)(mlx_vector_array*, mlx_vector_array*, const mlx_vector_array)) {
return mlx_closure_value_and_grad_new_func_(fun);
}
static inline mlx_closure_value_and_grad mlx_closure_value_and_grad_new_func_payload(
int (*fun)(
mlx_vector_array*,
mlx_vector_array*,
const mlx_vector_array,
void*),
void* payload,
void (*dtor)(void*)) {
return mlx_closure_value_and_grad_new_func_payload_(fun, payload, dtor);
}
static inline int mlx_closure_value_and_grad_set(
mlx_closure_value_and_grad* cls,
const mlx_closure_value_and_grad src) {
return mlx_closure_value_and_grad_set_(cls, src);
}
static inline int mlx_closure_value_and_grad_apply(
mlx_vector_array* res_0,
mlx_vector_array* res_1,
mlx_closure_value_and_grad cls,
const mlx_vector_array input) {
return mlx_closure_value_and_grad_apply_(res_0, res_1, cls, input);
}
static inline mlx_closure_custom mlx_closure_custom_new(void) {
return mlx_closure_custom_new_();
}
static inline int mlx_closure_custom_free(mlx_closure_custom cls) {
return mlx_closure_custom_free_(cls);
}
static inline mlx_closure_custom mlx_closure_custom_new_func(int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_vector_array,
const mlx_vector_array)) {
return mlx_closure_custom_new_func_(fun);
}
static inline mlx_closure_custom mlx_closure_custom_new_func_payload(
int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_vector_array,
const mlx_vector_array,
void*),
void* payload,
void (*dtor)(void*)) {
return mlx_closure_custom_new_func_payload_(fun, payload, dtor);
}
static inline int mlx_closure_custom_set(
mlx_closure_custom* cls,
const mlx_closure_custom src) {
return mlx_closure_custom_set_(cls, src);
}
static inline int mlx_closure_custom_apply(
mlx_vector_array* res,
mlx_closure_custom cls,
const mlx_vector_array input_0,
const mlx_vector_array input_1,
const mlx_vector_array input_2) {
return mlx_closure_custom_apply_(res, cls, input_0, input_1, input_2);
}
static inline mlx_closure_custom_jvp mlx_closure_custom_jvp_new(void) {
return mlx_closure_custom_jvp_new_();
}
static inline int mlx_closure_custom_jvp_free(mlx_closure_custom_jvp cls) {
return mlx_closure_custom_jvp_free_(cls);
}
static inline mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func(int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_vector_array,
const int*,
size_t _num)) {
return mlx_closure_custom_jvp_new_func_(fun);
}
static inline mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func_payload(
int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_vector_array,
const int*,
size_t _num,
void*),
void* payload,
void (*dtor)(void*)) {
return mlx_closure_custom_jvp_new_func_payload_(fun, payload, dtor);
}
static inline int mlx_closure_custom_jvp_set(
mlx_closure_custom_jvp* cls,
const mlx_closure_custom_jvp src) {
return mlx_closure_custom_jvp_set_(cls, src);
}
static inline int mlx_closure_custom_jvp_apply(
mlx_vector_array* res,
mlx_closure_custom_jvp cls,
const mlx_vector_array input_0,
const mlx_vector_array input_1,
const int* input_2,
size_t input_2_num) {
return mlx_closure_custom_jvp_apply_(res, cls, input_0, input_1, input_2, input_2_num);
}
static inline mlx_closure_custom_vmap mlx_closure_custom_vmap_new(void) {
return mlx_closure_custom_vmap_new_();
}
static inline int mlx_closure_custom_vmap_free(mlx_closure_custom_vmap cls) {
return mlx_closure_custom_vmap_free_(cls);
}
static inline mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func(int (*fun)(
mlx_vector_array*,
mlx_vector_int*,
const mlx_vector_array,
const int*,
size_t _num)) {
return mlx_closure_custom_vmap_new_func_(fun);
}
static inline mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func_payload(
int (*fun)(
mlx_vector_array*,
mlx_vector_int*,
const mlx_vector_array,
const int*,
size_t _num,
void*),
void* payload,
void (*dtor)(void*)) {
return mlx_closure_custom_vmap_new_func_payload_(fun, payload, dtor);
}
static inline int mlx_closure_custom_vmap_set(
mlx_closure_custom_vmap* cls,
const mlx_closure_custom_vmap src) {
return mlx_closure_custom_vmap_set_(cls, src);
}
static inline int mlx_closure_custom_vmap_apply(
mlx_vector_array* res_0,
mlx_vector_int* res_1,
mlx_closure_custom_vmap cls,
const mlx_vector_array input_0,
const int* input_1,
size_t input_1_num) {
return mlx_closure_custom_vmap_apply_(res_0, res_1, cls, input_0, input_1, input_1_num);
}
static inline int mlx_compile(mlx_closure* res, const mlx_closure fun, bool shapeless) {
return mlx_compile_(res, fun, shapeless);
}
static inline int mlx_detail_compile(
mlx_closure* res,
const mlx_closure fun,
uintptr_t fun_id,
bool shapeless,
const uint64_t* constants,
size_t constants_num) {
return mlx_detail_compile_(res, fun, fun_id, shapeless, constants, constants_num);
}
static inline int mlx_detail_compile_clear_cache(void) {
return mlx_detail_compile_clear_cache_();
}
static inline int mlx_detail_compile_erase(uintptr_t fun_id) {
return mlx_detail_compile_erase_(fun_id);
}
static inline int mlx_disable_compile(void) {
return mlx_disable_compile_();
}
static inline int mlx_enable_compile(void) {
return mlx_enable_compile_();
}
static inline int mlx_set_compile_mode(mlx_compile_mode mode) {
return mlx_set_compile_mode_(mode);
}
static inline mlx_device mlx_device_new(void) {
return mlx_device_new_();
}
static inline mlx_device mlx_device_new_type(mlx_device_type type, int index) {
return mlx_device_new_type_(type, index);
}
static inline int mlx_device_free(mlx_device dev) {
return mlx_device_free_(dev);
}
static inline int mlx_device_set(mlx_device* dev, const mlx_device src) {
return mlx_device_set_(dev, src);
}
static inline int mlx_device_tostring(mlx_string* str, mlx_device dev) {
return mlx_device_tostring_(str, dev);
}
static inline bool mlx_device_equal(mlx_device lhs, mlx_device rhs) {
return mlx_device_equal_(lhs, rhs);
}
static inline int mlx_device_get_index(int* index, mlx_device dev) {
return mlx_device_get_index_(index, dev);
}
static inline int mlx_device_get_type(mlx_device_type* type, mlx_device dev) {
return mlx_device_get_type_(type, dev);
}
static inline int mlx_get_default_device(mlx_device* dev) {
return mlx_get_default_device_(dev);
}
static inline int mlx_set_default_device(mlx_device dev) {
return mlx_set_default_device_(dev);
}
static inline int mlx_distributed_group_rank(mlx_distributed_group group) {
return mlx_distributed_group_rank_(group);
}
static inline int mlx_distributed_group_size(mlx_distributed_group group) {
return mlx_distributed_group_size_(group);
}
static inline mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, int color, int key) {
return mlx_distributed_group_split_(group, color, key);
}
static inline bool mlx_distributed_is_available(void) {
return mlx_distributed_is_available_();
}
static inline mlx_distributed_group mlx_distributed_init(bool strict) {
return mlx_distributed_init_(strict);
}
static inline int mlx_distributed_all_gather(
mlx_array* res,
const mlx_array x,
const mlx_distributed_group group /* may be null */,
const mlx_stream S) {
return mlx_distributed_all_gather_(res, x, group, S);
}
static inline int mlx_distributed_all_max(
mlx_array* res,
const mlx_array x,
const mlx_distributed_group group /* may be null */,
const mlx_stream s) {
return mlx_distributed_all_max_(res, x, group, s);
}
static inline int mlx_distributed_all_min(
mlx_array* res,
const mlx_array x,
const mlx_distributed_group group /* may be null */,
const mlx_stream s) {
return mlx_distributed_all_min_(res, x, group, s);
}
static inline int mlx_distributed_all_sum(
mlx_array* res,
const mlx_array x,
const mlx_distributed_group group /* may be null */,
const mlx_stream s) {
return mlx_distributed_all_sum_(res, x, group, s);
}
static inline int mlx_distributed_recv(
mlx_array* res,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
int src,
const mlx_distributed_group group /* may be null */,
const mlx_stream s) {
return mlx_distributed_recv_(res, shape, shape_num, dtype, src, group, s);
}
static inline int mlx_distributed_recv_like(
mlx_array* res,
const mlx_array x,
int src,
const mlx_distributed_group group /* may be null */,
const mlx_stream s) {
return mlx_distributed_recv_like_(res, x, src, group, s);
}
static inline int mlx_distributed_send(
mlx_array* res,
const mlx_array x,
int dst,
const mlx_distributed_group group /* may be null */,
const mlx_stream s) {
return mlx_distributed_send_(res, x, dst, group, s);
}
static inline int mlx_distributed_sum_scatter(
mlx_array* res,
const mlx_array x,
const mlx_distributed_group group /* may be null */,
const mlx_stream s) {
return mlx_distributed_sum_scatter_(res, x, group, s);
}
static inline void mlx_set_error_handler(
mlx_error_handler_func handler,
void* data,
void (*dtor)(void*)) {
mlx_set_error_handler_(handler, data, dtor);
}
#define _mlx_error(file, line, fmt, ...) _mlx_error_(file, line, fmt, __VA_ARGS__)
static inline int mlx_export_function(
const char* file,
const mlx_closure fun,
const mlx_vector_array args,
bool shapeless) {
return mlx_export_function_(file, fun, args, shapeless);
}
static inline int mlx_export_function_kwargs(
const char* file,
const mlx_closure_kwargs fun,
const mlx_vector_array args,
const mlx_map_string_to_array kwargs,
bool shapeless) {
return mlx_export_function_kwargs_(file, fun, args, kwargs, shapeless);
}
static inline mlx_function_exporter mlx_function_exporter_new(
const char* file,
const mlx_closure fun,
bool shapeless) {
return mlx_function_exporter_new_(file, fun, shapeless);
}
static inline int mlx_function_exporter_free(mlx_function_exporter xfunc) {
return mlx_function_exporter_free_(xfunc);
}
static inline int mlx_function_exporter_apply(
const mlx_function_exporter xfunc,
const mlx_vector_array args) {
return mlx_function_exporter_apply_(xfunc, args);
}
static inline int mlx_function_exporter_apply_kwargs(
const mlx_function_exporter xfunc,
const mlx_vector_array args,
const mlx_map_string_to_array kwargs) {
return mlx_function_exporter_apply_kwargs_(xfunc, args, kwargs);
}
static inline mlx_imported_function mlx_imported_function_new(const char* file) {
return mlx_imported_function_new_(file);
}
static inline int mlx_imported_function_free(mlx_imported_function xfunc) {
return mlx_imported_function_free_(xfunc);
}
static inline int mlx_imported_function_apply(
mlx_vector_array* res,
const mlx_imported_function xfunc,
const mlx_vector_array args) {
return mlx_imported_function_apply_(res, xfunc, args);
}
static inline int mlx_imported_function_apply_kwargs(
mlx_vector_array* res,
const mlx_imported_function xfunc,
const mlx_vector_array args,
const mlx_map_string_to_array kwargs) {
return mlx_imported_function_apply_kwargs_(res, xfunc, args, kwargs);
}
static inline mlx_fast_cuda_kernel_config mlx_fast_cuda_kernel_config_new(void) {
return mlx_fast_cuda_kernel_config_new_();
}
static inline void mlx_fast_cuda_kernel_config_free(mlx_fast_cuda_kernel_config cls) {
mlx_fast_cuda_kernel_config_free_(cls);
}
static inline int mlx_fast_cuda_kernel_config_add_output_arg(
mlx_fast_cuda_kernel_config cls,
const int* shape,
size_t size,
mlx_dtype dtype) {
return mlx_fast_cuda_kernel_config_add_output_arg_(cls, shape, size, dtype);
}
static inline int mlx_fast_cuda_kernel_config_set_grid(
mlx_fast_cuda_kernel_config cls,
int grid1,
int grid2,
int grid3) {
return mlx_fast_cuda_kernel_config_set_grid_(cls, grid1, grid2, grid3);
}
static inline int mlx_fast_cuda_kernel_config_set_thread_group(
mlx_fast_cuda_kernel_config cls,
int thread1,
int thread2,
int thread3) {
return mlx_fast_cuda_kernel_config_set_thread_group_(cls, thread1, thread2, thread3);
}
static inline int mlx_fast_cuda_kernel_config_set_init_value(
mlx_fast_cuda_kernel_config cls,
float value) {
return mlx_fast_cuda_kernel_config_set_init_value_(cls, value);
}
static inline int mlx_fast_cuda_kernel_config_set_verbose(
mlx_fast_cuda_kernel_config cls,
bool verbose) {
return mlx_fast_cuda_kernel_config_set_verbose_(cls, verbose);
}
static inline int mlx_fast_cuda_kernel_config_add_template_arg_dtype(
mlx_fast_cuda_kernel_config cls,
const char* name,
mlx_dtype dtype) {
return mlx_fast_cuda_kernel_config_add_template_arg_dtype_(cls, name, dtype);
}
static inline int mlx_fast_cuda_kernel_config_add_template_arg_int(
mlx_fast_cuda_kernel_config cls,
const char* name,
int value) {
return mlx_fast_cuda_kernel_config_add_template_arg_int_(cls, name, value);
}
static inline int mlx_fast_cuda_kernel_config_add_template_arg_bool(
mlx_fast_cuda_kernel_config cls,
const char* name,
bool value) {
return mlx_fast_cuda_kernel_config_add_template_arg_bool_(cls, name, value);
}
static inline mlx_fast_cuda_kernel mlx_fast_cuda_kernel_new(
const char* name,
const mlx_vector_string input_names,
const mlx_vector_string output_names,
const char* source,
const char* header,
bool ensure_row_contiguous,
int shared_memory) {
return mlx_fast_cuda_kernel_new_(name, input_names, output_names, source, header, ensure_row_contiguous, shared_memory);
}
static inline void mlx_fast_cuda_kernel_free(mlx_fast_cuda_kernel cls) {
mlx_fast_cuda_kernel_free_(cls);
}
static inline int mlx_fast_cuda_kernel_apply(
mlx_vector_array* outputs,
mlx_fast_cuda_kernel cls,
const mlx_vector_array inputs,
const mlx_fast_cuda_kernel_config config,
const mlx_stream stream) {
return mlx_fast_cuda_kernel_apply_(outputs, cls, inputs, config, stream);
}
static inline int mlx_fast_layer_norm(
mlx_array* res,
const mlx_array x,
const mlx_array weight /* may be null */,
const mlx_array bias /* may be null */,
float eps,
const mlx_stream s) {
return mlx_fast_layer_norm_(res, x, weight, bias, eps, s);
}
static inline mlx_fast_metal_kernel_config mlx_fast_metal_kernel_config_new(void) {
return mlx_fast_metal_kernel_config_new_();
}
static inline void mlx_fast_metal_kernel_config_free(mlx_fast_metal_kernel_config cls) {
mlx_fast_metal_kernel_config_free_(cls);
}
static inline int mlx_fast_metal_kernel_config_add_output_arg(
mlx_fast_metal_kernel_config cls,
const int* shape,
size_t size,
mlx_dtype dtype) {
return mlx_fast_metal_kernel_config_add_output_arg_(cls, shape, size, dtype);
}
static inline int mlx_fast_metal_kernel_config_set_grid(
mlx_fast_metal_kernel_config cls,
int grid1,
int grid2,
int grid3) {
return mlx_fast_metal_kernel_config_set_grid_(cls, grid1, grid2, grid3);
}
static inline int mlx_fast_metal_kernel_config_set_thread_group(
mlx_fast_metal_kernel_config cls,
int thread1,
int thread2,
int thread3) {
return mlx_fast_metal_kernel_config_set_thread_group_(cls, thread1, thread2, thread3);
}
static inline int mlx_fast_metal_kernel_config_set_init_value(
mlx_fast_metal_kernel_config cls,
float value) {
return mlx_fast_metal_kernel_config_set_init_value_(cls, value);
}
static inline int mlx_fast_metal_kernel_config_set_verbose(
mlx_fast_metal_kernel_config cls,
bool verbose) {
return mlx_fast_metal_kernel_config_set_verbose_(cls, verbose);
}
static inline int mlx_fast_metal_kernel_config_add_template_arg_dtype(
mlx_fast_metal_kernel_config cls,
const char* name,
mlx_dtype dtype) {
return mlx_fast_metal_kernel_config_add_template_arg_dtype_(cls, name, dtype);
}
static inline int mlx_fast_metal_kernel_config_add_template_arg_int(
mlx_fast_metal_kernel_config cls,
const char* name,
int value) {
return mlx_fast_metal_kernel_config_add_template_arg_int_(cls, name, value);
}
static inline int mlx_fast_metal_kernel_config_add_template_arg_bool(
mlx_fast_metal_kernel_config cls,
const char* name,
bool value) {
return mlx_fast_metal_kernel_config_add_template_arg_bool_(cls, name, value);
}
static inline mlx_fast_metal_kernel mlx_fast_metal_kernel_new(
const char* name,
const mlx_vector_string input_names,
const mlx_vector_string output_names,
const char* source,
const char* header,
bool ensure_row_contiguous,
bool atomic_outputs) {
return mlx_fast_metal_kernel_new_(name, input_names, output_names, source, header, ensure_row_contiguous, atomic_outputs);
}
static inline void mlx_fast_metal_kernel_free(mlx_fast_metal_kernel cls) {
mlx_fast_metal_kernel_free_(cls);
}
static inline int mlx_fast_metal_kernel_apply(
mlx_vector_array* outputs,
mlx_fast_metal_kernel cls,
const mlx_vector_array inputs,
const mlx_fast_metal_kernel_config config,
const mlx_stream stream) {
return mlx_fast_metal_kernel_apply_(outputs, cls, inputs, config, stream);
}
static inline int mlx_fast_rms_norm(
mlx_array* res,
const mlx_array x,
const mlx_array weight /* may be null */,
float eps,
const mlx_stream s) {
return mlx_fast_rms_norm_(res, x, weight, eps, s);
}
static inline int mlx_fast_rope(
mlx_array* res,
const mlx_array x,
int dims,
bool traditional,
mlx_optional_float base,
float scale,
int offset,
const mlx_array freqs /* may be null */,
const mlx_stream s) {
return mlx_fast_rope_(res, x, dims, traditional, base, scale, offset, freqs, s);
}
static inline int mlx_fast_scaled_dot_product_attention(
mlx_array* res,
const mlx_array queries,
const mlx_array keys,
const mlx_array values,
float scale,
const char* mask_mode,
const mlx_array mask_arr /* may be null */,
const mlx_array sinks /* may be null */,
const mlx_stream s) {
return mlx_fast_scaled_dot_product_attention_(res, queries, keys, values, scale, mask_mode, mask_arr, sinks, s);
}
static inline int mlx_fft_fft(
mlx_array* res,
const mlx_array a,
int n,
int axis,
const mlx_stream s) {
return mlx_fft_fft_(res, a, n, axis, s);
}
static inline int mlx_fft_fft2(
mlx_array* res,
const mlx_array a,
const int* n,
size_t n_num,
const int* axes,
size_t axes_num,
const mlx_stream s) {
return mlx_fft_fft2_(res, a, n, n_num, axes, axes_num, s);
}
static inline int mlx_fft_fftn(
mlx_array* res,
const mlx_array a,
const int* n,
size_t n_num,
const int* axes,
size_t axes_num,
const mlx_stream s) {
return mlx_fft_fftn_(res, a, n, n_num, axes, axes_num, s);
}
static inline int mlx_fft_fftshift(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
const mlx_stream s) {
return mlx_fft_fftshift_(res, a, axes, axes_num, s);
}
static inline int mlx_fft_ifft(
mlx_array* res,
const mlx_array a,
int n,
int axis,
const mlx_stream s) {
return mlx_fft_ifft_(res, a, n, axis, s);
}
static inline int mlx_fft_ifft2(
mlx_array* res,
const mlx_array a,
const int* n,
size_t n_num,
const int* axes,
size_t axes_num,
const mlx_stream s) {
return mlx_fft_ifft2_(res, a, n, n_num, axes, axes_num, s);
}
static inline int mlx_fft_ifftn(
mlx_array* res,
const mlx_array a,
const int* n,
size_t n_num,
const int* axes,
size_t axes_num,
const mlx_stream s) {
return mlx_fft_ifftn_(res, a, n, n_num, axes, axes_num, s);
}
static inline int mlx_fft_ifftshift(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
const mlx_stream s) {
return mlx_fft_ifftshift_(res, a, axes, axes_num, s);
}
static inline int mlx_fft_irfft(
mlx_array* res,
const mlx_array a,
int n,
int axis,
const mlx_stream s) {
return mlx_fft_irfft_(res, a, n, axis, s);
}
static inline int mlx_fft_irfft2(
mlx_array* res,
const mlx_array a,
const int* n,
size_t n_num,
const int* axes,
size_t axes_num,
const mlx_stream s) {
return mlx_fft_irfft2_(res, a, n, n_num, axes, axes_num, s);
}
static inline int mlx_fft_irfftn(
mlx_array* res,
const mlx_array a,
const int* n,
size_t n_num,
const int* axes,
size_t axes_num,
const mlx_stream s) {
return mlx_fft_irfftn_(res, a, n, n_num, axes, axes_num, s);
}
static inline int mlx_fft_rfft(
mlx_array* res,
const mlx_array a,
int n,
int axis,
const mlx_stream s) {
return mlx_fft_rfft_(res, a, n, axis, s);
}
static inline int mlx_fft_rfft2(
mlx_array* res,
const mlx_array a,
const int* n,
size_t n_num,
const int* axes,
size_t axes_num,
const mlx_stream s) {
return mlx_fft_rfft2_(res, a, n, n_num, axes, axes_num, s);
}
static inline int mlx_fft_rfftn(
mlx_array* res,
const mlx_array a,
const int* n,
size_t n_num,
const int* axes,
size_t axes_num,
const mlx_stream s) {
return mlx_fft_rfftn_(res, a, n, n_num, axes, axes_num, s);
}
static inline mlx_io_reader mlx_io_reader_new(void* desc, mlx_io_vtable vtable) {
return mlx_io_reader_new_(desc, vtable);
}
static inline int mlx_io_reader_descriptor(void** desc_, mlx_io_reader io) {
return mlx_io_reader_descriptor_(desc_, io);
}
static inline int mlx_io_reader_tostring(mlx_string* str_, mlx_io_reader io) {
return mlx_io_reader_tostring_(str_, io);
}
static inline int mlx_io_reader_free(mlx_io_reader io) {
return mlx_io_reader_free_(io);
}
static inline mlx_io_writer mlx_io_writer_new(void* desc, mlx_io_vtable vtable) {
return mlx_io_writer_new_(desc, vtable);
}
static inline int mlx_io_writer_descriptor(void** desc_, mlx_io_writer io) {
return mlx_io_writer_descriptor_(desc_, io);
}
static inline int mlx_io_writer_tostring(mlx_string* str_, mlx_io_writer io) {
return mlx_io_writer_tostring_(str_, io);
}
static inline int mlx_io_writer_free(mlx_io_writer io) {
return mlx_io_writer_free_(io);
}
static inline int mlx_load_reader(
mlx_array* res,
mlx_io_reader in_stream,
const mlx_stream s) {
return mlx_load_reader_(res, in_stream, s);
}
static inline int mlx_load(mlx_array* res, const char* file, const mlx_stream s) {
return mlx_load_(res, file, s);
}
static inline int mlx_load_safetensors_reader(
mlx_map_string_to_array* res_0,
mlx_map_string_to_string* res_1,
mlx_io_reader in_stream,
const mlx_stream s) {
return mlx_load_safetensors_reader_(res_0, res_1, in_stream, s);
}
static inline int mlx_load_safetensors(
mlx_map_string_to_array* res_0,
mlx_map_string_to_string* res_1,
const char* file,
const mlx_stream s) {
return mlx_load_safetensors_(res_0, res_1, file, s);
}
static inline int mlx_save_writer(mlx_io_writer out_stream, const mlx_array a) {
return mlx_save_writer_(out_stream, a);
}
static inline int mlx_save(const char* file, const mlx_array a) {
return mlx_save_(file, a);
}
static inline int mlx_save_safetensors_writer(
mlx_io_writer in_stream,
const mlx_map_string_to_array param,
const mlx_map_string_to_string metadata) {
return mlx_save_safetensors_writer_(in_stream, param, metadata);
}
static inline int mlx_save_safetensors(
const char* file,
const mlx_map_string_to_array param,
const mlx_map_string_to_string metadata) {
return mlx_save_safetensors_(file, param, metadata);
}
static inline int mlx_linalg_cholesky(
mlx_array* res,
const mlx_array a,
bool upper,
const mlx_stream s) {
return mlx_linalg_cholesky_(res, a, upper, s);
}
static inline int mlx_linalg_cholesky_inv(
mlx_array* res,
const mlx_array a,
bool upper,
const mlx_stream s) {
return mlx_linalg_cholesky_inv_(res, a, upper, s);
}
static inline int mlx_linalg_cross(
mlx_array* res,
const mlx_array a,
const mlx_array b,
int axis,
const mlx_stream s) {
return mlx_linalg_cross_(res, a, b, axis, s);
}
static inline int mlx_linalg_eig(
mlx_array* res_0,
mlx_array* res_1,
const mlx_array a,
const mlx_stream s) {
return mlx_linalg_eig_(res_0, res_1, a, s);
}
static inline int mlx_linalg_eigh(
mlx_array* res_0,
mlx_array* res_1,
const mlx_array a,
const char* UPLO,
const mlx_stream s) {
return mlx_linalg_eigh_(res_0, res_1, a, UPLO, s);
}
static inline int mlx_linalg_eigvals(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_linalg_eigvals_(res, a, s);
}
static inline int mlx_linalg_eigvalsh(
mlx_array* res,
const mlx_array a,
const char* UPLO,
const mlx_stream s) {
return mlx_linalg_eigvalsh_(res, a, UPLO, s);
}
static inline int mlx_linalg_inv(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_linalg_inv_(res, a, s);
}
static inline int mlx_linalg_lu(mlx_vector_array* res, const mlx_array a, const mlx_stream s) {
return mlx_linalg_lu_(res, a, s);
}
static inline int mlx_linalg_lu_factor(
mlx_array* res_0,
mlx_array* res_1,
const mlx_array a,
const mlx_stream s) {
return mlx_linalg_lu_factor_(res_0, res_1, a, s);
}
static inline int mlx_linalg_norm(
mlx_array* res,
const mlx_array a,
double ord,
const int* axis /* may be null */,
size_t axis_num,
bool keepdims,
const mlx_stream s) {
return mlx_linalg_norm_(res, a, ord, axis, axis_num, keepdims, s);
}
static inline int mlx_linalg_norm_matrix(
mlx_array* res,
const mlx_array a,
const char* ord,
const int* axis /* may be null */,
size_t axis_num,
bool keepdims,
const mlx_stream s) {
return mlx_linalg_norm_matrix_(res, a, ord, axis, axis_num, keepdims, s);
}
static inline int mlx_linalg_norm_l2(
mlx_array* res,
const mlx_array a,
const int* axis /* may be null */,
size_t axis_num,
bool keepdims,
const mlx_stream s) {
return mlx_linalg_norm_l2_(res, a, axis, axis_num, keepdims, s);
}
static inline int mlx_linalg_pinv(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_linalg_pinv_(res, a, s);
}
static inline int mlx_linalg_qr(
mlx_array* res_0,
mlx_array* res_1,
const mlx_array a,
const mlx_stream s) {
return mlx_linalg_qr_(res_0, res_1, a, s);
}
static inline int mlx_linalg_solve(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s) {
return mlx_linalg_solve_(res, a, b, s);
}
static inline int mlx_linalg_solve_triangular(
mlx_array* res,
const mlx_array a,
const mlx_array b,
bool upper,
const mlx_stream s) {
return mlx_linalg_solve_triangular_(res, a, b, upper, s);
}
static inline int mlx_linalg_svd(
mlx_vector_array* res,
const mlx_array a,
bool compute_uv,
const mlx_stream s) {
return mlx_linalg_svd_(res, a, compute_uv, s);
}
static inline int mlx_linalg_tri_inv(
mlx_array* res,
const mlx_array a,
bool upper,
const mlx_stream s) {
return mlx_linalg_tri_inv_(res, a, upper, s);
}
static inline mlx_map_string_to_array mlx_map_string_to_array_new(void) {
return mlx_map_string_to_array_new_();
}
static inline int mlx_map_string_to_array_set(
mlx_map_string_to_array* map,
const mlx_map_string_to_array src) {
return mlx_map_string_to_array_set_(map, src);
}
static inline int mlx_map_string_to_array_free(mlx_map_string_to_array map) {
return mlx_map_string_to_array_free_(map);
}
static inline int mlx_map_string_to_array_insert(
mlx_map_string_to_array map,
const char* key,
const mlx_array value) {
return mlx_map_string_to_array_insert_(map, key, value);
}
static inline int mlx_map_string_to_array_get(
mlx_array* value,
const mlx_map_string_to_array map,
const char* key) {
return mlx_map_string_to_array_get_(value, map, key);
}
static inline mlx_map_string_to_array_iterator mlx_map_string_to_array_iterator_new(
mlx_map_string_to_array map) {
return mlx_map_string_to_array_iterator_new_(map);
}
static inline int mlx_map_string_to_array_iterator_free(mlx_map_string_to_array_iterator it) {
return mlx_map_string_to_array_iterator_free_(it);
}
static inline int mlx_map_string_to_array_iterator_next(
const char** key,
mlx_array* value,
mlx_map_string_to_array_iterator it) {
return mlx_map_string_to_array_iterator_next_(key, value, it);
}
static inline mlx_map_string_to_string mlx_map_string_to_string_new(void) {
return mlx_map_string_to_string_new_();
}
static inline int mlx_map_string_to_string_set(
mlx_map_string_to_string* map,
const mlx_map_string_to_string src) {
return mlx_map_string_to_string_set_(map, src);
}
static inline int mlx_map_string_to_string_free(mlx_map_string_to_string map) {
return mlx_map_string_to_string_free_(map);
}
static inline int mlx_map_string_to_string_insert(
mlx_map_string_to_string map,
const char* key,
const char* value) {
return mlx_map_string_to_string_insert_(map, key, value);
}
static inline int mlx_map_string_to_string_get(
const char** value,
const mlx_map_string_to_string map,
const char* key) {
return mlx_map_string_to_string_get_(value, map, key);
}
static inline mlx_map_string_to_string_iterator mlx_map_string_to_string_iterator_new(
mlx_map_string_to_string map) {
return mlx_map_string_to_string_iterator_new_(map);
}
static inline int mlx_map_string_to_string_iterator_free(
mlx_map_string_to_string_iterator it) {
return mlx_map_string_to_string_iterator_free_(it);
}
static inline int mlx_map_string_to_string_iterator_next(
const char** key,
const char** value,
mlx_map_string_to_string_iterator it) {
return mlx_map_string_to_string_iterator_next_(key, value, it);
}
static inline int mlx_clear_cache(void) {
return mlx_clear_cache_();
}
static inline int mlx_get_active_memory(size_t* res) {
return mlx_get_active_memory_(res);
}
static inline int mlx_get_cache_memory(size_t* res) {
return mlx_get_cache_memory_(res);
}
static inline int mlx_get_memory_limit(size_t* res) {
return mlx_get_memory_limit_(res);
}
static inline int mlx_get_peak_memory(size_t* res) {
return mlx_get_peak_memory_(res);
}
static inline int mlx_reset_peak_memory(void) {
return mlx_reset_peak_memory_();
}
static inline int mlx_set_cache_limit(size_t* res, size_t limit) {
return mlx_set_cache_limit_(res, limit);
}
static inline int mlx_set_memory_limit(size_t* res, size_t limit) {
return mlx_set_memory_limit_(res, limit);
}
static inline int mlx_set_wired_limit(size_t* res, size_t limit) {
return mlx_set_wired_limit_(res, limit);
}
static inline mlx_metal_device_info_t mlx_metal_device_info(void) {
return mlx_metal_device_info_();
}
static inline int mlx_metal_is_available(bool* res) {
return mlx_metal_is_available_(res);
}
static inline int mlx_metal_start_capture(const char* path) {
return mlx_metal_start_capture_(path);
}
static inline int mlx_metal_stop_capture(void) {
return mlx_metal_stop_capture_();
}
static inline int mlx_abs(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_abs_(res, a, s);
}
static inline int mlx_add(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s) {
return mlx_add_(res, a, b, s);
}
static inline int mlx_addmm(
mlx_array* res,
const mlx_array c,
const mlx_array a,
const mlx_array b,
float alpha,
float beta,
const mlx_stream s) {
return mlx_addmm_(res, c, a, b, alpha, beta, s);
}
static inline int mlx_all_axes(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
bool keepdims,
const mlx_stream s) {
return mlx_all_axes_(res, a, axes, axes_num, keepdims, s);
}
static inline int mlx_all_axis(
mlx_array* res,
const mlx_array a,
int axis,
bool keepdims,
const mlx_stream s) {
return mlx_all_axis_(res, a, axis, keepdims, s);
}
static inline int mlx_all(
mlx_array* res,
const mlx_array a,
bool keepdims,
const mlx_stream s) {
return mlx_all_(res, a, keepdims, s);
}
static inline int mlx_allclose(
mlx_array* res,
const mlx_array a,
const mlx_array b,
double rtol,
double atol,
bool equal_nan,
const mlx_stream s) {
return mlx_allclose_(res, a, b, rtol, atol, equal_nan, s);
}
static inline int mlx_any_axes(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
bool keepdims,
const mlx_stream s) {
return mlx_any_axes_(res, a, axes, axes_num, keepdims, s);
}
static inline int mlx_any_axis(
mlx_array* res,
const mlx_array a,
int axis,
bool keepdims,
const mlx_stream s) {
return mlx_any_axis_(res, a, axis, keepdims, s);
}
static inline int mlx_any(
mlx_array* res,
const mlx_array a,
bool keepdims,
const mlx_stream s) {
return mlx_any_(res, a, keepdims, s);
}
static inline int mlx_arange(
mlx_array* res,
double start,
double stop,
double step,
mlx_dtype dtype,
const mlx_stream s) {
return mlx_arange_(res, start, stop, step, dtype, s);
}
static inline int mlx_arccos(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_arccos_(res, a, s);
}
static inline int mlx_arccosh(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_arccosh_(res, a, s);
}
static inline int mlx_arcsin(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_arcsin_(res, a, s);
}
static inline int mlx_arcsinh(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_arcsinh_(res, a, s);
}
static inline int mlx_arctan(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_arctan_(res, a, s);
}
static inline int mlx_arctan2(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s) {
return mlx_arctan2_(res, a, b, s);
}
static inline int mlx_arctanh(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_arctanh_(res, a, s);
}
static inline int mlx_argmax_axis(
mlx_array* res,
const mlx_array a,
int axis,
bool keepdims,
const mlx_stream s) {
return mlx_argmax_axis_(res, a, axis, keepdims, s);
}
static inline int mlx_argmax(
mlx_array* res,
const mlx_array a,
bool keepdims,
const mlx_stream s) {
return mlx_argmax_(res, a, keepdims, s);
}
static inline int mlx_argmin_axis(
mlx_array* res,
const mlx_array a,
int axis,
bool keepdims,
const mlx_stream s) {
return mlx_argmin_axis_(res, a, axis, keepdims, s);
}
static inline int mlx_argmin(
mlx_array* res,
const mlx_array a,
bool keepdims,
const mlx_stream s) {
return mlx_argmin_(res, a, keepdims, s);
}
static inline int mlx_argpartition_axis(
mlx_array* res,
const mlx_array a,
int kth,
int axis,
const mlx_stream s) {
return mlx_argpartition_axis_(res, a, kth, axis, s);
}
static inline int mlx_argpartition(
mlx_array* res,
const mlx_array a,
int kth,
const mlx_stream s) {
return mlx_argpartition_(res, a, kth, s);
}
static inline int mlx_argsort_axis(
mlx_array* res,
const mlx_array a,
int axis,
const mlx_stream s) {
return mlx_argsort_axis_(res, a, axis, s);
}
static inline int mlx_argsort(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_argsort_(res, a, s);
}
static inline int mlx_array_equal(
mlx_array* res,
const mlx_array a,
const mlx_array b,
bool equal_nan,
const mlx_stream s) {
return mlx_array_equal_(res, a, b, equal_nan, s);
}
static inline int mlx_as_strided(
mlx_array* res,
const mlx_array a,
const int* shape,
size_t shape_num,
const int64_t* strides,
size_t strides_num,
size_t offset,
const mlx_stream s) {
return mlx_as_strided_(res, a, shape, shape_num, strides, strides_num, offset, s);
}
static inline int mlx_astype(
mlx_array* res,
const mlx_array a,
mlx_dtype dtype,
const mlx_stream s) {
return mlx_astype_(res, a, dtype, s);
}
static inline int mlx_atleast_1d(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_atleast_1d_(res, a, s);
}
static inline int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_atleast_2d_(res, a, s);
}
static inline int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_atleast_3d_(res, a, s);
}
static inline int mlx_bitwise_and(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s) {
return mlx_bitwise_and_(res, a, b, s);
}
static inline int mlx_bitwise_invert(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_bitwise_invert_(res, a, s);
}
static inline int mlx_bitwise_or(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s) {
return mlx_bitwise_or_(res, a, b, s);
}
static inline int mlx_bitwise_xor(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s) {
return mlx_bitwise_xor_(res, a, b, s);
}
static inline int mlx_block_masked_mm(
mlx_array* res,
const mlx_array a,
const mlx_array b,
int block_size,
const mlx_array mask_out /* may be null */,
const mlx_array mask_lhs /* may be null */,
const mlx_array mask_rhs /* may be null */,
const mlx_stream s) {
return mlx_block_masked_mm_(res, a, b, block_size, mask_out, mask_lhs, mask_rhs, s);
}
static inline int mlx_broadcast_arrays(
mlx_vector_array* res,
const mlx_vector_array inputs,
const mlx_stream s) {
return mlx_broadcast_arrays_(res, inputs, s);
}
static inline int mlx_broadcast_to(
mlx_array* res,
const mlx_array a,
const int* shape,
size_t shape_num,
const mlx_stream s) {
return mlx_broadcast_to_(res, a, shape, shape_num, s);
}
static inline int mlx_ceil(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_ceil_(res, a, s);
}
static inline int mlx_clip(
mlx_array* res,
const mlx_array a,
const mlx_array a_min /* may be null */,
const mlx_array a_max /* may be null */,
const mlx_stream s) {
return mlx_clip_(res, a, a_min, a_max, s);
}
static inline int mlx_concatenate_axis(
mlx_array* res,
const mlx_vector_array arrays,
int axis,
const mlx_stream s) {
return mlx_concatenate_axis_(res, arrays, axis, s);
}
static inline int mlx_concatenate(
mlx_array* res,
const mlx_vector_array arrays,
const mlx_stream s) {
return mlx_concatenate_(res, arrays, s);
}
static inline int mlx_conjugate(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_conjugate_(res, a, s);
}
static inline int mlx_contiguous(
mlx_array* res,
const mlx_array a,
bool allow_col_major,
const mlx_stream s) {
return mlx_contiguous_(res, a, allow_col_major, s);
}
static inline int mlx_conv1d(
mlx_array* res,
const mlx_array input,
const mlx_array weight,
int stride,
int padding,
int dilation,
int groups,
const mlx_stream s) {
return mlx_conv1d_(res, input, weight, stride, padding, dilation, groups, s);
}
static inline int mlx_conv2d(
mlx_array* res,
const mlx_array input,
const mlx_array weight,
int stride_0,
int stride_1,
int padding_0,
int padding_1,
int dilation_0,
int dilation_1,
int groups,
const mlx_stream s) {
return mlx_conv2d_(res, input, weight, stride_0, stride_1, padding_0, padding_1, dilation_0, dilation_1, groups, s);
}
static inline int mlx_conv3d(
mlx_array* res,
const mlx_array input,
const mlx_array weight,
int stride_0,
int stride_1,
int stride_2,
int padding_0,
int padding_1,
int padding_2,
int dilation_0,
int dilation_1,
int dilation_2,
int groups,
const mlx_stream s) {
return mlx_conv3d_(res, input, weight, stride_0, stride_1, stride_2, padding_0, padding_1, padding_2, dilation_0, dilation_1, dilation_2, groups, s);
}
static inline int mlx_conv_general(
mlx_array* res,
const mlx_array input,
const mlx_array weight,
const int* stride,
size_t stride_num,
const int* padding_lo,
size_t padding_lo_num,
const int* padding_hi,
size_t padding_hi_num,
const int* kernel_dilation,
size_t kernel_dilation_num,
const int* input_dilation,
size_t input_dilation_num,
int groups,
bool flip,
const mlx_stream s) {
return mlx_conv_general_(res, input, weight, stride, stride_num, padding_lo, padding_lo_num, padding_hi, padding_hi_num, kernel_dilation, kernel_dilation_num, input_dilation, input_dilation_num, groups, flip, s);
}
static inline int mlx_conv_transpose1d(
mlx_array* res,
const mlx_array input,
const mlx_array weight,
int stride,
int padding,
int dilation,
int output_padding,
int groups,
const mlx_stream s) {
return mlx_conv_transpose1d_(res, input, weight, stride, padding, dilation, output_padding, groups, s);
}
static inline int mlx_conv_transpose2d(
mlx_array* res,
const mlx_array input,
const mlx_array weight,
int stride_0,
int stride_1,
int padding_0,
int padding_1,
int dilation_0,
int dilation_1,
int output_padding_0,
int output_padding_1,
int groups,
const mlx_stream s) {
return mlx_conv_transpose2d_(res, input, weight, stride_0, stride_1, padding_0, padding_1, dilation_0, dilation_1, output_padding_0, output_padding_1, groups, s);
}
static inline int mlx_conv_transpose3d(
mlx_array* res,
const mlx_array input,
const mlx_array weight,
int stride_0,
int stride_1,
int stride_2,
int padding_0,
int padding_1,
int padding_2,
int dilation_0,
int dilation_1,
int dilation_2,
int output_padding_0,
int output_padding_1,
int output_padding_2,
int groups,
const mlx_stream s) {
return mlx_conv_transpose3d_(res, input, weight, stride_0, stride_1, stride_2, padding_0, padding_1, padding_2, dilation_0, dilation_1, dilation_2, output_padding_0, output_padding_1, output_padding_2, groups, s);
}
static inline int mlx_copy(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_copy_(res, a, s);
}
static inline int mlx_cos(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_cos_(res, a, s);
}
static inline int mlx_cosh(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_cosh_(res, a, s);
}
static inline int mlx_cummax(
mlx_array* res,
const mlx_array a,
int axis,
bool reverse,
bool inclusive,
const mlx_stream s) {
return mlx_cummax_(res, a, axis, reverse, inclusive, s);
}
static inline int mlx_cummin(
mlx_array* res,
const mlx_array a,
int axis,
bool reverse,
bool inclusive,
const mlx_stream s) {
return mlx_cummin_(res, a, axis, reverse, inclusive, s);
}
static inline int mlx_cumprod(
mlx_array* res,
const mlx_array a,
int axis,
bool reverse,
bool inclusive,
const mlx_stream s) {
return mlx_cumprod_(res, a, axis, reverse, inclusive, s);
}
static inline int mlx_cumsum(
mlx_array* res,
const mlx_array a,
int axis,
bool reverse,
bool inclusive,
const mlx_stream s) {
return mlx_cumsum_(res, a, axis, reverse, inclusive, s);
}
static inline int mlx_degrees(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_degrees_(res, a, s);
}
static inline int mlx_depends(
mlx_vector_array* res,
const mlx_vector_array inputs,
const mlx_vector_array dependencies) {
return mlx_depends_(res, inputs, dependencies);
}
static inline int mlx_dequantize(
mlx_array* res,
const mlx_array w,
const mlx_array scales,
const mlx_array biases /* may be null */,
mlx_optional_int group_size,
mlx_optional_int bits,
const char* mode,
mlx_optional_dtype dtype,
const mlx_stream s) {
return mlx_dequantize_(res, w, scales, biases, group_size, bits, mode, dtype, s);
}
static inline int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s) {
return mlx_diag_(res, a, k, s);
}
static inline int mlx_diagonal(
mlx_array* res,
const mlx_array a,
int offset,
int axis1,
int axis2,
const mlx_stream s) {
return mlx_diagonal_(res, a, offset, axis1, axis2, s);
}
static inline int mlx_divide(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s) {
return mlx_divide_(res, a, b, s);
}
static inline int mlx_divmod(
mlx_vector_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s) {
return mlx_divmod_(res, a, b, s);
}
static inline int mlx_einsum(
mlx_array* res,
const char* subscripts,
const mlx_vector_array operands,
const mlx_stream s) {
return mlx_einsum_(res, subscripts, operands, s);
}
static inline int mlx_equal(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s) {
return mlx_equal_(res, a, b, s);
}
static inline int mlx_erf(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_erf_(res, a, s);
}
static inline int mlx_erfinv(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_erfinv_(res, a, s);
}
static inline int mlx_exp(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_exp_(res, a, s);
}
static inline int mlx_expand_dims_axes(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
const mlx_stream s) {
return mlx_expand_dims_axes_(res, a, axes, axes_num, s);
}
static inline int mlx_expand_dims(
mlx_array* res,
const mlx_array a,
int axis,
const mlx_stream s) {
return mlx_expand_dims_(res, a, axis, s);
}
static inline int mlx_expm1(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_expm1_(res, a, s);
}
static inline int mlx_eye(
mlx_array* res,
int n,
int m,
int k,
mlx_dtype dtype,
const mlx_stream s) {
return mlx_eye_(res, n, m, k, dtype, s);
}
static inline int mlx_flatten(
mlx_array* res,
const mlx_array a,
int start_axis,
int end_axis,
const mlx_stream s) {
return mlx_flatten_(res, a, start_axis, end_axis, s);
}
static inline int mlx_floor(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_floor_(res, a, s);
}
static inline int mlx_floor_divide(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s) {
return mlx_floor_divide_(res, a, b, s);
}
static inline int mlx_from_fp8(
mlx_array* res,
const mlx_array x,
mlx_dtype dtype,
const mlx_stream s) {
return mlx_from_fp8_(res, x, dtype, s);
}
static inline int mlx_full(
mlx_array* res,
const int* shape,
size_t shape_num,
const mlx_array vals,
mlx_dtype dtype,
const mlx_stream s) {
return mlx_full_(res, shape, shape_num, vals, dtype, s);
}
static inline int mlx_full_like(
mlx_array* res,
const mlx_array a,
const mlx_array vals,
mlx_dtype dtype,
const mlx_stream s) {
return mlx_full_like_(res, a, vals, dtype, s);
}
static inline int mlx_gather(
mlx_array* res,
const mlx_array a,
const mlx_vector_array indices,
const int* axes,
size_t axes_num,
const int* slice_sizes,
size_t slice_sizes_num,
const mlx_stream s) {
return mlx_gather_(res, a, indices, axes, axes_num, slice_sizes, slice_sizes_num, s);
}
static inline int mlx_gather_mm(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_array lhs_indices /* may be null */,
const mlx_array rhs_indices /* may be null */,
bool sorted_indices,
const mlx_stream s) {
return mlx_gather_mm_(res, a, b, lhs_indices, rhs_indices, sorted_indices, s);
}
static inline int mlx_gather_qmm(
mlx_array* res,
const mlx_array x,
const mlx_array w,
const mlx_array scales,
const mlx_array biases /* may be null */,
const mlx_array lhs_indices /* may be null */,
const mlx_array rhs_indices /* may be null */,
bool transpose,
mlx_optional_int group_size,
mlx_optional_int bits,
const char* mode,
bool sorted_indices,
const mlx_stream s) {
return mlx_gather_qmm_(res, x, w, scales, biases, lhs_indices, rhs_indices, transpose, group_size, bits, mode, sorted_indices, s);
}
static inline int mlx_greater(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s) {
return mlx_greater_(res, a, b, s);
}
static inline int mlx_greater_equal(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s) {
return mlx_greater_equal_(res, a, b, s);
}
static inline int mlx_hadamard_transform(
mlx_array* res,
const mlx_array a,
mlx_optional_float scale,
const mlx_stream s) {
return mlx_hadamard_transform_(res, a, scale, s);
}
static inline int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) {
return mlx_identity_(res, n, dtype, s);
}
static inline int mlx_imag(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_imag_(res, a, s);
}
static inline int mlx_inner(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s) {
return mlx_inner_(res, a, b, s);
}
static inline int mlx_isclose(
mlx_array* res,
const mlx_array a,
const mlx_array b,
double rtol,
double atol,
bool equal_nan,
const mlx_stream s) {
return mlx_isclose_(res, a, b, rtol, atol, equal_nan, s);
}
static inline int mlx_isfinite(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_isfinite_(res, a, s);
}
static inline int mlx_isinf(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_isinf_(res, a, s);
}
static inline int mlx_isnan(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_isnan_(res, a, s);
}
static inline int mlx_isneginf(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_isneginf_(res, a, s);
}
static inline int mlx_isposinf(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_isposinf_(res, a, s);
}
static inline int mlx_kron(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s) {
return mlx_kron_(res, a, b, s);
}
static inline int mlx_left_shift(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s) {
return mlx_left_shift_(res, a, b, s);
}
static inline int mlx_less(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s) {
return mlx_less_(res, a, b, s);
}
static inline int mlx_less_equal(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s) {
return mlx_less_equal_(res, a, b, s);
}
static inline int mlx_linspace(
mlx_array* res,
double start,
double stop,
int num,
mlx_dtype dtype,
const mlx_stream s) {
return mlx_linspace_(res, start, stop, num, dtype, s);
}
static inline int mlx_log(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_log_(res, a, s);
}
static inline int mlx_log10(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_log10_(res, a, s);
}
static inline int mlx_log1p(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_log1p_(res, a, s);
}
static inline int mlx_log2(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_log2_(res, a, s);
}
static inline int mlx_logaddexp(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s) {
return mlx_logaddexp_(res, a, b, s);
}
static inline int mlx_logcumsumexp(
mlx_array* res,
const mlx_array a,
int axis,
bool reverse,
bool inclusive,
const mlx_stream s) {
return mlx_logcumsumexp_(res, a, axis, reverse, inclusive, s);
}
static inline int mlx_logical_and(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s) {
return mlx_logical_and_(res, a, b, s);
}
static inline int mlx_logical_not(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_logical_not_(res, a, s);
}
static inline int mlx_logical_or(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s) {
return mlx_logical_or_(res, a, b, s);
}
static inline int mlx_logsumexp_axes(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
bool keepdims,
const mlx_stream s) {
return mlx_logsumexp_axes_(res, a, axes, axes_num, keepdims, s);
}
static inline int mlx_logsumexp_axis(
mlx_array* res,
const mlx_array a,
int axis,
bool keepdims,
const mlx_stream s) {
return mlx_logsumexp_axis_(res, a, axis, keepdims, s);
}
static inline int mlx_logsumexp(
mlx_array* res,
const mlx_array a,
bool keepdims,
const mlx_stream s) {
return mlx_logsumexp_(res, a, keepdims, s);
}
static inline int mlx_masked_scatter(
mlx_array* res,
const mlx_array a,
const mlx_array mask,
const mlx_array src,
const mlx_stream s) {
return mlx_masked_scatter_(res, a, mask, src, s);
}
static inline int mlx_matmul(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s) {
return mlx_matmul_(res, a, b, s);
}
static inline int mlx_max_axes(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
bool keepdims,
const mlx_stream s) {
return mlx_max_axes_(res, a, axes, axes_num, keepdims, s);
}
static inline int mlx_max_axis(
mlx_array* res,
const mlx_array a,
int axis,
bool keepdims,
const mlx_stream s) {
return mlx_max_axis_(res, a, axis, keepdims, s);
}
static inline int mlx_max(
mlx_array* res,
const mlx_array a,
bool keepdims,
const mlx_stream s) {
return mlx_max_(res, a, keepdims, s);
}
static inline int mlx_maximum(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s) {
return mlx_maximum_(res, a, b, s);
}
static inline int mlx_mean_axes(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
bool keepdims,
const mlx_stream s) {
return mlx_mean_axes_(res, a, axes, axes_num, keepdims, s);
}
static inline int mlx_mean_axis(
mlx_array* res,
const mlx_array a,
int axis,
bool keepdims,
const mlx_stream s) {
return mlx_mean_axis_(res, a, axis, keepdims, s);
}
static inline int mlx_mean(
mlx_array* res,
const mlx_array a,
bool keepdims,
const mlx_stream s) {
return mlx_mean_(res, a, keepdims, s);
}
static inline int mlx_median(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
bool keepdims,
const mlx_stream s) {
return mlx_median_(res, a, axes, axes_num, keepdims, s);
}
static inline int mlx_meshgrid(
mlx_vector_array* res,
const mlx_vector_array arrays,
bool sparse,
const char* indexing,
const mlx_stream s) {
return mlx_meshgrid_(res, arrays, sparse, indexing, s);
}
static inline int mlx_min_axes(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
bool keepdims,
const mlx_stream s) {
return mlx_min_axes_(res, a, axes, axes_num, keepdims, s);
}
static inline int mlx_min_axis(
mlx_array* res,
const mlx_array a,
int axis,
bool keepdims,
const mlx_stream s) {
return mlx_min_axis_(res, a, axis, keepdims, s);
}
static inline int mlx_min(
mlx_array* res,
const mlx_array a,
bool keepdims,
const mlx_stream s) {
return mlx_min_(res, a, keepdims, s);
}
static inline int mlx_minimum(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s) {
return mlx_minimum_(res, a, b, s);
}
static inline int mlx_moveaxis(
mlx_array* res,
const mlx_array a,
int source,
int destination,
const mlx_stream s) {
return mlx_moveaxis_(res, a, source, destination, s);
}
static inline int mlx_multiply(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s) {
return mlx_multiply_(res, a, b, s);
}
static inline int mlx_nan_to_num(
mlx_array* res,
const mlx_array a,
float nan,
mlx_optional_float posinf,
mlx_optional_float neginf,
const mlx_stream s) {
return mlx_nan_to_num_(res, a, nan, posinf, neginf, s);
}
static inline int mlx_negative(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_negative_(res, a, s);
}
static inline int mlx_not_equal(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s) {
return mlx_not_equal_(res, a, b, s);
}
static inline int mlx_number_of_elements(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
bool inverted,
mlx_dtype dtype,
const mlx_stream s) {
return mlx_number_of_elements_(res, a, axes, axes_num, inverted, dtype, s);
}
static inline int mlx_ones(
mlx_array* res,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
const mlx_stream s) {
return mlx_ones_(res, shape, shape_num, dtype, s);
}
static inline int mlx_ones_like(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_ones_like_(res, a, s);
}
static inline int mlx_outer(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s) {
return mlx_outer_(res, a, b, s);
}
static inline int mlx_pad(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
const int* low_pad_size,
size_t low_pad_size_num,
const int* high_pad_size,
size_t high_pad_size_num,
const mlx_array pad_value,
const char* mode,
const mlx_stream s) {
return mlx_pad_(res, a, axes, axes_num, low_pad_size, low_pad_size_num, high_pad_size, high_pad_size_num, pad_value, mode, s);
}
static inline int mlx_pad_symmetric(
mlx_array* res,
const mlx_array a,
int pad_width,
const mlx_array pad_value,
const char* mode,
const mlx_stream s) {
return mlx_pad_symmetric_(res, a, pad_width, pad_value, mode, s);
}
static inline int mlx_partition_axis(
mlx_array* res,
const mlx_array a,
int kth,
int axis,
const mlx_stream s) {
return mlx_partition_axis_(res, a, kth, axis, s);
}
static inline int mlx_partition(
mlx_array* res,
const mlx_array a,
int kth,
const mlx_stream s) {
return mlx_partition_(res, a, kth, s);
}
static inline int mlx_power(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s) {
return mlx_power_(res, a, b, s);
}
static inline int mlx_prod_axes(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
bool keepdims,
const mlx_stream s) {
return mlx_prod_axes_(res, a, axes, axes_num, keepdims, s);
}
static inline int mlx_prod_axis(
mlx_array* res,
const mlx_array a,
int axis,
bool keepdims,
const mlx_stream s) {
return mlx_prod_axis_(res, a, axis, keepdims, s);
}
static inline int mlx_prod(
mlx_array* res,
const mlx_array a,
bool keepdims,
const mlx_stream s) {
return mlx_prod_(res, a, keepdims, s);
}
static inline int mlx_put_along_axis(
mlx_array* res,
const mlx_array a,
const mlx_array indices,
const mlx_array values,
int axis,
const mlx_stream s) {
return mlx_put_along_axis_(res, a, indices, values, axis, s);
}
static inline int mlx_quantize(
mlx_vector_array* res,
const mlx_array w,
mlx_optional_int group_size,
mlx_optional_int bits,
const char* mode,
const mlx_stream s) {
return mlx_quantize_(res, w, group_size, bits, mode, s);
}
static inline int mlx_quantized_matmul(
mlx_array* res,
const mlx_array x,
const mlx_array w,
const mlx_array scales,
const mlx_array biases /* may be null */,
bool transpose,
mlx_optional_int group_size,
mlx_optional_int bits,
const char* mode,
const mlx_stream s) {
return mlx_quantized_matmul_(res, x, w, scales, biases, transpose, group_size, bits, mode, s);
}
static inline int mlx_radians(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_radians_(res, a, s);
}
static inline int mlx_real(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_real_(res, a, s);
}
static inline int mlx_reciprocal(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_reciprocal_(res, a, s);
}
static inline int mlx_remainder(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s) {
return mlx_remainder_(res, a, b, s);
}
static inline int mlx_repeat_axis(
mlx_array* res,
const mlx_array arr,
int repeats,
int axis,
const mlx_stream s) {
return mlx_repeat_axis_(res, arr, repeats, axis, s);
}
static inline int mlx_repeat(
mlx_array* res,
const mlx_array arr,
int repeats,
const mlx_stream s) {
return mlx_repeat_(res, arr, repeats, s);
}
static inline int mlx_reshape(
mlx_array* res,
const mlx_array a,
const int* shape,
size_t shape_num,
const mlx_stream s) {
return mlx_reshape_(res, a, shape, shape_num, s);
}
static inline int mlx_right_shift(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s) {
return mlx_right_shift_(res, a, b, s);
}
static inline int mlx_roll_axis(
mlx_array* res,
const mlx_array a,
const int* shift,
size_t shift_num,
int axis,
const mlx_stream s) {
return mlx_roll_axis_(res, a, shift, shift_num, axis, s);
}
static inline int mlx_roll_axes(
mlx_array* res,
const mlx_array a,
const int* shift,
size_t shift_num,
const int* axes,
size_t axes_num,
const mlx_stream s) {
return mlx_roll_axes_(res, a, shift, shift_num, axes, axes_num, s);
}
static inline int mlx_roll(
mlx_array* res,
const mlx_array a,
const int* shift,
size_t shift_num,
const mlx_stream s) {
return mlx_roll_(res, a, shift, shift_num, s);
}
static inline int mlx_round(
mlx_array* res,
const mlx_array a,
int decimals,
const mlx_stream s) {
return mlx_round_(res, a, decimals, s);
}
static inline int mlx_rsqrt(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_rsqrt_(res, a, s);
}
static inline int mlx_scatter(
mlx_array* res,
const mlx_array a,
const mlx_vector_array indices,
const mlx_array updates,
const int* axes,
size_t axes_num,
const mlx_stream s) {
return mlx_scatter_(res, a, indices, updates, axes, axes_num, s);
}
static inline int mlx_scatter_add(
mlx_array* res,
const mlx_array a,
const mlx_vector_array indices,
const mlx_array updates,
const int* axes,
size_t axes_num,
const mlx_stream s) {
return mlx_scatter_add_(res, a, indices, updates, axes, axes_num, s);
}
static inline int mlx_scatter_add_axis(
mlx_array* res,
const mlx_array a,
const mlx_array indices,
const mlx_array values,
int axis,
const mlx_stream s) {
return mlx_scatter_add_axis_(res, a, indices, values, axis, s);
}
static inline int mlx_scatter_max(
mlx_array* res,
const mlx_array a,
const mlx_vector_array indices,
const mlx_array updates,
const int* axes,
size_t axes_num,
const mlx_stream s) {
return mlx_scatter_max_(res, a, indices, updates, axes, axes_num, s);
}
static inline int mlx_scatter_min(
mlx_array* res,
const mlx_array a,
const mlx_vector_array indices,
const mlx_array updates,
const int* axes,
size_t axes_num,
const mlx_stream s) {
return mlx_scatter_min_(res, a, indices, updates, axes, axes_num, s);
}
static inline int mlx_scatter_prod(
mlx_array* res,
const mlx_array a,
const mlx_vector_array indices,
const mlx_array updates,
const int* axes,
size_t axes_num,
const mlx_stream s) {
return mlx_scatter_prod_(res, a, indices, updates, axes, axes_num, s);
}
static inline int mlx_segmented_mm(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_array segments,
const mlx_stream s) {
return mlx_segmented_mm_(res, a, b, segments, s);
}
static inline int mlx_sigmoid(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_sigmoid_(res, a, s);
}
static inline int mlx_sign(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_sign_(res, a, s);
}
static inline int mlx_sin(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_sin_(res, a, s);
}
static inline int mlx_sinh(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_sinh_(res, a, s);
}
static inline int mlx_slice(
mlx_array* res,
const mlx_array a,
const int* start,
size_t start_num,
const int* stop,
size_t stop_num,
const int* strides,
size_t strides_num,
const mlx_stream s) {
return mlx_slice_(res, a, start, start_num, stop, stop_num, strides, strides_num, s);
}
static inline int mlx_slice_dynamic(
mlx_array* res,
const mlx_array a,
const mlx_array start,
const int* axes,
size_t axes_num,
const int* slice_size,
size_t slice_size_num,
const mlx_stream s) {
return mlx_slice_dynamic_(res, a, start, axes, axes_num, slice_size, slice_size_num, s);
}
static inline int mlx_slice_update(
mlx_array* res,
const mlx_array src,
const mlx_array update,
const int* start,
size_t start_num,
const int* stop,
size_t stop_num,
const int* strides,
size_t strides_num,
const mlx_stream s) {
return mlx_slice_update_(res, src, update, start, start_num, stop, stop_num, strides, strides_num, s);
}
static inline int mlx_slice_update_dynamic(
mlx_array* res,
const mlx_array src,
const mlx_array update,
const mlx_array start,
const int* axes,
size_t axes_num,
const mlx_stream s) {
return mlx_slice_update_dynamic_(res, src, update, start, axes, axes_num, s);
}
static inline int mlx_softmax_axes(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
bool precise,
const mlx_stream s) {
return mlx_softmax_axes_(res, a, axes, axes_num, precise, s);
}
static inline int mlx_softmax_axis(
mlx_array* res,
const mlx_array a,
int axis,
bool precise,
const mlx_stream s) {
return mlx_softmax_axis_(res, a, axis, precise, s);
}
static inline int mlx_softmax(
mlx_array* res,
const mlx_array a,
bool precise,
const mlx_stream s) {
return mlx_softmax_(res, a, precise, s);
}
static inline int mlx_sort_axis(
mlx_array* res,
const mlx_array a,
int axis,
const mlx_stream s) {
return mlx_sort_axis_(res, a, axis, s);
}
static inline int mlx_sort(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_sort_(res, a, s);
}
static inline int mlx_split(
mlx_vector_array* res,
const mlx_array a,
int num_splits,
int axis,
const mlx_stream s) {
return mlx_split_(res, a, num_splits, axis, s);
}
static inline int mlx_split_sections(
mlx_vector_array* res,
const mlx_array a,
const int* indices,
size_t indices_num,
int axis,
const mlx_stream s) {
return mlx_split_sections_(res, a, indices, indices_num, axis, s);
}
static inline int mlx_sqrt(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_sqrt_(res, a, s);
}
static inline int mlx_square(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_square_(res, a, s);
}
static inline int mlx_squeeze_axes(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
const mlx_stream s) {
return mlx_squeeze_axes_(res, a, axes, axes_num, s);
}
static inline int mlx_squeeze_axis(
mlx_array* res,
const mlx_array a,
int axis,
const mlx_stream s) {
return mlx_squeeze_axis_(res, a, axis, s);
}
static inline int mlx_squeeze(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_squeeze_(res, a, s);
}
static inline int mlx_stack_axis(
mlx_array* res,
const mlx_vector_array arrays,
int axis,
const mlx_stream s) {
return mlx_stack_axis_(res, arrays, axis, s);
}
static inline int mlx_stack(
mlx_array* res,
const mlx_vector_array arrays,
const mlx_stream s) {
return mlx_stack_(res, arrays, s);
}
static inline int mlx_std_axes(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
bool keepdims,
int ddof,
const mlx_stream s) {
return mlx_std_axes_(res, a, axes, axes_num, keepdims, ddof, s);
}
static inline int mlx_std_axis(
mlx_array* res,
const mlx_array a,
int axis,
bool keepdims,
int ddof,
const mlx_stream s) {
return mlx_std_axis_(res, a, axis, keepdims, ddof, s);
}
static inline int mlx_std(
mlx_array* res,
const mlx_array a,
bool keepdims,
int ddof,
const mlx_stream s) {
return mlx_std_(res, a, keepdims, ddof, s);
}
static inline int mlx_stop_gradient(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_stop_gradient_(res, a, s);
}
static inline int mlx_subtract(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s) {
return mlx_subtract_(res, a, b, s);
}
static inline int mlx_sum_axes(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
bool keepdims,
const mlx_stream s) {
return mlx_sum_axes_(res, a, axes, axes_num, keepdims, s);
}
static inline int mlx_sum_axis(
mlx_array* res,
const mlx_array a,
int axis,
bool keepdims,
const mlx_stream s) {
return mlx_sum_axis_(res, a, axis, keepdims, s);
}
static inline int mlx_sum(
mlx_array* res,
const mlx_array a,
bool keepdims,
const mlx_stream s) {
return mlx_sum_(res, a, keepdims, s);
}
static inline int mlx_swapaxes(
mlx_array* res,
const mlx_array a,
int axis1,
int axis2,
const mlx_stream s) {
return mlx_swapaxes_(res, a, axis1, axis2, s);
}
static inline int mlx_take_axis(
mlx_array* res,
const mlx_array a,
const mlx_array indices,
int axis,
const mlx_stream s) {
return mlx_take_axis_(res, a, indices, axis, s);
}
static inline int mlx_take(
mlx_array* res,
const mlx_array a,
const mlx_array indices,
const mlx_stream s) {
return mlx_take_(res, a, indices, s);
}
static inline int mlx_take_along_axis(
mlx_array* res,
const mlx_array a,
const mlx_array indices,
int axis,
const mlx_stream s) {
return mlx_take_along_axis_(res, a, indices, axis, s);
}
static inline int mlx_tan(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_tan_(res, a, s);
}
static inline int mlx_tanh(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_tanh_(res, a, s);
}
static inline int mlx_tensordot(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const int* axes_a,
size_t axes_a_num,
const int* axes_b,
size_t axes_b_num,
const mlx_stream s) {
return mlx_tensordot_(res, a, b, axes_a, axes_a_num, axes_b, axes_b_num, s);
}
static inline int mlx_tensordot_axis(
mlx_array* res,
const mlx_array a,
const mlx_array b,
int axis,
const mlx_stream s) {
return mlx_tensordot_axis_(res, a, b, axis, s);
}
static inline int mlx_tile(
mlx_array* res,
const mlx_array arr,
const int* reps,
size_t reps_num,
const mlx_stream s) {
return mlx_tile_(res, arr, reps, reps_num, s);
}
static inline int mlx_to_fp8(mlx_array* res, const mlx_array x, const mlx_stream s) {
return mlx_to_fp8_(res, x, s);
}
static inline int mlx_topk_axis(
mlx_array* res,
const mlx_array a,
int k,
int axis,
const mlx_stream s) {
return mlx_topk_axis_(res, a, k, axis, s);
}
static inline int mlx_topk(mlx_array* res, const mlx_array a, int k, const mlx_stream s) {
return mlx_topk_(res, a, k, s);
}
static inline int mlx_trace(
mlx_array* res,
const mlx_array a,
int offset,
int axis1,
int axis2,
mlx_dtype dtype,
const mlx_stream s) {
return mlx_trace_(res, a, offset, axis1, axis2, dtype, s);
}
static inline int mlx_transpose_axes(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
const mlx_stream s) {
return mlx_transpose_axes_(res, a, axes, axes_num, s);
}
static inline int mlx_transpose(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_transpose_(res, a, s);
}
static inline int mlx_tri(
mlx_array* res,
int n,
int m,
int k,
mlx_dtype type,
const mlx_stream s) {
return mlx_tri_(res, n, m, k, type, s);
}
static inline int mlx_tril(mlx_array* res, const mlx_array x, int k, const mlx_stream s) {
return mlx_tril_(res, x, k, s);
}
static inline int mlx_triu(mlx_array* res, const mlx_array x, int k, const mlx_stream s) {
return mlx_triu_(res, x, k, s);
}
static inline int mlx_unflatten(
mlx_array* res,
const mlx_array a,
int axis,
const int* shape,
size_t shape_num,
const mlx_stream s) {
return mlx_unflatten_(res, a, axis, shape, shape_num, s);
}
static inline int mlx_var_axes(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
bool keepdims,
int ddof,
const mlx_stream s) {
return mlx_var_axes_(res, a, axes, axes_num, keepdims, ddof, s);
}
static inline int mlx_var_axis(
mlx_array* res,
const mlx_array a,
int axis,
bool keepdims,
int ddof,
const mlx_stream s) {
return mlx_var_axis_(res, a, axis, keepdims, ddof, s);
}
static inline int mlx_var(
mlx_array* res,
const mlx_array a,
bool keepdims,
int ddof,
const mlx_stream s) {
return mlx_var_(res, a, keepdims, ddof, s);
}
static inline int mlx_view(
mlx_array* res,
const mlx_array a,
mlx_dtype dtype,
const mlx_stream s) {
return mlx_view_(res, a, dtype, s);
}
static inline int mlx_where(
mlx_array* res,
const mlx_array condition,
const mlx_array x,
const mlx_array y,
const mlx_stream s) {
return mlx_where_(res, condition, x, y, s);
}
static inline int mlx_zeros(
mlx_array* res,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
const mlx_stream s) {
return mlx_zeros_(res, shape, shape_num, dtype, s);
}
static inline int mlx_zeros_like(mlx_array* res, const mlx_array a, const mlx_stream s) {
return mlx_zeros_like_(res, a, s);
}
static inline int mlx_random_bernoulli(
mlx_array* res,
const mlx_array p,
const int* shape,
size_t shape_num,
const mlx_array key /* may be null */,
const mlx_stream s) {
return mlx_random_bernoulli_(res, p, shape, shape_num, key, s);
}
static inline int mlx_random_bits(
mlx_array* res,
const int* shape,
size_t shape_num,
int width,
const mlx_array key /* may be null */,
const mlx_stream s) {
return mlx_random_bits_(res, shape, shape_num, width, key, s);
}
static inline int mlx_random_categorical_shape(
mlx_array* res,
const mlx_array logits,
int axis,
const int* shape,
size_t shape_num,
const mlx_array key /* may be null */,
const mlx_stream s) {
return mlx_random_categorical_shape_(res, logits, axis, shape, shape_num, key, s);
}
static inline int mlx_random_categorical_num_samples(
mlx_array* res,
const mlx_array logits_,
int axis,
int num_samples,
const mlx_array key /* may be null */,
const mlx_stream s) {
return mlx_random_categorical_num_samples_(res, logits_, axis, num_samples, key, s);
}
static inline int mlx_random_categorical(
mlx_array* res,
const mlx_array logits,
int axis,
const mlx_array key /* may be null */,
const mlx_stream s) {
return mlx_random_categorical_(res, logits, axis, key, s);
}
static inline int mlx_random_gumbel(
mlx_array* res,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
const mlx_array key /* may be null */,
const mlx_stream s) {
return mlx_random_gumbel_(res, shape, shape_num, dtype, key, s);
}
static inline int mlx_random_key(mlx_array* res, uint64_t seed) {
return mlx_random_key_(res, seed);
}
static inline int mlx_random_laplace(
mlx_array* res,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
float loc,
float scale,
const mlx_array key /* may be null */,
const mlx_stream s) {
return mlx_random_laplace_(res, shape, shape_num, dtype, loc, scale, key, s);
}
static inline int mlx_random_multivariate_normal(
mlx_array* res,
const mlx_array mean,
const mlx_array cov,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
const mlx_array key /* may be null */,
const mlx_stream s) {
return mlx_random_multivariate_normal_(res, mean, cov, shape, shape_num, dtype, key, s);
}
static inline int mlx_random_normal_broadcast(
mlx_array* res,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
const mlx_array loc /* may be null */,
const mlx_array scale /* may be null */,
const mlx_array key /* may be null */,
const mlx_stream s) {
return mlx_random_normal_broadcast_(res, shape, shape_num, dtype, loc, scale, key, s);
}
static inline int mlx_random_normal(
mlx_array* res,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
float loc,
float scale,
const mlx_array key /* may be null */,
const mlx_stream s) {
return mlx_random_normal_(res, shape, shape_num, dtype, loc, scale, key, s);
}
static inline int mlx_random_permutation(
mlx_array* res,
const mlx_array x,
int axis,
const mlx_array key /* may be null */,
const mlx_stream s) {
return mlx_random_permutation_(res, x, axis, key, s);
}
static inline int mlx_random_permutation_arange(
mlx_array* res,
int x,
const mlx_array key /* may be null */,
const mlx_stream s) {
return mlx_random_permutation_arange_(res, x, key, s);
}
static inline int mlx_random_randint(
mlx_array* res,
const mlx_array low,
const mlx_array high,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
const mlx_array key /* may be null */,
const mlx_stream s) {
return mlx_random_randint_(res, low, high, shape, shape_num, dtype, key, s);
}
static inline int mlx_random_seed(uint64_t seed) {
return mlx_random_seed_(seed);
}
static inline int mlx_random_split_num(
mlx_array* res,
const mlx_array key,
int num,
const mlx_stream s) {
return mlx_random_split_num_(res, key, num, s);
}
static inline int mlx_random_split(
mlx_array* res_0,
mlx_array* res_1,
const mlx_array key,
const mlx_stream s) {
return mlx_random_split_(res_0, res_1, key, s);
}
static inline int mlx_random_truncated_normal(
mlx_array* res,
const mlx_array lower,
const mlx_array upper,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
const mlx_array key /* may be null */,
const mlx_stream s) {
return mlx_random_truncated_normal_(res, lower, upper, shape, shape_num, dtype, key, s);
}
static inline int mlx_random_uniform(
mlx_array* res,
const mlx_array low,
const mlx_array high,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
const mlx_array key /* may be null */,
const mlx_stream s) {
return mlx_random_uniform_(res, low, high, shape, shape_num, dtype, key, s);
}
static inline mlx_stream mlx_stream_new(void) {
return mlx_stream_new_();
}
static inline mlx_stream mlx_stream_new_device(mlx_device dev) {
return mlx_stream_new_device_(dev);
}
static inline int mlx_stream_set(mlx_stream* stream, const mlx_stream src) {
return mlx_stream_set_(stream, src);
}
static inline int mlx_stream_free(mlx_stream stream) {
return mlx_stream_free_(stream);
}
static inline int mlx_stream_tostring(mlx_string* str, mlx_stream stream) {
return mlx_stream_tostring_(str, stream);
}
static inline bool mlx_stream_equal(mlx_stream lhs, mlx_stream rhs) {
return mlx_stream_equal_(lhs, rhs);
}
static inline int mlx_stream_get_device(mlx_device* dev, mlx_stream stream) {
return mlx_stream_get_device_(dev, stream);
}
static inline int mlx_stream_get_index(int* index, mlx_stream stream) {
return mlx_stream_get_index_(index, stream);
}
static inline int mlx_synchronize(mlx_stream stream) {
return mlx_synchronize_(stream);
}
static inline int mlx_get_default_stream(mlx_stream* stream, mlx_device dev) {
return mlx_get_default_stream_(stream, dev);
}
static inline int mlx_set_default_stream(mlx_stream stream) {
return mlx_set_default_stream_(stream);
}
static inline mlx_stream mlx_default_cpu_stream_new(void) {
return mlx_default_cpu_stream_new_();
}
static inline mlx_stream mlx_default_gpu_stream_new(void) {
return mlx_default_gpu_stream_new_();
}
static inline mlx_string mlx_string_new(void) {
return mlx_string_new_();
}
static inline mlx_string mlx_string_new_data(const char* str) {
return mlx_string_new_data_(str);
}
static inline int mlx_string_set(mlx_string* str, const mlx_string src) {
return mlx_string_set_(str, src);
}
static inline const char * mlx_string_data(mlx_string str) {
return mlx_string_data_(str);
}
static inline int mlx_string_free(mlx_string str) {
return mlx_string_free_(str);
}
static inline int mlx_detail_vmap_replace(
mlx_vector_array* res,
const mlx_vector_array inputs,
const mlx_vector_array s_inputs,
const mlx_vector_array s_outputs,
const int* in_axes,
size_t in_axes_num,
const int* out_axes,
size_t out_axes_num) {
return mlx_detail_vmap_replace_(res, inputs, s_inputs, s_outputs, in_axes, in_axes_num, out_axes, out_axes_num);
}
static inline int mlx_detail_vmap_trace(
mlx_vector_array* res_0,
mlx_vector_array* res_1,
const mlx_closure fun,
const mlx_vector_array inputs,
const int* in_axes,
size_t in_axes_num) {
return mlx_detail_vmap_trace_(res_0, res_1, fun, inputs, in_axes, in_axes_num);
}
static inline int mlx_async_eval(const mlx_vector_array outputs) {
return mlx_async_eval_(outputs);
}
static inline int mlx_checkpoint(mlx_closure* res, const mlx_closure fun) {
return mlx_checkpoint_(res, fun);
}
static inline int mlx_custom_function(
mlx_closure* res,
const mlx_closure fun,
const mlx_closure_custom fun_vjp /* may be null */,
const mlx_closure_custom_jvp fun_jvp /* may be null */,
const mlx_closure_custom_vmap fun_vmap /* may be null */) {
return mlx_custom_function_(res, fun, fun_vjp, fun_jvp, fun_vmap);
}
static inline int mlx_custom_vjp(
mlx_closure* res,
const mlx_closure fun,
const mlx_closure_custom fun_vjp) {
return mlx_custom_vjp_(res, fun, fun_vjp);
}
static inline int mlx_eval(const mlx_vector_array outputs) {
return mlx_eval_(outputs);
}
static inline int mlx_jvp(
mlx_vector_array* res_0,
mlx_vector_array* res_1,
const mlx_closure fun,
const mlx_vector_array primals,
const mlx_vector_array tangents) {
return mlx_jvp_(res_0, res_1, fun, primals, tangents);
}
static inline int mlx_value_and_grad(
mlx_closure_value_and_grad* res,
const mlx_closure fun,
const int* argnums,
size_t argnums_num) {
return mlx_value_and_grad_(res, fun, argnums, argnums_num);
}
static inline int mlx_vjp(
mlx_vector_array* res_0,
mlx_vector_array* res_1,
const mlx_closure fun,
const mlx_vector_array primals,
const mlx_vector_array cotangents) {
return mlx_vjp_(res_0, res_1, fun, primals, cotangents);
}
static inline mlx_vector_array mlx_vector_array_new(void) {
return mlx_vector_array_new_();
}
static inline int mlx_vector_array_set(mlx_vector_array* vec, const mlx_vector_array src) {
return mlx_vector_array_set_(vec, src);
}
static inline int mlx_vector_array_free(mlx_vector_array vec) {
return mlx_vector_array_free_(vec);
}
static inline mlx_vector_array mlx_vector_array_new_data(const mlx_array* data, size_t size) {
return mlx_vector_array_new_data_(data, size);
}
static inline mlx_vector_array mlx_vector_array_new_value(const mlx_array val) {
return mlx_vector_array_new_value_(val);
}
static inline int mlx_vector_array_set_data(
mlx_vector_array* vec,
const mlx_array* data,
size_t size) {
return mlx_vector_array_set_data_(vec, data, size);
}
static inline int mlx_vector_array_set_value(mlx_vector_array* vec, const mlx_array val) {
return mlx_vector_array_set_value_(vec, val);
}
static inline int mlx_vector_array_append_data(
mlx_vector_array vec,
const mlx_array* data,
size_t size) {
return mlx_vector_array_append_data_(vec, data, size);
}
static inline int mlx_vector_array_append_value(mlx_vector_array vec, const mlx_array val) {
return mlx_vector_array_append_value_(vec, val);
}
static inline size_t mlx_vector_array_size(mlx_vector_array vec) {
return mlx_vector_array_size_(vec);
}
static inline int mlx_vector_array_get(
mlx_array* res,
const mlx_vector_array vec,
size_t idx) {
return mlx_vector_array_get_(res, vec, idx);
}
static inline mlx_vector_vector_array mlx_vector_vector_array_new(void) {
return mlx_vector_vector_array_new_();
}
static inline int mlx_vector_vector_array_set(
mlx_vector_vector_array* vec,
const mlx_vector_vector_array src) {
return mlx_vector_vector_array_set_(vec, src);
}
static inline int mlx_vector_vector_array_free(mlx_vector_vector_array vec) {
return mlx_vector_vector_array_free_(vec);
}
static inline mlx_vector_vector_array mlx_vector_vector_array_new_data(
const mlx_vector_array* data,
size_t size) {
return mlx_vector_vector_array_new_data_(data, size);
}
static inline mlx_vector_vector_array mlx_vector_vector_array_new_value(
const mlx_vector_array val) {
return mlx_vector_vector_array_new_value_(val);
}
static inline int mlx_vector_vector_array_set_data(
mlx_vector_vector_array* vec,
const mlx_vector_array* data,
size_t size) {
return mlx_vector_vector_array_set_data_(vec, data, size);
}
static inline int mlx_vector_vector_array_set_value(
mlx_vector_vector_array* vec,
const mlx_vector_array val) {
return mlx_vector_vector_array_set_value_(vec, val);
}
static inline int mlx_vector_vector_array_append_data(
mlx_vector_vector_array vec,
const mlx_vector_array* data,
size_t size) {
return mlx_vector_vector_array_append_data_(vec, data, size);
}
static inline int mlx_vector_vector_array_append_value(
mlx_vector_vector_array vec,
const mlx_vector_array val) {
return mlx_vector_vector_array_append_value_(vec, val);
}
static inline size_t mlx_vector_vector_array_size(mlx_vector_vector_array vec) {
return mlx_vector_vector_array_size_(vec);
}
static inline int mlx_vector_vector_array_get(
mlx_vector_array* res,
const mlx_vector_vector_array vec,
size_t idx) {
return mlx_vector_vector_array_get_(res, vec, idx);
}
static inline mlx_vector_int mlx_vector_int_new(void) {
return mlx_vector_int_new_();
}
static inline int mlx_vector_int_set(mlx_vector_int* vec, const mlx_vector_int src) {
return mlx_vector_int_set_(vec, src);
}
static inline int mlx_vector_int_free(mlx_vector_int vec) {
return mlx_vector_int_free_(vec);
}
static inline mlx_vector_int mlx_vector_int_new_data(int* data, size_t size) {
return mlx_vector_int_new_data_(data, size);
}
static inline mlx_vector_int mlx_vector_int_new_value(int val) {
return mlx_vector_int_new_value_(val);
}
static inline int mlx_vector_int_set_data(mlx_vector_int* vec, int* data, size_t size) {
return mlx_vector_int_set_data_(vec, data, size);
}
static inline int mlx_vector_int_set_value(mlx_vector_int* vec, int val) {
return mlx_vector_int_set_value_(vec, val);
}
static inline int mlx_vector_int_append_data(mlx_vector_int vec, int* data, size_t size) {
return mlx_vector_int_append_data_(vec, data, size);
}
static inline int mlx_vector_int_append_value(mlx_vector_int vec, int val) {
return mlx_vector_int_append_value_(vec, val);
}
static inline size_t mlx_vector_int_size(mlx_vector_int vec) {
return mlx_vector_int_size_(vec);
}
static inline int mlx_vector_int_get(int* res, const mlx_vector_int vec, size_t idx) {
return mlx_vector_int_get_(res, vec, idx);
}
static inline mlx_vector_string mlx_vector_string_new(void) {
return mlx_vector_string_new_();
}
static inline int mlx_vector_string_set(mlx_vector_string* vec, const mlx_vector_string src) {
return mlx_vector_string_set_(vec, src);
}
static inline int mlx_vector_string_free(mlx_vector_string vec) {
return mlx_vector_string_free_(vec);
}
static inline mlx_vector_string mlx_vector_string_new_data(const char** data, size_t size) {
return mlx_vector_string_new_data_(data, size);
}
static inline mlx_vector_string mlx_vector_string_new_value(const char* val) {
return mlx_vector_string_new_value_(val);
}
static inline int mlx_vector_string_set_data(
mlx_vector_string* vec,
const char** data,
size_t size) {
return mlx_vector_string_set_data_(vec, data, size);
}
static inline int mlx_vector_string_set_value(mlx_vector_string* vec, const char* val) {
return mlx_vector_string_set_value_(vec, val);
}
static inline int mlx_vector_string_append_data(
mlx_vector_string vec,
const char** data,
size_t size) {
return mlx_vector_string_append_data_(vec, data, size);
}
static inline int mlx_vector_string_append_value(mlx_vector_string vec, const char* val) {
return mlx_vector_string_append_value_(vec, val);
}
static inline size_t mlx_vector_string_size(mlx_vector_string vec) {
return mlx_vector_string_size_(vec);
}
static inline int mlx_vector_string_get(char** res, const mlx_vector_string vec, size_t idx) {
return mlx_vector_string_get_(res, vec, idx);
}
static inline int mlx_version(mlx_string* str_) {
return mlx_version_(str_);
}
#endif // MLX_GENERATED_H