Commit cdb27fee authored by Philipp Arras's avatar Philipp Arras

Docs and add multifrequency model to tests

parent 9d00734c
...@@ -15,64 +15,95 @@ ...@@ -15,64 +15,95 @@
# #
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from functools import reduce
from ..domain_tuple import DomainTuple from ..domain_tuple import DomainTuple
from ..operators.contraction_operator import ContractionOperator from ..operators.contraction_operator import ContractionOperator
from ..operators.distributors import PowerDistributor from ..operators.distributors import PowerDistributor
from ..operators.harmonic_operators import HarmonicTransformOperator from ..operators.harmonic_operators import HarmonicTransformOperator
from ..operators.simple_linear_operators import ducktape from ..operators.simple_linear_operators import ducktape
from ..operators.scaling_operator import ScalingOperator
def CorrelatedField(s_space, amplitude_operator, name='xi'): def CorrelatedField(target, amplitude_operator, name='xi'):
''' '''Constructs operator which turns white Gaussian excitation fields into a
Function for construction of correlated fields correlated field.
This function returns an operator which implements:
ht @ (vol * A * xi),
where `ht` is a harmonic transform operator, `A` is the sqare root of the
prior covariance an `xi` is the excitation field.
Parameters Parameters
---------- ----------
s_space : Domain target : Domain, DomainTuple or tuple of Domain
Field domain Target of the operator. Is not allowed to be a DomainTuple with more
than one space.
amplitude_operator: Operator amplitude_operator: Operator
operator for correlation structure
name : string name : string
MultiField component name :class:`MultiField` key for xi-field.
Returns
-------
Correlated field : Operator
''' '''
h_space = s_space.get_default_codomain() tgt = DomainTuple.make(target)
ht = HarmonicTransformOperator(h_space, s_space) if len(tgt) > 1:
raise ValueError
h_space = tgt[0].get_default_codomain()
ht = HarmonicTransformOperator(h_space, tgt[0])
p_space = amplitude_operator.target[0] p_space = amplitude_operator.target[0]
power_distributor = PowerDistributor(h_space, p_space) power_distributor = PowerDistributor(h_space, p_space)
A = power_distributor(amplitude_operator) A = power_distributor(amplitude_operator)
vol = h_space.scalar_dvol vol = h_space.scalar_dvol**-0.5
vol = ScalingOperator(vol**(-0.5), h_space) return ht(vol*A*ducktape(h_space, None, name))
return ht(vol(A)*ducktape(h_space, None, name))
def MfCorrelatedField(s_space_spatial, def MfCorrelatedField(target, amplitudes, name='xi'):
s_space_energy, '''Constructs operator which turns white Gaussian excitation fields into a
amplitude_operator_spatial, correlated field defined on a DomainTuple with two entries and two separate
amplitude_operator_energy, correlation structures.
name="xi"):
''' This operator may be used as model for multi-frequency reconstructions
Method for construction of correlated multi-frequency fields with a correlation structure in both spatial and energy direction.
Parameters
----------
target : Domain, DomainTuple or tuple of Domain
Target of the operator. Is not allowed to be a DomainTuple with more
than one space.
amplitudes: iterable of Operator
List of two amplitude operators.
name : string
:class:`MultiField` key for xi-field.
Returns
-------
Correlated field : Operator
''' '''
h_space_spatial = s_space_spatial.get_default_codomain() tgt = DomainTuple.make(target)
h_space_energy = s_space_energy.get_default_codomain() if len(tgt) != 2:
h_space = DomainTuple.make((h_space_spatial, h_space_energy)) raise ValueError
ht1 = HarmonicTransformOperator(h_space, target=s_space_spatial, space=0) if len(amplitudes) != 2:
ht2 = HarmonicTransformOperator(ht1.target, space=1) raise ValueError
ht = ht2(ht1)
p_space_spatial = amplitude_operator_spatial.target[0] hsp = DomainTuple.make([tt.get_default_codomain() for tt in tgt])
p_space_energy = amplitude_operator_energy.target[0] ht1 = HarmonicTransformOperator(hsp, target=tgt[0], space=0)
ht2 = HarmonicTransformOperator(ht1.target, space=1)
ht = ht2 @ ht1
pd_spatial = PowerDistributor(h_space, p_space_spatial, 0) psp = [aa.target[0] for aa in amplitudes]
pd_energy = PowerDistributor(pd_spatial.domain, p_space_energy, 1) pd0 = PowerDistributor(hsp, psp[0], 0)
pd = pd_spatial(pd_energy) pd1 = PowerDistributor(pd0.domain, psp[1], 1)
pd = pd0 @ pd1
dom_distr_spatial = ContractionOperator(pd.domain, 1).adjoint dd0 = ContractionOperator(pd.domain, 1).adjoint
dom_distr_energy = ContractionOperator(pd.domain, 0).adjoint dd1 = ContractionOperator(pd.domain, 0).adjoint
d = [dd0, dd1]
a_spatial = dom_distr_spatial(amplitude_operator_spatial) a = [dd @ amplitudes[ii] for ii, dd in enumerate(d)]
a_energy = dom_distr_energy(amplitude_operator_energy) a = reduce(lambda x, y: x*y, a)
a = a_spatial*a_energy A = pd @ a
A = pd(a) vol = reduce(lambda x, y: x*y, [sp.scalar_dvol**-0.5 for sp in hsp])
return ht(A*ducktape(h_space, None, name)) return ht(vol*A*ducktape(hsp, None, name))
...@@ -103,6 +103,12 @@ def testModelLibrary(space, seed): ...@@ -103,6 +103,12 @@ def testModelLibrary(space, seed):
pos = S.draw_sample() pos = S.draw_sample()
ift.extra.check_value_gradient_consistency(model2, pos, ntries=20) ift.extra.check_value_gradient_consistency(model2, pos, ntries=20)
domtup = ift.DomainTuple.make((space, space))
model3 = ift.MfCorrelatedField(domtup, [model, model])
S = ift.ScalingOperator(1., model3.domain)
pos = S.draw_sample()
ift.extra.check_value_gradient_consistency(model3, pos, ntries=20)
def testPointModel(space, seed): def testPointModel(space, seed):
S = ift.ScalingOperator(1., space) S = ift.ScalingOperator(1., space)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment