Commit eadb48d6 authored by Martin Reinecke's avatar Martin Reinecke

consolidation

parent 91032b52
Pipeline #21614 passed with stage
in 4 minutes and 27 seconds
......@@ -10,7 +10,7 @@ from .domain_object import DomainObject
from .basic_arithmetics import *
from .nifty_utilities import *
from .utilities import *
from .field_types import *
......@@ -31,5 +31,3 @@ from . import plotting
from . import library
from . import dobj
from .memoization import memo
......@@ -18,7 +18,7 @@
from __future__ import division
import abc
from .nifty_meta import NiftyMeta
from .utilities import NiftyMeta
from future.utils import with_metaclass
......@@ -38,7 +38,6 @@ class DomainObject(with_metaclass(
raise NotImplementedError
def __hash__(self):
# Extract the identifying parts from the vars(self) dict.
result_hash = 0
for key in self._needed_for_hash:
result_hash ^= hash(vars(self)[key])
......
......@@ -105,24 +105,10 @@ class DomainTuple(object):
return self._dom == x._dom
def __ne__(self, x):
if not isinstance(x, DomainTuple):
x = DomainTuple.make(x)
if self is x:
return False
return self._dom != x._dom
return not self.__eq__(x)
def __str__(self):
res = "DomainTuple, len: " + str(len(self.domains))
for i in self.domains:
res += "\n" + str(i)
return res
def collapsed_shape_for_domain(self, ispace):
"""Returns a three-component shape, with the total number of pixels
in the domains before the requested space in res[0], the total number
of pixels in the requested space in res[1], and the remaining pixels in
res[2].
"""
return (self._accdims[ispace],
self._accdims[ispace+1]//self._accdims[ispace],
self._accdims[-1]//self._accdims[ispace+1])
......@@ -16,8 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from ..nifty_meta import NiftyMeta
from ..memoization import memo
from ..utilities import memo, NiftyMeta
from future.utils import with_metaclass
......
from .energy import Energy
from ..memoization import memo
from ..utilities import memo
class QuadraticEnergy(Energy):
......
......@@ -19,7 +19,7 @@
from __future__ import division
from builtins import range
import numpy as np
from . import nifty_utilities as utilities
from . import utilities
from .domain_tuple import DomainTuple
from functools import reduce
from . import dobj
......@@ -65,7 +65,6 @@ class Field(object):
*the given domain contains something that is not a DomainObject
instance
*val is an array that has a different dimension than the domain
"""
# ---Initialization methods---
......@@ -179,7 +178,6 @@ class Field(object):
out : Field
The output object.
"""
domain = DomainTuple.make(domain)
return Field(domain=domain,
val=dobj.from_random(random_type, dtype=dtype,
......@@ -194,10 +192,6 @@ class Field(object):
def val(self):
""" Returns the data object associated with this Field.
No copy is made.
Returns
-------
out : numpy.ndarray
"""
return self._val
......@@ -211,10 +205,8 @@ class Field(object):
Returns
-------
out : tuple
The output object. The tuple contains the dimensions of the spaces
in domain.
"""
Integer tuple containing the dimensions of the spaces in domain.
"""
return self.domain.shape
@property
......@@ -232,14 +224,12 @@ class Field(object):
@property
def real(self):
""" The real part of the field (data is not copied).
"""
""" The real part of the field (data is not copied)."""
return Field(self.domain, self.val.real)
@property
def imag(self):
""" The imaginary part of the field (data is not copied).
"""
""" The imaginary part of the field (data is not copied)."""
return Field(self.domain, self.val.imag)
# ---Special unary/binary operations---
......@@ -290,7 +280,6 @@ class Field(object):
-------
out : Field
The weighted field.
"""
if out is None:
out = self.copy()
......@@ -313,7 +302,8 @@ class Field(object):
new_shape[self.domain.axes[ind][0]:
self.domain.axes[ind][-1]+1] = wgt.shape
wgt = wgt.reshape(new_shape)
if dobj.distaxis(self._val) >= 0 and ind == 0: # we need to distribute the weights along axis 0
if dobj.distaxis(self._val) >= 0 and ind == 0:
# we need to distribute the weights along axis 0
wgt = dobj.local_data(dobj.from_global_data(wgt))
out *= wgt**power
fct = fct**power
......@@ -336,8 +326,7 @@ class Field(object):
Returns
-------
out : float, complex
out : float, complex, either scalar or Field
"""
if not isinstance(x, Field):
raise ValueError("The dot-partner must be an instance of " +
......@@ -354,15 +343,19 @@ class Field(object):
if spaces is None:
return fct*dobj.vdot(y.val, x.val)
else:
spaces = utilities.cast_iseq_to_tuple(spaces)
active_axes = []
for i in spaces:
active_axes += self.domain.axes[i]
res = 0.
for sl in utilities.get_slice_list(self.shape, active_axes):
res += dobj.vdot(y.val, x.val[sl])
return res*fct
spaces = utilities.cast_iseq_to_tuple(spaces)
if spaces == tuple(range(len(self.domain))): # full contraction
return fct*dobj.vdot(y.val, x.val)
raise NotImplementedError("special case for vdot not yet implemented")
active_axes = []
for i in spaces:
active_axes += self.domain.axes[i]
res = 0.
for sl in utilities.get_slice_list(self.shape, active_axes):
res += dobj.vdot(y.val, x.val[sl])
return res*fct
def norm(self):
""" Computes the L2-norm of the field values.
......@@ -371,7 +364,6 @@ class Field(object):
-------
norm : float
The L2-norm of the field values.
"""
return np.sqrt(np.abs(self.vdot(x=self)))
......@@ -380,9 +372,7 @@ class Field(object):
Returns
-------
cc : field
The complex conjugated field.
The complex conjugated field.
"""
return Field(self.domain, self.val.conjugate(), self.dtype)
......
......@@ -3,7 +3,7 @@ from ...operators.smoothness_operator import SmoothnessOperator
from ...operators.power_projection_operator import PowerProjectionOperator
from ...operators.inversion_enabler import InversionEnabler
from . import CriticalPowerCurvature
from ...memoization import memo
from ...utilities import memo
from ... import Field, exp
from ...sugar import generate_posterior_sample
......
from ...operators import EndomorphicOperator
from ...memoization import memo
from ...utilities import memo
from ...basic_arithmetics import exp
from ...sugar import create_composed_fft_operator
......
from ...energies.energy import Energy
from ...memoization import memo
from ...utilities import memo
from . import LogNormalWienerFilterCurvature
from ...sugar import create_composed_fft_operator
from ...operators.inversion_enabler import InversionEnabler
......
from ...energies.energy import Energy
from ...memoization import memo
from ...utilities import memo
from ...operators.inversion_enabler import InversionEnabler
from . import WienerFilterCurvature
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
def memo(f):
name = f.__name__
def wrapped_f(self):
if not hasattr(self, "_cache"):
self._cache = {}
try:
return self._cache[name]
except KeyError:
self._cache[name] = f(self)
return self._cache[name]
return wrapped_f
......@@ -18,7 +18,7 @@
from builtins import range
import abc
from ...nifty_meta import NiftyMeta
from ...utilities import NiftyMeta
from future.utils import with_metaclass
......
......@@ -17,7 +17,7 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
import abc
from ..nifty_meta import NiftyMeta
from ..utilities import NiftyMeta
from future.utils import with_metaclass
......
import abc
class DocStringInheritor(type):
"""
A variation on
http://groups.google.com/group/comp.lang.python/msg/26f7b4fcb4d66c95
by Paul McGuire
"""
def __new__(meta, name, bases, clsdict):
if not('__doc__' in clsdict and clsdict['__doc__']):
for mro_cls in (mro_cls for base in bases
for mro_cls in base.mro()):
doc = mro_cls.__doc__
if doc:
clsdict['__doc__'] = doc
break
for attr, attribute in list(clsdict.items()):
if not attribute.__doc__:
for mro_cls in (mro_cls for base in bases
for mro_cls in base.mro()
if hasattr(mro_cls, attr)):
doc = getattr(getattr(mro_cls, attr), '__doc__')
if doc:
if isinstance(attribute, property):
clsdict[attr] = property(attribute.fget,
attribute.fset,
attribute.fdel,
doc)
else:
attribute.__doc__ = doc
break
return super(DocStringInheritor, meta).__new__(meta, name,
bases, clsdict)
class NiftyMeta(DocStringInheritor, abc.ABCMeta):
pass
......@@ -21,7 +21,7 @@ import numpy as np
from ..field import Field
from ..domain_tuple import DomainTuple
from .endomorphic_operator import EndomorphicOperator
from ..nifty_utilities import cast_iseq_to_tuple
from ..utilities import cast_iseq_to_tuple
from .. import dobj
......
......@@ -18,7 +18,7 @@
from __future__ import division
import numpy as np
from .. import nifty_utilities as utilities
from .. import utilities
from ..low_level_library import hartley
from .. import dobj
from ..field import Field
......
......@@ -18,7 +18,7 @@
from builtins import str
import abc
from ..nifty_meta import NiftyMeta
from ..utilities import NiftyMeta
from ..field import Field
from future.utils import with_metaclass
......
......@@ -21,7 +21,7 @@ from builtins import range
from builtins import object
import numpy as np
from ..field import Field, DomainTuple
from .. import nifty_utilities as utilities
from .. import utilities
class Prober(object):
......
......@@ -51,7 +51,9 @@ class RGSpace(Space):
self._needed_for_hash += ["_distances", "_shape", "_harmonic"]
self._harmonic = bool(harmonic)
self._shape = self._parse_shape(shape)
if np.isscalar(shape):
shape = (shape,)
self._shape = tuple(int(i) for i in shape)
self._distances = self._parse_distances(distances)
self._dvol = float(reduce(lambda x, y: x*y, self._distances))
self._dim = int(reduce(lambda x, y: x*y, self._shape))
......@@ -163,17 +165,12 @@ class RGSpace(Space):
@property
def distances(self):
"""Distance between two grid points along each axis. It is a tuple
"""Distance between grid points along each axis. It is a tuple
of positive floating point numbers with the n-th entry giving the
distances of grid points along the n-th dimension.
distance between neighboring grid points along the n-th dimension.
"""
return self._distances
def _parse_shape(self, shape):
if np.isscalar(shape):
return (shape,)
return tuple(np.array(shape, dtype=np.int))
def _parse_distances(self, distances):
if distances is None:
if self.harmonic:
......
......@@ -18,8 +18,8 @@
import numpy as np
from . import Space, PowerSpace, Field, ComposedOperator, DiagonalOperator,\
PowerProjectionOperator, FFTOperator, sqrt, DomainTuple, dobj
from . import nifty_utilities as utilities
PowerProjectionOperator, FFTOperator, sqrt, DomainTuple, dobj,\
utilities
__all__ = ['PS_field',
'power_analyze',
......
......@@ -19,6 +19,7 @@
from builtins import next, range
import numpy as np
from itertools import product
import abc
def get_slice_list(shape, axes):
......@@ -42,10 +43,8 @@ def get_slice_list(shape, axes):
------
ValueError
If shape is empty.
ValueError
If axes(axis) does not match shape.
"""
if shape is None:
raise ValueError("shape cannot be None.")
......@@ -72,3 +71,54 @@ def cast_iseq_to_tuple(seq):
if np.isscalar(seq):
return (int(seq),)
return tuple(int(item) for item in seq)
def memo(f):
name = f.__name__
def wrapped_f(self):
if not hasattr(self, "_cache"):
self._cache = {}
try:
return self._cache[name]
except KeyError:
self._cache[name] = f(self)
return self._cache[name]
return wrapped_f
class _DocStringInheritor(type):
"""
A variation on
http://groups.google.com/group/comp.lang.python/msg/26f7b4fcb4d66c95
by Paul McGuire
"""
def __new__(meta, name, bases, clsdict):
if not('__doc__' in clsdict and clsdict['__doc__']):
for mro_cls in (mro_cls for base in bases
for mro_cls in base.mro()):
doc = mro_cls.__doc__
if doc:
clsdict['__doc__'] = doc
break
for attr, attribute in list(clsdict.items()):
if not attribute.__doc__:
for mro_cls in (mro_cls for base in bases
for mro_cls in base.mro()
if hasattr(mro_cls, attr)):
doc = getattr(getattr(mro_cls, attr), '__doc__')
if doc:
if isinstance(attribute, property):
clsdict[attr] = property(attribute.fget,
attribute.fset,
attribute.fdel,
doc)
else:
attribute.__doc__ = doc
break
return super(_DocStringInheritor, meta).__new__(meta, name,
bases, clsdict)
class NiftyMeta(_DocStringInheritor, abc.ABCMeta):
pass
......@@ -127,5 +127,5 @@ class Test_Functionality(unittest.TestCase):
s = ift.RGSpace((10,))
f1 = ift.Field.from_random("normal", domain=s, dtype=np.complex128)
f2 = ift.Field.from_random("normal", domain=s, dtype=np.complex128)
# assert_allclose(f1.vdot(f2), f1.vdot(f2, spaces=0))
assert_allclose(f1.vdot(f2), f1.vdot(f2, spaces=0))
assert_allclose(f1.vdot(f2), np.conj(f2.vdot(f1)))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment