Commit e1157422 authored by Philipp Arras's avatar Philipp Arras

Fixups

parent 86b8fa17
Pipeline #31014 passed with stages
in 1 minute and 24 seconds
......@@ -4,8 +4,8 @@ from . import dobj
from .domains import *
from .domain_tuple import DomainTuple
from .field import Field
from .operators import *
from .nonlinear_operators import *
from .operators import *
from .probing.utils import probe_with_posterior_samples, probe_diagonal, \
StatCalculator
from .minimization import *
......
import nifty4 as ift
from ..operators import LinearOperator
from .selection_operator import SelectionOperator
......@@ -25,7 +24,7 @@ class NonlinearOperator(object):
def __getitem__(self, key):
sel = SelectionOperator(self.value.domain, key)
return LinearModel(self.position, self, sel)
return sel(self)
# TODO Support addition and multiplication with fields
def __add__(self, other):
......@@ -49,6 +48,7 @@ class NonlinearOperator(object):
raise NotImplementedError
def _joint_position(op1, op2):
a = op1.position._val
b = op2.position._val
......@@ -123,17 +123,18 @@ class ScalarMul(NonlinearOperator):
class LinearModel(NonlinearOperator):
def __init__(self, position, inp, lin_op):
def __init__(self, inp, lin_op):
"""
Computes lin_op(inp) where lin_op is a Linear Operator
"""
super(LinearModel, self).__init__(position)
from ..operators import LinearOperator
super(LinearModel, self).__init__(inp.position)
if not isinstance(lin_op, LinearOperator):
raise TypeError("needs a LinearOperator as input")
self._inp = inp.at(position)
self._lin_op = lin_op
self._inp = inp
# FIXME This is a dirty hack!
if isinstance(self._lin_op, SelectionOperator):
self._lin_op = SelectionOperator(self._inp.value.domain,
......@@ -143,4 +144,4 @@ class LinearModel(NonlinearOperator):
self._gradient = self._lin_op*self._inp.gradient
def at(self, position):
return self.__class__(position, self._inp, self._lin_op)
return self.__class__(self._inp.at(position), self._lin_op)
from ..multi import MultiDomain, MultiField
from ..operators import LinearOperator
from ..sugar import full
class SelectionOperator(LinearOperator):
def __init__(self, domain, key):
from ..multi import MultiDomain
if not isinstance(domain, MultiDomain):
raise TypeError("Domain must be a MultiDomain")
self._target = domain[key]
......@@ -34,4 +34,5 @@ class SelectionOperator(LinearOperator):
result[key] = full(val, 0.)
else:
result[key] = x.copy()
from ..multi import MultiField
return MultiField(result)
......@@ -17,9 +17,11 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
import abc
from ..utilities import NiftyMetaBase
import numpy as np
from ..utilities import NiftyMetaBase
class LinearOperator(NiftyMetaBase()):
"""NIFTY base class for linear operators.
......@@ -196,7 +198,10 @@ class LinearOperator(NiftyMetaBase()):
raise NotImplementedError
def __call__(self, x):
from ..nonlinear_operators import LinearModel, NonlinearOperator
"""Same as :meth:`times`"""
if isinstance(x, NonlinearOperator):
return LinearModel(x, self)
return self.apply(x, self.TIMES)
def times(self, x):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment