Commit aa58d6ae authored by Martin Reinecke's avatar Martin Reinecke Committed by Philipp Arras

FieldAdapter -> ducktape

parent 8bf014cf
...@@ -43,8 +43,8 @@ if __name__ == '__main__': ...@@ -43,8 +43,8 @@ if __name__ == '__main__':
vol = harmonic_space.scalar_dvol vol = harmonic_space.scalar_dvol
vol = ift.ScalingOperator(vol**(-0.5), harmonic_space) vol = ift.ScalingOperator(vol**(-0.5), harmonic_space)
correlated_field = ht(vol(power_distributor(A)) * correlated_field = ht(
ift.FieldAdapter(harmonic_space, "xi")) vol(power_distributor(A))*ift.ducktape(harmonic_space, None, 'xi'))
# alternatively to the block above one can do: # alternatively to the block above one can do:
#correlated_field = ift.CorrelatedField(position_space, A) #correlated_field = ift.CorrelatedField(position_space, A)
......
...@@ -48,7 +48,7 @@ from .operators.block_diagonal_operator import BlockDiagonalOperator ...@@ -48,7 +48,7 @@ from .operators.block_diagonal_operator import BlockDiagonalOperator
from .operators.outer_product_operator import OuterProduct from .operators.outer_product_operator import OuterProduct
from .operators.simple_linear_operators import ( from .operators.simple_linear_operators import (
VdotOperator, ConjugationOperator, Realizer, VdotOperator, ConjugationOperator, Realizer,
FieldAdapter, GeometryRemover, NullOperator) ducktape, GeometryRemover, NullOperator)
from .operators.energy_operators import ( from .operators.energy_operators import (
EnergyOperator, GaussianEnergy, PoissonianEnergy, InverseGammaLikelihood, EnergyOperator, GaussianEnergy, PoissonianEnergy, InverseGammaLikelihood,
BernoulliEnergy, Hamiltonian, SampledKullbachLeiblerDivergence) BernoulliEnergy, Hamiltonian, SampledKullbachLeiblerDivergence)
......
...@@ -25,7 +25,7 @@ from ..multi_field import MultiField ...@@ -25,7 +25,7 @@ from ..multi_field import MultiField
from ..operators.distributors import PowerDistributor from ..operators.distributors import PowerDistributor
from ..operators.energy_operators import Hamiltonian, InverseGammaLikelihood from ..operators.energy_operators import Hamiltonian, InverseGammaLikelihood
from ..operators.scaling_operator import ScalingOperator from ..operators.scaling_operator import ScalingOperator
from ..operators.simple_linear_operators import FieldAdapter from ..operators.simple_linear_operators import ducktape
def make_adjust_variances(a, def make_adjust_variances(a,
...@@ -87,7 +87,7 @@ def do_adjust_variances(position, ...@@ -87,7 +87,7 @@ def do_adjust_variances(position,
h_space = position[xi_key].domain[0] h_space = position[xi_key].domain[0]
pd = PowerDistributor(h_space, amplitude_model.target[0]) pd = PowerDistributor(h_space, amplitude_model.target[0])
a = pd(amplitude_model) a = pd(amplitude_model)
xi = FieldAdapter(h_space, xi_key) xi = ducktape(None, position.domain, xi_key)
ham = make_adjust_variances(a, xi, position, samples=samples) ham = make_adjust_variances(a, xi, position, samples=samples)
......
...@@ -135,7 +135,6 @@ def AmplitudeModel(s_space, Npixdof, ceps_a, ceps_k, sm, sv, im, iv, ...@@ -135,7 +135,6 @@ def AmplitudeModel(s_space, Npixdof, ceps_a, ceps_k, sm, sv, im, iv,
''' '''
from ..operators.exp_transform import ExpTransform from ..operators.exp_transform import ExpTransform
from ..operators.simple_linear_operators import FieldAdapter
from ..operators.scaling_operator import ScalingOperator from ..operators.scaling_operator import ScalingOperator
h_space = s_space.get_default_codomain() h_space = s_space.get_default_codomain()
...@@ -143,10 +142,9 @@ def AmplitudeModel(s_space, Npixdof, ceps_a, ceps_k, sm, sv, im, iv, ...@@ -143,10 +142,9 @@ def AmplitudeModel(s_space, Npixdof, ceps_a, ceps_k, sm, sv, im, iv,
logk_space = et.domain[0] logk_space = et.domain[0]
smooth = CepstrumOperator(logk_space, ceps_a, ceps_k, zero_mode) smooth = CepstrumOperator(logk_space, ceps_a, ceps_k, zero_mode)
smooth = smooth.ducktape(keys[0])
linear = SlopeModel(logk_space, sm, sv, im, iv) linear = SlopeModel(logk_space, sm, sv, im, iv)
linear = linear.ducktape(keys[1])
fa_smooth = FieldAdapter(smooth.domain, keys[0])
fa_linear = FieldAdapter(linear.domain, keys[1])
fac = ScalingOperator(0.5, smooth.target) fac = ScalingOperator(0.5, smooth.target)
return et((fac(smooth(fa_smooth) + linear(fa_linear))).exp()) return et((fac(smooth + linear)).exp())
...@@ -24,7 +24,7 @@ from ..multi_domain import MultiDomain ...@@ -24,7 +24,7 @@ from ..multi_domain import MultiDomain
from ..operators.contraction_operator import ContractionOperator from ..operators.contraction_operator import ContractionOperator
from ..operators.distributors import PowerDistributor from ..operators.distributors import PowerDistributor
from ..operators.harmonic_operators import HarmonicTransformOperator from ..operators.harmonic_operators import HarmonicTransformOperator
from ..operators.simple_linear_operators import FieldAdapter from ..operators.simple_linear_operators import ducktape
from ..operators.scaling_operator import ScalingOperator from ..operators.scaling_operator import ScalingOperator
...@@ -48,7 +48,7 @@ def CorrelatedField(s_space, amplitude_model, name='xi'): ...@@ -48,7 +48,7 @@ def CorrelatedField(s_space, amplitude_model, name='xi'):
A = power_distributor(amplitude_model) A = power_distributor(amplitude_model)
vol = h_space.scalar_dvol vol = h_space.scalar_dvol
vol = ScalingOperator(vol**(-0.5), h_space) vol = ScalingOperator(vol**(-0.5), h_space)
return ht(vol(A)*FieldAdapter(h_space, name)) return ht(vol(A)*ducktape(h_space, None, name))
def MfCorrelatedField(s_space_spatial, s_space_energy, amplitude_model_spatial, def MfCorrelatedField(s_space_spatial, s_space_energy, amplitude_model_spatial,
...@@ -77,4 +77,4 @@ def MfCorrelatedField(s_space_spatial, s_space_energy, amplitude_model_spatial, ...@@ -77,4 +77,4 @@ def MfCorrelatedField(s_space_spatial, s_space_energy, amplitude_model_spatial,
a_energy = dom_distr_energy(amplitude_model_energy) a_energy = dom_distr_energy(amplitude_model_energy)
a = a_spatial*a_energy a = a_spatial*a_energy
A = pd(a) A = pd(a)
return ht(A*FieldAdapter(h_space, name)) return ht(A*ducktape(h_space, None, name))
...@@ -52,8 +52,8 @@ class Linearization(object): ...@@ -52,8 +52,8 @@ class Linearization(object):
return self._metric return self._metric
def __getitem__(self, name): def __getitem__(self, name):
from .operators.simple_linear_operators import FieldAdapter from .operators.simple_linear_operators import ducktape
return self.new(self._val[name], FieldAdapter(self.domain[name], name)) return self.new(self._val[name], ducktape(None, self.domain, name))
def __neg__(self): def __neg__(self):
return self.new(-self._val, -self._jac, return self.new(-self._val, -self._jac,
......
...@@ -95,6 +95,14 @@ class Operator(NiftyMetaBase()): ...@@ -95,6 +95,14 @@ class Operator(NiftyMetaBase()):
return _OpChain.make((self, x)) return _OpChain.make((self, x))
return self.apply(x) return self.apply(x)
def ducktape(self, name):
from .simple_linear_operators import ducktape
return self(ducktape(self, None, name))
def ducktape_left(self, name):
from .simple_linear_operators import ducktape
return ducktape(None, self, name)(self)
def __repr__(self): def __repr__(self):
return self.__class__.__name__ return self.__class__.__name__
......
...@@ -62,28 +62,65 @@ class Realizer(EndomorphicOperator): ...@@ -62,28 +62,65 @@ class Realizer(EndomorphicOperator):
return x.real return x.real
class FieldAdapter(LinearOperator): class _FieldAdapter(LinearOperator):
"""Operator which extracts a field out of a MultiField. """Operator for conversion between Fields and MultiFields.
Parameters Parameters
---------- ----------
target : Domain, tuple of Domain or DomainTuple: tgt : Domain, tuple of Domain, DomainTuple, dict or MultiDomain:
The domain which shall be extracted out of the MultiField. If this is a Domain, tuple of Domain or DomainTuple, this will be the
operator's target, and its domain will be a MultiDomain consisting of
its domain with the supplied `name`
If this is a dict or MultiDomain, everything except for `name` will
be stripped out of it, and the result will be the operator's target.
Its domain will then be the DomainTuple corresponding to the single
entry in the operator's domain.
name : String name : String
The key of the MultiField which shall be extracted. The relevant key of the MultiDomain.
""" """
def __init__(self, target, name): def __init__(self, tgt, name):
self._target = DomainTuple.make(target) from ..sugar import makeDomain
self._domain = MultiDomain.make({name: self._target}) tmp = makeDomain(tgt)
if isinstance(tmp, DomainTuple):
self._target = tmp
self._domain = MultiDomain.make({name: tmp})
else:
self._domain = tmp[name]
self._target = MultiDomain.make({name: tmp[name]})
self._capability = self.TIMES | self.ADJOINT_TIMES self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode): def apply(self, x, mode):
self._check_input(x, mode) self._check_input(x, mode)
if mode == self.TIMES: if isinstance(x, MultiField):
return x.values()[0] return x.values()[0]
return MultiField(self._domain, (x,)) else:
return MultiField(self._tgt(mode), (x,))
def __repr__(self):
key = self.domain.keys()[0]
return 'FieldAdapter: {}'.format(key)
def ducktape(left, right, name):
from ..sugar import makeDomain
from .operator import Operator
if left is None: # need to infer left from right
if isinstance(right, Operator):
right = right.target
elif right is not None:
right = makeDomain(right)
if isinstance(right, MultiDomain):
left = right[name]
else:
left = MultiDomain.make({name: right})
else:
if isinstance(left, Operator):
left = left.domain
else:
left = makeDomain(left)
return _FieldAdapter(left, name)
class GeometryRemover(LinearOperator): class GeometryRemover(LinearOperator):
......
...@@ -62,29 +62,29 @@ class Model_Tests(unittest.TestCase): ...@@ -62,29 +62,29 @@ class Model_Tests(unittest.TestCase):
lin2 = self.make_linearization(type2, dom2, seed) lin2 = self.make_linearization(type2, dom2, seed)
dom = ift.MultiDomain.union((dom1, dom2)) dom = ift.MultiDomain.union((dom1, dom2))
model = ift.FieldAdapter(space, "s1")*ift.FieldAdapter(space, "s2") select_s1 = ift.ducktape(None, dom, "s1")
select_s2 = ift.ducktape(None, dom, "s2")
model = select_s1*select_s2
pos = ift.from_random("normal", dom) pos = ift.from_random("normal", dom)
ift.extra.check_value_gradient_consistency(model, pos, ntries=20) ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
model = ift.FieldAdapter(space, "s1")+ift.FieldAdapter(space, "s2") model = select_s1+select_s2
pos = ift.from_random("normal", dom) pos = ift.from_random("normal", dom)
ift.extra.check_value_gradient_consistency(model, pos, ntries=20) ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
model = ift.FieldAdapter(space, "s1").scale(3.) model = select_s1.scale(3.)
pos = ift.from_random("normal", dom1) pos = ift.from_random("normal", dom1)
ift.extra.check_value_gradient_consistency(model, pos, ntries=20) ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
model = ift.ScalingOperator(2.456, space)( model = ift.ScalingOperator(2.456, space)(select_s1*select_s2)
ift.FieldAdapter(space, "s1")*ift.FieldAdapter(space, "s2"))
pos = ift.from_random("normal", dom) pos = ift.from_random("normal", dom)
ift.extra.check_value_gradient_consistency(model, pos, ntries=20) ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
model = ift.positive_tanh(ift.ScalingOperator(2.456, space)( model = ift.positive_tanh(ift.ScalingOperator(2.456, space)(
ift.FieldAdapter(space, "s1")*ift.FieldAdapter(space, "s2"))) select_s1*select_s2))
pos = ift.from_random("normal", dom) pos = ift.from_random("normal", dom)
ift.extra.check_value_gradient_consistency(model, pos, ntries=20) ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
pos = ift.from_random("normal", dom) pos = ift.from_random("normal", dom)
model = ift.OuterProduct(pos['s1'], ift.makeDomain(space)) model = ift.OuterProduct(pos['s1'], ift.makeDomain(space))
ift.extra.check_value_gradient_consistency(model, pos['s2'], ntries=20) ift.extra.check_value_gradient_consistency(model, pos['s2'], ntries=20)
if isinstance(space, ift.RGSpace): if isinstance(space, ift.RGSpace):
model = ift.FFTOperator(space)( model = ift.FFTOperator(space)(select_s1*select_s2)
ift.FieldAdapter(space, "s1")*ift.FieldAdapter(space, "s2"))
pos = ift.from_random("normal", dom) pos = ift.from_random("normal", dom)
ift.extra.check_value_gradient_consistency(model, pos, ntries=20) ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment