diff --git a/nifty6/minimization/energy_adapter.py b/nifty6/minimization/energy_adapter.py index 2624151f7b5421fffdcad205e16bb40ac988842f..c1dcaa345834a6c75962f6a6a9d44d1fc241956b 100644 --- a/nifty6/minimization/energy_adapter.py +++ b/nifty6/minimization/energy_adapter.py @@ -11,12 +11,13 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. # -# Copyright(C) 2013-2019 Max-Planck-Society +# Copyright(C) 2013-2020 Max-Planck-Society # # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. from ..linearization import Linearization from ..minimization.energy import Energy +from ..sugar import makeDomain class EnergyAdapter(Energy): @@ -40,10 +41,20 @@ class EnergyAdapter(Energy): additional resources. Default: False. """ - def __init__(self, position, op, constants=[], want_metric=False): + def __init__(self, position, op, constants=[], want_metric=False, + _op4eval=None): super(EnergyAdapter, self).__init__(position) self._op = op self._constants = constants + if self._op4eval is None: + if len(constants) > 0: + dom = {kk: vv for kk, vv in position.domain.items() + if kk in constants} + dom = makeDomain(dom) + cstpos = position.extract(dom) + _, self._op4eval = op.simplify_for_constant_input(cstpos) + else: + self._op4eval = op self._want_metric = want_metric lin = Linearization.make_partial_var(position, constants, want_metric) tmp = self._op(lin) @@ -53,7 +64,7 @@ class EnergyAdapter(Energy): def at(self, position): return EnergyAdapter(position, self._op, self._constants, - self._want_metric) + self._want_metric, self._op4eval) @property def value(self):