Commit e6de9466 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

performance tweaks

parent b6ebd0ad
Pipeline #21872 passed with stage
in 4 minutes and 16 seconds
......@@ -417,6 +417,11 @@ class Field(object):
return self._contraction_helper('sum', spaces)
def integrate(self, spaces=None):
swgt = self.scalar_weight(spaces)
if swgt is not None:
res = self.sum(spaces)
res *= swgt
return res
tmp = self.weight(1, spaces=spaces)
return tmp.sum(spaces)
......
......@@ -9,7 +9,6 @@ class CriticalPowerCurvature(EndomorphicOperator):
CriticalPowerEnergy used in some minimization algorithms or
for error estimates of the power spectrum.
Parameters
----------
theta: Field,
......@@ -20,19 +19,19 @@ class CriticalPowerCurvature(EndomorphicOperator):
def __init__(self, theta, T):
super(CriticalPowerCurvature, self).__init__()
self.theta = DiagonalOperator(theta)
self.T = T
self._theta = DiagonalOperator(theta)
self._T = T
@property
def preconditioner(self):
return self.theta.inverse_times
return self._theta.inverse_times
def _times(self, x):
return self.T(x) + self.theta(x)
return self._T(x) + self._theta(x)
@property
def domain(self):
return self.theta.domain
return self._theta.domain
@property
def self_adjoint(self):
......
......@@ -52,8 +52,6 @@ class CriticalPowerEnergy(Energy):
default : None
"""
# ---Overwritten properties and methods---
def __init__(self, position, m, D=None, alpha=1.0, q=0.,
smoothness_prior=0., logarithmic=True, samples=3, w=None,
inverter=None):
......@@ -61,8 +59,8 @@ class CriticalPowerEnergy(Energy):
self.m = m
self.D = D
self.samples = samples
self.alpha = Field(self.position.domain, val=alpha)
self.q = Field(self.position.domain, val=q)
self.alpha = float(alpha)
self.q = float(q)
self.T = SmoothnessOperator(domain=self.position.domain[0],
strength=smoothness_prior,
logarithmic=logarithmic)
......@@ -71,8 +69,6 @@ class CriticalPowerEnergy(Energy):
self._w = w
self._inverter = inverter
# ---Mandatory properties and methods---
def at(self, position):
return self.__class__(position, self.m, D=self.D, alpha=self.alpha,
q=self.q, smoothness_prior=self.smoothness_prior,
......@@ -83,9 +79,9 @@ class CriticalPowerEnergy(Energy):
@property
@memo
def value(self):
energy = Field.ones_like(self.position).vdot(self._theta)
energy += self.position.vdot(self.alpha-0.5)
energy += 0.5 * self.position.vdot(self._Tt)
energy = self._theta.integrate()
energy += self.position.integrate()*(self.alpha-0.5)
energy += 0.5*self.position.vdot(self._Tt)
return energy.real
@property
......@@ -116,15 +112,15 @@ class CriticalPowerEnergy(Energy):
def w(self):
if self._w is None:
# self.logger.info("Initializing w")
w = Field(domain=self.position.domain, val=0., dtype=self.m.dtype)
if self.D is not None:
w = Field.zeros(self.position.domain, dtype=self.m.dtype)
for i in range(self.samples):
# self.logger.info("Drawing sample %i" % i)
posterior_sample = generate_posterior_sample(
self.m, self.D)
w += self.P(abs(posterior_sample) ** 2)
w += self.P(abs(posterior_sample)**2)
w /= float(self.samples)
w *= 1./self.samples
else:
w = self.P(abs(self.m)**2)
self._w = w
......@@ -133,7 +129,7 @@ class CriticalPowerEnergy(Energy):
@property
@memo
def _theta(self):
return exp(-self.position) * (self.q + self.w / 2.)
return exp(-self.position) * (self.q + self.w*0.5)
@property
@memo
......
......@@ -22,20 +22,13 @@ class LogNormalWienerFilterCurvature(EndomorphicOperator):
The prior signal covariance
"""
def __init__(self, R, N, S, d, position, fft4exp=None):
def __init__(self, R, N, S, position, fft4exp):
super(LogNormalWienerFilterCurvature, self).__init__()
self.R = R
self.N = N
self.S = S
self.d = d
self.position = position
if fft4exp is None:
self._fft = create_composed_fft_operator(self.domain,
all_to='position')
else:
self._fft = fft4exp
super(LogNormalWienerFilterCurvature, self).__init__()
self._fft = fft4exp
@property
def domain(self):
......@@ -51,33 +44,14 @@ class LogNormalWienerFilterCurvature(EndomorphicOperator):
def _times(self, x):
part1 = self.S.inverse_times(x)
# part2 = self._exppRNRexppd * x
part3 = self._fft.adjoint_times(self._expp_sspace * self._fft(x))
part3 = self._fft.adjoint_times(
self._expp_sspace *
self._fft(self.R.adjoint_times(
self.N.inverse_times(self.R(part3)))))
return part1 + part3 # + part2
return part1 + part3
@property
@memo
def _expp_sspace(self):
return exp(self._fft(self.position))
@property
@memo
def _Rexppd(self):
expp = self._fft.adjoint_times(self._expp_sspace)
return self.R(expp) - self.d
@property
@memo
def _NRexppd(self):
return self.N.inverse_times(self._Rexppd)
@property
@memo
def _exppRNRexppd(self):
return self._fft.adjoint_times(
self._expp_sspace *
self._fft(self.R.adjoint_times(self._NRexppd)))
......@@ -48,19 +48,19 @@ class LogNormalWienerFilterEnergy(Energy):
@memo
def value(self):
return 0.5*(self.position.vdot(self._Sp) +
self.curvature.op._Rexppd.vdot(self.curvature.op._NRexppd))
self._Rexppd.vdot(self._NRexppd))
@property
@memo
def gradient(self):
return self._Sp + self.curvature.op._exppRNRexppd
return self._Sp + self._exppRNRexppd
@property
@memo
def curvature(self):
return InversionEnabler(
LogNormalWienerFilterCurvature(R=self.R, N=self.N, S=self.S,
d=self.d, position=self.position,
position=self.position,
fft4exp=self._fft),
inverter=self._inverter)
......@@ -68,3 +68,21 @@ class LogNormalWienerFilterEnergy(Energy):
@memo
def _Sp(self):
return self.S.inverse_times(self.position)
@property
@memo
def _Rexppd(self):
expp = self._fft.adjoint_times(self.curvature.op._expp_sspace)
return self.R(expp) - self.d
@property
@memo
def _NRexppd(self):
return self.N.inverse_times(self._Rexppd)
@property
@memo
def _exppRNRexppd(self):
return self._fft.adjoint_times(
self.curvature.op._expp_sspace *
self._fft(self.R.adjoint_times(self._NRexppd)))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment