Commit 9cc3f112 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

fix invocation of curvature tests; provide some missing MultiField functionality

parent 3d2ba0d3
......@@ -61,4 +61,6 @@ class GaussianEnergy(Energy):
@property
@memo
def curvature(self):
if self._cov is None:
return SandwichOperator.make(self._inp.gradient, None)
return SandwichOperator.make(self._inp.gradient, self._cov.inverse)
......@@ -164,10 +164,25 @@ class MultiField(object):
def __neg__(self):
return MultiField({key: -val for key, val in self.items()})
def __abs__(self):
return MultiField({key: abs(val) for key, val in self.items()})
def conjugate(self):
return MultiField({key: sub_field.conjugate()
for key, sub_field in self.items()})
def all(self):
for v in self.values():
if not v.all():
return False
return True
def any(self):
for v in self.values():
if v.any():
return True
return False
def isEquivalentTo(self, other):
"""Determines (as quickly as possible) whether `self`'s content is
identical to `other`'s content."""
......
......@@ -93,7 +93,7 @@ class Energy_Tests(unittest.TestCase):
N = None
energy = ift.GaussianEnergy(d_model, d, N)
if isinstance(nonlinearity, ift.Linear):
if isinstance(nonlinearity(), ift.Linear):
ift.extra.check_value_gradient_curvature_consistency(
energy, ntries=10)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment