Commit bce18b72 authored by Martin Reinecke's avatar Martin Reinecke

Merge branch 'working_on_demos' of gitlab.mpcdf.mpg.de:ift/NIFTy into working_on_demos

parents 4804882e 0885b6ad
Pipeline #14718 passed with stage
in 6 minutes and 46 seconds
from nifty import *
from nifty.library.wiener_filter import WienerFilterEnergy
from nifty.library.critical_filter import CriticalPowerEnergy
import plotly.offline as pl
import plotly.graph_objs as go
......
from nifty.energies.energy import Energy
from nifty.operators.smoothness_operator import SmoothnessOperator
from nifty.library.critical_filter import CriticalPowerCurvature
from nifty.energies.memoization import memo
from nifty.sugar import generate_posterior_sample
from nifty import Field, exp
......@@ -77,22 +77,23 @@ class CriticalPowerEnergy(Energy):
@property
def value(self):
energy = exp(-self.position).vdot(self.q + self.w / 2., bare= True)
energy += self.position.vdot(self.alpha - 1. + self.rho / 2., bare=True)
energy += 0.5 * self.position.vdot(self.T(self.position))
energy = self._theta.vdot(Field(self.position.domain,val=1.), bare= True)
energy += self.position.vdot(self._rho_prime, bare=True)
energy += 0.5 * self.position.vdot(self._Tt)
return energy.real
@property
def gradient(self):
gradient = - self.theta.weight(-1)
gradient += (self.alpha - 1. + self.rho / 2.).weight(-1)
gradient += self.T(self.position)
gradient = - self._theta.weight(-1)
gradient += (self._rho_prime).weight(-1)
gradient += self._Tt
gradient.val = gradient.val.real
return gradient
@property
def curvature(self):
curvature = CriticalPowerCurvature(theta=self.theta.weight(-1), T=self.T, inverter=self.inverter)
curvature = CriticalPowerCurvature(theta=self._theta.weight(-1), T=self.T,
inverter=self.inverter)
return curvature
def _calculate_w(self, m, D, samples):
......@@ -114,4 +115,18 @@ class CriticalPowerEnergy(Energy):
return w
@property
@memo
def _theta(self):
return (exp(-self.position) * (self.q + self.w / 2.))
@property
@memo
def _rho_prime(self):
return self.alpha - 1. + self.rho / 2.
@property
@memo
def _Tt(self):
return self.T(self.position)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment