Commit 09f7ba45 authored by Martin Reinecke's avatar Martin Reinecke

Merge branch 'misc_op_work' into 'NIFTy_5'

Misc op work

See merge request ift/nifty-dev!88
parents fc9e9bde 3e395a7b
......@@ -106,23 +106,21 @@ class AmplitudeModel(Operator):
from ..operators.symmetrizing_operator import SymmetrizingOperator
h_space = s_space.get_default_codomain()
p_space = PowerSpace(h_space)
self._exp_transform = ExpTransform(p_space, Npixdof)
self._exp_transform = ExpTransform(PowerSpace(h_space), Npixdof)
logk_space = self._exp_transform.domain[0]
qht = QHTOperator(target=logk_space)
dof_space = qht.domain[0]
param_space = UnstructuredDomain(2)
sym = SymmetrizingOperator(logk_space)
phi_mean = np.array([sm, im])
phi_sig = np.array([sv, iv])
self._slope = SlopeOperator(param_space, logk_space, phi_sig)
self._norm_phi_mean = Field.from_global_data(param_space,
self._slope = SlopeOperator(logk_space, phi_sig)
self._norm_phi_mean = Field.from_global_data(self._slope.domain,
phi_mean/phi_sig)
self._domain = MultiDomain.make({keys[0]: dof_space,
keys[1]: param_space})
keys[1]: self._slope.domain})
self._target = self._exp_transform.target
kern = lambda k: _ceps_kernel(dof_space, k, ceps_a, ceps_k)
......
......@@ -133,6 +133,9 @@ class ChainOperator(LinearOperator):
x = op.apply(x, mode)
return x
def __repr__(self):
subs = "\n".join(sub.__repr__() for sub in self._ops)
return "ChainOperator:\n" + utilities.indent(subs)
# def draw_sample(self, from_inverse=False, dtype=np.float64):
# from ..sugar import from_random
......@@ -144,7 +147,3 @@ class ChainOperator(LinearOperator):
# for op in self._ops:
# samp = op.process_sample(samp, from_inverse)
# return samp
def __repr__(self):
subs = "\n".join(sub.__repr__() for sub in self._ops)
return "ChainOperator:\n" + utilities.indent(subs)
......@@ -108,6 +108,7 @@ class ExpTransform(LinearOperator):
v = x.val
ndim = len(self.target.shape)
curshp = list(self._dom(mode).shape)
tgtshp = self._tgt(mode).shape
d0 = self._target.axes[self._space][0]
for d in self._target.axes[self._space]:
idx = (slice(None),) * d
......@@ -117,7 +118,7 @@ class ExpTransform(LinearOperator):
if mode == self.ADJOINT_TIMES:
shp = list(x.shape)
shp[d] = self._tgt(mode).shape[d]
shp[d] = tgtshp[d]
xnew = np.zeros(shp, dtype=x.dtype)
xnew = special_add_at(xnew, d, self._bindex[d-d0], x*(1.-wgt))
xnew = special_add_at(xnew, d, self._bindex[d-d0]+1, x*wgt)
......@@ -125,6 +126,6 @@ class ExpTransform(LinearOperator):
xnew = x[idx + (self._bindex[d-d0],)] * (1.-wgt)
xnew += x[idx + (self._bindex[d-d0]+1,)] * wgt
curshp[d] = self._tgt(mode).shape[d]
curshp[d] = xnew.shape[d]
v = dobj.from_local_data(curshp, xnew, distaxis=dobj.distaxis(v))
return Field(self._tgt(mode), dobj.ensure_default_distributed(v))
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
import numpy as np
......@@ -11,9 +29,10 @@ from .linear_operator import LinearOperator
class FieldZeroPadder(LinearOperator):
def __init__(self, domain, new_shape, space=0):
def __init__(self, domain, new_shape, space=0, central=False):
self._domain = DomainTuple.make(domain)
self._space = utilities.infer_space(self._domain, space)
self._central = central
dom = self._domain[self._space]
if not isinstance(dom, RGSpace):
raise TypeError("RGSpace required")
......@@ -22,7 +41,7 @@ class FieldZeroPadder(LinearOperator):
if len(new_shape) != len(dom.shape):
raise ValueError("Shape mismatch")
if any([a < b for a, b in zip(new_shape, dom.shape)]):
if any([a <= b for a, b in zip(new_shape, dom.shape)]):
raise ValueError("New shape must be larger than old shape")
self._target = list(self._domain)
self._target[self._space] = RGSpace(new_shape, dom.distances)
......@@ -31,25 +50,48 @@ class FieldZeroPadder(LinearOperator):
def apply(self, x, mode):
self._check_input(x, mode)
x = x.val
dax = dobj.distaxis(x)
shp_in = x.shape
shp_out = self._tgt(mode).shape
axbefore = self._target.axes[self._space][0]
axes = self._target.axes[self._space]
if dax in axes:
x = dobj.redistribute(x, nodist=axes)
curax = dobj.distaxis(x)
if mode == self.ADJOINT_TIMES:
newarr = np.empty(dobj.local_shape(shp_out, curax), dtype=x.dtype)
sl = tuple(slice(0, shp_out[axis]) for axis in axes)
newarr[()] = dobj.local_data(x)[(slice(None),)*axbefore + sl]
v = x.val
curshp = list(self._dom(mode).shape)
tgtshp = self._tgt(mode).shape
for d in self._target.axes[self._space]:
idx = (slice(None),) * d
v, x = dobj.ensure_not_distributed(v, (d,))
if mode == self.TIMES:
shp = list(x.shape)
shp[d] = tgtshp[d]
xnew = np.zeros(shp, dtype=x.dtype)
if self._central:
Nyquist = x.shape[d]//2
i1 = idx + (slice(0, Nyquist+1),)
xnew[i1] = x[i1]
i1 = idx + (slice(None, -(Nyquist+1), -1),)
xnew[i1] = x[i1]
# if (x.shape[d] & 1) == 0: # even number of pixels
# print (Nyquist, x.shape[d]-Nyquist)
# i1 = idx+(Nyquist,)
# xnew[i1] *= 0.5
# i1 = idx+(-Nyquist,)
# xnew[i1] *= 0.5
else:
newarr = np.zeros(dobj.local_shape(shp_out, curax), dtype=x.dtype)
sl = tuple(slice(0, shp_in[axis]) for axis in axes)
newarr[(slice(None),)*axbefore + sl] = dobj.local_data(x)
newarr = dobj.from_local_data(shp_out, newarr, distaxis=curax)
if dax in axes:
newarr = dobj.redistribute(newarr, dist=dax)
return Field(self._tgt(mode), val=newarr)
xnew[idx + (slice(0, x.shape[d]),)] = x
else: # ADJOINT_TIMES
if self._central:
shp = list(x.shape)
shp[d] = tgtshp[d]
xnew = np.zeros(shp, dtype=x.dtype)
Nyquist = xnew.shape[d]//2
i1 = idx + (slice(0, Nyquist+1),)
xnew[i1] = x[i1]
i1 = idx + (slice(None, -(Nyquist+1), -1),)
xnew[i1] += x[i1]
# if (xnew.shape[d] & 1) == 0: # even number of pixels
# i1 = idx+(Nyquist,)
# xnew[i1] *= 0.5
else:
xnew = x[idx + (slice(0, tgtshp[d]),)]
curshp[d] = xnew.shape[d]
v = dobj.from_local_data(curshp, xnew, distaxis=dobj.distaxis(v))
return Field(self._tgt(mode), dobj.ensure_default_distributed(v))
......@@ -52,8 +52,7 @@ class QHTOperator(LinearOperator):
raise ValueError("target[space] has to be a LogRGSpace!")
if self._target[self._space].harmonic:
raise TypeError(
"target[space] must be a nonharmonic space")
raise TypeError("target[space] must be a nonharmonic space")
self._domain = [dom for dom in self._target]
self._domain[self._space] = \
......
......@@ -65,6 +65,7 @@ class RegriddingOperator(LinearOperator):
v = x.val
ndim = len(self.target.shape)
curshp = list(self._dom(mode).shape)
tgtshp = self._tgt(mode).shape
d0 = self._target.axes[self._space][0]
for d in self._target.axes[self._space]:
idx = (slice(None),) * d
......@@ -74,7 +75,7 @@ class RegriddingOperator(LinearOperator):
if mode == self.ADJOINT_TIMES:
shp = list(x.shape)
shp[d] = self._tgt(mode).shape[d]
shp[d] = tgtshp[d]
xnew = np.zeros(shp, dtype=x.dtype)
xnew = special_add_at(xnew, d, self._bindex[d-d0], x*(1.-wgt))
xnew = special_add_at(xnew, d, self._bindex[d-d0]+1, x*wgt)
......@@ -82,6 +83,6 @@ class RegriddingOperator(LinearOperator):
xnew = x[idx + (self._bindex[d-d0],)] * (1.-wgt)
xnew += x[idx + (self._bindex[d-d0]+1,)] * wgt
curshp[d] = self._tgt(mode).shape[d]
curshp[d] = xnew.shape[d]
v = dobj.from_local_data(curshp, xnew, distaxis=dobj.distaxis(v))
return Field(self._tgt(mode), dobj.ensure_default_distributed(v))
......@@ -34,31 +34,25 @@ class SlopeOperator(LinearOperator):
This operator creates a field on a LogRGSpace, which is created
according to a slope of given entries, (mean, y-intercept).
The slope mean is the powerlaw of the field in normal-space.
The slope mean is the power law of the field in normal-space.
Parameters
----------
domain : domain or DomainTuple, shape=(2,)
It has to be and UnstructuredDomain.
It has to be an UnstructuredDomain.
The domain of the slope mean and the y-intercept mean.
target : domain or DomainTuple
The output domain has to a LogRGSpace
sigmas : np.array, shape=(2,)
The slope variance and the y-intercept variance.
"""
def __init__(self, domain, target, sigmas):
def __init__(self, target, sigmas):
if not isinstance(target, LogRGSpace):
raise TypeError
if not (isinstance(domain, UnstructuredDomain) and domain.shape == (2,)):
raise TypeError
self._domain = DomainTuple.make(domain)
self._domain = DomainTuple.make(UnstructuredDomain((2,)))
self._target = DomainTuple.make(target)
self._capability = self.TIMES | self.ADJOINT_TIMES
if self.domain[0].shape != (len(self.target[0].shape) + 1,):
raise AssertionError("Shape mismatch!")
self._sigmas = sigmas
self.ndim = len(self.target[0].shape)
self.pos = np.zeros((self.ndim,) + self.target[0].shape)
......
......@@ -74,8 +74,7 @@ class Consistency_Tests(unittest.TestCase):
tmp = ift.ExpTransform(ift.PowerSpace(args[0]), args[1], args[2])
tgt = tmp.domain[0]
sig = np.array([0.3, 0.13])
dom = ift.UnstructuredDomain(2)
op = ift.SlopeOperator(dom, tgt, sig)
op = ift.SlopeOperator(tgt, sig)
ift.extra.consistency_check(op, dtype, dtype)
@expand(product(_h_spaces + _p_spaces + _pow_spaces,
......@@ -202,19 +201,20 @@ class Consistency_Tests(unittest.TestCase):
op = ift.SymmetrizingOperator(dom, space)
ift.extra.consistency_check(op, dtype, dtype)
@expand(product([0, 2], [2, 2.7], [np.float64, np.complex128]))
def testZeroPadder(self, space, factor, dtype):
@expand(product([0, 2], [2, 2.7], [np.float64, np.complex128],
[False, True]))
def testZeroPadder(self, space, factor, dtype, central):
dom = (ift.RGSpace(10), ift.UnstructuredDomain(13), ift.RGSpace(7, 12),
ift.HPSpace(4))
newshape = [factor*l for l in dom[space].shape]
op = ift.FieldZeroPadder(dom, newshape, space)
newshape = [int(factor*l) for l in dom[space].shape]
op = ift.FieldZeroPadder(dom, newshape, space, central)
ift.extra.consistency_check(op, dtype, dtype)
@expand(product([0, 2], [2, 2.7], [np.float64, np.complex128]))
def testZeroPadder2(self, space, factor, dtype):
dom = (ift.RGSpace(10), ift.UnstructuredDomain(13), ift.RGSpace(7, 12),
ift.HPSpace(4))
newshape = [factor*l for l in dom[space].shape]
newshape = [int(factor*l) for l in dom[space].shape]
op = ift.CentralZeroPadder(dom, newshape, space)
ift.extra.consistency_check(op, dtype, dtype)
......@@ -244,7 +244,7 @@ class Consistency_Tests(unittest.TestCase):
@expand([[ift.RGSpace((13, 52, 40)), (4, 6, 25), None],
[ift.RGSpace((128, 128)), (45, 48), 0],
[ift.RGSpace(13), (7,), None],
[(ift.HPSpace(3), ift.RGSpace((12, 24),distances=0.3)),
[(ift.HPSpace(3), ift.RGSpace((12, 24), distances=0.3)),
(12, 12), 1]])
def testRegridding(self, domain, shape, space):
op = ift.RegriddingOperator(domain, shape, space)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment