wiener_process_integrated_amplitude.py 1.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import nifty5 as ift
import numpy as np


class WienerProcessIntegratedAmplitude(ift.LinearOperator):
    def __init__(self, target):
        # target is PowerSpace
        self._target = ift.makeDomain(target)
        self._domain = ift.makeDomain(
            ift.UnstructuredDomain(self.target.shape[0] - 2))
        self._capability = self.TIMES | self.ADJOINT_TIMES

    def apply(self, x, mode):
        self._check_input(x, mode)
Philipp Arras's avatar
Philipp Arras committed
15
16
17
18
        k_lengths = self._target[0].k_lengths
        vol = k_lengths[2:] - k_lengths[1:-1]
        ks = k_lengths[1:-1] + vol/2
        logvol = vol/ks
19
20
21
22
23
        if mode == self.TIMES:
            x = x.to_global_data()
            res = np.empty(self._target.shape)
            res[0] = 0
            res[1] = 0
Philipp Arras's avatar
Philipp Arras committed
24
25
            res[2:] = np.cumsum(x*logvol)
            res[2:] = np.cumsum(res[2:]*logvol)
26
27
28
29
            return ift.from_global_data(self._target, res)
        else:
            x = x.to_global_data()
            res = np.empty(self._target.shape)
Philipp Arras's avatar
Philipp Arras committed
30
31
            res[2:] = np.cumsum(x[2:][::-1])[::-1]*logvol
            res[2:] = np.cumsum(res[2:][::-1])[::-1]*logvol
32
33
34
35
36
            return ift.from_global_data(self._domain, res[2:])


if __name__ == '__main__':
    np.random.seed(42)
Philipp Arras's avatar
Philipp Arras committed
37
    ndim = 2
38
39
40
41
42
43
44
45
    sspace = ift.RGSpace(
        np.linspace(16, 20, num=ndim).astype(np.int),
        np.linspace(2.3, 7.99, num=ndim))
    hspace = sspace.get_default_codomain()
    target = ift.PowerSpace(hspace)
    op = WienerProcessIntegratedAmplitude(target)
    ift.extra.consistency_check(op)
    fld = ift.from_random('normal', op.domain)
Philipp Arras's avatar
Philipp Arras committed
46
47
    op = op.exp()
    ift.single_plot(op(fld), name='debug.png')