Commit 7d380d74 authored by Lukas Platz's avatar Lukas Platz
Browse files

move lognormal_moments and value_reshaper to utilities

parent b905b9fd
......@@ -23,19 +23,10 @@ from ..operators.adder import Adder
from ..operators.simple_linear_operators import ducktape
from ..operators.diagonal_operator import DiagonalOperator
from ..sugar import makeField
from ..utilities import value_reshaper, lognormal_moments
def _reshaper(x, N):
x = np.asfarray(x)
if x.shape in [(), (1, )]:
return np.full(N, x) if N != 0 else x.reshape(())
elif x.shape == (N, ):
return x
else:
raise TypeError("Shape of parameters cannot be interpreted")
def NormalTransform(mean, sigma, key, N=0):
def NormalTransform(mean, sigma, key, N_copies=0):
"""Opchain that transforms standard normally distributed values to
normally distributed values with given mean an standard deviation.
......@@ -52,35 +43,20 @@ def NormalTransform(mean, sigma, key, N=0):
If >= 1, target will be an
:class:`~nifty6.unstructured_domain.UnstructuredDomain`.
"""
if N == 0:
if N_copies == 0:
domain = DomainTuple.scalar_domain()
mean, sigma = np.asfarray(mean), np.asfarray(sigma)
mean_adder = Adder(makeField(domain, mean))
return mean_adder @ (sigma * ducktape(domain, None, key))
domain = UnstructuredDomain(N)
mean, sigma = (_reshaper(param, N) for param in (mean, sigma))
domain = UnstructuredDomain(N_copies)
mean, sigma = (value_reshaper(param, N_copies) for param in (mean, sigma))
mean_adder = Adder(makeField(domain, mean))
sigma_op = DiagonalOperator(makeField(domain, sigma))
return mean_adder @ sigma_op @ ducktape(domain, None, key)
def _lognormal_moments(mean, sig, N=0):
if N == 0:
mean, sig = np.asfarray(mean), np.asfarray(sig)
else:
mean, sig = (_reshaper(param, N) for param in (mean, sig))
if not np.all(mean > 0):
raise ValueError("mean must be greater 0; got {!r}".format(mean))
if not np.all(sig > 0):
raise ValueError("sig must be greater 0; got {!r}".format(sig))
logsig = np.sqrt(np.log1p((sig / mean)**2))
logmean = np.log(mean) - logsig**2 / 2
return logmean, logsig
class LognormalTransform(Operator):
def LognormalTransform(mean, sigma, key, N_copies):
"""Opchain that transforms standard normally distributed values to
log-normally distributed values with given mean an standard deviation.
......@@ -97,23 +73,5 @@ class LognormalTransform(Operator):
If >= 1, target will be an
:class:`~nifty6.unstructured_domain.UnstructuredDomain`.
"""
def __init__(self, mean, sigma, key, N_copies):
key = str(key)
logmean, logsigma = _lognormal_moments(mean, sigma, N_copies)
self._mean = mean
self._sigma = sigma
op = NormalTransform(logmean, logsigma, key, N_copies).ptw("exp")
self._domain, self._target = op.domain, op.target
self.apply = op.apply
self._repr_str = f"LognormalTransform: " + op.__repr__()
@property
def mean(self):
return self._mean
@property
def std(self):
return self._sigma
def __repr__(self):
return self._repr_str
logmean, logsigma = lognormal_moments(mean, sigma, N_copies)
return NormalTransform(logmean, logsigma, key, N_copies).ptw("exp")
......@@ -24,7 +24,8 @@ import numpy as np
__all__ = ["get_slice_list", "safe_cast", "parse_spaces", "infer_space",
"memo", "NiftyMeta", "my_sum", "my_lincomb_simple",
"my_lincomb", "indent",
"my_product", "frozendict", "special_add_at", "iscomplextype"]
"my_product", "frozendict", "special_add_at", "iscomplextype",
"value_reshaper", "lognormal_moments"]
def my_sum(iterable):
......@@ -366,3 +367,31 @@ def allreduce_sum(obj, comm):
if comm is None:
return vals[0]
return comm.bcast(vals[0], root=who[0])
def value_reshaper(x, N):
"""Produce arrays of shape `(N,)`.
If `x` is a scalar or array of length one, fill the target array with it.
If `x` is an array, check if it has the right shape."""
x = np.asfarray(x)
if x.shape in [(), (1, )]:
return np.full(N, x) if N != 0 else x.reshape(())
elif x.shape == (N, ):
return x
raise TypeError("x and N are incompatible")
def lognormal_moments(mean, sigma, N=0):
"""Calculates the parameters for a normal distribution `n(x)`
such that `exp(n)(x)` has the mean and standard deviation given.
Used in :func:`~nifty6.normal_operators.LognormalTransform`."""
mean, sigma = (value_reshaper(param, N) for param in (mean, sigma))
if not np.all(mean > 0):
raise ValueError("mean must be greater 0; got {!r}".format(mean))
if not np.all(sigma > 0):
raise ValueError("sig must be greater 0; got {!r}".format(sigma))
logsigma = np.sqrt(np.log1p((sigma / mean)**2))
logmean = np.log(mean) - logsigma**2 / 2
return logmean, logsigma
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment