diff --git a/nifty5/library/correlated_fields.py b/nifty5/library/correlated_fields.py index a0e85ed0439a186f3ea82221d482ab030eb7928f..c1d3f4c70e23311a26bf6cfc8c986fbab54fc9ed 100644 --- a/nifty5/library/correlated_fields.py +++ b/nifty5/library/correlated_fields.py @@ -15,64 +15,95 @@ # # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. +from functools import reduce + from ..domain_tuple import DomainTuple from ..operators.contraction_operator import ContractionOperator from ..operators.distributors import PowerDistributor from ..operators.harmonic_operators import HarmonicTransformOperator from ..operators.simple_linear_operators import ducktape -from ..operators.scaling_operator import ScalingOperator -def CorrelatedField(s_space, amplitude_operator, name='xi'): - ''' - Function for construction of correlated fields +def CorrelatedField(target, amplitude_operator, name='xi'): + '''Constructs operator which turns white Gaussian excitation fields into a + correlated field. + + This function returns an operator which implements: + + ht @ (vol * A * xi), + + where `ht` is a harmonic transform operator, `A` is the sqare root of the + prior covariance an `xi` is the excitation field. Parameters ---------- - s_space : Domain - Field domain + target : Domain, DomainTuple or tuple of Domain + Target of the operator. Is not allowed to be a DomainTuple with more + than one space. amplitude_operator: Operator - operator for correlation structure name : string - MultiField component name + :class:`MultiField` key for xi-field. + + Returns + ------- + Correlated field : Operator ''' - h_space = s_space.get_default_codomain() - ht = HarmonicTransformOperator(h_space, s_space) + tgt = DomainTuple.make(target) + if len(tgt) > 1: + raise ValueError + h_space = tgt[0].get_default_codomain() + ht = HarmonicTransformOperator(h_space, tgt[0]) p_space = amplitude_operator.target[0] power_distributor = PowerDistributor(h_space, p_space) A = power_distributor(amplitude_operator) - vol = h_space.scalar_dvol - vol = ScalingOperator(vol**(-0.5), h_space) - return ht(vol(A)*ducktape(h_space, None, name)) + vol = h_space.scalar_dvol**-0.5 + return ht(vol*A*ducktape(h_space, None, name)) -def MfCorrelatedField(s_space_spatial, - s_space_energy, - amplitude_operator_spatial, - amplitude_operator_energy, - name="xi"): - ''' - Method for construction of correlated multi-frequency fields +def MfCorrelatedField(target, amplitudes, name='xi'): + '''Constructs operator which turns white Gaussian excitation fields into a + correlated field defined on a DomainTuple with two entries and two separate + correlation structures. + + This operator may be used as model for multi-frequency reconstructions + with a correlation structure in both spatial and energy direction. + + Parameters + ---------- + target : Domain, DomainTuple or tuple of Domain + Target of the operator. Is not allowed to be a DomainTuple with more + than one space. + amplitudes: iterable of Operator + List of two amplitude operators. + name : string + :class:`MultiField` key for xi-field. + + Returns + ------- + Correlated field : Operator ''' - h_space_spatial = s_space_spatial.get_default_codomain() - h_space_energy = s_space_energy.get_default_codomain() - h_space = DomainTuple.make((h_space_spatial, h_space_energy)) - ht1 = HarmonicTransformOperator(h_space, target=s_space_spatial, space=0) - ht2 = HarmonicTransformOperator(ht1.target, space=1) - ht = ht2(ht1) + tgt = DomainTuple.make(target) + if len(tgt) != 2: + raise ValueError + if len(amplitudes) != 2: + raise ValueError - p_space_spatial = amplitude_operator_spatial.target[0] - p_space_energy = amplitude_operator_energy.target[0] + hsp = DomainTuple.make([tt.get_default_codomain() for tt in tgt]) + ht1 = HarmonicTransformOperator(hsp, target=tgt[0], space=0) + ht2 = HarmonicTransformOperator(ht1.target, space=1) + ht = ht2 @ ht1 - pd_spatial = PowerDistributor(h_space, p_space_spatial, 0) - pd_energy = PowerDistributor(pd_spatial.domain, p_space_energy, 1) - pd = pd_spatial(pd_energy) + psp = [aa.target[0] for aa in amplitudes] + pd0 = PowerDistributor(hsp, psp[0], 0) + pd1 = PowerDistributor(pd0.domain, psp[1], 1) + pd = pd0 @ pd1 - dom_distr_spatial = ContractionOperator(pd.domain, 1).adjoint - dom_distr_energy = ContractionOperator(pd.domain, 0).adjoint + dd0 = ContractionOperator(pd.domain, 1).adjoint + dd1 = ContractionOperator(pd.domain, 0).adjoint + d = [dd0, dd1] - a_spatial = dom_distr_spatial(amplitude_operator_spatial) - a_energy = dom_distr_energy(amplitude_operator_energy) - a = a_spatial*a_energy - A = pd(a) - return ht(A*ducktape(h_space, None, name)) + a = [dd @ amplitudes[ii] for ii, dd in enumerate(d)] + a = reduce(lambda x, y: x*y, a) + A = pd @ a + vol = reduce(lambda x, y: x*y, [sp.scalar_dvol**-0.5 for sp in hsp]) + return ht(vol*A*ducktape(hsp, None, name)) diff --git a/test/test_model_gradients.py b/test/test_model_gradients.py index 7b5729a34bf260c42e70c4305e4f23f5a0b3fa41..e30e0ce5497683551c5c301c20301ca9ec41c667 100644 --- a/test/test_model_gradients.py +++ b/test/test_model_gradients.py @@ -103,6 +103,12 @@ def testModelLibrary(space, seed): pos = S.draw_sample() ift.extra.check_value_gradient_consistency(model2, pos, ntries=20) + domtup = ift.DomainTuple.make((space, space)) + model3 = ift.MfCorrelatedField(domtup, [model, model]) + S = ift.ScalingOperator(1., model3.domain) + pos = S.draw_sample() + ift.extra.check_value_gradient_consistency(model3, pos, ntries=20) + def testPointModel(space, seed): S = ift.ScalingOperator(1., space)