Commit 192933d1 authored by Philipp Arras's avatar Philipp Arras

Simplifications

parent 89144295
Pipeline #77385 passed with stages
in 13 minutes and 21 seconds
......@@ -19,54 +19,20 @@
import numpy as np
from ..domains.power_space import PowerSpace
from ..domains.unstructured_domain import UnstructuredDomain
from ..operators.adder import Adder
from ..operators.contraction_operator import ContractionOperator
from ..operators.distributors import PowerDistributor
from ..operators.harmonic_operators import HarmonicTransformOperator
from ..operators.linear_operator import LinearOperator
from ..operators.normal_operators import LognormalTransform, NormalTransform
from .correlated_fields import _TwoLogIntegrations
from ..operators.operator import Operator
from ..operators.simple_linear_operators import ducktape
from ..sugar import full, makeDomain, makeField, makeOp
from ..operators.value_inserter import ValueInserter
from ..sugar import full, makeField, makeOp
from .correlated_fields import (_log_vol, _Normalization,
_relative_log_k_lengths, _SlopeRemover)
class _SimpleTwoLogIntegrations(LinearOperator):
def __init__(self, target):
self._target = makeDomain(target)
assert len(self._target) == 1
tgt = self._target[0]
assert isinstance(tgt, PowerSpace)
self._domain = makeDomain(UnstructuredDomain((2, tgt.shape[0] - 2)))
self._log_vol = _log_vol(tgt)
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_input(x, mode)
from_third = slice(2, None)
no_border = slice(1, -1)
reverse = slice(None, None, -1)
if mode == self.TIMES:
x = x.val
res = np.empty(self._target.shape)
res[0] = res[1] = 0
res[from_third] = np.cumsum(x[1])
res[from_third] = (res[from_third] +
res[no_border])/2*self._log_vol + x[0]
res[from_third] = np.cumsum(res[from_third])
else:
x = x.val_rw()
res = np.zeros(self._domain.shape)
x[from_third] = np.cumsum(x[from_third][reverse])[reverse]
res[0] += x[from_third]
x[from_third] *= self._log_vol/2.
x[no_border] += x[from_third]
res[1] += np.cumsum(x[from_third][reverse])[reverse]
return makeField(self._tgt(mode), res)
class SimpleCorrelatedField(Operator):
def __init__(self,
target,
......@@ -100,7 +66,7 @@ class SimpleCorrelatedField(Operator):
prefix + 'zeromode', 0)
tgt = PowerSpace(harmonic_partner)
twolog = _SimpleTwoLogIntegrations(tgt)
twolog = _TwoLogIntegrations(tgt)
dom = twolog.domain[0]
vflex = np.zeros(dom.shape)
vasp = np.zeros(dom.shape, dtype=np.float64)
......@@ -115,7 +81,6 @@ class SimpleCorrelatedField(Operator):
expander = ContractionOperator(twolog.domain, 0).adjoint
ps_expander = ContractionOperator(tgt, 0).adjoint
h_expander = ContractionOperator(harmonic_partner, 0).adjoint
slope = vslope @ ps_expander @ avgsl
sig_flex = vflex @ expander @ flex
sig_asp = vasp @ expander @ asp
......@@ -124,24 +89,19 @@ class SimpleCorrelatedField(Operator):
smooth = _SlopeRemover(tgt, 0) @ twolog @ smooth
op = _Normalization(tgt, 0) @ (slope + smooth)
vol0 = np.zeros(tgt.shape, dtype=np.float64)
vol1 = np.zeros(tgt.shape, dtype=np.float64)
vol1[1:] = vol0[0] = target.total_volume
vol = Adder(makeField(tgt, vol0)) @ makeOp(makeField(tgt, vol1))
amp = vol @ ((ps_expander @ (zm.ptw("reciprocal")*fluct))*op)
self._a = amp
maskzm = np.ones(tgt.shape)
maskzm[0] = 0
maskzm = makeOp(makeField(tgt, maskzm))
insert = ValueInserter(tgt, (0,))
a = (maskzm @ ((ps_expander @ fluct)*op)) + insert(zm)
self._a = a.scale(target.total_volume)
ht = HarmonicTransformOperator(harmonic_partner, target)
pd = PowerDistributor(harmonic_partner, amp.target[0])
pd = PowerDistributor(harmonic_partner, tgt)
xi = ducktape(harmonic_partner, None, prefix + 'xi')
op = ht(h_expander(zm)*pd(amp)*xi*(1))
op = ht(pd(self._a)*xi)
if offset_mean is not None:
op = Adder(full(op.target, float(offset_mean))) @ op
self.apply = op.apply
self._domain = op.domain
self._target = op.target
@property
def normalized_amplitudes(self):
return [self._a]
......@@ -172,11 +172,6 @@ def test_complicated_vs_simple(seed, domain):
ift.extra.assert_allclose(scf(inp), op1(inp))
ift.extra.check_operator(scf, inp, ntries=10)
op1 = cfm.normalized_amplitudes[0]
op0 = scf.normalized_amplitudes[0]
assert_(op0.domain is op1.domain)
ift.extra.assert_allclose(op0.force(inp), op1.force(inp))
# op1 = cfm.amplitude
# op0 = scf.amplitude
# assert_(op0.domain is op1.domain)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment