Commit e1157422 authored by Philipp Arras's avatar Philipp Arras
Browse files

Fixups

parent 86b8fa17
Pipeline #31014 passed with stages
in 1 minute and 24 seconds
...@@ -4,8 +4,8 @@ from . import dobj ...@@ -4,8 +4,8 @@ from . import dobj
from .domains import * from .domains import *
from .domain_tuple import DomainTuple from .domain_tuple import DomainTuple
from .field import Field from .field import Field
from .operators import *
from .nonlinear_operators import * from .nonlinear_operators import *
from .operators import *
from .probing.utils import probe_with_posterior_samples, probe_diagonal, \ from .probing.utils import probe_with_posterior_samples, probe_diagonal, \
StatCalculator StatCalculator
from .minimization import * from .minimization import *
......
import nifty4 as ift import nifty4 as ift
from ..operators import LinearOperator
from .selection_operator import SelectionOperator from .selection_operator import SelectionOperator
...@@ -25,7 +24,7 @@ class NonlinearOperator(object): ...@@ -25,7 +24,7 @@ class NonlinearOperator(object):
def __getitem__(self, key): def __getitem__(self, key):
sel = SelectionOperator(self.value.domain, key) sel = SelectionOperator(self.value.domain, key)
return LinearModel(self.position, self, sel) return sel(self)
# TODO Support addition and multiplication with fields # TODO Support addition and multiplication with fields
def __add__(self, other): def __add__(self, other):
...@@ -49,6 +48,7 @@ class NonlinearOperator(object): ...@@ -49,6 +48,7 @@ class NonlinearOperator(object):
raise NotImplementedError raise NotImplementedError
def _joint_position(op1, op2): def _joint_position(op1, op2):
a = op1.position._val a = op1.position._val
b = op2.position._val b = op2.position._val
...@@ -123,17 +123,18 @@ class ScalarMul(NonlinearOperator): ...@@ -123,17 +123,18 @@ class ScalarMul(NonlinearOperator):
class LinearModel(NonlinearOperator): class LinearModel(NonlinearOperator):
def __init__(self, position, inp, lin_op): def __init__(self, inp, lin_op):
""" """
Computes lin_op(inp) where lin_op is a Linear Operator Computes lin_op(inp) where lin_op is a Linear Operator
""" """
super(LinearModel, self).__init__(position) from ..operators import LinearOperator
super(LinearModel, self).__init__(inp.position)
if not isinstance(lin_op, LinearOperator): if not isinstance(lin_op, LinearOperator):
raise TypeError("needs a LinearOperator as input") raise TypeError("needs a LinearOperator as input")
self._inp = inp.at(position)
self._lin_op = lin_op self._lin_op = lin_op
self._inp = inp
# FIXME This is a dirty hack! # FIXME This is a dirty hack!
if isinstance(self._lin_op, SelectionOperator): if isinstance(self._lin_op, SelectionOperator):
self._lin_op = SelectionOperator(self._inp.value.domain, self._lin_op = SelectionOperator(self._inp.value.domain,
...@@ -143,4 +144,4 @@ class LinearModel(NonlinearOperator): ...@@ -143,4 +144,4 @@ class LinearModel(NonlinearOperator):
self._gradient = self._lin_op*self._inp.gradient self._gradient = self._lin_op*self._inp.gradient
def at(self, position): def at(self, position):
return self.__class__(position, self._inp, self._lin_op) return self.__class__(self._inp.at(position), self._lin_op)
from ..multi import MultiDomain, MultiField
from ..operators import LinearOperator from ..operators import LinearOperator
from ..sugar import full from ..sugar import full
class SelectionOperator(LinearOperator): class SelectionOperator(LinearOperator):
def __init__(self, domain, key): def __init__(self, domain, key):
from ..multi import MultiDomain
if not isinstance(domain, MultiDomain): if not isinstance(domain, MultiDomain):
raise TypeError("Domain must be a MultiDomain") raise TypeError("Domain must be a MultiDomain")
self._target = domain[key] self._target = domain[key]
...@@ -34,4 +34,5 @@ class SelectionOperator(LinearOperator): ...@@ -34,4 +34,5 @@ class SelectionOperator(LinearOperator):
result[key] = full(val, 0.) result[key] = full(val, 0.)
else: else:
result[key] = x.copy() result[key] = x.copy()
from ..multi import MultiField
return MultiField(result) return MultiField(result)
...@@ -17,9 +17,11 @@ ...@@ -17,9 +17,11 @@
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
import abc import abc
from ..utilities import NiftyMetaBase
import numpy as np import numpy as np
from ..utilities import NiftyMetaBase
class LinearOperator(NiftyMetaBase()): class LinearOperator(NiftyMetaBase()):
"""NIFTY base class for linear operators. """NIFTY base class for linear operators.
...@@ -196,7 +198,10 @@ class LinearOperator(NiftyMetaBase()): ...@@ -196,7 +198,10 @@ class LinearOperator(NiftyMetaBase()):
raise NotImplementedError raise NotImplementedError
def __call__(self, x): def __call__(self, x):
from ..nonlinear_operators import LinearModel, NonlinearOperator
"""Same as :meth:`times`""" """Same as :meth:`times`"""
if isinstance(x, NonlinearOperator):
return LinearModel(x, self)
return self.apply(x, self.TIMES) return self.apply(x, self.TIMES)
def times(self, x): def times(self, x):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment