From d7b805473621f65f50d731fa56acff63c99c4e84 Mon Sep 17 00:00:00 2001
From: Martin Reinecke <martin@mpa-garching.mpg.de>
Date: Fri, 27 Jul 2018 20:47:00 +0200
Subject: [PATCH] move EnergyAdapter

---
 demos/getting_started_2.py        | 25 +------------------------
 demos/getting_started_3b.py       | 24 +-----------------------
 nifty5/__init__.py                |  1 +
 nifty5/energies/energy_adapter.py | 28 ++++++++++++++++++++++++++++
 4 files changed, 31 insertions(+), 47 deletions(-)
 create mode 100644 nifty5/energies/energy_adapter.py

diff --git a/demos/getting_started_2.py b/demos/getting_started_2.py
index f2267bcff..298aff157 100644
--- a/demos/getting_started_2.py
+++ b/demos/getting_started_2.py
@@ -20,29 +20,6 @@ import nifty5 as ift
 import numpy as np
 
 
-class EnergyAdapter(ift.Energy):
-    def __init__(self, position, op):
-        super(EnergyAdapter, self).__init__(position)
-        self._op = op
-        pvar = ift.Linearization.make_var(position)
-        self._res = op(pvar)
-
-    def at(self, position):
-        return EnergyAdapter(position, self._op)
-
-    @property
-    def value(self):
-        return self._res.val.local_data[()]
-
-    @property
-    def gradient(self):
-        return self._res.gradient
-
-    @property
-    def metric(self):
-        return self._res.metric
-
-
 def get_2D_exposure():
     x_shape, y_shape = position_space.shape
 
@@ -120,7 +97,7 @@ if __name__ == '__main__':
 
     # Minimize the Hamiltonian
     H = ift.Hamiltonian(likelihood)
-    H = EnergyAdapter(position, H)
+    H = ift.EnergyAdapter(position, H)
     #ift.extra.check_value_gradient_consistency(H)
     H = H.make_invertible(ic_cg)
     H, convergence = minimizer(H)
diff --git a/demos/getting_started_3b.py b/demos/getting_started_3b.py
index 2a7ce7d27..6cd4054c5 100644
--- a/demos/getting_started_3b.py
+++ b/demos/getting_started_3b.py
@@ -25,28 +25,6 @@ def get_random_LOS(n_los):
     ends = list(np.random.uniform(0, 1, (n_los, 2)).T)
     return starts, ends
 
-class EnergyAdapter(ift.Energy):
-    def __init__(self, position, op):
-        super(EnergyAdapter, self).__init__(position)
-        self._op = op
-        pvar = ift.Linearization.make_var(position)
-        self._res = op(pvar)
-
-    def at(self, position):
-        return EnergyAdapter(position, self._op)
-
-    @property
-    def value(self):
-        return self._res.val.local_data[()]
-
-    @property
-    def gradient(self):
-        return self._res.gradient
-
-    @property
-    def metric(self):
-        return self._res.metric
-
 if __name__ == '__main__':
     # FIXME description of the tutorial
     np.random.seed(42)
@@ -114,7 +92,7 @@ if __name__ == '__main__':
                    for _ in range(N_samples)]
 
         KL = ift.SampledKullbachLeiblerDivergence(H, samples)
-        KL = EnergyAdapter(position, KL)
+        KL = ift.EnergyAdapter(position, KL)
         KL = KL.make_invertible(ic_cg)
         KL, convergence = minimizer(KL)
         position = KL.position
diff --git a/nifty5/__init__.py b/nifty5/__init__.py
index 66bc0179b..705f12439 100644
--- a/nifty5/__init__.py
+++ b/nifty5/__init__.py
@@ -103,6 +103,7 @@ from .multi.block_diagonal_operator import BlockDiagonalOperator
 
 from .energies.kl import SampledKullbachLeiblerDivergence
 from .energies.hamiltonian import Hamiltonian
+from .energies.energy_adapter import EnergyAdapter
 
 from .operator import Operator
 from .linearization import Linearization
diff --git a/nifty5/energies/energy_adapter.py b/nifty5/energies/energy_adapter.py
new file mode 100644
index 000000000..612563716
--- /dev/null
+++ b/nifty5/energies/energy_adapter.py
@@ -0,0 +1,28 @@
+from __future__ import absolute_import, division, print_function
+
+from ..compat import *
+from ..minimization.energy import Energy
+from ..linearization import Linearization
+
+
+class EnergyAdapter(Energy):
+    def __init__(self, position, op):
+        super(EnergyAdapter, self).__init__(position)
+        self._op = op
+        pvar = Linearization.make_var(position)
+        self._res = op(pvar)
+
+    def at(self, position):
+        return EnergyAdapter(position, self._op)
+
+    @property
+    def value(self):
+        return self._res.val.local_data[()]
+
+    @property
+    def gradient(self):
+        return self._res.gradient
+
+    @property
+    def metric(self):
+        return self._res.metric
-- 
GitLab