Commit 7b26ce5f authored by theos's avatar theos
Browse files

Finished implementation of DiagonalOperator

parent 4acfa8c4
......@@ -577,7 +577,7 @@ class Field(object):
new_field.__class__ = self.__class__
# copy domain, codomain and val
for key, value in self.__dict__.items():
if key != 'val':
if key != '_val':
new_field.__dict__[key] = value
else:
new_field.__dict__[key] = self.val.copy_empty()
......
......@@ -15,12 +15,11 @@ class DiagonalOperator(EndomorphicOperator):
# ---Overwritten properties and methods---
def __init__(self, domain=(), field_type=(), implemented=False,
def __init__(self, domain=(), field_type=(), implemented=True,
diagonal=None, bare=False, copy=True,
distribution_strategy=None):
super(DiagonalOperator, self).__init__(domain=domain,
field_type=field_type,
implemented=implemented)
field_type=field_type)
self._implemented = bool(implemented)
......@@ -30,61 +29,27 @@ class DiagonalOperator(EndomorphicOperator):
elif isinstance(diagonal, Field):
distribution_strategy = diagonal.distribution_strategy
self.distribution_strategy = self._parse_distribution_strategy(
self._distribution_strategy = self._parse_distribution_strategy(
distribution_strategy=distribution_strategy,
val=diagonal)
self.set_diagonal(diagonal=diagonal, bare=bare, copy=copy)
def _times(self, x, spaces, types):
# if the distribution_strategy of self is sub-slice compatible to
# the one of x, reshape the local data of self and apply it directly
active_axes = []
if spaces is None:
for axes in x.domain_axes:
active_axes += axes
else:
for space_index in spaces:
active_axes += x.domain_axes[space_index]
if types is None:
for axes in x.field_type_axes:
active_axes += axes
else:
for type_index in types:
active_axes += x.field_type_axes[type_index]
if x.val.get_axes_local_distribution_strategy(active_axes) == \
self.distribution_strategy:
local_data = self._diagonal.val.get_local_data(copy=False)
# check if domains match completely
# -> multiply directly
# check if axes_local_distribution_strategy matches.
# If yes, extract local data of self.diagonal and x and use numpy
# reshape.
# assert that indices in spaces and types are striktly increasing
# otherwise a wild transpose would be necessary
# build new shape (1,1,x,1,y,1,1,z)
# copy self.diagonal into new shape
# apply reshaped array to x
return self._times_helper(x, spaces, types,
operation=lambda z: z.__mul__)
def _adjoint_times(self, x, spaces, types):
pass
return self._times_helper(x, spaces, types,
operation=lambda z: z.adjoint().__mul__)
def _inverse_times(self, x, spaces, types):
pass
return self._times_helper(x, spaces, types,
operation=lambda z: z.__rdiv__)
def _adjoint_inverse_times(self, x, spaces, types):
pass
def _inverse_adjoint_times(self, x, spaces, types):
pass
return self._times_helper(x, spaces, types,
operation=lambda z: z.adjoint().__rdiv__)
def diagonal(self, bare=False, copy=True):
if bare:
......@@ -178,3 +143,49 @@ class DiagonalOperator(EndomorphicOperator):
# store the diagonal-field
self._diagonal = f
def _times_helper(self, x, spaces, types, operation):
# if the domain and field_type match directly
# -> multiply the fields directly
if x.domain == self.domain and x.field_type == self.field_type:
# here the actual multiplication takes place
return operation(self.diagonal(copy=False))(x)
# if the distribution_strategy of self is sub-slice compatible to
# the one of x, reshape the local data of self and apply it directly
active_axes = []
if spaces is None:
for axes in x.domain_axes:
active_axes += axes
else:
for space_index in spaces:
active_axes += x.domain_axes[space_index]
if types is None:
for axes in x.field_type_axes:
active_axes += axes
else:
for type_index in types:
active_axes += x.field_type_axes[type_index]
axes_local_distribution_strategy = \
x.val.get_axes_local_distribution_strategy(active_axes)
if axes_local_distribution_strategy == self.distribution_strategy:
local_diagonal = self._diagonal.val.get_local_data(copy=False)
else:
# create an array that is sub-slice compatible
redistr_diagonal_val = self._diagonal.val.copy(
distribution_strategy=axes_local_distribution_strategy)
local_diagonal = redistr_diagonal_val.get_local_data(copy=False)
reshaper = [x.shape[i] if i in active_axes else 1
for i in xrange(len(x.shape))]
reshaped_local_diagonal = np.reshape(local_diagonal, reshaper)
# here the actual multiplication takes place
local_result = operation(reshaped_local_diagonal)(
x.val.get_local_data(copy=False))
result_field = x.copy_empty(dtype=local_result.dtype)
result_field.val.set_local_data(local_result, copy=False)
return result_field
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment