Commit 68731e4f authored by Reimar Leike's avatar Reimar Leike
Browse files

reverted changes in linearOperator, put changes in scalingoperator instead,...

reverted changes in linearOperator, put changes in scalingoperator instead, fixed bugs and wrote tests:
parent e72f8abf
Pipeline #67716 passed with stages
in 15 minutes and 30 seconds
...@@ -174,10 +174,7 @@ class LinearOperator(Operator): ...@@ -174,10 +174,7 @@ class LinearOperator(Operator):
return self.apply(x, self.TIMES) return self.apply(x, self.TIMES)
from ..linearization import Linearization from ..linearization import Linearization
if isinstance(x, Linearization): if isinstance(x, Linearization):
res = x.new(self(x._val), self(x._jac)) return x.new(self(x._val), self(x._jac))
if x.metric is not None:
res = res.add_metric(self(x.metric))
return res
return self.__matmul__(x) return self.__matmul__(x)
def times(self, x): def times(self, x):
......
...@@ -98,14 +98,19 @@ class ScalingOperator(EndomorphicOperator): ...@@ -98,14 +98,19 @@ class ScalingOperator(EndomorphicOperator):
return from_random(random_type="normal", domain=self._domain, return from_random(random_type="normal", domain=self._domain,
std=self._get_fct(from_inverse), dtype=dtype) std=self._get_fct(from_inverse), dtype=dtype)
def __matmul__(self, other): def __call__(self, other):
if np.isreal(self._factor) and self._factor > 0: if np.isreal(self._factor) and self._factor >= 0:
from .sandwich_operator import SandwichOperator from ..linearization import Linearization
if isinstance(other, SandwichOperator): if isinstance(other, Linearization):
sqrt_fac = np.sqrt(self._factor) res = EndomorphicOperator.__call__(self, other)
newop = ScalingOperator(other.domain, sqrt_fac) if other.metric is not None:
return SandwichOperator.make(newop, other) from .sandwich_operator import SandwichOperator
return EndomorphicOperator.__matmul__(self, other) sqrt_fac = np.sqrt(self._factor)
newop = ScalingOperator(other.metric.domain, sqrt_fac)
met = SandwichOperator.make(newop, other.metric)
res = res.add_metric(met)
return res
return EndomorphicOperator.__call__(self, other)
def __repr__(self): def __repr__(self):
return "ScalingOperator ({})".format(self._factor) return "ScalingOperator ({})".format(self._factor)
...@@ -32,6 +32,7 @@ SPACES = [ ...@@ -32,6 +32,7 @@ SPACES = [
] ]
SEEDS = [4, 78, 23] SEEDS = [4, 78, 23]
PARAMS = product(SEEDS, SPACES) PARAMS = product(SEEDS, SPACES)
pmp = pytest.mark.parametrize
@pytest.fixture(params=PARAMS) @pytest.fixture(params=PARAMS)
...@@ -46,8 +47,11 @@ def test_gaussian(field): ...@@ -46,8 +47,11 @@ def test_gaussian(field):
energy = ift.GaussianEnergy(domain=field.domain) energy = ift.GaussianEnergy(domain=field.domain)
ift.extra.check_jacobian_consistency(energy, field) ift.extra.check_jacobian_consistency(energy, field)
def test_ScaledEnergy(field): @pmp('icov', [lambda dom:ift.ScalingOperator(dom, 1.),
energy = ift.GaussianEnergy(domain=field.domain) lambda dom:ift.SandwichOperator.make(ift.GeometryRemover(dom))])
def test_ScaledEnergy(field, icov):
icov = icov(field.domain)
energy = ift.GaussianEnergy(inverse_covariance=icov)
ift.extra.check_jacobian_consistency(energy.scale(0.3), field) ift.extra.check_jacobian_consistency(energy.scale(0.3), field)
lin = ift.Linearization.make_var(field, want_metric=True) lin = ift.Linearization.make_var(field, want_metric=True)
...@@ -55,8 +59,8 @@ def test_ScaledEnergy(field): ...@@ -55,8 +59,8 @@ def test_ScaledEnergy(field):
sE = energy.scale(0.3) sE = energy.scale(0.3)
linn = sE(lin) linn = sE(lin)
met2 = linn.metric met2 = linn.metric
assert np.assert_allclose(met1(field), met2(field) / 0.3, rtol=1e-12) np.testing.assert_allclose(met1(field).val, met2(field).val / 0.3, rtol=1e-12)
assert isinstance(met2, ift.SandwichOperator) met2.draw_sample()
def test_studentt(field): def test_studentt(field):
energy = ift.StudentTEnergy(domain=field.domain, theta=.5) energy = ift.StudentTEnergy(domain=field.domain, theta=.5)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment