diff --git a/nifty/probing/__init__.py b/nifty/operators/probing_operator/__init__.py similarity index 100% rename from nifty/probing/__init__.py rename to nifty/operators/probing_operator/__init__.py diff --git a/nifty/probing/diagonal_prober.py b/nifty/operators/probing_operator/diagonal_prober.py similarity index 100% rename from nifty/probing/diagonal_prober.py rename to nifty/operators/probing_operator/diagonal_prober.py diff --git a/nifty/probing/prober.py b/nifty/operators/probing_operator/probing_operator.py similarity index 72% rename from nifty/probing/prober.py rename to nifty/operators/probing_operator/probing_operator.py index a572a97f6f002d0abeb437e320c630cdef96b93f..5e84bf7c2324b01434ffe9c600e112dca997d125 100644 --- a/nifty/probing/prober.py +++ b/nifty/operators/probing_operator/probing_operator.py @@ -4,85 +4,54 @@ import abc import numpy as np -from nifty.field_types import FieldType -from nifty.spaces import Space from nifty.field import Field +from nifty.operators.endomorphic_operator import EndomorphicOperator from d2o import STRATEGIES as DISTRIBUTION_STRATEGIES -class Prober(object): - __metaclass__ = abc.ABCMeta +class ProbingOperator(EndomorphicOperator): + """ + aka DiagonalProbingOperator + """ + + # ---Overwritten properties and methods--- def __init__(self, domain=None, field_type=None, distribution_strategy=None, probe_count=8, random_type='pm1', compute_variance=False): - self.domain = domain - self.field_type = field_type + self._domain = self._parse_domain(domain) + self._field_type = self._parse_field_type(field_type) + self._distribution_strategy = \ + self._parse_distribution_strategy(distribution_strategy) self.distribution_strategy = distribution_strategy self.probe_count = probe_count self.random_type = random_type self.compute_variance = bool(compute_variance) - def _parse_domain(self, domain): - if domain is None: - domain = () - elif isinstance(domain, Space): - domain = (domain,) - elif not isinstance(domain, tuple): - domain = tuple(domain) - - for d in domain: - if not isinstance(d, Space): - raise TypeError( - "Given object contains something that is not a " - "nifty.space.") - return domain - - def _parse_field_type(self, field_type): - if field_type is None: - field_type = () - elif isinstance(field_type, FieldType): - field_type = (field_type,) - elif not isinstance(field_type, tuple): - field_type = tuple(field_type) - - for ft in field_type: - if not isinstance(ft, FieldType): - raise TypeError( - "Given object is not a nifty.FieldType.") - return field_type - - # ---Properties--- + # ---Mandatory properties and methods--- @property def domain(self): return self._domain - @domain.setter - def domain(self, domain): - self._domain = self._parse_domain(domain) - @property def field_type(self): return self._field_type - @field_type.setter - def field_type(self, field_type): - self._field_type = self._parse_field_type(field_type) + # ---Added properties and methods--- @property def distribution_strategy(self): return self._distribution_strategy - @distribution_strategy.setter - def distribution_strategy(self, distribution_strategy): + def _parse_distribution_strategy(self, distribution_strategy): distribution_strategy = str(distribution_strategy) if distribution_strategy not in DISTRIBUTION_STRATEGIES['global']: raise ValueError("distribution_strategy must be a global-type " "strategy.") - self._distribution_strategy = distribution_strategy + return distribution_strategy @property def probe_count(self): diff --git a/nifty/probing/trace_prober.py b/nifty/operators/probing_operator/trace_prober.py similarity index 100% rename from nifty/probing/trace_prober.py rename to nifty/operators/probing_operator/trace_prober.py