Commit cdb27fee authored by Philipp Arras's avatar Philipp Arras
Browse files

Docs and add multifrequency model to tests

parent 9d00734c
......@@ -15,64 +15,95 @@
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from functools import reduce
from ..domain_tuple import DomainTuple
from ..operators.contraction_operator import ContractionOperator
from ..operators.distributors import PowerDistributor
from ..operators.harmonic_operators import HarmonicTransformOperator
from ..operators.simple_linear_operators import ducktape
from ..operators.scaling_operator import ScalingOperator
def CorrelatedField(s_space, amplitude_operator, name='xi'):
'''
Function for construction of correlated fields
def CorrelatedField(target, amplitude_operator, name='xi'):
'''Constructs operator which turns white Gaussian excitation fields into a
correlated field.
This function returns an operator which implements:
ht @ (vol * A * xi),
where `ht` is a harmonic transform operator, `A` is the sqare root of the
prior covariance an `xi` is the excitation field.
Parameters
----------
s_space : Domain
Field domain
target : Domain, DomainTuple or tuple of Domain
Target of the operator. Is not allowed to be a DomainTuple with more
than one space.
amplitude_operator: Operator
operator for correlation structure
name : string
MultiField component name
:class:`MultiField` key for xi-field.
Returns
-------
Correlated field : Operator
'''
h_space = s_space.get_default_codomain()
ht = HarmonicTransformOperator(h_space, s_space)
tgt = DomainTuple.make(target)
if len(tgt) > 1:
raise ValueError
h_space = tgt[0].get_default_codomain()
ht = HarmonicTransformOperator(h_space, tgt[0])
p_space = amplitude_operator.target[0]
power_distributor = PowerDistributor(h_space, p_space)
A = power_distributor(amplitude_operator)
vol = h_space.scalar_dvol
vol = ScalingOperator(vol**(-0.5), h_space)
return ht(vol(A)*ducktape(h_space, None, name))
vol = h_space.scalar_dvol**-0.5
return ht(vol*A*ducktape(h_space, None, name))
def MfCorrelatedField(s_space_spatial,
s_space_energy,
amplitude_operator_spatial,
amplitude_operator_energy,
name="xi"):
'''
Method for construction of correlated multi-frequency fields
def MfCorrelatedField(target, amplitudes, name='xi'):
'''Constructs operator which turns white Gaussian excitation fields into a
correlated field defined on a DomainTuple with two entries and two separate
correlation structures.
This operator may be used as model for multi-frequency reconstructions
with a correlation structure in both spatial and energy direction.
Parameters
----------
target : Domain, DomainTuple or tuple of Domain
Target of the operator. Is not allowed to be a DomainTuple with more
than one space.
amplitudes: iterable of Operator
List of two amplitude operators.
name : string
:class:`MultiField` key for xi-field.
Returns
-------
Correlated field : Operator
'''
h_space_spatial = s_space_spatial.get_default_codomain()
h_space_energy = s_space_energy.get_default_codomain()
h_space = DomainTuple.make((h_space_spatial, h_space_energy))
ht1 = HarmonicTransformOperator(h_space, target=s_space_spatial, space=0)
ht2 = HarmonicTransformOperator(ht1.target, space=1)
ht = ht2(ht1)
tgt = DomainTuple.make(target)
if len(tgt) != 2:
raise ValueError
if len(amplitudes) != 2:
raise ValueError
p_space_spatial = amplitude_operator_spatial.target[0]
p_space_energy = amplitude_operator_energy.target[0]
hsp = DomainTuple.make([tt.get_default_codomain() for tt in tgt])
ht1 = HarmonicTransformOperator(hsp, target=tgt[0], space=0)
ht2 = HarmonicTransformOperator(ht1.target, space=1)
ht = ht2 @ ht1
pd_spatial = PowerDistributor(h_space, p_space_spatial, 0)
pd_energy = PowerDistributor(pd_spatial.domain, p_space_energy, 1)
pd = pd_spatial(pd_energy)
psp = [aa.target[0] for aa in amplitudes]
pd0 = PowerDistributor(hsp, psp[0], 0)
pd1 = PowerDistributor(pd0.domain, psp[1], 1)
pd = pd0 @ pd1
dom_distr_spatial = ContractionOperator(pd.domain, 1).adjoint
dom_distr_energy = ContractionOperator(pd.domain, 0).adjoint
dd0 = ContractionOperator(pd.domain, 1).adjoint
dd1 = ContractionOperator(pd.domain, 0).adjoint
d = [dd0, dd1]
a_spatial = dom_distr_spatial(amplitude_operator_spatial)
a_energy = dom_distr_energy(amplitude_operator_energy)
a = a_spatial*a_energy
A = pd(a)
return ht(A*ducktape(h_space, None, name))
a = [dd @ amplitudes[ii] for ii, dd in enumerate(d)]
a = reduce(lambda x, y: x*y, a)
A = pd @ a
vol = reduce(lambda x, y: x*y, [sp.scalar_dvol**-0.5 for sp in hsp])
return ht(vol*A*ducktape(hsp, None, name))
......@@ -103,6 +103,12 @@ def testModelLibrary(space, seed):
pos = S.draw_sample()
ift.extra.check_value_gradient_consistency(model2, pos, ntries=20)
domtup = ift.DomainTuple.make((space, space))
model3 = ift.MfCorrelatedField(domtup, [model, model])
S = ift.ScalingOperator(1., model3.domain)
pos = S.draw_sample()
ift.extra.check_value_gradient_consistency(model3, pos, ntries=20)
def testPointModel(space, seed):
S = ift.ScalingOperator(1., space)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment