Commit e6de9466 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

performance tweaks

parent b6ebd0ad
Pipeline #21872 passed with stage
in 4 minutes and 16 seconds
...@@ -417,6 +417,11 @@ class Field(object): ...@@ -417,6 +417,11 @@ class Field(object):
return self._contraction_helper('sum', spaces) return self._contraction_helper('sum', spaces)
def integrate(self, spaces=None): def integrate(self, spaces=None):
swgt = self.scalar_weight(spaces)
if swgt is not None:
res = self.sum(spaces)
res *= swgt
return res
tmp = self.weight(1, spaces=spaces) tmp = self.weight(1, spaces=spaces)
return tmp.sum(spaces) return tmp.sum(spaces)
......
...@@ -9,7 +9,6 @@ class CriticalPowerCurvature(EndomorphicOperator): ...@@ -9,7 +9,6 @@ class CriticalPowerCurvature(EndomorphicOperator):
CriticalPowerEnergy used in some minimization algorithms or CriticalPowerEnergy used in some minimization algorithms or
for error estimates of the power spectrum. for error estimates of the power spectrum.
Parameters Parameters
---------- ----------
theta: Field, theta: Field,
...@@ -20,19 +19,19 @@ class CriticalPowerCurvature(EndomorphicOperator): ...@@ -20,19 +19,19 @@ class CriticalPowerCurvature(EndomorphicOperator):
def __init__(self, theta, T): def __init__(self, theta, T):
super(CriticalPowerCurvature, self).__init__() super(CriticalPowerCurvature, self).__init__()
self.theta = DiagonalOperator(theta) self._theta = DiagonalOperator(theta)
self.T = T self._T = T
@property @property
def preconditioner(self): def preconditioner(self):
return self.theta.inverse_times return self._theta.inverse_times
def _times(self, x): def _times(self, x):
return self.T(x) + self.theta(x) return self._T(x) + self._theta(x)
@property @property
def domain(self): def domain(self):
return self.theta.domain return self._theta.domain
@property @property
def self_adjoint(self): def self_adjoint(self):
......
...@@ -52,8 +52,6 @@ class CriticalPowerEnergy(Energy): ...@@ -52,8 +52,6 @@ class CriticalPowerEnergy(Energy):
default : None default : None
""" """
# ---Overwritten properties and methods---
def __init__(self, position, m, D=None, alpha=1.0, q=0., def __init__(self, position, m, D=None, alpha=1.0, q=0.,
smoothness_prior=0., logarithmic=True, samples=3, w=None, smoothness_prior=0., logarithmic=True, samples=3, w=None,
inverter=None): inverter=None):
...@@ -61,8 +59,8 @@ class CriticalPowerEnergy(Energy): ...@@ -61,8 +59,8 @@ class CriticalPowerEnergy(Energy):
self.m = m self.m = m
self.D = D self.D = D
self.samples = samples self.samples = samples
self.alpha = Field(self.position.domain, val=alpha) self.alpha = float(alpha)
self.q = Field(self.position.domain, val=q) self.q = float(q)
self.T = SmoothnessOperator(domain=self.position.domain[0], self.T = SmoothnessOperator(domain=self.position.domain[0],
strength=smoothness_prior, strength=smoothness_prior,
logarithmic=logarithmic) logarithmic=logarithmic)
...@@ -71,8 +69,6 @@ class CriticalPowerEnergy(Energy): ...@@ -71,8 +69,6 @@ class CriticalPowerEnergy(Energy):
self._w = w self._w = w
self._inverter = inverter self._inverter = inverter
# ---Mandatory properties and methods---
def at(self, position): def at(self, position):
return self.__class__(position, self.m, D=self.D, alpha=self.alpha, return self.__class__(position, self.m, D=self.D, alpha=self.alpha,
q=self.q, smoothness_prior=self.smoothness_prior, q=self.q, smoothness_prior=self.smoothness_prior,
...@@ -83,9 +79,9 @@ class CriticalPowerEnergy(Energy): ...@@ -83,9 +79,9 @@ class CriticalPowerEnergy(Energy):
@property @property
@memo @memo
def value(self): def value(self):
energy = Field.ones_like(self.position).vdot(self._theta) energy = self._theta.integrate()
energy += self.position.vdot(self.alpha-0.5) energy += self.position.integrate()*(self.alpha-0.5)
energy += 0.5 * self.position.vdot(self._Tt) energy += 0.5*self.position.vdot(self._Tt)
return energy.real return energy.real
@property @property
...@@ -116,15 +112,15 @@ class CriticalPowerEnergy(Energy): ...@@ -116,15 +112,15 @@ class CriticalPowerEnergy(Energy):
def w(self): def w(self):
if self._w is None: if self._w is None:
# self.logger.info("Initializing w") # self.logger.info("Initializing w")
w = Field(domain=self.position.domain, val=0., dtype=self.m.dtype)
if self.D is not None: if self.D is not None:
w = Field.zeros(self.position.domain, dtype=self.m.dtype)
for i in range(self.samples): for i in range(self.samples):
# self.logger.info("Drawing sample %i" % i) # self.logger.info("Drawing sample %i" % i)
posterior_sample = generate_posterior_sample( posterior_sample = generate_posterior_sample(
self.m, self.D) self.m, self.D)
w += self.P(abs(posterior_sample) ** 2) w += self.P(abs(posterior_sample)**2)
w /= float(self.samples) w *= 1./self.samples
else: else:
w = self.P(abs(self.m)**2) w = self.P(abs(self.m)**2)
self._w = w self._w = w
...@@ -133,7 +129,7 @@ class CriticalPowerEnergy(Energy): ...@@ -133,7 +129,7 @@ class CriticalPowerEnergy(Energy):
@property @property
@memo @memo
def _theta(self): def _theta(self):
return exp(-self.position) * (self.q + self.w / 2.) return exp(-self.position) * (self.q + self.w*0.5)
@property @property
@memo @memo
......
...@@ -22,20 +22,13 @@ class LogNormalWienerFilterCurvature(EndomorphicOperator): ...@@ -22,20 +22,13 @@ class LogNormalWienerFilterCurvature(EndomorphicOperator):
The prior signal covariance The prior signal covariance
""" """
def __init__(self, R, N, S, d, position, fft4exp=None): def __init__(self, R, N, S, position, fft4exp):
super(LogNormalWienerFilterCurvature, self).__init__()
self.R = R self.R = R
self.N = N self.N = N
self.S = S self.S = S
self.d = d
self.position = position self.position = position
self._fft = fft4exp
if fft4exp is None:
self._fft = create_composed_fft_operator(self.domain,
all_to='position')
else:
self._fft = fft4exp
super(LogNormalWienerFilterCurvature, self).__init__()
@property @property
def domain(self): def domain(self):
...@@ -51,33 +44,14 @@ class LogNormalWienerFilterCurvature(EndomorphicOperator): ...@@ -51,33 +44,14 @@ class LogNormalWienerFilterCurvature(EndomorphicOperator):
def _times(self, x): def _times(self, x):
part1 = self.S.inverse_times(x) part1 = self.S.inverse_times(x)
# part2 = self._exppRNRexppd * x
part3 = self._fft.adjoint_times(self._expp_sspace * self._fft(x)) part3 = self._fft.adjoint_times(self._expp_sspace * self._fft(x))
part3 = self._fft.adjoint_times( part3 = self._fft.adjoint_times(
self._expp_sspace * self._expp_sspace *
self._fft(self.R.adjoint_times( self._fft(self.R.adjoint_times(
self.N.inverse_times(self.R(part3))))) self.N.inverse_times(self.R(part3)))))
return part1 + part3 # + part2 return part1 + part3
@property @property
@memo @memo
def _expp_sspace(self): def _expp_sspace(self):
return exp(self._fft(self.position)) return exp(self._fft(self.position))
@property
@memo
def _Rexppd(self):
expp = self._fft.adjoint_times(self._expp_sspace)
return self.R(expp) - self.d
@property
@memo
def _NRexppd(self):
return self.N.inverse_times(self._Rexppd)
@property
@memo
def _exppRNRexppd(self):
return self._fft.adjoint_times(
self._expp_sspace *
self._fft(self.R.adjoint_times(self._NRexppd)))
...@@ -48,19 +48,19 @@ class LogNormalWienerFilterEnergy(Energy): ...@@ -48,19 +48,19 @@ class LogNormalWienerFilterEnergy(Energy):
@memo @memo
def value(self): def value(self):
return 0.5*(self.position.vdot(self._Sp) + return 0.5*(self.position.vdot(self._Sp) +
self.curvature.op._Rexppd.vdot(self.curvature.op._NRexppd)) self._Rexppd.vdot(self._NRexppd))
@property @property
@memo @memo
def gradient(self): def gradient(self):
return self._Sp + self.curvature.op._exppRNRexppd return self._Sp + self._exppRNRexppd
@property @property
@memo @memo
def curvature(self): def curvature(self):
return InversionEnabler( return InversionEnabler(
LogNormalWienerFilterCurvature(R=self.R, N=self.N, S=self.S, LogNormalWienerFilterCurvature(R=self.R, N=self.N, S=self.S,
d=self.d, position=self.position, position=self.position,
fft4exp=self._fft), fft4exp=self._fft),
inverter=self._inverter) inverter=self._inverter)
...@@ -68,3 +68,21 @@ class LogNormalWienerFilterEnergy(Energy): ...@@ -68,3 +68,21 @@ class LogNormalWienerFilterEnergy(Energy):
@memo @memo
def _Sp(self): def _Sp(self):
return self.S.inverse_times(self.position) return self.S.inverse_times(self.position)
@property
@memo
def _Rexppd(self):
expp = self._fft.adjoint_times(self.curvature.op._expp_sspace)
return self.R(expp) - self.d
@property
@memo
def _NRexppd(self):
return self.N.inverse_times(self._Rexppd)
@property
@memo
def _exppRNRexppd(self):
return self._fft.adjoint_times(
self.curvature.op._expp_sspace *
self._fft(self.R.adjoint_times(self._NRexppd)))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment