Commit 7d380d74 authored by Lukas Platz's avatar Lukas Platz
Browse files

move lognormal_moments and value_reshaper to utilities

parent b905b9fd
...@@ -23,19 +23,10 @@ from ..operators.adder import Adder ...@@ -23,19 +23,10 @@ from ..operators.adder import Adder
from ..operators.simple_linear_operators import ducktape from ..operators.simple_linear_operators import ducktape
from ..operators.diagonal_operator import DiagonalOperator from ..operators.diagonal_operator import DiagonalOperator
from ..sugar import makeField from ..sugar import makeField
from ..utilities import value_reshaper, lognormal_moments
def _reshaper(x, N): def NormalTransform(mean, sigma, key, N_copies=0):
x = np.asfarray(x)
if x.shape in [(), (1, )]:
return np.full(N, x) if N != 0 else x.reshape(())
elif x.shape == (N, ):
return x
else:
raise TypeError("Shape of parameters cannot be interpreted")
def NormalTransform(mean, sigma, key, N=0):
"""Opchain that transforms standard normally distributed values to """Opchain that transforms standard normally distributed values to
normally distributed values with given mean an standard deviation. normally distributed values with given mean an standard deviation.
...@@ -52,35 +43,20 @@ def NormalTransform(mean, sigma, key, N=0): ...@@ -52,35 +43,20 @@ def NormalTransform(mean, sigma, key, N=0):
If >= 1, target will be an If >= 1, target will be an
:class:`~nifty6.unstructured_domain.UnstructuredDomain`. :class:`~nifty6.unstructured_domain.UnstructuredDomain`.
""" """
if N == 0: if N_copies == 0:
domain = DomainTuple.scalar_domain() domain = DomainTuple.scalar_domain()
mean, sigma = np.asfarray(mean), np.asfarray(sigma) mean, sigma = np.asfarray(mean), np.asfarray(sigma)
mean_adder = Adder(makeField(domain, mean)) mean_adder = Adder(makeField(domain, mean))
return mean_adder @ (sigma * ducktape(domain, None, key)) return mean_adder @ (sigma * ducktape(domain, None, key))
domain = UnstructuredDomain(N) domain = UnstructuredDomain(N_copies)
mean, sigma = (_reshaper(param, N) for param in (mean, sigma)) mean, sigma = (value_reshaper(param, N_copies) for param in (mean, sigma))
mean_adder = Adder(makeField(domain, mean)) mean_adder = Adder(makeField(domain, mean))
sigma_op = DiagonalOperator(makeField(domain, sigma)) sigma_op = DiagonalOperator(makeField(domain, sigma))
return mean_adder @ sigma_op @ ducktape(domain, None, key) return mean_adder @ sigma_op @ ducktape(domain, None, key)
def _lognormal_moments(mean, sig, N=0): def LognormalTransform(mean, sigma, key, N_copies):
if N == 0:
mean, sig = np.asfarray(mean), np.asfarray(sig)
else:
mean, sig = (_reshaper(param, N) for param in (mean, sig))
if not np.all(mean > 0):
raise ValueError("mean must be greater 0; got {!r}".format(mean))
if not np.all(sig > 0):
raise ValueError("sig must be greater 0; got {!r}".format(sig))
logsig = np.sqrt(np.log1p((sig / mean)**2))
logmean = np.log(mean) - logsig**2 / 2
return logmean, logsig
class LognormalTransform(Operator):
"""Opchain that transforms standard normally distributed values to """Opchain that transforms standard normally distributed values to
log-normally distributed values with given mean an standard deviation. log-normally distributed values with given mean an standard deviation.
...@@ -97,23 +73,5 @@ class LognormalTransform(Operator): ...@@ -97,23 +73,5 @@ class LognormalTransform(Operator):
If >= 1, target will be an If >= 1, target will be an
:class:`~nifty6.unstructured_domain.UnstructuredDomain`. :class:`~nifty6.unstructured_domain.UnstructuredDomain`.
""" """
def __init__(self, mean, sigma, key, N_copies): logmean, logsigma = lognormal_moments(mean, sigma, N_copies)
key = str(key) return NormalTransform(logmean, logsigma, key, N_copies).ptw("exp")
logmean, logsigma = _lognormal_moments(mean, sigma, N_copies)
self._mean = mean
self._sigma = sigma
op = NormalTransform(logmean, logsigma, key, N_copies).ptw("exp")
self._domain, self._target = op.domain, op.target
self.apply = op.apply
self._repr_str = f"LognormalTransform: " + op.__repr__()
@property
def mean(self):
return self._mean
@property
def std(self):
return self._sigma
def __repr__(self):
return self._repr_str
...@@ -24,7 +24,8 @@ import numpy as np ...@@ -24,7 +24,8 @@ import numpy as np
__all__ = ["get_slice_list", "safe_cast", "parse_spaces", "infer_space", __all__ = ["get_slice_list", "safe_cast", "parse_spaces", "infer_space",
"memo", "NiftyMeta", "my_sum", "my_lincomb_simple", "memo", "NiftyMeta", "my_sum", "my_lincomb_simple",
"my_lincomb", "indent", "my_lincomb", "indent",
"my_product", "frozendict", "special_add_at", "iscomplextype"] "my_product", "frozendict", "special_add_at", "iscomplextype",
"value_reshaper", "lognormal_moments"]
def my_sum(iterable): def my_sum(iterable):
...@@ -366,3 +367,31 @@ def allreduce_sum(obj, comm): ...@@ -366,3 +367,31 @@ def allreduce_sum(obj, comm):
if comm is None: if comm is None:
return vals[0] return vals[0]
return comm.bcast(vals[0], root=who[0]) return comm.bcast(vals[0], root=who[0])
def value_reshaper(x, N):
"""Produce arrays of shape `(N,)`.
If `x` is a scalar or array of length one, fill the target array with it.
If `x` is an array, check if it has the right shape."""
x = np.asfarray(x)
if x.shape in [(), (1, )]:
return np.full(N, x) if N != 0 else x.reshape(())
elif x.shape == (N, ):
return x
raise TypeError("x and N are incompatible")
def lognormal_moments(mean, sigma, N=0):
"""Calculates the parameters for a normal distribution `n(x)`
such that `exp(n)(x)` has the mean and standard deviation given.
Used in :func:`~nifty6.normal_operators.LognormalTransform`."""
mean, sigma = (value_reshaper(param, N) for param in (mean, sigma))
if not np.all(mean > 0):
raise ValueError("mean must be greater 0; got {!r}".format(mean))
if not np.all(sigma > 0):
raise ValueError("sig must be greater 0; got {!r}".format(sigma))
logsigma = np.sqrt(np.log1p((sigma / mean)**2))
logmean = np.log(mean) - logsigma**2 / 2
return logmean, logsigma
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment