Commit 9c429547 authored by Theo Steininger's avatar Theo Steininger
Browse files

Finished refactoring of probing classes. Now uses mixin-classes.

parent cd593c77
...@@ -278,3 +278,36 @@ def get_default_codomain(domain): ...@@ -278,3 +278,36 @@ def get_default_codomain(domain):
return LMGLTransformation.get_codomain(domain) return LMGLTransformation.get_codomain(domain)
else: else:
raise TypeError('ERROR: unknown domain') raise TypeError('ERROR: unknown domain')
def parse_domain(domain):
from nifty.spaces.space import Space
if domain is None:
domain = ()
elif isinstance(domain, Space):
domain = (domain,)
elif not isinstance(domain, tuple):
domain = tuple(domain)
for d in domain:
if not isinstance(d, Space):
raise TypeError(
"Given object contains something that is not a "
"nifty.space.")
return domain
def parse_field_type(field_type):
from nifty.field_types import FieldType
if field_type is None:
field_type = ()
elif isinstance(field_type, FieldType):
field_type = (field_type,)
elif not isinstance(field_type, tuple):
field_type = tuple(field_type)
for ft in field_type:
if not isinstance(ft, FieldType):
raise TypeError(
"Given object is not a nifty.FieldType.")
return field_type
from transformations import * from transformations import *
from fft_operator import FFTOperator from fft_operator import FFTOperator
...@@ -4,8 +4,6 @@ import abc ...@@ -4,8 +4,6 @@ import abc
from keepers import Loggable from keepers import Loggable
from nifty.field import Field from nifty.field import Field
from nifty.spaces import Space
from nifty.field_types import FieldType
import nifty.nifty_utilities as utilities import nifty.nifty_utilities as utilities
...@@ -16,33 +14,10 @@ class LinearOperator(Loggable, object): ...@@ -16,33 +14,10 @@ class LinearOperator(Loggable, object):
pass pass
def _parse_domain(self, domain): def _parse_domain(self, domain):
if domain is None: return utilities.parse_domain(domain)
domain = ()
elif isinstance(domain, Space):
domain = (domain,)
elif not isinstance(domain, tuple):
domain = tuple(domain)
for d in domain:
if not isinstance(d, Space):
raise TypeError(
"Given object contains something that is not a "
"nifty.space.")
return domain
def _parse_field_type(self, field_type): def _parse_field_type(self, field_type):
if field_type is None: return utilities.parse_field_type(field_type)
field_type = ()
elif isinstance(field_type, FieldType):
field_type = (field_type,)
elif not isinstance(field_type, tuple):
field_type = tuple(field_type)
for ft in field_type:
if not isinstance(ft, FieldType):
raise TypeError(
"Given object is not a nifty.FieldType.")
return field_type
@abc.abstractproperty @abc.abstractproperty
def domain(self): def domain(self):
......
# -*- coding: utf-8 -*-
from prober import Prober
class DiagonalProber(Prober):
# ---Mandatory properties and methods---
def finish_probe(self, probe, pre_result):
return probe[1].conjugate()*pre_result
# -*- coding: utf-8 -*-
from prober import Prober
class TraceProber(Prober):
# ---Mandatory properties and methods---
def finish_probe(self, probe, pre_result):
return probe[1].conjugate().weight(power=-1).dot(pre_result)
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from prober import Prober from prober import Prober
from diagonal_prober import * from mixin_classes import *
from trace_prober import *
# -*- coding: utf-8 -*-
from mixin_base import MixinBase
from diagonal_prober_mixin import DiagonalProberMixin
from trace_prober_mixin import TraceProberMixin
# -*- coding: utf-8 -*-
from mixin_base import MixinBase
class DiagonalProberMixin(MixinBase):
def __init__(self):
self.reset()
super(DiagonalProberMixin, self).__init__()
def reset(self):
self.__sum_of_probings = 0
self.__sum_of_squares = 0
self.__diagonal = None
self.__diagonal_variance = None
super(DiagonalProberMixin, self).reset()
def finish_probe(self, probe, pre_result):
result = probe[1].conjugate()*pre_result
self.__sum_of_probings += result
if self.compute_variance:
self.__sum_of_squares += result.conjugate() * result
super(DiagonalProberMixin, self).finish_probe(probe, pre_result)
@property
def diagonal(self):
if self.__diagonal is None:
self.__diagonal = self.__sum_of_probings/self.probe_count
return self.__diagonal
@property
def diagonal_variance(self):
if not self.compute_variance:
raise AttributeError("self.compute_variance is set to False")
if self.__diagonal_variance is None:
# variance = 1/(n-1) (sum(x^2) - 1/n*sum(x)^2)
n = self.probe_count
sum_pr = self.__sum_of_probings
mean = self.diagonal
sum_sq = self.__sum_of_squares
self.__diagonal_variance = ((sum_sq - sum_pr*mean) / (n-1))
return self.__diagonal_variance
# -*- coding: utf-8 -*-
class MixinBase(object):
def reset(self, *args, **kwargs):
pass
def finish_probe(self, *args, **kwargs):
pass
# -*- coding: utf-8 -*-
from mixin_base import MixinBase
class TraceProberMixin(MixinBase):
def __init__(self):
self.reset()
super(TraceProberMixin, self).__init__()
def reset(self):
self.__sum_of_probings = 0
self.__sum_of_squares = 0
self.__trace = None
self.__trace_variance = None
super(TraceProberMixin, self).reset()
def finish_probe(self, probe, pre_result):
result = probe[1].dot(pre_result, bare=True)
self.__sum_of_probings += result
if self.compute_variance:
self.__sum_of_squares += result.conjugate() * result
super(TraceProberMixin, self).finish_probe(probe, pre_result)
@property
def trace(self):
if self.__trace is None:
self.__trace = self.__sum_of_probings/self.probe_count
return self.__trace
@property
def trace_variance(self):
if not self.compute_variance:
raise AttributeError("self.compute_variance is set to False")
if self.__trace_variance is None:
# variance = 1/(n-1) (sum(x^2) - 1/n*sum(x)^2)
n = self.probe_count
sum_pr = self.__sum_of_probings
mean = self.trace
sum_sq = self.__sum_of_squares
self.__trace_variance = ((sum_sq - sum_pr*mean) / (n-1))
return self.__trace_variance
# -*- coding: utf-8 -*-
from prober import Prober
...@@ -4,33 +4,43 @@ import abc ...@@ -4,33 +4,43 @@ import abc
import numpy as np import numpy as np
from nifty.field_types import FieldType
from nifty.spaces import Space
from nifty.field import Field from nifty.field import Field
from nifty.operators.endomorphic_operator import EndomorphicOperator import nifty.nifty_utilities as utilities
from nifty import nifty_configuration as nc
from d2o import STRATEGIES as DISTRIBUTION_STRATEGIES from d2o import STRATEGIES as DISTRIBUTION_STRATEGIES
class ProbingOperator(EndomorphicOperator): class Prober(object):
""" """
aka DiagonalProbingOperator See the following webpages for the principles behind the usage of
mixin-classes
https://www.python.org/download/releases/2.2.3/descrintro/#cooperation
https://rhettinger.wordpress.com/2011/05/26/super-considered-super/
""" """
# ---Overwritten properties and methods--- __metaclass__ = abc.ABCMeta
def __init__(self, domain=None, field_type=None, def __init__(self, domain=None, field_type=None,
distribution_strategy=None, probe_count=8, distribution_strategy=None, probe_count=8,
random_type='pm1', compute_variance=False): random_type='pm1', compute_variance=False):
self._domain = self._parse_domain(domain) self._domain = utilities.parse_domain(domain)
self._field_type = self._parse_field_type(field_type) self._field_type = utilities.parse_field_type(field_type)
self._distribution_strategy = \ self._distribution_strategy = \
self._parse_distribution_strategy(distribution_strategy) self._parse_distribution_strategy(distribution_strategy)
self.distribution_strategy = distribution_strategy self._probe_count = self._parse_probe_count(probe_count)
self.probe_count = probe_count self._random_type = self._parse_random_type(random_type)
self.random_type = random_type
self.compute_variance = bool(compute_variance) self.compute_variance = bool(compute_variance)
# ---Mandatory properties and methods--- super(Prober, self).__init__()
# ---Properties---
@property @property
def domain(self): def domain(self):
...@@ -40,57 +50,49 @@ class ProbingOperator(EndomorphicOperator): ...@@ -40,57 +50,49 @@ class ProbingOperator(EndomorphicOperator):
def field_type(self): def field_type(self):
return self._field_type return self._field_type
# ---Added properties and methods---
@property @property
def distribution_strategy(self): def distribution_strategy(self):
return self._distribution_strategy return self._distribution_strategy
def _parse_distribution_strategy(self, distribution_strategy): def _parse_distribution_strategy(self, distribution_strategy):
distribution_strategy = str(distribution_strategy) if distribution_strategy is None:
distribution_strategy = nc['default_distribution_strategy']
else:
distribution_strategy = str(distribution_strategy)
if distribution_strategy not in DISTRIBUTION_STRATEGIES['global']: if distribution_strategy not in DISTRIBUTION_STRATEGIES['global']:
raise ValueError("distribution_strategy must be a global-type " raise ValueError("distribution_strategy must be a global-type "
"strategy.") "strategy.")
return distribution_strategy self._distribution_strategy = distribution_strategy
@property @property
def probe_count(self): def probe_count(self):
return self._probe_count return self._probe_count
@probe_count.setter def _parse_probe_count(self, probe_count):
def probe_count(self, probe_count): return int(probe_count)
self._probe_count = int(probe_count)
@property @property
def random_type(self): def random_type(self):
return self._random_type return self._random_type
@random_type.setter def _parse_random_type(self, random_type):
def random_type(self, random_type):
if random_type not in ["pm1", "normal"]: if random_type not in ["pm1", "normal"]:
raise ValueError( raise ValueError(
"unsupported random type: '" + str(random_type) + "'.") "unsupported random type: '" + str(random_type) + "'.")
else: return random_type
self._random_type = random_type
# ---Probing methods--- # ---Probing methods---
def probing_run(self, callee): def probing_run(self, callee):
""" controls the generation, evaluation and finalization of probes """ """ controls the generation, evaluation and finalization of probes """
sum_of_probes = 0 self.reset()
sum_of_squares = 0
for index in xrange(self.probe_count): for index in xrange(self.probe_count):
current_probe = self.get_probe(index) current_probe = self.get_probe(index)
pre_result = self.process_probe(callee, current_probe, index) pre_result = self.process_probe(callee, current_probe, index)
result = self.finish_probe(current_probe, pre_result) self.finish_probe(current_probe, pre_result)
sum_of_probes += result
if self.compute_variance:
sum_of_squares += result.conjugate() * result
mean_and_variance = self.finalize(sum_of_probes, sum_of_squares) def reset(self):
return mean_and_variance super(Prober, self).reset()
def get_probe(self, index): def get_probe(self, index):
""" layer of abstraction for potential probe-caching """ """ layer of abstraction for potential probe-caching """
...@@ -113,21 +115,8 @@ class ProbingOperator(EndomorphicOperator): ...@@ -113,21 +115,8 @@ class ProbingOperator(EndomorphicOperator):
""" processes a probe """ """ processes a probe """
return callee(probe, **kwargs) return callee(probe, **kwargs)
@abc.abstractmethod
def finish_probe(self, probe, pre_result): def finish_probe(self, probe, pre_result):
return pre_result super(Prober, self).finish_probe(probe, pre_result)
def finalize(self, sum_of_probes, sum_of_squares):
probe_count = self.probe_count
mean_of_probes = sum_of_probes/probe_count
if self.compute_variance:
# variance = 1/(n-1) (sum(x^2) - 1/n*sum(x)^2)
variance = ((sum_of_squares - sum_of_probes*mean_of_probes) /
(probe_count-1))
else:
variance = None
return (mean_of_probes, variance)
def __call__(self, callee): def __call__(self, callee):
return self.probing_run(callee) return self.probing_run(callee)
...@@ -150,8 +150,7 @@ from keepers import Loggable,\ ...@@ -150,8 +150,7 @@ from keepers import Loggable,\
Versionable Versionable
class Space(Versionable, Loggable, object):
class Space(Versionable, Loggable, Plottable, object):
""" """
.. __ __ .. __ __
.. /__/ / /_ .. /__/ / /_
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment