From 48e52d16e63229c19458fcf8ec692dddc1f7565d Mon Sep 17 00:00:00 2001 From: Vincent Eberle <veberle@mpa-garching.mpg.de> Date: Fri, 17 Jan 2020 17:19:41 +0100 Subject: [PATCH] Work in progress. Scalable Energys by Reimar and Vince --- nifty6/operators/energy_operators.py | 24 ++++++++++++++++++++++++ test/test_energy_gradients.py | 10 ++++++++++ 2 files changed, 34 insertions(+) diff --git a/nifty6/operators/energy_operators.py b/nifty6/operators/energy_operators.py index b748a1fcc..83618cbcc 100644 --- a/nifty6/operators/energy_operators.py +++ b/nifty6/operators/energy_operators.py @@ -44,6 +44,30 @@ class EnergyOperator(Operator): """ _target = DomainTuple.scalar_domain() + def scale(self, scalar): + if scalar >= 0: + return ScaledEnergy(self._target, scalar)(self) + return super(self, Operator).scale(scalar) + + + +class ScaledEnergy(EnergyOperator): + def __init__(self, domain, scalar): + self._domain = domain + assert scalar >= 0 + self._scalar = scalar + + def apply(self, x): + self._check_input(x) + res = ScalingOperator(self._target, self._scalar)(x) + if isinstance(x, Linearization): + if x.metric is not None: + met = SandwichOperator.make(ScalingOperator(self._target, np.sqrt(self._scalar)), x.metric) + res = res.add_metric(met) + return res + + + class Squared2NormOperator(EnergyOperator): """Computes the square of the L2-norm of the output of an operator. diff --git a/test/test_energy_gradients.py b/test/test_energy_gradients.py index 6d295d9d4..eb69dcc7d 100644 --- a/test/test_energy_gradients.py +++ b/test/test_energy_gradients.py @@ -46,6 +46,16 @@ def test_gaussian(field): energy = ift.GaussianEnergy(domain=field.domain) ift.extra.check_jacobian_consistency(energy, field) +def test_ScaledEnergy(field): + energy = ift.GaussianEnergy(domain=field.domain) + ift.extra.check_jacobian_consistency(energy.scale(0.3), field) + + #lin = ift.Linearization.make_var(field) + met1 = ift.GaussianEnergy(field) + met2 = ift.GaussianEnergy.scale(0.3)(field) + assert met1 == met2 / 0.3 + + assert isinstance(met2, ift.SandwichOperator) def test_studentt(field): energy = ift.StudentTEnergy(domain=field.domain, theta=.5) -- GitLab