Commit c8860db8 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

tweaks

parent ae534b64
Pipeline #76486 failed with stages
in 5 minutes
......@@ -85,6 +85,7 @@ def check_linear_operator(op, domain_dtype=np.float64, target_dtype=np.float64,
rtol, only_r_linear)
# MR FIXME the default tolerance is extremely small, especially for Jacobian tests
def check_operator(op, loc, tol=1e-8, ntries=100, perf_check=True,
only_r_differentiable=True, metric_sampling=True):
"""
......@@ -331,15 +332,17 @@ def _check_nontrivial_constant(op, loc, tol, ntries, only_r_differentiable,
val0 = op(loc)
_, op0 = op.simplify_for_constant_input(cstloc)
val1 = op0(loc)
val2 = op0(loc.unite(cstloc))
assert_equal(val1, val2)
# MR FIXME: This tests something we don't promise!
# val2 = op0(loc.unite(cstloc))
# assert_equal(val1, val2)
assert_equal(val0, val1)
lin = Linearization.make_var(loc, want_metric=True)
oplin = op0(lin)
if isinstance(op, EnergyOperator):
_allzero(oplin.gradient.extract(cstdom))
_allzero(oplin.jac(from_random(cstdom).unite(full(vardom, 0))))
# MR FIXME: This tests something we don't promise!
# _allzero(oplin.jac(from_random(cstdom).unite(full(vardom, 0))))
if isinstance(op, EnergyOperator) and metric_sampling:
samp0 = oplin.metric.draw_sample()
......
......@@ -290,7 +290,8 @@ class Operator(metaclass=NiftyMeta):
op = ConstantEnergyOperator(self.domain, self(c_inp))
else:
op = ConstantOperator(self.domain, self(c_inp))
op = ConstantOperator(self.domain, self(c_inp))
# MR FIXME something is redundant here
# op = ConstantOperator(self.domain, self(c_inp))
return op(c_inp), op
if not isinstance(c_inp.domain, MultiDomain):
raise RuntimeError
......
......@@ -25,23 +25,23 @@ from .simple_linear_operators import NullOperator
class ConstCollector(object):
def __init__(self):
self._const = None
self._nc = set()
self._const = None # MultiField on the part of the MultiDomain that could be constant
self._nc = set() # NoConstant - set of keys that we know cannot be constant
def mult(self, const, fulldom):
if const is None:
self._nc |= set(fulldom)
self._nc |= set(fulldom.keys())
else:
self._nc |= set(fulldom) - set(const)
from ..multi_field import MultiField
self._nc |= set(fulldom.keys()) - set(const.keys())
if self._const is None:
from ..multi_field import MultiField
self._const = MultiField.from_dict(
{key: const[key] for key in const if key not in self._nc})
else:
from ..multi_field import MultiField
{key: const[key]
for key in const.keys() if key not in self._nc})
else: # we know that the domains are identical for products
self._const = MultiField.from_dict(
{key: self._const[key]*const[key]
for key in const if key not in self._nc})
for key in const.keys() if key not in self._nc})
def add(self, const, fulldom):
if const is None:
......@@ -49,15 +49,10 @@ class ConstCollector(object):
else:
from ..multi_field import MultiField
self._nc |= set(fulldom.keys()) - set(const.keys())
if self._const is None:
self._const = MultiField.from_dict(
{key: const[key]
for key in const.keys() if key not in self._nc})
else:
self._const = self._const.unite(const)
self._const = MultiField.from_dict(
{key: self._const[key]
for key in self._const if key not in self._nc})
self._const = const if self._const is None else self._const.unite(const)
self._const = MultiField.from_dict(
{key: const[key]
for key in const.keys() if key not in self._nc})
@property
def constfield(self):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment