Commit 771d979c authored by Philipp Arras's avatar Philipp Arras
Browse files

Add ConstantEnergyOperator

parent 061c1a8a
......@@ -130,6 +130,8 @@ class MetricGaussianKL(Energy):
self._mitigate_nans = nanisinf
if not isinstance(mirror_samples, bool):
raise TypeError
if isinstance(mean, MultiField) and set(point_estimates) == set(mean.keys()):
raise RuntimeError('Point estimates for whole domain. Use EnergyAdapter instead.')
self._hamiltonian = hamiltonian
self._ham4eval = _ham4eval
......@@ -27,7 +27,7 @@ from .linear_operator import LinearOperator
from .operator import Operator
from .sampling_enabler import SamplingDtypeSetter, SamplingEnabler
from .scaling_operator import ScalingOperator
from .simple_linear_operators import VdotOperator
from .simple_linear_operators import NullOperator, VdotOperator
def _check_sampling_dtype(domain, dtypes):
......@@ -485,3 +485,24 @@ class AveragedEnergy(EnergyOperator):
mymap = map(lambda v: self._h(x+v), self._res_samples)
return utilities.my_sum(mymap)/len(self._res_samples)
class _ConstantEnergyOperator(EnergyOperator):
def __init__(self, dom, output):
from ..sugar import makeDomain
self._domain = makeDomain(dom)
if is not output.domain:
raise TypeError
self._output = output
def apply(self, x):
if x.jac is not None:
val = self._output
jac = NullOperator(self._domain, self._target)
met = NullOperator(self._domain, self._domain) if x.want_metric else None
return, jac, met)
return self._output
def __repr__(self):
return 'ConstantEnergyOperator <- {}'.format(self.domain.keys())
......@@ -269,10 +269,14 @@ class Operator(metaclass=NiftyMeta):
return self.__class__.__name__
def simplify_for_constant_input(self, c_inp):
from .energy_operators import EnergyOperator, _ConstantEnergyOperator
if c_inp is None:
return None, self
if c_inp.domain == self.domain:
op = _ConstantOperator(self.domain, self(c_inp))
if isinstance(self, EnergyOperator):
op = _ConstantEnergyOperator(self.domain, self(c_inp))
op = _ConstantOperator(self.domain, self(c_inp))
return op(c_inp), op
return self._simplify_for_constant_input_nontrivial(c_inp)
......@@ -18,7 +18,7 @@
import numpy as np
import pytest
from mpi4py import MPI
from numpy.testing import assert_, assert_equal
from numpy.testing import assert_, assert_equal, assert_raises
import nifty6 as ift
......@@ -58,6 +58,10 @@ def test_kl(constants, point_estimates, mirror_samples, mode, mf):
'n_samples': 2,
'mean': mean0,
'hamiltonian': h}
if isinstance(mean0, ift.MultiField) and set(point_estimates) == set(mean0.keys()):
with assert_raises(RuntimeError):
ift.MetricGaussianKL(**args, comm=comm)
if mode == 0:
kl0 = ift.MetricGaussianKL(**args, comm=comm)
locsamp = kl0._local_samples
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment