Commit d10f7a5a authored by Philipp Arras's avatar Philipp Arras

Fixups

parent 357e20ac
def make_smooth_sky_model(s_space, amplitude_model):
'''
Method for construction of correlated sky model
......@@ -43,9 +41,10 @@ def make_smooth_mf_sky_model(s_space_spatial, s_space_energy,
amplitude_model : model for correlation structure
'''
from .. import (DomainTuple, Field, HarmonicTransformOperator, MultiField,
PointwiseExponential, PowerDistributor, Variable)
from ..linear_operators import DomainDistributor
from .. import (DomainTuple, Field, MultiField,
PointwiseExponential, Variable)
from ..operators import (DomainDistributor, PowerDistributor,
HarmonicTransformOperator)
h_space_spatial = s_space_spatial.get_default_codomain()
h_space_energy = s_space_energy.get_default_codomain()
h_space = DomainTuple.make((h_space_spatial, h_space_energy))
......
from .diagonal_operator import DiagonalOperator
from .dof_distributor import DOFDistributor
from .domain_distributor import DomainDistributor
from .endomorphic_operator import EndomorphicOperator
from .exp_transform import ExpTransform
from .fft_operator import FFTOperator
......@@ -27,4 +28,4 @@ __all__ = ["LinearOperator", "EndomorphicOperator", "ScalingOperator",
"InversionEnabler", "SandwichOperator", "SamplingEnabler",
"DOFDistributor", "SelectionOperator", "MultiAdaptor",
"ExpTransform", "SymmetrizingOperator", "QHTOperator",
"SlopeOperator"]
"SlopeOperator", "DomainDistributor"]
import numpy as np
from ..field import Field
from .. import dobj
from ..domain_tuple import DomainTuple
from .linear_operator import LinearOperator
if dobj.ntask > 1:
raise NotImplementedError('UpProj class does not support MPI.')
class DomainDistributor(LinearOperator):
def __init__(self, target, axis):
assert len(target) == 2
assert axis in [0, 1]
if axis == 0:
domain = target[1]
self._size = target[0].size
else:
domain = target[0]
self._size = target[1].size
self._axis = axis
self._domain = DomainTuple.make(domain)
self._target = DomainTuple.make(target)
@property
def domain(self):
return self._domain
@property
def target(self):
return self._target
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
x = x.local_data
otherDirection = np.ones(self._size)
if self._axis == 0:
res = np.outer(otherDirection, x)
else:
res = np.outer(x, otherDirection)
res = res.reshape(dobj.local_shape(self.target.shape))
return Field.from_local_data(self.target, res)
else:
if self._axis == 0:
x = x.local_data.reshape(self._size, -1)
res = np.sum(x, axis=0)
else:
x = x.local_data.reshape(-1, self._size)
res = np.sum(x, axis=1)
res = res.reshape(dobj.local_shape(self.domain.shape))
return Field.from_local_data(self.domain, res)
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment