Commit 85dc3a08 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

simplifications

parent e67c236f
Pipeline #22306 passed with stage
in 4 minutes and 47 seconds
......@@ -64,11 +64,37 @@ class CriticalPowerEnergy(Energy):
self.T = SmoothnessOperator(domain=self.position.domain[0],
strength=smoothness_prior,
logarithmic=logarithmic)
self.P = PowerProjectionOperator(domain=self.m.domain,
power_space=self.position.domain[0])
self._w = w
self._inverter = inverter
if w is None:
# self.logger.info("Initializing w")
P = PowerProjectionOperator(domain=self.m.domain,
power_space=self.position.domain[0])
if self.D is not None:
w = Field.zeros(self.position.domain, dtype=self.m.dtype)
for i in range(self.samples):
# self.logger.info("Drawing sample %i" % i)
sample = generate_posterior_sample(self.m, self.D)
w += P(abs(sample)**2)
w *= 1./self.samples
else:
w = P(abs(self.m)**2)
self._w = w
theta = exp(-self.position) * (self.q + self._w*0.5)
Tt = self.T(self.position)
energy = theta.integrate()
energy += self.position.integrate()*(self.alpha-0.5)
energy += 0.5*self.position.vdot(Tt)
self._value = energy.real
gradient = -theta
gradient += self.alpha-0.5
gradient += Tt
self._gradient = gradient.real
def at(self, position):
return self.__class__(position, self.m, D=self.D, alpha=self.alpha,
q=self.q, smoothness_prior=self.smoothness_prior,
......@@ -77,20 +103,12 @@ class CriticalPowerEnergy(Energy):
inverter=self._inverter)
@property
@memo
def value(self):
energy = self._theta.integrate()
energy += self.position.integrate()*(self.alpha-0.5)
energy += 0.5*self.position.vdot(self._Tt)
return energy.real
return self._value
@property
@memo
def gradient(self):
gradient = -self._theta
gradient += self.alpha-0.5
gradient += self._Tt
return gradient.real
return self._gradient
@property
@memo
......@@ -98,8 +116,6 @@ class CriticalPowerEnergy(Energy):
curv = CriticalPowerCurvature(theta=self._theta, T=self.T)
return InversionEnabler(curv, inverter=self._inverter)
# ---Added properties and methods---
@property
def logarithmic(self):
return self.T.logarithmic
......@@ -107,30 +123,3 @@ class CriticalPowerEnergy(Energy):
@property
def smoothness_prior(self):
return self.T.strength
@property
def w(self):
if self._w is None:
# self.logger.info("Initializing w")
if self.D is not None:
w = Field.zeros(self.position.domain, dtype=self.m.dtype)
for i in range(self.samples):
# self.logger.info("Drawing sample %i" % i)
sample = generate_posterior_sample(self.m, self.D)
w += self.P(abs(sample)**2)
w *= 1./self.samples
else:
w = self.P(abs(self.m)**2)
self._w = w
return self._w
@property
@memo
def _theta(self):
return exp(-self.position) * (self.q + self.w*0.5)
@property
@memo
def _Tt(self):
return self.T(self.position)
......@@ -5,26 +5,23 @@ from .response_operators import LinearizedPowerResponse
class NonlinearPowerCurvature(EndomorphicOperator):
def __init__(self, position, FFT, Instrument, nonlinearity,
def __init__(self, position, Instrument, nonlinearity,
Projection, N, T, sample_list):
self.N = N
self.FFT = FFT
self.Instrument = Instrument
self.T = T
self.sample_list = sample_list
self.samples = len(sample_list)
self.position = position
self.Projection = Projection
self.nonlinearity = nonlinearity
# if preconditioner is None:
# preconditioner = self.theta.inverse_times
self._domain = self.position.domain
super(NonlinearPowerCurvature, self).__init__()
@property
def domain(self):
return self._domain
return self.position.domain
@property
def self_adjoint(self):
......@@ -35,19 +32,14 @@ class NonlinearPowerCurvature(EndomorphicOperator):
return False
def _times(self, x):
result = Field(self.domain, val=0.)
for i in range(self.samples):
sample = self.sample_list[i]
result = Field.zeros_like(self.position, dtype=np.float64)
for sample in self.sample_list:
result += self._sample_times(x, sample)
result /= self.samples
return (result + self.T(x))
result *= 1./len(self.sample_list)
return result + self.T(x)
def _sample_times(self, x, sample):
LinearizedResponse = LinearizedPowerResponse(self.Instrument, self.nonlinearity,
self.FFT, self.Projection, self.position, sample)
result = LinearizedResponse.adjoint_times(
self.Projection, self.position, sample)
return LinearizedResponse.adjoint_times(
self.N.inverse_times(LinearizedResponse(x)))
return result
......@@ -110,6 +110,6 @@ class NonlinearPowerEnergy(Energy):
@property
@memo
def curvature(self):
curvature = NonlinearPowerCurvature(self.position, self.FFT, self.Instrument, self.nonlinearity,
curvature = NonlinearPowerCurvature(self.position, self.Instrument, self.nonlinearity,
self.Projection, self.N, self.T, self.sample_list)
return InversionEnabler(curvature, inverter=self.inverter)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment