Commit cfbd598e authored by Martin Reinecke's avatar Martin Reinecke
Browse files

more DomainTuple work

parent 75145bf8
......@@ -22,6 +22,8 @@ from .version import __version__
from .field import Field
from .domain_tuple import DomainTuple
from .random import Random
from .basic_arithmetics import *
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from functools import reduce
from .domain_object import DomainObject
class DomainTuple(object):
def __init__(self, domain):
self._dom = self._parse_domain(domain)
self._axtuple = self._get_axes_tuple()
shape_tuple = tuple(sp.shape for sp in self._dom)
self._shape = reduce(lambda x, y: x + y, shape_tuple, ())
self._dim = reduce(lambda x, y: x * y, self._shape, 1)
def _get_axes_tuple(self):
i = 0
res = [None]*len(self._dom)
for idx, thing in enumerate(self._dom):
nax = len(thing.shape)
res[idx] = tuple(range(i, i+nax))
i += nax
return res
@staticmethod
def make(domain):
if isinstance(domain, DomainTuple):
return domain
return DomainTuple(domain)
@staticmethod
def _parse_domain(domain):
if domain is None:
return ()
if isinstance(domain, DomainObject):
return (domain,)
if not isinstance(domain, tuple):
domain = tuple(domain)
for d in domain:
if not isinstance(d, DomainObject):
raise TypeError(
"Given object contains something that is not an "
"instance of DomainObject class.")
return domain
def __getitem__(self, i):
return self._dom[i]
@property
def domains(self):
return self._dom
@property
def shape(self):
return self._shape
@property
def dim(self):
return self._dim
@property
def axes(self):
return self._axtuple
def __len__(self):
return len(self._dom)
def __hash__(self):
return self._dom.__hash__()
def __eq__(self, x):
if not isinstance(x, DomainTuple):
x = DomainTuple(x)
return self._dom == x._dom
def __ne__(self, x):
if not isinstance(x, DomainTuple):
x = DomainTuple(x)
return self._dom != x._dom
......@@ -22,6 +22,7 @@ import numpy as np
from .spaces.power_space import PowerSpace
from . import nifty_utilities as utilities
from .random import Random
from .domain_tuple import DomainTuple
from functools import reduce
......@@ -53,10 +54,8 @@ class Field(object):
----------
val : numpy.ndarray
domain : DomainObject
domain : DomainTuple
See Parameters.
domain_axes : tuple of tuples
Enumerates the axes of the Field
dtype : type
Contains the datatype stored in the Field.
......@@ -74,27 +73,21 @@ class Field(object):
def __init__(self, domain=None, val=None, dtype=None, copy=False):
self.domain = self._parse_domain(domain=domain, val=val)
self.domain_axes = self._get_axes_tuple(self.domain)
shape_tuple = tuple(sp.shape for sp in self.domain)
if len(shape_tuple) == 0:
global_shape = ()
else:
global_shape = reduce(lambda x, y: x + y, shape_tuple)
dtype = self._infer_dtype(dtype=dtype, val=val)
if isinstance(val, Field):
if self.domain != val.domain:
raise ValueError("Domain mismatch")
self._val = np.array(val.val, dtype=dtype, copy=copy)
elif (np.isscalar(val)):
self._val = np.full(global_shape, dtype=dtype, fill_value=val)
self._val = np.full(self.domain.shape, dtype=dtype, fill_value=val)
elif isinstance(val, np.ndarray):
if global_shape == val.shape:
if self.domain.shape == val.shape:
self._val = np.array(val, dtype=dtype, copy=copy)
else:
raise ValueError("Shape mismatch")
elif val is None:
self._val = np.empty(global_shape, dtype=dtype)
self._val = np.empty(self.domain.shape, dtype=dtype)
else:
raise TypeError("unknown source type")
......@@ -104,19 +97,9 @@ class Field(object):
if isinstance(val, Field):
return val.domain
if np.isscalar(val):
return () # empty domain tuple
return DomainTuple(()) # empty domain tuple
raise TypeError("could not infer domain from value")
return utilities.parse_domain(domain)
@staticmethod
def _get_axes_tuple(things_with_shape):
i = 0
axes_list = [None]*len(things_with_shape)
for idx, thing in enumerate(things_with_shape):
nax = len(thing.shape)
axes_list[idx] = tuple(range(i, i+nax))
i += nax
return tuple(axes_list)
return DomainTuple.make(domain)
# MR: this needs some rethinking ... do we need to have at least float64?
@staticmethod
......@@ -155,9 +138,10 @@ class Field(object):
power_synthesize
"""
domain = DomainTuple.make(domain)
generator_function = getattr(Random, random_type)
return Field(domain=domain, val=generator_function(dtype=dtype,
shape=utilities.domains2shape(domain), **kwargs))
shape=domain.shape, **kwargs))
# ---Powerspectral methods---
......@@ -240,7 +224,7 @@ class Field(object):
def _single_power_analyze(field, idx, binbounds):
power_domain = PowerSpace(field.domain[idx], binbounds)
pindex = power_domain.pindex
axes = field.domain_axes[idx]
axes = field.domain.axes[idx]
new_pindex_shape = [1] * len(field.shape)
for i, ax in enumerate(axes):
new_pindex_shape[ax] = pindex.shape[i]
......@@ -275,7 +259,7 @@ class Field(object):
local_blow_up = [slice(None)]*len(spec.shape)
# it is important to count from behind, since spec potentially
# grows with every iteration
index = self.domain_axes[i][0]-len(self.shape)
index = self.domain.axes[i][0]-len(self.shape)
local_blow_up[index] = power_space.pindex
# here, the power_spectrum is distributed into the new shape
spec = spec[local_blow_up]
......@@ -380,7 +364,7 @@ class Field(object):
The output object. The tuple contains the dimensions of the spaces
in domain.
"""
return self._val.shape
return self.domain.shape
@property
def dim(self):
......@@ -393,7 +377,7 @@ class Field(object):
out : int
The dimension of the Field.
"""
return self._val.size
return self.domain.dim
@property
def total_volume(self):
......@@ -479,8 +463,8 @@ class Field(object):
fct *= wgt
else:
new_shape = np.ones(len(self.shape), dtype=np.int)
new_shape[self.domain_axes[ind][0]:
self.domain_axes[ind][-1]+1] = wgt.shape
new_shape[self.domain.axes[ind][0]:
self.domain.axes[ind][-1]+1] = wgt.shape
wgt = wgt.reshape(new_shape)
new_field *= wgt**power
fct = fct**power
......@@ -574,7 +558,7 @@ class Field(object):
else:
spaces = utilities.cast_iseq_to_tuple(spaces)
axes_list = tuple(self.domain_axes[sp_index] for sp_index in spaces)
axes_list = tuple(self.domain.axes[sp_index] for sp_index in spaces)
if len(axes_list) > 0:
axes_list = reduce(lambda x, y: x+y, axes_list)
......
......@@ -76,31 +76,6 @@ def cast_iseq_to_tuple(seq):
return tuple(int(item) for item in seq)
def parse_domain(domain):
if domain is None:
return ()
if isinstance(domain, DomainObject):
return (domain,)
if not isinstance(domain, tuple):
domain = tuple(domain)
for d in domain:
if not isinstance(d, DomainObject):
raise TypeError(
"Given object contains something that is not an "
"instance of DomainObject-class.")
return domain
def domains2shape(domain):
domain = parse_domain(domain)
shape_tuple = tuple(sp.shape for sp in domain)
if len(shape_tuple) == 0:
return ()
else:
return reduce(lambda x, y: x + y, shape_tuple)
def bincount_axis(obj, minlength=None, weights=None, axis=None):
if minlength is not None:
length = max(np.amax(obj) + 1, minlength)
......
......@@ -18,7 +18,7 @@
from builtins import range
from ..linear_operator import LinearOperator
from ... import DomainTuple
class ComposedOperator(LinearOperator):
""" NIFTY class for composed operators.
......@@ -97,17 +97,19 @@ class ComposedOperator(LinearOperator):
@property
def domain(self):
if not hasattr(self, '_domain'):
self._domain = ()
dom = ()
for op in self._operator_store:
self._domain += op.domain
dom += op.domain.domains
self._domain = DomainTuple.make(dom)
return self._domain
@property
def target(self):
if not hasattr(self, '_target'):
self._target = ()
tgt = ()
for op in self._operator_store:
self._target += op.target
tgt += op.target.domains
self._target = DomainTuple.make(tgt)
return self._target
@property
......
......@@ -21,6 +21,7 @@ from builtins import range
import numpy as np
from ...field import Field
from ...domain_tuple import DomainTuple
from ..endomorphic_operator import EndomorphicOperator
......@@ -84,7 +85,7 @@ class DiagonalOperator(EndomorphicOperator):
default_spaces=None):
super(DiagonalOperator, self).__init__(default_spaces)
self._domain = self._parse_domain(domain)
self._domain = DomainTuple.make(domain)
self._self_adjoint = None
self._unitary = None
......@@ -214,7 +215,7 @@ class DiagonalOperator(EndomorphicOperator):
else:
active_axes = []
for space_index in spaces:
active_axes += x.domain_axes[space_index]
active_axes += x.domain.axes[space_index]
reshaper = [x.shape[i] if i in active_axes else 1
for i in range(len(x.shape))]
......
......@@ -16,7 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from ... import Field, nifty_utilities as utilities
from ... import Field, DomainTuple, nifty_utilities as utilities
from ...spaces import RGSpace, GLSpace, HPSpace, LMSpace
from ..linear_operator import LinearOperator
......@@ -96,14 +96,14 @@ class FFTOperator(LinearOperator):
super(FFTOperator, self).__init__(default_spaces)
# Initialize domain and target
self._domain = self._parse_domain(domain)
self._domain = DomainTuple.make(domain)
if len(self.domain) != 1:
raise ValueError("TransformationOperator accepts only exactly one "
"space as input domain.")
if target is None:
target = (self.domain[0].get_default_codomain(), )
self._target = self._parse_domain(target)
self._target = DomainTuple.make(target)
if len(self.target) != 1:
raise ValueError("TransformationOperator accepts only exactly one "
"space as output target.")
......@@ -127,13 +127,13 @@ class FFTOperator(LinearOperator):
# this case means that x lives on only one space, which is
# identical to the space in the domain of `self`. Otherwise the
# input check of LinearOperator would have failed.
axes = x.domain_axes[0]
axes = x.domain.axes[0]
result_domain = other
else:
spaces = utilities.cast_iseq_to_tuple(spaces)
result_domain = list(x.domain)
result_domain[spaces[0]] = other[0]
axes = x.domain_axes[spaces[0]]
axes = x.domain.axes[spaces[0]]
new_val, fct = trafo.transform(x.val, axes=axes)
res = Field(result_domain, new_val, copy=False)
......
......@@ -20,7 +20,7 @@ import numpy as np
from ...field import Field
from ...spaces.power_space import PowerSpace
from ..endomorphic_operator import EndomorphicOperator
from ... import sqrt
from ... import sqrt, DomainTuple
from ... import nifty_utilities as utilities
......@@ -41,7 +41,7 @@ class LaplaceOperator(EndomorphicOperator):
def __init__(self, domain, default_spaces=None, logarithmic=True):
super(LaplaceOperator, self).__init__(default_spaces)
self._domain = self._parse_domain(domain)
self._domain = DomainTuple.make(domain)
if len(self.domain) != 1:
raise ValueError("The domain must contain exactly one PowerSpace.")
......@@ -93,10 +93,10 @@ class LaplaceOperator(EndomorphicOperator):
# this case means that x lives on only one space, which is
# identical to the space in the domain of `self`. Otherwise the
# input check of LinearOperator would have failed.
axes = x.domain_axes[0]
axes = x.domain.axes[0]
else:
spaces = utilities.cast_iseq_to_tuple(spaces)
axes = x.domain_axes[spaces[0]]
axes = x.domain.axes[spaces[0]]
axis = axes[0]
nval = len(self._dposc)
prefix = (slice(None),) * axis
......@@ -119,10 +119,10 @@ class LaplaceOperator(EndomorphicOperator):
# this case means that x lives on only one space, which is
# identical to the space in the domain of `self`. Otherwise the
# input check of LinearOperator would have failed.
axes = x.domain_axes[0]
axes = x.domain.axes[0]
else:
spaces = utilities.cast_iseq_to_tuple(spaces)
axes = x.domain_axes[spaces[0]]
axes = x.domain.axes[spaces[0]]
axis = axes[0]
nval = len(self._dposc)
prefix = (slice(None),) * axis
......
......@@ -69,10 +69,6 @@ class LinearOperator(with_metaclass(
def __init__(self, default_spaces=None):
self._default_spaces = default_spaces
@staticmethod
def _parse_domain(domain):
return utilities.parse_domain(domain)
@abc.abstractproperty
def domain(self):
"""
......
......@@ -88,7 +88,7 @@ class ProjectionOperator(EndomorphicOperator):
active_axes = list(range(len(x.shape)))
else:
for space_index in spaces:
active_axes += x.domain_axes[space_index]
active_axes += x.domain.axes[space_index]
local_projection_vector = self._projection_field.val
......
from builtins import range
from ... import Field,\
FieldArray
FieldArray,\
DomainTuple
from ..linear_operator import LinearOperator
from ..smoothing_operator import FFTSmoothingOperator
from ..composed_operator import ComposedOperator
......@@ -54,7 +55,7 @@ class ResponseOperator(LinearOperator):
"exposure do not match")
nsigma = len(sigma)
self._domain = self._parse_domain(domain)
self._domain = DomainTuple.make(domain)
kernel_smoothing = [FFTSmoothingOperator(self._domain[x], sigma[x])
for x in range(nsigma)]
......@@ -65,7 +66,7 @@ class ResponseOperator(LinearOperator):
self._composed_exposure = ComposedOperator(kernel_exposure)
target_list = [FieldArray(x.shape) for x in self.domain]
self._target = self._parse_domain(target_list)
self._target = DomainTuple.make(target_list)
@property
def domain(self):
......
......@@ -6,7 +6,7 @@ import numpy as np
from ..endomorphic_operator import EndomorphicOperator
from ... import nifty_utilities as utilities
from ... import Field
from ... import Field, DomainTuple
class DirectSmoothingOperator(EndomorphicOperator):
......@@ -14,7 +14,7 @@ class DirectSmoothingOperator(EndomorphicOperator):
default_spaces=None):
super(DirectSmoothingOperator, self).__init__(default_spaces)
self._domain = self._parse_domain(domain)
self._domain = DomainTuple.make(domain)
if len(self._domain) != 1:
raise ValueError("DirectSmoothingOperator only accepts exactly one"
" space as input domain.")
......@@ -93,7 +93,7 @@ class DirectSmoothingOperator(EndomorphicOperator):
def _smooth(self, x, spaces):
# infer affected axes
# we rely on the knowledge that `spaces` is a tuple with length 1.
affected_axes = x.domain_axes[spaces[0]]
affected_axes = x.domain.axes[spaces[0]]
if len(affected_axes) != 1:
raise ValueError("By this implementation only one-dimensional "
"spaces can be smoothed directly.")
......
......@@ -5,7 +5,7 @@ import numpy as np
from ..endomorphic_operator import EndomorphicOperator
from ..fft_operator import FFTOperator
from ... import DomainTuple
class FFTSmoothingOperator(EndomorphicOperator):
......@@ -13,7 +13,7 @@ class FFTSmoothingOperator(EndomorphicOperator):
default_spaces=None):
super(FFTSmoothingOperator, self).__init__(default_spaces)
self._domain = self._parse_domain(domain)
self._domain = DomainTuple.make(domain)
if len(self._domain) != 1:
raise ValueError("SmoothingOperator only accepts exactly one "
"space as input domain.")
......@@ -57,7 +57,7 @@ class FFTSmoothingOperator(EndomorphicOperator):
# transform to the (global-)default codomain and perform all remaining
# steps therein
transformed_x = self._transformator(x, spaces=spaces)
coaxes = transformed_x.domain_axes[spaces[0]]
coaxes = transformed_x.domain.axes[spaces[0]]
# now, apply the kernel to transformed_x
# this is done node-locally utilizing numpy's reshaping in order to
......
from ...spaces.power_space import PowerSpace
from ..endomorphic_operator import EndomorphicOperator
from ..laplace_operator import LaplaceOperator
from ... import Field
from ... import Field, DomainTuple
class SmoothnessOperator(EndomorphicOperator):
......@@ -34,7 +34,7 @@ class SmoothnessOperator(EndomorphicOperator):
super(SmoothnessOperator, self).__init__(default_spaces=default_spaces)
self._domain = self._parse_domain(domain)
self._domain = DomainTuple.make(domain)
if len(self.domain) != 1:
raise ValueError("The domain must contain exactly one PowerSpace.")
......
......@@ -21,7 +21,7 @@ from builtins import range
from builtins import object
import numpy as np
from ...field import Field
from ...field import Field, DomainTuple
from ... import nifty_utilities as utilities
......@@ -39,7 +39,7 @@ class Prober(object):
random_type='pm1', probe_dtype=np.float,
compute_variance=False, ncpu=1):
self._domain = utilities.parse_domain(domain)
self._domain = DomainTuple.make(domain)
self._probe_count = self._parse_probe_count(probe_count)
self._ncpu = self._parse_probe_count(ncpu)
self._random_type = self._parse_random_type(random_type)
......
......@@ -29,7 +29,8 @@ from itertools import product
from nifty2go import Field,\
RGSpace,\
LMSpace,\
PowerSpace
PowerSpace,\
DomainTuple
from test.common import expand
......@@ -40,11 +41,10 @@ SPACE_COMBINATIONS = [(), SPACES[0], SPACES[1], SPACES]
class Test_Interface(unittest.TestCase):
@expand(product(SPACE_COMBINATIONS,
[['domain', tuple],
['domain_axes', tuple],
[['domain', DomainTuple],
['val', np.ndarray],
['shape', tuple],
['dim', np.int],
['dim', (np.int, np.int64)],
['total_volume', np.float]]))
def test_return_types(self, domain, attribute_desired_type):
attribute = attribute_desired_type[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment