diff --git a/nifty2go/__init__.py b/nifty2go/__init__.py index 5a72fe9188db4e964eefdda98fef0892386a3478..ab6af5ac209259e0c38f12813ce7f15afe963186 100644 --- a/nifty2go/__init__.py +++ b/nifty2go/__init__.py @@ -22,6 +22,8 @@ from .version import __version__ from .field import Field +from .domain_tuple import DomainTuple + from .random import Random from .basic_arithmetics import * diff --git a/nifty2go/domain_tuple.py b/nifty2go/domain_tuple.py new file mode 100644 index 0000000000000000000000000000000000000000..8408bbe76f2dfdaff9d43509b4fa402ccaae8c7f --- /dev/null +++ b/nifty2go/domain_tuple.py @@ -0,0 +1,94 @@ +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# +# Copyright(C) 2013-2017 Max-Planck-Society +# +# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik +# and financially supported by the Studienstiftung des deutschen Volkes. + +from functools import reduce +from .domain_object import DomainObject + +class DomainTuple(object): + def __init__(self, domain): + self._dom = self._parse_domain(domain) + self._axtuple = self._get_axes_tuple() + shape_tuple = tuple(sp.shape for sp in self._dom) + self._shape = reduce(lambda x, y: x + y, shape_tuple, ()) + self._dim = reduce(lambda x, y: x * y, self._shape, 1) + + def _get_axes_tuple(self): + i = 0 + res = [None]*len(self._dom) + for idx, thing in enumerate(self._dom): + nax = len(thing.shape) + res[idx] = tuple(range(i, i+nax)) + i += nax + return res + + @staticmethod + def make(domain): + if isinstance(domain, DomainTuple): + return domain + return DomainTuple(domain) + + @staticmethod + def _parse_domain(domain): + if domain is None: + return () + if isinstance(domain, DomainObject): + return (domain,) + + if not isinstance(domain, tuple): + domain = tuple(domain) + for d in domain: + if not isinstance(d, DomainObject): + raise TypeError( + "Given object contains something that is not an " + "instance of DomainObject class.") + return domain + + def __getitem__(self, i): + return self._dom[i] + + @property + def domains(self): + return self._dom + + @property + def shape(self): + return self._shape + + @property + def dim(self): + return self._dim + + @property + def axes(self): + return self._axtuple + + def __len__(self): + return len(self._dom) + + def __hash__(self): + return self._dom.__hash__() + + def __eq__(self, x): + if not isinstance(x, DomainTuple): + x = DomainTuple(x) + return self._dom == x._dom + + def __ne__(self, x): + if not isinstance(x, DomainTuple): + x = DomainTuple(x) + return self._dom != x._dom diff --git a/nifty2go/field.py b/nifty2go/field.py index 09f029f27ec26f6d6e5567da66d3629a82a431ef..e11352c22a5e1e686c5a52971ee9fdd583a85800 100644 --- a/nifty2go/field.py +++ b/nifty2go/field.py @@ -22,6 +22,7 @@ import numpy as np from .spaces.power_space import PowerSpace from . import nifty_utilities as utilities from .random import Random +from .domain_tuple import DomainTuple from functools import reduce @@ -53,10 +54,8 @@ class Field(object): ---------- val : numpy.ndarray - domain : DomainObject + domain : DomainTuple See Parameters. - domain_axes : tuple of tuples - Enumerates the axes of the Field dtype : type Contains the datatype stored in the Field. @@ -74,27 +73,21 @@ class Field(object): def __init__(self, domain=None, val=None, dtype=None, copy=False): self.domain = self._parse_domain(domain=domain, val=val) - self.domain_axes = self._get_axes_tuple(self.domain) - shape_tuple = tuple(sp.shape for sp in self.domain) - if len(shape_tuple) == 0: - global_shape = () - else: - global_shape = reduce(lambda x, y: x + y, shape_tuple) dtype = self._infer_dtype(dtype=dtype, val=val) if isinstance(val, Field): if self.domain != val.domain: raise ValueError("Domain mismatch") self._val = np.array(val.val, dtype=dtype, copy=copy) elif (np.isscalar(val)): - self._val = np.full(global_shape, dtype=dtype, fill_value=val) + self._val = np.full(self.domain.shape, dtype=dtype, fill_value=val) elif isinstance(val, np.ndarray): - if global_shape == val.shape: + if self.domain.shape == val.shape: self._val = np.array(val, dtype=dtype, copy=copy) else: raise ValueError("Shape mismatch") elif val is None: - self._val = np.empty(global_shape, dtype=dtype) + self._val = np.empty(self.domain.shape, dtype=dtype) else: raise TypeError("unknown source type") @@ -104,19 +97,9 @@ class Field(object): if isinstance(val, Field): return val.domain if np.isscalar(val): - return () # empty domain tuple + return DomainTuple(()) # empty domain tuple raise TypeError("could not infer domain from value") - return utilities.parse_domain(domain) - - @staticmethod - def _get_axes_tuple(things_with_shape): - i = 0 - axes_list = [None]*len(things_with_shape) - for idx, thing in enumerate(things_with_shape): - nax = len(thing.shape) - axes_list[idx] = tuple(range(i, i+nax)) - i += nax - return tuple(axes_list) + return DomainTuple.make(domain) # MR: this needs some rethinking ... do we need to have at least float64? @staticmethod @@ -155,9 +138,10 @@ class Field(object): power_synthesize """ + domain = DomainTuple.make(domain) generator_function = getattr(Random, random_type) return Field(domain=domain, val=generator_function(dtype=dtype, - shape=utilities.domains2shape(domain), **kwargs)) + shape=domain.shape, **kwargs)) # ---Powerspectral methods--- @@ -240,7 +224,7 @@ class Field(object): def _single_power_analyze(field, idx, binbounds): power_domain = PowerSpace(field.domain[idx], binbounds) pindex = power_domain.pindex - axes = field.domain_axes[idx] + axes = field.domain.axes[idx] new_pindex_shape = [1] * len(field.shape) for i, ax in enumerate(axes): new_pindex_shape[ax] = pindex.shape[i] @@ -275,7 +259,7 @@ class Field(object): local_blow_up = [slice(None)]*len(spec.shape) # it is important to count from behind, since spec potentially # grows with every iteration - index = self.domain_axes[i][0]-len(self.shape) + index = self.domain.axes[i][0]-len(self.shape) local_blow_up[index] = power_space.pindex # here, the power_spectrum is distributed into the new shape spec = spec[local_blow_up] @@ -380,7 +364,7 @@ class Field(object): The output object. The tuple contains the dimensions of the spaces in domain. """ - return self._val.shape + return self.domain.shape @property def dim(self): @@ -393,7 +377,7 @@ class Field(object): out : int The dimension of the Field. """ - return self._val.size + return self.domain.dim @property def total_volume(self): @@ -479,8 +463,8 @@ class Field(object): fct *= wgt else: new_shape = np.ones(len(self.shape), dtype=np.int) - new_shape[self.domain_axes[ind][0]: - self.domain_axes[ind][-1]+1] = wgt.shape + new_shape[self.domain.axes[ind][0]: + self.domain.axes[ind][-1]+1] = wgt.shape wgt = wgt.reshape(new_shape) new_field *= wgt**power fct = fct**power @@ -574,7 +558,7 @@ class Field(object): else: spaces = utilities.cast_iseq_to_tuple(spaces) - axes_list = tuple(self.domain_axes[sp_index] for sp_index in spaces) + axes_list = tuple(self.domain.axes[sp_index] for sp_index in spaces) if len(axes_list) > 0: axes_list = reduce(lambda x, y: x+y, axes_list) diff --git a/nifty2go/nifty_utilities.py b/nifty2go/nifty_utilities.py index 2b6ba0765ba9787efae50ab6065aeeb43bd6bf6d..b072de5d1d86a4f3e70131720c7d8c144fa95206 100644 --- a/nifty2go/nifty_utilities.py +++ b/nifty2go/nifty_utilities.py @@ -76,31 +76,6 @@ def cast_iseq_to_tuple(seq): return tuple(int(item) for item in seq) -def parse_domain(domain): - if domain is None: - return () - if isinstance(domain, DomainObject): - return (domain,) - - if not isinstance(domain, tuple): - domain = tuple(domain) - for d in domain: - if not isinstance(d, DomainObject): - raise TypeError( - "Given object contains something that is not an " - "instance of DomainObject-class.") - return domain - - -def domains2shape(domain): - domain = parse_domain(domain) - shape_tuple = tuple(sp.shape for sp in domain) - if len(shape_tuple) == 0: - return () - else: - return reduce(lambda x, y: x + y, shape_tuple) - - def bincount_axis(obj, minlength=None, weights=None, axis=None): if minlength is not None: length = max(np.amax(obj) + 1, minlength) diff --git a/nifty2go/operators/composed_operator/composed_operator.py b/nifty2go/operators/composed_operator/composed_operator.py index 98f97ae19cc456877954f77f464680149eed7766..ade9fd85c6d79a712a22fcf7bef25513f45f0e96 100644 --- a/nifty2go/operators/composed_operator/composed_operator.py +++ b/nifty2go/operators/composed_operator/composed_operator.py @@ -18,7 +18,7 @@ from builtins import range from ..linear_operator import LinearOperator - +from ... import DomainTuple class ComposedOperator(LinearOperator): """ NIFTY class for composed operators. @@ -97,17 +97,19 @@ class ComposedOperator(LinearOperator): @property def domain(self): if not hasattr(self, '_domain'): - self._domain = () + dom = () for op in self._operator_store: - self._domain += op.domain + dom += op.domain.domains + self._domain = DomainTuple.make(dom) return self._domain @property def target(self): if not hasattr(self, '_target'): - self._target = () + tgt = () for op in self._operator_store: - self._target += op.target + tgt += op.target.domains + self._target = DomainTuple.make(tgt) return self._target @property diff --git a/nifty2go/operators/diagonal_operator/diagonal_operator.py b/nifty2go/operators/diagonal_operator/diagonal_operator.py index 0ce6753289bd3a1d43da972ef6059eb9916701b2..34a34b80623b8ad124c4745e661352fca6ae07e9 100644 --- a/nifty2go/operators/diagonal_operator/diagonal_operator.py +++ b/nifty2go/operators/diagonal_operator/diagonal_operator.py @@ -21,6 +21,7 @@ from builtins import range import numpy as np from ...field import Field +from ...domain_tuple import DomainTuple from ..endomorphic_operator import EndomorphicOperator @@ -84,7 +85,7 @@ class DiagonalOperator(EndomorphicOperator): default_spaces=None): super(DiagonalOperator, self).__init__(default_spaces) - self._domain = self._parse_domain(domain) + self._domain = DomainTuple.make(domain) self._self_adjoint = None self._unitary = None @@ -214,7 +215,7 @@ class DiagonalOperator(EndomorphicOperator): else: active_axes = [] for space_index in spaces: - active_axes += x.domain_axes[space_index] + active_axes += x.domain.axes[space_index] reshaper = [x.shape[i] if i in active_axes else 1 for i in range(len(x.shape))] diff --git a/nifty2go/operators/fft_operator/fft_operator.py b/nifty2go/operators/fft_operator/fft_operator.py index 5b44657e17bd46c597aef38bac48c7893e5d144c..19f4023042785c64eb843573b8521801baf0d984 100644 --- a/nifty2go/operators/fft_operator/fft_operator.py +++ b/nifty2go/operators/fft_operator/fft_operator.py @@ -16,7 +16,7 @@ # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # and financially supported by the Studienstiftung des deutschen Volkes. -from ... import Field, nifty_utilities as utilities +from ... import Field, DomainTuple, nifty_utilities as utilities from ...spaces import RGSpace, GLSpace, HPSpace, LMSpace from ..linear_operator import LinearOperator @@ -96,14 +96,14 @@ class FFTOperator(LinearOperator): super(FFTOperator, self).__init__(default_spaces) # Initialize domain and target - self._domain = self._parse_domain(domain) + self._domain = DomainTuple.make(domain) if len(self.domain) != 1: raise ValueError("TransformationOperator accepts only exactly one " "space as input domain.") if target is None: target = (self.domain[0].get_default_codomain(), ) - self._target = self._parse_domain(target) + self._target = DomainTuple.make(target) if len(self.target) != 1: raise ValueError("TransformationOperator accepts only exactly one " "space as output target.") @@ -127,13 +127,13 @@ class FFTOperator(LinearOperator): # this case means that x lives on only one space, which is # identical to the space in the domain of `self`. Otherwise the # input check of LinearOperator would have failed. - axes = x.domain_axes[0] + axes = x.domain.axes[0] result_domain = other else: spaces = utilities.cast_iseq_to_tuple(spaces) result_domain = list(x.domain) result_domain[spaces[0]] = other[0] - axes = x.domain_axes[spaces[0]] + axes = x.domain.axes[spaces[0]] new_val, fct = trafo.transform(x.val, axes=axes) res = Field(result_domain, new_val, copy=False) diff --git a/nifty2go/operators/laplace_operator/laplace_operator.py b/nifty2go/operators/laplace_operator/laplace_operator.py index 943ae912f595cac09968d9b6c46c070f74135cf1..fe52bedeaa11602b55fe85665652f03a52c03c32 100644 --- a/nifty2go/operators/laplace_operator/laplace_operator.py +++ b/nifty2go/operators/laplace_operator/laplace_operator.py @@ -20,7 +20,7 @@ import numpy as np from ...field import Field from ...spaces.power_space import PowerSpace from ..endomorphic_operator import EndomorphicOperator -from ... import sqrt +from ... import sqrt, DomainTuple from ... import nifty_utilities as utilities @@ -41,7 +41,7 @@ class LaplaceOperator(EndomorphicOperator): def __init__(self, domain, default_spaces=None, logarithmic=True): super(LaplaceOperator, self).__init__(default_spaces) - self._domain = self._parse_domain(domain) + self._domain = DomainTuple.make(domain) if len(self.domain) != 1: raise ValueError("The domain must contain exactly one PowerSpace.") @@ -93,10 +93,10 @@ class LaplaceOperator(EndomorphicOperator): # this case means that x lives on only one space, which is # identical to the space in the domain of `self`. Otherwise the # input check of LinearOperator would have failed. - axes = x.domain_axes[0] + axes = x.domain.axes[0] else: spaces = utilities.cast_iseq_to_tuple(spaces) - axes = x.domain_axes[spaces[0]] + axes = x.domain.axes[spaces[0]] axis = axes[0] nval = len(self._dposc) prefix = (slice(None),) * axis @@ -119,10 +119,10 @@ class LaplaceOperator(EndomorphicOperator): # this case means that x lives on only one space, which is # identical to the space in the domain of `self`. Otherwise the # input check of LinearOperator would have failed. - axes = x.domain_axes[0] + axes = x.domain.axes[0] else: spaces = utilities.cast_iseq_to_tuple(spaces) - axes = x.domain_axes[spaces[0]] + axes = x.domain.axes[spaces[0]] axis = axes[0] nval = len(self._dposc) prefix = (slice(None),) * axis diff --git a/nifty2go/operators/linear_operator/linear_operator.py b/nifty2go/operators/linear_operator/linear_operator.py index 8477caf60b2ea8008b1b88aff5bfc13b7597e077..39b1487b2d8f8ead6ba18eff3c9166f109738395 100644 --- a/nifty2go/operators/linear_operator/linear_operator.py +++ b/nifty2go/operators/linear_operator/linear_operator.py @@ -69,10 +69,6 @@ class LinearOperator(with_metaclass( def __init__(self, default_spaces=None): self._default_spaces = default_spaces - @staticmethod - def _parse_domain(domain): - return utilities.parse_domain(domain) - @abc.abstractproperty def domain(self): """ diff --git a/nifty2go/operators/projection_operator/projection_operator.py b/nifty2go/operators/projection_operator/projection_operator.py index bb8e2049f0a6ba66e4d3784d7c1735698d88154e..377128fa0c4eb8794ee58823b8b91088743c6757 100644 --- a/nifty2go/operators/projection_operator/projection_operator.py +++ b/nifty2go/operators/projection_operator/projection_operator.py @@ -88,7 +88,7 @@ class ProjectionOperator(EndomorphicOperator): active_axes = list(range(len(x.shape))) else: for space_index in spaces: - active_axes += x.domain_axes[space_index] + active_axes += x.domain.axes[space_index] local_projection_vector = self._projection_field.val diff --git a/nifty2go/operators/response_operator/response_operator.py b/nifty2go/operators/response_operator/response_operator.py index 9ab5510fd0bacbd802bc4618cd4f5f11b17bfc58..f8e89e95872aa8694ec0699b223e1b734edc303e 100644 --- a/nifty2go/operators/response_operator/response_operator.py +++ b/nifty2go/operators/response_operator/response_operator.py @@ -1,6 +1,7 @@ from builtins import range from ... import Field,\ - FieldArray + FieldArray,\ + DomainTuple from ..linear_operator import LinearOperator from ..smoothing_operator import FFTSmoothingOperator from ..composed_operator import ComposedOperator @@ -54,7 +55,7 @@ class ResponseOperator(LinearOperator): "exposure do not match") nsigma = len(sigma) - self._domain = self._parse_domain(domain) + self._domain = DomainTuple.make(domain) kernel_smoothing = [FFTSmoothingOperator(self._domain[x], sigma[x]) for x in range(nsigma)] @@ -65,7 +66,7 @@ class ResponseOperator(LinearOperator): self._composed_exposure = ComposedOperator(kernel_exposure) target_list = [FieldArray(x.shape) for x in self.domain] - self._target = self._parse_domain(target_list) + self._target = DomainTuple.make(target_list) @property def domain(self): diff --git a/nifty2go/operators/smoothing_operator/direct_smoothing_operator.py b/nifty2go/operators/smoothing_operator/direct_smoothing_operator.py index 6e3affe29c5a113d0c7fc01c8fe1dbde628dc4fc..0efa637a4af9031bf1a6afc0946583e3f8881884 100644 --- a/nifty2go/operators/smoothing_operator/direct_smoothing_operator.py +++ b/nifty2go/operators/smoothing_operator/direct_smoothing_operator.py @@ -6,7 +6,7 @@ import numpy as np from ..endomorphic_operator import EndomorphicOperator from ... import nifty_utilities as utilities -from ... import Field +from ... import Field, DomainTuple class DirectSmoothingOperator(EndomorphicOperator): @@ -14,7 +14,7 @@ class DirectSmoothingOperator(EndomorphicOperator): default_spaces=None): super(DirectSmoothingOperator, self).__init__(default_spaces) - self._domain = self._parse_domain(domain) + self._domain = DomainTuple.make(domain) if len(self._domain) != 1: raise ValueError("DirectSmoothingOperator only accepts exactly one" " space as input domain.") @@ -93,7 +93,7 @@ class DirectSmoothingOperator(EndomorphicOperator): def _smooth(self, x, spaces): # infer affected axes # we rely on the knowledge that `spaces` is a tuple with length 1. - affected_axes = x.domain_axes[spaces[0]] + affected_axes = x.domain.axes[spaces[0]] if len(affected_axes) != 1: raise ValueError("By this implementation only one-dimensional " "spaces can be smoothed directly.") diff --git a/nifty2go/operators/smoothing_operator/fft_smoothing_operator.py b/nifty2go/operators/smoothing_operator/fft_smoothing_operator.py index d9f66bdf8c3e0800e5d02eb8f1e5e9a3d1a815e9..6a257ac834ed31f08e57c955455c50f34e9d9e15 100644 --- a/nifty2go/operators/smoothing_operator/fft_smoothing_operator.py +++ b/nifty2go/operators/smoothing_operator/fft_smoothing_operator.py @@ -5,7 +5,7 @@ import numpy as np from ..endomorphic_operator import EndomorphicOperator from ..fft_operator import FFTOperator - +from ... import DomainTuple class FFTSmoothingOperator(EndomorphicOperator): @@ -13,7 +13,7 @@ class FFTSmoothingOperator(EndomorphicOperator): default_spaces=None): super(FFTSmoothingOperator, self).__init__(default_spaces) - self._domain = self._parse_domain(domain) + self._domain = DomainTuple.make(domain) if len(self._domain) != 1: raise ValueError("SmoothingOperator only accepts exactly one " "space as input domain.") @@ -57,7 +57,7 @@ class FFTSmoothingOperator(EndomorphicOperator): # transform to the (global-)default codomain and perform all remaining # steps therein transformed_x = self._transformator(x, spaces=spaces) - coaxes = transformed_x.domain_axes[spaces[0]] + coaxes = transformed_x.domain.axes[spaces[0]] # now, apply the kernel to transformed_x # this is done node-locally utilizing numpy's reshaping in order to diff --git a/nifty2go/operators/smoothness_operator/smoothness_operator.py b/nifty2go/operators/smoothness_operator/smoothness_operator.py index 0067ede9ac08b35e0b1c89b3ee76f4be8e7a3565..cc27f9153aa3b45253f2156ce7bc36b80593c65c 100644 --- a/nifty2go/operators/smoothness_operator/smoothness_operator.py +++ b/nifty2go/operators/smoothness_operator/smoothness_operator.py @@ -1,7 +1,7 @@ from ...spaces.power_space import PowerSpace from ..endomorphic_operator import EndomorphicOperator from ..laplace_operator import LaplaceOperator -from ... import Field +from ... import Field, DomainTuple class SmoothnessOperator(EndomorphicOperator): @@ -34,7 +34,7 @@ class SmoothnessOperator(EndomorphicOperator): super(SmoothnessOperator, self).__init__(default_spaces=default_spaces) - self._domain = self._parse_domain(domain) + self._domain = DomainTuple.make(domain) if len(self.domain) != 1: raise ValueError("The domain must contain exactly one PowerSpace.") diff --git a/nifty2go/probing/prober/prober.py b/nifty2go/probing/prober/prober.py index 8b2754217f07c8620c8f7dc1870e71f65c381886..0aa504f8207c745c96c27e5202848808714a7c1e 100644 --- a/nifty2go/probing/prober/prober.py +++ b/nifty2go/probing/prober/prober.py @@ -21,7 +21,7 @@ from builtins import range from builtins import object import numpy as np -from ...field import Field +from ...field import Field, DomainTuple from ... import nifty_utilities as utilities @@ -39,7 +39,7 @@ class Prober(object): random_type='pm1', probe_dtype=np.float, compute_variance=False, ncpu=1): - self._domain = utilities.parse_domain(domain) + self._domain = DomainTuple.make(domain) self._probe_count = self._parse_probe_count(probe_count) self._ncpu = self._parse_probe_count(ncpu) self._random_type = self._parse_random_type(random_type) diff --git a/test/test_field.py b/test/test_field.py index 4b37a02c4e6a6e03224f5fab601cb450362e01b1..48bf15b38bb172d04a960fb4d75c67584a9c265f 100644 --- a/test/test_field.py +++ b/test/test_field.py @@ -29,7 +29,8 @@ from itertools import product from nifty2go import Field,\ RGSpace,\ LMSpace,\ - PowerSpace + PowerSpace,\ + DomainTuple from test.common import expand @@ -40,11 +41,10 @@ SPACE_COMBINATIONS = [(), SPACES[0], SPACES[1], SPACES] class Test_Interface(unittest.TestCase): @expand(product(SPACE_COMBINATIONS, - [['domain', tuple], - ['domain_axes', tuple], + [['domain', DomainTuple], ['val', np.ndarray], ['shape', tuple], - ['dim', np.int], + ['dim', (np.int, np.int64)], ['total_volume', np.float]])) def test_return_types(self, domain, attribute_desired_type): attribute = attribute_desired_type[0]