Commit 5d2241a3 authored by Martin Reinecke's avatar Martin Reinecke

consolidation

parent ba6dc97b
......@@ -25,11 +25,10 @@ from .operators.dof_distributor import DOFDistributor
from .operators.domain_distributor import DomainDistributor
from .operators.endomorphic_operator import EndomorphicOperator
from .operators.exp_transform import ExpTransform
from .operators.fft_operator import FFTOperator
from .operators.harmonic_operators import (
FFTOperator, HartleyOperator, SHTOperator, HarmonicTransformOperator,
HarmonicSmoothingOperator)
from .operators.field_zero_padder import FieldZeroPadder
from .operators.hartley_operator import HartleyOperator
from .operators.harmonic_smoothing_operator import HarmonicSmoothingOperator
from .operators.harmonic_transform_operator import HarmonicTransformOperator
from .operators.inversion_enabler import InversionEnabler
from .operators.laplace_operator import LaplaceOperator
from .operators.linear_operator import LinearOperator
......
......@@ -608,17 +608,6 @@ class Field(object):
self._domain.__str__() + \
"\n- val = " + repr(self._val)
def isEquivalentTo(self, other):
"""Determines (as quickly as possible) whether `self`'s content is
identical to `other`'s content."""
if self is other:
return True
if not isinstance(other, Field):
return False
if self._domain is not other._domain:
return False
return (self._val == other._val).all()
def extract(self, dom):
if dom is not self._domain:
raise ValueError("domain mismatch")
......
......@@ -23,7 +23,7 @@ from ..domain_tuple import DomainTuple
from ..multi_field import MultiField
from ..multi_domain import MultiDomain
from ..operators.domain_distributor import DomainDistributor
from ..operators.harmonic_transform_operator import HarmonicTransformOperator
from ..operators.harmonic_operators import HarmonicTransformOperator
from ..operators.power_distributor import PowerDistributor
from ..operators.operator import Operator
......
......@@ -199,22 +199,10 @@ class MultiField(object):
return True
return False
def isEquivalentTo(self, other):
"""Determines (as quickly as possible) whether `self`'s content is
identical to `other`'s content."""
if self is other:
return True
if not isinstance(other, MultiField):
return False
if self._domain is not other._domain:
return False
for v1, v2 in zip(self._val, other._val):
if not v1.isEquivalentTo(v2):
return False
return True
def extract(self, subset):
if isinstance(subset, MultiDomain):
if subset is self._domain:
return self
return MultiField(subset,
tuple(self[key] for key in subset.keys()))
else:
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
import numpy as np
from .. import dobj, utilities
from ..compat import *
from ..domain_tuple import DomainTuple
from ..domains.rg_space import RGSpace
from ..field import Field
from .linear_operator import LinearOperator
class FFTOperator(LinearOperator):
"""Transforms between a pair of position and harmonic RGSpaces.
Parameters
----------
domain: Domain, tuple of Domain or DomainTuple
The domain of the data that is input by "times" and output by
"adjoint_times".
target: Domain, optional
The target (sub-)domain of the transform operation.
If omitted, a domain will be chosen automatically.
space: int, optional
The index of the subdomain on which the operator should act
If None, it is set to 0 if `domain` contains exactly one space.
`domain[space]` must be an RGSpace.
Notes
-----
This operator performs full FFTs, which implies that its output field will
always have complex type, regardless of the type of the input field.
If a real field is desired after a forward/backward transform couple, it
must be manually cast to real.
"""
def __init__(self, domain, target=None, space=None):
super(FFTOperator, self).__init__()
# Initialize domain and target
self._domain = DomainTuple.make(domain)
self._space = utilities.infer_space(self._domain, space)
adom = self._domain[self._space]
if not isinstance(adom, RGSpace):
raise TypeError("FFTOperator only works on RGSpaces")
if target is None:
target = adom.get_default_codomain()
self._target = [dom for dom in self._domain]
self._target[self._space] = target
self._target = DomainTuple.make(self._target)
adom.check_codomain(target)
target.check_codomain(adom)
utilities.fft_prep()
def apply(self, x, mode):
from pyfftw.interfaces.numpy_fft import fftn, ifftn
self._check_input(x, mode)
ncells = x.domain[self._space].size
if x.domain[self._space].harmonic: # harmonic -> position
func = fftn
fct = 1.
else:
func = ifftn
fct = ncells
axes = x.domain.axes[self._space]
tdom = self._tgt(mode)
oldax = dobj.distaxis(x.val)
if oldax not in axes: # straightforward, no redistribution needed
ldat = x.local_data
ldat = func(ldat, axes=axes)
tmp = dobj.from_local_data(x.val.shape, ldat, distaxis=oldax)
elif len(axes) < len(x.shape) or len(axes) == 1:
# we can use one FFT pass in between the redistributions
tmp = dobj.redistribute(x.val, nodist=axes)
newax = dobj.distaxis(tmp)
ldat = dobj.local_data(tmp)
ldat = func(ldat, axes=axes)
tmp = dobj.from_local_data(tmp.shape, ldat, distaxis=newax)
tmp = dobj.redistribute(tmp, dist=oldax)
else: # two separate FFTs needed
rem_axes = tuple(i for i in axes if i != oldax)
tmp = x.val
ldat = dobj.local_data(tmp)
ldat = func(ldat, axes=rem_axes)
if oldax != 0:
raise ValueError("bad distribution")
ldat2 = ldat.reshape((ldat.shape[0],
np.prod(ldat.shape[1:])))
shp2d = (x.val.shape[0], np.prod(x.val.shape[1:]))
tmp = dobj.from_local_data(shp2d, ldat2, distaxis=0)
tmp = dobj.transpose(tmp)
ldat2 = dobj.local_data(tmp)
ldat2 = func(ldat2, axes=(1,))
tmp = dobj.from_local_data(tmp.shape, ldat2, distaxis=0)
tmp = dobj.transpose(tmp)
ldat2 = dobj.local_data(tmp).reshape(ldat.shape)
tmp = dobj.from_local_data(x.val.shape, ldat2, distaxis=0)
Tval = Field(tdom, tmp)
if mode & (LinearOperator.TIMES | LinearOperator.ADJOINT_TIMES):
fct *= self._domain[self._space].scalar_dvol
else:
fct *= self._target[self._space].scalar_dvol
return Tval if fct == 1 else Tval*fct
@property
def domain(self):
return self._domain
@property
def target(self):
return self._target
@property
def capability(self):
return self._all_ops
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
from ..compat import *
from ..domain_tuple import DomainTuple
from ..utilities import infer_space
from .diagonal_operator import DiagonalOperator
from .hartley_operator import HartleyOperator
from .scaling_operator import ScalingOperator
def HarmonicSmoothingOperator(domain, sigma, space=None):
""" This function returns an operator that carries out a smoothing with
a Gaussian kernel of width `sigma` on the part of `domain` given by
`space`.
Parameters
----------
domain : Domain, tuple of Domain, or DomainTuple
The total domain of the operator's input and output fields
sigma : float>=0
The sigma of the Gaussian used for smoothing. It has the same units as
the RGSpace the operator is working on.
If `sigma==0`, an identity operator will be returned.
space : int, optional
The index of the sub-domain on which the smoothing is performed.
Can be omitted if `domain` only has one sub-domain.
Notes
-----
The sub-domain on which the smoothing is carried out *must* be a
non-harmonic `RGSpace`.
"""
sigma = float(sigma)
if sigma < 0.:
raise ValueError("sigma must be nonnegative")
if sigma == 0.:
return ScalingOperator(1., domain)
domain = DomainTuple.make(domain)
space = infer_space(domain, space)
if domain[space].harmonic:
raise TypeError("domain must not be harmonic")
Hartley = HartleyOperator(domain, space=space)
codomain = Hartley.domain[space].get_default_codomain()
kernel = codomain.get_k_length_array()
smoother = codomain.get_fft_smoothing_kernel_function(sigma)
kernel = smoother(kernel)
ddom = list(domain)
ddom[space] = codomain
diag = DiagonalOperator(kernel, ddom, space)
return Hartley.inverse(diag(Hartley))
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
from .. import utilities
from ..compat import *
from ..domain_tuple import DomainTuple
from ..domains.rg_space import RGSpace
from .hartley_operator import HartleyOperator
from .linear_operator import LinearOperator
from .sht_operator import SHTOperator
class HarmonicTransformOperator(LinearOperator):
"""Transforms between a harmonic domain and a position domain counterpart.
Built-in domain pairs are
- a harmonic and a non-harmonic RGSpace (with matching distances)
- an LMSpace and a HPSpace
- an LMSpace and a GLSpace
The supported operations are times() and adjoint_times().
Parameters
----------
domain : Domain, tuple of Domain or DomainTuple
The domain of the data that is input by "times" and output by
"adjoint_times".
target : Domain, optional
The target domain of the transform operation.
If omitted, a domain will be chosen automatically.
Whenever the input domain of the transform is an RGSpace, the codomain
(and its parameters) are uniquely determined.
For LMSpace, a GLSpace of sufficient resolution is chosen.
space : int, optional
The index of the domain on which the operator should act
If None, it is set to 0 if domain contains exactly one subdomain.
domain[space] must be a harmonic domain.
"""
def __init__(self, domain, target=None, space=None):
super(HarmonicTransformOperator, self).__init__()
domain = DomainTuple.make(domain)
space = utilities.infer_space(domain, space)
hspc = domain[space]
if not hspc.harmonic:
raise TypeError(
"HarmonicTransformOperator only works on a harmonic space")
if isinstance(hspc, RGSpace):
self._op = HartleyOperator(domain, target, space)
else:
self._op = SHTOperator(domain, target, space)
def apply(self, x, mode):
self._check_input(x, mode)
return self._op.apply(x, mode)
@property
def domain(self):
return self._op.domain
@property
def target(self):
return self._op.target
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
import numpy as np
from .. import dobj, utilities
from ..compat import *
from ..domain_tuple import DomainTuple
from ..domains.gl_space import GLSpace
from ..domains.lm_space import LMSpace
from ..field import Field
from .linear_operator import LinearOperator
class SHTOperator(LinearOperator):
"""Transforms between a harmonic domain on the sphere and a position
domain counterpart.
Built-in domain pairs are
- an LMSpace and a HPSpace
- an LMSpace and a GLSpace
The supported operations are times() and adjoint_times().
Parameters
----------
domain : Domain, tuple of Domain or DomainTuple
The domain of the data that is input by "times" and output by
"adjoint_times".
target : Domain, optional
The target domain of the transform operation.
If omitted, a domain will be chosen automatically.
Whenever the input domain of the transform is an RGSpace, the codomain
(and its parameters) are uniquely determined.
For LMSpace, a GLSpace of sufficient resolution is chosen.
space : int, optional
The index of the domain on which the operator should act
If None, it is set to 0 if domain contains exactly one subdomain.
domain[space] must be a LMSpace.
"""
def __init__(self, domain, target=None, space=None):
super(SHTOperator, self).__init__()
# Initialize domain and target
self._domain = DomainTuple.make(domain)
self._space = utilities.infer_space(self._domain, space)
hspc = self._domain[self._space]
if not isinstance(hspc, LMSpace):
raise TypeError("SHTOperator only works on a LMSpace domain")
if target is None:
target = hspc.get_default_codomain()
self._target = [dom for dom in self._domain]
self._target[self._space] = target
self._target = DomainTuple.make(self._target)
hspc.check_codomain(target)
target.check_codomain(hspc)
from pyHealpix import sharpjob_d
self.lmax = hspc.lmax
self.mmax = hspc.mmax
self.sjob = sharpjob_d()
self.sjob.set_triangular_alm_info(self.lmax, self.mmax)
if isinstance(target, GLSpace):
self.sjob.set_Gauss_geometry(target.nlat, target.nlon)
else:
self.sjob.set_Healpix_geometry(target.nside)
def apply(self, x, mode):
self._check_input(x, mode)
if utilities.iscomplextype(x.dtype):
return (self._apply_spherical(x.real, mode) +
1j*self._apply_spherical(x.imag, mode))
else:
return self._apply_spherical(x, mode)
def _slice_p2h(self, inp):
rr = self.sjob.alm2map_adjoint(inp)
if len(rr) != ((self.mmax+1)*(self.mmax+2))//2 + \
(self.mmax+1)*(self.lmax-self.mmax):
raise ValueError("array length mismatch")
res = np.empty(2*len(rr)-self.lmax-1, dtype=rr[0].real.dtype)
res[0:self.lmax+1] = rr[0:self.lmax+1].real
res[self.lmax+1::2] = np.sqrt(2)*rr[self.lmax+1:].real
res[self.lmax+2::2] = np.sqrt(2)*rr[self.lmax+1:].imag
return res/np.sqrt(np.pi*4)
def _slice_h2p(self, inp):
res = np.empty((len(inp)+self.lmax+1)//2, dtype=(inp[0]*1j).dtype)
if len(res) != ((self.mmax+1)*(self.mmax+2))//2 + \
(self.mmax+1)*(self.lmax-self.mmax):
raise ValueError("array length mismatch")
res[0:self.lmax+1] = inp[0:self.lmax+1]
res[self.lmax+1:] = np.sqrt(0.5)*(inp[self.lmax+1::2] +
1j*inp[self.lmax+2::2])
res = self.sjob.alm2map(res)
return res/np.sqrt(np.pi*4)
def _apply_spherical(self, x, mode):
axes = x.domain.axes[self._space]
axis = axes[0]
tval = x.val
if dobj.distaxis(tval) == axis:
tval = dobj.redistribute(tval, nodist=(axis,))
distaxis = dobj.distaxis(tval)
p2h = not x.domain[self._space].harmonic
tdom = self._tgt(mode)
func = self._slice_p2h if p2h else self._slice_h2p
idat = dobj.local_data(tval)
odat = np.empty(dobj.local_shape(tdom.shape, distaxis=distaxis),
dtype=x.dtype)
for slice in utilities.get_slice_list(idat.shape, axes):
odat[slice] = func(idat[slice])
odat = dobj.from_local_data(tdom.shape, odat, distaxis)
if distaxis != dobj.distaxis(x.val):
odat = dobj.redistribute(odat, dist=dobj.distaxis(x.val))
return Field(tdom, odat)
@property
def domain(self):
return self._domain
@property
def target(self):
return self._target
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment