Commit c2bf2cb1 authored by Theo Steininger's avatar Theo Steininger
Browse files

Merge branch 'master' into tests

parents b7e0e994 cd593c77
Pipeline #9650 canceled with stage
......@@ -4,85 +4,54 @@ import abc
import numpy as np
from nifty.field_types import FieldType
from nifty.spaces import Space
from nifty.field import Field
from nifty.operators.endomorphic_operator import EndomorphicOperator
from d2o import STRATEGIES as DISTRIBUTION_STRATEGIES
class Prober(object):
__metaclass__ = abc.ABCMeta
class ProbingOperator(EndomorphicOperator):
"""
aka DiagonalProbingOperator
"""
# ---Overwritten properties and methods---
def __init__(self, domain=None, field_type=None,
distribution_strategy=None, probe_count=8,
random_type='pm1', compute_variance=False):
self.domain = domain
self.field_type = field_type
self._domain = self._parse_domain(domain)
self._field_type = self._parse_field_type(field_type)
self._distribution_strategy = \
self._parse_distribution_strategy(distribution_strategy)
self.distribution_strategy = distribution_strategy
self.probe_count = probe_count
self.random_type = random_type
self.compute_variance = bool(compute_variance)
def _parse_domain(self, domain):
if domain is None:
domain = ()
elif isinstance(domain, Space):
domain = (domain,)
elif not isinstance(domain, tuple):
domain = tuple(domain)
for d in domain:
if not isinstance(d, Space):
raise TypeError(
"Given object contains something that is not a "
"nifty.space.")
return domain
def _parse_field_type(self, field_type):
if field_type is None:
field_type = ()
elif isinstance(field_type, FieldType):
field_type = (field_type,)
elif not isinstance(field_type, tuple):
field_type = tuple(field_type)
for ft in field_type:
if not isinstance(ft, FieldType):
raise TypeError(
"Given object is not a nifty.FieldType.")
return field_type
# ---Properties---
# ---Mandatory properties and methods---
@property
def domain(self):
return self._domain
@domain.setter
def domain(self, domain):
self._domain = self._parse_domain(domain)
@property
def field_type(self):
return self._field_type
@field_type.setter
def field_type(self, field_type):
self._field_type = self._parse_field_type(field_type)
# ---Added properties and methods---
@property
def distribution_strategy(self):
return self._distribution_strategy
@distribution_strategy.setter
def distribution_strategy(self, distribution_strategy):
def _parse_distribution_strategy(self, distribution_strategy):
distribution_strategy = str(distribution_strategy)
if distribution_strategy not in DISTRIBUTION_STRATEGIES['global']:
raise ValueError("distribution_strategy must be a global-type "
"strategy.")
self._distribution_strategy = distribution_strategy
return distribution_strategy
@property
def probe_count(self):
......
......@@ -71,7 +71,7 @@ class GLSpace(Space):
# ---Overwritten properties and methods---
def __init__(self, nlat=2, nlon=None, dtype=np.dtype('float')):
def __init__(self, nlat=2, nlon=None, dtype=None):
"""
Sets the attributes for a gl_space class instance.
......
......@@ -96,7 +96,7 @@ class HPSpace(Space):
# ---Overwritten properties and methods---
def __init__(self, nside=2, dtype=np.dtype('float')):
def __init__(self, nside=2, dtype=None):
"""
Sets the attributes for a hp_space class instance.
......
......@@ -74,7 +74,7 @@ class LMSpace(Space):
Pixel volume of the :py:class:`lm_space`, which is always 1.
"""
def __init__(self, lmax, dtype=np.dtype('complex128')):
def __init__(self, lmax, dtype=None):
"""
Sets the attributes for an lm_space class instance.
......@@ -131,11 +131,12 @@ class LMSpace(Space):
@property
def dim(self):
l = self.lmax
m = self.mmax
# the LMSpace consist of the full triangle (including -m's!),
# minus two little triangles if mmax < lmax
# dim = (((2*(l+1)-1)+1)**2/4 - 2 * (l-m)(l-m+1)/2
return np.int((l+1)**2 - (l-m)*(l-m+1.))
# dim = np.int((l+1)**2 - (l-m)*(l-m+1.))
# We fix l == m
return np.int((l+1)**2)
@property
def total_volume(self):
......
......@@ -18,7 +18,7 @@ class PowerSpace(Space):
def __init__(self, harmonic_domain=RGSpace((1,)),
distribution_strategy='not',
log=False, nbin=None, binbounds=None,
dtype=np.dtype('float')):
dtype=None):
super(PowerSpace, self).__init__(dtype)
self._ignore_for_hash += ['_pindex', '_kindex', '_rho', '_pundex',
......
......@@ -150,7 +150,8 @@ from keepers import Loggable,\
Versionable
class Space(Versionable, Loggable, object):
class Space(Versionable, Loggable, Plottable, object):
"""
.. __ __
.. /__/ / /_
......@@ -204,7 +205,10 @@ class Space(Versionable, Loggable, object):
"""
# parse dtype
self.dtype = np.dtype(dtype)
casted_dtype = np.result_type(dtype, np.float64)
if casted_dtype != dtype:
self.Logger.warning("Input dtype reset to: %s" % str(casted_dtype))
self.dtype = casted_dtype
self._ignore_for_hash = ['_global_id']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment