diff --git a/src/library/correlated_fields_simple.py b/src/library/correlated_fields_simple.py index f6c1529af1a994db9d309c17f653492f39098961..0ffb2edba4975ab06bad282cb534d11d5574d8bf 100644 --- a/src/library/correlated_fields_simple.py +++ b/src/library/correlated_fields_simple.py @@ -19,54 +19,20 @@ import numpy as np from ..domains.power_space import PowerSpace -from ..domains.unstructured_domain import UnstructuredDomain from ..operators.adder import Adder from ..operators.contraction_operator import ContractionOperator from ..operators.distributors import PowerDistributor from ..operators.harmonic_operators import HarmonicTransformOperator -from ..operators.linear_operator import LinearOperator from ..operators.normal_operators import LognormalTransform, NormalTransform +from .correlated_fields import _TwoLogIntegrations from ..operators.operator import Operator from ..operators.simple_linear_operators import ducktape -from ..sugar import full, makeDomain, makeField, makeOp +from ..operators.value_inserter import ValueInserter +from ..sugar import full, makeField, makeOp from .correlated_fields import (_log_vol, _Normalization, _relative_log_k_lengths, _SlopeRemover) -class _SimpleTwoLogIntegrations(LinearOperator): - def __init__(self, target): - self._target = makeDomain(target) - assert len(self._target) == 1 - tgt = self._target[0] - assert isinstance(tgt, PowerSpace) - self._domain = makeDomain(UnstructuredDomain((2, tgt.shape[0] - 2))) - self._log_vol = _log_vol(tgt) - self._capability = self.TIMES | self.ADJOINT_TIMES - - def apply(self, x, mode): - self._check_input(x, mode) - from_third = slice(2, None) - no_border = slice(1, -1) - reverse = slice(None, None, -1) - if mode == self.TIMES: - x = x.val - res = np.empty(self._target.shape) - res[0] = res[1] = 0 - res[from_third] = np.cumsum(x[1]) - res[from_third] = (res[from_third] + - res[no_border])/2*self._log_vol + x[0] - res[from_third] = np.cumsum(res[from_third]) - else: - x = x.val_rw() - res = np.zeros(self._domain.shape) - x[from_third] = np.cumsum(x[from_third][reverse])[reverse] - res[0] += x[from_third] - x[from_third] *= self._log_vol/2. - x[no_border] += x[from_third] - res[1] += np.cumsum(x[from_third][reverse])[reverse] - return makeField(self._tgt(mode), res) - - class SimpleCorrelatedField(Operator): def __init__(self, target, @@ -100,7 +66,7 @@ class SimpleCorrelatedField(Operator): prefix + 'zeromode', 0) tgt = PowerSpace(harmonic_partner) - twolog = _SimpleTwoLogIntegrations(tgt) + twolog = _TwoLogIntegrations(tgt) dom = twolog.domain[0] vflex = np.zeros(dom.shape) vasp = np.zeros(dom.shape, dtype=np.float64) @@ -115,7 +81,6 @@ class SimpleCorrelatedField(Operator): expander = ContractionOperator(twolog.domain, 0).adjoint ps_expander = ContractionOperator(tgt, 0).adjoint - h_expander = ContractionOperator(harmonic_partner, 0).adjoint slope = vslope @ ps_expander @ avgsl sig_flex = vflex @ expander @ flex sig_asp = vasp @ expander @ asp @@ -124,24 +89,19 @@ class SimpleCorrelatedField(Operator): smooth = _SlopeRemover(tgt, 0) @ twolog @ smooth op = _Normalization(tgt, 0) @ (slope + smooth) - vol0 = np.zeros(tgt.shape, dtype=np.float64) - vol1 = np.zeros(tgt.shape, dtype=np.float64) - vol1[1:] = vol0[0] = target.total_volume - vol = Adder(makeField(tgt, vol0)) @ makeOp(makeField(tgt, vol1)) - - amp = vol @ ((ps_expander @ (zm.ptw("reciprocal")*fluct))*op) - self._a = amp + maskzm = np.ones(tgt.shape) + maskzm[0] = 0 + maskzm = makeOp(makeField(tgt, maskzm)) + insert = ValueInserter(tgt, (0,)) + a = (maskzm @ ((ps_expander @ fluct)*op)) + insert(zm) + self._a = a.scale(target.total_volume) ht = HarmonicTransformOperator(harmonic_partner, target) - pd = PowerDistributor(harmonic_partner, amp.target[0]) + pd = PowerDistributor(harmonic_partner, tgt) xi = ducktape(harmonic_partner, None, prefix + 'xi') - op = ht(h_expander(zm)*pd(amp)*xi*(1)) + op = ht(pd(self._a)*xi) if offset_mean is not None: op = Adder(full(op.target, float(offset_mean))) @ op self.apply = op.apply self._domain = op.domain self._target = op.target - - @property - def normalized_amplitudes(self): - return [self._a] diff --git a/test/test_operators/test_correlated_fields.py b/test/test_operators/test_correlated_fields.py index 56f1736402c47984e1963de9c710ed08afc1ecc0..20411883480f76664a5cfec817d03f5aa97e3700 100644 --- a/test/test_operators/test_correlated_fields.py +++ b/test/test_operators/test_correlated_fields.py @@ -172,11 +172,6 @@ def test_complicated_vs_simple(seed, domain): ift.extra.assert_allclose(scf(inp), op1(inp)) ift.extra.check_operator(scf, inp, ntries=10) - op1 = cfm.normalized_amplitudes[0] - op0 = scf.normalized_amplitudes[0] - assert_(op0.domain is op1.domain) - ift.extra.assert_allclose(op0.force(inp), op1.force(inp)) - # op1 = cfm.amplitude # op0 = scf.amplitude # assert_(op0.domain is op1.domain)