Commit 12e3d597 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

simplify DiagonalOperator

parent 643d308b
......@@ -3,6 +3,71 @@
import numpy as np
from numpy import ndarray as data_object
from numpy import full, empty, sqrt, ones, zeros, vdot, abs, bincount
from ..nifty_utilities import cast_iseq_to_tuple, get_slice_list
from functools import reduce
def from_object(object, dtype=None, copy=True):
return np.array(object, dtype=dtype, copy=copy)
def bincount_axis(obj, minlength=None, weights=None, axis=None):
if minlength is not None:
length = max(np.amax(obj) + 1, minlength)
else:
length = np.amax(obj) + 1
if obj.shape == ():
raise ValueError("object of too small depth for desired array")
data = obj
# if present, parse the axis keyword and transpose/reorder self.data
# such that all affected axes follow each other. Only if they are in a
# sequence flattening will be possible
if axis is not None:
# do the reordering
ndim = len(obj.shape)
axis = sorted(cast_iseq_to_tuple(axis))
reordering = [x for x in range(ndim) if x not in axis]
reordering += axis
data = np.transpose(data, reordering)
if weights is not None:
weights = np.transpose(weights, reordering)
reord_axis = list(range(ndim-len(axis), ndim))
# semi-flatten the dimensions in `axis`, i.e. after reordering
# the last ones.
semi_flat_dim = reduce(lambda x, y: x*y,
data.shape[ndim-len(reord_axis):])
flat_shape = data.shape[:ndim-len(reord_axis)] + (semi_flat_dim, )
else:
flat_shape = (reduce(lambda x, y: x*y, data.shape), )
data = np.ascontiguousarray(data.reshape(flat_shape))
if weights is not None:
weights = np.ascontiguousarray(weights.reshape(flat_shape))
# compute the local bincount results
# -> prepare the local result array
result_dtype = np.int if weights is None else np.float
local_counts = np.empty(flat_shape[:-1] + (length, ), dtype=result_dtype)
# iterate over all entries in the surviving axes and compute the local
# bincounts
for slice_list in get_slice_list(flat_shape, axes=(len(flat_shape)-1,)):
current_weights = None if weights is None else weights[slice_list]
local_counts[slice_list] = np.bincount(data[slice_list],
weights=current_weights,
minlength=length)
# restore the original ordering
# place the bincount stuff at the location of the first `axis` entry
if axis is not None:
# axis has been sorted above
insert_position = axis[0]
new_ndim = len(local_counts.shape)
return_order = (list(range(0, insert_position)) +
[new_ndim-1, ] +
list(range(insert_position, new_ndim-1)))
local_counts = np.ascontiguousarray(
local_counts.transpose(return_order))
return local_counts
......@@ -231,8 +231,8 @@ class Field(object):
new_pindex_shape[ax] = pindex.shape[i]
pindex = np.broadcast_to(pindex.reshape(new_pindex_shape), field.shape)
power_spectrum = utilities.bincount_axis(pindex, weights=field.val,
axis=axes)
power_spectrum = dobj.bincount_axis(pindex, weights=field.val,
axis=axes)
new_rho_shape = [1] * len(power_spectrum.shape)
new_rho_shape[axes[0]] = len(power_domain.rho)
power_spectrum /= power_domain.rho.reshape(new_rho_shape)
......@@ -510,7 +510,7 @@ class Field(object):
# create a diagonal operator which is capable of taking care of the
# axes-matching
from .operators.diagonal_operator import DiagonalOperator
diag = DiagonalOperator(y.domain, y.conjugate(), copy=False)
diag = DiagonalOperator(y.conjugate(), copy=False)
dotted = diag(x, spaces=spaces)
return fct*dotted.sum(spaces=spaces)
......
......@@ -19,8 +19,6 @@
from builtins import next, range
import numpy as np
from itertools import product
from functools import reduce
from .domain_object import DomainObject
def get_slice_list(shape, axes):
......@@ -74,67 +72,3 @@ def cast_iseq_to_tuple(seq):
if np.isscalar(seq):
return (int(seq),)
return tuple(int(item) for item in seq)
def bincount_axis(obj, minlength=None, weights=None, axis=None):
if minlength is not None:
length = max(np.amax(obj) + 1, minlength)
else:
length = np.amax(obj) + 1
if obj.shape == ():
raise ValueError("object of too small depth for desired array")
data = obj
# if present, parse the axis keyword and transpose/reorder self.data
# such that all affected axes follow each other. Only if they are in a
# sequence flattening will be possible
if axis is not None:
# do the reordering
ndim = len(obj.shape)
axis = sorted(cast_iseq_to_tuple(axis))
reordering = [x for x in range(ndim) if x not in axis]
reordering += axis
data = np.transpose(data, reordering)
if weights is not None:
weights = np.transpose(weights, reordering)
reord_axis = list(range(ndim-len(axis), ndim))
# semi-flatten the dimensions in `axis`, i.e. after reordering
# the last ones.
semi_flat_dim = reduce(lambda x, y: x*y,
data.shape[ndim-len(reord_axis):])
flat_shape = data.shape[:ndim-len(reord_axis)] + (semi_flat_dim, )
else:
flat_shape = (reduce(lambda x, y: x*y, data.shape), )
data = np.ascontiguousarray(data.reshape(flat_shape))
if weights is not None:
weights = np.ascontiguousarray(weights.reshape(flat_shape))
# compute the local bincount results
# -> prepare the local result array
result_dtype = np.int if weights is None else np.float
local_counts = np.empty(flat_shape[:-1] + (length, ), dtype=result_dtype)
# iterate over all entries in the surviving axes and compute the local
# bincounts
for slice_list in get_slice_list(flat_shape, axes=(len(flat_shape)-1,)):
current_weights = None if weights is None else weights[slice_list]
local_counts[slice_list] = np.bincount(data[slice_list],
weights=current_weights,
minlength=length)
# restore the original ordering
# place the bincount stuff at the location of the first `axis` entry
if axis is not None:
# axis has been sorted above
insert_position = axis[0]
new_ndim = len(local_counts.shape)
return_order = (list(range(0, insert_position)) +
[new_ndim-1, ] +
list(range(insert_position, new_ndim-1)))
local_counts = np.ascontiguousarray(
local_counts.transpose(return_order))
return local_counts
......@@ -35,9 +35,7 @@ class DiagonalOperator(EndomorphicOperator):
Parameters
----------
domain : tuple of DomainObjects, i.e. Spaces and FieldTypes
The domain on which the Operator's input Field lives.
diagonal : {scalar, list, array, Field}
diagonal : Field
The diagonal entries of the operator.
copy : boolean
Internal copy of the diagonal (default: True)
......@@ -68,15 +66,14 @@ class DiagonalOperator(EndomorphicOperator):
# ---Overwritten properties and methods---
def __init__(self, domain=(), diagonal=None, copy=True,
default_spaces=None):
def __init__(self, diagonal, copy=True, default_spaces=None):
super(DiagonalOperator, self).__init__(default_spaces)
self._domain = DomainTuple.make(domain)
if not isinstance(diagonal, Field):
raise TypeError("Field object required")
self._diagonal = diagonal if not copy else diagonal.copy()
self._self_adjoint = None
self._unitary = None
self.set_diagonal(diagonal=diagonal, copy=copy)
def _times(self, x, spaces):
return self._times_helper(x, spaces, operation=lambda z: z.__mul__)
......@@ -119,13 +116,13 @@ class DiagonalOperator(EndomorphicOperator):
The inverse of the diagonal of the Operator.
"""
return 1./self.diagonal(copy=False)
return 1./self._diagonal
# ---Mandatory properties and methods---
@property
def domain(self):
return self._domain
return self._diagonal.domain
@property
def self_adjoint(self):
......@@ -144,30 +141,6 @@ class DiagonalOperator(EndomorphicOperator):
# ---Added properties and methods---
def set_diagonal(self, diagonal, copy=True):
""" Sets the diagonal of the Operator.
Parameters
----------
diagonal : {scalar, list, array, Field}
The diagonal entries of the operator.
copy : boolean
Specifies if a copy of the input shall be made (default: True).
"""
# use the casting functionality from Field to process `diagonal`
f = Field(domain=self.domain, val=diagonal, copy=copy)
# Reset the self_adjoint property:
self._self_adjoint = None
# Reset the unitarity property
self._unitary = None
# store the diagonal-field
self._diagonal = f
def _times_helper(self, x, spaces, operation):
# if the domain matches directly
# -> multiply the fields directly
......
......@@ -59,8 +59,8 @@ class ResponseOperator(LinearOperator):
kernel_smoothing = [FFTSmoothingOperator(self._domain[x], sigma[x])
for x in range(nsigma)]
kernel_exposure = [DiagonalOperator(self._domain[x],
diagonal=exposure[x]) for x in range(nsigma)]
kernel_exposure = [DiagonalOperator(Field(self._domain[x],exposure[x]))
for x in range(nsigma)]
self._composed_kernel = ComposedOperator(kernel_smoothing)
self._composed_exposure = ComposedOperator(kernel_exposure)
......
......@@ -21,7 +21,7 @@ class Test_Minimizers(unittest.TestCase):
starting_point = ift.Field.from_random('normal', domain=space)*10
covariance_diagonal = ift.Field.from_random(
'uniform', domain=space) + 0.5
covariance = ift.DiagonalOperator(space, diagonal=covariance_diagonal)
covariance = ift.DiagonalOperator(covariance_diagonal)
required_result = ift.Field(space, val=1.)
IC = ift.DefaultIterationController(tol_abs_gradnorm=1e-5)
......
......@@ -20,8 +20,8 @@ class ComposedOperator_Tests(unittest.TestCase):
def test_property(self, space1, space2):
rand1 = Field.from_random('normal', domain=space1)
rand2 = Field.from_random('normal', domain=space2)
op1 = DiagonalOperator(space1, diagonal=rand1)
op2 = DiagonalOperator(space2, diagonal=rand2)
op1 = DiagonalOperator(rand1)
op2 = DiagonalOperator(rand2)
op = ComposedOperator((op1, op2))
if op.domain != (op1.domain[0], op2.domain[0]):
raise TypeError
......@@ -32,8 +32,8 @@ class ComposedOperator_Tests(unittest.TestCase):
def test_times_adjoint_times(self, space1, space2):
diag1 = Field.from_random('normal', domain=space1)
diag2 = Field.from_random('normal', domain=space2)
op1 = DiagonalOperator(space1, diagonal=diag1)
op2 = DiagonalOperator(space2, diagonal=diag2)
op1 = DiagonalOperator(diag1)
op2 = DiagonalOperator(diag2)
op = ComposedOperator((op1, op2))
......@@ -48,8 +48,8 @@ class ComposedOperator_Tests(unittest.TestCase):
def test_times_inverse_times(self, space1, space2):
diag1 = Field.from_random('normal', domain=space1)
diag2 = Field.from_random('normal', domain=space2)
op1 = DiagonalOperator(space1, diagonal=diag1)
op2 = DiagonalOperator(space2, diagonal=diag2)
op1 = DiagonalOperator(diag1)
op2 = DiagonalOperator(diag2)
op = ComposedOperator((op1, op2))
......
......@@ -20,7 +20,7 @@ class DiagonalOperator_Tests(unittest.TestCase):
@expand(product(spaces, [True, False]))
def test_property(self, space, copy):
diag = Field.from_random('normal', domain=space)
D = DiagonalOperator(space, diagonal=diag)
D = DiagonalOperator(diag)
if D.domain[0] != space:
raise TypeError
if D.unitary != False:
......@@ -33,7 +33,7 @@ class DiagonalOperator_Tests(unittest.TestCase):
rand1 = Field.from_random('normal', domain=space)
rand2 = Field.from_random('normal', domain=space)
diag = Field.from_random('normal', domain=space)
D = DiagonalOperator(space, diagonal=diag, copy=copy)
D = DiagonalOperator(diag, copy=copy)
tt1 = rand1.vdot(D.times(rand2))
tt2 = rand2.vdot(D.times(rand1))
assert_approx_equal(tt1, tt2)
......@@ -42,7 +42,7 @@ class DiagonalOperator_Tests(unittest.TestCase):
def test_times_inverse(self, space, copy):
rand1 = Field.from_random('normal', domain=space)
diag = Field.from_random('normal', domain=space)
D = DiagonalOperator(space, diagonal=diag, copy=copy)
D = DiagonalOperator(diag, copy=copy)
tt1 = D.times(D.inverse_times(rand1))
assert_allclose(rand1.val, tt1.val)
......@@ -50,7 +50,7 @@ class DiagonalOperator_Tests(unittest.TestCase):
def test_times(self, space, copy):
rand1 = Field.from_random('normal', domain=space)
diag = Field.from_random('normal', domain=space)
D = DiagonalOperator(space, diagonal=diag, copy=copy)
D = DiagonalOperator(diag, copy=copy)
tt = D.times(rand1)
assert_equal(tt.domain[0], space)
......@@ -58,7 +58,7 @@ class DiagonalOperator_Tests(unittest.TestCase):
def test_adjoint_times(self, space, copy):
rand1 = Field.from_random('normal', domain=space)
diag = Field.from_random('normal', domain=space)
D = DiagonalOperator(space, diagonal=diag, copy=copy)
D = DiagonalOperator(diag, copy=copy)
tt = D.adjoint_times(rand1)
assert_equal(tt.domain[0], space)
......@@ -66,7 +66,7 @@ class DiagonalOperator_Tests(unittest.TestCase):
def test_inverse_times(self, space, copy):
rand1 = Field.from_random('normal', domain=space)
diag = Field.from_random('normal', domain=space)
D = DiagonalOperator(space, diagonal=diag, copy=copy)
D = DiagonalOperator(diag, copy=copy)
tt = D.inverse_times(rand1)
assert_equal(tt.domain[0], space)
......@@ -74,20 +74,20 @@ class DiagonalOperator_Tests(unittest.TestCase):
def test_adjoint_inverse_times(self, space, copy):
rand1 = Field.from_random('normal', domain=space)
diag = Field.from_random('normal', domain=space)
D = DiagonalOperator(space, diagonal=diag, copy=copy)
D = DiagonalOperator(diag, copy=copy)
tt = D.adjoint_inverse_times(rand1)
assert_equal(tt.domain[0], space)
@expand(product(spaces, [True, False]))
def test_diagonal(self, space, copy):
diag = Field.from_random('normal', domain=space)
D = DiagonalOperator(space, diagonal=diag, copy=copy)
D = DiagonalOperator(diag, copy=copy)
diag_op = D.diagonal()
assert_allclose(diag.val, diag_op.val)
@expand(product(spaces, [True, False]))
def test_inverse(self, space, copy):
diag = Field.from_random('normal', domain=space)
D = DiagonalOperator(space, diagonal=diag, copy=copy)
D = DiagonalOperator(diag, copy=copy)
diag_op = D.inverse_diagonal()
assert_allclose(1./diag.val, diag_op.val)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment