Commit fb0a4dc3 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'nifty6to7' into 'NIFTy_7'

Nifty6to7

See merge request ift/nifty!533
parents 254dfe98 03bd7f7a
......@@ -152,10 +152,8 @@
"sigmas = [1.0, 0.5, 0.1]\n",
"\n",
"for i in range(3):\n",
" op = ift.library.correlated_fields._LognormalMomentMatching(mean=mean,\n",
" sig=sigmas[i],\n",
" key='foo',\n",
" N_copies=0)\n",
" op = ift.LognormalTransform(mean=mean, sigma=sigmas[i],\n",
" key='foo', N_copies=0)\n",
" op_samples = np.array(\n",
" [op(s).val for s in [ift.from_random(op.domain) for i in range(10000)]])\n",
"\n",
......
%% Cell type:markdown id: tags:
# Notebook showcasing the NIFTy 6 Correlated Field model
**Skip to `Parameter Showcases` for the meat/veggies ;)**
The field model roughly works like this:
`f = HT( A * zero_mode * xi ) + offset`
`A` is a spectral power field which is constructed from power spectra that hold on subdomains of the target domain.
It is scaled by a zero mode operator and then pointwise multiplied by a gaussian excitation field, yielding
a representation of the field in harmonic space.
It is then transformed into the target real space and a offset added.
The power spectra `A` is constructed of are in turn constructed as the sum of a power law component
and an integrated Wiener process whose amplitude and roughness can be set.
## Setup code
%% Cell type:code id: tags:
``` python
import nifty7 as ift
import matplotlib.pyplot as plt
import numpy as np
ift.random.push_sseq_from_seed(43)
n_pix = 256
x_space = ift.RGSpace(n_pix)
```
%% Cell type:code id: tags:
``` python
# Plotting routine
def plot(fields, spectra, title=None):
# Plotting preparation is normally handled by nifty7.Plot
# It is done manually here to be able to tweak details
# Fields are assumed to have identical domains
fig = plt.figure(tight_layout=True, figsize=(12, 3.5))
if title is not None:
fig.suptitle(title, fontsize=14)
# Field
ax1 = fig.add_subplot(1, 2, 1)
ax1.axhline(y=0., color='k', linestyle='--', alpha=0.25)
for field in fields:
dom = field.domain[0]
xcoord = np.arange(dom.shape[0]) * dom.distances[0]
ax1.plot(xcoord, field.val)
ax1.set_xlim(xcoord[0], xcoord[-1])
ax1.set_ylim(-5., 5.)
ax1.set_xlabel('x')
ax1.set_ylabel('f(x)')
ax1.set_title('Field realizations')
# Spectrum
ax2 = fig.add_subplot(1, 2, 2)
for spectrum in spectra:
xcoord = spectrum.domain[0].k_lengths
ycoord = spectrum.val_rw()
ycoord[0] = ycoord[1]
ax2.plot(xcoord, ycoord)
ax2.set_ylim(1e-6, 10.)
ax2.set_xscale('log')
ax2.set_yscale('log')
ax2.set_xlabel('k')
ax2.set_ylabel('p(k)')
ax2.set_title('Power Spectrum')
fig.align_labels()
plt.show()
# Helper: draw main sample
main_sample = None
def init_model(m_pars, fl_pars):
global main_sample
cf = ift.CorrelatedFieldMaker.make(**m_pars)
cf.add_fluctuations(**fl_pars)
field = cf.finalize(prior_info=0)
main_sample = ift.from_random(field.domain)
print("model domain keys:", field.domain.keys())
# Helper: field and spectrum from parameter dictionaries + plotting
def eval_model(m_pars, fl_pars, title=None, samples=None):
cf = ift.CorrelatedFieldMaker.make(**m_pars)
cf.add_fluctuations(**fl_pars)
field = cf.finalize(prior_info=0)
spectrum = cf.amplitude
if samples is None:
samples = [main_sample]
field_realizations = [field(s) for s in samples]
spectrum_realizations = [spectrum.force(s) for s in samples]
plot(field_realizations, spectrum_realizations, title)
def gen_samples(key_to_vary):
if key_to_vary is None:
return [main_sample]
dct = main_sample.to_dict()
subdom_to_vary = dct.pop(key_to_vary).domain
samples = []
for i in range(8):
d = dct.copy()
d[key_to_vary] = ift.from_random(subdom_to_vary)
samples.append(ift.MultiField.from_dict(d))
return samples
def vary_parameter(parameter_key, values, samples_vary_in=None):
s = gen_samples(samples_vary_in)
for v in values:
if parameter_key in cf_make_pars.keys():
m_pars = {**cf_make_pars, parameter_key: v}
eval_model(m_pars, cf_x_fluct_pars, f"{parameter_key} = {v}", s)
else:
fl_pars = {**cf_x_fluct_pars, parameter_key: v}
eval_model(cf_make_pars, fl_pars, f"{parameter_key} = {v}", s)
```
%% Cell type:markdown id: tags:
## Before the Action: The Moment-Matched Log-Normal Distribution
Many properties of the correlated field are modelled as being lognormally distributed.
The distribution models are parametrized via their means `_mean` and standard-deviations `_stddev`.
To get a feeling of how the ratio of the `mean` and `stddev` parameters influences the distribution shape,
here are a few example histograms: (observe the x-axis!)
%% Cell type:code id: tags:
``` python
fig = plt.figure(figsize=(13, 3.5))
mean = 1.0
sigmas = [1.0, 0.5, 0.1]
for i in range(3):
op = ift.library.correlated_fields._LognormalMomentMatching(mean=mean,
sig=sigmas[i],
key='foo',
N_copies=0)
op = ift.LognormalTransform(mean=mean, sigma=sigmas[i],
key='foo', N_copies=0)
op_samples = np.array(
[op(s).val for s in [ift.from_random(op.domain) for i in range(10000)]])
ax = fig.add_subplot(1, 3, i + 1)
ax.hist(op_samples, bins=50)
ax.set_title(f"mean = {mean}, sigma = {sigmas[i]}")
ax.set_xlabel('x')
del op_samples
plt.show()
```
%% Cell type:markdown id: tags:
## The Neutral Field
To demonstrate the effect of all parameters, first a 'neutral' set of parameters
is defined which then are varied one by one, showing the effect of the variation
on the generated field realizations and the underlying power spectrum from which
they were drawn.
As a neutral field, a model with a white power spectrum and vanishing spectral power was chosen.
%% Cell type:code id: tags:
``` python
# Neutral model parameters yielding a quasi-constant field
cf_make_pars = {
'offset_mean': 0.,
'offset_std_mean': 1e-3,
'offset_std_std': 1e-16,
'prefix': ''
}
cf_x_fluct_pars = {
'target_subdomain': x_space,
'fluctuations_mean': 1e-3,
'fluctuations_stddev': 1e-16,
'flexibility_mean': 1e-3,
'flexibility_stddev': 1e-16,
'asperity_mean': 1e-3,
'asperity_stddev': 1e-16,
'loglogavgslope_mean': 0.,
'loglogavgslope_stddev': 1e-16
}
init_model(cf_make_pars, cf_x_fluct_pars)
```
%% Cell type:code id: tags:
``` python
# Show neutral field
eval_model(cf_make_pars, cf_x_fluct_pars, "Neutral Field")
```
%% Cell type:markdown id: tags:
# Parameter Showcases
## The `fluctuations_` parameters of `add_fluctuations()`
determine the **amplitude of variations along the field dimension**
for which `add_fluctuations` is called.
`fluctuations_mean` set the _average_ amplitude of the fields fluctuations along the given dimension,\
`fluctuations_stddev` sets the width and shape of the amplitude distribution.
The amplitude is modelled as being log-normally distributed,
see `The Moment-Matched Log-Normal Distribution` above for details.
#### `fluctuations_mean`:
%% Cell type:code id: tags:
``` python
vary_parameter('fluctuations_mean', [0.05, 0.5, 2.], samples_vary_in='xi')
```
%% Cell type:markdown id: tags:
#### `fluctuations_stddev`:
%% Cell type:code id: tags:
``` python
cf_x_fluct_pars['fluctuations_mean'] = 1.0
vary_parameter('fluctuations_stddev', [0.01, 0.1, 1.0], samples_vary_in='fluctuations')
```
%% Cell type:markdown id: tags:
## The `loglogavgslope_` parameters of `add_fluctuations()`
determine **the slope of the loglog-linear (power law) component of the power spectrum**.
The slope is modelled to be normally distributed.
#### `loglogavgslope_mean`:
%% Cell type:code id: tags:
``` python
vary_parameter('loglogavgslope_mean', [-3., -1., 1.], samples_vary_in='xi')
```
%% Cell type:markdown id: tags:
#### `loglogavgslope_stddev`:
%% Cell type:code id: tags:
``` python
cf_x_fluct_pars['loglogavgslope_mean'] = -1.0
vary_parameter('loglogavgslope_stddev', [0.01, 0.1, 1.0], samples_vary_in='loglogavgslope')
```
%% Cell type:markdown id: tags:
## The `flexibility_` parameters of `add_fluctuations()`
determine **the amplitude of the integrated Wiener process component of the power spectrum**
(how strong the power spectrum varies besides the power-law).
`flexibility_mean` sets the _average_ amplitude of the i.g.p. component,
`flexibility_stddev` sets how much the amplitude can vary.\
These two parameters feed into a moment-matched log-normal distribution model,
see above for a demo of its behavior.
#### `flexibility_mean`:
%% Cell type:code id: tags:
``` python
vary_parameter('flexibility_mean', [0.1, 1.0, 3.0], samples_vary_in='spectrum')
```
%% Cell type:markdown id: tags:
#### `flexibility_stddev`:
%% Cell type:code id: tags:
``` python
cf_x_fluct_pars['flexibility_mean'] = 2.0
vary_parameter('flexibility_stddev', [0.01, 0.1, 1.0], samples_vary_in='flexibility')
```
%% Cell type:markdown id: tags:
## The `asperity_` parameters of `add_fluctuations()`
`asperity_` determines **how rough the integrated Wiener process component of the power spectrum is**.
`asperity_mean` sets the average roughness, `asperity_stddev` sets how much the roughness can vary.\
These two parameters feed into a moment-matched log-normal distribution model,
see above for a demo of its behavior.
#### `asperity_mean`:
%% Cell type:code id: tags:
``` python
vary_parameter('asperity_mean', [0.001, 1.0, 5.0], samples_vary_in='spectrum')
```
%% Cell type:markdown id: tags:
#### `asperity_stddev`:
%% Cell type:code id: tags:
``` python
cf_x_fluct_pars['asperity_mean'] = 1.0
vary_parameter('asperity_stddev', [0.01, 0.1, 1.0], samples_vary_in='asperity')
```
%% Cell type:markdown id: tags:
## The `offset_mean` parameter of `CorrelatedFieldMaker.make()`
The `offset_mean` parameter defines a global additive offset on the field realizations.
If the field is used for a lognormal model `f = field.exp()`, this acts as a global signal magnitude offset.
%% Cell type:code id: tags:
``` python
# Reset model to neutral
cf_x_fluct_pars['fluctuations_mean'] = 1e-3
cf_x_fluct_pars['flexibility_mean'] = 1e-3
cf_x_fluct_pars['asperity_mean'] = 1e-3
cf_x_fluct_pars['loglogavgslope_mean'] = 1e-3
```
%% Cell type:code id: tags:
``` python
vary_parameter('offset_mean', [3., 0., -2.])
```
%% Cell type:markdown id: tags:
## The `offset_std_` parameters of `CorrelatedFieldMaker.make()`
Variation of the global offset of the field are modelled as being log-normally distributed.
See `The Moment-Matched Log-Normal Distribution` above for details.
The `offset_std_mean` parameter sets how much NIFTy will vary the offset *on average*.\
The `offset_std_std` parameters defines the with and shape of the offset variation distribution.
#### `offset_std_mean`:
%% Cell type:code id: tags:
``` python
vary_parameter('offset_std_mean', [1e-16, 0.5, 2.], samples_vary_in='xi')
```
%% Cell type:markdown id: tags:
#### `offset_std_std`:
%% Cell type:code id: tags:
``` python
cf_make_pars['offset_std_mean'] = 1.0
vary_parameter('offset_std_std', [0.01, 0.1, 1.0], samples_vary_in='zeromode')
```
......
......@@ -52,6 +52,7 @@ from .operators.energy_operators import (
BernoulliEnergy, StandardHamiltonian, AveragedEnergy, QuadraticFormOperator,
Squared2NormOperator, StudentTEnergy, VariableCovarianceGaussianEnergy)
from .operators.convolution_operators import FuncConvolutionOperator
from .operators.normal_operators import NormalTransform, LognormalTransform
from .probing import probe_with_posterior_samples, probe_diagonal, \
StatCalculator, approximation2endo
......
......@@ -37,48 +37,11 @@ from ..operators.harmonic_operators import HarmonicTransformOperator
from ..operators.linear_operator import LinearOperator
from ..operators.operator import Operator
from ..operators.simple_linear_operators import ducktape
from ..operators.normal_operators import NormalTransform, LognormalTransform
from ..probing import StatCalculator
from ..sugar import full, makeDomain, makeField, makeOp
def _reshaper(x, N):
x = np.asfarray(x)
if x.shape in [(), (1,)]:
return np.full(N, x) if N != 0 else x.reshape(())
elif x.shape == (N,):
return x
else:
raise TypeError("Shape of parameters cannot be interpreted")
def _lognormal_moments(mean, sig, N=0):
if N == 0:
mean, sig = np.asfarray(mean), np.asfarray(sig)
else:
mean, sig = (_reshaper(param, N) for param in (mean, sig))
if not np.all(mean > 0):
raise ValueError("mean must be greater 0; got {!r}".format(mean))
if not np.all(sig > 0):
raise ValueError("sig must be greater 0; got {!r}".format(sig))
logsig = np.sqrt(np.log1p((sig/mean)**2))
logmean = np.log(mean) - logsig**2/2
return logmean, logsig
def _normal(mean, sig, key, N=0):
if N == 0:
domain = DomainTuple.scalar_domain()
mean, sig = np.asfarray(mean), np.asfarray(sig)
return Adder(makeField(domain, mean)) @ (
sig * ducktape(domain, None, key))
domain = UnstructuredDomain(N)
mean, sig = (_reshaper(param, N) for param in (mean, sig))
return Adder(makeField(domain, mean)) @ (DiagonalOperator(
makeField(domain, sig)) @ ducktape(domain, None, key))
def _log_k_lengths(pspace):
"""Log(k_lengths) without zeromode"""
return np.log(pspace.k_lengths[1:])
......@@ -120,29 +83,6 @@ def _total_fluctuation_realized(samples):
return np.sqrt(res if np.isscalar(res) else res.val)
class _LognormalMomentMatching(Operator):
def __init__(self, mean, sig, key, N_copies):
key = str(key)
logmean, logsig = _lognormal_moments(mean, sig, N_copies)
self._mean = mean
self._sig = sig
op = _normal(logmean, logsig, key, N_copies).ptw("exp")
self._domain, self._target = op.domain, op.target
self.apply = op.apply
self._repr_str = f"_LognormalMomentMatching: " + op.__repr__()
@property
def mean(self):
return self._mean
@property
def std(self):
return self._sig
def __repr__(self):
return self._repr_str
class _SlopeRemover(EndomorphicOperator):
def __init__(self, domain, space=0):
self._domain = makeDomain(domain)
......@@ -441,10 +381,8 @@ class CorrelatedFieldMaker:
elif len(dofdex) != total_N:
raise ValueError("length of dofdex needs to match total_N")
N = max(dofdex) + 1 if total_N > 0 else 0
zm = _LognormalMomentMatching(offset_std_mean,
offset_std_std,
prefix + 'zeromode',
N)
zm = LognormalTransform(offset_std_mean, offset_std_std,
prefix + 'zeromode', N)
if total_N > 0:
zm = _Distributor(dofdex, zm.target, UnstructuredDomain(total_N)) @ zm
return CorrelatedFieldMaker(offset_mean, zm, prefix, total_N)
......@@ -532,17 +470,15 @@ class CorrelatedFieldMaker:
prefix = str(prefix)
# assert isinstance(target_subdomain[space], (RGSpace, HPSpace, GLSpace)
fluct = _LognormalMomentMatching(fluctuations_mean,
fluctuations_stddev,
self._prefix + prefix + 'fluctuations',
N)
flex = _LognormalMomentMatching(flexibility_mean, flexibility_stddev,
self._prefix + prefix + 'flexibility',
N)
asp = _LognormalMomentMatching(asperity_mean, asperity_stddev,
self._prefix + prefix + 'asperity', N)
avgsl = _normal(loglogavgslope_mean, loglogavgslope_stddev,
self._prefix + prefix + 'loglogavgslope', N)
fluct = LognormalTransform(fluctuations_mean, fluctuations_stddev,
self._prefix + prefix + 'fluctuations', N)
flex = LognormalTransform(flexibility_mean, flexibility_stddev,
self._prefix + prefix + 'flexibility', N)
asp = LognormalTransform(asperity_mean, asperity_stddev,
self._prefix + prefix + 'asperity', N)
avgsl = NormalTransform(loglogavgslope_mean, loglogavgslope_stddev,
self._prefix + prefix + 'loglogavgslope', N)
amp = _Amplitude(PowerSpace(harmonic_partner), fluct, flex, asp, avgsl,
self._azm, target_subdomain[-1].total_volume,
self._prefix + prefix + 'spectrum', dofdex)
......
......@@ -132,14 +132,9 @@ class MetricGaussianKL(Energy):
_, self._hamiltonian = hamiltonian.simplify_for_constant_input(cstpos)
self._n_samples = int(n_samples)
if comm is not None:
self._comm = comm
ntask = self._comm.Get_size()
rank = self._comm.Get_rank()
self._lo, self._hi = utilities.shareRange(self._n_samples, ntask, rank)
else:
self._comm = None
self._lo, self._hi = 0, self._n_samples
self._comm = comm
ntask, rank, _ = utilities.get_MPI_params_from_comm(self._comm)
self._lo, self._hi = utilities.shareRange(self._n_samples, ntask, rank)
self._mirror_samples = bool(mirror_samples)
self._n_eff_samples = self._n_samples
......@@ -213,14 +208,13 @@ class MetricGaussianKL(Energy):
@property
def samples(self):
if self._comm is None:
ntask, rank, _ = utilities.get_MPI_params_from_comm(self._comm)
if ntask == 1:
for s in self._local_samples:
yield s
if self._mirror_samples:
yield -s
else:
ntask = self._comm.Get_size()
rank = self._comm.Get_rank()
rank_lo_hi = [utilities.shareRange(self._n_samples, ntask, i) for i in range(ntask)]
for itask, (l, h) in enumerate(rank_lo_hi):
for i in range(l, h):
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2020 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
from ..domain_tuple import DomainTuple
from ..domains.unstructured_domain import UnstructuredDomain
from ..operators.operator import Operator
from ..operators.adder import Adder
from ..operators.simple_linear_operators import ducktape
from ..operators.diagonal_operator import DiagonalOperator
from ..sugar import makeField
from ..utilities import value_reshaper, lognormal_moments
def NormalTransform(mean, sigma, key, N_copies=0):
"""Opchain that transforms standard normally distributed values to
normally distributed values with given mean an standard deviation.
Parameters:
-----------
mean : float
Mean of the field
sigma : float
Standard deviation of the field
key : string
Name of the operators domain (Multidomain)
N_copies : integer
If == 0, target will be a scalar field.
If >= 1, target will be an
:class:`~nifty7.unstructured_domain.UnstructuredDomain`.
"""
if N_copies == 0:
domain = DomainTuple.scalar_domain()
mean, sigma = np.asfarray(mean), np.asfarray(sigma)
mean_adder = Adder(makeField(domain, mean))
return mean_adder @ (sigma * ducktape(domain, None, key))
domain = UnstructuredDomain(N_copies)
mean, sigma = (value_reshaper(param, N_copies) for param in (mean, sigma))
mean_adder = Adder(makeField(domain, mean))
sigma_op = DiagonalOperator(makeField(domain, sigma))
return mean_adder @ sigma_op @ ducktape(domain, None, key)
def LognormalTransform(mean, sigma, key, N_copies):
"""Opchain that transforms standard normally distributed values to
log-normally distributed values with given mean an standard deviation.
Parameters:
-----------
mean : float
Mean of the field
sigma : float
Standard deviation of the field
key : string
Name of the domain
N_copies : integer
If == 0, target will be a scalar field.
If >= 1, target will be an
:class:`~nifty7.unstructured_domain.UnstructuredDomain`.
"""
logmean, logsigma = lognormal_moments(mean, sigma, N_copies)
return NormalTransform(logmean, logsigma, key, N_copies).ptw("exp")
......@@ -77,10 +77,18 @@ objects, generate new ones via :func:`spawn_sseq()`.
import numpy as np
# fix for numpy issue #16539
def _fix_seed(seed):
if isinstance(seed, int):
return (seed, 0, 0, 0)
raise TypeError("random seed should have integer type")
# Stack of SeedSequence objects. Will always start out with a well-defined
# default. Users can change the "random seed" used by a calculation by pushing
# a different SeedSequence before invoking any other nifty7.random calls
_sseq = [np.random.SeedSequence(42)]
_sseq = [np.random.SeedSequence(_fix_seed(42))]
# Stack of random number generators associated with _sseq.
_rng = [np.random.default_rng(_sseq[-1])]
......@@ -196,7 +204,7 @@ def push_sseq_from_seed(seed):
In all other situations, it is highly recommended to use the
:class:`Context` class for managing the RNG state.
"""
_sseq.append(np.random.SeedSequence(seed))
_sseq.append(np.random.SeedSequence(_fix_seed(seed)))
_rng.append(np.random.default_rng(_sseq[-1]))
......@@ -276,7 +284,7 @@ class Context(object):
def __init__(self, inp):
if not isinstance(inp, np.random.SeedSequence):
inp = np.random.SeedSequence(inp)
inp = np.random.SeedSequence(_fix_seed(inp))
self._sseq = inp
def __enter__(self):
......
......@@ -24,7 +24,8 @@ import numpy as np
__all__ = ["get_slice_list", "safe_cast", "parse_spaces", "infer_space",
"memo", "NiftyMeta", "my_sum", "my_lincomb_simple",
"my_lincomb", "indent",
"my_product", "frozendict", "special_add_at", "iscomplextype"]
"my_product", "frozendict", "special_add_at", "iscomplextype",
"value_reshaper", "lognormal_moments"]
def my_sum(iterable):
......@@ -277,6 +278,16 @@ def shareRange(nwork, nshares, myshare):
return lo, hi
def get_MPI_params_from_comm(comm):
if comm is None:
return 1, 0, True
size = comm.Get_size()
rank = comm.Get_rank()
return size, rank, rank == 0
def get_MPI_params():
"""Returns basic information about the MPI setup of the running script.
......@@ -366,3 +377,31 @@ def allreduce_sum(obj, comm):
if comm is None:
return vals[0]
return comm.bcast(vals[0], root=who[0])
def value_reshaper(x, N):
"""Produce arrays of shape `(N,)`.
If `x` is a scalar or array of length one, fill the target array with it.
If `x` is an array, check if it has the right shape."""
x = np.asfarray(x)
if x.shape in [(), (1, )]:
return np.full(N, x) if N != 0 else x.reshape(())
elif x.shape == (N, ):
return x
raise TypeError("x and N are incompatible")
def lognormal_moments(mean, sigma, N=0):
"""Calculates the parameters for a normal distribution `n(x)`
such that `exp(n)(x)` has the mean and standard deviation given.
Used in :func:`~nifty7.normal_operators.LognormalTransform`."""
mean, sigma = (value_reshaper(param, N) for param in (mean, sigma))
if not np.all(mean > 0):
raise ValueError("mean must be greater 0; got {!r}".format(mean))
if not np.all(sigma > 0):
raise ValueError("sig must be greater 0; got {!r}".format(sigma))
logsigma = np.sqrt(np.log1p((sigma / mean)**2))
logmean = np.log(mean) - logsigma**2 / 2
return logmean, logsigma
......@@ -163,8 +163,8 @@ def testNormalization(h_space, specialbinbounds, logarithmic, nbin):
ift.extra.check_jacobian_consistency(op, pos, ntries=10)
@pmp('N', [1, 3])
def testMomentMatchingJacobian(N):
op = ift.library.correlated_fields._LognormalMomentMatching(1, 0.2, '', N)
ift.extra.check_jacobian_consistency(op, ift.from_random(op.domain),
ntries=10)
@pmp('N', [1, 20])
def testLognormalTransform(N):
op = ift.LognormalTransform(1, 0.2, '', N)
loc = ift.from_random(op.domain)
ift.extra.check_jacobian_consistency(op, loc, ntries=10)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2020 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import pytest
from numpy.testing import assert_allclose
import nifty7 as ift
from ..common import setup_function, teardown_function
pmp = pytest.mark.parametrize
@pmp('mean', [3., -2])
@pmp('std', [1.5, 0.1])
@pmp('seed', [7, 21])
def test_normal_transform(mean, std, seed):
op = ift.NormalTransform(mean, std, 'dom')
assert op.target is ift.DomainTuple.make(())
op = ift.NormalTransform(mean, std, 'dom', 500)
with ift.random.Context(seed):
res = op(ift.from_random(op.domain))
assert_allclose(res.val.mean(), mean, rtol=0.1)
assert_allclose(res.val.std(), std, rtol=0.1)
loc = ift.from_random(op.domain)
ift.extra.check_jacobian_consistency(op, loc)
@pmp('mean', [0.01, 10.])
@pmp('std_fct', [0.01, 0.1])
@pmp('seed', [7, 21])
def test_lognormal_transform(mean, std_fct, seed):
std = mean * std_fct
op = ift.LognormalTransform(mean, std, 'dom', 500)
with ift.random.Context(seed):
res = op(ift.from_random(op.domain))
assert_allclose(res.val.mean(), mean, rtol=0.1)
assert_allclose(res.val.std(), std, rtol=0.1)
loc = ift.from_random(op.domain)
ift.extra.check_jacobian_consistency(op, loc)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment