from nifty4 import Energy, Field, log, exp, DiagonalOperator from nifty4.library import WienerFilterCurvature class SeparationEnergy(Energy): def __init__(self, position, parameters): x = position.val.clip(-9, 9) position = Field(position.domain, val=x) super(SeparationEnergy, self).__init__(position=position) self.parameters = parameters self.inverter = parameters['inverter'] self.d = parameters['data'] self.FFT = parameters['FFT'] self.correlation = parameters['correlation'] self.alpha = parameters['alpha'] self.q = parameters['q'] pos_tanh = parameters['pos_tanh'] self.S = self.FFT * self.correlation * self.FFT.adjoint self.a = pos_tanh(self.position) self.a_p = pos_tanh.derivative(self.position) self.u = log(self.d * self.a) self.u_p = self.a_p/self.a one_m_a = 1 - self.a self.s = log(self.d * one_m_a) self.s_p = - self.a_p / one_m_a self.var_x = 9. def at(self, position): return self.__class__(position, parameters=self.parameters) @property def value(self): diffuse = 0.5 * self.s.vdot(self.S.inverse(self.s)) point = (self.alpha-1).vdot(self.u) + self.q.vdot(exp(-self.u)) det = self.s.integrate() det += 0.5 / self.var_x * self.position.vdot(self.position) return diffuse + point + det @property def gradient(self): diffuse = self.S.inverse(self.s) * self.s_p point = (self.alpha - 1) * self.u_p - self.q * exp(-self.u) * self.u_p det = self.position / self.var_x det += self.s_p return diffuse + point + det @property def curvature(self): point = self.q * exp(-self.u) * self.u_p ** 2 R = self.FFT.inverse * self.s_p N = self.correlation S = DiagonalOperator(1/(point + 1/self.var_x)) return WienerFilterCurvature(R=R, N=N, S=S, inverter=self.inverter)