Commit de074086 authored by theos's avatar theos

Improved interface of Prober class.

Fixed a few small bugs.
parent 1c53be26
from nifty import *
import plotly.offline as pl
import plotly.graph_objs as go
#import plotly.offline as pl
#import plotly.graph_objs as go
from mpi4py import MPI
comm = MPI.COMM_WORLD
......@@ -12,7 +12,7 @@ if __name__ == "__main__":
distribution_strategy = 'fftw'
s_space = RGSpace([512, 512], dtype=np.complex128)
s_space = RGSpace([512, 512], dtype=np.float64)
fft = FFTOperator(s_space)
h_space = fft.target[0]
p_space = PowerSpace(h_space, distribution_strategy=distribution_strategy)
......@@ -46,8 +46,8 @@ if __name__ == "__main__":
m_data = m.val.get_full_data().real
ss_data = ss.val.get_full_data().real
if rank == 0:
pl.plot([go.Heatmap(z=d_data)], filename='data.html')
pl.plot([go.Heatmap(z=m_data)], filename='map.html')
pl.plot([go.Heatmap(z=ss_data)], filename='map_orig.html')
# if rank == 0:
# pl.plot([go.Heatmap(z=d_data)], filename='data.html')
# pl.plot([go.Heatmap(z=m_data)], filename='map.html')
# pl.plot([go.Heatmap(z=ss_data)], filename='map_orig.html')
#
......@@ -17,7 +17,7 @@ from nifty.random import Random
from keepers import Loggable
class Field(object, Loggable):
class Field(Loggable, object):
# ---Initialization methods---
def __init__(self, domain=None, val=None, dtype=None, field_type=None,
......@@ -430,7 +430,7 @@ class Field(object, Loggable):
if copy:
new_val = new_val.copy()
self._val = new_val
return self._val
return self
def get_val(self, copy=False):
if copy:
......
......@@ -7,7 +7,7 @@ import numpy as np
from keepers import Loggable
class ConjugateGradient(object, Loggable):
class ConjugateGradient(Loggable, object):
def __init__(self, convergence_tolerance=1E-4, convergence_level=3,
iteration_limit=None, reset_count=None,
preconditioner=None, callback=None):
......
......@@ -5,7 +5,7 @@ from keepers import Loggable
from nifty import LineEnergy
class LineSearch(object, Loggable):
class LineSearch(Loggable, object):
"""
Class for finding a step size.
"""
......
......@@ -9,7 +9,7 @@ from keepers import Loggable
from .line_searching import LineSearchStrongWolfe
class QuasiNewtonMinimizer(object, Loggable):
class QuasiNewtonMinimizer(Loggable, object):
__metaclass__ = abc.ABCMeta
def __init__(self, line_searcher=LineSearchStrongWolfe(), callback=None,
......
......@@ -60,27 +60,3 @@ class EndomorphicOperator(LinearOperator):
@abc.abstractproperty
def symmetric(self):
raise NotImplementedError
def trace(self):
pass
def inverse_trace(self):
pass
def diagonal(self):
pass
def inverse_diagonal(self):
pass
def determinant(self):
pass
def inverse_determinant(self):
pass
def log_determinant(self):
pass
def trace_log(self):
pass
......@@ -95,7 +95,7 @@ class FFTOperator(LinearOperator):
result_domain[spaces[0]] = self.target[0]
result_field = x.copy_empty(domain=result_domain)
result_field.set_val(new_val=new_val)
result_field.set_val(new_val=new_val, copy=False)
return result_field
......@@ -118,7 +118,7 @@ class FFTOperator(LinearOperator):
result_domain[spaces[0]] = self.domain[0]
result_field = x.copy_empty(domain=result_domain)
result_field.set_val(new_val=new_val)
result_field.set_val(new_val=new_val, copy=False)
return result_field
......
......@@ -10,7 +10,7 @@ from keepers import Loggable
pyfftw = gdi.get('pyfftw')
class Transform(object, Loggable):
class Transform(Loggable, object):
"""
A generic fft object without any implementation.
"""
......
......@@ -4,7 +4,7 @@ import abc
from keepers import Loggable
class Transformation(object, Loggable):
class Transformation(Loggable, object):
"""
A generic transformation which defines a static check_codomain
method for all transforms.
......
......@@ -9,7 +9,7 @@ from nifty.field_types import FieldType
import nifty.nifty_utilities as utilities
class LinearOperator(object, Loggable):
class LinearOperator(Loggable, object):
__metaclass__ = abc.ABCMeta
def __init__(self):
......
# -*- coding: utf-8 -*-
import numpy as np
from nifty.minimization import ConjugateGradient
from nifty.nifty_utilities import get_default_codomain
from nifty.field import Field
from nifty.operators import EndomorphicOperator,\
FFTOperator
......@@ -45,10 +43,8 @@ class PropagatorOperator(EndomorphicOperator):
self._domain = N.domain
self._likelihood_times = lambda z: N.inverse_times(z)
fft_S = FFTOperator(S.domain, target=self._domain)
self._S_times = lambda z: fft_S(S(fft_S.inverse_times(z)))
self._S_inverse_times = lambda z: fft_S(S.inverse_times(
fft_S.inverse_times(z)))
self._S = S
self._fft_S = FFTOperator(self._domain, target=self._S.domain)
if preconditioner is None:
preconditioner = self._S_times
......@@ -61,8 +57,6 @@ class PropagatorOperator(EndomorphicOperator):
self.inverter = ConjugateGradient(
preconditioner=self.preconditioner)
self.x0 = None
# ---Mandatory properties and methods---
@property
......@@ -87,18 +81,44 @@ class PropagatorOperator(EndomorphicOperator):
# ---Added properties and methods---
def _times(self, x, spaces, types):
if self.x0 is None:
x0 = Field(self.domain, val=0., dtype=np.complex128)
else:
x0 = self.x0
def _S_times(self, x, spaces=None, types=None):
transformed_x = self._fft_S(x,
spaces=spaces,
types=types)
y = self._S(transformed_x, spaces=spaces, types=types)
transformed_y = self._fft_S.inverse_times(y,
spaces=spaces,
types=types)
result = x.copy_empty()
result.set_val(transformed_y, copy=False)
return result
def _S_inverse_times(self, x, spaces=None, types=None):
transformed_x = self._fft_S(x,
spaces=spaces,
types=types)
y = self._S.inverse_times(transformed_x,
spaces=spaces,
types=types)
transformed_y = self._fft_S.inverse_times(y,
spaces=spaces,
types=types)
result = x.copy_empty()
result.set_val(transformed_y, copy=False)
return result
def _times(self, x, spaces, types, x0=None):
if x0 is None:
x0 = Field(self.domain, val=0., dtype=x.dtype)
(result, convergence) = self.inverter(A=self.inverse_times,
b=x,
x0=x0)
self.x0 = result
return result
def _inverse_times(self, x, spaces, types):
result = self._S_inverse_times(x)
result += self._likelihood_times(x)
pre_result = self._S_inverse_times(x, spaces, types)
pre_result += self._likelihood_times(x)
result = x.copy_empty()
result.set_val(pre_result, copy=False)
return result
......@@ -54,7 +54,7 @@ class SmoothingOperator(EndomorphicOperator):
@property
def symmetric(self):
return False
return True
@property
def unitary(self):
......@@ -138,7 +138,10 @@ class SmoothingOperator(EndomorphicOperator):
transformed_x.val.set_local_data(local_transformed_x, copy=False)
result = Transformator.inverse_times(transformed_x, spaces=spaces)
smoothed_x = Transformator.inverse_times(transformed_x, spaces=spaces)
result = x.copy_empty()
result.set_val(smoothed_x, copy=False)
return result
......
# -*- coding: utf-8 -*-
from nifty.operators import EndomorphicOperator
from prober import Prober
from operator_prober import OperatorProber
__all__ = ['DiagonalProber', 'InverseDiagonalProber',
'AdjointDiagonalProber', 'AdjointInverseDiagonalProber']
class DiagonalTypeProber(OperatorProber):
class DiagonalProber(Prober):
# ---Mandatory properties and methods---
# --- ->Mandatory from OperatorProber---
@property
def valid_operator_class(self):
return EndomorphicOperator
# --- ->Mandatory from Prober---
def finish_probe(self, probe, pre_result):
return probe[1].conjugate()*pre_result
class DiagonalProber(DiagonalTypeProber):
# ---Mandatory properties and methods---
# --- ->Mandatory from OperatorProber---
@property
def is_inverse(self):
return False
# --- ->Mandatory from Prober---
def evaluate_probe(self, probe):
""" processes a probe """
return self.operator.times(probe[1])
class InverseDiagonalProber(DiagonalTypeProber):
# ---Mandatory properties and methods---
# --- ->Mandatory from OperatorProber---
@property
def is_inverse(self):
return True
# --- ->Mandatory from Prober---
def evaluate_probe(self, probe):
""" processes a probe """
return self.operator.inverse_times(probe[1])
class AdjointDiagonalProber(DiagonalTypeProber):
# ---Mandatory properties and methods---
# --- ->Mandatory from OperatorProber---
@property
def is_inverse(self):
return False
# --- ->Mandatory from Prober---
def evaluate_probe(self, probe):
""" processes a probe """
return self.operator.adjoint_times(probe[1])
class AdjointInverseDiagonalProber(DiagonalTypeProber):
# ---Mandatory properties and methods---
# --- ->Mandatory from OperatorProber---
@property
def is_inverse(self):
return True
# --- ->Mandatory from Prober---
def evaluate_probe(self, probe):
""" processes a probe """
return self.operator.adjoint_inverse_times(probe[1])
# -*- coding: utf-8 -*-
import abc
from prober import Prober
class OperatorProber(Prober):
# ---Overwritten properties and methods---
def __init__(self, operator, probe_count=8, random_type='pm1',
distribution_strategy=None, compute_variance=False):
super(OperatorProber, self).__init__(
probe_count=probe_count,
random_type=random_type,
compute_variance=compute_variance)
if not isinstance(operator, self.valid_operator_class):
raise TypeError("Operator must be an instance of %s" %
str(self.valid_operator_class))
self._operator = operator
# ---Mandatory properties and methods---
@property
def domain(self):
if self.is_inverse:
return self.operator.target
else:
return self.operator.domain
@property
def field_type(self):
if self.is_inverse:
return self.operator.field_type_target
else:
return self.operator.field_type
@property
def distribution_strategy(self):
return self.operator.distribution_strategy
# ---Added properties and methods---
@abc.abstractproperty
def is_inverse(self):
raise NotImplementedError
@abc.abstractproperty
def valid_operator_class(self):
raise NotImplementedError
@property
def operator(self):
return self._operator
......@@ -4,6 +4,8 @@ import abc
import numpy as np
from nifty.field_types import FieldType
from nifty.spaces import Space
from nifty.field import Field
from d2o import STRATEGIES as DISTRIBUTION_STRATEGIES
......@@ -12,28 +14,67 @@ from d2o import STRATEGIES as DISTRIBUTION_STRATEGIES
class Prober(object):
__metaclass__ = abc.ABCMeta
def __init__(self, probe_count=8, random_type='pm1',
compute_variance=False):
def __init__(self, domain=None, field_type=None,
distribution_strategy=None, probe_count=8,
random_type='pm1', compute_variance=False):
self.domain = domain
self.field_type = field_type
self.distribution_strategy = distribution_strategy
self.probe_count = probe_count
self.random_type = random_type
self.compute_variance = bool(compute_variance)
def _parse_domain(self, domain):
if domain is None:
domain = ()
elif isinstance(domain, Space):
domain = (domain,)
elif not isinstance(domain, tuple):
domain = tuple(domain)
for d in domain:
if not isinstance(d, Space):
raise TypeError(
"Given object contains something that is not a "
"nifty.space.")
return domain
def _parse_field_type(self, field_type):
if field_type is None:
field_type = ()
elif isinstance(field_type, FieldType):
field_type = (field_type,)
elif not isinstance(field_type, tuple):
field_type = tuple(field_type)
for ft in field_type:
if not isinstance(ft, FieldType):
raise TypeError(
"Given object is not a nifty.FieldType.")
return field_type
# ---Properties---
@abc.abstractproperty
@property
def domain(self):
raise NotImplementedError
return self._domain
@domain.setter
def domain(self, domain):
self._domain = self._parse_domain(domain)
@abc.abstractproperty
@property
def field_type(self):
raise NotImplementedError
return self._field_type
@field_type.setter
def field_type(self, field_type):
self._field_type = self._parse_field_type(field_type)
@abc.abstractproperty
@property
def distribution_strategy(self):
raise NotImplementedError
return self._distribution_strategy
@distribution_strategy.setter
def distribution_strategy(self, distribution_strategy):
......@@ -65,14 +106,14 @@ class Prober(object):
# ---Probing methods---
def probing_run(self):
def probing_run(self, callee):
""" controls the generation, evaluation and finalization of probes """
sum_of_probes = 0
sum_of_squares = 0
for index in xrange(self.probe_count):
current_probe = self.get_probe(index)
pre_result = self.process_probe(current_probe, index)
pre_result = self.process_probe(callee, current_probe, index)
result = self.finish_probe(current_probe, pre_result)
sum_of_probes += result
......@@ -95,13 +136,13 @@ class Prober(object):
uid = np.random.randint(1e18)
return (uid, f)
def process_probe(self, probe, index):
return self.evaluate_probe(probe)
def process_probe(self, callee, probe, index):
""" layer of abstraction for potential result-caching/recycling """
return self.evaluate_probe(callee, probe[1])
@abc.abstractmethod
def evaluate_probe(self, probe):
def evaluate_probe(self, callee, probe, **kwargs):
""" processes a probe """
raise NotImplementedError
return callee(probe, **kwargs)
@abc.abstractmethod
def finish_probe(self, probe, pre_result):
......@@ -119,5 +160,5 @@ class Prober(object):
return (mean_of_probes, variance)
def __call__(self):
return self.probe()
def __call__(self, callee):
return self.probing_run(callee)
# -*- coding: utf-8 -*-
from nifty.operators import EndomorphicOperator
from prober import Prober
from operator_prober import OperatorProber
__all__ = ['TraceProber', 'InverseTraceProber',
'AdjointTraceProber', 'AdjointInverseTraceProber']
class TraceTypeProber(OperatorProber):
class TraceProber(Prober):
# ---Mandatory properties and methods---
# --- ->Mandatory from OperatorProber---
@property
def valid_operator_class(self):
return EndomorphicOperator
# --- ->Mandatory from Prober---
def finish_probe(self, probe, pre_result):
return probe[1].conjugate().weight(power=-1).dot(pre_result)
class TraceProber(TraceTypeProber):
# ---Mandatory properties and methods---
# --- ->Mandatory from OperatorProber---
@property
def is_inverse(self):
return False
# --- ->Mandatory from Prober---
def evaluate_probe(self, probe):
""" processes a probe """
return self.operator.times(probe[1])
class InverseTraceProber(TraceTypeProber):
# ---Mandatory properties and methods---
# --- ->Mandatory from OperatorProber---
@property
def is_inverse(self):
return True
# --- ->Mandatory from Prober---
def evaluate_probe(self, probe):
""" processes a probe """
return self.operator.inverse_times(probe[1])
class AdjointTraceProber(TraceTypeProber):
# ---Mandatory properties and methods---
# --- ->Mandatory from OperatorProber---
@property
def is_inverse(self):
return False
# --- ->Mandatory from Prober---
def evaluate_probe(self, probe):
""" processes a probe """
return self.operator.adjoint_times(probe[1])
class AdjointInverseTraceProber(TraceTypeProber):
# ---Mandatory properties and methods---
# --- ->Mandatory from OperatorProber---
@property
def is_inverse(self):
return True
# --- ->Mandatory from Prober---
def evaluate_probe(self, probe):
""" processes a probe """
return self.operator.adjoint_inverse_times(probe[1])
......@@ -149,7 +149,7 @@ import numpy as np
from keepers import Loggable
class Space(object, Loggable):
class Space(Loggable, object):
"""
.. __ __
.. /__/ / /_
......
......@@ -25,3 +25,4 @@ def create_power_operator(domain, power_spectrum, distribution_strategy='not'):
power_operator = DiagonalOperator(domain, diagonal=f)