Commit 3e05d6d0 authored by Philipp Arras's avatar Philipp Arras
Browse files

Remove jax correlated field again

parent fc4e3f45
Pipeline #104717 canceled with stages
in 4 minutes and 58 seconds
......@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2021 Max-Planck-Society
# Copyright(C) 2013-2020 Max-Planck-Society
# Author: Philipp Arras
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
......@@ -27,11 +27,10 @@ from ..operators.harmonic_operators import HarmonicTransformOperator
from ..operators.normal_operators import LognormalTransform, NormalTransform
from ..operators.simple_linear_operators import ducktape
from ..operators.value_inserter import ValueInserter
from ..sugar import full, makeField, makeOp, makeDomain
from ..sugar import full, makeField, makeOp
from .correlated_fields import (_log_vol, _Normalization,
_relative_log_k_lengths, _SlopeRemover,
_TwoLogIntegrations)
from ..utilities import lognormal_moments, myassert
def SimpleCorrelatedField(
......@@ -44,16 +43,12 @@ def SimpleCorrelatedField(
loglogavgslope,
prefix="",
harmonic_partner=None,
jax=False
):
"""Simplified version of :class:`~nifty8.library.correlated_fields.CorrelatedFieldMaker`.
Assumes `total_N = 0`, `dofdex = None` and the presence of only one power
spectrum, i.e. only one call of
:func:`~nifty8.library.correlated_fields.CorrelatedFieldMaker.add_fluctuations`.
The boolean parameter `jax` controls if the jax backend shall be used.
Default is False.
"""
target = DomainTuple.make(target)
if len(target) != 1:
......@@ -72,10 +67,6 @@ def SimpleCorrelatedField(
raise TypeError
if flexibility is None and asperity is not None:
raise ValueError
if jax:
return _SCFJax(target, offset_mean, offset_std, fluctuations,
flexibility, asperity, loglogavgslope, prefix,
harmonic_partner)
fluct = LognormalTransform(*fluctuations, prefix + 'fluctuations', 0)
avgsl = NormalTransform(*loglogavgslope, prefix + 'loglogavgslope', 0)
......@@ -131,105 +122,3 @@ def SimpleCorrelatedField(
op.amplitude = a
op.power_spectrum = a**2
return op
def _hartley(arr):
from jax.numpy import fft
arr = fft.fftn(arr)
return arr.real + arr.imag
def _twolog_integrate(log_vol, x):
twolog = np.empty((2 + log_vol.shape[0], ))
twolog = twolog.at[0].set(0.)
twolog = twolog.at[1].set(0.)
twolog = twolog.at[2:].set(np.cumsum(x[1], axis=0))
twolog = twolog.at[2:].set(
(twolog[2:] + twolog[1:-1]) / 2. * log_vol + x[0]
)
twolog = twolog.at[2:].set(np.cumsum(twolog[2:], axis=0))
return twolog
# def _remove_slope(rel_log_mode_dist, x):
# sc = rel_log_mode_dist / rel_log_mode_dist[-1]
# return x - x[-1] * sc
def _normalize(x, multipl):
import jax.numpy as jnp
x = jnp.exp(x)
return jnp.sqrt(x/jnp.sum(x*multipl))
def _normal(mean, sigma, key):
return lambda x: float(mean)+float(sigma)*x[key]
def _lognormal(mean, sigma, key):
import jax.numpy as jnp
logmean, logsigma = lognormal_moments(mean, sigma)
return lambda x: jnp.exp(_normal(logmean, logsigma, key)(x))
def _power_distributor(target):
target = DomainTuple.make(target)
myassert(len(target) == 1)
hspace = target[0]
power_space = PowerSpace(hspace)
pindex = PowerSpace(hspace).pindex.ravel()
def operator(x):
import jax.numpy as jnp
oarr = jnp.empty(target.size)
oarr = x[pindex]
return oarr.reshape(target.shape)
return operator
def _SCFJax(target, offset_mean, offset_std, fluctuations, flexibility,
asperity, loglogavgslope, prefix, harmonic_partner):
from ..operators.jax_operator import JaxOperator
if flexibility is not None or asperity is not None or offset_std is not None:
raise NotImplementedError
pspace = PowerSpace(harmonic_partner)
dom = {}
# Slope
kk = prefix+"loglogavgslope"
dom[kk] = DomainTuple.scalar_domain()
slope_var = _normal(*loglogavgslope, kk)
slope = lambda xi: _relative_log_k_lengths(pspace) * slope_var(xi)
# /Slope
# Spectrum normalization
multipl = PowerDistributor(harmonic_partner, pspace).adjoint(full(harmonic_partner, 1.)).val_rw()
multipl[0] = 0.
norm_a = lambda xi: _normalize(slope(xi), multipl)
kk = prefix+"fluctuations"
dom[kk] = DomainTuple.scalar_domain()
fluct = _lognormal(*fluctuations, kk)
mask = np.ones(pspace.shape)
mask[0] = 0.
ampl = lambda xi: norm_a(xi)*fluct(xi)*target.total_volume*mask
# /Spectrum normalization
# Power distribution, excitations, FT
kk = prefix+"xi"
pdom = makeDomain(dom)
dom[kk] = harmonic_partner
pd = _power_distributor(harmonic_partner)
b = lambda xi: _hartley(pd(ampl(xi))*xi[kk])*harmonic_partner.scalar_dvol
# /Power distribution, excitations, FT
# Offset
if offset_mean is None:
func = b
else:
func = lambda xi: b(xi) + offset_mean
# /Offset
op = JaxOperator(dom, target, func)
op.amplitude = JaxOperator(pdom, pspace, ampl)
op.power_spectrum = JaxOperator(pdom, pspace, lambda xi: ampl(xi)**2)
return op
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment