separation_energy.py 1.96 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1 2
from nifty4 import Energy, Field, log, exp, DiagonalOperator
from nifty4.library import WienerFilterCurvature
Jakob Knollmueller's avatar
Jakob Knollmueller committed
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21


class SeparationEnergy(Energy):

    def __init__(self, position, parameters):

        x = position.val.clip(-9, 9)
        position = Field(position.domain, val=x)
        super(SeparationEnergy, self).__init__(position=position)

        self.parameters = parameters
        self.inverter = parameters['inverter']
        self.d = parameters['data']
        self.FFT = parameters['FFT']
        self.correlation = parameters['correlation']
        self.alpha = parameters['alpha']
        self.q = parameters['q']
        pos_tanh = parameters['pos_tanh']

Martin Reinecke's avatar
Martin Reinecke committed
22
        self.S = self.FFT * self.correlation * self.FFT.adjoint
Jakob Knollmueller's avatar
Jakob Knollmueller committed
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
        self.a = pos_tanh(self.position)
        self.a_p = pos_tanh.derivative(self.position)

        self.u = log(self.d * self.a)
        self.u_p = self.a_p/self.a
        one_m_a = 1 - self.a
        self.s = log(self.d * one_m_a)
        self.s_p = - self.a_p / one_m_a
        self.var_x = 9.

    def at(self, position):
        return self.__class__(position, parameters=self.parameters)

    @property
    def value(self):
        diffuse = 0.5 * self.s.vdot(self.S.inverse(self.s))
        point = (self.alpha-1).vdot(self.u) + self.q.vdot(exp(-self.u))
        det = self.s.integrate()
        det += 0.5 / self.var_x * self.position.vdot(self.position)
        return diffuse + point + det

    @property
    def gradient(self):
        diffuse = self.S.inverse(self.s) * self.s_p
        point = (self.alpha - 1) * self.u_p - self.q * exp(-self.u) * self.u_p
        det = self.position / self.var_x
        det += self.s_p
        return diffuse + point + det

    @property
    def curvature(self):
        point = self.q * exp(-self.u) * self.u_p ** 2
Martin Reinecke's avatar
Martin Reinecke committed
55
        R = self.FFT.inverse * self.s_p
Jakob Knollmueller's avatar
Jakob Knollmueller committed
56 57 58
        N = self.correlation
        S = DiagonalOperator(1/(point + 1/self.var_x))
        return WienerFilterCurvature(R=R, N=N, S=S, inverter=self.inverter)