Skip to content
Snippets Groups Projects
Commit ca8a91eb authored by Philipp Arras's avatar Philipp Arras
Browse files

Add simplify for constant input

parent b3ea2e41
No related branches found
No related tags found
2 merge requests!535Nifty 7,!509Support complex data in `VariableCovarianceGaussianEnergy` and use simplify for constant input for KL and `EnergyAdapter`
Pipeline #75629 canceled
......@@ -11,12 +11,13 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
# Copyright(C) 2013-2020 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from ..linearization import Linearization
from ..minimization.energy import Energy
from ..sugar import makeDomain
class EnergyAdapter(Energy):
......@@ -40,10 +41,20 @@ class EnergyAdapter(Energy):
additional resources. Default: False.
"""
def __init__(self, position, op, constants=[], want_metric=False):
def __init__(self, position, op, constants=[], want_metric=False,
_op4eval=None):
super(EnergyAdapter, self).__init__(position)
self._op = op
self._constants = constants
if self._op4eval is None:
if len(constants) > 0:
dom = {kk: vv for kk, vv in position.domain.items()
if kk in constants}
dom = makeDomain(dom)
cstpos = position.extract(dom)
_, self._op4eval = op.simplify_for_constant_input(cstpos)
else:
self._op4eval = op
self._want_metric = want_metric
lin = Linearization.make_partial_var(position, constants, want_metric)
tmp = self._op(lin)
......@@ -53,7 +64,7 @@ class EnergyAdapter(Energy):
def at(self, position):
return EnergyAdapter(position, self._op, self._constants,
self._want_metric)
self._want_metric, self._op4eval)
@property
def value(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment