normal_operators.py 2.89 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2020 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

import numpy as np
from ..domain_tuple import DomainTuple
from ..domains.unstructured_domain import UnstructuredDomain
from ..operators.operator import Operator
from ..operators.adder import Adder
from ..operators.simple_linear_operators import ducktape
from ..operators.diagonal_operator import DiagonalOperator
from ..sugar import makeField
26
from ..utilities import value_reshaper, lognormal_moments
27
28


29
def NormalTransform(mean, sigma, key, N_copies=0):
30
31
32
    """Opchain that transforms standard normally distributed values to
    normally distributed values with given mean an standard deviation.

Philipp Arras's avatar
Philipp Arras committed
33
34
    Parameters
    ----------
35
36
37
38
39
40
41
42
43
44
45
    mean : float
        Mean of the field
    sigma : float
        Standard deviation of the field
    key : string
        Name of the operators domain (Multidomain)
    N_copies : integer
        If == 0, target will be a scalar field.
        If >= 1, target will be an
        :class:`~nifty6.unstructured_domain.UnstructuredDomain`.
    """
46
    if N_copies == 0:
47
48
49
50
51
        domain = DomainTuple.scalar_domain()
        mean, sigma = np.asfarray(mean), np.asfarray(sigma)
        mean_adder = Adder(makeField(domain, mean))
        return mean_adder @ (sigma * ducktape(domain, None, key))

52
53
    domain = UnstructuredDomain(N_copies)
    mean, sigma = (value_reshaper(param, N_copies) for param in (mean, sigma))
54
55
56
57
58
    mean_adder = Adder(makeField(domain, mean))
    sigma_op = DiagonalOperator(makeField(domain, sigma))
    return mean_adder @ sigma_op @ ducktape(domain, None, key)


59
def LognormalTransform(mean, sigma, key, N_copies):
60
61
62
    """Opchain that transforms standard normally distributed values to
    log-normally distributed values with given mean an standard deviation.

Philipp Arras's avatar
Philipp Arras committed
63
64
    Parameters
    ----------
65
66
67
68
69
70
71
72
73
74
75
    mean : float
        Mean of the field
    sigma : float
        Standard deviation of the field
    key : string
        Name of the domain
    N_copies : integer
        If == 0, target will be a scalar field.
        If >= 1, target will be an
        :class:`~nifty6.unstructured_domain.UnstructuredDomain`.
    """
76
77
    logmean, logsigma = lognormal_moments(mean, sigma, N_copies)
    return NormalTransform(logmean, logsigma, key, N_copies).ptw("exp")