Skip to content
Snippets Groups Projects
Commit 0885b6ad authored by Jakob Knollmueller's avatar Jakob Knollmueller
Browse files

added memo to PowerEnergy and changed check in smoothness and laplace from len(domain) != 0 to 1

parent d6341dcb
Branches
Tags
1 merge request!120Working on demos
Pipeline #
from nifty import * from nifty import *
from nifty.library.wiener_filter import WienerFilterEnergy
from nifty.library.critical_filter import CriticalPowerEnergy
import plotly.offline as pl import plotly.offline as pl
import plotly.graph_objs as go import plotly.graph_objs as go
......
from nifty.energies.energy import Energy from nifty.energies.energy import Energy
from nifty.operators.smoothness_operator import SmoothnessOperator from nifty.operators.smoothness_operator import SmoothnessOperator
from nifty.library.critical_filter import CriticalPowerCurvature from nifty.library.critical_filter import CriticalPowerCurvature
from nifty.energies.memoization import memo
from nifty.sugar import generate_posterior_sample from nifty.sugar import generate_posterior_sample
from nifty import Field, exp from nifty import Field, exp
...@@ -77,22 +77,23 @@ class CriticalPowerEnergy(Energy): ...@@ -77,22 +77,23 @@ class CriticalPowerEnergy(Energy):
@property @property
def value(self): def value(self):
energy = exp(-self.position).vdot(self.q + self.w / 2., bare= True) energy = self._theta.vdot(Field(self.position.domain,val=1.), bare= True)
energy += self.position.vdot(self.alpha - 1. + self.rho / 2., bare=True) energy += self.position.vdot(self._rho_prime, bare=True)
energy += 0.5 * self.position.vdot(self.T(self.position)) energy += 0.5 * self.position.vdot(self._Tt)
return energy.real return energy.real
@property @property
def gradient(self): def gradient(self):
gradient = - self.theta.weight(-1) gradient = - self._theta.weight(-1)
gradient += (self.alpha - 1. + self.rho / 2.).weight(-1) gradient += (self._rho_prime).weight(-1)
gradient += self.T(self.position) gradient += self._Tt
gradient.val = gradient.val.real gradient.val = gradient.val.real
return gradient return gradient
@property @property
def curvature(self): def curvature(self):
curvature = CriticalPowerCurvature(theta=self.theta.weight(-1), T=self.T, inverter=self.inverter) curvature = CriticalPowerCurvature(theta=self._theta.weight(-1), T=self.T,
inverter=self.inverter)
return curvature return curvature
def _calculate_w(self, m, D, samples): def _calculate_w(self, m, D, samples):
...@@ -114,4 +115,18 @@ class CriticalPowerEnergy(Energy): ...@@ -114,4 +115,18 @@ class CriticalPowerEnergy(Energy):
return w return w
@property
@memo
def _theta(self):
return (exp(-self.position) * (self.q + self.w / 2.))
@property
@memo
def _rho_prime(self):
return self.alpha - 1. + self.rho / 2.
@property
@memo
def _Tt(self):
return self.T(self.position)
...@@ -42,7 +42,7 @@ class LaplaceOperator(EndomorphicOperator): ...@@ -42,7 +42,7 @@ class LaplaceOperator(EndomorphicOperator):
def __init__(self, domain, default_spaces=None, logarithmic=True): def __init__(self, domain, default_spaces=None, logarithmic=True):
super(LaplaceOperator, self).__init__(default_spaces) super(LaplaceOperator, self).__init__(default_spaces)
self._domain = self._parse_domain(domain) self._domain = self._parse_domain(domain)
if len(self.domain) != 0: if len(self.domain) != 1:
raise ValueError("The domain must contain exactly one PowerSpace.") raise ValueError("The domain must contain exactly one PowerSpace.")
if not isinstance(self.domain[0], PowerSpace): if not isinstance(self.domain[0], PowerSpace):
......
...@@ -31,7 +31,7 @@ class SmoothnessOperator(EndomorphicOperator): ...@@ -31,7 +31,7 @@ class SmoothnessOperator(EndomorphicOperator):
super(SmoothnessOperator, self).__init__(default_spaces=default_spaces) super(SmoothnessOperator, self).__init__(default_spaces=default_spaces)
self._domain = self._parse_domain(domain) self._domain = self._parse_domain(domain)
if len(self.domain) != 0: if len(self.domain) != 1:
raise ValueError("The domain must contain exactly one PowerSpace.") raise ValueError("The domain must contain exactly one PowerSpace.")
if not isinstance(self.domain[0], PowerSpace): if not isinstance(self.domain[0], PowerSpace):
...@@ -68,7 +68,7 @@ class SmoothnessOperator(EndomorphicOperator): ...@@ -68,7 +68,7 @@ class SmoothnessOperator(EndomorphicOperator):
return False return False
def _times(self, x, spaces): def _times(self, x, spaces):
res = self._aplace.adjoint_times(self._laplace(x, spaces), spaces) res = self._laplace.adjoint_times(self._laplace(x, spaces), spaces)
return (1./self.sigma)**2*res return (1./self.sigma)**2*res
# ---Added properties and methods--- # ---Added properties and methods---
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment