separation_energy.py 1.95 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
from nifty4 import Energy, Field, log, exp, DiagonalOperator
from nifty4.library import WienerFilterCurvature
Jakob Knollmueller's avatar
Jakob Knollmueller committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58


class SeparationEnergy(Energy):

    def __init__(self, position, parameters):

        x = position.val.clip(-9, 9)
        position = Field(position.domain, val=x)
        super(SeparationEnergy, self).__init__(position=position)

        self.parameters = parameters
        self.inverter = parameters['inverter']
        self.d = parameters['data']
        self.FFT = parameters['FFT']
        self.correlation = parameters['correlation']
        self.alpha = parameters['alpha']
        self.q = parameters['q']
        pos_tanh = parameters['pos_tanh']

        self.S = self.FFT.adjoint * self.correlation * self.FFT
        self.a = pos_tanh(self.position)
        self.a_p = pos_tanh.derivative(self.position)

        self.u = log(self.d * self.a)
        self.u_p = self.a_p/self.a
        one_m_a = 1 - self.a
        self.s = log(self.d * one_m_a)
        self.s_p = - self.a_p / one_m_a
        self.var_x = 9.

    def at(self, position):
        return self.__class__(position, parameters=self.parameters)

    @property
    def value(self):
        diffuse = 0.5 * self.s.vdot(self.S.inverse(self.s))
        point = (self.alpha-1).vdot(self.u) + self.q.vdot(exp(-self.u))
        det = self.s.integrate()
        det += 0.5 / self.var_x * self.position.vdot(self.position)
        return diffuse + point + det

    @property
    def gradient(self):
        diffuse = self.S.inverse(self.s) * self.s_p
        point = (self.alpha - 1) * self.u_p - self.q * exp(-self.u) * self.u_p
        det = self.position / self.var_x
        det += self.s_p
        return diffuse + point + det

    @property
    def curvature(self):
        point = self.q * exp(-self.u) * self.u_p ** 2
        R = self.FFT * self.s_p
        N = self.correlation
        S = DiagonalOperator(1/(point + 1/self.var_x))
        return WienerFilterCurvature(R=R, N=N, S=S, inverter=self.inverter)