Commit 57006538 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'operator_work' into 'NIFTy_5'

Operator work

See merge request ift/nifty-dev!50
parents de0a55da 59a411bb
......@@ -114,4 +114,3 @@ if __name__ == '__main__':
title='Data', name='data.png')
ift.plot(HT(m), title='Reconstruction', name='reconstruction.png')
ift.plot(mask_to_nan(mask, HT(m-MOCK_SIGNAL)), name='residuals.png')
......@@ -33,6 +33,7 @@ from .operators.endomorphic_operator import EndomorphicOperator
from .operators.exp_transform import ExpTransform
from .operators.fft_operator import FFTOperator
from .operators.fft_smoothing_operator import FFTSmoothingOperator
from .operators.field_zero_padder import FieldZeroPadder
from .operators.geometry_remover import GeometryRemover
from .operators.harmonic_transform_operator import HarmonicTransformOperator
from .operators.inversion_enabler import InversionEnabler
......
......@@ -52,7 +52,7 @@ class DomainTuple(object):
nax = len(thing.shape)
res[idx] = tuple(range(i, i+nax))
i += nax
return res
return tuple(res)
@staticmethod
def make(domain):
......
......@@ -44,6 +44,7 @@ def make_correlated_field(s_space, amplitude_model):
ht = FFTOperator(h_space, s_space)
p_space = amplitude_model.value.domain[0]
power_distributor = PowerDistributor(h_space, p_space)
# FIXME Remove tau and phi stuff from here. Should not be necessary
position = MultiField.from_dict({
'xi': Field.from_random('normal', h_space),
'tau': amplitude_model.position['tau'],
......
......@@ -37,9 +37,8 @@ class Constant(Model):
-----
Since there is no model-function associated:
- Position has no influence on value.
- There is no Jacobian.
- The Jacobian is a null matrix.
"""
# TODO Remove position
def __init__(self, position, constant):
super(Constant, self).__init__(position)
self._constant = constant
......
......@@ -242,7 +242,7 @@ for op in ["__sub__", "__rsub__",
if self._domain is not other._domain:
raise ValueError("domain mismatch")
val = tuple(getattr(v1, op)(v2)
for v1, v2 in zip (self._val, other._val))
for v1, v2 in zip(self._val, other._val))
else:
val = tuple(getattr(v1, op)(other) for v1 in self._val)
return MultiField(self._domain, val)
......
......@@ -30,17 +30,19 @@ from .linear_operator import LinearOperator
class ExpTransform(LinearOperator):
def __init__(self, target, dof):
if not ((isinstance(target, RGSpace) and target.harmonic) or
isinstance(target, PowerSpace)):
def __init__(self, target, dof, space=0):
self._target = DomainTuple.make(target)
self._space = int(space)
tgt = self._target[self._space]
if not ((isinstance(tgt, RGSpace) and tgt.harmonic) or
isinstance(tgt, PowerSpace)):
raise ValueError(
"Target must be a harmonic RGSpace or a power space.")
if np.isscalar(dof):
dof = np.full(len(target.shape), int(dof), dtype=np.int)
dof = np.full(len(tgt.shape), int(dof), dtype=np.int)
dof = np.array(dof)
ndim = len(target.shape)
ndim = len(tgt.shape)
t_mins = np.empty(ndim)
bindistances = np.empty(ndim)
......@@ -48,12 +50,12 @@ class ExpTransform(LinearOperator):
self._frac = [None] * ndim
for i in range(ndim):
if isinstance(target, RGSpace):
rng = np.arange(target.shape[i])
tmp = np.minimum(rng, target.shape[i]+1-rng)
k_array = tmp * target.distances[i]
if isinstance(tgt, RGSpace):
rng = np.arange(tgt.shape[i])
tmp = np.minimum(rng, tgt.shape[i]+1-rng)
k_array = tmp * tgt.distances[i]
else:
k_array = target.k_lengths
k_array = tgt.k_lengths
# avoid taking log of first entry
log_k_array = np.log(k_array[1:])
......@@ -77,8 +79,9 @@ class ExpTransform(LinearOperator):
from ..domains.log_rg_space import LogRGSpace
log_space = LogRGSpace(2*dof+1, bindistances,
t_mins, harmonic=False)
self._target = DomainTuple.make(target)
self._domain = DomainTuple.make(log_space)
self._domain = [dom for dom in self._target]
self._domain[self._space] = log_space
self._domain = DomainTuple.make(self._domain)
@property
def domain(self):
......@@ -94,9 +97,10 @@ class ExpTransform(LinearOperator):
ax = dobj.distaxis(x)
ndim = len(self.target.shape)
curshp = list(self._dom(mode).shape)
for d in range(ndim):
d0 = self._target.axes[self._space][0]
for d in self._target.axes[self._space]:
idx = (slice(None,),) * d
wgt = self._frac[d].reshape((1,)*d + (-1,) + (1,)*(ndim-d-1))
wgt = self._frac[d-d0].reshape((1,)*d + (-1,) + (1,)*(ndim-d-1))
if d == ax:
x = dobj.redistribute(x, nodist=(ax,))
......@@ -107,11 +111,11 @@ class ExpTransform(LinearOperator):
shp = list(x.shape)
shp[d] = self._tgt(mode).shape[d]
xnew = np.zeros(shp, dtype=x.dtype)
np.add.at(xnew, idx + (self._bindex[d],), x * (1.-wgt))
np.add.at(xnew, idx + (self._bindex[d]+1,), x * wgt)
np.add.at(xnew, idx + (self._bindex[d-d0],), x * (1.-wgt))
np.add.at(xnew, idx + (self._bindex[d-d0]+1,), x * wgt)
else: # TIMES
xnew = x[idx + (self._bindex[d],)] * (1.-wgt)
xnew += x[idx + (self._bindex[d]+1,)] * wgt
xnew = x[idx + (self._bindex[d-d0],)] * (1.-wgt)
xnew += x[idx + (self._bindex[d-d0]+1,)] * wgt
curshp[d] = self._tgt(mode).shape[d]
x = dobj.from_local_data(curshp, xnew, distaxis=curax)
......
from __future__ import absolute_import, division, print_function
import numpy as np
from .. import dobj
from ..compat import *
from ..domain_tuple import DomainTuple
from ..domains.rg_space import RGSpace
from ..field import Field
from .linear_operator import LinearOperator
class FieldZeroPadder(LinearOperator):
def __init__(self, domain, factor, space=0):
super(FieldZeroPadder, self).__init__()
self._domain = DomainTuple.make(domain)
self._space = int(space)
dom = self._domain[self._space]
if not isinstance(dom, RGSpace):
raise TypeError("RGSpace required")
if not len(dom.shape) == 1:
raise TypeError("RGSpace must be one-dimensional")
if dom.harmonic:
raise TypeError("RGSpace must not be harmonic")
tgt = RGSpace((int(factor*dom.shape[0]),), dom.distances)
self._target = list(self._domain)
self._target[self._space] = tgt
self._target = DomainTuple.make(self._target)
@property
def domain(self):
return self._domain
@property
def target(self):
return self._target
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
x = x.val
dax = dobj.distaxis(x)
shp_in = x.shape
shp_out = self._tgt(mode).shape
ax = self._target.axes[self._space][0]
if dax == ax:
x = dobj.redistribute(x, nodist=(ax,))
curax = dobj.distaxis(x)
if mode == self.ADJOINT_TIMES:
newarr = np.empty(dobj.local_shape(shp_out), dtype=x.dtype)
newarr[()] = dobj.local_data(x)[(slice(None),)*ax +
(slice(0, shp_out[ax]),)]
else:
newarr = np.zeros(dobj.local_shape(shp_out), dtype=x.dtype)
newarr[(slice(None),)*ax +
(slice(0, shp_in[ax]),)] = dobj.local_data(x)
newarr = dobj.from_local_data(shp_out, newarr, distaxis=curax)
if dax == ax:
newarr = dobj.redistribute(newarr, dist=ax)
return Field(self._tgt(mode), val=newarr)
......@@ -88,6 +88,7 @@ class LinearOperator(NiftyMetaBase()):
@abc.abstractproperty
def domain(self):
# FIXME Adopt documentation to MultiDomains
"""DomainTuple : the operator's input domain
The domain on which the Operator's input Field lives."""
......
......@@ -41,7 +41,7 @@ class NullOperator(LinearOperator):
@staticmethod
def _nullfield(dom):
if isinstance (dom, DomainTuple):
if isinstance(dom, DomainTuple):
return Field.full(dom, 0)
else:
return MultiField(dom, (None,)*len(dom))
......
......@@ -36,26 +36,34 @@ class QHTOperator(LinearOperator):
Parameters
----------
domain : LogRGSpace
The domain needs to be a LogRGSpace.
domain : domain, tuple of domains or DomainTuple
The full input domain
space : int
The index of the domain on which the operator acts.
domain[space] must be a harmonic LogRGSpace.
target : LogRGSpace
The target needs to be a LogRGSpace.
The target codomain of domain[space]
Must be a nonharmonic LogRGSpace.
"""
def __init__(self, domain, target):
if not domain.harmonic:
raise TypeError(
"HarmonicTransformOperator only works on a harmonic space")
if target.harmonic:
raise TypeError("Target is not a codomain of domain")
def __init__(self, domain, target, space=0):
self._domain = DomainTuple.make(domain)
self._space = int(space)
from ..domains.log_rg_space import LogRGSpace
if not isinstance(domain, LogRGSpace):
raise ValueError("Domain has to be a LogRGSpace!")
if not isinstance(self._domain[self._space], LogRGSpace):
raise ValueError("Domain[space] has to be a LogRGSpace!")
if not isinstance(target, LogRGSpace):
raise ValueError("Target has to be a LogRGSpace!")
self._domain = DomainTuple.make(domain)
self._target = DomainTuple.make(target)
if not self._domain[self._space].harmonic:
raise TypeError(
"HarmonicTransformOperator only works on a harmonic space")
if target.harmonic:
raise TypeError("Target is not a codomain of domain")
self._target = [dom for dom in self._domain]
self._target[self._space] = target
self._target = DomainTuple.make(self._target)
@property
def domain(self):
......@@ -67,9 +75,10 @@ class QHTOperator(LinearOperator):
def apply(self, x, mode):
self._check_input(x, mode)
x = x.val * self.domain[0].scalar_dvol()
n = len(self.domain[0].shape)
rng = range(n) if mode == self.TIMES else reversed(range(n))
dom = self._domain[self._space]
x = x.val * dom.scalar_dvol()
n = self._domain.axes[self._space]
rng = n if mode == self.TIMES else reversed(n)
ax = dobj.distaxis(x)
globshape = x.shape
for i in rng:
......
......@@ -27,11 +27,12 @@ from .endomorphic_operator import EndomorphicOperator
class SymmetrizingOperator(EndomorphicOperator):
def __init__(self, domain):
if not (isinstance(domain, LogRGSpace) and not domain.harmonic):
raise TypeError
def __init__(self, domain, space=0):
self._domain = DomainTuple.make(domain)
self._ndim = len(self.domain.shape)
self._space = int(space)
dom = self._domain[self._space]
if not (isinstance(dom, LogRGSpace) and not dom.harmonic):
raise TypeError
@property
def domain(self):
......@@ -42,7 +43,7 @@ class SymmetrizingOperator(EndomorphicOperator):
tmp = x.val.copy()
ax = dobj.distaxis(tmp)
globshape = tmp.shape
for i in range(self._ndim):
for i in self._domain.axes[self._space]:
lead = (slice(None),)*i
if i == ax:
tmp = dobj.redistribute(tmp, nodist=(ax,))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment