Skip to content
Snippets Groups Projects
Commit cd593c77 authored by Theo Steininger's avatar Theo Steininger
Browse files

Renamed probing files.

parent 27b869ee
No related branches found
No related tags found
No related merge requests found
Pipeline #
......@@ -4,85 +4,54 @@ import abc
import numpy as np
from nifty.field_types import FieldType
from nifty.spaces import Space
from nifty.field import Field
from nifty.operators.endomorphic_operator import EndomorphicOperator
from d2o import STRATEGIES as DISTRIBUTION_STRATEGIES
class Prober(object):
__metaclass__ = abc.ABCMeta
class ProbingOperator(EndomorphicOperator):
"""
aka DiagonalProbingOperator
"""
# ---Overwritten properties and methods---
def __init__(self, domain=None, field_type=None,
distribution_strategy=None, probe_count=8,
random_type='pm1', compute_variance=False):
self.domain = domain
self.field_type = field_type
self._domain = self._parse_domain(domain)
self._field_type = self._parse_field_type(field_type)
self._distribution_strategy = \
self._parse_distribution_strategy(distribution_strategy)
self.distribution_strategy = distribution_strategy
self.probe_count = probe_count
self.random_type = random_type
self.compute_variance = bool(compute_variance)
def _parse_domain(self, domain):
if domain is None:
domain = ()
elif isinstance(domain, Space):
domain = (domain,)
elif not isinstance(domain, tuple):
domain = tuple(domain)
for d in domain:
if not isinstance(d, Space):
raise TypeError(
"Given object contains something that is not a "
"nifty.space.")
return domain
def _parse_field_type(self, field_type):
if field_type is None:
field_type = ()
elif isinstance(field_type, FieldType):
field_type = (field_type,)
elif not isinstance(field_type, tuple):
field_type = tuple(field_type)
for ft in field_type:
if not isinstance(ft, FieldType):
raise TypeError(
"Given object is not a nifty.FieldType.")
return field_type
# ---Properties---
# ---Mandatory properties and methods---
@property
def domain(self):
return self._domain
@domain.setter
def domain(self, domain):
self._domain = self._parse_domain(domain)
@property
def field_type(self):
return self._field_type
@field_type.setter
def field_type(self, field_type):
self._field_type = self._parse_field_type(field_type)
# ---Added properties and methods---
@property
def distribution_strategy(self):
return self._distribution_strategy
@distribution_strategy.setter
def distribution_strategy(self, distribution_strategy):
def _parse_distribution_strategy(self, distribution_strategy):
distribution_strategy = str(distribution_strategy)
if distribution_strategy not in DISTRIBUTION_STRATEGIES['global']:
raise ValueError("distribution_strategy must be a global-type "
"strategy.")
self._distribution_strategy = distribution_strategy
return distribution_strategy
@property
def probe_count(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment