Commit e6c74a63 authored by Martin Reinecke's avatar Martin Reinecke

cleanup

parent 67d2098c
......@@ -21,7 +21,7 @@ from .multi_field import MultiField
from .operators.operator import Operator
from .operators.central_zero_padder import CentralZeroPadder
from .operators.diagonal_operator import DiagonalOperator
from .operators.dof_distributor import DOFDistributor
from .operators.distributors import DOFDistributor, PowerDistributor
from .operators.domain_distributor import DomainDistributor
from .operators.endomorphic_operator import EndomorphicOperator
from .operators.exp_transform import ExpTransform
......@@ -33,7 +33,6 @@ from .operators.inversion_enabler import InversionEnabler
from .operators.laplace_operator import LaplaceOperator
from .operators.linear_operator import LinearOperator
from .operators.mask_operator import MaskOperator
from .operators.power_distributor import PowerDistributor
from .operators.qht_operator import QHTOperator
from .operators.sampling_enabler import SamplingEnabler
from .operators.sandwich_operator import SandwichOperator
......
......@@ -24,7 +24,7 @@ from ..multi_field import MultiField
from ..multi_domain import MultiDomain
from ..operators.domain_distributor import DomainDistributor
from ..operators.harmonic_operators import HarmonicTransformOperator
from ..operators.power_distributor import PowerDistributor
from ..operators.distributors import PowerDistributor
from ..operators.operator import Operator
from ..operators.simple_linear_operators import FieldAdapter
......
......@@ -24,6 +24,7 @@ from .. import dobj
from ..compat import *
from ..domain_tuple import DomainTuple
from ..domains.dof_space import DOFSpace
from ..domains.power_space import PowerSpace
from ..field import Field
from ..utilities import infer_space, special_add_at
from .linear_operator import LinearOperator
......@@ -117,14 +118,11 @@ class DOFDistributor(LinearOperator):
oarr = np.zeros(self._hshape, dtype=x.dtype)
oarr = special_add_at(oarr, 1, self._dofdex, arr)
if dobj.distaxis(x.val) in x.domain.axes[self._space]:
oarr = dobj.np_allreduce_sum(oarr).reshape(self._domain.shape)
res = Field.from_global_data(self._domain, oarr)
oarr = oarr.reshape(self._domain.shape)
res = Field.from_global_data(self._domain, oarr, sum_up=True)
else:
oarr = oarr.reshape(dobj.local_shape(self._domain.shape,
dobj.distaxis(x.val)))
res = Field(self._domain,
dobj.from_local_data(self._domain.shape, oarr,
dobj.default_distaxis()))
oarr = oarr.reshape(self._domain.local_shape)
res = Field.from_local_data(self._domain, oarr)
return res
def _times(self, x):
......@@ -141,3 +139,37 @@ class DOFDistributor(LinearOperator):
def apply(self, x, mode):
self._check_input(x, mode)
return self._times(x) if mode == self.TIMES else self._adjoint_times(x)
class PowerDistributor(DOFDistributor):
"""Operator which transforms between a PowerSpace and a harmonic domain.
Parameters
----------
target: Domain, tuple of Domain, or DomainTuple
the total *target* domain of the operator.
power_space: PowerSpace, optional
the input sub-domain on which the operator acts.
If not supplied, a matching PowerSpace with natural binbounds will be
used.
space: int, optional:
The index of the sub-domain on which the operator acts.
Can be omitted if `target` only has one sub-domain.
"""
def __init__(self, target, power_space=None, space=None):
# Initialize domain and target
self._target = DomainTuple.make(target)
self._space = infer_space(self._target, space)
hspace = self._target[self._space]
if not hspace.harmonic:
raise ValueError("Operator requires harmonic target space")
if power_space is None:
power_space = PowerSpace(hspace)
else:
if not isinstance(power_space, PowerSpace):
raise TypeError("power_space argument must be a PowerSpace")
if power_space.harmonic_partner != hspace:
raise ValueError("power_space does not match its partner")
self._init2(power_space.pindex, self._space, power_space)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
from ..compat import *
from ..domain_tuple import DomainTuple
from ..domains.power_space import PowerSpace
from ..utilities import infer_space
from .dof_distributor import DOFDistributor
class PowerDistributor(DOFDistributor):
"""Operator which transforms between a PowerSpace and a harmonic domain.
Parameters
----------
target: Domain, tuple of Domain, or DomainTuple
the total *target* domain of the operator.
power_space: PowerSpace, optional
the input sub-domain on which the operator acts.
If not supplied, a matching PowerSpace with natural binbounds will be
used.
space: int, optional:
The index of the sub-domain on which the operator acts.
Can be omitted if `target` only has one sub-domain.
"""
def __init__(self, target, power_space=None, space=None):
# Initialize domain and target
self._target = DomainTuple.make(target)
self._space = infer_space(self._target, space)
hspace = self._target[self._space]
if not hspace.harmonic:
raise ValueError("Operator requires harmonic target space")
if power_space is None:
power_space = PowerSpace(hspace)
else:
if not isinstance(power_space, PowerSpace):
raise TypeError("power_space argument must be a PowerSpace")
if power_space.harmonic_partner != hspace:
raise ValueError("power_space does not match its partner")
self._init2(power_space.pindex, self._space, power_space)
......@@ -28,8 +28,8 @@ from .scaling_operator import ScalingOperator
class SandwichOperator(EndomorphicOperator):
"""Operator which is equivalent to the expression `bun.adjoint*cheese*bun`.
"""
"""Operator which is equivalent to the expression
`bun.adjoint(cheese(bun))`."""
def __init__(self, bun, cheese, op, _callingfrommake=False):
if not _callingfrommake:
......@@ -54,7 +54,7 @@ class SandwichOperator(EndomorphicOperator):
if not isinstance(bun, LinearOperator):
raise TypeError("bun must be a linear operator")
if cheese is not None and not isinstance(cheese, LinearOperator):
raise TypeError("cheese must be a linear operator")
raise TypeError("cheese must be a linear operator or None")
if cheese is None:
cheese = ScalingOperator(1., bun.target)
op = bun.adjoint(bun)
......@@ -70,7 +70,7 @@ class SandwichOperator(EndomorphicOperator):
return self._op.apply(x, mode)
def draw_sample(self, from_inverse=False, dtype=np.float64):
# Inverse samples from general sandwiches is not possible
# Inverse samples from general sandwiches are not possible
if from_inverse:
if self._bun.capabilities & self._bun.INVERSE_TIMES:
try:
......
......@@ -32,7 +32,7 @@ from .operators.block_diagonal_operator import BlockDiagonalOperator
from .multi_domain import MultiDomain
from .multi_field import MultiField
from .operators.diagonal_operator import DiagonalOperator
from .operators.power_distributor import PowerDistributor
from .operators.distributors import PowerDistributor
__all__ = ['PS_field', 'power_analyze', 'create_power_operator',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment