Skip to content
Snippets Groups Projects
Commit cd593c77 authored by Theo Steininger's avatar Theo Steininger
Browse files

Renamed probing files.

parent 27b869ee
No related branches found
No related tags found
No related merge requests found
Pipeline #
...@@ -4,85 +4,54 @@ import abc ...@@ -4,85 +4,54 @@ import abc
import numpy as np import numpy as np
from nifty.field_types import FieldType
from nifty.spaces import Space
from nifty.field import Field from nifty.field import Field
from nifty.operators.endomorphic_operator import EndomorphicOperator
from d2o import STRATEGIES as DISTRIBUTION_STRATEGIES from d2o import STRATEGIES as DISTRIBUTION_STRATEGIES
class Prober(object): class ProbingOperator(EndomorphicOperator):
__metaclass__ = abc.ABCMeta """
aka DiagonalProbingOperator
"""
# ---Overwritten properties and methods---
def __init__(self, domain=None, field_type=None, def __init__(self, domain=None, field_type=None,
distribution_strategy=None, probe_count=8, distribution_strategy=None, probe_count=8,
random_type='pm1', compute_variance=False): random_type='pm1', compute_variance=False):
self.domain = domain self._domain = self._parse_domain(domain)
self.field_type = field_type self._field_type = self._parse_field_type(field_type)
self._distribution_strategy = \
self._parse_distribution_strategy(distribution_strategy)
self.distribution_strategy = distribution_strategy self.distribution_strategy = distribution_strategy
self.probe_count = probe_count self.probe_count = probe_count
self.random_type = random_type self.random_type = random_type
self.compute_variance = bool(compute_variance) self.compute_variance = bool(compute_variance)
def _parse_domain(self, domain): # ---Mandatory properties and methods---
if domain is None:
domain = ()
elif isinstance(domain, Space):
domain = (domain,)
elif not isinstance(domain, tuple):
domain = tuple(domain)
for d in domain:
if not isinstance(d, Space):
raise TypeError(
"Given object contains something that is not a "
"nifty.space.")
return domain
def _parse_field_type(self, field_type):
if field_type is None:
field_type = ()
elif isinstance(field_type, FieldType):
field_type = (field_type,)
elif not isinstance(field_type, tuple):
field_type = tuple(field_type)
for ft in field_type:
if not isinstance(ft, FieldType):
raise TypeError(
"Given object is not a nifty.FieldType.")
return field_type
# ---Properties---
@property @property
def domain(self): def domain(self):
return self._domain return self._domain
@domain.setter
def domain(self, domain):
self._domain = self._parse_domain(domain)
@property @property
def field_type(self): def field_type(self):
return self._field_type return self._field_type
@field_type.setter # ---Added properties and methods---
def field_type(self, field_type):
self._field_type = self._parse_field_type(field_type)
@property @property
def distribution_strategy(self): def distribution_strategy(self):
return self._distribution_strategy return self._distribution_strategy
@distribution_strategy.setter def _parse_distribution_strategy(self, distribution_strategy):
def distribution_strategy(self, distribution_strategy):
distribution_strategy = str(distribution_strategy) distribution_strategy = str(distribution_strategy)
if distribution_strategy not in DISTRIBUTION_STRATEGIES['global']: if distribution_strategy not in DISTRIBUTION_STRATEGIES['global']:
raise ValueError("distribution_strategy must be a global-type " raise ValueError("distribution_strategy must be a global-type "
"strategy.") "strategy.")
self._distribution_strategy = distribution_strategy return distribution_strategy
@property @property
def probe_count(self): def probe_count(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment