Commit aa810aa8 authored by theos's avatar theos
Browse files

Added DiagonalProber and TraceProber classes

parent 7b26ce5f
......@@ -2,3 +2,5 @@
from prober import Prober
from operator_prober import OperatorProber
from diagonal_prober import *
from trace_prober import *
......@@ -4,8 +4,26 @@ from nifty.operators import EndomorphicOperator
from operator_prober import OperatorProber
__all__ = ['DiagonalProber', 'InverseDiagonalProber',
'AdjointDiagonalProber', 'AdjointInverseDiagonalProber']
class DiagonalProber(OperatorProber):
class DiagonalTypeProber(OperatorProber):
# ---Mandatory properties and methods---
# --- ->Mandatory from OperatorProber---
@property
def valid_operator_class(self):
return EndomorphicOperator
# --- ->Mandatory from Prober---
def finish_probe(self, probe, pre_result):
return probe[1].conjugate()*pre_result
class DiagonalProber(DiagonalTypeProber):
# ---Mandatory properties and methods---
# --- ->Mandatory from OperatorProber---
......@@ -14,15 +32,56 @@ class DiagonalProber(OperatorProber):
def is_inverse(self):
return False
# --- ->Mandatory from Prober---
def evaluate_probe(self, probe):
""" processes a probe """
return self.operator.times(probe[1])
class InverseDiagonalProber(DiagonalTypeProber):
# ---Mandatory properties and methods---
# --- ->Mandatory from OperatorProber---
@property
def valid_operator_class(self):
return EndomorphicOperator
def is_inverse(self):
return True
# --- ->Mandatory from Prober---
def evaluate_probe(self, probe):
""" processes a probe """
return self.operator.times(probe[1])
return self.operator.inverse_times(probe[1])
def finish_probe(self, probe, pre_result):
return probe[1].conjugate()*pre_result
class AdjointDiagonalProber(DiagonalTypeProber):
# ---Mandatory properties and methods---
# --- ->Mandatory from OperatorProber---
@property
def is_inverse(self):
return False
# --- ->Mandatory from Prober---
def evaluate_probe(self, probe):
""" processes a probe """
return self.operator.adjoint_times(probe[1])
class AdjointInverseDiagonalProber(DiagonalTypeProber):
# ---Mandatory properties and methods---
# --- ->Mandatory from OperatorProber---
@property
def is_inverse(self):
return True
# --- ->Mandatory from Prober---
def evaluate_probe(self, probe):
""" processes a probe """
return self.operator.adjoint_inverse_times(probe[1])
......@@ -16,7 +16,6 @@ class OperatorProber(Prober):
super(OperatorProber, self).__init__(
probe_count=probe_count,
random_type=random_type,
distribution_strategy=distribution_strategy,
compute_variance=compute_variance)
if not isinstance(operator, self.valid_operator_class):
......@@ -41,6 +40,10 @@ class OperatorProber(Prober):
else:
return self.operator.field_type
@property
def distribution_strategy(self):
return self.operator.distribution_strategy
# ---Added properties and methods---
@abc.abstractproperty
......
......@@ -4,9 +4,7 @@ import abc
import numpy as np
from nifty.config import about,\
nifty_configuration as gc
from nifty.config import about
from nifty.field import Field
from d2o import STRATEGIES as DISTRIBUTION_STRATEGIES
......@@ -16,31 +14,27 @@ class Prober(object):
__metaclass__ = abc.ABCMeta
def __init__(self, probe_count=8, random_type='pm1',
distribution_strategy=None, compute_variance=False):
compute_variance=False):
self.probe_count = probe_count
self.random_type = random_type
if distribution_strategy is None:
distribution_strategy = gc['default_distribution_strategy']
self.distribution_strategy = distribution_strategy
self.compute_variance = bool(compute_variance)
# ---Properties---
@abc.abstractproperty
def domain(self):
raise NotImplemented
raise NotImplementedError
@abc.abstractproperty
def field_type(self):
raise NotImplemented
raise NotImplementedError
@property
@abc.abstractproperty
def distribution_strategy(self):
return self._distribution_strategy
raise NotImplementedError
@distribution_strategy.setter
def distribution_strategy(self, distribution_strategy):
......
# -*- coding: utf-8 -*-
from nifty.operators import EndomorphicOperator
from operator_prober import OperatorProber
__all__ = ['TraceProber', 'InverseTraceProber',
'AdjointTraceProber', 'AdjointInverseTraceProber']
class TraceTypeProber(OperatorProber):
# ---Mandatory properties and methods---
# --- ->Mandatory from OperatorProber---
@property
def valid_operator_class(self):
return EndomorphicOperator
# --- ->Mandatory from Prober---
def finish_probe(self, probe, pre_result):
return probe[1].conjugate().weight(power=-1).dot(pre_result)
class TraceProber(TraceTypeProber):
# ---Mandatory properties and methods---
# --- ->Mandatory from OperatorProber---
@property
def is_inverse(self):
return False
# --- ->Mandatory from Prober---
def evaluate_probe(self, probe):
""" processes a probe """
return self.operator.times(probe[1])
class InverseTraceProber(TraceTypeProber):
# ---Mandatory properties and methods---
# --- ->Mandatory from OperatorProber---
@property
def is_inverse(self):
return True
# --- ->Mandatory from Prober---
def evaluate_probe(self, probe):
""" processes a probe """
return self.operator.inverse_times(probe[1])
class AdjointTraceProber(TraceTypeProber):
# ---Mandatory properties and methods---
# --- ->Mandatory from OperatorProber---
@property
def is_inverse(self):
return False
# --- ->Mandatory from Prober---
def evaluate_probe(self, probe):
""" processes a probe """
return self.operator.adjoint_times(probe[1])
class AdjointInverseTraceProber(TraceTypeProber):
# ---Mandatory properties and methods---
# --- ->Mandatory from OperatorProber---
@property
def is_inverse(self):
return True
# --- ->Mandatory from Prober---
def evaluate_probe(self, probe):
""" processes a probe """
return self.operator.adjoint_inverse_times(probe[1])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment