Commit cd593c77 authored by Theo Steininger's avatar Theo Steininger

Renamed probing files.

parent 27b869ee
Pipeline #9649 failed with stages
in 23 minutes and 35 seconds
......@@ -4,85 +4,54 @@ import abc
import numpy as np
from nifty.field_types import FieldType
from nifty.spaces import Space
from nifty.field import Field
from nifty.operators.endomorphic_operator import EndomorphicOperator
from d2o import STRATEGIES as DISTRIBUTION_STRATEGIES
class Prober(object):
__metaclass__ = abc.ABCMeta
class ProbingOperator(EndomorphicOperator):
"""
aka DiagonalProbingOperator
"""
# ---Overwritten properties and methods---
def __init__(self, domain=None, field_type=None,
distribution_strategy=None, probe_count=8,
random_type='pm1', compute_variance=False):
self.domain = domain
self.field_type = field_type
self._domain = self._parse_domain(domain)
self._field_type = self._parse_field_type(field_type)
self._distribution_strategy = \
self._parse_distribution_strategy(distribution_strategy)
self.distribution_strategy = distribution_strategy
self.probe_count = probe_count
self.random_type = random_type
self.compute_variance = bool(compute_variance)
def _parse_domain(self, domain):
if domain is None:
domain = ()
elif isinstance(domain, Space):
domain = (domain,)
elif not isinstance(domain, tuple):
domain = tuple(domain)
for d in domain:
if not isinstance(d, Space):
raise TypeError(
"Given object contains something that is not a "
"nifty.space.")
return domain
def _parse_field_type(self, field_type):
if field_type is None:
field_type = ()
elif isinstance(field_type, FieldType):
field_type = (field_type,)
elif not isinstance(field_type, tuple):
field_type = tuple(field_type)
for ft in field_type:
if not isinstance(ft, FieldType):
raise TypeError(
"Given object is not a nifty.FieldType.")
return field_type
# ---Properties---
# ---Mandatory properties and methods---
@property
def domain(self):
return self._domain
@domain.setter
def domain(self, domain):
self._domain = self._parse_domain(domain)
@property
def field_type(self):
return self._field_type
@field_type.setter
def field_type(self, field_type):
self._field_type = self._parse_field_type(field_type)
# ---Added properties and methods---
@property
def distribution_strategy(self):
return self._distribution_strategy
@distribution_strategy.setter
def distribution_strategy(self, distribution_strategy):
def _parse_distribution_strategy(self, distribution_strategy):
distribution_strategy = str(distribution_strategy)
if distribution_strategy not in DISTRIBUTION_STRATEGIES['global']:
raise ValueError("distribution_strategy must be a global-type "
"strategy.")
self._distribution_strategy = distribution_strategy
return distribution_strategy
@property
def probe_count(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment