From 4d0eace3a593c2d9f72460af290824f1861c922e Mon Sep 17 00:00:00 2001 From: theos <theo.steininger@ultimanet.de> Date: Tue, 25 Oct 2016 12:10:58 +0200 Subject: [PATCH] Fixed the memoization for Energy. --- demos/wiener_filter_hamiltonian.py | 3 ++- nifty/energies/__init__.py | 1 + nifty/energies/energy.py | 12 ------------ nifty/energies/memoization.py | 13 +++++++++++++ 4 files changed, 16 insertions(+), 13 deletions(-) create mode 100644 nifty/energies/memoization.py diff --git a/demos/wiener_filter_hamiltonian.py b/demos/wiener_filter_hamiltonian.py index 80d2721ca..ce801dfdd 100644 --- a/demos/wiener_filter_hamiltonian.py +++ b/demos/wiener_filter_hamiltonian.py @@ -34,6 +34,7 @@ class WienerFilterEnergy(Energy): return_g.val = g.val.real return return_g + @memo def D_inverse_x(self): return D.inverse_times(self.position) @@ -88,7 +89,7 @@ if __name__ == "__main__": minimizer = VL_BFGS(convergence_tolerance=0, iteration_limit=50, callback=distance_measure, - max_history_length=5) + max_history_length=3) m0 = Field(s_space, val=1) diff --git a/nifty/energies/__init__.py b/nifty/energies/__init__.py index 3f5356d25..c31d21a95 100644 --- a/nifty/energies/__init__.py +++ b/nifty/energies/__init__.py @@ -2,3 +2,4 @@ from energy import Energy from line_energy import LineEnergy +from memoization import memo diff --git a/nifty/energies/energy.py b/nifty/energies/energy.py index 325e79950..ba5c7415d 100644 --- a/nifty/energies/energy.py +++ b/nifty/energies/energy.py @@ -24,15 +24,3 @@ class Energy(object): @property def curvature(self): raise NotImplementedError - - def memo(f): - name = id(f) - - def wrapped_f(self): - try: - return self._cache[name] - except KeyError: - self._cache[name] = f(self) - return self._cache[name] - return wrapped_f - diff --git a/nifty/energies/memoization.py b/nifty/energies/memoization.py new file mode 100644 index 000000000..b7bb12692 --- /dev/null +++ b/nifty/energies/memoization.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + + +def memo(f): + name = id(f) + + def wrapped_f(self): + try: + return self._cache[name] + except KeyError: + self._cache[name] = f(self) + return self._cache[name] + return wrapped_f -- GitLab