diff --git a/demos/wiener_filter_hamiltonian.py b/demos/wiener_filter_hamiltonian.py index 80d2721caa75781e3147d3b89694f850ca4fdf1d..ce801dfddc9734fddbe22f35a302c9aea8daaf24 100644 --- a/demos/wiener_filter_hamiltonian.py +++ b/demos/wiener_filter_hamiltonian.py @@ -34,6 +34,7 @@ class WienerFilterEnergy(Energy): return_g.val = g.val.real return return_g + @memo def D_inverse_x(self): return D.inverse_times(self.position) @@ -88,7 +89,7 @@ if __name__ == "__main__": minimizer = VL_BFGS(convergence_tolerance=0, iteration_limit=50, callback=distance_measure, - max_history_length=5) + max_history_length=3) m0 = Field(s_space, val=1) diff --git a/nifty/energies/__init__.py b/nifty/energies/__init__.py index 3f5356d255bde20b19f743aa97674b5b4d83c584..c31d21a9516ea6840c857a226cf381f0dbbccf28 100644 --- a/nifty/energies/__init__.py +++ b/nifty/energies/__init__.py @@ -2,3 +2,4 @@ from energy import Energy from line_energy import LineEnergy +from memoization import memo diff --git a/nifty/energies/energy.py b/nifty/energies/energy.py index 325e799507453180d9e892934fce56948cb188d5..ba5c7415d52982152328f86fec77b0f9608744c7 100644 --- a/nifty/energies/energy.py +++ b/nifty/energies/energy.py @@ -24,15 +24,3 @@ class Energy(object): @property def curvature(self): raise NotImplementedError - - def memo(f): - name = id(f) - - def wrapped_f(self): - try: - return self._cache[name] - except KeyError: - self._cache[name] = f(self) - return self._cache[name] - return wrapped_f - diff --git a/nifty/energies/memoization.py b/nifty/energies/memoization.py new file mode 100644 index 0000000000000000000000000000000000000000..b7bb126929d6b7c882ce41c36ce063bb0c2d3357 --- /dev/null +++ b/nifty/energies/memoization.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + + +def memo(f): + name = id(f) + + def wrapped_f(self): + try: + return self._cache[name] + except KeyError: + self._cache[name] = f(self) + return self._cache[name] + return wrapped_f