diff --git a/nifty/__init__.py b/nifty/__init__.py index 8595f2c0ca32a18bf9fc24dbdb53b7f468884915..e122f454e0b8ae99a4dcab4aa48d3a5e5df67e3c 100644 --- a/nifty/__init__.py +++ b/nifty/__init__.py @@ -10,7 +10,7 @@ from .domain_object import DomainObject from .basic_arithmetics import * -from .nifty_utilities import * +from .utilities import * from .field_types import * @@ -31,5 +31,3 @@ from . import plotting from . import library from . import dobj - -from .memoization import memo diff --git a/nifty/domain_object.py b/nifty/domain_object.py index 22193c51b6bdf1913054213240fc478398b66734..301d4599ad548bbd097b22824412907e96f4897e 100644 --- a/nifty/domain_object.py +++ b/nifty/domain_object.py @@ -18,7 +18,7 @@ from __future__ import division import abc -from .nifty_meta import NiftyMeta +from .utilities import NiftyMeta from future.utils import with_metaclass @@ -38,7 +38,6 @@ class DomainObject(with_metaclass( raise NotImplementedError def __hash__(self): - # Extract the identifying parts from the vars(self) dict. result_hash = 0 for key in self._needed_for_hash: result_hash ^= hash(vars(self)[key]) diff --git a/nifty/domain_tuple.py b/nifty/domain_tuple.py index 87e5abb569cc56dd7026c8bc2aa800097e508011..df52725ffcf611cb01cc818c55c76e4f1ca46629 100644 --- a/nifty/domain_tuple.py +++ b/nifty/domain_tuple.py @@ -105,24 +105,10 @@ class DomainTuple(object): return self._dom == x._dom def __ne__(self, x): - if not isinstance(x, DomainTuple): - x = DomainTuple.make(x) - if self is x: - return False - return self._dom != x._dom + return not self.__eq__(x) def __str__(self): res = "DomainTuple, len: " + str(len(self.domains)) for i in self.domains: res += "\n" + str(i) return res - - def collapsed_shape_for_domain(self, ispace): - """Returns a three-component shape, with the total number of pixels - in the domains before the requested space in res[0], the total number - of pixels in the requested space in res[1], and the remaining pixels in - res[2]. - """ - return (self._accdims[ispace], - self._accdims[ispace+1]//self._accdims[ispace], - self._accdims[-1]//self._accdims[ispace+1]) diff --git a/nifty/energies/energy.py b/nifty/energies/energy.py index c6df68e1169dd869f572bcb69173627ad1849be5..d53e3c1ec9297dc00a0c91d7f6791d2836a8051c 100644 --- a/nifty/energies/energy.py +++ b/nifty/energies/energy.py @@ -16,8 +16,7 @@ # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # and financially supported by the Studienstiftung des deutschen Volkes. -from ..nifty_meta import NiftyMeta -from ..memoization import memo +from ..utilities import memo, NiftyMeta from future.utils import with_metaclass diff --git a/nifty/energies/quadratic_energy.py b/nifty/energies/quadratic_energy.py index f50e4aae4412ae2039b14c2e7759234929d534cf..b87d6bc568d8b19a7cc14cadde8213b4b47918cc 100644 --- a/nifty/energies/quadratic_energy.py +++ b/nifty/energies/quadratic_energy.py @@ -1,5 +1,5 @@ from .energy import Energy -from ..memoization import memo +from ..utilities import memo class QuadraticEnergy(Energy): diff --git a/nifty/field.py b/nifty/field.py index fa6481a19d084918c19bc1c2ff7a90e7f731fa5d..99861349d7b199e3d40dd75aa5de3d8de52ac584 100644 --- a/nifty/field.py +++ b/nifty/field.py @@ -19,7 +19,7 @@ from __future__ import division from builtins import range import numpy as np -from . import nifty_utilities as utilities +from . import utilities from .domain_tuple import DomainTuple from functools import reduce from . import dobj @@ -65,7 +65,6 @@ class Field(object): *the given domain contains something that is not a DomainObject instance *val is an array that has a different dimension than the domain - """ # ---Initialization methods--- @@ -179,7 +178,6 @@ class Field(object): out : Field The output object. """ - domain = DomainTuple.make(domain) return Field(domain=domain, val=dobj.from_random(random_type, dtype=dtype, @@ -194,10 +192,6 @@ class Field(object): def val(self): """ Returns the data object associated with this Field. No copy is made. - - Returns - ------- - out : numpy.ndarray """ return self._val @@ -211,10 +205,8 @@ class Field(object): Returns ------- - out : tuple - The output object. The tuple contains the dimensions of the spaces - in domain. - """ + Integer tuple containing the dimensions of the spaces in domain. + """ return self.domain.shape @property @@ -232,14 +224,12 @@ class Field(object): @property def real(self): - """ The real part of the field (data is not copied). - """ + """ The real part of the field (data is not copied).""" return Field(self.domain, self.val.real) @property def imag(self): - """ The imaginary part of the field (data is not copied). - """ + """ The imaginary part of the field (data is not copied).""" return Field(self.domain, self.val.imag) # ---Special unary/binary operations--- @@ -290,7 +280,6 @@ class Field(object): ------- out : Field The weighted field. - """ if out is None: out = self.copy() @@ -313,7 +302,8 @@ class Field(object): new_shape[self.domain.axes[ind][0]: self.domain.axes[ind][-1]+1] = wgt.shape wgt = wgt.reshape(new_shape) - if dobj.distaxis(self._val) >= 0 and ind == 0: # we need to distribute the weights along axis 0 + if dobj.distaxis(self._val) >= 0 and ind == 0: + # we need to distribute the weights along axis 0 wgt = dobj.local_data(dobj.from_global_data(wgt)) out *= wgt**power fct = fct**power @@ -336,8 +326,7 @@ class Field(object): Returns ------- - out : float, complex - + out : float, complex, either scalar or Field """ if not isinstance(x, Field): raise ValueError("The dot-partner must be an instance of " + @@ -354,15 +343,19 @@ class Field(object): if spaces is None: return fct*dobj.vdot(y.val, x.val) - else: - spaces = utilities.cast_iseq_to_tuple(spaces) - active_axes = [] - for i in spaces: - active_axes += self.domain.axes[i] - res = 0. - for sl in utilities.get_slice_list(self.shape, active_axes): - res += dobj.vdot(y.val, x.val[sl]) - return res*fct + + spaces = utilities.cast_iseq_to_tuple(spaces) + if spaces == tuple(range(len(self.domain))): # full contraction + return fct*dobj.vdot(y.val, x.val) + + raise NotImplementedError("special case for vdot not yet implemented") + active_axes = [] + for i in spaces: + active_axes += self.domain.axes[i] + res = 0. + for sl in utilities.get_slice_list(self.shape, active_axes): + res += dobj.vdot(y.val, x.val[sl]) + return res*fct def norm(self): """ Computes the L2-norm of the field values. @@ -371,7 +364,6 @@ class Field(object): ------- norm : float The L2-norm of the field values. - """ return np.sqrt(np.abs(self.vdot(x=self))) @@ -380,9 +372,7 @@ class Field(object): Returns ------- - cc : field - The complex conjugated field. - + The complex conjugated field. """ return Field(self.domain, self.val.conjugate(), self.dtype) diff --git a/nifty/library/critical_filter/critical_power_energy.py b/nifty/library/critical_filter/critical_power_energy.py index 0c7af23a579b6b0c5834339b2148f9fa3f2ee8a2..30d73826f70f4a36bfbd880d345fb5ec735e10e4 100644 --- a/nifty/library/critical_filter/critical_power_energy.py +++ b/nifty/library/critical_filter/critical_power_energy.py @@ -3,7 +3,7 @@ from ...operators.smoothness_operator import SmoothnessOperator from ...operators.power_projection_operator import PowerProjectionOperator from ...operators.inversion_enabler import InversionEnabler from . import CriticalPowerCurvature -from ...memoization import memo +from ...utilities import memo from ... import Field, exp from ...sugar import generate_posterior_sample diff --git a/nifty/library/log_normal_wiener_filter/log_normal_wiener_filter_curvature.py b/nifty/library/log_normal_wiener_filter/log_normal_wiener_filter_curvature.py index 74309c4f68116fc45fabe4b223abc914ac62f614..add9cc58ff255271e8fa87e066cb7126e8370918 100644 --- a/nifty/library/log_normal_wiener_filter/log_normal_wiener_filter_curvature.py +++ b/nifty/library/log_normal_wiener_filter/log_normal_wiener_filter_curvature.py @@ -1,5 +1,5 @@ from ...operators import EndomorphicOperator -from ...memoization import memo +from ...utilities import memo from ...basic_arithmetics import exp from ...sugar import create_composed_fft_operator diff --git a/nifty/library/log_normal_wiener_filter/log_normal_wiener_filter_energy.py b/nifty/library/log_normal_wiener_filter/log_normal_wiener_filter_energy.py index 56f411841d4696a99c03b8e9f7a179f1e3c7cc09..f3a7a59dc49a7cc2e717464081a9f0d2eb6960da 100644 --- a/nifty/library/log_normal_wiener_filter/log_normal_wiener_filter_energy.py +++ b/nifty/library/log_normal_wiener_filter/log_normal_wiener_filter_energy.py @@ -1,5 +1,5 @@ from ...energies.energy import Energy -from ...memoization import memo +from ...utilities import memo from . import LogNormalWienerFilterCurvature from ...sugar import create_composed_fft_operator from ...operators.inversion_enabler import InversionEnabler diff --git a/nifty/library/wiener_filter/wiener_filter_energy.py b/nifty/library/wiener_filter/wiener_filter_energy.py index a6210a88ca6cb659fde1e236a1e5885dd28ec945..51e33c2db8bad0eea8aa9fdcd9f06e32e76cccc0 100644 --- a/nifty/library/wiener_filter/wiener_filter_energy.py +++ b/nifty/library/wiener_filter/wiener_filter_energy.py @@ -1,5 +1,5 @@ from ...energies.energy import Energy -from ...memoization import memo +from ...utilities import memo from ...operators.inversion_enabler import InversionEnabler from . import WienerFilterCurvature diff --git a/nifty/memoization.py b/nifty/memoization.py deleted file mode 100644 index daa21d55c84e424589109b413e6f67b35ff968b7..0000000000000000000000000000000000000000 --- a/nifty/memoization.py +++ /dev/null @@ -1,31 +0,0 @@ -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see <http://www.gnu.org/licenses/>. -# -# Copyright(C) 2013-2017 Max-Planck-Society -# -# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik -# and financially supported by the Studienstiftung des deutschen Volkes. - - -def memo(f): - name = f.__name__ - - def wrapped_f(self): - if not hasattr(self, "_cache"): - self._cache = {} - try: - return self._cache[name] - except KeyError: - self._cache[name] = f(self) - return self._cache[name] - return wrapped_f diff --git a/nifty/minimization/iteration_controlling/iteration_controller.py b/nifty/minimization/iteration_controlling/iteration_controller.py index 001f06eca4c5a9868df5d1171653b88628014269..ccdcb3899c77ddd9ba19fdd9afc149ae022bb291 100644 --- a/nifty/minimization/iteration_controlling/iteration_controller.py +++ b/nifty/minimization/iteration_controlling/iteration_controller.py @@ -18,7 +18,7 @@ from builtins import range import abc -from ...nifty_meta import NiftyMeta +from ...utilities import NiftyMeta from future.utils import with_metaclass diff --git a/nifty/minimization/minimizer.py b/nifty/minimization/minimizer.py index ee472a34509a5b7f055f191b0fe1f9225fbede0f..0809da8a32d62970a5988e095469a6a2243ae11a 100644 --- a/nifty/minimization/minimizer.py +++ b/nifty/minimization/minimizer.py @@ -17,7 +17,7 @@ # and financially supported by the Studienstiftung des deutschen Volkes. import abc -from ..nifty_meta import NiftyMeta +from ..utilities import NiftyMeta from future.utils import with_metaclass diff --git a/nifty/nifty_meta.py b/nifty/nifty_meta.py deleted file mode 100644 index 3d24bc621aca3ae0d1cf2743912b8caf09056bc9..0000000000000000000000000000000000000000 --- a/nifty/nifty_meta.py +++ /dev/null @@ -1,38 +0,0 @@ -import abc - - -class DocStringInheritor(type): - """ - A variation on - http://groups.google.com/group/comp.lang.python/msg/26f7b4fcb4d66c95 - by Paul McGuire - """ - def __new__(meta, name, bases, clsdict): - if not('__doc__' in clsdict and clsdict['__doc__']): - for mro_cls in (mro_cls for base in bases - for mro_cls in base.mro()): - doc = mro_cls.__doc__ - if doc: - clsdict['__doc__'] = doc - break - for attr, attribute in list(clsdict.items()): - if not attribute.__doc__: - for mro_cls in (mro_cls for base in bases - for mro_cls in base.mro() - if hasattr(mro_cls, attr)): - doc = getattr(getattr(mro_cls, attr), '__doc__') - if doc: - if isinstance(attribute, property): - clsdict[attr] = property(attribute.fget, - attribute.fset, - attribute.fdel, - doc) - else: - attribute.__doc__ = doc - break - return super(DocStringInheritor, meta).__new__(meta, name, - bases, clsdict) - - -class NiftyMeta(DocStringInheritor, abc.ABCMeta): - pass diff --git a/nifty/operators/diagonal_operator.py b/nifty/operators/diagonal_operator.py index 11b0ed469628c395c5619709c88e13ae0379f74d..6ff48998d146d279831180fc9361b4630f528171 100644 --- a/nifty/operators/diagonal_operator.py +++ b/nifty/operators/diagonal_operator.py @@ -21,7 +21,7 @@ import numpy as np from ..field import Field from ..domain_tuple import DomainTuple from .endomorphic_operator import EndomorphicOperator -from ..nifty_utilities import cast_iseq_to_tuple +from ..utilities import cast_iseq_to_tuple from .. import dobj diff --git a/nifty/operators/fft_operator_support.py b/nifty/operators/fft_operator_support.py index c1f170912eafee1739672748efd805a36a748227..37730afe8984bd1d5b26af8cce229397ab00156d 100644 --- a/nifty/operators/fft_operator_support.py +++ b/nifty/operators/fft_operator_support.py @@ -18,7 +18,7 @@ from __future__ import division import numpy as np -from .. import nifty_utilities as utilities +from .. import utilities from ..low_level_library import hartley from .. import dobj from ..field import Field diff --git a/nifty/operators/linear_operator.py b/nifty/operators/linear_operator.py index e6400bc2642a45c2f240d7da9c2c1cb22a1b9660..fad51f95843b59f289a0a207679ae443e91d1f41 100644 --- a/nifty/operators/linear_operator.py +++ b/nifty/operators/linear_operator.py @@ -18,7 +18,7 @@ from builtins import str import abc -from ..nifty_meta import NiftyMeta +from ..utilities import NiftyMeta from ..field import Field from future.utils import with_metaclass diff --git a/nifty/probing/prober.py b/nifty/probing/prober.py index cc9207d5da359b47f50606b58910a512a753e58d..70fd69a6abe467545c3428ff78dd3a00cb0f3bee 100644 --- a/nifty/probing/prober.py +++ b/nifty/probing/prober.py @@ -21,7 +21,7 @@ from builtins import range from builtins import object import numpy as np from ..field import Field, DomainTuple -from .. import nifty_utilities as utilities +from .. import utilities class Prober(object): diff --git a/nifty/spaces/rg_space.py b/nifty/spaces/rg_space.py index 52736a7cd1aac82db5c5d6ceaedee38a514c11f2..04272e2d839d3b34f633d44a1498ba8f7c964902 100644 --- a/nifty/spaces/rg_space.py +++ b/nifty/spaces/rg_space.py @@ -51,7 +51,9 @@ class RGSpace(Space): self._needed_for_hash += ["_distances", "_shape", "_harmonic"] self._harmonic = bool(harmonic) - self._shape = self._parse_shape(shape) + if np.isscalar(shape): + shape = (shape,) + self._shape = tuple(int(i) for i in shape) self._distances = self._parse_distances(distances) self._dvol = float(reduce(lambda x, y: x*y, self._distances)) self._dim = int(reduce(lambda x, y: x*y, self._shape)) @@ -163,17 +165,12 @@ class RGSpace(Space): @property def distances(self): - """Distance between two grid points along each axis. It is a tuple + """Distance between grid points along each axis. It is a tuple of positive floating point numbers with the n-th entry giving the - distances of grid points along the n-th dimension. + distance between neighboring grid points along the n-th dimension. """ return self._distances - def _parse_shape(self, shape): - if np.isscalar(shape): - return (shape,) - return tuple(np.array(shape, dtype=np.int)) - def _parse_distances(self, distances): if distances is None: if self.harmonic: diff --git a/nifty/sugar.py b/nifty/sugar.py index 8c168f74e7f5ce6e3d189a07dba3c83b996cb523..656e06667b5b749b6ea5228540d1d7258ce81c75 100644 --- a/nifty/sugar.py +++ b/nifty/sugar.py @@ -18,8 +18,8 @@ import numpy as np from . import Space, PowerSpace, Field, ComposedOperator, DiagonalOperator,\ - PowerProjectionOperator, FFTOperator, sqrt, DomainTuple, dobj -from . import nifty_utilities as utilities + PowerProjectionOperator, FFTOperator, sqrt, DomainTuple, dobj,\ + utilities __all__ = ['PS_field', 'power_analyze', diff --git a/nifty/nifty_utilities.py b/nifty/utilities.py similarity index 55% rename from nifty/nifty_utilities.py rename to nifty/utilities.py index c8df74b54e3eb7a04f26eff0585d4e3f8e1ec3c7..6f6899ce149c8a2ff4bafcd54d7d393928c40a56 100644 --- a/nifty/nifty_utilities.py +++ b/nifty/utilities.py @@ -19,6 +19,7 @@ from builtins import next, range import numpy as np from itertools import product +import abc def get_slice_list(shape, axes): @@ -42,10 +43,8 @@ def get_slice_list(shape, axes): ------ ValueError If shape is empty. - ValueError If axes(axis) does not match shape. """ - if shape is None: raise ValueError("shape cannot be None.") @@ -72,3 +71,54 @@ def cast_iseq_to_tuple(seq): if np.isscalar(seq): return (int(seq),) return tuple(int(item) for item in seq) + + +def memo(f): + name = f.__name__ + + def wrapped_f(self): + if not hasattr(self, "_cache"): + self._cache = {} + try: + return self._cache[name] + except KeyError: + self._cache[name] = f(self) + return self._cache[name] + return wrapped_f + + +class _DocStringInheritor(type): + """ + A variation on + http://groups.google.com/group/comp.lang.python/msg/26f7b4fcb4d66c95 + by Paul McGuire + """ + def __new__(meta, name, bases, clsdict): + if not('__doc__' in clsdict and clsdict['__doc__']): + for mro_cls in (mro_cls for base in bases + for mro_cls in base.mro()): + doc = mro_cls.__doc__ + if doc: + clsdict['__doc__'] = doc + break + for attr, attribute in list(clsdict.items()): + if not attribute.__doc__: + for mro_cls in (mro_cls for base in bases + for mro_cls in base.mro() + if hasattr(mro_cls, attr)): + doc = getattr(getattr(mro_cls, attr), '__doc__') + if doc: + if isinstance(attribute, property): + clsdict[attr] = property(attribute.fget, + attribute.fset, + attribute.fdel, + doc) + else: + attribute.__doc__ = doc + break + return super(_DocStringInheritor, meta).__new__(meta, name, + bases, clsdict) + + +class NiftyMeta(_DocStringInheritor, abc.ABCMeta): + pass diff --git a/test/test_field.py b/test/test_field.py index 9da6e095fa502fa04d1c6960fc652c53989d4f55..52f998a20a39f48ceeaf09f184bd52477a6f4931 100644 --- a/test/test_field.py +++ b/test/test_field.py @@ -127,5 +127,5 @@ class Test_Functionality(unittest.TestCase): s = ift.RGSpace((10,)) f1 = ift.Field.from_random("normal", domain=s, dtype=np.complex128) f2 = ift.Field.from_random("normal", domain=s, dtype=np.complex128) - # assert_allclose(f1.vdot(f2), f1.vdot(f2, spaces=0)) + assert_allclose(f1.vdot(f2), f1.vdot(f2, spaces=0)) assert_allclose(f1.vdot(f2), np.conj(f2.vdot(f1)))