From 48e52d16e63229c19458fcf8ec692dddc1f7565d Mon Sep 17 00:00:00 2001
From: Vincent Eberle <veberle@mpa-garching.mpg.de>
Date: Fri, 17 Jan 2020 17:19:41 +0100
Subject: [PATCH] Work in progress. Scalable Energys by Reimar and Vince

---
 nifty6/operators/energy_operators.py | 24 ++++++++++++++++++++++++
 test/test_energy_gradients.py        | 10 ++++++++++
 2 files changed, 34 insertions(+)

diff --git a/nifty6/operators/energy_operators.py b/nifty6/operators/energy_operators.py
index b748a1fcc..83618cbcc 100644
--- a/nifty6/operators/energy_operators.py
+++ b/nifty6/operators/energy_operators.py
@@ -44,6 +44,30 @@ class EnergyOperator(Operator):
     """
     _target = DomainTuple.scalar_domain()
 
+    def scale(self, scalar):
+        if scalar >= 0:
+            return ScaledEnergy(self._target, scalar)(self)
+        return super(self, Operator).scale(scalar)
+
+      
+
+class ScaledEnergy(EnergyOperator):
+    def __init__(self, domain, scalar):
+        self._domain = domain
+        assert scalar >= 0
+        self._scalar = scalar
+
+    def apply(self, x):
+        self._check_input(x)
+        res = ScalingOperator(self._target, self._scalar)(x)
+        if isinstance(x, Linearization):
+            if x.metric is not None:
+                met = SandwichOperator.make(ScalingOperator(self._target, np.sqrt(self._scalar)), x.metric)
+                res = res.add_metric(met)
+        return res
+
+
+           
 
 class Squared2NormOperator(EnergyOperator):
     """Computes the square of the L2-norm of the output of an operator.
diff --git a/test/test_energy_gradients.py b/test/test_energy_gradients.py
index 6d295d9d4..eb69dcc7d 100644
--- a/test/test_energy_gradients.py
+++ b/test/test_energy_gradients.py
@@ -46,6 +46,16 @@ def test_gaussian(field):
     energy = ift.GaussianEnergy(domain=field.domain)
     ift.extra.check_jacobian_consistency(energy, field)
 
+def test_ScaledEnergy(field):
+    energy = ift.GaussianEnergy(domain=field.domain)
+    ift.extra.check_jacobian_consistency(energy.scale(0.3), field)
+
+    #lin =  ift.Linearization.make_var(field)
+    met1 = ift.GaussianEnergy(field)
+    met2 = ift.GaussianEnergy.scale(0.3)(field)
+    assert met1 == met2 / 0.3
+
+    assert isinstance(met2, ift.SandwichOperator)
 
 def test_studentt(field):
     energy = ift.StudentTEnergy(domain=field.domain, theta=.5)
-- 
GitLab