Commit 847536fd authored by Philipp Arras's avatar Philipp Arras

Simplify FieldAdapter

parent bdef7437
......@@ -33,20 +33,14 @@ if __name__ == '__main__':
# Setting up an amplitude model
A = ift.AmplitudeModel(position_space, 16, 1, 10, -4., 1, 0., 1.)
dummy = ift.from_random('normal', A.domain)
# Building the model for a correlated signal
harmonic_space = position_space.get_default_codomain()
ht = ift.HarmonicTransformOperator(harmonic_space, position_space)
power_space = A.target[0]
power_distributor = ift.PowerDistributor(harmonic_space, power_space)
dummy = ift.Field.from_random('normal', harmonic_space)
domain = ift.MultiDomain.union((A.domain,
ift.MultiDomain.make({
'xi': harmonic_space
})))
correlated_field = ht(power_distributor(A)*ift.FieldAdapter(domain, "xi"))
correlated_field = ht(power_distributor(A)*ift.FieldAdapter(harmonic_space, "xi"))
# alternatively to the block above one can do:
# correlated_field = ift.CorrelatedField(position_space, A)
......@@ -64,7 +58,7 @@ if __name__ == '__main__':
N = ift.ScalingOperator(noise, data_space)
# generate mock data
MOCK_POSITION = ift.from_random('normal', domain)
MOCK_POSITION = ift.from_random('normal', signal_response.domain)
data = signal_response(MOCK_POSITION) + N.draw_sample()
# set up model likelihood
......@@ -79,7 +73,7 @@ if __name__ == '__main__':
# build model Hamiltonian
H = ift.Hamiltonian(likelihood, ic_sampling)
INITIAL_POSITION = ift.from_random('normal', domain)
INITIAL_POSITION = ift.from_random('normal', H.domain)
position = INITIAL_POSITION
plot = ift.Plot()
......
......@@ -86,7 +86,7 @@ def do_adjust_variances(position,
h_space = position[xi_key].domain[0]
pd = PowerDistributor(h_space, amplitude_model.target[0])
a = pd(amplitude_model)
xi = FieldAdapter(MultiDomain.make({xi_key: h_space}), xi_key)
xi = FieldAdapter(h_space, xi_key)
axi_domain = MultiDomain.union([a.domain, xi.domain])
ham = make_adjust_variances(
......
......@@ -45,7 +45,7 @@ def CorrelatedField(s_space, amplitude_model, name='xi'):
p_space = amplitude_model.target[0]
power_distributor = PowerDistributor(h_space, p_space)
A = power_distributor(amplitude_model)
return ht(A*FieldAdapter(MultiDomain.make({name: h_space}), name))
return ht(A*FieldAdapter(h_space, name))
def MfCorrelatedField(s_space_spatial, s_space_energy, amplitude_model_spatial,
......@@ -74,4 +74,4 @@ def MfCorrelatedField(s_space_spatial, s_space_energy, amplitude_model_spatial,
a_energy = dom_distr_energy(amplitude_model_energy)
a = a_spatial*a_energy
A = pd(a)
return ht(A*FieldAdapter(MultiDomain.make({name: h_space}), name))
return ht(A*FieldAdapter(h_space, name))
......@@ -52,7 +52,7 @@ class Linearization(object):
def __getitem__(self, name):
from .operators.simple_linear_operators import FieldAdapter
return self.new(self._val[name], FieldAdapter(self.domain, name))
return self.new(self._val[name], FieldAdapter(self.domain[name], name))
def __neg__(self):
return self.new(-self._val, -self._jac,
......
......@@ -64,14 +64,24 @@ class Realizer(EndomorphicOperator):
class FieldAdapter(LinearOperator):
def __init__(self, dom, name):
self._target = dom[name]
"""Operator which extracts a field out of a MultiField.
Parameters
----------
target : Domain, tuple of Domain or DomainTuple:
The domain which shall be extracted out of the MultiField.
name : String
The key of the MultiField which shall be extracted.
"""
def __init__(self, target, name):
self._target = DomainTuple.make(target)
self._domain = MultiDomain.make({name: self._target})
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
return x.values()[0]
return MultiField(self._domain, (x,))
......
......@@ -62,21 +62,21 @@ class Model_Tests(unittest.TestCase):
lin2 = self.make_linearization(type2, dom2, seed)
dom = ift.MultiDomain.union((dom1, dom2))
model = ift.FieldAdapter(dom, "s1")*ift.FieldAdapter(dom, "s2")
model = ift.FieldAdapter(space, "s1")*ift.FieldAdapter(space, "s2")
pos = ift.from_random("normal", dom)
ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
model = ift.FieldAdapter(dom, "s1")+ift.FieldAdapter(dom, "s2")
model = ift.FieldAdapter(space, "s1")+ift.FieldAdapter(space, "s2")
pos = ift.from_random("normal", dom)
ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
model = ift.FieldAdapter(dom, "s1").scale(3.)
model = ift.FieldAdapter(space, "s1").scale(3.)
pos = ift.from_random("normal", dom1)
ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
model = ift.ScalingOperator(2.456, space)(
ift.FieldAdapter(dom, "s1")*ift.FieldAdapter(dom, "s2"))
ift.FieldAdapter(space, "s1")*ift.FieldAdapter(space, "s2"))
pos = ift.from_random("normal", dom)
ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
model = ift.positive_tanh(ift.ScalingOperator(2.456, space)(
ift.FieldAdapter(dom, "s1")*ift.FieldAdapter(dom, "s2")))
ift.FieldAdapter(space, "s1")*ift.FieldAdapter(space, "s2")))
pos = ift.from_random("normal", dom)
ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
pos = ift.from_random("normal", dom)
......@@ -84,7 +84,7 @@ class Model_Tests(unittest.TestCase):
ift.extra.check_value_gradient_consistency(model, pos['s2'], ntries=20)
if isinstance(space, ift.RGSpace):
model = ift.FFTOperator(space)(
ift.FieldAdapter(dom, "s1")*ift.FieldAdapter(dom, "s2"))
ift.FieldAdapter(space, "s1")*ift.FieldAdapter(space, "s2"))
pos = ift.from_random("normal", dom)
ift.extra.check_value_gradient_consistency(model, pos, ntries=20)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment