Skip to content
Snippets Groups Projects
Commit cdb27fee authored by Philipp Arras's avatar Philipp Arras
Browse files

Docs and add multifrequency model to tests

parent 9d00734c
No related branches found
No related tags found
No related merge requests found
...@@ -15,64 +15,95 @@ ...@@ -15,64 +15,95 @@
# #
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from functools import reduce
from ..domain_tuple import DomainTuple from ..domain_tuple import DomainTuple
from ..operators.contraction_operator import ContractionOperator from ..operators.contraction_operator import ContractionOperator
from ..operators.distributors import PowerDistributor from ..operators.distributors import PowerDistributor
from ..operators.harmonic_operators import HarmonicTransformOperator from ..operators.harmonic_operators import HarmonicTransformOperator
from ..operators.simple_linear_operators import ducktape from ..operators.simple_linear_operators import ducktape
from ..operators.scaling_operator import ScalingOperator
def CorrelatedField(s_space, amplitude_operator, name='xi'): def CorrelatedField(target, amplitude_operator, name='xi'):
''' '''Constructs operator which turns white Gaussian excitation fields into a
Function for construction of correlated fields correlated field.
This function returns an operator which implements:
ht @ (vol * A * xi),
where `ht` is a harmonic transform operator, `A` is the sqare root of the
prior covariance an `xi` is the excitation field.
Parameters Parameters
---------- ----------
s_space : Domain target : Domain, DomainTuple or tuple of Domain
Field domain Target of the operator. Is not allowed to be a DomainTuple with more
than one space.
amplitude_operator: Operator amplitude_operator: Operator
operator for correlation structure
name : string name : string
MultiField component name :class:`MultiField` key for xi-field.
Returns
-------
Correlated field : Operator
''' '''
h_space = s_space.get_default_codomain() tgt = DomainTuple.make(target)
ht = HarmonicTransformOperator(h_space, s_space) if len(tgt) > 1:
raise ValueError
h_space = tgt[0].get_default_codomain()
ht = HarmonicTransformOperator(h_space, tgt[0])
p_space = amplitude_operator.target[0] p_space = amplitude_operator.target[0]
power_distributor = PowerDistributor(h_space, p_space) power_distributor = PowerDistributor(h_space, p_space)
A = power_distributor(amplitude_operator) A = power_distributor(amplitude_operator)
vol = h_space.scalar_dvol vol = h_space.scalar_dvol**-0.5
vol = ScalingOperator(vol**(-0.5), h_space) return ht(vol*A*ducktape(h_space, None, name))
return ht(vol(A)*ducktape(h_space, None, name))
def MfCorrelatedField(s_space_spatial, def MfCorrelatedField(target, amplitudes, name='xi'):
s_space_energy, '''Constructs operator which turns white Gaussian excitation fields into a
amplitude_operator_spatial, correlated field defined on a DomainTuple with two entries and two separate
amplitude_operator_energy, correlation structures.
name="xi"):
''' This operator may be used as model for multi-frequency reconstructions
Method for construction of correlated multi-frequency fields with a correlation structure in both spatial and energy direction.
Parameters
----------
target : Domain, DomainTuple or tuple of Domain
Target of the operator. Is not allowed to be a DomainTuple with more
than one space.
amplitudes: iterable of Operator
List of two amplitude operators.
name : string
:class:`MultiField` key for xi-field.
Returns
-------
Correlated field : Operator
''' '''
h_space_spatial = s_space_spatial.get_default_codomain() tgt = DomainTuple.make(target)
h_space_energy = s_space_energy.get_default_codomain() if len(tgt) != 2:
h_space = DomainTuple.make((h_space_spatial, h_space_energy)) raise ValueError
ht1 = HarmonicTransformOperator(h_space, target=s_space_spatial, space=0) if len(amplitudes) != 2:
ht2 = HarmonicTransformOperator(ht1.target, space=1) raise ValueError
ht = ht2(ht1)
p_space_spatial = amplitude_operator_spatial.target[0] hsp = DomainTuple.make([tt.get_default_codomain() for tt in tgt])
p_space_energy = amplitude_operator_energy.target[0] ht1 = HarmonicTransformOperator(hsp, target=tgt[0], space=0)
ht2 = HarmonicTransformOperator(ht1.target, space=1)
ht = ht2 @ ht1
pd_spatial = PowerDistributor(h_space, p_space_spatial, 0) psp = [aa.target[0] for aa in amplitudes]
pd_energy = PowerDistributor(pd_spatial.domain, p_space_energy, 1) pd0 = PowerDistributor(hsp, psp[0], 0)
pd = pd_spatial(pd_energy) pd1 = PowerDistributor(pd0.domain, psp[1], 1)
pd = pd0 @ pd1
dom_distr_spatial = ContractionOperator(pd.domain, 1).adjoint dd0 = ContractionOperator(pd.domain, 1).adjoint
dom_distr_energy = ContractionOperator(pd.domain, 0).adjoint dd1 = ContractionOperator(pd.domain, 0).adjoint
d = [dd0, dd1]
a_spatial = dom_distr_spatial(amplitude_operator_spatial) a = [dd @ amplitudes[ii] for ii, dd in enumerate(d)]
a_energy = dom_distr_energy(amplitude_operator_energy) a = reduce(lambda x, y: x*y, a)
a = a_spatial*a_energy A = pd @ a
A = pd(a) vol = reduce(lambda x, y: x*y, [sp.scalar_dvol**-0.5 for sp in hsp])
return ht(A*ducktape(h_space, None, name)) return ht(vol*A*ducktape(hsp, None, name))
...@@ -103,6 +103,12 @@ def testModelLibrary(space, seed): ...@@ -103,6 +103,12 @@ def testModelLibrary(space, seed):
pos = S.draw_sample() pos = S.draw_sample()
ift.extra.check_value_gradient_consistency(model2, pos, ntries=20) ift.extra.check_value_gradient_consistency(model2, pos, ntries=20)
domtup = ift.DomainTuple.make((space, space))
model3 = ift.MfCorrelatedField(domtup, [model, model])
S = ift.ScalingOperator(1., model3.domain)
pos = S.draw_sample()
ift.extra.check_value_gradient_consistency(model3, pos, ntries=20)
def testPointModel(space, seed): def testPointModel(space, seed):
S = ift.ScalingOperator(1., space) S = ift.ScalingOperator(1., space)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment