diff --git a/nifty6/operators/energy_operators.py b/nifty6/operators/energy_operators.py index b748a1fccb142a4e9a55fef5063e596cda855568..83618cbcc4f6893030065dfe2f4bcae72b51e281 100644 --- a/nifty6/operators/energy_operators.py +++ b/nifty6/operators/energy_operators.py @@ -44,6 +44,30 @@ class EnergyOperator(Operator): """ _target = DomainTuple.scalar_domain() + def scale(self, scalar): + if scalar >= 0: + return ScaledEnergy(self._target, scalar)(self) + return super(self, Operator).scale(scalar) + + + +class ScaledEnergy(EnergyOperator): + def __init__(self, domain, scalar): + self._domain = domain + assert scalar >= 0 + self._scalar = scalar + + def apply(self, x): + self._check_input(x) + res = ScalingOperator(self._target, self._scalar)(x) + if isinstance(x, Linearization): + if x.metric is not None: + met = SandwichOperator.make(ScalingOperator(self._target, np.sqrt(self._scalar)), x.metric) + res = res.add_metric(met) + return res + + + class Squared2NormOperator(EnergyOperator): """Computes the square of the L2-norm of the output of an operator. diff --git a/test/test_energy_gradients.py b/test/test_energy_gradients.py index 6d295d9d46356bb67d242931b9b89ef435a1d4bd..eb69dcc7d870c2875a47146234bdc195545f36c0 100644 --- a/test/test_energy_gradients.py +++ b/test/test_energy_gradients.py @@ -46,6 +46,16 @@ def test_gaussian(field): energy = ift.GaussianEnergy(domain=field.domain) ift.extra.check_jacobian_consistency(energy, field) +def test_ScaledEnergy(field): + energy = ift.GaussianEnergy(domain=field.domain) + ift.extra.check_jacobian_consistency(energy.scale(0.3), field) + + #lin = ift.Linearization.make_var(field) + met1 = ift.GaussianEnergy(field) + met2 = ift.GaussianEnergy.scale(0.3)(field) + assert met1 == met2 / 0.3 + + assert isinstance(met2, ift.SandwichOperator) def test_studentt(field): energy = ift.StudentTEnergy(domain=field.domain, theta=.5)