array_api_jit package

array_api_jit.jit(decorator: Mapping[str, Callable[[Callable[[...], Any]], Callable[[...], Any]]] | None = None, /, *, fail_on_error: bool = False, rerun_on_error: bool = False, decorator_args: Mapping[str, Sequence[Any]] | None = None, decorator_kwargs: Mapping[str, Mapping[str, Any]] | None = None) Callable[[Callable[[P], T]], Callable[[P], T]][source]

Just-in-time compilation decorator with multiple backends.

Parameters:
  • decorator (Mapping[str, Callable[[Callable[P, T]], Callable[P, T]]] | None, optional) – The JIT decorator to use for each array namespace, by default None

  • fail_on_error (bool, optional) – If True, raise an error if the JIT decorator fails to apply. If False, just warn and return the original function, by default False

  • rerun_on_error (bool, optional) – If True, rerun the function without JIT if the function with JIT applied fails, by default False

  • decorator_args (Mapping[str, Sequence[Any]] | None, optional) – Additional positional arguments to be passed along with the function to the decorator for each array namespace, by default None

  • decorator_kwargs (Mapping[str, Mapping[str, Any]] | None, optional) – Additional keyword arguments to be passed along with the function to the decorator for each array namespace, by default None

Returns:

The JIT decorator that can be applied to a function.

Return type:

Callable[[Callable[P, T]], Callable[P, T]]

Example

>>> from array_api_jit import jit
>>> from array_api_compat import array_namespace
>>> from typing import Any
>>> import numba
>>> @jit(
...     {"numpy": numba.jit()},  # numba.jit is not used by default
...     decorator_kwargs={"jax": {"static_argnames": ["n"]}},  # jax requires static_argnames
... )
... def sin_n_times(x: Any, n: int) -> Any:
...     xp = array_namespace(x)
...     for i in range(n):
...         x = xp.sin(x)
...     return x