Skip to content

Draft: GPU backend for old numpy nifty

Philipp Arras requested to merge g-philipp/nifty:cupy_backend into NIFTy_8

The old numpy-based nifty is now GPU-compatible. In the best case, you can specify device_id=0 for optimize_kl and your optimization runs on the GPU.

The basic idea is that ift.Field.val is not a numpy array anymore, but rather a so-called ift.AnyArray. It provides an abstraction layer around, e.g., a np.ndarray but it can also hold an ndarray of the GPU (cupy.ndarray).

ift.Field, ift.MultiField and ift.AnyArray provide the method .copy_to(device_id) which returns an instance of the respective class on the specified device. device_id=-1 corresponds to the host, device_id=0 is the first GPU, and so on. If you need the underlying np.ndarray of a Field, irrespective of where it is stored, use .asnumpy().

In practice this means that you may need to replace many .val calls, e.g., in plotting routines with .asnumpy(). This breaks the interface but is intended to avoid silent failures.

Merge request reports

Loading