Commit 8103fbbd authored by Philipp Arras's avatar Philipp Arras
Browse files

Cosmetics

parent cc077788
......@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2020 Max-Planck-Society
# Copyright(C) 2013-2021 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
......@@ -97,13 +97,13 @@ def main():
data = signal_response(mock_position) + N.draw_sample_with_dtype(dtype=np.float64)
# Minimization parameters
ic_sampling = ift.AbsDeltaEnergyController(
ic_sampling = ift.AbsDeltaEnergyController(name="Sampling (linear)",
deltaE=0.05, iteration_limit=100)
ic_newton = ift.AbsDeltaEnergyController(
name='Newton', deltaE=0.5, iteration_limit=35)
ic_newton = ift.AbsDeltaEnergyController(name='Newton', deltaE=0.5,
convergence_level=2, iteration_limit=35)
minimizer = ift.NewtonCG(ic_newton)
ic_sampling_nl = ift.AbsDeltaEnergyController(
name='Sampling', deltaE=0.5, iteration_limit=15, convergence_level=2)
ic_sampling_nl = ift.AbsDeltaEnergyController(name='Sampling (nonlin)',
deltaE=0.5, iteration_limit=15, convergence_level=2)
minimizer_sampling = ift.NewtonCG(ic_sampling_nl)
# Set up likelihood and information Hamiltonian
......
......@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2020 Max-Planck-Society
# Copyright(C) 2013-2021 Max-Planck-Society
# Authors: Philipp Frank
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
......@@ -21,6 +21,7 @@ from functools import reduce
from .. import random
from .. import utilities
from ..domain_tuple import DomainTuple
from ..linearization import Linearization
from ..multi_field import MultiField
from ..operators.inversion_enabler import InversionEnabler
......@@ -38,17 +39,6 @@ from ..utilities import myassert
from .energy import Energy
from .descent_minimizers import DescentMinimizer, ConjugateGradient
def _is_prior_dtype_float(H):
real = True
dts = H._prior._met._dtype
if isinstance(dts, dict):
for k in dts.keys():
if not np.issubdtype(dts[k], np.float):
real = False
else:
real = np.issubdtype(dts, np.float)
return real
def _get_lo_hi(comm, n_samples):
ntask, rank, _ = utilities.get_MPI_params_from_comm(comm)
......@@ -187,8 +177,17 @@ class _GeoMetricSampler:
n_samples, mirror_samples, napprox=0, want_error=False):
if not isinstance(H, StandardHamiltonian):
raise NotImplementedError
if not _is_prior_dtype_float(H):
# Check domain dtype
dts = H._prior._met._dtype
if isinstance(H.domain, DomainTuple):
real = np.issubdtype(dts, np.float)
else:
real = all([np.issubdtype(dts[kk], np.float) for kk in dts.keys()])
if not real:
raise ValueError("_GeoMetricSampler only supports real valued latent DOFs.")
# /Check domain dtype
if isinstance(position, MultiField):
self._position = position.extract(H.domain)
else:
......@@ -206,12 +205,11 @@ class _GeoMetricSampler:
scale = SamplingDtypeSetter(scale, dtype) if sampling else scale
fl = f_lh(Linearization.make_var(self._position))
self._g = (Adder(-self._position) +
fl.jac.adjoint@Adder(-fl.val)@f_lh)
self._g = (Adder(-self._position) + fl.jac.adjoint@Adder(-fl.val)@f_lh)
self._likelihood = SandwichOperator.make(fl.jac, scale)
self._prior = SamplingDtypeSetter(ScalingOperator(fl.domain,1.), np.float64)
self._met = self._likelihood + self._prior
if napprox >=1:
if napprox >= 1:
self._approximation = makeOp(approximation2endo(self._met, napprox)).inverse
else:
self._approximation = None
......
......@@ -93,10 +93,11 @@ class LikelihoodOperator(EnergyOperator):
:func:`~nifty7.operators.operator.Operator.get_transformation`.
"""
dtp, f = self.get_transformation()
ch = ScalingOperator(f.target, 1.)
ch = None
if dtp is not None:
ch = SamplingDtypeSetter(ch, dtp)
return SandwichOperator.make(f(Linearization.make_var(x)).jac, ch)
ch = SamplingDtypeSetter(ScalingOperator(f.target, 1.), dtp)
bun = f(Linearization.make_var(x)).jac
return SandwichOperator.make(bun, ch)
class Squared2NormOperator(EnergyOperator):
......@@ -181,8 +182,8 @@ class VariableCovarianceGaussianEnergy(LikelihoodOperator):
Default is True
"""
def __init__(self, domain, residual_key, inverse_covariance_key, sampling_dtype,
use_full_fisher = True):
def __init__(self, domain, residual_key, inverse_covariance_key,
sampling_dtype, use_full_fisher=True):
self._kr = str(residual_key)
self._ki = str(inverse_covariance_key)
dom = DomainTuple.make(domain)
......@@ -190,7 +191,7 @@ class VariableCovarianceGaussianEnergy(LikelihoodOperator):
self._dt = {self._kr: sampling_dtype, self._ki: np.float64}
_check_sampling_dtype(self._domain, self._dt)
self._cplx = _iscomplex(sampling_dtype)
self._use_fisher = use_full_fisher
self._use_full_fisher = use_full_fisher
def apply(self, x):
self._check_input(x)
......@@ -201,7 +202,7 @@ class VariableCovarianceGaussianEnergy(LikelihoodOperator):
res = 0.5*(r.vdot(r*i) - i.ptw("log").sum())
if not x.want_metric:
return res
if self._use_fisher:
if self._use_full_fisher:
met = 1. if self._cplx else 0.5
met = MultiField.from_dict({self._kr: i.val, self._ki: met*i.val**(-2)},
domain=self._domain)
......@@ -616,6 +617,3 @@ class AveragedEnergy(EnergyOperator):
dtp, trafo = self._h.get_transformation()
mymap = map(lambda v: trafo@Adder(v), self._res_samples)
return dtp, utilities.my_sum(mymap)/np.sqrt(len(self._res_samples))
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment