Commit 68731e4f authored by Reimar Leike's avatar Reimar Leike
Browse files

reverted changes in linearOperator, put changes in scalingoperator instead,...

reverted changes in linearOperator, put changes in scalingoperator instead, fixed bugs and wrote tests:
parent e72f8abf
Pipeline #67716 passed with stages
in 15 minutes and 30 seconds
......@@ -174,10 +174,7 @@ class LinearOperator(Operator):
return self.apply(x, self.TIMES)
from ..linearization import Linearization
if isinstance(x, Linearization):
res = x.new(self(x._val), self(x._jac))
if x.metric is not None:
res = res.add_metric(self(x.metric))
return res
return x.new(self(x._val), self(x._jac))
return self.__matmul__(x)
def times(self, x):
......
......@@ -98,14 +98,19 @@ class ScalingOperator(EndomorphicOperator):
return from_random(random_type="normal", domain=self._domain,
std=self._get_fct(from_inverse), dtype=dtype)
def __matmul__(self, other):
if np.isreal(self._factor) and self._factor > 0:
def __call__(self, other):
if np.isreal(self._factor) and self._factor >= 0:
from ..linearization import Linearization
if isinstance(other, Linearization):
res = EndomorphicOperator.__call__(self, other)
if other.metric is not None:
from .sandwich_operator import SandwichOperator
if isinstance(other, SandwichOperator):
sqrt_fac = np.sqrt(self._factor)
newop = ScalingOperator(other.domain, sqrt_fac)
return SandwichOperator.make(newop, other)
return EndomorphicOperator.__matmul__(self, other)
newop = ScalingOperator(other.metric.domain, sqrt_fac)
met = SandwichOperator.make(newop, other.metric)
res = res.add_metric(met)
return res
return EndomorphicOperator.__call__(self, other)
def __repr__(self):
return "ScalingOperator ({})".format(self._factor)
......@@ -32,6 +32,7 @@ SPACES = [
]
SEEDS = [4, 78, 23]
PARAMS = product(SEEDS, SPACES)
pmp = pytest.mark.parametrize
@pytest.fixture(params=PARAMS)
......@@ -46,8 +47,11 @@ def test_gaussian(field):
energy = ift.GaussianEnergy(domain=field.domain)
ift.extra.check_jacobian_consistency(energy, field)
def test_ScaledEnergy(field):
energy = ift.GaussianEnergy(domain=field.domain)
@pmp('icov', [lambda dom:ift.ScalingOperator(dom, 1.),
lambda dom:ift.SandwichOperator.make(ift.GeometryRemover(dom))])
def test_ScaledEnergy(field, icov):
icov = icov(field.domain)
energy = ift.GaussianEnergy(inverse_covariance=icov)
ift.extra.check_jacobian_consistency(energy.scale(0.3), field)
lin = ift.Linearization.make_var(field, want_metric=True)
......@@ -55,8 +59,8 @@ def test_ScaledEnergy(field):
sE = energy.scale(0.3)
linn = sE(lin)
met2 = linn.metric
assert np.assert_allclose(met1(field), met2(field) / 0.3, rtol=1e-12)
assert isinstance(met2, ift.SandwichOperator)
np.testing.assert_allclose(met1(field).val, met2(field).val / 0.3, rtol=1e-12)
met2.draw_sample()
def test_studentt(field):
energy = ift.StudentTEnergy(domain=field.domain, theta=.5)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment