Commit 9c429547 authored by Theo Steininger's avatar Theo Steininger
Browse files

Finished refactoring of probing classes. Now uses mixin-classes.

parent cd593c77
......@@ -278,3 +278,36 @@ def get_default_codomain(domain):
return LMGLTransformation.get_codomain(domain)
else:
raise TypeError('ERROR: unknown domain')
def parse_domain(domain):
from nifty.spaces.space import Space
if domain is None:
domain = ()
elif isinstance(domain, Space):
domain = (domain,)
elif not isinstance(domain, tuple):
domain = tuple(domain)
for d in domain:
if not isinstance(d, Space):
raise TypeError(
"Given object contains something that is not a "
"nifty.space.")
return domain
def parse_field_type(field_type):
from nifty.field_types import FieldType
if field_type is None:
field_type = ()
elif isinstance(field_type, FieldType):
field_type = (field_type,)
elif not isinstance(field_type, tuple):
field_type = tuple(field_type)
for ft in field_type:
if not isinstance(ft, FieldType):
raise TypeError(
"Given object is not a nifty.FieldType.")
return field_type
from transformations import *
from fft_operator import FFTOperator
......@@ -4,8 +4,6 @@ import abc
from keepers import Loggable
from nifty.field import Field
from nifty.spaces import Space
from nifty.field_types import FieldType
import nifty.nifty_utilities as utilities
......@@ -16,33 +14,10 @@ class LinearOperator(Loggable, object):
pass
def _parse_domain(self, domain):
if domain is None:
domain = ()
elif isinstance(domain, Space):
domain = (domain,)
elif not isinstance(domain, tuple):
domain = tuple(domain)
for d in domain:
if not isinstance(d, Space):
raise TypeError(
"Given object contains something that is not a "
"nifty.space.")
return domain
return utilities.parse_domain(domain)
def _parse_field_type(self, field_type):
if field_type is None:
field_type = ()
elif isinstance(field_type, FieldType):
field_type = (field_type,)
elif not isinstance(field_type, tuple):
field_type = tuple(field_type)
for ft in field_type:
if not isinstance(ft, FieldType):
raise TypeError(
"Given object is not a nifty.FieldType.")
return field_type
return utilities.parse_field_type(field_type)
@abc.abstractproperty
def domain(self):
......
# -*- coding: utf-8 -*-
from prober import Prober
class DiagonalProber(Prober):
# ---Mandatory properties and methods---
def finish_probe(self, probe, pre_result):
return probe[1].conjugate()*pre_result
# -*- coding: utf-8 -*-
from prober import Prober
class TraceProber(Prober):
# ---Mandatory properties and methods---
def finish_probe(self, probe, pre_result):
return probe[1].conjugate().weight(power=-1).dot(pre_result)
# -*- coding: utf-8 -*-
from prober import Prober
from diagonal_prober import *
from trace_prober import *
from mixin_classes import *
# -*- coding: utf-8 -*-
from mixin_base import MixinBase
from diagonal_prober_mixin import DiagonalProberMixin
from trace_prober_mixin import TraceProberMixin
# -*- coding: utf-8 -*-
from mixin_base import MixinBase
class DiagonalProberMixin(MixinBase):
def __init__(self):
self.reset()
super(DiagonalProberMixin, self).__init__()
def reset(self):
self.__sum_of_probings = 0
self.__sum_of_squares = 0
self.__diagonal = None
self.__diagonal_variance = None
super(DiagonalProberMixin, self).reset()
def finish_probe(self, probe, pre_result):
result = probe[1].conjugate()*pre_result
self.__sum_of_probings += result
if self.compute_variance:
self.__sum_of_squares += result.conjugate() * result
super(DiagonalProberMixin, self).finish_probe(probe, pre_result)
@property
def diagonal(self):
if self.__diagonal is None:
self.__diagonal = self.__sum_of_probings/self.probe_count
return self.__diagonal
@property
def diagonal_variance(self):
if not self.compute_variance:
raise AttributeError("self.compute_variance is set to False")
if self.__diagonal_variance is None:
# variance = 1/(n-1) (sum(x^2) - 1/n*sum(x)^2)
n = self.probe_count
sum_pr = self.__sum_of_probings
mean = self.diagonal
sum_sq = self.__sum_of_squares
self.__diagonal_variance = ((sum_sq - sum_pr*mean) / (n-1))
return self.__diagonal_variance
# -*- coding: utf-8 -*-
class MixinBase(object):
def reset(self, *args, **kwargs):
pass
def finish_probe(self, *args, **kwargs):
pass
# -*- coding: utf-8 -*-
from mixin_base import MixinBase
class TraceProberMixin(MixinBase):
def __init__(self):
self.reset()
super(TraceProberMixin, self).__init__()
def reset(self):
self.__sum_of_probings = 0
self.__sum_of_squares = 0
self.__trace = None
self.__trace_variance = None
super(TraceProberMixin, self).reset()
def finish_probe(self, probe, pre_result):
result = probe[1].dot(pre_result, bare=True)
self.__sum_of_probings += result
if self.compute_variance:
self.__sum_of_squares += result.conjugate() * result
super(TraceProberMixin, self).finish_probe(probe, pre_result)
@property
def trace(self):
if self.__trace is None:
self.__trace = self.__sum_of_probings/self.probe_count
return self.__trace
@property
def trace_variance(self):
if not self.compute_variance:
raise AttributeError("self.compute_variance is set to False")
if self.__trace_variance is None:
# variance = 1/(n-1) (sum(x^2) - 1/n*sum(x)^2)
n = self.probe_count
sum_pr = self.__sum_of_probings
mean = self.trace
sum_sq = self.__sum_of_squares
self.__trace_variance = ((sum_sq - sum_pr*mean) / (n-1))
return self.__trace_variance
# -*- coding: utf-8 -*-
from prober import Prober
......@@ -4,33 +4,43 @@ import abc
import numpy as np
from nifty.field_types import FieldType
from nifty.spaces import Space
from nifty.field import Field
from nifty.operators.endomorphic_operator import EndomorphicOperator
import nifty.nifty_utilities as utilities
from nifty import nifty_configuration as nc
from d2o import STRATEGIES as DISTRIBUTION_STRATEGIES
class ProbingOperator(EndomorphicOperator):
class Prober(object):
"""
aka DiagonalProbingOperator
See the following webpages for the principles behind the usage of
mixin-classes
https://www.python.org/download/releases/2.2.3/descrintro/#cooperation
https://rhettinger.wordpress.com/2011/05/26/super-considered-super/
"""
# ---Overwritten properties and methods---
__metaclass__ = abc.ABCMeta
def __init__(self, domain=None, field_type=None,
distribution_strategy=None, probe_count=8,
random_type='pm1', compute_variance=False):
self._domain = self._parse_domain(domain)
self._field_type = self._parse_field_type(field_type)
self._domain = utilities.parse_domain(domain)
self._field_type = utilities.parse_field_type(field_type)
self._distribution_strategy = \
self._parse_distribution_strategy(distribution_strategy)
self.distribution_strategy = distribution_strategy
self.probe_count = probe_count
self.random_type = random_type
self._probe_count = self._parse_probe_count(probe_count)
self._random_type = self._parse_random_type(random_type)
self.compute_variance = bool(compute_variance)
# ---Mandatory properties and methods---
super(Prober, self).__init__()
# ---Properties---
@property
def domain(self):
......@@ -40,57 +50,49 @@ class ProbingOperator(EndomorphicOperator):
def field_type(self):
return self._field_type
# ---Added properties and methods---
@property
def distribution_strategy(self):
return self._distribution_strategy
def _parse_distribution_strategy(self, distribution_strategy):
distribution_strategy = str(distribution_strategy)
if distribution_strategy is None:
distribution_strategy = nc['default_distribution_strategy']
else:
distribution_strategy = str(distribution_strategy)
if distribution_strategy not in DISTRIBUTION_STRATEGIES['global']:
raise ValueError("distribution_strategy must be a global-type "
"strategy.")
return distribution_strategy
self._distribution_strategy = distribution_strategy
@property
def probe_count(self):
return self._probe_count
@probe_count.setter
def probe_count(self, probe_count):
self._probe_count = int(probe_count)
def _parse_probe_count(self, probe_count):
return int(probe_count)
@property
def random_type(self):
return self._random_type
@random_type.setter
def random_type(self, random_type):
def _parse_random_type(self, random_type):
if random_type not in ["pm1", "normal"]:
raise ValueError(
"unsupported random type: '" + str(random_type) + "'.")
else:
self._random_type = random_type
return random_type
# ---Probing methods---
def probing_run(self, callee):
""" controls the generation, evaluation and finalization of probes """
sum_of_probes = 0
sum_of_squares = 0
self.reset()
for index in xrange(self.probe_count):
current_probe = self.get_probe(index)
pre_result = self.process_probe(callee, current_probe, index)
result = self.finish_probe(current_probe, pre_result)
sum_of_probes += result
if self.compute_variance:
sum_of_squares += result.conjugate() * result
self.finish_probe(current_probe, pre_result)
mean_and_variance = self.finalize(sum_of_probes, sum_of_squares)
return mean_and_variance
def reset(self):
super(Prober, self).reset()
def get_probe(self, index):
""" layer of abstraction for potential probe-caching """
......@@ -113,21 +115,8 @@ class ProbingOperator(EndomorphicOperator):
""" processes a probe """
return callee(probe, **kwargs)
@abc.abstractmethod
def finish_probe(self, probe, pre_result):
return pre_result
def finalize(self, sum_of_probes, sum_of_squares):
probe_count = self.probe_count
mean_of_probes = sum_of_probes/probe_count
if self.compute_variance:
# variance = 1/(n-1) (sum(x^2) - 1/n*sum(x)^2)
variance = ((sum_of_squares - sum_of_probes*mean_of_probes) /
(probe_count-1))
else:
variance = None
return (mean_of_probes, variance)
super(Prober, self).finish_probe(probe, pre_result)
def __call__(self, callee):
return self.probing_run(callee)
......@@ -150,8 +150,7 @@ from keepers import Loggable,\
Versionable
class Space(Versionable, Loggable, Plottable, object):
class Space(Versionable, Loggable, object):
"""
.. __ __
.. /__/ / /_
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment