Commit 7834aff5 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

cleanup

parent dec2f197
Pipeline #18148 passed with stage
in 5 minutes and 28 seconds
......@@ -17,20 +17,15 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import division
import abc
from future.utils import with_metaclass
import numpy as np
from ... import nifty_utilities as utilities
class Transformation(with_metaclass(abc.ABCMeta,
type('NewBase', (object,), {}))):
class Transformation(object):
def __init__(self, domain, codomain):
self.domain = domain
self.codomain = codomain
@abc.abstractproperty
def unitary(self):
raise NotImplementedError
......@@ -39,17 +34,12 @@ class Transformation(with_metaclass(abc.ABCMeta,
class RGRGTransformation(Transformation):
# ---Overwritten properties and methods---
def __init__(self, domain, codomain=None):
import pyfftw
super(RGRGTransformation, self).__init__(domain, codomain)
pyfftw.interfaces.cache.enable()
self._fwd = self.codomain.harmonic
# ---Mandatory properties and methods---
@property
def unitary(self):
return True
......@@ -77,60 +67,49 @@ class RGRGTransformation(Transformation):
The axes along which the transformation should take place
"""
fct = 1.
# correct for forward/inverse fft.
# naively one would set power to 0.5 here in order to
# apply effectively a factor of 1/sqrt(N) to the field.
# BUT: the pixel volumes of the domain and codomain are different.
# Hence, in order to produce the same scalar product, power==1.
if self.codomain.harmonic:
# correct for forward fft.
# naively one would set power to 0.5 here in order to
# apply effectively a factor of 1/sqrt(N) to the field.
# BUT: the pixel volumes of the domain and codomain are different.
# Hence, in order to produce the same scalar product, power===1.
fct *= self.domain.weight()
fct = self.domain.weight()
else:
fct = 1./self.codomain.weight()
# Perform the transformation
if issubclass(val.dtype.type, np.complexfloating):
Tval_real = self._transform_helper(val.real, axes)
Tval_imag = self._transform_helper(val.imag, axes)
Tval1 = self._transform_helper(val.real, axes)
Tval2 = self._transform_helper(val.imag, axes)
if self.codomain.harmonic:
Tval_real.real += Tval_real.imag
Tval_imag.real += Tval_imag.imag
Tval_real.imag = Tval_imag.real
Tval1.real += Tval1.imag
Tval2.real += Tval2.imag
else:
Tval_real.real -= Tval_real.imag
Tval_imag.real -= Tval_imag.imag
Tval_real.imag = Tval_imag.real
Tval = Tval_real
Tval1.real -= Tval1.imag
Tval2.real -= Tval2.imag
Tval1.imag = Tval2.real
Tval = Tval1
else:
Tval = self._transform_helper(val, axes)
if self.codomain.harmonic:
Tval.real += Tval.imag
Tval = Tval.real + Tval.imag
else:
Tval.real -= Tval.imag
Tval = Tval.real
if not self.codomain.harmonic:
# correct for inverse fft.
# See discussion above.
fct /= self.codomain.weight()
Tval = Tval.real - Tval.imag
Tval *= fct
return Tval
class SlicingTransformation(Transformation):
def transform(self, val, axes=None):
return_shape = np.array(val.shape)
return_shape[list(axes)] = self.codomain.shape
return_shape = tuple(return_shape)
return_val = np.empty(return_shape, dtype=val.dtype)
return_val = np.empty(tuple(return_shape), dtype=val.dtype)
for slice_list in utilities.get_slice_list(val.shape, axes):
return_val[slice_list] = self._transformation_of_slice(
val[slice_list])
for slice in utilities.get_slice_list(val.shape, axes):
return_val[slice] = self._transformation_of_slice(val[slice])
return return_val
@abc.abstractmethod
def _transformation_of_slice(self, inp):
raise NotImplementedError
......@@ -162,9 +141,6 @@ def buildIdx(nr, lmax):
class HPLMTransformation(SlicingTransformation):
def __init__(self, domain, codomain=None):
super(HPLMTransformation, self).__init__(domain, codomain)
@property
def unitary(self):
return False
......@@ -188,9 +164,6 @@ class HPLMTransformation(SlicingTransformation):
class LMHPTransformation(SlicingTransformation):
def __init__(self, domain, codomain=None):
super(LMHPTransformation, self).__init__(domain, codomain)
@property
def unitary(self):
return False
......@@ -215,9 +188,6 @@ class LMHPTransformation(SlicingTransformation):
class GLLMTransformation(SlicingTransformation):
def __init__(self, domain, codomain=None):
super(GLLMTransformation, self).__init__(domain, codomain)
@property
def unitary(self):
return False
......@@ -243,9 +213,6 @@ class GLLMTransformation(SlicingTransformation):
class LMGLTransformation(SlicingTransformation):
def __init__(self, domain, codomain=None):
super(LMGLTransformation, self).__init__(domain, codomain)
@property
def unitary(self):
return False
......@@ -263,7 +230,6 @@ class LMGLTransformation(SlicingTransformation):
rr = buildLm(inp.real, lmax=lmax)
ri = buildLm(inp.imag, lmax=lmax)
return sjob.alm2map(rr) + 1j*sjob.alm2map(ri)
else:
result = buildLm(inp, lmax=lmax)
return sjob.alm2map(result)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment