Commit 7a3ef9b3 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'scalable_energies2' into 'NIFTy_6'

Scalable energies2

See merge request !402
parents 7d91c08e bbb24ee2
Pipeline #67726 passed with stages
in 15 minutes and 22 seconds
......@@ -98,5 +98,18 @@ class ScalingOperator(EndomorphicOperator):
return from_random(random_type="normal", domain=self._domain,
std=self._get_fct(from_inverse), dtype=dtype)
def __call__(self, other):
res = EndomorphicOperator.__call__(self, other)
if np.isreal(self._factor) and self._factor >= 0:
from ..linearization import Linearization
if isinstance(other, Linearization):
if other.metric is not None:
from .sandwich_operator import SandwichOperator
sqrt_fac = np.sqrt(self._factor)
newop = ScalingOperator(other.metric.domain, sqrt_fac)
met = SandwichOperator.make(newop, other.metric)
res = res.add_metric(met)
return res
def __repr__(self):
return "ScalingOperator ({})".format(self._factor)
......@@ -32,6 +32,7 @@ SPACES = [
]
SEEDS = [4, 78, 23]
PARAMS = product(SEEDS, SPACES)
pmp = pytest.mark.parametrize
@pytest.fixture(params=PARAMS)
......@@ -46,6 +47,18 @@ def test_gaussian(field):
energy = ift.GaussianEnergy(domain=field.domain)
ift.extra.check_jacobian_consistency(energy, field)
@pmp('icov', [lambda dom:ift.ScalingOperator(dom, 1.),
lambda dom:ift.SandwichOperator.make(ift.GeometryRemover(dom))])
def test_ScaledEnergy(field, icov):
icov = icov(field.domain)
energy = ift.GaussianEnergy(inverse_covariance=icov)
ift.extra.check_jacobian_consistency(energy.scale(0.3), field)
lin = ift.Linearization.make_var(field, want_metric=True)
met1 = energy(lin).metric
met2 = energy.scale(0.3)(lin).metric
np.testing.assert_allclose(met1(field).val, met2(field).val / 0.3, rtol=1e-12)
met2.draw_sample()
def test_studentt(field):
energy = ift.StudentTEnergy(domain=field.domain, theta=.5)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment