Commit 0885b6ad authored by Jakob Knollmueller's avatar Jakob Knollmueller

added memo to PowerEnergy and changed check in smoothness and laplace from len(domain) != 0 to 1

parent d6341dcb
Pipeline #14717 passed with stage
in 6 minutes and 44 seconds
from nifty import *
from nifty.library.wiener_filter import WienerFilterEnergy
from nifty.library.critical_filter import CriticalPowerEnergy
import plotly.offline as pl
import plotly.graph_objs as go
......
from nifty.energies.energy import Energy
from nifty.operators.smoothness_operator import SmoothnessOperator
from nifty.library.critical_filter import CriticalPowerCurvature
from nifty.energies.memoization import memo
from nifty.sugar import generate_posterior_sample
from nifty import Field, exp
......@@ -77,22 +77,23 @@ class CriticalPowerEnergy(Energy):
@property
def value(self):
energy = exp(-self.position).vdot(self.q + self.w / 2., bare= True)
energy += self.position.vdot(self.alpha - 1. + self.rho / 2., bare=True)
energy += 0.5 * self.position.vdot(self.T(self.position))
energy = self._theta.vdot(Field(self.position.domain,val=1.), bare= True)
energy += self.position.vdot(self._rho_prime, bare=True)
energy += 0.5 * self.position.vdot(self._Tt)
return energy.real
@property
def gradient(self):
gradient = - self.theta.weight(-1)
gradient += (self.alpha - 1. + self.rho / 2.).weight(-1)
gradient += self.T(self.position)
gradient = - self._theta.weight(-1)
gradient += (self._rho_prime).weight(-1)
gradient += self._Tt
gradient.val = gradient.val.real
return gradient
@property
def curvature(self):
curvature = CriticalPowerCurvature(theta=self.theta.weight(-1), T=self.T, inverter=self.inverter)
curvature = CriticalPowerCurvature(theta=self._theta.weight(-1), T=self.T,
inverter=self.inverter)
return curvature
def _calculate_w(self, m, D, samples):
......@@ -114,4 +115,18 @@ class CriticalPowerEnergy(Energy):
return w
@property
@memo
def _theta(self):
return (exp(-self.position) * (self.q + self.w / 2.))
@property
@memo
def _rho_prime(self):
return self.alpha - 1. + self.rho / 2.
@property
@memo
def _Tt(self):
return self.T(self.position)
......@@ -42,7 +42,7 @@ class LaplaceOperator(EndomorphicOperator):
def __init__(self, domain, default_spaces=None, logarithmic=True):
super(LaplaceOperator, self).__init__(default_spaces)
self._domain = self._parse_domain(domain)
if len(self.domain) != 0:
if len(self.domain) != 1:
raise ValueError("The domain must contain exactly one PowerSpace.")
if not isinstance(self.domain[0], PowerSpace):
......
......@@ -31,7 +31,7 @@ class SmoothnessOperator(EndomorphicOperator):
super(SmoothnessOperator, self).__init__(default_spaces=default_spaces)
self._domain = self._parse_domain(domain)
if len(self.domain) != 0:
if len(self.domain) != 1:
raise ValueError("The domain must contain exactly one PowerSpace.")
if not isinstance(self.domain[0], PowerSpace):
......@@ -68,7 +68,7 @@ class SmoothnessOperator(EndomorphicOperator):
return False
def _times(self, x, spaces):
res = self._aplace.adjoint_times(self._laplace(x, spaces), spaces)
res = self._laplace.adjoint_times(self._laplace(x, spaces), spaces)
return (1./self.sigma)**2*res
# ---Added properties and methods---
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment