Commit de074086 authored by theos's avatar theos
Browse files

Improved interface of Prober class.

Fixed a few small bugs.
parent 1c53be26
from nifty import * from nifty import *
import plotly.offline as pl #import plotly.offline as pl
import plotly.graph_objs as go #import plotly.graph_objs as go
from mpi4py import MPI from mpi4py import MPI
comm = MPI.COMM_WORLD comm = MPI.COMM_WORLD
...@@ -12,7 +12,7 @@ if __name__ == "__main__": ...@@ -12,7 +12,7 @@ if __name__ == "__main__":
distribution_strategy = 'fftw' distribution_strategy = 'fftw'
s_space = RGSpace([512, 512], dtype=np.complex128) s_space = RGSpace([512, 512], dtype=np.float64)
fft = FFTOperator(s_space) fft = FFTOperator(s_space)
h_space = fft.target[0] h_space = fft.target[0]
p_space = PowerSpace(h_space, distribution_strategy=distribution_strategy) p_space = PowerSpace(h_space, distribution_strategy=distribution_strategy)
...@@ -46,8 +46,8 @@ if __name__ == "__main__": ...@@ -46,8 +46,8 @@ if __name__ == "__main__":
m_data = m.val.get_full_data().real m_data = m.val.get_full_data().real
ss_data = ss.val.get_full_data().real ss_data = ss.val.get_full_data().real
if rank == 0: # if rank == 0:
pl.plot([go.Heatmap(z=d_data)], filename='data.html') # pl.plot([go.Heatmap(z=d_data)], filename='data.html')
pl.plot([go.Heatmap(z=m_data)], filename='map.html') # pl.plot([go.Heatmap(z=m_data)], filename='map.html')
pl.plot([go.Heatmap(z=ss_data)], filename='map_orig.html') # pl.plot([go.Heatmap(z=ss_data)], filename='map_orig.html')
#
...@@ -17,7 +17,7 @@ from nifty.random import Random ...@@ -17,7 +17,7 @@ from nifty.random import Random
from keepers import Loggable from keepers import Loggable
class Field(object, Loggable): class Field(Loggable, object):
# ---Initialization methods--- # ---Initialization methods---
def __init__(self, domain=None, val=None, dtype=None, field_type=None, def __init__(self, domain=None, val=None, dtype=None, field_type=None,
...@@ -430,7 +430,7 @@ class Field(object, Loggable): ...@@ -430,7 +430,7 @@ class Field(object, Loggable):
if copy: if copy:
new_val = new_val.copy() new_val = new_val.copy()
self._val = new_val self._val = new_val
return self._val return self
def get_val(self, copy=False): def get_val(self, copy=False):
if copy: if copy:
......
...@@ -7,7 +7,7 @@ import numpy as np ...@@ -7,7 +7,7 @@ import numpy as np
from keepers import Loggable from keepers import Loggable
class ConjugateGradient(object, Loggable): class ConjugateGradient(Loggable, object):
def __init__(self, convergence_tolerance=1E-4, convergence_level=3, def __init__(self, convergence_tolerance=1E-4, convergence_level=3,
iteration_limit=None, reset_count=None, iteration_limit=None, reset_count=None,
preconditioner=None, callback=None): preconditioner=None, callback=None):
......
...@@ -5,7 +5,7 @@ from keepers import Loggable ...@@ -5,7 +5,7 @@ from keepers import Loggable
from nifty import LineEnergy from nifty import LineEnergy
class LineSearch(object, Loggable): class LineSearch(Loggable, object):
""" """
Class for finding a step size. Class for finding a step size.
""" """
......
...@@ -9,7 +9,7 @@ from keepers import Loggable ...@@ -9,7 +9,7 @@ from keepers import Loggable
from .line_searching import LineSearchStrongWolfe from .line_searching import LineSearchStrongWolfe
class QuasiNewtonMinimizer(object, Loggable): class QuasiNewtonMinimizer(Loggable, object):
__metaclass__ = abc.ABCMeta __metaclass__ = abc.ABCMeta
def __init__(self, line_searcher=LineSearchStrongWolfe(), callback=None, def __init__(self, line_searcher=LineSearchStrongWolfe(), callback=None,
......
...@@ -60,27 +60,3 @@ class EndomorphicOperator(LinearOperator): ...@@ -60,27 +60,3 @@ class EndomorphicOperator(LinearOperator):
@abc.abstractproperty @abc.abstractproperty
def symmetric(self): def symmetric(self):
raise NotImplementedError raise NotImplementedError
def trace(self):
pass
def inverse_trace(self):
pass
def diagonal(self):
pass
def inverse_diagonal(self):
pass
def determinant(self):
pass
def inverse_determinant(self):
pass
def log_determinant(self):
pass
def trace_log(self):
pass
...@@ -95,7 +95,7 @@ class FFTOperator(LinearOperator): ...@@ -95,7 +95,7 @@ class FFTOperator(LinearOperator):
result_domain[spaces[0]] = self.target[0] result_domain[spaces[0]] = self.target[0]
result_field = x.copy_empty(domain=result_domain) result_field = x.copy_empty(domain=result_domain)
result_field.set_val(new_val=new_val) result_field.set_val(new_val=new_val, copy=False)
return result_field return result_field
...@@ -118,7 +118,7 @@ class FFTOperator(LinearOperator): ...@@ -118,7 +118,7 @@ class FFTOperator(LinearOperator):
result_domain[spaces[0]] = self.domain[0] result_domain[spaces[0]] = self.domain[0]
result_field = x.copy_empty(domain=result_domain) result_field = x.copy_empty(domain=result_domain)
result_field.set_val(new_val=new_val) result_field.set_val(new_val=new_val, copy=False)
return result_field return result_field
......
...@@ -10,7 +10,7 @@ from keepers import Loggable ...@@ -10,7 +10,7 @@ from keepers import Loggable
pyfftw = gdi.get('pyfftw') pyfftw = gdi.get('pyfftw')
class Transform(object, Loggable): class Transform(Loggable, object):
""" """
A generic fft object without any implementation. A generic fft object without any implementation.
""" """
......
...@@ -4,7 +4,7 @@ import abc ...@@ -4,7 +4,7 @@ import abc
from keepers import Loggable from keepers import Loggable
class Transformation(object, Loggable): class Transformation(Loggable, object):
""" """
A generic transformation which defines a static check_codomain A generic transformation which defines a static check_codomain
method for all transforms. method for all transforms.
......
...@@ -9,7 +9,7 @@ from nifty.field_types import FieldType ...@@ -9,7 +9,7 @@ from nifty.field_types import FieldType
import nifty.nifty_utilities as utilities import nifty.nifty_utilities as utilities
class LinearOperator(object, Loggable): class LinearOperator(Loggable, object):
__metaclass__ = abc.ABCMeta __metaclass__ = abc.ABCMeta
def __init__(self): def __init__(self):
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import numpy as np
from nifty.minimization import ConjugateGradient from nifty.minimization import ConjugateGradient
from nifty.nifty_utilities import get_default_codomain
from nifty.field import Field from nifty.field import Field
from nifty.operators import EndomorphicOperator,\ from nifty.operators import EndomorphicOperator,\
FFTOperator FFTOperator
...@@ -45,10 +43,8 @@ class PropagatorOperator(EndomorphicOperator): ...@@ -45,10 +43,8 @@ class PropagatorOperator(EndomorphicOperator):
self._domain = N.domain self._domain = N.domain
self._likelihood_times = lambda z: N.inverse_times(z) self._likelihood_times = lambda z: N.inverse_times(z)
fft_S = FFTOperator(S.domain, target=self._domain) self._S = S
self._S_times = lambda z: fft_S(S(fft_S.inverse_times(z))) self._fft_S = FFTOperator(self._domain, target=self._S.domain)
self._S_inverse_times = lambda z: fft_S(S.inverse_times(
fft_S.inverse_times(z)))
if preconditioner is None: if preconditioner is None:
preconditioner = self._S_times preconditioner = self._S_times
...@@ -61,8 +57,6 @@ class PropagatorOperator(EndomorphicOperator): ...@@ -61,8 +57,6 @@ class PropagatorOperator(EndomorphicOperator):
self.inverter = ConjugateGradient( self.inverter = ConjugateGradient(
preconditioner=self.preconditioner) preconditioner=self.preconditioner)
self.x0 = None
# ---Mandatory properties and methods--- # ---Mandatory properties and methods---
@property @property
...@@ -87,18 +81,44 @@ class PropagatorOperator(EndomorphicOperator): ...@@ -87,18 +81,44 @@ class PropagatorOperator(EndomorphicOperator):
# ---Added properties and methods--- # ---Added properties and methods---
def _times(self, x, spaces, types): def _S_times(self, x, spaces=None, types=None):
if self.x0 is None: transformed_x = self._fft_S(x,
x0 = Field(self.domain, val=0., dtype=np.complex128) spaces=spaces,
else: types=types)
x0 = self.x0 y = self._S(transformed_x, spaces=spaces, types=types)
transformed_y = self._fft_S.inverse_times(y,
spaces=spaces,
types=types)
result = x.copy_empty()
result.set_val(transformed_y, copy=False)
return result
def _S_inverse_times(self, x, spaces=None, types=None):
transformed_x = self._fft_S(x,
spaces=spaces,
types=types)
y = self._S.inverse_times(transformed_x,
spaces=spaces,
types=types)
transformed_y = self._fft_S.inverse_times(y,
spaces=spaces,
types=types)
result = x.copy_empty()
result.set_val(transformed_y, copy=False)
return result
def _times(self, x, spaces, types, x0=None):
if x0 is None:
x0 = Field(self.domain, val=0., dtype=x.dtype)
(result, convergence) = self.inverter(A=self.inverse_times, (result, convergence) = self.inverter(A=self.inverse_times,
b=x, b=x,
x0=x0) x0=x0)
self.x0 = result
return result return result
def _inverse_times(self, x, spaces, types): def _inverse_times(self, x, spaces, types):
result = self._S_inverse_times(x) pre_result = self._S_inverse_times(x, spaces, types)
result += self._likelihood_times(x) pre_result += self._likelihood_times(x)
result = x.copy_empty()
result.set_val(pre_result, copy=False)
return result return result
...@@ -54,7 +54,7 @@ class SmoothingOperator(EndomorphicOperator): ...@@ -54,7 +54,7 @@ class SmoothingOperator(EndomorphicOperator):
@property @property
def symmetric(self): def symmetric(self):
return False return True
@property @property
def unitary(self): def unitary(self):
...@@ -138,7 +138,10 @@ class SmoothingOperator(EndomorphicOperator): ...@@ -138,7 +138,10 @@ class SmoothingOperator(EndomorphicOperator):
transformed_x.val.set_local_data(local_transformed_x, copy=False) transformed_x.val.set_local_data(local_transformed_x, copy=False)
result = Transformator.inverse_times(transformed_x, spaces=spaces) smoothed_x = Transformator.inverse_times(transformed_x, spaces=spaces)
result = x.copy_empty()
result.set_val(smoothed_x, copy=False)
return result return result
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from nifty.operators import EndomorphicOperator from prober import Prober
from operator_prober import OperatorProber
__all__ = ['DiagonalProber', 'InverseDiagonalProber', class DiagonalProber(Prober):
'AdjointDiagonalProber', 'AdjointInverseDiagonalProber']
class DiagonalTypeProber(OperatorProber):
# ---Mandatory properties and methods--- # ---Mandatory properties and methods---
# --- ->Mandatory from OperatorProber---
@property
def valid_operator_class(self):
return EndomorphicOperator
# --- ->Mandatory from Prober---
def finish_probe(self, probe, pre_result): def finish_probe(self, probe, pre_result):
return probe[1].conjugate()*pre_result return probe[1].conjugate()*pre_result
class DiagonalProber(DiagonalTypeProber):
# ---Mandatory properties and methods---
# --- ->Mandatory from OperatorProber---
@property
def is_inverse(self):
return False
# --- ->Mandatory from Prober---
def evaluate_probe(self, probe):
""" processes a probe """
return self.operator.times(probe[1])
class InverseDiagonalProber(DiagonalTypeProber):
# ---Mandatory properties and methods---
# --- ->Mandatory from OperatorProber---
@property
def is_inverse(self):
return True
# --- ->Mandatory from Prober---
def evaluate_probe(self, probe):
""" processes a probe """
return self.operator.inverse_times(probe[1])
class AdjointDiagonalProber(DiagonalTypeProber):
# ---Mandatory properties and methods---
# --- ->Mandatory from OperatorProber---
@property
def is_inverse(self):
return False
# --- ->Mandatory from Prober---
def evaluate_probe(self, probe):
""" processes a probe """
return self.operator.adjoint_times(probe[1])
class AdjointInverseDiagonalProber(DiagonalTypeProber):
# ---Mandatory properties and methods---
# --- ->Mandatory from OperatorProber---
@property
def is_inverse(self):
return True
# --- ->Mandatory from Prober---
def evaluate_probe(self, probe):
""" processes a probe """
return self.operator.adjoint_inverse_times(probe[1])
# -*- coding: utf-8 -*-
import abc
from prober import Prober
class OperatorProber(Prober):
# ---Overwritten properties and methods---
def __init__(self, operator, probe_count=8, random_type='pm1',
distribution_strategy=None, compute_variance=False):
super(OperatorProber, self).__init__(
probe_count=probe_count,
random_type=random_type,
compute_variance=compute_variance)
if not isinstance(operator, self.valid_operator_class):
raise TypeError("Operator must be an instance of %s" %
str(self.valid_operator_class))
self._operator = operator
# ---Mandatory properties and methods---
@property
def domain(self):
if self.is_inverse:
return self.operator.target
else:
return self.operator.domain
@property
def field_type(self):
if self.is_inverse:
return self.operator.field_type_target
else:
return self.operator.field_type
@property
def distribution_strategy(self):
return self.operator.distribution_strategy
# ---Added properties and methods---
@abc.abstractproperty
def is_inverse(self):
raise NotImplementedError
@abc.abstractproperty
def valid_operator_class(self):
raise NotImplementedError
@property
def operator(self):
return self._operator
...@@ -4,6 +4,8 @@ import abc ...@@ -4,6 +4,8 @@ import abc
import numpy as np import numpy as np
from nifty.field_types import FieldType
from nifty.spaces import Space
from nifty.field import Field from nifty.field import Field
from d2o import STRATEGIES as DISTRIBUTION_STRATEGIES from d2o import STRATEGIES as DISTRIBUTION_STRATEGIES
...@@ -12,28 +14,67 @@ from d2o import STRATEGIES as DISTRIBUTION_STRATEGIES ...@@ -12,28 +14,67 @@ from d2o import STRATEGIES as DISTRIBUTION_STRATEGIES
class Prober(object): class Prober(object):
__metaclass__ = abc.ABCMeta __metaclass__ = abc.ABCMeta
def __init__(self, probe_count=8, random_type='pm1', def __init__(self, domain=None, field_type=None,
compute_variance=False): distribution_strategy=None, probe_count=8,
random_type='pm1', compute_variance=False):
self.domain = domain
self.field_type = field_type
self.distribution_strategy = distribution_strategy
self.probe_count = probe_count self.probe_count = probe_count
self.random_type = random_type self.random_type = random_type
self.compute_variance = bool(compute_variance) self.compute_variance = bool(compute_variance)
def _parse_domain(self, domain):
if domain is None:
domain = ()
elif isinstance(domain, Space):
domain = (domain,)
elif not isinstance(domain, tuple):
domain = tuple(domain)
for d in domain:
if not isinstance(d, Space):
raise TypeError(
"Given object contains something that is not a "
"nifty.space.")
return domain
def _parse_field_type(self, field_type):
if field_type is None:
field_type = ()
elif isinstance(field_type, FieldType):
field_type = (field_type,)
elif not isinstance(field_type, tuple):
field_type = tuple(field_type)
for ft in field_type:
if not isinstance(ft, FieldType):
raise TypeError(
"Given object is not a nifty.FieldType.")
return field_type
# ---Properties--- # ---Properties---
@abc.abstractproperty @property
def domain(self): def domain(self):
raise NotImplementedError return self._domain
@domain.setter
def domain(self, domain):
self._domain = self._parse_domain(domain)
@abc.abstractproperty @property
def field_type(self): def field_type(self):
raise NotImplementedError return self._field_type
@field_type.setter
def field_type(self, field_type):
self._field_type = self._parse_field_type(field_type)
@abc.abstractproperty @property
def distribution_strategy(self): def distribution_strategy(self):
raise NotImplementedError return self._distribution_strategy
@distribution_strategy.setter @distribution_strategy.setter
def distribution_strategy(self, distribution_strategy): def distribution_strategy(self, distribution_strategy):
...@@ -65,14 +106,14 @@ class Prober(object): ...@@ -65,14 +106,14 @@ cl