Commit c2482aa1 by Theo Steininger

Added default_spaces property to Operator classes.

Pipeline #12218 passed with stages
in 11 minutes and 12 seconds
 ... ... @@ -107,7 +107,7 @@ if __name__ == "__main__": # callback=distance_measure, # max_history_length=3) m0 = Field(s_space, val=1) m0 = Field(s_space, val=1.) energy = WienerFilterEnergy(position=m0, D=D, j=j) ... ...
 ... ... @@ -19,6 +19,7 @@ import numpy as np from itertools import product def get_slice_list(shape, axes): """ Helper function which generates slice list(s) to traverse over all ... ... @@ -65,8 +66,7 @@ def get_slice_list(shape, axes): return def cast_axis_to_tuple(axis, length): def cast_axis_to_tuple(axis, length=None): if axis is None: return None try: ... ... @@ -78,16 +78,17 @@ def cast_axis_to_tuple(axis, length): raise TypeError( "Could not convert axis-input to tuple of ints") # shift negative indices to positive ones axis = tuple(item if (item >= 0) else (item + length) for item in axis) if length is not None: # shift negative indices to positive ones axis = tuple(item if (item >= 0) else (item + length) for item in axis) # Deactivated this, in order to allow for the ComposedOperator # remove duplicate entries # axis = tuple(set(axis)) # Deactivated this, in order to allow for the ComposedOperator # remove duplicate entries # axis = tuple(set(axis)) # assert that all entries are elements in [0, length] for elem in axis: assert (0 <= elem < length) # assert that all entries are elements in [0, length] for elem in axis: assert (0 <= elem < length) return axis ... ...
 ... ... @@ -30,9 +30,10 @@ class DiagonalOperator(EndomorphicOperator): # ---Overwritten properties and methods--- def __init__(self, domain=(), diagonal=None, bare=False, copy=True, distribution_strategy=None): def __init__(self, domain=(), diagonal=None, bare=False, copy=True, distribution_strategy=None, default_spaces=None): super(DiagonalOperator, self).__init__(default_spaces) self._domain = self._parse_domain(domain) if distribution_strategy is None: ... ...
 ... ... @@ -112,7 +112,8 @@ class FFTOperator(LinearOperator): # ---Overwritten properties and methods--- def __init__(self, domain, target=None, module=None, domain_dtype=None, target_dtype=None): domain_dtype=None, target_dtype=None, default_spaces=None): super(FFTOperator, self).__init__(default_spaces) # Initialize domain and target ... ...
 ... ... @@ -29,6 +29,7 @@ class InvertibleOperatorMixin(object): else: self.__inverter = ConjugateGradient( preconditioner=self.__preconditioner) super(InvertibleOperatorMixin, self).__init__(*args, **kwargs) def _times(self, x, spaces, x0=None): if x0 is None: ... ...
 ... ... @@ -27,8 +27,8 @@ import nifty.nifty_utilities as utilities class LinearOperator(Loggable, object): __metaclass__ = NiftyMeta def __init__(self): pass def __init__(self, default_spaces=None): self.default_spaces = default_spaces def _parse_domain(self, domain): return utilities.parse_domain(domain) ... ... @@ -45,6 +45,14 @@ class LinearOperator(Loggable, object): def unitary(self): raise NotImplementedError @property def default_spaces(self): return self._default_spaces @default_spaces.setter def default_spaces(self, spaces): self._default_spaces = utilities.cast_axis_to_tuple(spaces) def __call__(self, *args, **kwargs): return self.times(*args, **kwargs) ... ... @@ -127,6 +135,9 @@ class LinearOperator(Loggable, object): raise ValueError( "supplied object is not a `nifty.Field`.") if spaces is None: spaces = self.default_spaces # sanitize the `spaces` and `types` input spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain)) ... ...
 ... ... @@ -27,7 +27,9 @@ class ProjectionOperator(EndomorphicOperator): # ---Overwritten properties and methods--- def __init__(self, projection_field): def __init__(self, projection_field, default_spaces=None): super(ProjectionOperator, self).__init__(default_spaces) if not isinstance(projection_field, Field): raise TypeError("The projection_field must be a NIFTy-Field" "instance.") ... ...
 ... ... @@ -26,7 +26,7 @@ class PropagatorOperator(InvertibleOperatorMixin, EndomorphicOperator): # ---Overwritten properties and methods--- def __init__(self, S=None, M=None, R=None, N=None, inverter=None, preconditioner=None): preconditioner=None, default_spaces=None): """ Sets the standard operator properties and `codomain`, `_A1`, `_A2`, and `RN` if required. ... ... @@ -66,7 +66,8 @@ class PropagatorOperator(InvertibleOperatorMixin, EndomorphicOperator): preconditioner = self._S_times super(PropagatorOperator, self).__init__(inverter=inverter, preconditioner=preconditioner) preconditioner=preconditioner, default_spaces=default_spaces) # ---Mandatory properties and methods--- ... ...
 ... ... @@ -27,7 +27,9 @@ from d2o import STRATEGIES class SmoothingOperator(EndomorphicOperator): # ---Overwritten properties and methods--- def __init__(self, domain=(), sigma=0, log_distances=False): def __init__(self, domain=(), sigma=0, log_distances=False, default_spaces=None): super(SmoothingOperator, self).__init__(default_spaces) self._domain = self._parse_domain(domain) ... ...
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!