Commit 67b317d9 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'NIFTy_5' into modular_amplitudes

parents c8e16cdf 309dd635
......@@ -181,7 +181,7 @@ class ScipyCG(Minimizer):
prec_op = scipy_linop(shape=(op.domain.size,,
res, stat = cg(sci_op, b, x0=sx, tol=self._tol, M=prec_op,
maxiter=self._maxiter, atol='legacy')
stat = (IterationController.CONVERGED if stat >= 0 else
return, energy.position)), stat
......@@ -143,6 +143,9 @@ class LinearOperator(Operator):
return self._capability
def force(self, x):
return self.apply(x.extract(self.domain), self.TIMES)
def apply(self, x, mode):
""" Applies the Operator to a given `x`, in a specified `mode`.
......@@ -256,5 +259,4 @@ class LinearOperator(Operator):
def _check_input(self, x, mode):
if self._dom(mode) != x.domain:
raise ValueError("The operator's and field's domains don't match.")
self._check_domain_equality(self._dom(mode), x.domain)
......@@ -32,6 +32,15 @@ from .linear_operator import LinearOperator
# collect the unstructured Fields.
class MaskOperator(LinearOperator):
def __init__(self, mask):
"""Implementation of a mask response
This operator takes a field, applies a mask and returns the values of
the field in a UnstructuredDomain. It can be used as response operator.
mask : Field
if not isinstance(mask, Field):
raise TypeError
......@@ -23,6 +23,17 @@ class Operator(NiftyMetaBase()):
The domain on which the Operator's output Field lives."""
return self._target
def _check_domain_equality(dom_op, dom_field):
if dom_op != dom_field:
s = "The operator's and field's domains don't match."
from ..domain_tuple import DomainTuple
from ..multi_domain import MultiDomain
if not isinstance(dom_op, [DomainTuple, MultiDomain]):
s += " Your operator's domain is neither a `DomainTuple`" \
" nor a `MultiDomain`."
raise ValueError(s)
def scale(self, factor):
if factor == 1:
return self
......@@ -60,13 +71,13 @@ class Operator(NiftyMetaBase()):
raise NotImplementedError
def force(self, x):
"""Extract correct subset of domain of x and apply operator."""
return self.apply(x.extract(self.domain))
def _check_input(self, x):
from ..linearization import Linearization
d = if isinstance(x, Linearization) else x.domain
if self._domain != d:
raise ValueError("The operator's and field's domains don't match.")
self._check_domain_equality(self._domain, d)
def __call__(self, x):
if isinstance(x, Operator):
......@@ -32,5 +32,4 @@ def _custom_name_func(testcase_func, param_num, param):
def expand(*args, **kwargs):
return parameterized.expand(*args, testcase_func_name=_custom_name_func,
return parameterized.expand(*args, func=_custom_name_func, **kwargs)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment