Commit 92eab1e7 authored by Philipp Arras's avatar Philipp Arras
Browse files

Implement basic functionality of cfm in jax

parent 1aae596d
Pipeline #103465 failed with stages
in 13 minutes and 12 seconds
......@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2020 Max-Planck-Society
# Copyright(C) 2013-2021 Max-Planck-Society
# Author: Philipp Arras
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
......@@ -27,10 +27,11 @@ from ..operators.harmonic_operators import HarmonicTransformOperator
from ..operators.normal_operators import LognormalTransform, NormalTransform
from ..operators.simple_linear_operators import ducktape
from ..operators.value_inserter import ValueInserter
from ..sugar import full, makeField, makeOp
from ..sugar import full, makeField, makeOp, makeDomain
from .correlated_fields import (_log_vol, _Normalization,
_relative_log_k_lengths, _SlopeRemover,
_TwoLogIntegrations)
from ..utilities import lognormal_moments, myassert
def SimpleCorrelatedField(
......@@ -43,12 +44,16 @@ def SimpleCorrelatedField(
loglogavgslope,
prefix="",
harmonic_partner=None,
jax=False
):
"""Simplified version of :class:`~nifty8.library.correlated_fields.CorrelatedFieldMaker`.
Assumes `total_N = 0`, `dofdex = None` and the presence of only one power
spectrum, i.e. only one call of
:func:`~nifty8.library.correlated_fields.CorrelatedFieldMaker.add_fluctuations`.
The boolean parameter `jax` controls if the jax backend shall be used.
Default is False.
"""
target = DomainTuple.make(target)
if len(target) != 1:
......@@ -67,6 +72,10 @@ def SimpleCorrelatedField(
raise TypeError
if flexibility is None and asperity is not None:
raise ValueError
if jax:
return _SCFJax(target, offset_mean, offset_std, fluctuations,
flexibility, asperity, loglogavgslope, prefix,
harmonic_partner)
fluct = LognormalTransform(*fluctuations, prefix + 'fluctuations', 0)
avgsl = NormalTransform(*loglogavgslope, prefix + 'loglogavgslope', 0)
......@@ -122,3 +131,105 @@ def SimpleCorrelatedField(
op.amplitude = a
op.power_spectrum = a**2
return op
def _hartley(arr):
from jax.numpy import fft
arr = fft.fftn(arr)
return arr.real + arr.imag
def _twolog_integrate(log_vol, x):
twolog = np.empty((2 + log_vol.shape[0], ))
twolog = twolog.at[0].set(0.)
twolog = twolog.at[1].set(0.)
twolog = twolog.at[2:].set(np.cumsum(x[1], axis=0))
twolog = twolog.at[2:].set(
(twolog[2:] + twolog[1:-1]) / 2. * log_vol + x[0]
)
twolog = twolog.at[2:].set(np.cumsum(twolog[2:], axis=0))
return twolog
# def _remove_slope(rel_log_mode_dist, x):
# sc = rel_log_mode_dist / rel_log_mode_dist[-1]
# return x - x[-1] * sc
def _normalize(x, multipl):
import jax.numpy as jnp
x = jnp.exp(x)
return jnp.sqrt(x/jnp.sum(x*multipl))
def _normal(mean, sigma, key):
return lambda x: float(mean)+float(sigma)*x[key]
def _lognormal(mean, sigma, key):
import jax.numpy as jnp
logmean, logsigma = lognormal_moments(mean, sigma)
return lambda x: jnp.exp(_normal(logmean, logsigma, key)(x))
def _power_distributor(target):
target = DomainTuple.make(target)
myassert(len(target) == 1)
hspace = target[0]
power_space = PowerSpace(hspace)
pindex = PowerSpace(hspace).pindex.ravel()
def operator(x):
import jax.numpy as jnp
oarr = jnp.empty(target.size)
oarr = x[pindex]
return oarr.reshape(target.shape)
return operator
def _SCFJax(target, offset_mean, offset_std, fluctuations, flexibility,
asperity, loglogavgslope, prefix, harmonic_partner):
from ..operators.jax_operator import JaxOperator
if flexibility is not None or asperity is not None or offset_std is not None:
raise NotImplementedError
pspace = PowerSpace(harmonic_partner)
dom = {}
# Slope
kk = prefix+"loglogavgslope"
dom[kk] = DomainTuple.scalar_domain()
slope_var = _normal(*loglogavgslope, kk)
slope = lambda xi: _relative_log_k_lengths(pspace) * slope_var(xi)
# /Slope
# Spectrum normalization
multipl = PowerDistributor(harmonic_partner, pspace).adjoint(full(harmonic_partner, 1.)).val_rw()
multipl[0] = 0.
norm_a = lambda xi: _normalize(slope(xi), multipl)
kk = prefix+"fluctuations"
dom[kk] = DomainTuple.scalar_domain()
fluct = _lognormal(*fluctuations, kk)
mask = np.ones(pspace.shape)
mask[0] = 0.
ampl = lambda xi: norm_a(xi)*fluct(xi)*target.total_volume*mask
# /Spectrum normalization
# Power distribution, excitations, FT
kk = prefix+"xi"
pdom = makeDomain(dom)
dom[kk] = harmonic_partner
pd = _power_distributor(harmonic_partner)
b = lambda xi: _hartley(pd(ampl(xi))*xi[kk])*harmonic_partner.scalar_dvol
# /Power distribution, excitations, FT
# Offset
if offset_mean is None:
func = b
else:
func = lambda xi: b(xi) + offset_mean
# /Offset
op = JaxOperator(dom, target, func)
op.amplitude = JaxOperator(pdom, pspace, ampl)
op.power_spectrum = JaxOperator(pdom, pspace, lambda xi: ampl(xi)**2)
return op
......@@ -43,7 +43,7 @@ class JaxOperator(Operator):
target : DomainTuple or MultiDomain
Target of the operator.
func : callable
func : callable
The jax function that is evaluated by the operator. It has to be
implemented in terms of `jax.numpy` calls. If `domain` is a
`DomainTuple`, `func` takes a `dict` as argument and like-wise for the
......@@ -54,12 +54,13 @@ class JaxOperator(Operator):
self._domain = makeDomain(domain)
self._target = makeDomain(target)
self._func = jax.jit(func)
self._vjp = jax.jit(lambda x: jax.vjp(func, x))
def apply(self, x):
from ..sugar import is_linearization, makeField
self._check_input(x)
if is_linearization(x):
res, bwd = jax.vjp(self._func, x.val.val)
res, bwd = self._vjp(x.val.val)
fwd = lambda y: jax.jvp(self._func, (x.val.val,), (y,))[1]
jac = _JaxJacobian(self._domain, self._target, fwd, bwd)
return x.new(makeField(self._target, _jax2np(res)), jac)
......
......@@ -60,3 +60,25 @@ def test_mf_jax():
for kk in dom.keys():
np.testing.assert_allclose(np.array(func(loc.val)[kk]), op(loc)[kk].val)
ift.extra.check_operator(op, loc)
def test_cf():
dom = ift.RGSpace([13, 14], distances=(0.89, 0.9))
args = {
'offset_mean': 0,
'offset_std': None,
'fluctuations': (2.1, 0.8),
'loglogavgslope': (-2., 1.2),
'flexibility': None,
'asperity': None,
}
correlated_field = ift.SimpleCorrelatedField(dom, **args)
correlated_field_1 = ift.SimpleCorrelatedField(dom, **args, jax=True)
loc = ift.from_random(correlated_field.domain)
ift.extra.assert_allclose(correlated_field(loc), correlated_field_1(loc))
pspec = correlated_field.power_spectrum
pspec_1 = correlated_field_1.power_spectrum
ift.extra.assert_allclose(pspec.force(loc), pspec_1.force(loc))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment