Commit 12e3d597 authored by Martin Reinecke's avatar Martin Reinecke

simplify DiagonalOperator

parent 643d308b
...@@ -3,6 +3,71 @@ ...@@ -3,6 +3,71 @@
import numpy as np import numpy as np
from numpy import ndarray as data_object from numpy import ndarray as data_object
from numpy import full, empty, sqrt, ones, zeros, vdot, abs, bincount from numpy import full, empty, sqrt, ones, zeros, vdot, abs, bincount
from ..nifty_utilities import cast_iseq_to_tuple, get_slice_list
from functools import reduce
def from_object(object, dtype=None, copy=True): def from_object(object, dtype=None, copy=True):
return np.array(object, dtype=dtype, copy=copy) return np.array(object, dtype=dtype, copy=copy)
def bincount_axis(obj, minlength=None, weights=None, axis=None):
if minlength is not None:
length = max(np.amax(obj) + 1, minlength)
else:
length = np.amax(obj) + 1
if obj.shape == ():
raise ValueError("object of too small depth for desired array")
data = obj
# if present, parse the axis keyword and transpose/reorder self.data
# such that all affected axes follow each other. Only if they are in a
# sequence flattening will be possible
if axis is not None:
# do the reordering
ndim = len(obj.shape)
axis = sorted(cast_iseq_to_tuple(axis))
reordering = [x for x in range(ndim) if x not in axis]
reordering += axis
data = np.transpose(data, reordering)
if weights is not None:
weights = np.transpose(weights, reordering)
reord_axis = list(range(ndim-len(axis), ndim))
# semi-flatten the dimensions in `axis`, i.e. after reordering
# the last ones.
semi_flat_dim = reduce(lambda x, y: x*y,
data.shape[ndim-len(reord_axis):])
flat_shape = data.shape[:ndim-len(reord_axis)] + (semi_flat_dim, )
else:
flat_shape = (reduce(lambda x, y: x*y, data.shape), )
data = np.ascontiguousarray(data.reshape(flat_shape))
if weights is not None:
weights = np.ascontiguousarray(weights.reshape(flat_shape))
# compute the local bincount results
# -> prepare the local result array
result_dtype = np.int if weights is None else np.float
local_counts = np.empty(flat_shape[:-1] + (length, ), dtype=result_dtype)
# iterate over all entries in the surviving axes and compute the local
# bincounts
for slice_list in get_slice_list(flat_shape, axes=(len(flat_shape)-1,)):
current_weights = None if weights is None else weights[slice_list]
local_counts[slice_list] = np.bincount(data[slice_list],
weights=current_weights,
minlength=length)
# restore the original ordering
# place the bincount stuff at the location of the first `axis` entry
if axis is not None:
# axis has been sorted above
insert_position = axis[0]
new_ndim = len(local_counts.shape)
return_order = (list(range(0, insert_position)) +
[new_ndim-1, ] +
list(range(insert_position, new_ndim-1)))
local_counts = np.ascontiguousarray(
local_counts.transpose(return_order))
return local_counts
...@@ -231,8 +231,8 @@ class Field(object): ...@@ -231,8 +231,8 @@ class Field(object):
new_pindex_shape[ax] = pindex.shape[i] new_pindex_shape[ax] = pindex.shape[i]
pindex = np.broadcast_to(pindex.reshape(new_pindex_shape), field.shape) pindex = np.broadcast_to(pindex.reshape(new_pindex_shape), field.shape)
power_spectrum = utilities.bincount_axis(pindex, weights=field.val, power_spectrum = dobj.bincount_axis(pindex, weights=field.val,
axis=axes) axis=axes)
new_rho_shape = [1] * len(power_spectrum.shape) new_rho_shape = [1] * len(power_spectrum.shape)
new_rho_shape[axes[0]] = len(power_domain.rho) new_rho_shape[axes[0]] = len(power_domain.rho)
power_spectrum /= power_domain.rho.reshape(new_rho_shape) power_spectrum /= power_domain.rho.reshape(new_rho_shape)
...@@ -510,7 +510,7 @@ class Field(object): ...@@ -510,7 +510,7 @@ class Field(object):
# create a diagonal operator which is capable of taking care of the # create a diagonal operator which is capable of taking care of the
# axes-matching # axes-matching
from .operators.diagonal_operator import DiagonalOperator from .operators.diagonal_operator import DiagonalOperator
diag = DiagonalOperator(y.domain, y.conjugate(), copy=False) diag = DiagonalOperator(y.conjugate(), copy=False)
dotted = diag(x, spaces=spaces) dotted = diag(x, spaces=spaces)
return fct*dotted.sum(spaces=spaces) return fct*dotted.sum(spaces=spaces)
......
...@@ -19,8 +19,6 @@ ...@@ -19,8 +19,6 @@
from builtins import next, range from builtins import next, range
import numpy as np import numpy as np
from itertools import product from itertools import product
from functools import reduce
from .domain_object import DomainObject
def get_slice_list(shape, axes): def get_slice_list(shape, axes):
...@@ -74,67 +72,3 @@ def cast_iseq_to_tuple(seq): ...@@ -74,67 +72,3 @@ def cast_iseq_to_tuple(seq):
if np.isscalar(seq): if np.isscalar(seq):
return (int(seq),) return (int(seq),)
return tuple(int(item) for item in seq) return tuple(int(item) for item in seq)
def bincount_axis(obj, minlength=None, weights=None, axis=None):
if minlength is not None:
length = max(np.amax(obj) + 1, minlength)
else:
length = np.amax(obj) + 1
if obj.shape == ():
raise ValueError("object of too small depth for desired array")
data = obj
# if present, parse the axis keyword and transpose/reorder self.data
# such that all affected axes follow each other. Only if they are in a
# sequence flattening will be possible
if axis is not None:
# do the reordering
ndim = len(obj.shape)
axis = sorted(cast_iseq_to_tuple(axis))
reordering = [x for x in range(ndim) if x not in axis]
reordering += axis
data = np.transpose(data, reordering)
if weights is not None:
weights = np.transpose(weights, reordering)
reord_axis = list(range(ndim-len(axis), ndim))
# semi-flatten the dimensions in `axis`, i.e. after reordering
# the last ones.
semi_flat_dim = reduce(lambda x, y: x*y,
data.shape[ndim-len(reord_axis):])
flat_shape = data.shape[:ndim-len(reord_axis)] + (semi_flat_dim, )
else:
flat_shape = (reduce(lambda x, y: x*y, data.shape), )
data = np.ascontiguousarray(data.reshape(flat_shape))
if weights is not None:
weights = np.ascontiguousarray(weights.reshape(flat_shape))
# compute the local bincount results
# -> prepare the local result array
result_dtype = np.int if weights is None else np.float
local_counts = np.empty(flat_shape[:-1] + (length, ), dtype=result_dtype)
# iterate over all entries in the surviving axes and compute the local
# bincounts
for slice_list in get_slice_list(flat_shape, axes=(len(flat_shape)-1,)):
current_weights = None if weights is None else weights[slice_list]
local_counts[slice_list] = np.bincount(data[slice_list],
weights=current_weights,
minlength=length)
# restore the original ordering
# place the bincount stuff at the location of the first `axis` entry
if axis is not None:
# axis has been sorted above
insert_position = axis[0]
new_ndim = len(local_counts.shape)
return_order = (list(range(0, insert_position)) +
[new_ndim-1, ] +
list(range(insert_position, new_ndim-1)))
local_counts = np.ascontiguousarray(
local_counts.transpose(return_order))
return local_counts
...@@ -35,9 +35,7 @@ class DiagonalOperator(EndomorphicOperator): ...@@ -35,9 +35,7 @@ class DiagonalOperator(EndomorphicOperator):
Parameters Parameters
---------- ----------
domain : tuple of DomainObjects, i.e. Spaces and FieldTypes diagonal : Field
The domain on which the Operator's input Field lives.
diagonal : {scalar, list, array, Field}
The diagonal entries of the operator. The diagonal entries of the operator.
copy : boolean copy : boolean
Internal copy of the diagonal (default: True) Internal copy of the diagonal (default: True)
...@@ -68,15 +66,14 @@ class DiagonalOperator(EndomorphicOperator): ...@@ -68,15 +66,14 @@ class DiagonalOperator(EndomorphicOperator):
# ---Overwritten properties and methods--- # ---Overwritten properties and methods---
def __init__(self, domain=(), diagonal=None, copy=True, def __init__(self, diagonal, copy=True, default_spaces=None):
default_spaces=None):
super(DiagonalOperator, self).__init__(default_spaces) super(DiagonalOperator, self).__init__(default_spaces)
self._domain = DomainTuple.make(domain) if not isinstance(diagonal, Field):
raise TypeError("Field object required")
self._diagonal = diagonal if not copy else diagonal.copy()
self._self_adjoint = None self._self_adjoint = None
self._unitary = None self._unitary = None
self.set_diagonal(diagonal=diagonal, copy=copy)
def _times(self, x, spaces): def _times(self, x, spaces):
return self._times_helper(x, spaces, operation=lambda z: z.__mul__) return self._times_helper(x, spaces, operation=lambda z: z.__mul__)
...@@ -119,13 +116,13 @@ class DiagonalOperator(EndomorphicOperator): ...@@ -119,13 +116,13 @@ class DiagonalOperator(EndomorphicOperator):
The inverse of the diagonal of the Operator. The inverse of the diagonal of the Operator.
""" """
return 1./self.diagonal(copy=False) return 1./self._diagonal
# ---Mandatory properties and methods--- # ---Mandatory properties and methods---
@property @property
def domain(self): def domain(self):
return self._domain return self._diagonal.domain
@property @property
def self_adjoint(self): def self_adjoint(self):
...@@ -144,30 +141,6 @@ class DiagonalOperator(EndomorphicOperator): ...@@ -144,30 +141,6 @@ class DiagonalOperator(EndomorphicOperator):
# ---Added properties and methods--- # ---Added properties and methods---
def set_diagonal(self, diagonal, copy=True):
""" Sets the diagonal of the Operator.
Parameters
----------
diagonal : {scalar, list, array, Field}
The diagonal entries of the operator.
copy : boolean
Specifies if a copy of the input shall be made (default: True).
"""
# use the casting functionality from Field to process `diagonal`
f = Field(domain=self.domain, val=diagonal, copy=copy)
# Reset the self_adjoint property:
self._self_adjoint = None
# Reset the unitarity property
self._unitary = None
# store the diagonal-field
self._diagonal = f
def _times_helper(self, x, spaces, operation): def _times_helper(self, x, spaces, operation):
# if the domain matches directly # if the domain matches directly
# -> multiply the fields directly # -> multiply the fields directly
......
...@@ -59,8 +59,8 @@ class ResponseOperator(LinearOperator): ...@@ -59,8 +59,8 @@ class ResponseOperator(LinearOperator):
kernel_smoothing = [FFTSmoothingOperator(self._domain[x], sigma[x]) kernel_smoothing = [FFTSmoothingOperator(self._domain[x], sigma[x])
for x in range(nsigma)] for x in range(nsigma)]
kernel_exposure = [DiagonalOperator(self._domain[x], kernel_exposure = [DiagonalOperator(Field(self._domain[x],exposure[x]))
diagonal=exposure[x]) for x in range(nsigma)] for x in range(nsigma)]
self._composed_kernel = ComposedOperator(kernel_smoothing) self._composed_kernel = ComposedOperator(kernel_smoothing)
self._composed_exposure = ComposedOperator(kernel_exposure) self._composed_exposure = ComposedOperator(kernel_exposure)
......
...@@ -21,7 +21,7 @@ class Test_Minimizers(unittest.TestCase): ...@@ -21,7 +21,7 @@ class Test_Minimizers(unittest.TestCase):
starting_point = ift.Field.from_random('normal', domain=space)*10 starting_point = ift.Field.from_random('normal', domain=space)*10
covariance_diagonal = ift.Field.from_random( covariance_diagonal = ift.Field.from_random(
'uniform', domain=space) + 0.5 'uniform', domain=space) + 0.5
covariance = ift.DiagonalOperator(space, diagonal=covariance_diagonal) covariance = ift.DiagonalOperator(covariance_diagonal)
required_result = ift.Field(space, val=1.) required_result = ift.Field(space, val=1.)
IC = ift.DefaultIterationController(tol_abs_gradnorm=1e-5) IC = ift.DefaultIterationController(tol_abs_gradnorm=1e-5)
......
...@@ -20,8 +20,8 @@ class ComposedOperator_Tests(unittest.TestCase): ...@@ -20,8 +20,8 @@ class ComposedOperator_Tests(unittest.TestCase):
def test_property(self, space1, space2): def test_property(self, space1, space2):
rand1 = Field.from_random('normal', domain=space1) rand1 = Field.from_random('normal', domain=space1)
rand2 = Field.from_random('normal', domain=space2) rand2 = Field.from_random('normal', domain=space2)
op1 = DiagonalOperator(space1, diagonal=rand1) op1 = DiagonalOperator(rand1)
op2 = DiagonalOperator(space2, diagonal=rand2) op2 = DiagonalOperator(rand2)
op = ComposedOperator((op1, op2)) op = ComposedOperator((op1, op2))
if op.domain != (op1.domain[0], op2.domain[0]): if op.domain != (op1.domain[0], op2.domain[0]):
raise TypeError raise TypeError
...@@ -32,8 +32,8 @@ class ComposedOperator_Tests(unittest.TestCase): ...@@ -32,8 +32,8 @@ class ComposedOperator_Tests(unittest.TestCase):
def test_times_adjoint_times(self, space1, space2): def test_times_adjoint_times(self, space1, space2):
diag1 = Field.from_random('normal', domain=space1) diag1 = Field.from_random('normal', domain=space1)
diag2 = Field.from_random('normal', domain=space2) diag2 = Field.from_random('normal', domain=space2)
op1 = DiagonalOperator(space1, diagonal=diag1) op1 = DiagonalOperator(diag1)
op2 = DiagonalOperator(space2, diagonal=diag2) op2 = DiagonalOperator(diag2)
op = ComposedOperator((op1, op2)) op = ComposedOperator((op1, op2))
...@@ -48,8 +48,8 @@ class ComposedOperator_Tests(unittest.TestCase): ...@@ -48,8 +48,8 @@ class ComposedOperator_Tests(unittest.TestCase):
def test_times_inverse_times(self, space1, space2): def test_times_inverse_times(self, space1, space2):
diag1 = Field.from_random('normal', domain=space1) diag1 = Field.from_random('normal', domain=space1)
diag2 = Field.from_random('normal', domain=space2) diag2 = Field.from_random('normal', domain=space2)
op1 = DiagonalOperator(space1, diagonal=diag1) op1 = DiagonalOperator(diag1)
op2 = DiagonalOperator(space2, diagonal=diag2) op2 = DiagonalOperator(diag2)
op = ComposedOperator((op1, op2)) op = ComposedOperator((op1, op2))
......
...@@ -20,7 +20,7 @@ class DiagonalOperator_Tests(unittest.TestCase): ...@@ -20,7 +20,7 @@ class DiagonalOperator_Tests(unittest.TestCase):
@expand(product(spaces, [True, False])) @expand(product(spaces, [True, False]))
def test_property(self, space, copy): def test_property(self, space, copy):
diag = Field.from_random('normal', domain=space) diag = Field.from_random('normal', domain=space)
D = DiagonalOperator(space, diagonal=diag) D = DiagonalOperator(diag)
if D.domain[0] != space: if D.domain[0] != space:
raise TypeError raise TypeError
if D.unitary != False: if D.unitary != False:
...@@ -33,7 +33,7 @@ class DiagonalOperator_Tests(unittest.TestCase): ...@@ -33,7 +33,7 @@ class DiagonalOperator_Tests(unittest.TestCase):
rand1 = Field.from_random('normal', domain=space) rand1 = Field.from_random('normal', domain=space)
rand2 = Field.from_random('normal', domain=space) rand2 = Field.from_random('normal', domain=space)
diag = Field.from_random('normal', domain=space) diag = Field.from_random('normal', domain=space)
D = DiagonalOperator(space, diagonal=diag, copy=copy) D = DiagonalOperator(diag, copy=copy)
tt1 = rand1.vdot(D.times(rand2)) tt1 = rand1.vdot(D.times(rand2))
tt2 = rand2.vdot(D.times(rand1)) tt2 = rand2.vdot(D.times(rand1))
assert_approx_equal(tt1, tt2) assert_approx_equal(tt1, tt2)
...@@ -42,7 +42,7 @@ class DiagonalOperator_Tests(unittest.TestCase): ...@@ -42,7 +42,7 @@ class DiagonalOperator_Tests(unittest.TestCase):
def test_times_inverse(self, space, copy): def test_times_inverse(self, space, copy):
rand1 = Field.from_random('normal', domain=space) rand1 = Field.from_random('normal', domain=space)
diag = Field.from_random('normal', domain=space) diag = Field.from_random('normal', domain=space)
D = DiagonalOperator(space, diagonal=diag, copy=copy) D = DiagonalOperator(diag, copy=copy)
tt1 = D.times(D.inverse_times(rand1)) tt1 = D.times(D.inverse_times(rand1))
assert_allclose(rand1.val, tt1.val) assert_allclose(rand1.val, tt1.val)
...@@ -50,7 +50,7 @@ class DiagonalOperator_Tests(unittest.TestCase): ...@@ -50,7 +50,7 @@ class DiagonalOperator_Tests(unittest.TestCase):
def test_times(self, space, copy): def test_times(self, space, copy):
rand1 = Field.from_random('normal', domain=space) rand1 = Field.from_random('normal', domain=space)
diag = Field.from_random('normal', domain=space) diag = Field.from_random('normal', domain=space)
D = DiagonalOperator(space, diagonal=diag, copy=copy) D = DiagonalOperator(diag, copy=copy)
tt = D.times(rand1) tt = D.times(rand1)
assert_equal(tt.domain[0], space) assert_equal(tt.domain[0], space)
...@@ -58,7 +58,7 @@ class DiagonalOperator_Tests(unittest.TestCase): ...@@ -58,7 +58,7 @@ class DiagonalOperator_Tests(unittest.TestCase):
def test_adjoint_times(self, space, copy): def test_adjoint_times(self, space, copy):
rand1 = Field.from_random('normal', domain=space) rand1 = Field.from_random('normal', domain=space)
diag = Field.from_random('normal', domain=space) diag = Field.from_random('normal', domain=space)
D = DiagonalOperator(space, diagonal=diag, copy=copy) D = DiagonalOperator(diag, copy=copy)
tt = D.adjoint_times(rand1) tt = D.adjoint_times(rand1)
assert_equal(tt.domain[0], space) assert_equal(tt.domain[0], space)
...@@ -66,7 +66,7 @@ class DiagonalOperator_Tests(unittest.TestCase): ...@@ -66,7 +66,7 @@ class DiagonalOperator_Tests(unittest.TestCase):
def test_inverse_times(self, space, copy): def test_inverse_times(self, space, copy):
rand1 = Field.from_random('normal', domain=space) rand1 = Field.from_random('normal', domain=space)
diag = Field.from_random('normal', domain=space) diag = Field.from_random('normal', domain=space)
D = DiagonalOperator(space, diagonal=diag, copy=copy) D = DiagonalOperator(diag, copy=copy)
tt = D.inverse_times(rand1) tt = D.inverse_times(rand1)
assert_equal(tt.domain[0], space) assert_equal(tt.domain[0], space)
...@@ -74,20 +74,20 @@ class DiagonalOperator_Tests(unittest.TestCase): ...@@ -74,20 +74,20 @@ class DiagonalOperator_Tests(unittest.TestCase):
def test_adjoint_inverse_times(self, space, copy): def test_adjoint_inverse_times(self, space, copy):
rand1 = Field.from_random('normal', domain=space) rand1 = Field.from_random('normal', domain=space)
diag = Field.from_random('normal', domain=space) diag = Field.from_random('normal', domain=space)
D = DiagonalOperator(space, diagonal=diag, copy=copy) D = DiagonalOperator(diag, copy=copy)
tt = D.adjoint_inverse_times(rand1) tt = D.adjoint_inverse_times(rand1)
assert_equal(tt.domain[0], space) assert_equal(tt.domain[0], space)
@expand(product(spaces, [True, False])) @expand(product(spaces, [True, False]))
def test_diagonal(self, space, copy): def test_diagonal(self, space, copy):
diag = Field.from_random('normal', domain=space) diag = Field.from_random('normal', domain=space)
D = DiagonalOperator(space, diagonal=diag, copy=copy) D = DiagonalOperator(diag, copy=copy)
diag_op = D.diagonal() diag_op = D.diagonal()
assert_allclose(diag.val, diag_op.val) assert_allclose(diag.val, diag_op.val)
@expand(product(spaces, [True, False])) @expand(product(spaces, [True, False]))
def test_inverse(self, space, copy): def test_inverse(self, space, copy):
diag = Field.from_random('normal', domain=space) diag = Field.from_random('normal', domain=space)
D = DiagonalOperator(space, diagonal=diag, copy=copy) D = DiagonalOperator(diag, copy=copy)
diag_op = D.inverse_diagonal() diag_op = D.inverse_diagonal()
assert_allclose(1./diag.val, diag_op.val) assert_allclose(1./diag.val, diag_op.val)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment