Commit f97a9502 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'ducktape' into 'NIFTy_5'

FieldAdapter -> ducktape

See merge request ift/nifty-dev!137
parents e7c025ed 53afdd10
......@@ -43,8 +43,8 @@ if __name__ == '__main__':
vol = harmonic_space.scalar_dvol
vol = ift.ScalingOperator(vol**(-0.5), harmonic_space)
correlated_field = ht(vol(power_distributor(A)) *
ift.FieldAdapter(harmonic_space, "xi"))
correlated_field = ht(
vol(power_distributor(A))*ift.ducktape(harmonic_space, None, 'xi'))
# alternatively to the block above one can do:
#correlated_field = ift.CorrelatedField(position_space, A)
......@@ -48,7 +48,7 @@ from .operators.block_diagonal_operator import BlockDiagonalOperator
from .operators.outer_product_operator import OuterProduct
from .operators.simple_linear_operators import (
VdotOperator, ConjugationOperator, Realizer,
FieldAdapter, GeometryRemover, NullOperator)
FieldAdapter, ducktape, GeometryRemover, NullOperator)
from .operators.energy_operators import (
EnergyOperator, GaussianEnergy, PoissonianEnergy, InverseGammaLikelihood,
BernoulliEnergy, Hamiltonian, SampledKullbachLeiblerDivergence)
......@@ -25,7 +25,7 @@ from ..multi_field import MultiField
from ..operators.distributors import PowerDistributor
from ..operators.energy_operators import Hamiltonian, InverseGammaLikelihood
from ..operators.scaling_operator import ScalingOperator
from ..operators.simple_linear_operators import FieldAdapter
from ..operators.simple_linear_operators import ducktape
def make_adjust_variances(a,
......@@ -87,7 +87,7 @@ def do_adjust_variances(position,
h_space = position[xi_key].domain[0]
pd = PowerDistributor(h_space,[0])
a = pd(amplitude_model)
xi = FieldAdapter(h_space, xi_key)
xi = ducktape(None, position.domain, xi_key)
ham = make_adjust_variances(a, xi, position, samples=samples)
......@@ -135,7 +135,6 @@ def AmplitudeModel(s_space, Npixdof, ceps_a, ceps_k, sm, sv, im, iv,
from ..operators.exp_transform import ExpTransform
from ..operators.simple_linear_operators import FieldAdapter
from ..operators.scaling_operator import ScalingOperator
h_space = s_space.get_default_codomain()
......@@ -143,10 +142,9 @@ def AmplitudeModel(s_space, Npixdof, ceps_a, ceps_k, sm, sv, im, iv,
logk_space = et.domain[0]
smooth = CepstrumOperator(logk_space, ceps_a, ceps_k, zero_mode)
smooth = smooth.ducktape(keys[0])
linear = SlopeModel(logk_space, sm, sv, im, iv)
fa_smooth = FieldAdapter(smooth.domain, keys[0])
fa_linear = FieldAdapter(linear.domain, keys[1])
linear = linear.ducktape(keys[1])
fac = ScalingOperator(0.5,
return et((fac(smooth(fa_smooth) + linear(fa_linear))).exp())
return et((fac(smooth + linear)).exp())
......@@ -24,7 +24,7 @@ from ..multi_domain import MultiDomain
from ..operators.contraction_operator import ContractionOperator
from ..operators.distributors import PowerDistributor
from ..operators.harmonic_operators import HarmonicTransformOperator
from ..operators.simple_linear_operators import FieldAdapter
from ..operators.simple_linear_operators import ducktape
from ..operators.scaling_operator import ScalingOperator
......@@ -48,7 +48,7 @@ def CorrelatedField(s_space, amplitude_model, name='xi'):
A = power_distributor(amplitude_model)
vol = h_space.scalar_dvol
vol = ScalingOperator(vol**(-0.5), h_space)
return ht(vol(A)*FieldAdapter(h_space, name))
return ht(vol(A)*ducktape(h_space, None, name))
def MfCorrelatedField(s_space_spatial, s_space_energy, amplitude_model_spatial,
......@@ -77,4 +77,4 @@ def MfCorrelatedField(s_space_spatial, s_space_energy, amplitude_model_spatial,
a_energy = dom_distr_energy(amplitude_model_energy)
a = a_spatial*a_energy
A = pd(a)
return ht(A*FieldAdapter(h_space, name))
return ht(A*ducktape(h_space, None, name))
......@@ -52,8 +52,8 @@ class Linearization(object):
return self._metric
def __getitem__(self, name):
from .operators.simple_linear_operators import FieldAdapter
return[name], FieldAdapter(self.domain[name], name))
from .operators.simple_linear_operators import ducktape
return[name], ducktape(None, self.domain, name))
def __neg__(self):
return, -self._jac,
......@@ -10,6 +10,8 @@ class KL_Energy(Energy):
def __init__(self, position, h, nsamp, constants=[],
constants_samples=None, _samples=None):
super(KL_Energy, self).__init__(position)
if h.domain is not position.domain:
raise TypeError
self._h = h
self._constants = constants
if constants_samples is None:
......@@ -95,6 +95,14 @@ class Operator(NiftyMetaBase()):
return _OpChain.make((self, x))
return self.apply(x)
def ducktape(self, name):
from .simple_linear_operators import ducktape
return self(ducktape(self, None, name))
def ducktape_left(self, name):
from .simple_linear_operators import ducktape
return ducktape(None, self, name)(self)
def __repr__(self):
return self.__class__.__name__
......@@ -63,27 +63,64 @@ class Realizer(EndomorphicOperator):
class FieldAdapter(LinearOperator):
"""Operator which extracts a field out of a MultiField.
"""Operator for conversion between Fields and MultiFields.
target : Domain, tuple of Domain or DomainTuple:
The domain which shall be extracted out of the MultiField.
tgt : Domain, tuple of Domain, DomainTuple, dict or MultiDomain:
If this is a Domain, tuple of Domain or DomainTuple, this will be the
operator's target, and its domain will be a MultiDomain consisting of
its domain with the supplied `name`
If this is a dict or MultiDomain, everything except for `name` will
be stripped out of it, and the result will be the operator's target.
Its domain will then be the DomainTuple corresponding to the single
entry in the operator's domain.
name : String
The key of the MultiField which shall be extracted.
The relevant key of the MultiDomain.
def __init__(self, target, name):
self._target = DomainTuple.make(target)
self._domain = MultiDomain.make({name: self._target})
def __init__(self, tgt, name):
from ..sugar import makeDomain
tmp = makeDomain(tgt)
if isinstance(tmp, DomainTuple):
self._target = tmp
self._domain = MultiDomain.make({name: tmp})
self._domain = tmp[name]
self._target = MultiDomain.make({name: tmp[name]})
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
if isinstance(x, MultiField):
return x.values()[0]
return MultiField(self._domain, (x,))
return MultiField(self._tgt(mode), (x,))
def __repr__(self):
key = self.domain.keys()[0]
return 'FieldAdapter: {}'.format(key)
def ducktape(left, right, name):
from ..sugar import makeDomain
from .operator import Operator
if left is None: # need to infer left from right
if isinstance(right, Operator):
right =
elif right is not None:
right = makeDomain(right)
if isinstance(right, MultiDomain):
left = right[name]
left = MultiDomain.make({name: right})
if isinstance(left, Operator):
left = left.domain
left = makeDomain(left)
return FieldAdapter(left, name)
class GeometryRemover(LinearOperator):
......@@ -62,29 +62,29 @@ class Model_Tests(unittest.TestCase):
lin2 = self.make_linearization(type2, dom2, seed)
dom = ift.MultiDomain.union((dom1, dom2))
model = ift.FieldAdapter(space, "s1")*ift.FieldAdapter(space, "s2")
select_s1 = ift.ducktape(None, dom, "s1")
select_s2 = ift.ducktape(None, dom, "s2")
model = select_s1*select_s2
pos = ift.from_random("normal", dom)
ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
model = ift.FieldAdapter(space, "s1")+ift.FieldAdapter(space, "s2")
model = select_s1+select_s2
pos = ift.from_random("normal", dom)
ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
model = ift.FieldAdapter(space, "s1").scale(3.)
model = select_s1.scale(3.)
pos = ift.from_random("normal", dom1)
ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
model = ift.ScalingOperator(2.456, space)(
ift.FieldAdapter(space, "s1")*ift.FieldAdapter(space, "s2"))
model = ift.ScalingOperator(2.456, space)(select_s1*select_s2)
pos = ift.from_random("normal", dom)
ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
model = ift.positive_tanh(ift.ScalingOperator(2.456, space)(
ift.FieldAdapter(space, "s1")*ift.FieldAdapter(space, "s2")))
pos = ift.from_random("normal", dom)
ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
pos = ift.from_random("normal", dom)
model = ift.OuterProduct(pos['s1'], ift.makeDomain(space))
ift.extra.check_value_gradient_consistency(model, pos['s2'], ntries=20)
if isinstance(space, ift.RGSpace):
model = ift.FFTOperator(space)(
ift.FieldAdapter(space, "s1")*ift.FieldAdapter(space, "s2"))
model = ift.FFTOperator(space)(select_s1*select_s2)
pos = ift.from_random("normal", dom)
ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
......@@ -59,10 +59,10 @@ class Consistency_Tests(unittest.TestCase):
ift.extra.consistency_check(op, dtype, dtype)
def testLinearInterpolator(self):
sp = ift.RGSpace((10,8), distances=(0.1, 3.5))
sp = ift.RGSpace((10, 8), distances=(0.1, 3.5))
pos = np.random.rand(2, 23)
pos[0,:] *= 0.9
pos[1,:] *= 7*3.5
pos[0, :] *= 0.9
pos[1, :] *= 7*3.5
op = ift.LinearInterpolator(sp, pos)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment