Commit aa58d6ae authored by Martin Reinecke's avatar Martin Reinecke Committed by Philipp Arras

FieldAdapter -> ducktape

parent 8bf014cf
......@@ -43,8 +43,8 @@ if __name__ == '__main__':
vol = harmonic_space.scalar_dvol
vol = ift.ScalingOperator(vol**(-0.5), harmonic_space)
correlated_field = ht(vol(power_distributor(A)) *
ift.FieldAdapter(harmonic_space, "xi"))
correlated_field = ht(
vol(power_distributor(A))*ift.ducktape(harmonic_space, None, 'xi'))
# alternatively to the block above one can do:
#correlated_field = ift.CorrelatedField(position_space, A)
......
......@@ -48,7 +48,7 @@ from .operators.block_diagonal_operator import BlockDiagonalOperator
from .operators.outer_product_operator import OuterProduct
from .operators.simple_linear_operators import (
VdotOperator, ConjugationOperator, Realizer,
FieldAdapter, GeometryRemover, NullOperator)
ducktape, GeometryRemover, NullOperator)
from .operators.energy_operators import (
EnergyOperator, GaussianEnergy, PoissonianEnergy, InverseGammaLikelihood,
BernoulliEnergy, Hamiltonian, SampledKullbachLeiblerDivergence)
......
......@@ -25,7 +25,7 @@ from ..multi_field import MultiField
from ..operators.distributors import PowerDistributor
from ..operators.energy_operators import Hamiltonian, InverseGammaLikelihood
from ..operators.scaling_operator import ScalingOperator
from ..operators.simple_linear_operators import FieldAdapter
from ..operators.simple_linear_operators import ducktape
def make_adjust_variances(a,
......@@ -87,7 +87,7 @@ def do_adjust_variances(position,
h_space = position[xi_key].domain[0]
pd = PowerDistributor(h_space, amplitude_model.target[0])
a = pd(amplitude_model)
xi = FieldAdapter(h_space, xi_key)
xi = ducktape(None, position.domain, xi_key)
ham = make_adjust_variances(a, xi, position, samples=samples)
......
......@@ -135,7 +135,6 @@ def AmplitudeModel(s_space, Npixdof, ceps_a, ceps_k, sm, sv, im, iv,
'''
from ..operators.exp_transform import ExpTransform
from ..operators.simple_linear_operators import FieldAdapter
from ..operators.scaling_operator import ScalingOperator
h_space = s_space.get_default_codomain()
......@@ -143,10 +142,9 @@ def AmplitudeModel(s_space, Npixdof, ceps_a, ceps_k, sm, sv, im, iv,
logk_space = et.domain[0]
smooth = CepstrumOperator(logk_space, ceps_a, ceps_k, zero_mode)
smooth = smooth.ducktape(keys[0])
linear = SlopeModel(logk_space, sm, sv, im, iv)
fa_smooth = FieldAdapter(smooth.domain, keys[0])
fa_linear = FieldAdapter(linear.domain, keys[1])
linear = linear.ducktape(keys[1])
fac = ScalingOperator(0.5, smooth.target)
return et((fac(smooth(fa_smooth) + linear(fa_linear))).exp())
return et((fac(smooth + linear)).exp())
......@@ -24,7 +24,7 @@ from ..multi_domain import MultiDomain
from ..operators.contraction_operator import ContractionOperator
from ..operators.distributors import PowerDistributor
from ..operators.harmonic_operators import HarmonicTransformOperator
from ..operators.simple_linear_operators import FieldAdapter
from ..operators.simple_linear_operators import ducktape
from ..operators.scaling_operator import ScalingOperator
......@@ -48,7 +48,7 @@ def CorrelatedField(s_space, amplitude_model, name='xi'):
A = power_distributor(amplitude_model)
vol = h_space.scalar_dvol
vol = ScalingOperator(vol**(-0.5), h_space)
return ht(vol(A)*FieldAdapter(h_space, name))
return ht(vol(A)*ducktape(h_space, None, name))
def MfCorrelatedField(s_space_spatial, s_space_energy, amplitude_model_spatial,
......@@ -77,4 +77,4 @@ def MfCorrelatedField(s_space_spatial, s_space_energy, amplitude_model_spatial,
a_energy = dom_distr_energy(amplitude_model_energy)
a = a_spatial*a_energy
A = pd(a)
return ht(A*FieldAdapter(h_space, name))
return ht(A*ducktape(h_space, None, name))
......@@ -52,8 +52,8 @@ class Linearization(object):
return self._metric
def __getitem__(self, name):
from .operators.simple_linear_operators import FieldAdapter
return self.new(self._val[name], FieldAdapter(self.domain[name], name))
from .operators.simple_linear_operators import ducktape
return self.new(self._val[name], ducktape(None, self.domain, name))
def __neg__(self):
return self.new(-self._val, -self._jac,
......
......@@ -95,6 +95,14 @@ class Operator(NiftyMetaBase()):
return _OpChain.make((self, x))
return self.apply(x)
def ducktape(self, name):
from .simple_linear_operators import ducktape
return self(ducktape(self, None, name))
def ducktape_left(self, name):
from .simple_linear_operators import ducktape
return ducktape(None, self, name)(self)
def __repr__(self):
return self.__class__.__name__
......
......@@ -62,28 +62,65 @@ class Realizer(EndomorphicOperator):
return x.real
class FieldAdapter(LinearOperator):
"""Operator which extracts a field out of a MultiField.
class _FieldAdapter(LinearOperator):
"""Operator for conversion between Fields and MultiFields.
Parameters
----------
target : Domain, tuple of Domain or DomainTuple:
The domain which shall be extracted out of the MultiField.
tgt : Domain, tuple of Domain, DomainTuple, dict or MultiDomain:
If this is a Domain, tuple of Domain or DomainTuple, this will be the
operator's target, and its domain will be a MultiDomain consisting of
its domain with the supplied `name`
If this is a dict or MultiDomain, everything except for `name` will
be stripped out of it, and the result will be the operator's target.
Its domain will then be the DomainTuple corresponding to the single
entry in the operator's domain.
name : String
The key of the MultiField which shall be extracted.
The relevant key of the MultiDomain.
"""
def __init__(self, target, name):
self._target = DomainTuple.make(target)
self._domain = MultiDomain.make({name: self._target})
def __init__(self, tgt, name):
from ..sugar import makeDomain
tmp = makeDomain(tgt)
if isinstance(tmp, DomainTuple):
self._target = tmp
self._domain = MultiDomain.make({name: tmp})
else:
self._domain = tmp[name]
self._target = MultiDomain.make({name: tmp[name]})
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
if isinstance(x, MultiField):
return x.values()[0]
return MultiField(self._domain, (x,))
else:
return MultiField(self._tgt(mode), (x,))
def __repr__(self):
key = self.domain.keys()[0]
return 'FieldAdapter: {}'.format(key)
def ducktape(left, right, name):
from ..sugar import makeDomain
from .operator import Operator
if left is None: # need to infer left from right
if isinstance(right, Operator):
right = right.target
elif right is not None:
right = makeDomain(right)
if isinstance(right, MultiDomain):
left = right[name]
else:
left = MultiDomain.make({name: right})
else:
if isinstance(left, Operator):
left = left.domain
else:
left = makeDomain(left)
return _FieldAdapter(left, name)
class GeometryRemover(LinearOperator):
......
......@@ -62,29 +62,29 @@ class Model_Tests(unittest.TestCase):
lin2 = self.make_linearization(type2, dom2, seed)
dom = ift.MultiDomain.union((dom1, dom2))
model = ift.FieldAdapter(space, "s1")*ift.FieldAdapter(space, "s2")
select_s1 = ift.ducktape(None, dom, "s1")
select_s2 = ift.ducktape(None, dom, "s2")
model = select_s1*select_s2
pos = ift.from_random("normal", dom)
ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
model = ift.FieldAdapter(space, "s1")+ift.FieldAdapter(space, "s2")
model = select_s1+select_s2
pos = ift.from_random("normal", dom)
ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
model = ift.FieldAdapter(space, "s1").scale(3.)
model = select_s1.scale(3.)
pos = ift.from_random("normal", dom1)
ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
model = ift.ScalingOperator(2.456, space)(
ift.FieldAdapter(space, "s1")*ift.FieldAdapter(space, "s2"))
model = ift.ScalingOperator(2.456, space)(select_s1*select_s2)
pos = ift.from_random("normal", dom)
ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
model = ift.positive_tanh(ift.ScalingOperator(2.456, space)(
ift.FieldAdapter(space, "s1")*ift.FieldAdapter(space, "s2")))
select_s1*select_s2))
pos = ift.from_random("normal", dom)
ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
pos = ift.from_random("normal", dom)
model = ift.OuterProduct(pos['s1'], ift.makeDomain(space))
ift.extra.check_value_gradient_consistency(model, pos['s2'], ntries=20)
if isinstance(space, ift.RGSpace):
model = ift.FFTOperator(space)(
ift.FieldAdapter(space, "s1")*ift.FieldAdapter(space, "s2"))
model = ift.FFTOperator(space)(select_s1*select_s2)
pos = ift.from_random("normal", dom)
ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment