Draft: GPU backend for old numpy nifty
The old numpy-based nifty is now GPU-compatible. In the best case, you can
specify device_id=0
for optimize_kl
and your optimization runs on the GPU.
The basic idea is that ift.Field.val
is not a numpy array anymore, but rather
a so-called ift.AnyArray
. It provides an abstraction layer around, e.g., a
np.ndarray
but it can also hold an ndarray of the GPU (cupy.ndarray
).
ift.Field
, ift.MultiField
and ift.AnyArray
provide the method
.copy_to(device_id)
which returns an instance of the respective class on the
specified device. device_id=-1
corresponds to the host, device_id=0
is the
first GPU, and so on. If you need the underlying np.ndarray
of a Field
,
irrespective of where it is stored, use .asnumpy()
.
In practice this means that you may need to replace many .val
calls, e.g., in
plotting routines with .asnumpy()
. This breaks the interface but is intended
to avoid silent failures.