Commit c7d69125 authored by Philipp Arras's avatar Philipp Arras
Browse files

Simplifications

parent 04d518e2
...@@ -88,7 +88,7 @@ from .library.adjust_variances import (make_adjust_variances_hamiltonian, ...@@ -88,7 +88,7 @@ from .library.adjust_variances import (make_adjust_variances_hamiltonian,
do_adjust_variances) do_adjust_variances)
from .library.gridder import Gridder from .library.gridder import Gridder
from .library.correlated_fields import CorrelatedFieldMaker from .library.correlated_fields import CorrelatedFieldMaker
from .library.correlated_fields_simple import SimpleCorrelatedFieldMaker from .library.correlated_fields_simple import SimpleCorrelatedField
from . import extra from . import extra
......
...@@ -16,31 +16,21 @@ ...@@ -16,31 +16,21 @@
# #
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from functools import reduce
from operator import mul
import numpy as np import numpy as np
from .. import utilities
from ..domain_tuple import DomainTuple
from ..domains.power_space import PowerSpace from ..domains.power_space import PowerSpace
from ..domains.unstructured_domain import UnstructuredDomain from ..domains.unstructured_domain import UnstructuredDomain
from ..field import Field
from ..logger import logger
from ..multi_field import MultiField
from ..operators.adder import Adder from ..operators.adder import Adder
from ..operators.contraction_operator import ContractionOperator from ..operators.contraction_operator import ContractionOperator
from ..operators.diagonal_operator import DiagonalOperator
from ..operators.distributors import PowerDistributor from ..operators.distributors import PowerDistributor
from ..operators.endomorphic_operator import EndomorphicOperator
from ..operators.harmonic_operators import HarmonicTransformOperator from ..operators.harmonic_operators import HarmonicTransformOperator
from ..operators.linear_operator import LinearOperator from ..operators.linear_operator import LinearOperator
from ..operators.normal_operators import LognormalTransform, NormalTransform from ..operators.normal_operators import LognormalTransform, NormalTransform
from ..operators.operator import Operator from ..operators.operator import Operator
from ..operators.simple_linear_operators import ducktape from ..operators.simple_linear_operators import ducktape
from ..probing import StatCalculator
from ..sugar import full, makeDomain, makeField, makeOp from ..sugar import full, makeDomain, makeField, makeOp
from .correlated_fields import (_Distributor, _log_vol, _Normalization, from .correlated_fields import (_log_vol, _Normalization,
_relative_log_k_lengths, _SlopeRemover) _relative_log_k_lengths, _SlopeRemover)
...@@ -78,73 +68,7 @@ class _SimpleTwoLogIntegrations(LinearOperator): ...@@ -78,73 +68,7 @@ class _SimpleTwoLogIntegrations(LinearOperator):
return makeField(self._tgt(mode), res) return makeField(self._tgt(mode), res)
class _SimpleAmplitude(Operator): class SimpleCorrelatedField(Operator):
def __init__(self, target, fluctuations, flexibility, asperity,
loglogavgslope, azm, totvol, key):
assert isinstance(fluctuations, Operator)
assert isinstance(flexibility, Operator)
assert isinstance(asperity, Operator)
assert isinstance(loglogavgslope, Operator)
distributed_tgt = target = makeDomain(target)
azm_expander = ContractionOperator(distributed_tgt, spaces=0).adjoint
assert isinstance(target[0], PowerSpace)
twolog = _SimpleTwoLogIntegrations(target)
dom = twolog.domain
shp = dom[0].shape
expander = ContractionOperator(dom, spaces=0).adjoint
ps_expander = ContractionOperator(twolog.target, spaces=0).adjoint
# Prepare constant fields
foo = np.zeros(shp)
foo[0] = foo[1] = np.sqrt(_log_vol(target[0]))
vflex = DiagonalOperator(makeField(dom[0], foo), dom, 0)
foo = np.zeros(shp, dtype=np.float64)
foo[0] += 1
vasp = DiagonalOperator(makeField(dom[0], foo), dom, 0)
foo = np.ones(shp)
foo[0] = _log_vol(target[0])**2/12.
shift = DiagonalOperator(makeField(dom[0], foo), dom, 0)
vslope = DiagonalOperator(
makeField(target[0], _relative_log_k_lengths(target[0])), target,
0)
foo, bar = [np.zeros(target[0].shape) for _ in range(2)]
bar[1:] = foo[0] = totvol
vol0, vol1 = [
DiagonalOperator(makeField(target[0], aa), target, 0)
for aa in (foo, bar)
]
# Prepare fields for Adder
shift, vol0 = [op(full(op.domain, 1)) for op in (shift, vol0)]
# End prepare constant fields
slope = vslope @ ps_expander @ loglogavgslope
sig_flex = vflex @ expander @ flexibility
sig_asp = vasp @ expander @ asperity
sig_fluc = vol1 @ ps_expander @ fluctuations
sig_fluc = vol1 @ ps_expander @ fluctuations
xi = ducktape(dom, None, key)
sigma = sig_flex*(Adder(shift) @ sig_asp).ptw("sqrt")
smooth = _SlopeRemover(target, 0) @ twolog @ (sigma*xi)
op = _Normalization(target, 0) @ (slope + smooth)
op = Adder(vol0) @ (sig_fluc*(azm_expander @ azm.ptw("reciprocal"))*op)
self.apply = op.apply
self._domain, self._target = op.domain, op.target
self._op = op
def __repr__(self):
return self._op.__repr__
class SimpleCorrelatedFieldMaker:
def __init__(self, def __init__(self,
target, target,
offset_mean, offset_mean,
...@@ -175,18 +99,47 @@ class SimpleCorrelatedFieldMaker: ...@@ -175,18 +99,47 @@ class SimpleCorrelatedFieldMaker:
prefix + 'loglogavgslope', 0) prefix + 'loglogavgslope', 0)
zm = LognormalTransform(offset_std_mean, offset_std_std, zm = LognormalTransform(offset_std_mean, offset_std_std,
prefix + 'zeromode', 0) prefix + 'zeromode', 0)
amp = _SimpleAmplitude(PowerSpace(harmonic_partner), fluct, flex, asp,
avgsl, zm, target.total_volume, tgt = PowerSpace(harmonic_partner)
prefix + 'spectrum') azm_expander = ContractionOperator(tgt, 0).adjoint
twolog = _SimpleTwoLogIntegrations(tgt)
dom = twolog.domain[0]
vflex = np.zeros(dom.shape)
vasp = np.zeros(dom.shape, dtype=np.float64)
shift = np.ones(dom.shape, dtype=np.float64)
vol0 = np.zeros(tgt.shape, dtype=np.float64)
vol1 = np.zeros(tgt.shape, dtype=np.float64)
vflex[0] = vflex[1] = np.sqrt(_log_vol(tgt))
vasp[0] = 1
shift[0] = _log_vol(tgt)**2/12.
vol1[1:] = vol0[0] = target.total_volume
vflex = makeOp(makeField(dom, vflex))
vasp = makeOp(makeField(dom, vasp))
shift = makeOp(makeField(dom, shift))
vol0 = makeField(tgt, vol0)
vol1 = makeOp(makeField(tgt, vol1))
vslope = makeOp(makeField(tgt, _relative_log_k_lengths(tgt)))
shift = shift(full(shift.domain, 1))
expander = ContractionOperator(twolog.domain, spaces=0).adjoint
ps_expander = ContractionOperator(twolog.target, spaces=0).adjoint
slope = vslope @ ps_expander @ avgsl
sig_flex = vflex @ expander @ flex
sig_asp = vasp @ expander @ asp
sig_fluc = vol1 @ ps_expander @ fluct
sig_fluc = vol1 @ ps_expander @ fluct
xi = ducktape(dom, None, prefix + 'spectrum')
sigma = sig_flex*(Adder(shift) @ sig_asp).ptw("sqrt")
smooth = _SlopeRemover(tgt, 0) @ twolog @ (sigma*xi)
op = _Normalization(tgt, 0) @ (slope + smooth)
amp = Adder(vol0) @ (sig_fluc*(azm_expander @ zm.ptw("reciprocal"))*op)
ht = HarmonicTransformOperator(harmonic_partner, target) ht = HarmonicTransformOperator(harmonic_partner, target)
pd = PowerDistributor(harmonic_partner, amp.target[0]) pd = PowerDistributor(harmonic_partner, amp.target[0])
expander = ContractionOperator(harmonic_partner, spaces=0).adjoint expander = ContractionOperator(harmonic_partner, spaces=0).adjoint
self._op = ht( xi = ducktape(harmonic_partner, None, prefix + 'xi')
expander(zm)*pd(amp)* op = ht(expander(zm)*pd(amp)*xi)
ducktape(harmonic_partner, None, prefix + 'xi'))
if offset_mean is not None: if offset_mean is not None:
self._op = Adder(full(self._op.target, op = Adder(full(op.target, float(offset_mean))) @ op
float(offset_mean))) @ self._op self.apply = op.apply
self._domain = op.domain
def finalize(self): self._target = op.target
return self._op
...@@ -139,20 +139,20 @@ def test_complicated_vs_simple(seed, domain): ...@@ -139,20 +139,20 @@ def test_complicated_vs_simple(seed, domain):
loglogavgslope_stddev = _posrand() loglogavgslope_stddev = _posrand()
prefix = 'foobar' prefix = 'foobar'
hspace = domain.get_default_codomain() hspace = domain.get_default_codomain()
cfm_simple = ift.SimpleCorrelatedFieldMaker(domain, op0 = ift.SimpleCorrelatedField(domain,
offset_mean, offset_mean,
offset_std_mean, offset_std_mean,
offset_std_std, offset_std_std,
fluctuations_mean, fluctuations_mean,
fluctuations_stddev, fluctuations_stddev,
flexibility_mean, flexibility_mean,
flexibility_stddev, flexibility_stddev,
asperity_mean, asperity_mean,
asperity_stddev, asperity_stddev,
loglogavgslope_mean, loglogavgslope_mean,
loglogavgslope_stddev, loglogavgslope_stddev,
prefix=prefix, prefix=prefix,
harmonic_partner=hspace) harmonic_partner=hspace)
cfm = ift.CorrelatedFieldMaker.make(offset_mean, offset_std_mean, cfm = ift.CorrelatedFieldMaker.make(offset_mean, offset_std_mean,
offset_std_std, prefix) offset_std_std, prefix)
cfm.add_fluctuations(domain, cfm.add_fluctuations(domain,
...@@ -167,7 +167,6 @@ def test_complicated_vs_simple(seed, domain): ...@@ -167,7 +167,6 @@ def test_complicated_vs_simple(seed, domain):
prefix='', prefix='',
harmonic_partner=hspace) harmonic_partner=hspace)
op1 = cfm.finalize() op1 = cfm.finalize()
op0 = cfm_simple.finalize()
assert_(op0.domain is op1.domain) assert_(op0.domain is op1.domain)
inp = ift.from_random(op0.domain) inp = ift.from_random(op0.domain)
ift.extra.assert_allclose(op0(inp), op1(inp)) ift.extra.assert_allclose(op0(inp), op1(inp))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment