Commit c9edd78f authored by Theo Steininger's avatar Theo Steininger

Added ComposedOperator class.

parent 943ce348
Pipeline #9524 failed with stages
in 18 minutes and 7 seconds
......@@ -234,8 +234,9 @@ def cast_axis_to_tuple(axis, length):
# shift negative indices to positive ones
axis = tuple(item if (item >= 0) else (item + length) for item in axis)
# Deactivated this, in order to allow for the ComposedOperator
# remove duplicate entries
axis = tuple(set(axis))
# axis = tuple(set(axis))
# assert that all entries are elements in [0, length]
for elem in axis:
......
......@@ -32,3 +32,5 @@ from smoothing_operator import SmoothingOperator
from fft_operator import *
from propagator_operator import PropagatorOperator
from composed_operator import ComposedOperator
# -*- coding: utf-8 -*-
from composed_operator import ComposedOperator
# -*- coding: utf-8 -*-
from nifty.operators.linear_operator import LinearOperator
class ComposedOperator(LinearOperator):
# ---Overwritten properties and methods---
def __init__(self, operators):
self._operator_store = ()
for op in operators:
if not isinstance(op, LinearOperator):
raise TypeError("The elements of the operator list must be"
"instances of the LinearOperator-baseclass")
self._operator_store += (op,)
def _check_input_compatibility(self, x, spaces, types, inverse=False):
"""
The input check must be disabled for the ComposedOperator, since it
is not easily forecasteable what the output of an operator-call
will look like.
"""
return (spaces, types)
# ---Mandatory properties and methods---
@property
def domain(self):
if not hasattr(self, '_domain'):
self._domain = ()
for op in self._operator_store:
self._domain += op.domain
return self._domain
@property
def target(self):
if not hasattr(self, '_target'):
self._target = ()
for op in self._operator_store:
self._target += op.target
return self._target
@property
def field_type(self):
if not hasattr(self, '_field_type'):
self._field_type = ()
for op in self._operator_store:
self._field_type += op.field_type
return self._field_type
@property
def field_type_target(self):
if not hasattr(self, '_field_type_target'):
self._field_type_target = ()
for op in self._operator_store:
self._field_type_target += op.field_type_target
return self._field_type_target
@property
def implemented(self):
return True
@property
def unitary(self):
return False
def _times(self, x, spaces, types):
space_index = 0
type_index = 0
if spaces is None:
spaces = range(len(self.domain))
if types is None:
types = range(len(self.field_type))
for op in self._operator_store:
active_spaces = spaces[space_index:space_index+len(op.domain)]
space_index += len(op.domain)
active_types = types[type_index:type_index+len(op.field_type)]
type_index += len(op.field_type)
x = op(x, spaces=active_spaces, types=active_types)
return x
# -*- coding: utf-8 -*-
from nifty import RGSpace, FFTOperator, ComposedOperator, Field, \
RGRGTransformation
x1 = RGSpace((8,))
x2 = RGSpace((6,))
y1 = RGRGTransformation.get_codomain(x1)
y2 = RGRGTransformation.get_codomain(x2)
fft1 = FFTOperator(x1)
fft2 = FFTOperator(x2)
ifft1 = FFTOperator(y1)
ifft2 = FFTOperator(y2)
com1 = ComposedOperator((fft1, fft2))
com2 = ComposedOperator((fft1, fft2, ifft1))
f = Field((x1, x2), val=0)
f.val[1,1] = 11
com1(f)
com2(f, spaces=(0,1,0))
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment