Commit c2482aa1 authored by Theo Steininger's avatar Theo Steininger

Added default_spaces property to Operator classes.

parent 8ad1b902
Pipeline #12218 passed with stages
in 11 minutes and 12 seconds
......@@ -107,7 +107,7 @@ if __name__ == "__main__":
# callback=distance_measure,
# max_history_length=3)
m0 = Field(s_space, val=1)
m0 = Field(s_space, val=1.)
energy = WienerFilterEnergy(position=m0, D=D, j=j)
......
......@@ -19,6 +19,7 @@
import numpy as np
from itertools import product
def get_slice_list(shape, axes):
"""
Helper function which generates slice list(s) to traverse over all
......@@ -65,8 +66,7 @@ def get_slice_list(shape, axes):
return
def cast_axis_to_tuple(axis, length):
def cast_axis_to_tuple(axis, length=None):
if axis is None:
return None
try:
......@@ -78,16 +78,17 @@ def cast_axis_to_tuple(axis, length):
raise TypeError(
"Could not convert axis-input to tuple of ints")
# shift negative indices to positive ones
axis = tuple(item if (item >= 0) else (item + length) for item in axis)
if length is not None:
# shift negative indices to positive ones
axis = tuple(item if (item >= 0) else (item + length) for item in axis)
# Deactivated this, in order to allow for the ComposedOperator
# remove duplicate entries
# axis = tuple(set(axis))
# Deactivated this, in order to allow for the ComposedOperator
# remove duplicate entries
# axis = tuple(set(axis))
# assert that all entries are elements in [0, length]
for elem in axis:
assert (0 <= elem < length)
# assert that all entries are elements in [0, length]
for elem in axis:
assert (0 <= elem < length)
return axis
......
......@@ -30,9 +30,10 @@ class DiagonalOperator(EndomorphicOperator):
# ---Overwritten properties and methods---
def __init__(self, domain=(),
diagonal=None, bare=False, copy=True,
distribution_strategy=None):
def __init__(self, domain=(), diagonal=None, bare=False, copy=True,
distribution_strategy=None, default_spaces=None):
super(DiagonalOperator, self).__init__(default_spaces)
self._domain = self._parse_domain(domain)
if distribution_strategy is None:
......
......@@ -112,7 +112,8 @@ class FFTOperator(LinearOperator):
# ---Overwritten properties and methods---
def __init__(self, domain, target=None, module=None,
domain_dtype=None, target_dtype=None):
domain_dtype=None, target_dtype=None, default_spaces=None):
super(FFTOperator, self).__init__(default_spaces)
# Initialize domain and target
......
......@@ -29,6 +29,7 @@ class InvertibleOperatorMixin(object):
else:
self.__inverter = ConjugateGradient(
preconditioner=self.__preconditioner)
super(InvertibleOperatorMixin, self).__init__(*args, **kwargs)
def _times(self, x, spaces, x0=None):
if x0 is None:
......
......@@ -27,8 +27,8 @@ import nifty.nifty_utilities as utilities
class LinearOperator(Loggable, object):
__metaclass__ = NiftyMeta
def __init__(self):
pass
def __init__(self, default_spaces=None):
self.default_spaces = default_spaces
def _parse_domain(self, domain):
return utilities.parse_domain(domain)
......@@ -45,6 +45,14 @@ class LinearOperator(Loggable, object):
def unitary(self):
raise NotImplementedError
@property
def default_spaces(self):
return self._default_spaces
@default_spaces.setter
def default_spaces(self, spaces):
self._default_spaces = utilities.cast_axis_to_tuple(spaces)
def __call__(self, *args, **kwargs):
return self.times(*args, **kwargs)
......@@ -127,6 +135,9 @@ class LinearOperator(Loggable, object):
raise ValueError(
"supplied object is not a `nifty.Field`.")
if spaces is None:
spaces = self.default_spaces
# sanitize the `spaces` and `types` input
spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
......
......@@ -27,7 +27,9 @@ class ProjectionOperator(EndomorphicOperator):
# ---Overwritten properties and methods---
def __init__(self, projection_field):
def __init__(self, projection_field, default_spaces=None):
super(ProjectionOperator, self).__init__(default_spaces)
if not isinstance(projection_field, Field):
raise TypeError("The projection_field must be a NIFTy-Field"
"instance.")
......
......@@ -26,7 +26,7 @@ class PropagatorOperator(InvertibleOperatorMixin, EndomorphicOperator):
# ---Overwritten properties and methods---
def __init__(self, S=None, M=None, R=None, N=None, inverter=None,
preconditioner=None):
preconditioner=None, default_spaces=None):
"""
Sets the standard operator properties and `codomain`, `_A1`, `_A2`,
and `RN` if required.
......@@ -66,7 +66,8 @@ class PropagatorOperator(InvertibleOperatorMixin, EndomorphicOperator):
preconditioner = self._S_times
super(PropagatorOperator, self).__init__(inverter=inverter,
preconditioner=preconditioner)
preconditioner=preconditioner,
default_spaces=default_spaces)
# ---Mandatory properties and methods---
......
......@@ -27,7 +27,9 @@ from d2o import STRATEGIES
class SmoothingOperator(EndomorphicOperator):
# ---Overwritten properties and methods---
def __init__(self, domain=(), sigma=0, log_distances=False):
def __init__(self, domain=(), sigma=0, log_distances=False,
default_spaces=None):
super(SmoothingOperator, self).__init__(default_spaces)
self._domain = self._parse_domain(domain)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment