Commit 60bf8aa9 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

progress

parent 93c14275
......@@ -210,6 +210,23 @@ class MultiField(object):
return False
return True
def extract(self, subset):
if isinstance(subset, MultiDomain):
return MultiField(subset,
tuple(self[key] for key in subset.keys()))
else:
return MultiField.from_dict({key: self[key] for key in subset})
@staticmethod
def combine(fields):
res = {}
for f in fields:
for key in f.keys():
if key in res:
res[key] = res[key]+f[key]
else:
res[key] = f[key]
return MultiField.from_dict(res)
for op in ["__add__", "__radd__",
"__sub__", "__rsub__",
......
......@@ -10,8 +10,6 @@ from .utilities import NiftyMetaBase
#from ..multi.multi_domain import MultiDomain
from .field import Field
from .multi.multi_field import MultiField
from .operators.scaling_operator import ScalingOperator
from .operators.diagonal_operator import DiagonalOperator
class Linearization(object):
......@@ -54,6 +52,7 @@ class Linearization(object):
return (-self).__add__(other)
def __mul__(self, other):
from .operators.diagonal_operator import DiagonalOperator
if isinstance(other, Linearization):
d1 = DiagonalOperator(self._val)
d2 = DiagonalOperator(other._val)
......@@ -77,9 +76,11 @@ class Linearization(object):
@staticmethod
def make_var(field):
from .operators.scaling_operator import ScalingOperator
return Linearization(field, ScalingOperator(1., field.domain))
@staticmethod
def make_const(field):
from .operators.scaling_operator import ScalingOperator
return Linearization(field, ScalingOperator(0., {}))
class Operator(NiftyMetaBase()):
......
......@@ -23,10 +23,10 @@ import abc
import numpy as np
from ..compat import *
from ..utilities import NiftyMetaBase
from ..operator import Operator, Linearization
class LinearOperator(NiftyMetaBase()):
class LinearOperator(Operator):
"""NIFTY base class for linear operators.
The base NIFTY operator class is an abstract class from which
......@@ -205,6 +205,8 @@ class LinearOperator(NiftyMetaBase()):
"""Same as :meth:`times`"""
from ..models.model import Model
from ..models.linear_model import LinearModel
if isinstance(x, Linearization):
return Linearization(self(x._val), self*x._jac)
if isinstance(x, Model):
return LinearModel(x, self)
return self.apply(x, self.TIMES)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment