Commit f099ab83 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

add missing file

parent 0f88177b
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <>.
# Copyright(C) 2013-2018 Max-Planck-Society
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
import numpy as np
from .. import dobj, utilities
from ..compat import *
from ..domain_tuple import DomainTuple
from import RGSpace
from ..field import Field
from .linear_operator import LinearOperator
class HartleyOperator(LinearOperator):
"""Transforms between a pair of position and harmonic RGSpaces.
domain: Domain, tuple of Domain or DomainTuple
The domain of the data that is input by "times" and output by
target: Domain, optional
The target (sub-)domain of the transform operation.
If omitted, a domain will be chosen automatically.
space: int, optional
The index of the subdomain on which the operator should act
If None, it is set to 0 if `domain` contains exactly one space.
`domain[space]` must be an RGSpace.
def __init__(self, domain, target=None, space=None):
super(HartleyOperator, self).__init__()
# Initialize domain and target
self._domain = DomainTuple.make(domain)
self._space = utilities.infer_space(self._domain, space)
adom = self._domain[self._space]
if not isinstance(adom, RGSpace):
raise TypeError("HartleyOperator only works on RGSpaces")
if target is None:
target = adom.get_default_codomain()
self._target = [dom for dom in self._domain]
self._target[self._space] = target
self._target = DomainTuple.make(self._target)
def apply(self, x, mode):
self._check_input(x, mode)
if np.issubdtype(x.dtype, np.complexfloating):
return (self._apply_cartesian(x.real, mode) +
1j*self._apply_cartesian(x.imag, mode))
return self._apply_cartesian(x, mode)
def _apply_cartesian(self, x, mode):
axes = x.domain.axes[self._space]
tdom = self._tgt(mode)
oldax = dobj.distaxis(x.val)
if oldax not in axes: # straightforward, no redistribution needed
ldat = x.local_data
ldat = utilities.hartley(ldat, axes=axes)
tmp = dobj.from_local_data(x.val.shape, ldat, distaxis=oldax)
elif len(axes) < len(x.shape) or len(axes) == 1:
# we can use one Hartley pass in between the redistributions
tmp = dobj.redistribute(x.val, nodist=axes)
newax = dobj.distaxis(tmp)
ldat = dobj.local_data(tmp)
ldat = utilities.hartley(ldat, axes=axes)
tmp = dobj.from_local_data(tmp.shape, ldat, distaxis=newax)
tmp = dobj.redistribute(tmp, dist=oldax)
else: # two separate, full FFTs needed
# ideal strategy for the moment would be:
# - do real-to-complex FFT on all local axes
# - fill up array
# - redistribute array
# - do complex-to-complex FFT on remaining axis
# - add re+im
# - redistribute back
rem_axes = tuple(i for i in axes if i != oldax)
tmp = x.val
ldat = dobj.local_data(tmp)
ldat = utilities.my_fftn_r2c(ldat, axes=rem_axes)
if oldax != 0:
raise ValueError("bad distribution")
ldat2 = ldat.reshape((ldat.shape[0],[1:])))
shp2d = (x.val.shape[0],[1:]))
tmp = dobj.from_local_data(shp2d, ldat2, distaxis=0)
tmp = dobj.transpose(tmp)
ldat2 = dobj.local_data(tmp)
ldat2 = utilities.my_fftn(ldat2, axes=(1,))
ldat2 = ldat2.real+ldat2.imag
tmp = dobj.from_local_data(tmp.shape, ldat2, distaxis=0)
tmp = dobj.transpose(tmp)
ldat2 = dobj.local_data(tmp).reshape(ldat.shape)
tmp = dobj.from_local_data(x.val.shape, ldat2, distaxis=0)
Tval = Field(tdom, tmp)
if mode & (LinearOperator.TIMES | LinearOperator.ADJOINT_TIMES):
fct = self._domain[self._space].scalar_dvol
fct = self._target[self._space].scalar_dvol
return Tval if fct == 1 else Tval*fct
def domain(self):
return self._domain
def target(self):
return self._target
def capability(self):
return self._all_ops
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment