Commit 12e3d597 by Martin Reinecke

### simplify DiagonalOperator

parent 643d308b
 ... @@ -3,6 +3,71 @@ ... @@ -3,6 +3,71 @@ import numpy as np import numpy as np from numpy import ndarray as data_object from numpy import ndarray as data_object from numpy import full, empty, sqrt, ones, zeros, vdot, abs, bincount from numpy import full, empty, sqrt, ones, zeros, vdot, abs, bincount from ..nifty_utilities import cast_iseq_to_tuple, get_slice_list from functools import reduce def from_object(object, dtype=None, copy=True): def from_object(object, dtype=None, copy=True): return np.array(object, dtype=dtype, copy=copy) return np.array(object, dtype=dtype, copy=copy) def bincount_axis(obj, minlength=None, weights=None, axis=None): if minlength is not None: length = max(np.amax(obj) + 1, minlength) else: length = np.amax(obj) + 1 if obj.shape == (): raise ValueError("object of too small depth for desired array") data = obj # if present, parse the axis keyword and transpose/reorder self.data # such that all affected axes follow each other. Only if they are in a # sequence flattening will be possible if axis is not None: # do the reordering ndim = len(obj.shape) axis = sorted(cast_iseq_to_tuple(axis)) reordering = [x for x in range(ndim) if x not in axis] reordering += axis data = np.transpose(data, reordering) if weights is not None: weights = np.transpose(weights, reordering) reord_axis = list(range(ndim-len(axis), ndim)) # semi-flatten the dimensions in `axis`, i.e. after reordering # the last ones. semi_flat_dim = reduce(lambda x, y: x*y, data.shape[ndim-len(reord_axis):]) flat_shape = data.shape[:ndim-len(reord_axis)] + (semi_flat_dim, ) else: flat_shape = (reduce(lambda x, y: x*y, data.shape), ) data = np.ascontiguousarray(data.reshape(flat_shape)) if weights is not None: weights = np.ascontiguousarray(weights.reshape(flat_shape)) # compute the local bincount results # -> prepare the local result array result_dtype = np.int if weights is None else np.float local_counts = np.empty(flat_shape[:-1] + (length, ), dtype=result_dtype) # iterate over all entries in the surviving axes and compute the local # bincounts for slice_list in get_slice_list(flat_shape, axes=(len(flat_shape)-1,)): current_weights = None if weights is None else weights[slice_list] local_counts[slice_list] = np.bincount(data[slice_list], weights=current_weights, minlength=length) # restore the original ordering # place the bincount stuff at the location of the first `axis` entry if axis is not None: # axis has been sorted above insert_position = axis[0] new_ndim = len(local_counts.shape) return_order = (list(range(0, insert_position)) + [new_ndim-1, ] + list(range(insert_position, new_ndim-1))) local_counts = np.ascontiguousarray( local_counts.transpose(return_order)) return local_counts
 ... @@ -231,8 +231,8 @@ class Field(object): ... @@ -231,8 +231,8 @@ class Field(object): new_pindex_shape[ax] = pindex.shape[i] new_pindex_shape[ax] = pindex.shape[i] pindex = np.broadcast_to(pindex.reshape(new_pindex_shape), field.shape) pindex = np.broadcast_to(pindex.reshape(new_pindex_shape), field.shape) power_spectrum = utilities.bincount_axis(pindex, weights=field.val, power_spectrum = dobj.bincount_axis(pindex, weights=field.val, axis=axes) axis=axes) new_rho_shape = [1] * len(power_spectrum.shape) new_rho_shape = [1] * len(power_spectrum.shape) new_rho_shape[axes[0]] = len(power_domain.rho) new_rho_shape[axes[0]] = len(power_domain.rho) power_spectrum /= power_domain.rho.reshape(new_rho_shape) power_spectrum /= power_domain.rho.reshape(new_rho_shape) ... @@ -510,7 +510,7 @@ class Field(object): ... @@ -510,7 +510,7 @@ class Field(object): # create a diagonal operator which is capable of taking care of the # create a diagonal operator which is capable of taking care of the # axes-matching # axes-matching from .operators.diagonal_operator import DiagonalOperator from .operators.diagonal_operator import DiagonalOperator diag = DiagonalOperator(y.domain, y.conjugate(), copy=False) diag = DiagonalOperator(y.conjugate(), copy=False) dotted = diag(x, spaces=spaces) dotted = diag(x, spaces=spaces) return fct*dotted.sum(spaces=spaces) return fct*dotted.sum(spaces=spaces) ... ...
 ... @@ -19,8 +19,6 @@ ... @@ -19,8 +19,6 @@ from builtins import next, range from builtins import next, range import numpy as np import numpy as np from itertools import product from itertools import product from functools import reduce from .domain_object import DomainObject def get_slice_list(shape, axes): def get_slice_list(shape, axes): ... @@ -74,67 +72,3 @@ def cast_iseq_to_tuple(seq): ... @@ -74,67 +72,3 @@ def cast_iseq_to_tuple(seq): if np.isscalar(seq): if np.isscalar(seq): return (int(seq),) return (int(seq),) return tuple(int(item) for item in seq) return tuple(int(item) for item in seq) def bincount_axis(obj, minlength=None, weights=None, axis=None): if minlength is not None: length = max(np.amax(obj) + 1, minlength) else: length = np.amax(obj) + 1 if obj.shape == (): raise ValueError("object of too small depth for desired array") data = obj # if present, parse the axis keyword and transpose/reorder self.data # such that all affected axes follow each other. Only if they are in a # sequence flattening will be possible if axis is not None: # do the reordering ndim = len(obj.shape) axis = sorted(cast_iseq_to_tuple(axis)) reordering = [x for x in range(ndim) if x not in axis] reordering += axis data = np.transpose(data, reordering) if weights is not None: weights = np.transpose(weights, reordering) reord_axis = list(range(ndim-len(axis), ndim)) # semi-flatten the dimensions in `axis`, i.e. after reordering # the last ones. semi_flat_dim = reduce(lambda x, y: x*y, data.shape[ndim-len(reord_axis):]) flat_shape = data.shape[:ndim-len(reord_axis)] + (semi_flat_dim, ) else: flat_shape = (reduce(lambda x, y: x*y, data.shape), ) data = np.ascontiguousarray(data.reshape(flat_shape)) if weights is not None: weights = np.ascontiguousarray(weights.reshape(flat_shape)) # compute the local bincount results # -> prepare the local result array result_dtype = np.int if weights is None else np.float local_counts = np.empty(flat_shape[:-1] + (length, ), dtype=result_dtype) # iterate over all entries in the surviving axes and compute the local # bincounts for slice_list in get_slice_list(flat_shape, axes=(len(flat_shape)-1,)): current_weights = None if weights is None else weights[slice_list] local_counts[slice_list] = np.bincount(data[slice_list], weights=current_weights, minlength=length) # restore the original ordering # place the bincount stuff at the location of the first `axis` entry if axis is not None: # axis has been sorted above insert_position = axis[0] new_ndim = len(local_counts.shape) return_order = (list(range(0, insert_position)) + [new_ndim-1, ] + list(range(insert_position, new_ndim-1))) local_counts = np.ascontiguousarray( local_counts.transpose(return_order)) return local_counts
 ... @@ -35,9 +35,7 @@ class DiagonalOperator(EndomorphicOperator): ... @@ -35,9 +35,7 @@ class DiagonalOperator(EndomorphicOperator): Parameters Parameters ---------- ---------- domain : tuple of DomainObjects, i.e. Spaces and FieldTypes diagonal : Field The domain on which the Operator's input Field lives. diagonal : {scalar, list, array, Field} The diagonal entries of the operator. The diagonal entries of the operator. copy : boolean copy : boolean Internal copy of the diagonal (default: True) Internal copy of the diagonal (default: True) ... @@ -68,15 +66,14 @@ class DiagonalOperator(EndomorphicOperator): ... @@ -68,15 +66,14 @@ class DiagonalOperator(EndomorphicOperator): # ---Overwritten properties and methods--- # ---Overwritten properties and methods--- def __init__(self, domain=(), diagonal=None, copy=True, def __init__(self, diagonal, copy=True, default_spaces=None): default_spaces=None): super(DiagonalOperator, self).__init__(default_spaces) super(DiagonalOperator, self).__init__(default_spaces) self._domain = DomainTuple.make(domain) if not isinstance(diagonal, Field): raise TypeError("Field object required") self._diagonal = diagonal if not copy else diagonal.copy() self._self_adjoint = None self._self_adjoint = None self._unitary = None self._unitary = None self.set_diagonal(diagonal=diagonal, copy=copy) def _times(self, x, spaces): def _times(self, x, spaces): return self._times_helper(x, spaces, operation=lambda z: z.__mul__) return self._times_helper(x, spaces, operation=lambda z: z.__mul__) ... @@ -119,13 +116,13 @@ class DiagonalOperator(EndomorphicOperator): ... @@ -119,13 +116,13 @@ class DiagonalOperator(EndomorphicOperator): The inverse of the diagonal of the Operator. The inverse of the diagonal of the Operator. """ """ return 1./self.diagonal(copy=False) return 1./self._diagonal # ---Mandatory properties and methods--- # ---Mandatory properties and methods--- @property @property def domain(self): def domain(self): return self._domain return self._diagonal.domain @property @property def self_adjoint(self): def self_adjoint(self): ... @@ -144,30 +141,6 @@ class DiagonalOperator(EndomorphicOperator): ... @@ -144,30 +141,6 @@ class DiagonalOperator(EndomorphicOperator): # ---Added properties and methods--- # ---Added properties and methods--- def set_diagonal(self, diagonal, copy=True): """ Sets the diagonal of the Operator. Parameters ---------- diagonal : {scalar, list, array, Field} The diagonal entries of the operator. copy : boolean Specifies if a copy of the input shall be made (default: True). """ # use the casting functionality from Field to process `diagonal` f = Field(domain=self.domain, val=diagonal, copy=copy) # Reset the self_adjoint property: self._self_adjoint = None # Reset the unitarity property self._unitary = None # store the diagonal-field self._diagonal = f def _times_helper(self, x, spaces, operation): def _times_helper(self, x, spaces, operation): # if the domain matches directly # if the domain matches directly # -> multiply the fields directly # -> multiply the fields directly ... ...
 ... @@ -59,8 +59,8 @@ class ResponseOperator(LinearOperator): ... @@ -59,8 +59,8 @@ class ResponseOperator(LinearOperator): kernel_smoothing = [FFTSmoothingOperator(self._domain[x], sigma[x]) kernel_smoothing = [FFTSmoothingOperator(self._domain[x], sigma[x]) for x in range(nsigma)] for x in range(nsigma)] kernel_exposure = [DiagonalOperator(self._domain[x], kernel_exposure = [DiagonalOperator(Field(self._domain[x],exposure[x])) diagonal=exposure[x]) for x in range(nsigma)] for x in range(nsigma)] self._composed_kernel = ComposedOperator(kernel_smoothing) self._composed_kernel = ComposedOperator(kernel_smoothing) self._composed_exposure = ComposedOperator(kernel_exposure) self._composed_exposure = ComposedOperator(kernel_exposure) ... ...
 ... @@ -21,7 +21,7 @@ class Test_Minimizers(unittest.TestCase): ... @@ -21,7 +21,7 @@ class Test_Minimizers(unittest.TestCase): starting_point = ift.Field.from_random('normal', domain=space)*10 starting_point = ift.Field.from_random('normal', domain=space)*10 covariance_diagonal = ift.Field.from_random( covariance_diagonal = ift.Field.from_random( 'uniform', domain=space) + 0.5 'uniform', domain=space) + 0.5 covariance = ift.DiagonalOperator(space, diagonal=covariance_diagonal) covariance = ift.DiagonalOperator(covariance_diagonal) required_result = ift.Field(space, val=1.) required_result = ift.Field(space, val=1.) IC = ift.DefaultIterationController(tol_abs_gradnorm=1e-5) IC = ift.DefaultIterationController(tol_abs_gradnorm=1e-5) ... ...
 ... @@ -20,8 +20,8 @@ class ComposedOperator_Tests(unittest.TestCase): ... @@ -20,8 +20,8 @@ class ComposedOperator_Tests(unittest.TestCase): def test_property(self, space1, space2): def test_property(self, space1, space2): rand1 = Field.from_random('normal', domain=space1) rand1 = Field.from_random('normal', domain=space1) rand2 = Field.from_random('normal', domain=space2) rand2 = Field.from_random('normal', domain=space2) op1 = DiagonalOperator(space1, diagonal=rand1) op1 = DiagonalOperator(rand1) op2 = DiagonalOperator(space2, diagonal=rand2) op2 = DiagonalOperator(rand2) op = ComposedOperator((op1, op2)) op = ComposedOperator((op1, op2)) if op.domain != (op1.domain[0], op2.domain[0]): if op.domain != (op1.domain[0], op2.domain[0]): raise TypeError raise TypeError ... @@ -32,8 +32,8 @@ class ComposedOperator_Tests(unittest.TestCase): ... @@ -32,8 +32,8 @@ class ComposedOperator_Tests(unittest.TestCase): def test_times_adjoint_times(self, space1, space2): def test_times_adjoint_times(self, space1, space2): diag1 = Field.from_random('normal', domain=space1) diag1 = Field.from_random('normal', domain=space1) diag2 = Field.from_random('normal', domain=space2) diag2 = Field.from_random('normal', domain=space2) op1 = DiagonalOperator(space1, diagonal=diag1) op1 = DiagonalOperator(diag1) op2 = DiagonalOperator(space2, diagonal=diag2) op2 = DiagonalOperator(diag2) op = ComposedOperator((op1, op2)) op = ComposedOperator((op1, op2)) ... @@ -48,8 +48,8 @@ class ComposedOperator_Tests(unittest.TestCase): ... @@ -48,8 +48,8 @@ class ComposedOperator_Tests(unittest.TestCase): def test_times_inverse_times(self, space1, space2): def test_times_inverse_times(self, space1, space2): diag1 = Field.from_random('normal', domain=space1) diag1 = Field.from_random('normal', domain=space1) diag2 = Field.from_random('normal', domain=space2) diag2 = Field.from_random('normal', domain=space2) op1 = DiagonalOperator(space1, diagonal=diag1) op1 = DiagonalOperator(diag1) op2 = DiagonalOperator(space2, diagonal=diag2) op2 = DiagonalOperator(diag2) op = ComposedOperator((op1, op2)) op = ComposedOperator((op1, op2)) ... ...
 ... @@ -20,7 +20,7 @@ class DiagonalOperator_Tests(unittest.TestCase): ... @@ -20,7 +20,7 @@ class DiagonalOperator_Tests(unittest.TestCase): @expand(product(spaces, [True, False])) @expand(product(spaces, [True, False])) def test_property(self, space, copy): def test_property(self, space, copy): diag = Field.from_random('normal', domain=space) diag = Field.from_random('normal', domain=space) D = DiagonalOperator(space, diagonal=diag) D = DiagonalOperator(diag) if D.domain[0] != space: if D.domain[0] != space: raise TypeError raise TypeError if D.unitary != False: if D.unitary != False: ... @@ -33,7 +33,7 @@ class DiagonalOperator_Tests(unittest.TestCase): ... @@ -33,7 +33,7 @@ class DiagonalOperator_Tests(unittest.TestCase): rand1 = Field.from_random('normal', domain=space) rand1 = Field.from_random('normal', domain=space) rand2 = Field.from_random('normal', domain=space) rand2 = Field.from_random('normal', domain=space) diag = Field.from_random('normal', domain=space) diag = Field.from_random('normal', domain=space) D = DiagonalOperator(space, diagonal=diag, copy=copy) D = DiagonalOperator(diag, copy=copy) tt1 = rand1.vdot(D.times(rand2)) tt1 = rand1.vdot(D.times(rand2)) tt2 = rand2.vdot(D.times(rand1)) tt2 = rand2.vdot(D.times(rand1)) assert_approx_equal(tt1, tt2) assert_approx_equal(tt1, tt2) ... @@ -42,7 +42,7 @@ class DiagonalOperator_Tests(unittest.TestCase): ... @@ -42,7 +42,7 @@ class DiagonalOperator_Tests(unittest.TestCase): def test_times_inverse(self, space, copy): def test_times_inverse(self, space, copy): rand1 = Field.from_random('normal', domain=space) rand1 = Field.from_random('normal', domain=space) diag = Field.from_random('normal', domain=space) diag = Field.from_random('normal', domain=space) D = DiagonalOperator(space, diagonal=diag, copy=copy) D = DiagonalOperator(diag, copy=copy) tt1 = D.times(D.inverse_times(rand1)) tt1 = D.times(D.inverse_times(rand1)) assert_allclose(rand1.val, tt1.val) assert_allclose(rand1.val, tt1.val) ... @@ -50,7 +50,7 @@ class DiagonalOperator_Tests(unittest.TestCase): ... @@ -50,7 +50,7 @@ class DiagonalOperator_Tests(unittest.TestCase): def test_times(self, space, copy): def test_times(self, space, copy): rand1 = Field.from_random('normal', domain=space) rand1 = Field.from_random('normal', domain=space) diag = Field.from_random('normal', domain=space) diag = Field.from_random('normal', domain=space) D = DiagonalOperator(space, diagonal=diag, copy=copy) D = DiagonalOperator(diag, copy=copy) tt = D.times(rand1) tt = D.times(rand1) assert_equal(tt.domain[0], space) assert_equal(tt.domain[0], space) ... @@ -58,7 +58,7 @@ class DiagonalOperator_Tests(unittest.TestCase): ... @@ -58,7 +58,7 @@ class DiagonalOperator_Tests(unittest.TestCase): def test_adjoint_times(self, space, copy): def test_adjoint_times(self, space, copy): rand1 = Field.from_random('normal', domain=space) rand1 = Field.from_random('normal', domain=space) diag = Field.from_random('normal', domain=space) diag = Field.from_random('normal', domain=space) D = DiagonalOperator(space, diagonal=diag, copy=copy) D = DiagonalOperator(diag, copy=copy) tt = D.adjoint_times(rand1) tt = D.adjoint_times(rand1) assert_equal(tt.domain[0], space) assert_equal(tt.domain[0], space) ... @@ -66,7 +66,7 @@ class DiagonalOperator_Tests(unittest.TestCase): ... @@ -66,7 +66,7 @@ class DiagonalOperator_Tests(unittest.TestCase): def test_inverse_times(self, space, copy): def test_inverse_times(self, space, copy): rand1 = Field.from_random('normal', domain=space) rand1 = Field.from_random('normal', domain=space) diag = Field.from_random('normal', domain=space) diag = Field.from_random('normal', domain=space) D = DiagonalOperator(space, diagonal=diag, copy=copy) D = DiagonalOperator(diag, copy=copy) tt = D.inverse_times(rand1) tt = D.inverse_times(rand1) assert_equal(tt.domain[0], space) assert_equal(tt.domain[0], space) ... @@ -74,20 +74,20 @@ class DiagonalOperator_Tests(unittest.TestCase): ... @@ -74,20 +74,20 @@ class DiagonalOperator_Tests(unittest.TestCase): def test_adjoint_inverse_times(self, space, copy): def test_adjoint_inverse_times(self, space, copy): rand1 = Field.from_random('normal', domain=space) rand1 = Field.from_random('normal', domain=space) diag = Field.from_random('normal', domain=space) diag = Field.from_random('normal', domain=space) D = DiagonalOperator(space, diagonal=diag, copy=copy) D = DiagonalOperator(diag, copy=copy) tt = D.adjoint_inverse_times(rand1) tt = D.adjoint_inverse_times(rand1) assert_equal(tt.domain[0], space) assert_equal(tt.domain[0], space) @expand(product(spaces, [True, False])) @expand(product(spaces, [True, False])) def test_diagonal(self, space, copy): def test_diagonal(self, space, copy): diag = Field.from_random('normal', domain=space) diag = Field.from_random('normal', domain=space) D = DiagonalOperator(space, diagonal=diag, copy=copy) D = DiagonalOperator(diag, copy=copy) diag_op = D.diagonal() diag_op = D.diagonal() assert_allclose(diag.val, diag_op.val) assert_allclose(diag.val, diag_op.val) @expand(product(spaces, [True, False])) @expand(product(spaces, [True, False])) def test_inverse(self, space, copy): def test_inverse(self, space, copy): diag = Field.from_random('normal', domain=space) diag = Field.from_random('normal', domain=space) D = DiagonalOperator(space, diagonal=diag, copy=copy) D = DiagonalOperator(diag, copy=copy) diag_op = D.inverse_diagonal() diag_op = D.inverse_diagonal() assert_allclose(1./diag.val, diag_op.val) assert_allclose(1./diag.val, diag_op.val)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment