array_api_jit package¶
- array_api_jit.jit(decorator: Mapping[str, Callable[[Callable[[...], Any]], Callable[[...], Any]]] | None = None, /, *, fail_on_error: bool = False, rerun_on_error: bool = False, decorator_args: Mapping[str, Sequence[Any]] | None = None, decorator_kwargs: Mapping[str, Mapping[str, Any]] | None = None) Callable[[Callable[[P], T]], Callable[[P], T]][source]¶
Just-in-time compilation decorator with multiple backends.
- Parameters:
decorator (Mapping[str, Callable[[Callable[P, T]], Callable[P, T]]] | None, optional) – The JIT decorator to use for each array namespace, by default None
fail_on_error (bool, optional) – If True, raise an error if the JIT decorator fails to apply. If False, just warn and return the original function, by default False
rerun_on_error (bool, optional) – If True, rerun the function without JIT if the function with JIT applied fails, by default False
decorator_args (Mapping[str, Sequence[Any]] | None, optional) – Additional positional arguments to be passed along with the function to the decorator for each array namespace, by default None
decorator_kwargs (Mapping[str, Mapping[str, Any]] | None, optional) – Additional keyword arguments to be passed along with the function to the decorator for each array namespace, by default None
- Returns:
The JIT decorator that can be applied to a function.
- Return type:
Callable[[Callable[P, T]], Callable[P, T]]
Example
>>> from array_api_jit import jit >>> from array_api_compat import array_namespace >>> from typing import Any >>> import numba >>> @jit( ... {"numpy": numba.jit()}, # numba.jit is not used by default ... decorator_kwargs={"jax": {"static_argnames": ["n"]}}, # jax requires static_argnames ... ) ... def sin_n_times(x: Any, n: int) -> Any: ... xp = array_namespace(x) ... for i in range(n): ... x = xp.sin(x) ... return x