Skip to content
Snippets Groups Projects
Commit c15eb016 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

cleanup

parent e8f33125
No related branches found
No related tags found
1 merge request!209WIP: Byebye volume factors
Pipeline #
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
import numpy as np import numpy as np
from .. import DomainTuple from .. import DomainTuple
from ..spaces import RGSpace from ..spaces import RGSpace
from ..utilities import infer_space
from .linear_operator import LinearOperator from .linear_operator import LinearOperator
from .. import dobj from .. import dobj
from .. import utilities from .. import utilities
...@@ -65,7 +64,7 @@ class FFTOperator(LinearOperator): ...@@ -65,7 +64,7 @@ class FFTOperator(LinearOperator):
# Initialize domain and target # Initialize domain and target
self._domain = DomainTuple.make(domain) self._domain = DomainTuple.make(domain)
self._space = infer_space(self._domain, space) self._space = utilities.infer_space(self._domain, space)
adom = self._domain[self._space] adom = self._domain[self._space]
if target is None: if target is None:
...@@ -87,8 +86,8 @@ class FFTOperator(LinearOperator): ...@@ -87,8 +86,8 @@ class FFTOperator(LinearOperator):
self._applyfunc = self._apply_spherical self._applyfunc = self._apply_spherical
hspc = adom if adom.harmonic else target hspc = adom if adom.harmonic else target
pspc = target if adom.harmonic else adom pspc = target if adom.harmonic else adom
self.lmax=hspc.lmax self.lmax = hspc.lmax
self.mmax=hspc.mmax self.mmax = hspc.mmax
self.sjob = sharpjob_d() self.sjob = sharpjob_d()
self.sjob.set_triangular_alm_info(self.lmax, self.mmax) self.sjob.set_triangular_alm_info(self.lmax, self.mmax)
if isinstance(pspc, GLSpace): if isinstance(pspc, GLSpace):
...@@ -120,7 +119,7 @@ class FFTOperator(LinearOperator): ...@@ -120,7 +119,7 @@ class FFTOperator(LinearOperator):
""" """
from pyfftw.interfaces.numpy_fft import fftn from pyfftw.interfaces.numpy_fft import fftn
axes = x.domain.axes[self._space] axes = x.domain.axes[self._space]
tdom = self._target if x.domain==self._domain else self._domain tdom = self._target if x.domain == self._domain else self._domain
oldax = dobj.distaxis(x.val) oldax = dobj.distaxis(x.val)
if oldax not in axes: # straightforward, no redistribution needed if oldax not in axes: # straightforward, no redistribution needed
ldat = dobj.local_data(x.val) ldat = dobj.local_data(x.val)
...@@ -161,18 +160,10 @@ class FFTOperator(LinearOperator): ...@@ -161,18 +160,10 @@ class FFTOperator(LinearOperator):
ldat2 = dobj.local_data(tmp).reshape(ldat.shape) ldat2 = dobj.local_data(tmp).reshape(ldat.shape)
tmp = dobj.from_local_data(x.val.shape, ldat2, distaxis=0) tmp = dobj.from_local_data(x.val.shape, ldat2, distaxis=0)
Tval = Field(tdom, tmp) Tval = Field(tdom, tmp)
if x.domain[self._space].harmonic: if mode & (LinearOperator.TIMES | LinearOperator.ADJOINT_TIMES):
if (mode == LinearOperator.TIMES or fct = self._domain[self._space].scalar_dvol()
mode == LinearOperator.ADJOINT_TIMES):
fct = self._domain[self._space].scalar_dvol()
else:
fct = 1./(self._domain[self._space].scalar_dvol()*self._domain[self._space].dim)
else: else:
if (mode == LinearOperator.TIMES or fct = self._target[self._space].scalar_dvol()
mode == LinearOperator.ADJOINT_TIMES):
fct = 1./(self._target[self._space].scalar_dvol()*self._target[self._space].dim)
else:
fct = self._target[self._space].scalar_dvol()
if fct != 1: if fct != 1:
Tval *= fct Tval *= fct
...@@ -207,7 +198,7 @@ class FFTOperator(LinearOperator): ...@@ -207,7 +198,7 @@ class FFTOperator(LinearOperator):
distaxis = dobj.distaxis(tval) distaxis = dobj.distaxis(tval)
p2h = not x.domain[self._space].harmonic p2h = not x.domain[self._space].harmonic
tdom = self._target if x.domain==self._domain else self._domain tdom = self._target if x.domain == self._domain else self._domain
func = self._slice_p2h if p2h else self._slice_h2p func = self._slice_p2h if p2h else self._slice_h2p
idat = dobj.local_data(tval) idat = dobj.local_data(tval)
odat = np.empty(dobj.local_shape(tdom.shape, distaxis=distaxis), odat = np.empty(dobj.local_shape(tdom.shape, distaxis=distaxis),
......
...@@ -17,7 +17,7 @@ class ResponseOperator_Tests(unittest.TestCase): ...@@ -17,7 +17,7 @@ class ResponseOperator_Tests(unittest.TestCase):
@expand(product(spaces, [0., 5., 1.], [0., 1., .33])) @expand(product(spaces, [0., 5., 1.], [0., 1., .33]))
def test_times_adjoint_times(self, space, sigma, sensitivity): def test_times_adjoint_times(self, space, sigma, sensitivity):
if not isinstance(space, ift.RGSpace): # no smoothing supported if not isinstance(space, ift.RGSpace): # no smoothing supported
sigma = 0. sigma = 0.
op = ift.ResponseOperator(space, sigma=[sigma], op = ift.ResponseOperator(space, sigma=[sigma],
sensitivity=[sensitivity]) sensitivity=[sensitivity])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment