Commit 6c2cb2d4 authored by theos's avatar theos

Bug Fixes in ConjugateGradient, PropagatorOperator, SmoothingOperator and sugar.py

parent 15f09ea4
......@@ -4,15 +4,13 @@
from __future__ import division
import numpy as np
from nifty import Field
import logging
logger = logging.getLogger('NIFTy.CG')
class ConjugateGradient(object):
def __init__(self, convergence_tolerance=1E-4, convergence_level=1,
iteration_limit=-1, reset_count=-1,
def __init__(self, convergence_tolerance=1E-4, convergence_level=3,
iteration_limit=None, reset_count=None,
preconditioner=None, callback=None):
"""
Initializes the conjugate_gradient and sets the attributes (except
......@@ -39,8 +37,14 @@ class ConjugateGradient(object):
"""
self.convergence_tolerance = np.float(convergence_tolerance)
self.convergence_level = np.float(convergence_level)
self.iteration_limit = int(iteration_limit)
self.reset_count = int(reset_count)
if iteration_limit is not None:
iteration_limit = int(iteration_limit)
self.iteration_limit = iteration_limit
if reset_count is not None:
reset_count = int(reset_count)
self.reset_count = reset_count
if preconditioner is None:
preconditioner = lambda z: z
......@@ -48,7 +52,7 @@ class ConjugateGradient(object):
self.preconditioner = preconditioner
self.callback = callback
def __call__(self, A, b, x0=None):
def __call__(self, A, b, x0):
"""
Runs the conjugate gradient minimization.
......@@ -74,17 +78,14 @@ class ConjugateGradient(object):
has converged or not.
"""
if x0 is None:
x0 = Field(A.domain, val=0)
r = b - A(x0)
d = self.preconditioner(r)
previous_gamma = r.dot(d)
if previous_gamma == 0:
self.info("The starting guess is already perfect solution for "
"the inverse problem.")
logger.info("The starting guess is already perfect solution for "
"the inverse problem.")
return x0, self.convergence_level+1
norm_b = np.sqrt(b.dot(b))
x = x0
convergence = 0
iteration_number = 1
......@@ -107,7 +108,9 @@ class ConjugateGradient(object):
if alpha.real < 0:
logger.warn("Positive definiteness of A violated!")
reset = True
if reset or iteration_number % self.reset_count == 0:
if self.reset_count is not None:
reset += (iteration_number % self.reset_count == 0)
if reset:
logger.info("Resetting conjugate directions.")
r = b - A(x)
else:
......@@ -121,10 +124,7 @@ class ConjugateGradient(object):
beta = max(0, gamma/previous_gamma)
d = s + d * beta
#delta = previous_delta * np.sqrt(abs(gamma))
delta = abs(1-np.sqrt(abs(previous_gamma))/np.sqrt(abs(gamma)))
delta = np.sqrt(gamma)/norm_b
logger.debug("Iteration : %08u alpha = %3.1E beta = %3.1E "
"delta = %3.1E" %
......@@ -135,20 +135,23 @@ class ConjugateGradient(object):
if gamma == 0:
convergence = self.convergence_level+1
self.info("Reached infinite convergence.")
logger.info("Reached infinite convergence.")
break
elif abs(delta) < self.convergence_tolerance:
convergence += 1
self.info("Updated convergence level to: %u" % convergence)
logger.info("Updated convergence level to: %u" % convergence)
if convergence == self.convergence_level:
self.info("Reached target convergence level.")
logger.info("Reached target convergence level.")
break
else:
convergence = max(0, convergence-1)
if iteration_number == self.iteration_limit:
self.warn("Reached iteration limit. Stopping.")
break
if self.iteration_limit is not None:
if iteration_number == self.iteration_limit:
logger.warn("Reached iteration limit. Stopping.")
break
d = s + d * beta
iteration_number += 1
previous_gamma = gamma
......
# -*- coding: utf-8 -*-
import numpy as np
from nifty.minimization import ConjugateGradient
from nifty.nifty_utilities import get_default_codomain
from nifty.field import Field
from nifty.operators import EndomorphicOperator,\
FFTOperator
......@@ -40,22 +42,16 @@ class PropagatorOperator(EndomorphicOperator):
elif R is not None:
self._domain = R.domain
fft_RN = FFTOperator(self._domain, target=N.domain)
self._likelihood_times = \
lambda z: R.adjoint_times(
fft_RN.inverse_times(N.inverse_times(
fft_RN(R.times(z)))))
lambda z: R.adjoint_times(N.inverse_times(R.times(z)))
else:
self._domain = (get_default_codomain(N.domain[0]),)
fft_RN = FFTOperator(self._domain, target=N.domain)
self._likelihood_times = \
lambda z: fft_RN.inverse_times(N.inverse_times(
fft_RN(z)))
self._domain = N.domain
self._likelihood_times = lambda z: N.inverse_times(z)
fft_S = FFTOperator(S.domain, self._domain)
self._S_times = lambda z: fft_S.inverse_times(S(fft_S(z)))
self._S_inverse_times = lambda z: fft_S.inverse_times(
S.inverse_times(fft_S(z)))
fft_S = FFTOperator(S.domain, target=self._domain)
self._S_times = lambda z: fft_S(S(fft_S.inverse_times(z)))
self._S_inverse_times = lambda z: fft_S(S.inverse_times(
fft_S.inverse_times(z)))
if preconditioner is None:
preconditioner = self._S_times
......@@ -68,6 +64,8 @@ class PropagatorOperator(EndomorphicOperator):
self.inverter = ConjugateGradient(
preconditioner=self.preconditioner)
self.x0 = None
# ---Mandatory properties and methods---
@property
......@@ -93,10 +91,17 @@ class PropagatorOperator(EndomorphicOperator):
# ---Added properties and methods---
def _times(self, x, spaces, types):
(result, convergence) = self.inverter(A=self._inverse_times, b=x)
if self.x0 is None:
x0 = Field(self.domain, val=0., dtype=np.complex128)
else:
x0 = self.x0
(result, convergence) = self.inverter(A=self.inverse_times,
b=x,
x0=x0)
self.x0 = result
return result
def _inverse_multiply(self, x, **kwargs):
def _inverse_times(self, x, spaces, types):
result = self._S_inverse_times(x)
result += self._likelihood_times(x)
return result
......@@ -50,7 +50,7 @@ class SmoothingOperator(EndomorphicOperator):
@property
def symmetric(self):
return False
return True
@property
def unitary(self):
......@@ -98,7 +98,7 @@ class SmoothingOperator(EndomorphicOperator):
local_transformed_x = transformed_x.val.get_local_data(copy=False)
local_kernel = kernel.get_local_data(copy=False)
reshaper = [transformed_x.shape[i] if i in coaxes else 1
reshaper = [local_transformed_x.shape[i] if i in coaxes else 1
for i in xrange(len(transformed_x.shape))]
local_kernel = np.reshape(local_kernel, reshaper)
......
......@@ -13,7 +13,8 @@ def create_power_operator(domain, power_spectrum, distribution_strategy='not'):
fft = FFTOperator(domain)
domain = fft.target[0]
power_domain = PowerSpace(domain)
power_domain = PowerSpace(domain,
distribution_strategy=distribution_strategy)
fp = Field(power_domain,
val=power_spectrum,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment