Source code for array_api_jit._main

import importlib.util
import warnings
from collections.abc import Callable, Mapping, Sequence
from functools import cache, wraps
from types import ModuleType
from typing import Any, ParamSpec, TypeVar

from array_api_compat import (
    array_namespace,
    is_cupy_namespace,
    is_dask_namespace,
    is_jax_namespace,
    is_numpy_namespace,
    is_torch_namespace,
)

if importlib.util.find_spec("numba"):
    import numpy as np
    from numba.extending import overload

    @overload(array_namespace)
    def _array_namespace_overload(*args: Any) -> Any:
        def inner(*args: Any) -> Any:
            return np

        return inner


P = ParamSpec("P")
T = TypeVar("T")
Pin = ParamSpec("Pin")
Tin = TypeVar("Tin")
Pinner = ParamSpec("Pinner")
Tinner = TypeVar("Tinner")
STR_TO_IS_NAMESPACE = {
    "numpy": is_numpy_namespace,
    "jax": is_jax_namespace,
    "cupy": is_cupy_namespace,
    "torch": is_torch_namespace,
    "dask": is_dask_namespace,
}


def _default_decorator(
    module: ModuleType,
    /,
) -> Callable[[Callable[P, T]], Callable[P, T]]:
    if is_jax_namespace(module):
        import jax

        return jax.jit
    elif is_numpy_namespace(module) or is_cupy_namespace(module):
        # import numba

        # return numba.jit()
        # The success rate of numba.jit is low
        return lambda x: x
    elif is_torch_namespace(module):
        import torch

        return torch.compile
    elif is_dask_namespace(module):
        return lambda x: x
    else:
        return getattr(module, "jit", lambda x: x)


Decorator = Callable[[Callable[Pin, Tin]], Callable[Pin, Tin]]


[docs] def jit( decorator: Mapping[str, Decorator[..., Any]] | None = None, /, *, fail_on_error: bool = False, rerun_on_error: bool = False, decorator_args: Mapping[str, Sequence[Any]] | None = None, decorator_kwargs: Mapping[str, Mapping[str, Any]] | None = None, ) -> Callable[[Callable[P, T]], Callable[P, T]]: """ Just-in-time compilation decorator with multiple backends. Parameters ---------- decorator : Mapping[str, Callable[[Callable[P, T]], Callable[P, T]]] | None, optional The JIT decorator to use for each array namespace, by default None fail_on_error : bool, optional If True, raise an error if the JIT decorator fails to apply. If False, just warn and return the original function, by default False rerun_on_error : bool, optional If True, rerun the function without JIT if the function with JIT applied fails, by default False decorator_args : Mapping[str, Sequence[Any]] | None, optional Additional positional arguments to be passed along with the function to the decorator for each array namespace, by default None decorator_kwargs : Mapping[str, Mapping[str, Any]] | None, optional Additional keyword arguments to be passed along with the function to the decorator for each array namespace, by default None Returns ------- Callable[[Callable[P, T]], Callable[P, T]] The JIT decorator that can be applied to a function. Example ------- >>> from array_api_jit import jit >>> from array_api_compat import array_namespace >>> from typing import Any >>> import numba >>> @jit( ... {"numpy": numba.jit()}, # numba.jit is not used by default ... decorator_kwargs={"jax": {"static_argnames": ["n"]}}, # jax requires static_argnames ... ) ... def sin_n_times(x: Any, n: int) -> Any: ... xp = array_namespace(x) ... for i in range(n): ... x = xp.sin(x) ... return x """ def new_decorator(f: Callable[Pinner, Tinner]) -> Callable[Pinner, Tinner]: decorator_args_ = decorator_args or {} decorator_kwargs_ = decorator_kwargs or {} decorator_ = decorator or {} @cache def jit_cached(xp: ModuleType) -> Callable[Pinner, Tinner]: for name_, is_namespace in STR_TO_IS_NAMESPACE.items(): if is_namespace(xp): name = name_ else: name = xp.__name__.split(".")[0] decorator_args__ = decorator_args_.get(name, ()) decorator_kwargs__ = decorator_kwargs_.get(name, {}) if name in decorator_: decorator_current = decorator_[name] else: decorator_current = _default_decorator(xp) try: return decorator_current(f, *decorator_args__, **decorator_kwargs__) except Exception as e: if fail_on_error: raise RuntimeError(f"Failed to apply JIT decorator for {name}") from e warnings.warn( f"Failed to apply JIT decorator for {name}: {e}", RuntimeWarning, stacklevel=2, ) return f @wraps(f) def inner(*args_inner: Pinner.args, **kwargs_inner: Pinner.kwargs) -> Tinner: try: xp = array_namespace(*args_inner) except TypeError as e: if e.args[0] == "Unrecognized array input": return f(*args_inner, **kwargs_inner) raise f_jit = jit_cached(xp) try: return f_jit(*args_inner, **kwargs_inner) except Exception as e: if rerun_on_error: warnings.warn( f"JIT failed for {xp.__name__}: {e}. Rerunning without JIT.", RuntimeWarning, stacklevel=2, ) return f(*args_inner, **kwargs_inner) raise RuntimeError(f"Failed to run JIT function for {xp.__name__}") from e return inner return new_decorator