Commit cfbd598e authored by Martin Reinecke's avatar Martin Reinecke
Browse files

more DomainTuple work

parent 75145bf8
...@@ -22,6 +22,8 @@ from .version import __version__ ...@@ -22,6 +22,8 @@ from .version import __version__
from .field import Field from .field import Field
from .domain_tuple import DomainTuple
from .random import Random from .random import Random
from .basic_arithmetics import * from .basic_arithmetics import *
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from functools import reduce
from .domain_object import DomainObject
class DomainTuple(object):
def __init__(self, domain):
self._dom = self._parse_domain(domain)
self._axtuple = self._get_axes_tuple()
shape_tuple = tuple(sp.shape for sp in self._dom)
self._shape = reduce(lambda x, y: x + y, shape_tuple, ())
self._dim = reduce(lambda x, y: x * y, self._shape, 1)
def _get_axes_tuple(self):
i = 0
res = [None]*len(self._dom)
for idx, thing in enumerate(self._dom):
nax = len(thing.shape)
res[idx] = tuple(range(i, i+nax))
i += nax
return res
@staticmethod
def make(domain):
if isinstance(domain, DomainTuple):
return domain
return DomainTuple(domain)
@staticmethod
def _parse_domain(domain):
if domain is None:
return ()
if isinstance(domain, DomainObject):
return (domain,)
if not isinstance(domain, tuple):
domain = tuple(domain)
for d in domain:
if not isinstance(d, DomainObject):
raise TypeError(
"Given object contains something that is not an "
"instance of DomainObject class.")
return domain
def __getitem__(self, i):
return self._dom[i]
@property
def domains(self):
return self._dom
@property
def shape(self):
return self._shape
@property
def dim(self):
return self._dim
@property
def axes(self):
return self._axtuple
def __len__(self):
return len(self._dom)
def __hash__(self):
return self._dom.__hash__()
def __eq__(self, x):
if not isinstance(x, DomainTuple):
x = DomainTuple(x)
return self._dom == x._dom
def __ne__(self, x):
if not isinstance(x, DomainTuple):
x = DomainTuple(x)
return self._dom != x._dom
...@@ -22,6 +22,7 @@ import numpy as np ...@@ -22,6 +22,7 @@ import numpy as np
from .spaces.power_space import PowerSpace from .spaces.power_space import PowerSpace
from . import nifty_utilities as utilities from . import nifty_utilities as utilities
from .random import Random from .random import Random
from .domain_tuple import DomainTuple
from functools import reduce from functools import reduce
...@@ -53,10 +54,8 @@ class Field(object): ...@@ -53,10 +54,8 @@ class Field(object):
---------- ----------
val : numpy.ndarray val : numpy.ndarray
domain : DomainObject domain : DomainTuple
See Parameters. See Parameters.
domain_axes : tuple of tuples
Enumerates the axes of the Field
dtype : type dtype : type
Contains the datatype stored in the Field. Contains the datatype stored in the Field.
...@@ -74,27 +73,21 @@ class Field(object): ...@@ -74,27 +73,21 @@ class Field(object):
def __init__(self, domain=None, val=None, dtype=None, copy=False): def __init__(self, domain=None, val=None, dtype=None, copy=False):
self.domain = self._parse_domain(domain=domain, val=val) self.domain = self._parse_domain(domain=domain, val=val)
self.domain_axes = self._get_axes_tuple(self.domain)
shape_tuple = tuple(sp.shape for sp in self.domain)
if len(shape_tuple) == 0:
global_shape = ()
else:
global_shape = reduce(lambda x, y: x + y, shape_tuple)
dtype = self._infer_dtype(dtype=dtype, val=val) dtype = self._infer_dtype(dtype=dtype, val=val)
if isinstance(val, Field): if isinstance(val, Field):
if self.domain != val.domain: if self.domain != val.domain:
raise ValueError("Domain mismatch") raise ValueError("Domain mismatch")
self._val = np.array(val.val, dtype=dtype, copy=copy) self._val = np.array(val.val, dtype=dtype, copy=copy)
elif (np.isscalar(val)): elif (np.isscalar(val)):
self._val = np.full(global_shape, dtype=dtype, fill_value=val) self._val = np.full(self.domain.shape, dtype=dtype, fill_value=val)
elif isinstance(val, np.ndarray): elif isinstance(val, np.ndarray):
if global_shape == val.shape: if self.domain.shape == val.shape:
self._val = np.array(val, dtype=dtype, copy=copy) self._val = np.array(val, dtype=dtype, copy=copy)
else: else:
raise ValueError("Shape mismatch") raise ValueError("Shape mismatch")
elif val is None: elif val is None:
self._val = np.empty(global_shape, dtype=dtype) self._val = np.empty(self.domain.shape, dtype=dtype)
else: else:
raise TypeError("unknown source type") raise TypeError("unknown source type")
...@@ -104,19 +97,9 @@ class Field(object): ...@@ -104,19 +97,9 @@ class Field(object):
if isinstance(val, Field): if isinstance(val, Field):
return val.domain return val.domain
if np.isscalar(val): if np.isscalar(val):
return () # empty domain tuple return DomainTuple(()) # empty domain tuple
raise TypeError("could not infer domain from value") raise TypeError("could not infer domain from value")
return utilities.parse_domain(domain) return DomainTuple.make(domain)
@staticmethod
def _get_axes_tuple(things_with_shape):
i = 0
axes_list = [None]*len(things_with_shape)
for idx, thing in enumerate(things_with_shape):
nax = len(thing.shape)
axes_list[idx] = tuple(range(i, i+nax))
i += nax
return tuple(axes_list)
# MR: this needs some rethinking ... do we need to have at least float64? # MR: this needs some rethinking ... do we need to have at least float64?
@staticmethod @staticmethod
...@@ -155,9 +138,10 @@ class Field(object): ...@@ -155,9 +138,10 @@ class Field(object):
power_synthesize power_synthesize
""" """
domain = DomainTuple.make(domain)
generator_function = getattr(Random, random_type) generator_function = getattr(Random, random_type)
return Field(domain=domain, val=generator_function(dtype=dtype, return Field(domain=domain, val=generator_function(dtype=dtype,
shape=utilities.domains2shape(domain), **kwargs)) shape=domain.shape, **kwargs))
# ---Powerspectral methods--- # ---Powerspectral methods---
...@@ -240,7 +224,7 @@ class Field(object): ...@@ -240,7 +224,7 @@ class Field(object):
def _single_power_analyze(field, idx, binbounds): def _single_power_analyze(field, idx, binbounds):
power_domain = PowerSpace(field.domain[idx], binbounds) power_domain = PowerSpace(field.domain[idx], binbounds)
pindex = power_domain.pindex pindex = power_domain.pindex
axes = field.domain_axes[idx] axes = field.domain.axes[idx]
new_pindex_shape = [1] * len(field.shape) new_pindex_shape = [1] * len(field.shape)
for i, ax in enumerate(axes): for i, ax in enumerate(axes):
new_pindex_shape[ax] = pindex.shape[i] new_pindex_shape[ax] = pindex.shape[i]
...@@ -275,7 +259,7 @@ class Field(object): ...@@ -275,7 +259,7 @@ class Field(object):
local_blow_up = [slice(None)]*len(spec.shape) local_blow_up = [slice(None)]*len(spec.shape)
# it is important to count from behind, since spec potentially # it is important to count from behind, since spec potentially
# grows with every iteration # grows with every iteration
index = self.domain_axes[i][0]-len(self.shape) index = self.domain.axes[i][0]-len(self.shape)
local_blow_up[index] = power_space.pindex local_blow_up[index] = power_space.pindex
# here, the power_spectrum is distributed into the new shape # here, the power_spectrum is distributed into the new shape
spec = spec[local_blow_up] spec = spec[local_blow_up]
...@@ -380,7 +364,7 @@ class Field(object): ...@@ -380,7 +364,7 @@ class Field(object):
The output object. The tuple contains the dimensions of the spaces The output object. The tuple contains the dimensions of the spaces
in domain. in domain.
""" """
return self._val.shape return self.domain.shape
@property @property
def dim(self): def dim(self):
...@@ -393,7 +377,7 @@ class Field(object): ...@@ -393,7 +377,7 @@ class Field(object):
out : int out : int
The dimension of the Field. The dimension of the Field.
""" """
return self._val.size return self.domain.dim
@property @property
def total_volume(self): def total_volume(self):
...@@ -479,8 +463,8 @@ class Field(object): ...@@ -479,8 +463,8 @@ class Field(object):
fct *= wgt fct *= wgt
else: else:
new_shape = np.ones(len(self.shape), dtype=np.int) new_shape = np.ones(len(self.shape), dtype=np.int)
new_shape[self.domain_axes[ind][0]: new_shape[self.domain.axes[ind][0]:
self.domain_axes[ind][-1]+1] = wgt.shape self.domain.axes[ind][-1]+1] = wgt.shape
wgt = wgt.reshape(new_shape) wgt = wgt.reshape(new_shape)
new_field *= wgt**power new_field *= wgt**power
fct = fct**power fct = fct**power
...@@ -574,7 +558,7 @@ class Field(object): ...@@ -574,7 +558,7 @@ class Field(object):
else: else:
spaces = utilities.cast_iseq_to_tuple(spaces) spaces = utilities.cast_iseq_to_tuple(spaces)
axes_list = tuple(self.domain_axes[sp_index] for sp_index in spaces) axes_list = tuple(self.domain.axes[sp_index] for sp_index in spaces)
if len(axes_list) > 0: if len(axes_list) > 0:
axes_list = reduce(lambda x, y: x+y, axes_list) axes_list = reduce(lambda x, y: x+y, axes_list)
......
...@@ -76,31 +76,6 @@ def cast_iseq_to_tuple(seq): ...@@ -76,31 +76,6 @@ def cast_iseq_to_tuple(seq):
return tuple(int(item) for item in seq) return tuple(int(item) for item in seq)
def parse_domain(domain):
if domain is None:
return ()
if isinstance(domain, DomainObject):
return (domain,)
if not isinstance(domain, tuple):
domain = tuple(domain)
for d in domain:
if not isinstance(d, DomainObject):
raise TypeError(
"Given object contains something that is not an "
"instance of DomainObject-class.")
return domain
def domains2shape(domain):
domain = parse_domain(domain)
shape_tuple = tuple(sp.shape for sp in domain)
if len(shape_tuple) == 0:
return ()
else:
return reduce(lambda x, y: x + y, shape_tuple)
def bincount_axis(obj, minlength=None, weights=None, axis=None): def bincount_axis(obj, minlength=None, weights=None, axis=None):
if minlength is not None: if minlength is not None:
length = max(np.amax(obj) + 1, minlength) length = max(np.amax(obj) + 1, minlength)
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
from builtins import range from builtins import range
from ..linear_operator import LinearOperator from ..linear_operator import LinearOperator
from ... import DomainTuple
class ComposedOperator(LinearOperator): class ComposedOperator(LinearOperator):
""" NIFTY class for composed operators. """ NIFTY class for composed operators.
...@@ -97,17 +97,19 @@ class ComposedOperator(LinearOperator): ...@@ -97,17 +97,19 @@ class ComposedOperator(LinearOperator):
@property @property
def domain(self): def domain(self):
if not hasattr(self, '_domain'): if not hasattr(self, '_domain'):
self._domain = () dom = ()
for op in self._operator_store: for op in self._operator_store:
self._domain += op.domain dom += op.domain.domains
self._domain = DomainTuple.make(dom)
return self._domain return self._domain
@property @property
def target(self): def target(self):
if not hasattr(self, '_target'): if not hasattr(self, '_target'):
self._target = () tgt = ()
for op in self._operator_store: for op in self._operator_store:
self._target += op.target tgt += op.target.domains
self._target = DomainTuple.make(tgt)
return self._target return self._target
@property @property
......
...@@ -21,6 +21,7 @@ from builtins import range ...@@ -21,6 +21,7 @@ from builtins import range
import numpy as np import numpy as np
from ...field import Field from ...field import Field
from ...domain_tuple import DomainTuple
from ..endomorphic_operator import EndomorphicOperator from ..endomorphic_operator import EndomorphicOperator
...@@ -84,7 +85,7 @@ class DiagonalOperator(EndomorphicOperator): ...@@ -84,7 +85,7 @@ class DiagonalOperator(EndomorphicOperator):
default_spaces=None): default_spaces=None):
super(DiagonalOperator, self).__init__(default_spaces) super(DiagonalOperator, self).__init__(default_spaces)
self._domain = self._parse_domain(domain) self._domain = DomainTuple.make(domain)
self._self_adjoint = None self._self_adjoint = None
self._unitary = None self._unitary = None
...@@ -214,7 +215,7 @@ class DiagonalOperator(EndomorphicOperator): ...@@ -214,7 +215,7 @@ class DiagonalOperator(EndomorphicOperator):
else: else:
active_axes = [] active_axes = []
for space_index in spaces: for space_index in spaces:
active_axes += x.domain_axes[space_index] active_axes += x.domain.axes[space_index]
reshaper = [x.shape[i] if i in active_axes else 1 reshaper = [x.shape[i] if i in active_axes else 1
for i in range(len(x.shape))] for i in range(len(x.shape))]
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
from ... import Field, nifty_utilities as utilities from ... import Field, DomainTuple, nifty_utilities as utilities
from ...spaces import RGSpace, GLSpace, HPSpace, LMSpace from ...spaces import RGSpace, GLSpace, HPSpace, LMSpace
from ..linear_operator import LinearOperator from ..linear_operator import LinearOperator
...@@ -96,14 +96,14 @@ class FFTOperator(LinearOperator): ...@@ -96,14 +96,14 @@ class FFTOperator(LinearOperator):
super(FFTOperator, self).__init__(default_spaces) super(FFTOperator, self).__init__(default_spaces)
# Initialize domain and target # Initialize domain and target
self._domain = self._parse_domain(domain) self._domain = DomainTuple.make(domain)
if len(self.domain) != 1: if len(self.domain) != 1:
raise ValueError("TransformationOperator accepts only exactly one " raise ValueError("TransformationOperator accepts only exactly one "
"space as input domain.") "space as input domain.")
if target is None: if target is None:
target = (self.domain[0].get_default_codomain(), ) target = (self.domain[0].get_default_codomain(), )
self._target = self._parse_domain(target) self._target = DomainTuple.make(target)
if len(self.target) != 1: if len(self.target) != 1:
raise ValueError("TransformationOperator accepts only exactly one " raise ValueError("TransformationOperator accepts only exactly one "
"space as output target.") "space as output target.")
...@@ -127,13 +127,13 @@ class FFTOperator(LinearOperator): ...@@ -127,13 +127,13 @@ class FFTOperator(LinearOperator):
# this case means that x lives on only one space, which is # this case means that x lives on only one space, which is
# identical to the space in the domain of `self`. Otherwise the # identical to the space in the domain of `self`. Otherwise the
# input check of LinearOperator would have failed. # input check of LinearOperator would have failed.
axes = x.domain_axes[0] axes = x.domain.axes[0]
result_domain = other result_domain = other
else: else:
spaces = utilities.cast_iseq_to_tuple(spaces) spaces = utilities.cast_iseq_to_tuple(spaces)
result_domain = list(x.domain) result_domain = list(x.domain)
result_domain[spaces[0]] = other[0] result_domain[spaces[0]] = other[0]
axes = x.domain_axes[spaces[0]] axes = x.domain.axes[spaces[0]]
new_val, fct = trafo.transform(x.val, axes=axes) new_val, fct = trafo.transform(x.val, axes=axes)
res = Field(result_domain, new_val, copy=False) res = Field(result_domain, new_val, copy=False)
......
...@@ -20,7 +20,7 @@ import numpy as np ...@@ -20,7 +20,7 @@ import numpy as np
from ...field import Field from ...field import Field
from ...spaces.power_space import PowerSpace from ...spaces.power_space import PowerSpace
from ..endomorphic_operator import EndomorphicOperator from ..endomorphic_operator import EndomorphicOperator
from ... import sqrt from ... import sqrt, DomainTuple
from ... import nifty_utilities as utilities from ... import nifty_utilities as utilities
...@@ -41,7 +41,7 @@ class LaplaceOperator(EndomorphicOperator): ...@@ -41,7 +41,7 @@ class LaplaceOperator(EndomorphicOperator):
def __init__(self, domain, default_spaces=None, logarithmic=True): def __init__(self, domain, default_spaces=None, logarithmic=True):
super(LaplaceOperator, self).__init__(default_spaces) super(LaplaceOperator, self).__init__(default_spaces)
self._domain = self._parse_domain(domain) self._domain = DomainTuple.make(domain)
if len(self.domain) != 1: if len(self.domain) != 1:
raise ValueError("The domain must contain exactly one PowerSpace.") raise ValueError("The domain must contain exactly one PowerSpace.")
...@@ -93,10 +93,10 @@ class LaplaceOperator(EndomorphicOperator): ...@@ -93,10 +93,10 @@ class LaplaceOperator(EndomorphicOperator):
# this case means that x lives on only one space, which is # this case means that x lives on only one space, which is
# identical to the space in the domain of `self`. Otherwise the # identical to the space in the domain of `self`. Otherwise the
# input check of LinearOperator would have failed. # input check of LinearOperator would have failed.
axes = x.domain_axes[0] axes = x.domain.axes[0]
else: else:
spaces = utilities.cast_iseq_to_tuple(spaces) spaces = utilities.cast_iseq_to_tuple(spaces)
axes = x.domain_axes[spaces[0]] axes = x.domain.axes[spaces[0]]
axis = axes[0] axis = axes[0]
nval = len(self._dposc) nval = len(self._dposc)
prefix = (slice(None),) * axis prefix = (slice(None),) * axis
...@@ -119,10 +119,10 @@ class LaplaceOperator(EndomorphicOperator): ...@@ -119,10 +119,10 @@ class LaplaceOperator(EndomorphicOperator):
# this case means that x lives on only one space, which is # this case means that x lives on only one space, which is
# identical to the space in the domain of `self`. Otherwise the # identical to the space in the domain of `self`. Otherwise the
# input check of LinearOperator would have failed. # input check of LinearOperator would have failed.
axes = x.domain_axes[0] axes = x.domain.axes[0]
else: else:
spaces = utilities.cast_iseq_to_tuple(spaces) spaces = utilities.cast_iseq_to_tuple(spaces)
axes = x.domain_axes[spaces[0]] axes = x.domain.axes[spaces[0]]
axis = axes[0] axis = axes[0]
nval = len(self._dposc) nval = len(self._dposc)
prefix = (slice(None),) * axis prefix = (slice(None),) * axis
......
...@@ -69,10 +69,6 @@ class LinearOperator(with_metaclass( ...@@ -69,10 +69,6 @@ class LinearOperator(with_metaclass(
def __init__(self, default_spaces=None): def __init__(self, default_spaces=None):
self._default_spaces = default_spaces self._default_spaces = default_spaces
@staticmethod
def _parse_domain(domain):
return utilities.parse_domain(domain)
@abc.abstractproperty @abc.abstractproperty
def domain(self): def domain(self):
""" """
......
...@@ -88,7 +88,7 @@ class ProjectionOperator(EndomorphicOperator): ...@@ -88,7 +88,7 @@ class ProjectionOperator(EndomorphicOperator):
active_axes = list(range(len(x.shape))) active_axes = list(range(len(x.shape)))
else: else:
for space_index in spaces: for space_index in spaces:
active_axes += x.domain_axes[space_index] active_axes += x.domain.axes[space_index]
local_projection_vector = self._projection_field.val local_projection_vector = self._projection_field.val
......