Skip to content
Snippets Groups Projects
Commit e6c74a63 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

cleanup

parent 67d2098c
No related branches found
No related tags found
No related merge requests found
...@@ -21,7 +21,7 @@ from .multi_field import MultiField ...@@ -21,7 +21,7 @@ from .multi_field import MultiField
from .operators.operator import Operator from .operators.operator import Operator
from .operators.central_zero_padder import CentralZeroPadder from .operators.central_zero_padder import CentralZeroPadder
from .operators.diagonal_operator import DiagonalOperator from .operators.diagonal_operator import DiagonalOperator
from .operators.dof_distributor import DOFDistributor from .operators.distributors import DOFDistributor, PowerDistributor
from .operators.domain_distributor import DomainDistributor from .operators.domain_distributor import DomainDistributor
from .operators.endomorphic_operator import EndomorphicOperator from .operators.endomorphic_operator import EndomorphicOperator
from .operators.exp_transform import ExpTransform from .operators.exp_transform import ExpTransform
...@@ -33,7 +33,6 @@ from .operators.inversion_enabler import InversionEnabler ...@@ -33,7 +33,6 @@ from .operators.inversion_enabler import InversionEnabler
from .operators.laplace_operator import LaplaceOperator from .operators.laplace_operator import LaplaceOperator
from .operators.linear_operator import LinearOperator from .operators.linear_operator import LinearOperator
from .operators.mask_operator import MaskOperator from .operators.mask_operator import MaskOperator
from .operators.power_distributor import PowerDistributor
from .operators.qht_operator import QHTOperator from .operators.qht_operator import QHTOperator
from .operators.sampling_enabler import SamplingEnabler from .operators.sampling_enabler import SamplingEnabler
from .operators.sandwich_operator import SandwichOperator from .operators.sandwich_operator import SandwichOperator
......
...@@ -24,7 +24,7 @@ from ..multi_field import MultiField ...@@ -24,7 +24,7 @@ from ..multi_field import MultiField
from ..multi_domain import MultiDomain from ..multi_domain import MultiDomain
from ..operators.domain_distributor import DomainDistributor from ..operators.domain_distributor import DomainDistributor
from ..operators.harmonic_operators import HarmonicTransformOperator from ..operators.harmonic_operators import HarmonicTransformOperator
from ..operators.power_distributor import PowerDistributor from ..operators.distributors import PowerDistributor
from ..operators.operator import Operator from ..operators.operator import Operator
from ..operators.simple_linear_operators import FieldAdapter from ..operators.simple_linear_operators import FieldAdapter
......
...@@ -24,6 +24,7 @@ from .. import dobj ...@@ -24,6 +24,7 @@ from .. import dobj
from ..compat import * from ..compat import *
from ..domain_tuple import DomainTuple from ..domain_tuple import DomainTuple
from ..domains.dof_space import DOFSpace from ..domains.dof_space import DOFSpace
from ..domains.power_space import PowerSpace
from ..field import Field from ..field import Field
from ..utilities import infer_space, special_add_at from ..utilities import infer_space, special_add_at
from .linear_operator import LinearOperator from .linear_operator import LinearOperator
...@@ -117,14 +118,11 @@ class DOFDistributor(LinearOperator): ...@@ -117,14 +118,11 @@ class DOFDistributor(LinearOperator):
oarr = np.zeros(self._hshape, dtype=x.dtype) oarr = np.zeros(self._hshape, dtype=x.dtype)
oarr = special_add_at(oarr, 1, self._dofdex, arr) oarr = special_add_at(oarr, 1, self._dofdex, arr)
if dobj.distaxis(x.val) in x.domain.axes[self._space]: if dobj.distaxis(x.val) in x.domain.axes[self._space]:
oarr = dobj.np_allreduce_sum(oarr).reshape(self._domain.shape) oarr = oarr.reshape(self._domain.shape)
res = Field.from_global_data(self._domain, oarr) res = Field.from_global_data(self._domain, oarr, sum_up=True)
else: else:
oarr = oarr.reshape(dobj.local_shape(self._domain.shape, oarr = oarr.reshape(self._domain.local_shape)
dobj.distaxis(x.val))) res = Field.from_local_data(self._domain, oarr)
res = Field(self._domain,
dobj.from_local_data(self._domain.shape, oarr,
dobj.default_distaxis()))
return res return res
def _times(self, x): def _times(self, x):
...@@ -141,3 +139,37 @@ class DOFDistributor(LinearOperator): ...@@ -141,3 +139,37 @@ class DOFDistributor(LinearOperator):
def apply(self, x, mode): def apply(self, x, mode):
self._check_input(x, mode) self._check_input(x, mode)
return self._times(x) if mode == self.TIMES else self._adjoint_times(x) return self._times(x) if mode == self.TIMES else self._adjoint_times(x)
class PowerDistributor(DOFDistributor):
"""Operator which transforms between a PowerSpace and a harmonic domain.
Parameters
----------
target: Domain, tuple of Domain, or DomainTuple
the total *target* domain of the operator.
power_space: PowerSpace, optional
the input sub-domain on which the operator acts.
If not supplied, a matching PowerSpace with natural binbounds will be
used.
space: int, optional:
The index of the sub-domain on which the operator acts.
Can be omitted if `target` only has one sub-domain.
"""
def __init__(self, target, power_space=None, space=None):
# Initialize domain and target
self._target = DomainTuple.make(target)
self._space = infer_space(self._target, space)
hspace = self._target[self._space]
if not hspace.harmonic:
raise ValueError("Operator requires harmonic target space")
if power_space is None:
power_space = PowerSpace(hspace)
else:
if not isinstance(power_space, PowerSpace):
raise TypeError("power_space argument must be a PowerSpace")
if power_space.harmonic_partner != hspace:
raise ValueError("power_space does not match its partner")
self._init2(power_space.pindex, self._space, power_space)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
from ..compat import *
from ..domain_tuple import DomainTuple
from ..domains.power_space import PowerSpace
from ..utilities import infer_space
from .dof_distributor import DOFDistributor
class PowerDistributor(DOFDistributor):
"""Operator which transforms between a PowerSpace and a harmonic domain.
Parameters
----------
target: Domain, tuple of Domain, or DomainTuple
the total *target* domain of the operator.
power_space: PowerSpace, optional
the input sub-domain on which the operator acts.
If not supplied, a matching PowerSpace with natural binbounds will be
used.
space: int, optional:
The index of the sub-domain on which the operator acts.
Can be omitted if `target` only has one sub-domain.
"""
def __init__(self, target, power_space=None, space=None):
# Initialize domain and target
self._target = DomainTuple.make(target)
self._space = infer_space(self._target, space)
hspace = self._target[self._space]
if not hspace.harmonic:
raise ValueError("Operator requires harmonic target space")
if power_space is None:
power_space = PowerSpace(hspace)
else:
if not isinstance(power_space, PowerSpace):
raise TypeError("power_space argument must be a PowerSpace")
if power_space.harmonic_partner != hspace:
raise ValueError("power_space does not match its partner")
self._init2(power_space.pindex, self._space, power_space)
...@@ -28,8 +28,8 @@ from .scaling_operator import ScalingOperator ...@@ -28,8 +28,8 @@ from .scaling_operator import ScalingOperator
class SandwichOperator(EndomorphicOperator): class SandwichOperator(EndomorphicOperator):
"""Operator which is equivalent to the expression `bun.adjoint*cheese*bun`. """Operator which is equivalent to the expression
""" `bun.adjoint(cheese(bun))`."""
def __init__(self, bun, cheese, op, _callingfrommake=False): def __init__(self, bun, cheese, op, _callingfrommake=False):
if not _callingfrommake: if not _callingfrommake:
...@@ -54,7 +54,7 @@ class SandwichOperator(EndomorphicOperator): ...@@ -54,7 +54,7 @@ class SandwichOperator(EndomorphicOperator):
if not isinstance(bun, LinearOperator): if not isinstance(bun, LinearOperator):
raise TypeError("bun must be a linear operator") raise TypeError("bun must be a linear operator")
if cheese is not None and not isinstance(cheese, LinearOperator): if cheese is not None and not isinstance(cheese, LinearOperator):
raise TypeError("cheese must be a linear operator") raise TypeError("cheese must be a linear operator or None")
if cheese is None: if cheese is None:
cheese = ScalingOperator(1., bun.target) cheese = ScalingOperator(1., bun.target)
op = bun.adjoint(bun) op = bun.adjoint(bun)
...@@ -70,7 +70,7 @@ class SandwichOperator(EndomorphicOperator): ...@@ -70,7 +70,7 @@ class SandwichOperator(EndomorphicOperator):
return self._op.apply(x, mode) return self._op.apply(x, mode)
def draw_sample(self, from_inverse=False, dtype=np.float64): def draw_sample(self, from_inverse=False, dtype=np.float64):
# Inverse samples from general sandwiches is not possible # Inverse samples from general sandwiches are not possible
if from_inverse: if from_inverse:
if self._bun.capabilities & self._bun.INVERSE_TIMES: if self._bun.capabilities & self._bun.INVERSE_TIMES:
try: try:
......
...@@ -32,7 +32,7 @@ from .operators.block_diagonal_operator import BlockDiagonalOperator ...@@ -32,7 +32,7 @@ from .operators.block_diagonal_operator import BlockDiagonalOperator
from .multi_domain import MultiDomain from .multi_domain import MultiDomain
from .multi_field import MultiField from .multi_field import MultiField
from .operators.diagonal_operator import DiagonalOperator from .operators.diagonal_operator import DiagonalOperator
from .operators.power_distributor import PowerDistributor from .operators.distributors import PowerDistributor
__all__ = ['PS_field', 'power_analyze', 'create_power_operator', __all__ = ['PS_field', 'power_analyze', 'create_power_operator',
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment