Commit ca8a91eb authored by Philipp Arras's avatar Philipp Arras
Browse files

Add simplify for constant input

parent b3ea2e41
Pipeline #75629 canceled with stages
in 1 minute and 18 seconds
...@@ -11,12 +11,13 @@ ...@@ -11,12 +11,13 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
# #
# Copyright(C) 2013-2019 Max-Planck-Society # Copyright(C) 2013-2020 Max-Planck-Society
# #
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from ..linearization import Linearization from ..linearization import Linearization
from ..minimization.energy import Energy from ..minimization.energy import Energy
from ..sugar import makeDomain
class EnergyAdapter(Energy): class EnergyAdapter(Energy):
...@@ -40,10 +41,20 @@ class EnergyAdapter(Energy): ...@@ -40,10 +41,20 @@ class EnergyAdapter(Energy):
additional resources. Default: False. additional resources. Default: False.
""" """
def __init__(self, position, op, constants=[], want_metric=False): def __init__(self, position, op, constants=[], want_metric=False,
_op4eval=None):
super(EnergyAdapter, self).__init__(position) super(EnergyAdapter, self).__init__(position)
self._op = op self._op = op
self._constants = constants self._constants = constants
if self._op4eval is None:
if len(constants) > 0:
dom = {kk: vv for kk, vv in position.domain.items()
if kk in constants}
dom = makeDomain(dom)
cstpos = position.extract(dom)
_, self._op4eval = op.simplify_for_constant_input(cstpos)
else:
self._op4eval = op
self._want_metric = want_metric self._want_metric = want_metric
lin = Linearization.make_partial_var(position, constants, want_metric) lin = Linearization.make_partial_var(position, constants, want_metric)
tmp = self._op(lin) tmp = self._op(lin)
...@@ -53,7 +64,7 @@ class EnergyAdapter(Energy): ...@@ -53,7 +64,7 @@ class EnergyAdapter(Energy):
def at(self, position): def at(self, position):
return EnergyAdapter(position, self._op, self._constants, return EnergyAdapter(position, self._op, self._constants,
self._want_metric) self._want_metric, self._op4eval)
@property @property
def value(self): def value(self):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment