Commit 5de1cebf authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'NIFTy_5' into energy_cleanup

parents f5efc36b 52593191
...@@ -59,6 +59,8 @@ class data_object(object): ...@@ -59,6 +59,8 @@ class data_object(object):
distaxis = -1 distaxis = -1
self._distaxis = distaxis self._distaxis = distaxis
self._data = data self._data = data
if local_shape(self._shape, self._distaxis) != self._data.shape:
raise ValueError("shape mismatch")
# def _sanity_checks(self): # def _sanity_checks(self):
# # check whether the distaxis is consistent # # check whether the distaxis is consistent
......
...@@ -79,7 +79,7 @@ def distaxis(arr): ...@@ -79,7 +79,7 @@ def distaxis(arr):
def from_local_data(shape, arr, distaxis=-1): def from_local_data(shape, arr, distaxis=-1):
if shape != arr.shape: if tuple(shape) != arr.shape:
raise ValueError raise ValueError
return arr return arr
......
...@@ -6,3 +6,4 @@ from .point_sources import PointSources ...@@ -6,3 +6,4 @@ from .point_sources import PointSources
from .poissonian_energy import PoissonianEnergy from .poissonian_energy import PoissonianEnergy
from .wiener_filter_curvature import WienerFilterCurvature from .wiener_filter_curvature import WienerFilterCurvature
from .wiener_filter_energy import WienerFilterEnergy from .wiener_filter_energy import WienerFilterEnergy
from .smooth_sky import make_correlated_field, make_mf_correlated_field
def make_correlated_field(s_space, amplitude_model): def make_correlated_field(s_space, amplitude_model):
''' '''
Method for construction of correlated sky model Method for construction of correlated fields
Parameters Parameters
---------- ----------
s_space : domain of sky model s_space : Field domain
amplitude_model : model for correlation structure amplitude_model : model for correlation structure
''' '''
...@@ -30,16 +30,10 @@ def make_correlated_field(s_space, amplitude_model): ...@@ -30,16 +30,10 @@ def make_correlated_field(s_space, amplitude_model):
return correlated_field, internals return correlated_field, internals
def make_smooth_mf_sky_model(s_space_spatial, s_space_energy, def make_mf_correlated_field(s_space_spatial, s_space_energy,
amplitude_model_spatial, amplitude_model_energy): amplitude_model_spatial, amplitude_model_energy):
''' '''
Method for construction of correlated sky model Method for construction of correlated multi-frequency fields
Parameters
----------
s_space : domain of sky model
amplitude_model : model for correlation structure
''' '''
from .. import (DomainTuple, Field, MultiField, from .. import (DomainTuple, Field, MultiField,
PointwiseExponential, Variable) PointwiseExponential, Variable)
......
...@@ -4,6 +4,7 @@ from ..domain_tuple import DomainTuple ...@@ -4,6 +4,7 @@ from ..domain_tuple import DomainTuple
from ..domains import PowerSpace, RGSpace from ..domains import PowerSpace, RGSpace
from ..field import Field from ..field import Field
from .linear_operator import LinearOperator from .linear_operator import LinearOperator
from .. import dobj
class ExpTransform(LinearOperator): class ExpTransform(LinearOperator):
...@@ -27,7 +28,7 @@ class ExpTransform(LinearOperator): ...@@ -27,7 +28,7 @@ class ExpTransform(LinearOperator):
for i in range(ndim): for i in range(ndim):
if isinstance(target, RGSpace): if isinstance(target, RGSpace):
rng = np.arange(target.shape[i]) rng = np.arange(target.shape[i])
tmp = np.minimum(rng, target.shape[i] + 1 - rng) tmp = np.minimum(rng, target.shape[i]+1-rng)
k_array = tmp * target.distances[i] k_array = tmp * target.distances[i]
else: else:
k_array = target.k_lengths k_array = target.k_lengths
...@@ -42,8 +43,8 @@ class ExpTransform(LinearOperator): ...@@ -42,8 +43,8 @@ class ExpTransform(LinearOperator):
# Save t_min for later # Save t_min for later
t_mins[i] = t_min t_mins[i] = t_min
bindistances[i] = (t_max - t_min) / (dof[i] - 1) bindistances[i] = (t_max-t_min) / (dof[i]-1)
coord = np.append(0., 1. + (log_k_array - t_min) / bindistances[i]) coord = np.append(0., 1. + (log_k_array-t_min) / bindistances[i])
self._bindex[i] = np.floor(coord).astype(int) self._bindex[i] = np.floor(coord).astype(int)
# Interpolated value is computed via # Interpolated value is computed via
...@@ -52,7 +53,7 @@ class ExpTransform(LinearOperator): ...@@ -52,7 +53,7 @@ class ExpTransform(LinearOperator):
self._frac[i] = coord - self._bindex[i] self._frac[i] = coord - self._bindex[i]
from ..domains import LogRGSpace from ..domains import LogRGSpace
log_space = LogRGSpace(2 * dof + 1, bindistances, log_space = LogRGSpace(2*dof+1, bindistances,
t_mins, harmonic=False) t_mins, harmonic=False)
self._target = DomainTuple.make(target) self._target = DomainTuple.make(target)
self._domain = DomainTuple.make(log_space) self._domain = DomainTuple.make(log_space)
...@@ -67,30 +68,34 @@ class ExpTransform(LinearOperator): ...@@ -67,30 +68,34 @@ class ExpTransform(LinearOperator):
def apply(self, x, mode): def apply(self, x, mode):
self._check_input(x, mode) self._check_input(x, mode)
x = x.to_global_data() x = x.val
ax = dobj.distaxis(x)
ndim = len(self.target.shape) ndim = len(self.target.shape)
idx = () curshp = list(self._dom(mode).shape)
for d in range(ndim): for d in range(ndim):
fst_dims = (1,) * d idx = (slice(None,),) * d
lst_dims = (1,) * (ndim - d - 1) wgt = self._frac[d].reshape((1,)*d + (-1,) + (1,)*(ndim-d-1))
wgt = self._frac[d].reshape(fst_dims + (-1,) + lst_dims)
if d == ax:
x = dobj.redistribute(x, nodist=(ax,))
curax = dobj.distaxis(x)
x = dobj.local_data(x)
# ADJOINT_TIMES
if mode == self.ADJOINT_TIMES: if mode == self.ADJOINT_TIMES:
shp = list(x.shape) shp = list(x.shape)
shp[d] = self._tgt(mode).shape[d] shp[d] = self._tgt(mode).shape[d]
xnew = np.zeros(shp, dtype=x.dtype) xnew = np.zeros(shp, dtype=x.dtype)
np.add.at(xnew, idx + (self._bindex[d],), x * (1. - wgt)) np.add.at(xnew, idx + (self._bindex[d],), x * (1.-wgt))
np.add.at(xnew, idx + (self._bindex[d] + 1,), x * wgt) np.add.at(xnew, idx + (self._bindex[d]+1,), x * wgt)
else: # TIMES
# TIMES xnew = x[idx + (self._bindex[d],)] * (1.-wgt)
else: xnew += x[idx + (self._bindex[d]+1,)] * wgt
xnew = x[idx + (self._bindex[d],)] * (1. - wgt)
xnew += x[idx + (self._bindex[d] + 1,)] * wgt curshp[d] = self._tgt(mode).shape[d]
x = dobj.from_local_data(curshp, xnew, distaxis=curax)
x = xnew if d == ax:
idx = (slice(None),) + idx x = dobj.redistribute(x, dist=ax)
return Field.from_global_data(self._tgt(mode), x) return Field(self._tgt(mode), val=x)
@property @property
def capability(self): def capability(self):
......
...@@ -24,6 +24,9 @@ from ..field import Field ...@@ -24,6 +24,9 @@ from ..field import Field
from .linear_operator import LinearOperator from .linear_operator import LinearOperator
# MR FIXME: this needs a redesign to avoid most _global_data() calls
# Possible approach: keep everything living on `domain` distributed and only
# collect the unstructured Fields.
class MaskOperator(LinearOperator): class MaskOperator(LinearOperator):
def __init__(self, mask): def __init__(self, mask):
if not isinstance(mask, Field): if not isinstance(mask, Field):
......
...@@ -35,13 +35,19 @@ class QHTOperator(LinearOperator): ...@@ -35,13 +35,19 @@ class QHTOperator(LinearOperator):
x = x.val * self.domain[0].scalar_dvol() x = x.val * self.domain[0].scalar_dvol()
n = len(self.domain[0].shape) n = len(self.domain[0].shape)
rng = range(n) if mode == self.TIMES else reversed(range(n)) rng = range(n) if mode == self.TIMES else reversed(range(n))
# MR FIXME: this needs to be fixed properly for MPI ax = dobj.distaxis(x)
x = dobj.to_global_data(x) globshape = x.shape
for i in rng: for i in rng:
sl = (slice(None),)*i + (slice(1, None),) sl = (slice(None),)*i + (slice(1, None),)
if i == ax:
x = dobj.redistribute(x, nodist=(ax,))
curax = dobj.distaxis(x)
x = dobj.local_data(x)
x[sl] = hartley(x[sl], axes=(i,)) x[sl] = hartley(x[sl], axes=(i,))
x = dobj.from_local_data(globshape, x, distaxis=curax)
return Field.from_global_data(self._tgt(mode), x) if i == ax:
x = dobj.redistribute(x, dist=ax)
return Field(self._tgt(mode), val=x)
@property @property
def capability(self): def capability(self):
......
...@@ -25,13 +25,8 @@ class SlopeOperator(LinearOperator): ...@@ -25,13 +25,8 @@ class SlopeOperator(LinearOperator):
rng = np.arange(target.shape[i]) rng = np.arange(target.shape[i])
tmp = np.minimum( tmp = np.minimum(
rng, target.shape[i] + 1 - rng) * target.bindistances[i] rng, target.shape[i] + 1 - rng) * target.bindistances[i]
fst_dims = (1,) * i self.pos[i] += tmp.reshape(
lst_dims = (1,) * (self.ndim - i - 1) (1,)*i + (shape[i],) + (1,)*(self.ndim-i-1))
self.pos[i] += tmp.reshape(fst_dims + (shape[i],) + lst_dims)
@property
def sigmas(self):
return self._sigmas
@property @property
def domain(self): def domain(self):
...@@ -47,17 +42,17 @@ class SlopeOperator(LinearOperator): ...@@ -47,17 +42,17 @@ class SlopeOperator(LinearOperator):
# Times # Times
if mode == self.TIMES: if mode == self.TIMES:
inp = x.to_global_data() inp = x.to_global_data()
res = self.sigmas[-1] * inp[-1] res = self._sigmas[-1] * inp[-1]
for i in range(self.ndim): for i in range(self.ndim):
res += self.sigmas[i] * inp[i] * self.pos[i] res += self._sigmas[i] * inp[i] * self.pos[i]
return Field.from_global_data(self.target, res) return Field.from_global_data(self.target, res)
# Adjoint times # Adjoint times
res = np.zeros(self.domain[0].shape) res = np.zeros(self.domain[0].shape)
xglob = x.to_global_data() xglob = x.to_global_data()
res[-1] = np.sum(xglob) * self.sigmas[-1] res[-1] = np.sum(xglob) * self._sigmas[-1]
for i in range(self.ndim): for i in range(self.ndim):
res[i] = np.sum(self.pos[i] * xglob) * self.sigmas[i] res[i] = np.sum(self.pos[i] * xglob) * self._sigmas[i]
return Field.from_global_data(self.domain, res) return Field.from_global_data(self.domain, res)
@property @property
......
from ..domain_tuple import DomainTuple from ..domain_tuple import DomainTuple
from ..field import Field from ..field import Field
from .endomorphic_operator import EndomorphicOperator from .endomorphic_operator import EndomorphicOperator
from .. import dobj
class SymmetrizingOperator(EndomorphicOperator): class SymmetrizingOperator(EndomorphicOperator):
...@@ -14,12 +15,20 @@ class SymmetrizingOperator(EndomorphicOperator): ...@@ -14,12 +15,20 @@ class SymmetrizingOperator(EndomorphicOperator):
def apply(self, x, mode): def apply(self, x, mode):
self._check_input(x, mode) self._check_input(x, mode)
# FIXME Not efficient with MPI tmp = x.copy().val
tmp = x.to_global_data().copy() ax = dobj.distaxis(tmp)
globshape = tmp.shape
for i in range(self._ndim): for i in range(self._ndim):
lead = (slice(None),)*i lead = (slice(None),)*i
if i == ax:
tmp = dobj.redistribute(tmp, nodist=(ax,))
curax = dobj.distaxis(tmp)
tmp = dobj.local_data(tmp)
tmp[lead + (slice(1, None),)] -= tmp[lead + (slice(None, 0, -1),)] tmp[lead + (slice(1, None),)] -= tmp[lead + (slice(None, 0, -1),)]
return Field.from_global_data(self.target, tmp) tmp = dobj.from_local_data(globshape, tmp, distaxis=curax)
if i == ax:
tmp = dobj.redistribute(tmp, dist=ax)
return Field(self.target, val=tmp)
@property @property
def capability(self): def capability(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment