Commit 744ce661 authored by Philipp Arras's avatar Philipp Arras
Browse files

Merge remote-tracking branch 'origin/NIFTy_6' into nifty6to7

parents 254dfe98 ccbf5012
Pipeline #76390 failed with stages
in 45 seconds
......@@ -152,10 +152,8 @@
"sigmas = [1.0, 0.5, 0.1]\n",
"\n",
"for i in range(3):\n",
" op = ift.library.correlated_fields._LognormalMomentMatching(mean=mean,\n",
" sig=sigmas[i],\n",
" key='foo',\n",
" N_copies=0)\n",
" op = ift.LognormalTransform(mean=mean, sigma=sigmas[i],\n",
" key='foo', N_copies=0)\n",
" op_samples = np.array(\n",
" [op(s).val for s in [ift.from_random(op.domain) for i in range(10000)]])\n",
"\n",
......
......@@ -52,6 +52,7 @@ from .operators.energy_operators import (
BernoulliEnergy, StandardHamiltonian, AveragedEnergy, QuadraticFormOperator,
Squared2NormOperator, StudentTEnergy, VariableCovarianceGaussianEnergy)
from .operators.convolution_operators import FuncConvolutionOperator
from .operators.normal_operators import NormalTransform, LognormalTransform
from .probing import probe_with_posterior_samples, probe_diagonal, \
StatCalculator, approximation2endo
......
......@@ -37,48 +37,11 @@ from ..operators.harmonic_operators import HarmonicTransformOperator
from ..operators.linear_operator import LinearOperator
from ..operators.operator import Operator
from ..operators.simple_linear_operators import ducktape
from ..operators.normal_operators import NormalTransform, LognormalTransform
from ..probing import StatCalculator
from ..sugar import full, makeDomain, makeField, makeOp
def _reshaper(x, N):
x = np.asfarray(x)
if x.shape in [(), (1,)]:
return np.full(N, x) if N != 0 else x.reshape(())
elif x.shape == (N,):
return x
else:
raise TypeError("Shape of parameters cannot be interpreted")
def _lognormal_moments(mean, sig, N=0):
if N == 0:
mean, sig = np.asfarray(mean), np.asfarray(sig)
else:
mean, sig = (_reshaper(param, N) for param in (mean, sig))
if not np.all(mean > 0):
raise ValueError("mean must be greater 0; got {!r}".format(mean))
if not np.all(sig > 0):
raise ValueError("sig must be greater 0; got {!r}".format(sig))
logsig = np.sqrt(np.log1p((sig/mean)**2))
logmean = np.log(mean) - logsig**2/2
return logmean, logsig
def _normal(mean, sig, key, N=0):
if N == 0:
domain = DomainTuple.scalar_domain()
mean, sig = np.asfarray(mean), np.asfarray(sig)
return Adder(makeField(domain, mean)) @ (
sig * ducktape(domain, None, key))
domain = UnstructuredDomain(N)
mean, sig = (_reshaper(param, N) for param in (mean, sig))
return Adder(makeField(domain, mean)) @ (DiagonalOperator(
makeField(domain, sig)) @ ducktape(domain, None, key))
def _log_k_lengths(pspace):
"""Log(k_lengths) without zeromode"""
return np.log(pspace.k_lengths[1:])
......@@ -120,29 +83,6 @@ def _total_fluctuation_realized(samples):
return np.sqrt(res if np.isscalar(res) else res.val)
class _LognormalMomentMatching(Operator):
def __init__(self, mean, sig, key, N_copies):
key = str(key)
logmean, logsig = _lognormal_moments(mean, sig, N_copies)
self._mean = mean
self._sig = sig
op = _normal(logmean, logsig, key, N_copies).ptw("exp")
self._domain, self._target = op.domain, op.target
self.apply = op.apply
self._repr_str = f"_LognormalMomentMatching: " + op.__repr__()
@property
def mean(self):
return self._mean
@property
def std(self):
return self._sig
def __repr__(self):
return self._repr_str
class _SlopeRemover(EndomorphicOperator):
def __init__(self, domain, space=0):
self._domain = makeDomain(domain)
......@@ -441,10 +381,8 @@ class CorrelatedFieldMaker:
elif len(dofdex) != total_N:
raise ValueError("length of dofdex needs to match total_N")
N = max(dofdex) + 1 if total_N > 0 else 0
zm = _LognormalMomentMatching(offset_std_mean,
offset_std_std,
prefix + 'zeromode',
N)
zm = LognormalTransform(offset_std_mean, offset_std_std,
prefix + 'zeromode', N)
if total_N > 0:
zm = _Distributor(dofdex, zm.target, UnstructuredDomain(total_N)) @ zm
return CorrelatedFieldMaker(offset_mean, zm, prefix, total_N)
......@@ -532,17 +470,15 @@ class CorrelatedFieldMaker:
prefix = str(prefix)
# assert isinstance(target_subdomain[space], (RGSpace, HPSpace, GLSpace)
fluct = _LognormalMomentMatching(fluctuations_mean,
fluctuations_stddev,
self._prefix + prefix + 'fluctuations',
N)
flex = _LognormalMomentMatching(flexibility_mean, flexibility_stddev,
self._prefix + prefix + 'flexibility',
N)
asp = _LognormalMomentMatching(asperity_mean, asperity_stddev,
fluct = LognormalTransform(fluctuations_mean, fluctuations_stddev,
self._prefix + prefix + 'fluctuations', N)
flex = LognormalTransform(flexibility_mean, flexibility_stddev,
self._prefix + prefix + 'flexibility', N)
asp = LognormalTransform(asperity_mean, asperity_stddev,
self._prefix + prefix + 'asperity', N)
avgsl = _normal(loglogavgslope_mean, loglogavgslope_stddev,
avgsl = NormalTransform(loglogavgslope_mean, loglogavgslope_stddev,
self._prefix + prefix + 'loglogavgslope', N)
amp = _Amplitude(PowerSpace(harmonic_partner), fluct, flex, asp, avgsl,
self._azm, target_subdomain[-1].total_volume,
self._prefix + prefix + 'spectrum', dofdex)
......
......@@ -132,14 +132,9 @@ class MetricGaussianKL(Energy):
_, self._hamiltonian = hamiltonian.simplify_for_constant_input(cstpos)
self._n_samples = int(n_samples)
if comm is not None:
self._comm = comm
ntask = self._comm.Get_size()
rank = self._comm.Get_rank()
ntask, rank, _ = utilities.get_MPI_params_from_comm(self._comm)
self._lo, self._hi = utilities.shareRange(self._n_samples, ntask, rank)
else:
self._comm = None
self._lo, self._hi = 0, self._n_samples
self._mirror_samples = bool(mirror_samples)
self._n_eff_samples = self._n_samples
......@@ -213,14 +208,13 @@ class MetricGaussianKL(Energy):
@property
def samples(self):
if self._comm is None:
ntask, rank, _ = utilities.get_MPI_params_from_comm(self._comm)
if ntask == 1:
for s in self._local_samples:
yield s
if self._mirror_samples:
yield -s
else:
ntask = self._comm.Get_size()
rank = self._comm.Get_rank()
rank_lo_hi = [utilities.shareRange(self._n_samples, ntask, i) for i in range(ntask)]
for itask, (l, h) in enumerate(rank_lo_hi):
for i in range(l, h):
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2020 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
from ..domain_tuple import DomainTuple
from ..domains.unstructured_domain import UnstructuredDomain
from ..operators.operator import Operator
from ..operators.adder import Adder
from ..operators.simple_linear_operators import ducktape
from ..operators.diagonal_operator import DiagonalOperator
from ..sugar import makeField
from ..utilities import value_reshaper, lognormal_moments
def NormalTransform(mean, sigma, key, N_copies=0):
"""Opchain that transforms standard normally distributed values to
normally distributed values with given mean an standard deviation.
Parameters:
-----------
mean : float
Mean of the field
sigma : float
Standard deviation of the field
key : string
Name of the operators domain (Multidomain)
N_copies : integer
If == 0, target will be a scalar field.
If >= 1, target will be an
:class:`~nifty6.unstructured_domain.UnstructuredDomain`.
"""
if N_copies == 0:
domain = DomainTuple.scalar_domain()
mean, sigma = np.asfarray(mean), np.asfarray(sigma)
mean_adder = Adder(makeField(domain, mean))
return mean_adder @ (sigma * ducktape(domain, None, key))
domain = UnstructuredDomain(N_copies)
mean, sigma = (value_reshaper(param, N_copies) for param in (mean, sigma))
mean_adder = Adder(makeField(domain, mean))
sigma_op = DiagonalOperator(makeField(domain, sigma))
return mean_adder @ sigma_op @ ducktape(domain, None, key)
def LognormalTransform(mean, sigma, key, N_copies):
"""Opchain that transforms standard normally distributed values to
log-normally distributed values with given mean an standard deviation.
Parameters:
-----------
mean : float
Mean of the field
sigma : float
Standard deviation of the field
key : string
Name of the domain
N_copies : integer
If == 0, target will be a scalar field.
If >= 1, target will be an
:class:`~nifty6.unstructured_domain.UnstructuredDomain`.
"""
logmean, logsigma = lognormal_moments(mean, sigma, N_copies)
return NormalTransform(logmean, logsigma, key, N_copies).ptw("exp")
......@@ -77,10 +77,18 @@ objects, generate new ones via :func:`spawn_sseq()`.
import numpy as np
# fix for numpy issue #16539
def _fix_seed(seed):
if isinstance(seed, int):
return (seed, 0, 0, 0)
raise TypeError("random seed should have integer type")
# Stack of SeedSequence objects. Will always start out with a well-defined
# default. Users can change the "random seed" used by a calculation by pushing
# a different SeedSequence before invoking any other nifty7.random calls
_sseq = [np.random.SeedSequence(42)]
_sseq = [np.random.SeedSequence(_fix_seed(42))]
# Stack of random number generators associated with _sseq.
_rng = [np.random.default_rng(_sseq[-1])]
......@@ -196,7 +204,7 @@ def push_sseq_from_seed(seed):
In all other situations, it is highly recommended to use the
:class:`Context` class for managing the RNG state.
"""
_sseq.append(np.random.SeedSequence(seed))
_sseq.append(np.random.SeedSequence(_fix_seed(seed)))
_rng.append(np.random.default_rng(_sseq[-1]))
......@@ -276,7 +284,7 @@ class Context(object):
def __init__(self, inp):
if not isinstance(inp, np.random.SeedSequence):
inp = np.random.SeedSequence(inp)
inp = np.random.SeedSequence(_fix_seed(inp))
self._sseq = inp
def __enter__(self):
......
......@@ -24,7 +24,8 @@ import numpy as np
__all__ = ["get_slice_list", "safe_cast", "parse_spaces", "infer_space",
"memo", "NiftyMeta", "my_sum", "my_lincomb_simple",
"my_lincomb", "indent",
"my_product", "frozendict", "special_add_at", "iscomplextype"]
"my_product", "frozendict", "special_add_at", "iscomplextype",
"value_reshaper", "lognormal_moments"]
def my_sum(iterable):
......@@ -277,6 +278,16 @@ def shareRange(nwork, nshares, myshare):
return lo, hi
def get_MPI_params_from_comm(comm):
if comm is None:
return 1, 0, True
size = comm.Get_size()
rank = comm.Get_rank()
return size, rank, rank == 0
def get_MPI_params():
"""Returns basic information about the MPI setup of the running script.
......@@ -366,3 +377,31 @@ def allreduce_sum(obj, comm):
if comm is None:
return vals[0]
return comm.bcast(vals[0], root=who[0])
def value_reshaper(x, N):
"""Produce arrays of shape `(N,)`.
If `x` is a scalar or array of length one, fill the target array with it.
If `x` is an array, check if it has the right shape."""
x = np.asfarray(x)
if x.shape in [(), (1, )]:
return np.full(N, x) if N != 0 else x.reshape(())
elif x.shape == (N, ):
return x
raise TypeError("x and N are incompatible")
def lognormal_moments(mean, sigma, N=0):
"""Calculates the parameters for a normal distribution `n(x)`
such that `exp(n)(x)` has the mean and standard deviation given.
Used in :func:`~nifty6.normal_operators.LognormalTransform`."""
mean, sigma = (value_reshaper(param, N) for param in (mean, sigma))
if not np.all(mean > 0):
raise ValueError("mean must be greater 0; got {!r}".format(mean))
if not np.all(sigma > 0):
raise ValueError("sig must be greater 0; got {!r}".format(sigma))
logsigma = np.sqrt(np.log1p((sigma / mean)**2))
logmean = np.log(mean) - logsigma**2 / 2
return logmean, logsigma
......@@ -163,8 +163,8 @@ def testNormalization(h_space, specialbinbounds, logarithmic, nbin):
ift.extra.check_jacobian_consistency(op, pos, ntries=10)
@pmp('N', [1, 3])
def testMomentMatchingJacobian(N):
op = ift.library.correlated_fields._LognormalMomentMatching(1, 0.2, '', N)
ift.extra.check_jacobian_consistency(op, ift.from_random(op.domain),
ntries=10)
@pmp('N', [1, 20])
def testLognormalTransform(N):
op = ift.LognormalTransform(1, 0.2, '', N)
loc = ift.from_random(op.domain)
ift.extra.check_jacobian_consistency(op, loc, ntries=10)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2020 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import pytest
from numpy.testing import assert_allclose
import nifty6 as ift
from ..common import setup_function, teardown_function
pmp = pytest.mark.parametrize
@pmp('mean', [3., -2])
@pmp('std', [1.5, 0.1])
@pmp('seed', [7, 21])
def test_normal_transform(mean, std, seed):
op = ift.NormalTransform(mean, std, 'dom')
assert op.target is ift.DomainTuple.make(())
op = ift.NormalTransform(mean, std, 'dom', 500)
with ift.random.Context(seed):
res = op(ift.from_random(op.domain))
assert_allclose(res.val.mean(), mean, rtol=0.1)
assert_allclose(res.val.std(), std, rtol=0.1)
loc = ift.from_random(op.domain)
ift.extra.check_jacobian_consistency(op, loc)
@pmp('mean', [0.01, 10.])
@pmp('std_fct', [0.01, 0.1])
@pmp('seed', [7, 21])
def test_lognormal_transform(mean, std_fct, seed):
std = mean * std_fct
op = ift.LognormalTransform(mean, std, 'dom', 500)
with ift.random.Context(seed):
res = op(ift.from_random(op.domain))
assert_allclose(res.val.mean(), mean, rtol=0.1)
assert_allclose(res.val.std(), std, rtol=0.1)
loc = ift.from_random(op.domain)
ift.extra.check_jacobian_consistency(op, loc)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment