Commit 67b317d9 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'NIFTy_5' into modular_amplitudes

parents c8e16cdf 309dd635
...@@ -181,7 +181,7 @@ class ScipyCG(Minimizer): ...@@ -181,7 +181,7 @@ class ScipyCG(Minimizer):
prec_op = scipy_linop(shape=(op.domain.size, op.target.size), prec_op = scipy_linop(shape=(op.domain.size, op.target.size),
matvec=mymatvec(preconditioner)) matvec=mymatvec(preconditioner))
res, stat = cg(sci_op, b, x0=sx, tol=self._tol, M=prec_op, res, stat = cg(sci_op, b, x0=sx, tol=self._tol, M=prec_op,
maxiter=self._maxiter) maxiter=self._maxiter, atol='legacy')
stat = (IterationController.CONVERGED if stat >= 0 else stat = (IterationController.CONVERGED if stat >= 0 else
IterationController.ERROR) IterationController.ERROR)
return energy.at(_toField(res, energy.position)), stat return energy.at(_toField(res, energy.position)), stat
...@@ -143,6 +143,9 @@ class LinearOperator(Operator): ...@@ -143,6 +143,9 @@ class LinearOperator(Operator):
""" """
return self._capability return self._capability
def force(self, x):
return self.apply(x.extract(self.domain), self.TIMES)
def apply(self, x, mode): def apply(self, x, mode):
""" Applies the Operator to a given `x`, in a specified `mode`. """ Applies the Operator to a given `x`, in a specified `mode`.
...@@ -256,5 +259,4 @@ class LinearOperator(Operator): ...@@ -256,5 +259,4 @@ class LinearOperator(Operator):
def _check_input(self, x, mode): def _check_input(self, x, mode):
self._check_mode(mode) self._check_mode(mode)
if self._dom(mode) != x.domain: self._check_domain_equality(self._dom(mode), x.domain)
raise ValueError("The operator's and field's domains don't match.")
...@@ -32,6 +32,15 @@ from .linear_operator import LinearOperator ...@@ -32,6 +32,15 @@ from .linear_operator import LinearOperator
# collect the unstructured Fields. # collect the unstructured Fields.
class MaskOperator(LinearOperator): class MaskOperator(LinearOperator):
def __init__(self, mask): def __init__(self, mask):
"""Implementation of a mask response
This operator takes a field, applies a mask and returns the values of
the field in a UnstructuredDomain. It can be used as response operator.
Parameters
----------
mask : Field
"""
if not isinstance(mask, Field): if not isinstance(mask, Field):
raise TypeError raise TypeError
......
...@@ -23,6 +23,17 @@ class Operator(NiftyMetaBase()): ...@@ -23,6 +23,17 @@ class Operator(NiftyMetaBase()):
The domain on which the Operator's output Field lives.""" The domain on which the Operator's output Field lives."""
return self._target return self._target
@staticmethod
def _check_domain_equality(dom_op, dom_field):
if dom_op != dom_field:
s = "The operator's and field's domains don't match."
from ..domain_tuple import DomainTuple
from ..multi_domain import MultiDomain
if not isinstance(dom_op, [DomainTuple, MultiDomain]):
s += " Your operator's domain is neither a `DomainTuple`" \
" nor a `MultiDomain`."
raise ValueError(s)
def scale(self, factor): def scale(self, factor):
if factor == 1: if factor == 1:
return self return self
...@@ -60,13 +71,13 @@ class Operator(NiftyMetaBase()): ...@@ -60,13 +71,13 @@ class Operator(NiftyMetaBase()):
raise NotImplementedError raise NotImplementedError
def force(self, x): def force(self, x):
"""Extract correct subset of domain of x and apply operator."""
return self.apply(x.extract(self.domain)) return self.apply(x.extract(self.domain))
def _check_input(self, x): def _check_input(self, x):
from ..linearization import Linearization from ..linearization import Linearization
d = x.target if isinstance(x, Linearization) else x.domain d = x.target if isinstance(x, Linearization) else x.domain
if self._domain != d: self._check_domain_equality(self._domain, d)
raise ValueError("The operator's and field's domains don't match.")
def __call__(self, x): def __call__(self, x):
if isinstance(x, Operator): if isinstance(x, Operator):
......
...@@ -32,5 +32,4 @@ def _custom_name_func(testcase_func, param_num, param): ...@@ -32,5 +32,4 @@ def _custom_name_func(testcase_func, param_num, param):
def expand(*args, **kwargs): def expand(*args, **kwargs):
return parameterized.expand(*args, testcase_func_name=_custom_name_func, return parameterized.expand(*args, func=_custom_name_func, **kwargs)
**kwargs)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment