File size: 6,544 Bytes
7885a28 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
import scipy._lib.uarray as ua
from . import _basic_backend
from . import _realtransforms_backend
from . import _fftlog_backend
class _ScipyBackend:
"""The default backend for fft calculations
Notes
-----
We use the domain ``numpy.scipy`` rather than ``scipy`` because ``uarray``
treats the domain as a hierarchy. This means the user can install a single
backend for ``numpy`` and have it implement ``numpy.scipy.fft`` as well.
"""
__ua_domain__ = "numpy.scipy.fft"
@staticmethod
def __ua_function__(method, args, kwargs):
fn = getattr(_basic_backend, method.__name__, None)
if fn is None:
fn = getattr(_realtransforms_backend, method.__name__, None)
if fn is None:
fn = getattr(_fftlog_backend, method.__name__, None)
if fn is None:
return NotImplemented
return fn(*args, **kwargs)
_named_backends = {
'scipy': _ScipyBackend,
}
def _backend_from_arg(backend):
"""Maps strings to known backends and validates the backend"""
if isinstance(backend, str):
try:
backend = _named_backends[backend]
except KeyError as e:
raise ValueError(f'Unknown backend {backend}') from e
if backend.__ua_domain__ != 'numpy.scipy.fft':
raise ValueError('Backend does not implement "numpy.scipy.fft"')
return backend
def set_global_backend(backend, coerce=False, only=False, try_last=False):
"""Sets the global fft backend
This utility method replaces the default backend for permanent use. It
will be tried in the list of backends automatically, unless the
``only`` flag is set on a backend. This will be the first tried
backend outside the :obj:`set_backend` context manager.
Parameters
----------
backend : {object, 'scipy'}
The backend to use.
Can either be a ``str`` containing the name of a known backend
{'scipy'} or an object that implements the uarray protocol.
coerce : bool
Whether to coerce input types when trying this backend.
only : bool
If ``True``, no more backends will be tried if this fails.
Implied by ``coerce=True``.
try_last : bool
If ``True``, the global backend is tried after registered backends.
Raises
------
ValueError: If the backend does not implement ``numpy.scipy.fft``.
Notes
-----
This will overwrite the previously set global backend, which, by default, is
the SciPy implementation.
Examples
--------
We can set the global fft backend:
>>> from scipy.fft import fft, set_global_backend
>>> set_global_backend("scipy") # Sets global backend (default is "scipy").
>>> fft([1]) # Calls the global backend
array([1.+0.j])
"""
backend = _backend_from_arg(backend)
ua.set_global_backend(backend, coerce=coerce, only=only, try_last=try_last)
def register_backend(backend):
"""
Register a backend for permanent use.
Registered backends have the lowest priority and will be tried after the
global backend.
Parameters
----------
backend : {object, 'scipy'}
The backend to use.
Can either be a ``str`` containing the name of a known backend
{'scipy'} or an object that implements the uarray protocol.
Raises
------
ValueError: If the backend does not implement ``numpy.scipy.fft``.
Examples
--------
We can register a new fft backend:
>>> from scipy.fft import fft, register_backend, set_global_backend
>>> class NoopBackend: # Define an invalid Backend
... __ua_domain__ = "numpy.scipy.fft"
... def __ua_function__(self, func, args, kwargs):
... return NotImplemented
>>> set_global_backend(NoopBackend()) # Set the invalid backend as global
>>> register_backend("scipy") # Register a new backend
# The registered backend is called because
# the global backend returns `NotImplemented`
>>> fft([1])
array([1.+0.j])
>>> set_global_backend("scipy") # Restore global backend to default
"""
backend = _backend_from_arg(backend)
ua.register_backend(backend)
def set_backend(backend, coerce=False, only=False):
"""Context manager to set the backend within a fixed scope.
Upon entering the ``with`` statement, the given backend will be added to
the list of available backends with the highest priority. Upon exit, the
backend is reset to the state before entering the scope.
Parameters
----------
backend : {object, 'scipy'}
The backend to use.
Can either be a ``str`` containing the name of a known backend
{'scipy'} or an object that implements the uarray protocol.
coerce : bool, optional
Whether to allow expensive conversions for the ``x`` parameter. e.g.,
copying a NumPy array to the GPU for a CuPy backend. Implies ``only``.
only : bool, optional
If only is ``True`` and this backend returns ``NotImplemented``, then a
BackendNotImplemented error will be raised immediately. Ignoring any
lower priority backends.
Examples
--------
>>> import scipy.fft as fft
>>> with fft.set_backend('scipy', only=True):
... fft.fft([1]) # Always calls the scipy implementation
array([1.+0.j])
"""
backend = _backend_from_arg(backend)
return ua.set_backend(backend, coerce=coerce, only=only)
def skip_backend(backend):
"""Context manager to skip a backend within a fixed scope.
Within the context of a ``with`` statement, the given backend will not be
called. This covers backends registered both locally and globally. Upon
exit, the backend will again be considered.
Parameters
----------
backend : {object, 'scipy'}
The backend to skip.
Can either be a ``str`` containing the name of a known backend
{'scipy'} or an object that implements the uarray protocol.
Examples
--------
>>> import scipy.fft as fft
>>> fft.fft([1]) # Calls default SciPy backend
array([1.+0.j])
>>> with fft.skip_backend('scipy'): # We explicitly skip the SciPy backend
... fft.fft([1]) # leaving no implementation available
Traceback (most recent call last):
...
BackendNotImplementedError: No selected backends had an implementation ...
"""
backend = _backend_from_arg(backend)
return ua.skip_backend(backend)
set_global_backend('scipy', try_last=True)
|