diff --git a/nifty/nifty_utilities.py b/nifty/nifty_utilities.py index 70b4bdb261bc1030ecbfa6d660b627dd17c0e261..f1ff27c9dfff60f884f33e2642e797f99721ce55 100644 --- a/nifty/nifty_utilities.py +++ b/nifty/nifty_utilities.py @@ -278,3 +278,36 @@ def get_default_codomain(domain): return LMGLTransformation.get_codomain(domain) else: raise TypeError('ERROR: unknown domain') + + +def parse_domain(domain): + from nifty.spaces.space import Space + if domain is None: + domain = () + elif isinstance(domain, Space): + domain = (domain,) + elif not isinstance(domain, tuple): + domain = tuple(domain) + + for d in domain: + if not isinstance(d, Space): + raise TypeError( + "Given object contains something that is not a " + "nifty.space.") + return domain + + +def parse_field_type(field_type): + from nifty.field_types import FieldType + if field_type is None: + field_type = () + elif isinstance(field_type, FieldType): + field_type = (field_type,) + elif not isinstance(field_type, tuple): + field_type = tuple(field_type) + + for ft in field_type: + if not isinstance(ft, FieldType): + raise TypeError( + "Given object is not a nifty.FieldType.") + return field_type diff --git a/nifty/operators/fft_operator/__init__.py b/nifty/operators/fft_operator/__init__.py index 7408a958619be60189d1da75096ebf95e6eb91ad..c7ed7094026b4d0cf6175d0794c95465734319a6 100644 --- a/nifty/operators/fft_operator/__init__.py +++ b/nifty/operators/fft_operator/__init__.py @@ -1,2 +1,3 @@ + from transformations import * from fft_operator import FFTOperator diff --git a/nifty/operators/linear_operator/linear_operator.py b/nifty/operators/linear_operator/linear_operator.py index 1cf5b6f17a0ac7a81024fdc151d1984b3222c89a..1a1361892eb49ebf5157b28f51546a8feab1f1c3 100644 --- a/nifty/operators/linear_operator/linear_operator.py +++ b/nifty/operators/linear_operator/linear_operator.py @@ -4,8 +4,6 @@ import abc from keepers import Loggable from nifty.field import Field -from nifty.spaces import Space -from nifty.field_types import FieldType import nifty.nifty_utilities as utilities @@ -16,33 +14,10 @@ class LinearOperator(Loggable, object): pass def _parse_domain(self, domain): - if domain is None: - domain = () - elif isinstance(domain, Space): - domain = (domain,) - elif not isinstance(domain, tuple): - domain = tuple(domain) - - for d in domain: - if not isinstance(d, Space): - raise TypeError( - "Given object contains something that is not a " - "nifty.space.") - return domain + return utilities.parse_domain(domain) def _parse_field_type(self, field_type): - if field_type is None: - field_type = () - elif isinstance(field_type, FieldType): - field_type = (field_type,) - elif not isinstance(field_type, tuple): - field_type = tuple(field_type) - - for ft in field_type: - if not isinstance(ft, FieldType): - raise TypeError( - "Given object is not a nifty.FieldType.") - return field_type + return utilities.parse_field_type(field_type) @abc.abstractproperty def domain(self): diff --git a/nifty/operators/probing_operator/__init__.py b/nifty/operators/probing_operator/__init__.py deleted file mode 100644 index 8d200af309ef84fa57449318502f31a302117783..0000000000000000000000000000000000000000 --- a/nifty/operators/probing_operator/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# -*- coding: utf-8 -*- - -from prober import Prober -from diagonal_prober import * -from trace_prober import * diff --git a/nifty/operators/probing_operator/diagonal_prober.py b/nifty/operators/probing_operator/diagonal_prober.py deleted file mode 100644 index 251c116975b8ad851fd44c4246c7c2b0fed77486..0000000000000000000000000000000000000000 --- a/nifty/operators/probing_operator/diagonal_prober.py +++ /dev/null @@ -1,10 +0,0 @@ -# -*- coding: utf-8 -*- - -from prober import Prober - - -class DiagonalProber(Prober): - - # ---Mandatory properties and methods--- - def finish_probe(self, probe, pre_result): - return probe[1].conjugate()*pre_result diff --git a/nifty/operators/probing_operator/trace_prober.py b/nifty/operators/probing_operator/trace_prober.py deleted file mode 100644 index dab7c50cdab5302e8b20d5177ae8a1bdb0a4eca4..0000000000000000000000000000000000000000 --- a/nifty/operators/probing_operator/trace_prober.py +++ /dev/null @@ -1,10 +0,0 @@ -# -*- coding: utf-8 -*- - -from prober import Prober - - -class TraceProber(Prober): - - # ---Mandatory properties and methods--- - def finish_probe(self, probe, pre_result): - return probe[1].conjugate().weight(power=-1).dot(pre_result) diff --git a/nifty/probing/__init__.py b/nifty/probing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..518372d0250dccbd1d5b0e233a248d8ade364e9c --- /dev/null +++ b/nifty/probing/__init__.py @@ -0,0 +1,4 @@ +# -*- coding: utf-8 -*- + +from prober import Prober +from mixin_classes import * diff --git a/nifty/probing/mixin_classes/__init__.py b/nifty/probing/mixin_classes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cc58e32c51f0dfe432fa303fe4fac2efe2e2c270 --- /dev/null +++ b/nifty/probing/mixin_classes/__init__.py @@ -0,0 +1,5 @@ +# -*- coding: utf-8 -*- + +from mixin_base import MixinBase +from diagonal_prober_mixin import DiagonalProberMixin +from trace_prober_mixin import TraceProberMixin diff --git a/nifty/probing/mixin_classes/diagonal_prober_mixin.py b/nifty/probing/mixin_classes/diagonal_prober_mixin.py new file mode 100644 index 0000000000000000000000000000000000000000..4b9136d96391c288510b12065324667681be05d7 --- /dev/null +++ b/nifty/probing/mixin_classes/diagonal_prober_mixin.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- + +from mixin_base import MixinBase + + +class DiagonalProberMixin(MixinBase): + def __init__(self): + self.reset() + super(DiagonalProberMixin, self).__init__() + + def reset(self): + self.__sum_of_probings = 0 + self.__sum_of_squares = 0 + self.__diagonal = None + self.__diagonal_variance = None + super(DiagonalProberMixin, self).reset() + + def finish_probe(self, probe, pre_result): + result = probe[1].conjugate()*pre_result + self.__sum_of_probings += result + if self.compute_variance: + self.__sum_of_squares += result.conjugate() * result + super(DiagonalProberMixin, self).finish_probe(probe, pre_result) + + @property + def diagonal(self): + if self.__diagonal is None: + self.__diagonal = self.__sum_of_probings/self.probe_count + return self.__diagonal + + @property + def diagonal_variance(self): + if not self.compute_variance: + raise AttributeError("self.compute_variance is set to False") + if self.__diagonal_variance is None: + # variance = 1/(n-1) (sum(x^2) - 1/n*sum(x)^2) + n = self.probe_count + sum_pr = self.__sum_of_probings + mean = self.diagonal + sum_sq = self.__sum_of_squares + + self.__diagonal_variance = ((sum_sq - sum_pr*mean) / (n-1)) + return self.__diagonal_variance diff --git a/nifty/probing/mixin_classes/mixin_base.py b/nifty/probing/mixin_classes/mixin_base.py new file mode 100644 index 0000000000000000000000000000000000000000..c54320d4252be87c1fbe997a5341413b5cda5d81 --- /dev/null +++ b/nifty/probing/mixin_classes/mixin_base.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- + + +class MixinBase(object): + def reset(self, *args, **kwargs): + pass + + def finish_probe(self, *args, **kwargs): + pass diff --git a/nifty/probing/mixin_classes/trace_prober_mixin.py b/nifty/probing/mixin_classes/trace_prober_mixin.py new file mode 100644 index 0000000000000000000000000000000000000000..a32760216aeee9c36db7d8ef2c267b816a5ff20c --- /dev/null +++ b/nifty/probing/mixin_classes/trace_prober_mixin.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- + +from mixin_base import MixinBase + + +class TraceProberMixin(MixinBase): + def __init__(self): + self.reset() + super(TraceProberMixin, self).__init__() + + def reset(self): + self.__sum_of_probings = 0 + self.__sum_of_squares = 0 + self.__trace = None + self.__trace_variance = None + super(TraceProberMixin, self).reset() + + def finish_probe(self, probe, pre_result): + result = probe[1].dot(pre_result, bare=True) + self.__sum_of_probings += result + if self.compute_variance: + self.__sum_of_squares += result.conjugate() * result + super(TraceProberMixin, self).finish_probe(probe, pre_result) + + @property + def trace(self): + if self.__trace is None: + self.__trace = self.__sum_of_probings/self.probe_count + return self.__trace + + @property + def trace_variance(self): + if not self.compute_variance: + raise AttributeError("self.compute_variance is set to False") + if self.__trace_variance is None: + # variance = 1/(n-1) (sum(x^2) - 1/n*sum(x)^2) + n = self.probe_count + sum_pr = self.__sum_of_probings + mean = self.trace + sum_sq = self.__sum_of_squares + + self.__trace_variance = ((sum_sq - sum_pr*mean) / (n-1)) + return self.__trace_variance diff --git a/nifty/probing/prober/__init__.py b/nifty/probing/prober/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..813e0ec0bcf95dcd6bae5a88d38b634963bede56 --- /dev/null +++ b/nifty/probing/prober/__init__.py @@ -0,0 +1,3 @@ +# -*- coding: utf-8 -*- + +from prober import Prober diff --git a/nifty/operators/probing_operator/probing_operator.py b/nifty/probing/prober/prober.py similarity index 60% rename from nifty/operators/probing_operator/probing_operator.py rename to nifty/probing/prober/prober.py index 5e84bf7c2324b01434ffe9c600e112dca997d125..e563c72b15f6bc21df537795eba964781d9be6aa 100644 --- a/nifty/operators/probing_operator/probing_operator.py +++ b/nifty/probing/prober/prober.py @@ -4,33 +4,43 @@ import abc import numpy as np +from nifty.field_types import FieldType +from nifty.spaces import Space from nifty.field import Field -from nifty.operators.endomorphic_operator import EndomorphicOperator +import nifty.nifty_utilities as utilities + +from nifty import nifty_configuration as nc from d2o import STRATEGIES as DISTRIBUTION_STRATEGIES -class ProbingOperator(EndomorphicOperator): +class Prober(object): """ - aka DiagonalProbingOperator + See the following webpages for the principles behind the usage of + mixin-classes + + https://www.python.org/download/releases/2.2.3/descrintro/#cooperation + https://rhettinger.wordpress.com/2011/05/26/super-considered-super/ + """ - # ---Overwritten properties and methods--- + __metaclass__ = abc.ABCMeta def __init__(self, domain=None, field_type=None, distribution_strategy=None, probe_count=8, random_type='pm1', compute_variance=False): - self._domain = self._parse_domain(domain) - self._field_type = self._parse_field_type(field_type) + self._domain = utilities.parse_domain(domain) + self._field_type = utilities.parse_field_type(field_type) self._distribution_strategy = \ self._parse_distribution_strategy(distribution_strategy) - self.distribution_strategy = distribution_strategy - self.probe_count = probe_count - self.random_type = random_type + self._probe_count = self._parse_probe_count(probe_count) + self._random_type = self._parse_random_type(random_type) self.compute_variance = bool(compute_variance) - # ---Mandatory properties and methods--- + super(Prober, self).__init__() + + # ---Properties--- @property def domain(self): @@ -40,57 +50,49 @@ class ProbingOperator(EndomorphicOperator): def field_type(self): return self._field_type - # ---Added properties and methods--- - @property def distribution_strategy(self): return self._distribution_strategy def _parse_distribution_strategy(self, distribution_strategy): - distribution_strategy = str(distribution_strategy) + if distribution_strategy is None: + distribution_strategy = nc['default_distribution_strategy'] + else: + distribution_strategy = str(distribution_strategy) if distribution_strategy not in DISTRIBUTION_STRATEGIES['global']: raise ValueError("distribution_strategy must be a global-type " "strategy.") - return distribution_strategy + self._distribution_strategy = distribution_strategy @property def probe_count(self): return self._probe_count - @probe_count.setter - def probe_count(self, probe_count): - self._probe_count = int(probe_count) + def _parse_probe_count(self, probe_count): + return int(probe_count) @property def random_type(self): return self._random_type - @random_type.setter - def random_type(self, random_type): + def _parse_random_type(self, random_type): if random_type not in ["pm1", "normal"]: raise ValueError( "unsupported random type: '" + str(random_type) + "'.") - else: - self._random_type = random_type + return random_type # ---Probing methods--- def probing_run(self, callee): """ controls the generation, evaluation and finalization of probes """ - sum_of_probes = 0 - sum_of_squares = 0 - + self.reset() for index in xrange(self.probe_count): current_probe = self.get_probe(index) pre_result = self.process_probe(callee, current_probe, index) - result = self.finish_probe(current_probe, pre_result) - - sum_of_probes += result - if self.compute_variance: - sum_of_squares += result.conjugate() * result + self.finish_probe(current_probe, pre_result) - mean_and_variance = self.finalize(sum_of_probes, sum_of_squares) - return mean_and_variance + def reset(self): + super(Prober, self).reset() def get_probe(self, index): """ layer of abstraction for potential probe-caching """ @@ -113,21 +115,8 @@ class ProbingOperator(EndomorphicOperator): """ processes a probe """ return callee(probe, **kwargs) - @abc.abstractmethod def finish_probe(self, probe, pre_result): - return pre_result - - def finalize(self, sum_of_probes, sum_of_squares): - probe_count = self.probe_count - mean_of_probes = sum_of_probes/probe_count - if self.compute_variance: - # variance = 1/(n-1) (sum(x^2) - 1/n*sum(x)^2) - variance = ((sum_of_squares - sum_of_probes*mean_of_probes) / - (probe_count-1)) - else: - variance = None - - return (mean_of_probes, variance) + super(Prober, self).finish_probe(probe, pre_result) def __call__(self, callee): return self.probing_run(callee) diff --git a/nifty/spaces/space/space.py b/nifty/spaces/space/space.py index 80d65ad3f882e1a53c88d5c1f001febe4ecb68b3..72628a040ca6826a069a2dbb74decbe01bb6a5dc 100644 --- a/nifty/spaces/space/space.py +++ b/nifty/spaces/space/space.py @@ -150,8 +150,7 @@ from keepers import Loggable,\ Versionable - -class Space(Versionable, Loggable, Plottable, object): +class Space(Versionable, Loggable, object): """ .. __ __ .. /__/ / /_