diff --git a/demos/bernoulli_demo.py b/demos/bernoulli_demo.py index f81d4ad7fed677e3cc75e5576364dd53dad8e9c6..700b377c6b4987c0b787c7a7e77a8f26729e793b 100644 --- a/demos/bernoulli_demo.py +++ b/demos/bernoulli_demo.py @@ -73,14 +73,15 @@ if __name__ == '__main__': # Minimize the Hamiltonian H = ift.Hamiltonian(likelihood, ic_sampling) - H = ift.EnergyAdapter(position, H) + H = ift.EnergyAdapter(position, H, want_metric=True) # minimizer = ift.L_BFGS(ic_newton) H, convergence = minimizer(H) reconstruction = sky(H.position) - ift.plot(reconstruction, title='reconstruction') - ift.plot(GR.adjoint_times(data), title='data') - ift.plot(sky(mock_position), title='truth') - ift.plot_finish(nx=3, xsize=16, ysize=5, title="results", - name="bernoulli.png") + plot = ift.Plot() + plot.add(reconstruction, title='reconstruction') + plot.add(GR.adjoint_times(data), title='data') + plot.add(sky(mock_position), title='truth') + plot.output(nx=3, xsize=16, ysize=5, title="results", + name="bernoulli.png") diff --git a/demos/getting_started_1.py b/demos/getting_started_1.py index 76425bea5b55995d516d34c119731e8c71093348..410db8b8b4d658a70c8065fe37186ada31c1f119 100644 --- a/demos/getting_started_1.py +++ b/demos/getting_started_1.py @@ -103,18 +103,17 @@ if __name__ == '__main__': # PLOTTING rg = isinstance(position_space, ift.RGSpace) + plot = ift.Plot() if rg and len(position_space.shape) == 1: - ift.plot([HT(MOCK_SIGNAL), GR.adjoint(data), HT(m)], + plot.add([HT(MOCK_SIGNAL), GR.adjoint(data), HT(m)], label=['Mock signal', 'Data', 'Reconstruction'], alpha=[1, .3, 1]) - ift.plot(mask_to_nan(mask, HT(m-MOCK_SIGNAL)), title='Residuals') - ift.plot_finish(nx=2, ny=1, xsize=10, ysize=4, - title="getting_started_1") + plot.add(mask_to_nan(mask, HT(m-MOCK_SIGNAL)), title='Residuals') + plot.output(nx=2, ny=1, xsize=10, ysize=4, title="getting_started_1") else: - ift.plot(HT(MOCK_SIGNAL), title='Mock Signal') - ift.plot(mask_to_nan(mask, (GR(Mask)).adjoint(data)), + plot.add(HT(MOCK_SIGNAL), title='Mock Signal') + plot.add(mask_to_nan(mask, (GR(Mask)).adjoint(data)), title='Data') - ift.plot(HT(m), title='Reconstruction') - ift.plot(mask_to_nan(mask, HT(m-MOCK_SIGNAL)), title='Residuals') - ift.plot_finish(nx=2, ny=2, xsize=10, ysize=10, - title="getting_started_1") + plot.add(HT(m), title='Reconstruction') + plot.add(mask_to_nan(mask, HT(m-MOCK_SIGNAL)), title='Residuals') + plot.output(nx=2, ny=2, xsize=10, ysize=10, title="getting_started_1") diff --git a/demos/getting_started_2.py b/demos/getting_started_2.py index ad0ff69eadfa02d3ab749eb0ec3d6a6bfeff8f3c..467fbedc044c8902e3c4ce71d7020ac810fb574f 100644 --- a/demos/getting_started_2.py +++ b/demos/getting_started_2.py @@ -93,14 +93,15 @@ if __name__ == '__main__': # Minimize the Hamiltonian H = ift.Hamiltonian(likelihood) - H = ift.EnergyAdapter(position, H) + H = ift.EnergyAdapter(position, H, want_metric=True) H, convergence = minimizer(H) # Plot results signal = sky(mock_position) reconst = sky(H.position) - ift.plot(signal, title='Signal') - ift.plot(GR.adjoint(data), title='Data') - ift.plot(reconst, title='Reconstruction') - ift.plot(reconst - signal, title='Residuals') - ift.plot_finish(name='getting_started_2.png', xsize=16, ysize=16) + plot = ift.Plot() + plot.add(signal, title='Signal') + plot.add(GR.adjoint(data), title='Data') + plot.add(reconst, title='Reconstruction') + plot.add(reconst - signal, title='Residuals') + plot.output(name='getting_started_2.png', xsize=16, ysize=16) diff --git a/demos/getting_started_3.py b/demos/getting_started_3.py index 43387801207663fccebaf5f2a0a757956ff528fa..4b9b7e59ab180f39e522b771552e75117397964b 100644 --- a/demos/getting_started_3.py +++ b/demos/getting_started_3.py @@ -41,11 +41,12 @@ if __name__ == '__main__': power_space = A.target[0] power_distributor = ift.PowerDistributor(harmonic_space, power_space) dummy = ift.Field.from_random('normal', harmonic_space) - domain = ift.MultiDomain.union( - (A.domain, ift.MultiDomain.make({'xi': harmonic_space}))) + domain = ift.MultiDomain.union((A.domain, + ift.MultiDomain.make({ + 'xi': harmonic_space + }))) - correlated_field = ht( - power_distributor(A)*ift.FieldAdapter(domain, "xi")) + correlated_field = ht(power_distributor(A)*ift.FieldAdapter(domain, "xi")) # alternatively to the block above one can do: # correlated_field = ift.CorrelatedField(position_space, A) @@ -54,8 +55,7 @@ if __name__ == '__main__': # Building the Line of Sight response LOS_starts, LOS_ends = get_random_LOS(100) - R = ift.LOSResponse(position_space, starts=LOS_starts, - ends=LOS_ends) + R = ift.LOSResponse(position_space, starts=LOS_starts, ends=LOS_ends) # build signal response model and model likelihood signal_response = R(signal) # specify noise @@ -68,8 +68,7 @@ if __name__ == '__main__': data = signal_response(MOCK_POSITION) + N.draw_sample() # set up model likelihood - likelihood = ift.GaussianEnergy( - mean=data, covariance=N)(signal_response) + likelihood = ift.GaussianEnergy(mean=data, covariance=N)(signal_response) # set up minimization and inversion schemes ic_sampling = ift.GradientNormController(iteration_limit=100) @@ -83,34 +82,33 @@ if __name__ == '__main__': INITIAL_POSITION = ift.from_random('normal', domain) position = INITIAL_POSITION - ift.plot(signal(MOCK_POSITION), title='ground truth') - ift.plot(R.adjoint_times(data), title='data') - ift.plot([A(MOCK_POSITION)], title='power') - ift.plot_finish(nx=3, xsize=16, ysize=5, title="setup", name="setup.png") + plot = ift.Plot() + plot.add(signal(MOCK_POSITION), title='Ground Truth') + plot.add(R.adjoint_times(data), title='Data') + plot.add([A(MOCK_POSITION)], title='Power Spectrum') + plot.output(ny=1, nx=3, xsize=24, ysize=6, name="setup.png") # number of samples used to estimate the KL N_samples = 20 for i in range(2): - metric = H(ift.Linearization.make_var(position)).metric - samples = [metric.draw_sample(from_inverse=True) - for _ in range(N_samples)] - - KL = ift.SampledKullbachLeiblerDivergence(H, samples) - KL = ift.EnergyAdapter(position, KL) + KL = ift.KL_Energy(position, H, N_samples, want_metric=True) KL, convergence = minimizer(KL) position = KL.position - ift.plot(signal(position), title="reconstruction") - ift.plot([A(position), A(MOCK_POSITION)], title="power") - ift.plot_finish(nx=2, xsize=12, ysize=6, title="loop", name="loop.png") + plot = ift.Plot() + plot.add(signal(KL.position), title="reconstruction") + plot.add([A(KL.position), A(MOCK_POSITION)], title="power") + plot.output(ny=1, ysize=6, xsize=16, name="loop.png") + plot = ift.Plot() sc = ift.StatCalculator() - for sample in samples: - sc.add(signal(sample+position)) - ift.plot(sc.mean, title="mean") - ift.plot(ift.sqrt(sc.var), title="std deviation") - - powers = [A(s+position) for s in samples] - ift.plot([A(position), A(MOCK_POSITION)]+powers, title="power") - ift.plot_finish(nx=3, xsize=16, ysize=5, title="results", - name="results.png") + for sample in KL.samples: + sc.add(signal(sample+KL.position)) + plot.add(sc.mean, title="Posterior Mean") + plot.add(ift.sqrt(sc.var), title="Posterior Standard Deviation") + + powers = [A(s+KL.position) for s in KL.samples] + plot.add( + [A(KL.position), A(MOCK_POSITION)]+powers, + title="Sampled Posterior Power Spectrum") + plot.output(ny=1, nx=3, xsize=24, ysize=6, name="results.png") diff --git a/demos/plot_test.py b/demos/plot_test.py index 0138b31cefdc5e672b2c550aefec7fa0a9f8cf00..49b2e752cad45ae7ace5e894caf853236b8d8db9 100644 --- a/demos/plot_test.py +++ b/demos/plot_test.py @@ -20,21 +20,24 @@ def plot_test(): # Start various plotting tests - ift.plot(field_rg1_1, title='Single plot') - ift.plot_finish() - - ift.plot(field_rg2, title='2d rg') - ift.plot([field_rg1_1, field_rg1_2], title='list 1d rg', label=['1', '2']) - ift.plot(field_rg1_2, title='1d rg, xmin, ymin', xmin=0.5, ymin=0., + plot = ift.Plot() + plot.add(field_rg1_1, title='Single plot') + plot.output() + + plot = ift.Plot() + plot.add(field_rg2, title='2d rg') + plot.add([field_rg1_1, field_rg1_2], title='list 1d rg', label=['1', '2']) + plot.add(field_rg1_2, title='1d rg, xmin, ymin', xmin=0.5, ymin=0., xlabel='xmin=0.5', ylabel='ymin=0') - ift.plot_finish(title='Three plots') - - ift.plot(field_hp, title='HP planck-color', colormap='Planck-like') - ift.plot(field_rg1_2, title='1d rg') - ift.plot(field_ps) - ift.plot(field_gl, title='GL') - ift.plot(field_rg2, title='2d rg') - ift.plot_finish(nx=2, ny=3, title='Five plots') + plot.output(title='Three plots') + + plot = ift.Plot() + plot.add(field_hp, title='HP planck-color', colormap='Planck-like') + plot.add(field_rg1_2, title='1d rg') + plot.add(field_ps) + plot.add(field_gl, title='GL') + plot.add(field_rg2, title='2d rg') + plot.output(nx=2, ny=3, title='Five plots') if __name__ == '__main__': diff --git a/demos/polynomial_fit.py b/demos/polynomial_fit.py index f49cbec8b0a9e26810ae3e5e65cf51cd226aaaf8..403269c791db3b55c9327264cdb8c30c6d050c1c 100644 --- a/demos/polynomial_fit.py +++ b/demos/polynomial_fit.py @@ -86,15 +86,16 @@ N = ift.DiagonalOperator(ift.from_global_data(d_space, var)) IC = ift.GradientNormController(tol_abs_gradnorm=1e-8) likelihood = ift.GaussianEnergy(d, N)(R) -H = ift.Hamiltonian(likelihood, IC) -H = ift.EnergyAdapter(params, H, IC) +Ham = ift.Hamiltonian(likelihood, IC) +H = ift.EnergyAdapter(params, Ham, want_metric=True) # Minimize minimizer = ift.NewtonCG(IC) H, _ = minimizer(H) # Draw posterior samples -samples = [H.metric.draw_sample(from_inverse=True) + H.position +metric = Ham(ift.Linearization.make_var(H.position, want_metric=True)).metric +samples = [metric.draw_sample(from_inverse=True) + H.position for _ in range(N_samples)] # Plotting diff --git a/nifty5/__init__.py b/nifty5/__init__.py index 49282e10297e23ac02b19809f9471a5e5d42592b..562f462fc316c95e05d43ad22a07c7d6c8134902 100644 --- a/nifty5/__init__.py +++ b/nifty5/__init__.py @@ -22,7 +22,8 @@ from .operators.operator import Operator from .operators.central_zero_padder import CentralZeroPadder from .operators.diagonal_operator import DiagonalOperator from .operators.distributors import DOFDistributor, PowerDistributor -from .operators.domain_tuple_operators import DomainDistributor, DomainTupleFieldInserter +from .operators.domain_tuple_operators import DomainTupleFieldInserter +from .operators.contraction_operator import ContractionOperator from .operators.endomorphic_operator import EndomorphicOperator from .operators.exp_transform import ExpTransform from .operators.harmonic_operators import ( @@ -67,9 +68,10 @@ from .minimization.energy import Energy from .minimization.quadratic_energy import QuadraticEnergy from .minimization.line_energy import LineEnergy from .minimization.energy_adapter import EnergyAdapter +from .minimization.kl_energy import KL_Energy from .sugar import * -from .plotting.plot import plot, plot_finish +from .plotting.plot import Plot from .library.amplitude_model import AmplitudeModel from .library.inverse_gamma_model import InverseGammaModel diff --git a/nifty5/extra/energy_and_model_tests.py b/nifty5/extra/energy_and_model_tests.py index 60022c10129a337759b102e3c6622e7518e56692..f151e5756a659ee558128499d9af103c13f93d85 100644 --- a/nifty5/extra/energy_and_model_tests.py +++ b/nifty5/extra/energy_and_model_tests.py @@ -41,7 +41,7 @@ def _get_acceptable_location(op, loc, lin): for i in range(50): try: loc2 = loc+dir - lin2 = op(Linearization.make_var(loc2)) + lin2 = op(Linearization.make_var(loc2, lin.want_metric)) if np.isfinite(lin2.val.sum()) and abs(lin2.val.sum()) < 1e20: break except FloatingPointError: @@ -54,14 +54,14 @@ def _get_acceptable_location(op, loc, lin): def _check_consistency(op, loc, tol, ntries, do_metric): for _ in range(ntries): - lin = op(Linearization.make_var(loc)) + lin = op(Linearization.make_var(loc, do_metric)) loc2, lin2 = _get_acceptable_location(op, loc, lin) dir = loc2-loc locnext = loc2 dirnorm = dir.norm() for i in range(50): locmid = loc + 0.5*dir - linmid = op(Linearization.make_var(locmid)) + linmid = op(Linearization.make_var(locmid, do_metric)) dirder = linmid.jac(dir) numgrad = (lin2.val-lin.val) xtol = tol * dirder.norm() / np.sqrt(dirder.size) diff --git a/nifty5/library/correlated_fields.py b/nifty5/library/correlated_fields.py index 0cdf90b0af9a6602f64575c2360735313d4db2f8..2687951644fa444e8bc568e39d5bbd2a5f31da6d 100644 --- a/nifty5/library/correlated_fields.py +++ b/nifty5/library/correlated_fields.py @@ -21,8 +21,8 @@ from __future__ import absolute_import, division, print_function from ..compat import * from ..domain_tuple import DomainTuple from ..multi_domain import MultiDomain +from ..operators.contraction_operator import ContractionOperator from ..operators.distributors import PowerDistributor -from ..operators.domain_tuple_operators import DomainDistributor from ..operators.harmonic_operators import HarmonicTransformOperator from ..operators.simple_linear_operators import FieldAdapter from ..sugar import exp @@ -65,8 +65,8 @@ def MfCorrelatedField(s_space_spatial, s_space_energy, amplitude_model_spatial, pd_energy = PowerDistributor(pd_spatial.domain, p_space_energy, 1) pd = pd_spatial(pd_energy) - dom_distr_spatial = DomainDistributor(pd.domain, 0) - dom_distr_energy = DomainDistributor(pd.domain, 1) + dom_distr_spatial = ContractionOperator(pd.domain, 0).adjoint + dom_distr_energy = ContractionOperator(pd.domain, 1).adjoint a_spatial = dom_distr_spatial(amplitude_model_spatial) a_energy = dom_distr_energy(amplitude_model_energy) diff --git a/nifty5/library/inverse_gamma_model.py b/nifty5/library/inverse_gamma_model.py index 9be2acce1a7b4aad7e292849034cca8cd10a60b0..831546453062e734b0e708dd33245e5fd13304e2 100644 --- a/nifty5/library/inverse_gamma_model.py +++ b/nifty5/library/inverse_gamma_model.py @@ -53,7 +53,7 @@ class InverseGammaModel(Operator): outer = 1/outer_inv jac = makeOp(Field.from_local_data(self._domain, inner*outer)) jac = jac(x.jac) - return Linearization(points, jac) + return x.new(points, jac) @staticmethod def IG(field, alpha, q): diff --git a/nifty5/linearization.py b/nifty5/linearization.py index 41210bdeec24e31147e72080491f16699520aa01..c27dec1abc65794ab57288480f49e59b95053d12 100644 --- a/nifty5/linearization.py +++ b/nifty5/linearization.py @@ -9,13 +9,17 @@ from .sugar import makeOp class Linearization(object): - def __init__(self, val, jac, metric=None): + def __init__(self, val, jac, metric=None, want_metric=False): self._val = val self._jac = jac if self._val.domain != self._jac.target: raise ValueError("domain mismatch") + self._want_metric = want_metric self._metric = metric + def new(self, val, jac, metric=None): + return Linearization(val, jac, metric, self._want_metric) + @property def domain(self): return self._jac.domain @@ -37,6 +41,10 @@ class Linearization(object): """Only available if target is a scalar""" return self._jac.adjoint_times(Field.scalar(1.)) + @property + def want_metric(self): + return self._want_metric + @property def metric(self): """Only available if target is a scalar""" @@ -44,35 +52,34 @@ class Linearization(object): def __getitem__(self, name): from .operators.simple_linear_operators import FieldAdapter - return Linearization(self._val[name], FieldAdapter(self.domain, name)) + return self.new(self._val[name], FieldAdapter(self.domain, name)) def __neg__(self): - return Linearization( - -self._val, -self._jac, - None if self._metric is None else -self._metric) + return self.new(-self._val, -self._jac, + None if self._metric is None else -self._metric) def conjugate(self): - return Linearization( + return self.new( self._val.conjugate(), self._jac.conjugate(), None if self._metric is None else self._metric.conjugate()) @property def real(self): - return Linearization(self._val.real, self._jac.real) + return self.new(self._val.real, self._jac.real) def _myadd(self, other, neg): if isinstance(other, Linearization): met = None if self._metric is not None and other._metric is not None: met = self._metric._myadd(other._metric, neg) - return Linearization( + return self.new( self._val.flexible_addsub(other._val, neg), self._jac._myadd(other._jac, neg), met) if isinstance(other, (int, float, complex, Field, MultiField)): if neg: - return Linearization(self._val-other, self._jac, self._metric) + return self.new(self._val-other, self._jac, self._metric) else: - return Linearization(self._val+other, self._jac, self._metric) + return self.new(self._val+other, self._jac, self._metric) def __add__(self, other): return self._myadd(other, False) @@ -91,7 +98,7 @@ class Linearization(object): if isinstance(other, Linearization): if self.target != other.target: raise ValueError("domain mismatch") - return Linearization( + return self.new( self._val*other._val, (makeOp(other._val)(self._jac))._myadd( makeOp(self._val)(other._jac), False)) @@ -99,11 +106,11 @@ class Linearization(object): if other == 1: return self met = None if self._metric is None else self._metric.scale(other) - return Linearization(self._val*other, self._jac.scale(other), met) + return self.new(self._val*other, self._jac.scale(other), met) if isinstance(other, (Field, MultiField)): if self.target != other.domain: raise ValueError("domain mismatch") - return Linearization(self._val*other, makeOp(other)(self._jac)) + return self.new(self._val*other, makeOp(other)(self._jac)) def __rmul__(self, other): return self.__mul__(other) @@ -111,46 +118,48 @@ class Linearization(object): def vdot(self, other): from .operators.simple_linear_operators import VdotOperator if isinstance(other, (Field, MultiField)): - return Linearization( + return self.new( Field.scalar(self._val.vdot(other)), VdotOperator(other)(self._jac)) - return Linearization( + return self.new( Field.scalar(self._val.vdot(other._val)), VdotOperator(self._val)(other._jac) + VdotOperator(other._val)(self._jac)) def sum(self): from .operators.simple_linear_operators import SumReductionOperator - return Linearization( + return self.new( Field.scalar(self._val.sum()), SumReductionOperator(self._jac.target)(self._jac)) def exp(self): tmp = self._val.exp() - return Linearization(tmp, makeOp(tmp)(self._jac)) + return self.new(tmp, makeOp(tmp)(self._jac)) def log(self): tmp = self._val.log() - return Linearization(tmp, makeOp(1./self._val)(self._jac)) + return self.new(tmp, makeOp(1./self._val)(self._jac)) def tanh(self): tmp = self._val.tanh() - return Linearization(tmp, makeOp(1.-tmp**2)(self._jac)) + return self.new(tmp, makeOp(1.-tmp**2)(self._jac)) def positive_tanh(self): tmp = self._val.tanh() tmp2 = 0.5*(1.+tmp) - return Linearization(tmp2, makeOp(0.5*(1.-tmp**2))(self._jac)) + return self.new(tmp2, makeOp(0.5*(1.-tmp**2))(self._jac)) def add_metric(self, metric): - return Linearization(self._val, self._jac, metric) + return self.new(self._val, self._jac, metric) @staticmethod - def make_var(field): + def make_var(field, want_metric=False): from .operators.scaling_operator import ScalingOperator - return Linearization(field, ScalingOperator(1., field.domain)) + return Linearization(field, ScalingOperator(1., field.domain), + want_metric=want_metric) @staticmethod - def make_const(field): + def make_const(field, want_metric=False): from .operators.simple_linear_operators import NullOperator - return Linearization(field, NullOperator(field.domain, field.domain)) + return Linearization(field, NullOperator(field.domain, field.domain), + want_metric=want_metric) diff --git a/nifty5/minimization/conjugate_gradient.py b/nifty5/minimization/conjugate_gradient.py index 9a621917c4234c2bd8f0aa30d2a55f09e39728c8..f2487a3dde1aba9298ab2b3dccc4a2b967f8cf7c 100644 --- a/nifty5/minimization/conjugate_gradient.py +++ b/nifty5/minimization/conjugate_gradient.py @@ -75,7 +75,7 @@ class ConjugateGradient(Minimizer): return energy, controller.CONVERGED while True: - q = energy.metric(d) + q = energy.apply_metric(d) ddotq = d.vdot(q).real if ddotq == 0.: logger.error("Error: ConjugateGradient: ddotq==0.") diff --git a/nifty5/minimization/descent_minimizers.py b/nifty5/minimization/descent_minimizers.py index 5eddad19ca680e5111393709236825d8048e8ade..bc0da9017747c69e2c0a744fa116d94927c94a6c 100644 --- a/nifty5/minimization/descent_minimizers.py +++ b/nifty5/minimization/descent_minimizers.py @@ -180,7 +180,7 @@ class NewtonCG(DescentMinimizer): while True: if abs(ri).sum() <= termcond: return xsupi - Ap = energy.metric(psupi) + Ap = energy.apply_metric(psupi) # check curvature curv = psupi.vdot(Ap) if 0 <= curv <= 3*float64eps: diff --git a/nifty5/minimization/energy.py b/nifty5/minimization/energy.py index c213a6e8f71396e3a508e8abb778b5649a275db4..cfa59fb552dcfc3df8f37503bb765037834ed47a 100644 --- a/nifty5/minimization/energy.py +++ b/nifty5/minimization/energy.py @@ -109,6 +109,20 @@ class Energy(NiftyMetaBase()): """ raise NotImplementedError + def apply_metric(self, x): + """ + Parameters + ---------- + x: Field/MultiField + Argument for the metric operator + + Returns + ------- + Field/MultiField: + Output of the metric operator + """ + raise NotImplementedError + def longest_step(self, dir): """Returns the longest allowed step size along `dir` diff --git a/nifty5/minimization/energy_adapter.py b/nifty5/minimization/energy_adapter.py index f85d2e9d8c215a034319949190a5d8f2c063633e..985459cc56f162fe7b5306577d605b598129e4fc 100644 --- a/nifty5/minimization/energy_adapter.py +++ b/nifty5/minimization/energy_adapter.py @@ -8,58 +8,38 @@ from ..operators.scaling_operator import ScalingOperator class EnergyAdapter(Energy): - def __init__(self, position, op, controller=None, preconditioner=None, - constants=[]): + def __init__(self, position, op, constants=[], want_metric=False): super(EnergyAdapter, self).__init__(position) self._op = op - self._val = self._grad = self._metric = None - self._controller = controller - self._preconditioner = preconditioner self._constants = constants - - def at(self, position): - return EnergyAdapter(position, self._op, self._controller, - self._preconditioner, self._constants) - - def _fill_all(self): + self._want_metric = want_metric if len(self._constants) == 0: - tmp = self._op(Linearization.make_var(self._position)) + tmp = self._op(Linearization.make_var(self._position, want_metric)) else: ops = [ScalingOperator(0. if key in self._constants else 1., dom) for key, dom in self._position.domain.items()] bdop = BlockDiagonalOperator(self._position.domain, tuple(ops)) - tmp = self._op(Linearization(self._position, bdop)) + tmp = self._op(Linearization(self._position, bdop, + want_metric=want_metric)) self._val = tmp.val.local_data[()] self._grad = tmp.gradient - if self._controller is not None: - from ..operators.linear_operator import LinearOperator - from ..operators.inversion_enabler import InversionEnabler + self._metric = tmp._metric - if self._preconditioner is None: - precond = None - elif isinstance(self._preconditioner, LinearOperator): - precond = self._preconditioner - elif isinstance(self._preconditioner, Energy): - precond = self._preconditioner.at(self._position).metric - self._metric = InversionEnabler(tmp._metric, self._controller, - precond) - else: - self._metric = tmp._metric + def at(self, position): + return EnergyAdapter(position, self._op, self._constants, + self._want_metric) @property def value(self): - if self._val is None: - self._val = self._op(self._position).local_data[()] return self._val @property def gradient(self): - if self._grad is None: - self._fill_all() return self._grad @property def metric(self): - if self._metric is None: - self._fill_all() return self._metric + + def apply_metric(self, x): + return self._metric(x) diff --git a/nifty5/minimization/kl_energy.py b/nifty5/minimization/kl_energy.py new file mode 100644 index 0000000000000000000000000000000000000000..d7a98364d6024cbada97acca11d0d460864c8280 --- /dev/null +++ b/nifty5/minimization/kl_energy.py @@ -0,0 +1,57 @@ +from __future__ import absolute_import, division, print_function + +from ..compat import * +from .energy import Energy +from ..linearization import Linearization +from ..operators.scaling_operator import ScalingOperator +from ..operators.block_diagonal_operator import BlockDiagonalOperator +from .. import utilities + + +class KL_Energy(Energy): + def __init__(self, position, h, nsamp, constants=[], _samples=None, + want_metric=False): + super(KL_Energy, self).__init__(position) + self._h = h + self._constants = constants + self._want_metric = want_metric + if _samples is None: + met = h(Linearization.make_var(position, True)).metric + _samples = tuple(met.draw_sample(from_inverse=True) + for _ in range(nsamp)) + self._samples = _samples + if len(constants) == 0: + tmp = Linearization.make_var(position, want_metric) + else: + ops = [ScalingOperator(0. if key in constants else 1., dom) + for key, dom in position.domain.items()] + bdop = BlockDiagonalOperator(position.domain, tuple(ops)) + tmp = Linearization(position, bdop, want_metric=want_metric) + mymap = map(lambda v: self._h(tmp+v), self._samples) + tmp = utilities.my_sum(mymap) * (1./len(self._samples)) + self._val = tmp.val.local_data[()] + self._grad = tmp.gradient + self._metric = tmp.metric + + def at(self, position): + return KL_Energy(position, self._h, 0, self._constants, self._samples, + self._want_metric) + + @property + def value(self): + return self._val + + @property + def gradient(self): + return self._grad + + def apply_metric(self, x): + return self._metric(x) + + @property + def metric(self): + return self._metric + + @property + def samples(self): + return self._samples diff --git a/nifty5/minimization/quadratic_energy.py b/nifty5/minimization/quadratic_energy.py index a0949adb1c429d351c8b1959cdfaa42d255f00e7..254bc5ab3d275ec6da9f4ef143e366e73b2d7899 100644 --- a/nifty5/minimization/quadratic_energy.py +++ b/nifty5/minimization/quadratic_energy.py @@ -77,3 +77,6 @@ class QuadraticEnergy(Energy): @property def metric(self): return self._A + + def apply_metric(self, x): + return self._A(x) diff --git a/nifty5/minimization/scipy_minimizer.py b/nifty5/minimization/scipy_minimizer.py index bb94f66490606511fadcb90a02b924ba94a239ed..49e7a4ddaf5d23001ed628023a870550a904659e 100644 --- a/nifty5/minimization/scipy_minimizer.py +++ b/nifty5/minimization/scipy_minimizer.py @@ -93,7 +93,7 @@ class _MinHelper(object): def hessp(self, x, p): self._update(x) - res = self._energy.metric(_toField(p, self._energy.position)) + res = self._energy.apply_metric(_toField(p, self._energy.position)) return _toArray_rw(res) diff --git a/nifty5/operators/contraction_operator.py b/nifty5/operators/contraction_operator.py new file mode 100644 index 0000000000000000000000000000000000000000..741f34c3c2b5d31b0a12d1f7137c1e9ccb1bb678 --- /dev/null +++ b/nifty5/operators/contraction_operator.py @@ -0,0 +1,64 @@ +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# +# Copyright(C) 2013-2018 Max-Planck-Society +# +# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik +# and financially supported by the Studienstiftung des deutschen Volkes. + +from __future__ import absolute_import, division, print_function + +import numpy as np + +from .. import utilities +from ..compat import * +from ..domain_tuple import DomainTuple +from ..field import Field +from .linear_operator import LinearOperator + + +class ContractionOperator(LinearOperator): + """A linear operator which sums up fields into the direction of subspaces. + + This ContractionOperator sums up a field with is defined on a DomainTuple + to a DomainTuple which contains the former as a subset. + + Parameters + ---------- + domain : Domain, tuple of Domain or DomainTuple + spaces : int or tuple of int + The elements of "domain" which are taken as target. + """ + + def __init__(self, domain, spaces): + self._domain = DomainTuple.make(domain) + self._spaces = utilities.parse_spaces(spaces, len(self._domain)) + self._target = [ + dom for i, dom in enumerate(self._domain) if i in self._spaces + ] + self._target = DomainTuple.make(self._target) + self._capability = self.TIMES | self.ADJOINT_TIMES + + def apply(self, x, mode): + self._check_input(x, mode) + if mode == self.ADJOINT_TIMES: + ldat = x.local_data if 0 in self._spaces else x.to_global_data() + shp = [] + for i, dom in enumerate(self._domain): + tmp = dom.shape if i > 0 else dom.local_shape + shp += tmp if i in self._spaces else (1,)*len(dom.shape) + ldat = np.broadcast_to(ldat.reshape(shp), self._domain.local_shape) + return Field.from_local_data(self._domain, ldat) + else: + return x.sum( + [s for s in range(len(x.domain)) if s not in self._spaces]) diff --git a/nifty5/operators/diagonal_operator.py b/nifty5/operators/diagonal_operator.py index 3be566b09d67d770d1d5f58fb06f44eefb0bde15..1e2fb22a7e85c2da757e79f9a2cee896ecc901a5 100644 --- a/nifty5/operators/diagonal_operator.py +++ b/nifty5/operators/diagonal_operator.py @@ -150,6 +150,8 @@ class DiagonalOperator(EndomorphicOperator): return Field.from_local_data(x.domain, x.local_data/xdiag) def _flip_modes(self, trafo): + if trafo == self.ADJOINT_BIT and not self._complex: # shortcut + return self xdiag = self._ldiag if self._complex and (trafo & self.ADJOINT_BIT): xdiag = xdiag.conj() diff --git a/nifty5/operators/domain_tuple_operators.py b/nifty5/operators/domain_tuple_operators.py index 13c017262ae7e810b4fe1f6a282dbf2619c57198..4dacdc5bda1ac7d2ce74cb7d0959f29e643e9649 100644 --- a/nifty5/operators/domain_tuple_operators.py +++ b/nifty5/operators/domain_tuple_operators.py @@ -20,52 +20,12 @@ from __future__ import absolute_import, division, print_function import numpy as np -from .. import utilities from ..compat import * from ..domain_tuple import DomainTuple from ..field import Field from .linear_operator import LinearOperator -class DomainDistributor(LinearOperator): - """A linear operator which broadcasts a field to a larger domain. - - This DomainDistributor broadcasts a field which is defined on a - DomainTuple to a DomainTuple which contains the former as a subset. The - entries of the field are copied such that they are constant in the - direction of the new spaces. - - Parameters - ---------- - target : Domain, tuple of Domain or DomainTuple - spaces : int or tuple of int - The elements of "target" which are taken as domain. - """ - - def __init__(self, target, spaces): - self._target = DomainTuple.make(target) - self._spaces = utilities.parse_spaces(spaces, len(self._target)) - self._domain = [ - tgt for i, tgt in enumerate(self._target) if i in self._spaces - ] - self._domain = DomainTuple.make(self._domain) - self._capability = self.TIMES | self.ADJOINT_TIMES - - def apply(self, x, mode): - self._check_input(x, mode) - if mode == self.TIMES: - ldat = x.local_data if 0 in self._spaces else x.to_global_data() - shp = [] - for i, tgt in enumerate(self._target): - tmp = tgt.shape if i > 0 else tgt.local_shape - shp += tmp if i in self._spaces else (1,)*len(tgt.shape) - ldat = np.broadcast_to(ldat.reshape(shp), self._target.local_shape) - return Field.from_local_data(self._target, ldat) - else: - return x.sum( - [s for s in range(len(x.domain)) if s not in self._spaces]) - - class DomainTupleFieldInserter(LinearOperator): def __init__(self, domain, new_space, ind, infront=False): '''Writes the content of a field into one slice of a DomainTuple. diff --git a/nifty5/operators/energy_operators.py b/nifty5/operators/energy_operators.py index 7d502421e295b5f3f79697b9cf187ac202e56440..a68ea92527a4850bd79976cb25eab87d2e8d9d7c 100644 --- a/nifty5/operators/energy_operators.py +++ b/nifty5/operators/energy_operators.py @@ -42,7 +42,7 @@ class SquaredNormOperator(EnergyOperator): if isinstance(x, Linearization): val = Field.scalar(x.val.vdot(x.val)) jac = VdotOperator(2*x.val)(x.jac) - return Linearization(val, jac) + return x.new(val, jac) return Field.scalar(x.vdot(x)) @@ -59,7 +59,7 @@ class QuadraticFormOperator(EnergyOperator): t1 = self._op(x.val) jac = VdotOperator(t1)(x.jac) val = Field.scalar(0.5*x.val.vdot(t1)) - return Linearization(val, jac) + return x.new(val, jac) return Field.scalar(0.5*x.vdot(self._op(x))) @@ -91,7 +91,7 @@ class GaussianEnergy(EnergyOperator): def apply(self, x): residual = x if self._mean is None else x-self._mean res = self._op(residual).real - if not isinstance(x, Linearization): + if not isinstance(x, Linearization) or not x.want_metric: return res metric = SandwichOperator.make(x.jac, self._icov) return res.add_metric(metric) @@ -107,6 +107,8 @@ class PoissonianEnergy(EnergyOperator): res = x.sum() - x.log().vdot(self._d) if not isinstance(x, Linearization): return Field.scalar(res) + if not x.want_metric: + return res metric = SandwichOperator.make(x.jac, makeOp(1./x.val)) return res.add_metric(metric) @@ -122,6 +124,8 @@ class BernoulliEnergy(EnergyOperator): v = x.log().vdot(-self._d) - (1.-x).log().vdot(1.-self._d) if not isinstance(x, Linearization): return Field.scalar(v) + if not x.want_metric: + return v met = makeOp(1./(x.val*(1.-x.val))) met = SandwichOperator.make(x.jac, met) return v.add_metric(met) @@ -135,11 +139,11 @@ class Hamiltonian(EnergyOperator): self._domain = lh.domain def apply(self, x): - if self._ic_samp is None or not isinstance(x, Linearization): + if (self._ic_samp is None or not isinstance(x, Linearization) or + not x.want_metric): return self._lh(x)+self._prior(x) else: - lhx = self._lh(x) - prx = self._prior(x) + lhx, prx = self._lh(x), self._prior(x) mtr = SamplingEnabler(lhx.metric, prx.metric.inverse, self._ic_samp, prx.metric.inverse) return (lhx+prx).add_metric(mtr) diff --git a/nifty5/operators/field_zero_padder.py b/nifty5/operators/field_zero_padder.py index 0a0617d4106964035a4fad1b3e35f02b818623ab..c4ad475904fd0c35a854597f4db5520cf303cd63 100644 --- a/nifty5/operators/field_zero_padder.py +++ b/nifty5/operators/field_zero_padder.py @@ -69,7 +69,6 @@ class FieldZeroPadder(LinearOperator): i1 = idx + (slice(None, -(Nyquist+1), -1),) xnew[i1] = x[i1] # if (x.shape[d] & 1) == 0: # even number of pixels -# print (Nyquist, x.shape[d]-Nyquist) # i1 = idx+(Nyquist,) # xnew[i1] *= 0.5 # i1 = idx+(-Nyquist,) diff --git a/nifty5/operators/linear_operator.py b/nifty5/operators/linear_operator.py index 26a0cf81dc0dfac484ce06f2ba567158198c09c6..aa330c2f4ccede83ad73de5b72810da71ed542ba 100644 --- a/nifty5/operators/linear_operator.py +++ b/nifty5/operators/linear_operator.py @@ -175,7 +175,7 @@ class LinearOperator(Operator): return self.apply(x, self.TIMES) from ..linearization import Linearization if isinstance(x, Linearization): - return Linearization(self(x._val), self(x._jac)) + return x.new(self(x._val), self(x._jac)) return self.__matmul__(x) def times(self, x): diff --git a/nifty5/operators/operator.py b/nifty5/operators/operator.py index 73ece15e64ef3e1f64b8d8e27f75e09f4ced4f89..13b7e6fb7bfb1e27f12e3f292744b239f868b8b5 100644 --- a/nifty5/operators/operator.py +++ b/nifty5/operators/operator.py @@ -144,11 +144,12 @@ class _OpProd(Operator): v2 = v.extract(self._op2.domain) if not lin: return self._op1(v1) * self._op2(v2) - lin1 = self._op1(Linearization.make_var(v1)) - lin2 = self._op2(Linearization.make_var(v2)) + wm = x.want_metric + lin1 = self._op1(Linearization.make_var(v1, wm)) + lin2 = self._op2(Linearization.make_var(v2, wm)) op = (makeOp(lin1._val)(lin2._jac))._myadd( makeOp(lin2._val)(lin1._jac), False) - return Linearization(lin1._val*lin2._val, op(x.jac)) + return lin1.new(lin1._val*lin2._val, op(x.jac)) class _OpSum(Operator): @@ -168,10 +169,11 @@ class _OpSum(Operator): res = None if not lin: return self._op1(v1).unite(self._op2(v2)) - lin1 = self._op1(Linearization.make_var(v1)) - lin2 = self._op2(Linearization.make_var(v2)) + wm = x.want_metric + lin1 = self._op1(Linearization.make_var(v1, wm)) + lin2 = self._op2(Linearization.make_var(v2, wm)) op = lin1._jac._myadd(lin2._jac, False) - res = Linearization(lin1._val+lin2._val, op(x.jac)) + res = lin1.new(lin1._val+lin2._val, op(x.jac)) if lin1._metric is not None and lin2._metric is not None: res = res.add_metric(lin1._metric + lin2._metric) return res diff --git a/nifty5/plotting/plot.py b/nifty5/plotting/plot.py index 915f84d7ef0549d0a16d261a6cc80c5aedf8cb9c..6e414f5bc5760b6bdb67d12331f0418431640131 100644 --- a/nifty5/plotting/plot.py +++ b/nifty5/plotting/plot.py @@ -262,82 +262,79 @@ def _plot(f, ax, **kwargs): raise ValueError("Field type not(yet) supported") -_plots = [] -_kwargs = [] - - -def plot(f, **kwargs): - """Add a figure to the current list of plots. - - Notes - ----- - After doing one or more calls `plot()`, one also needs to call - `plot_finish()` to output the result. - - Parameters - ---------- - f: Field, or list of Field objects - If `f` is a single Field, it must live over a single `RGSpace`, - `PowerSpace`, `HPSpace`, `GLSPace`. - If it is a list, all list members must be Fields living over the same - one-dimensional `RGSpace` or `PowerSpace`. - title: string - title of the plot - xlabel: string - label for the x axis - ylabel: string - label for the y axis - [xyz]min, [xyz]max: float - limits for the values to plot - colormap: string - color map to use for the plot (if it is a 2D plot) - linewidth: float or list of floats - line width - label: string of list of strings - annotation string - alpha: float or list of floats - transparency value - """ - _plots.append(f) - _kwargs.append(kwargs) - - -def plot_finish(**kwargs): - """Plot the accumulated list of figures. - - Parameters - ---------- - title: string - title of the full plot - nx, ny: integer (default: square root of the numer of plots, rounded up) - number of subplots to use in x- and y-direction - xsize, ysize: float (default: 6) - dimensions of the full plot in inches - name: string (default: "") - if left empty, the plot will be shown on the screen, - otherwise it will be written to a file with the given name. - Supported extensions: .png and .pdf - """ - global _plots, _kwargs - import matplotlib.pyplot as plt - nplot = len(_plots) - fig = plt.figure() - if "title" in kwargs: - plt.suptitle(kwargs.pop("title")) - nx = kwargs.pop("nx", int(np.ceil(np.sqrt(nplot)))) - ny = kwargs.pop("ny", int(np.ceil(np.sqrt(nplot)))) - if nx*ny < nplot: - raise ValueError( - 'Figure dimensions not sufficient for number of plots. ' - 'Available plot slots: {}, number of plots: {}' - .format(nx*ny, nplot)) - xsize = kwargs.pop("xsize", 6) - ysize = kwargs.pop("ysize", 6) - fig.set_size_inches(xsize, ysize) - for i in range(nplot): - ax = fig.add_subplot(ny, nx, i+1) - _plot(_plots[i], ax, **_kwargs[i]) - fig.tight_layout() - _makeplot(kwargs.pop("name", None)) - _plots = [] - _kwargs = [] +class Plot(object): + def __init__(self): + self._plots = [] + self._kwargs = [] + + def add(self, f, **kwargs): + """Add a figure to the current list of plots. + + Notes + ----- + After doing one or more calls `plot()`, one also needs to call + `plot_finish()` to output the result. + + Parameters + ---------- + f: Field, or list of Field objects + If `f` is a single Field, it must live over a single `RGSpace`, + `PowerSpace`, `HPSpace`, `GLSPace`. + If it is a list, all list members must be Fields living over the + same one-dimensional `RGSpace` or `PowerSpace`. + title: string + title of the plot + xlabel: string + label for the x axis + ylabel: string + label for the y axis + [xyz]min, [xyz]max: float + limits for the values to plot + colormap: string + color map to use for the plot (if it is a 2D plot) + linewidth: float or list of floats + line width + label: string of list of strings + annotation string + alpha: float or list of floats + transparency value + """ + self._plots.append(f) + self._kwargs.append(kwargs) + + def output(self, **kwargs): + """Plot the accumulated list of figures. + + Parameters + ---------- + title: string + title of the full plot + nx, ny: integer (default: square root of the numer of plots, rounded up) + number of subplots to use in x- and y-direction + xsize, ysize: float (default: 6) + dimensions of the full plot in inches + name: string (default: "") + if left empty, the plot will be shown on the screen, + otherwise it will be written to a file with the given name. + Supported extensions: .png and .pdf + """ + import matplotlib.pyplot as plt + nplot = len(self._plots) + fig = plt.figure() + if "title" in kwargs: + plt.suptitle(kwargs.pop("title")) + nx = kwargs.pop("nx", int(np.ceil(np.sqrt(nplot)))) + ny = kwargs.pop("ny", int(np.ceil(np.sqrt(nplot)))) + if nx*ny < nplot: + raise ValueError( + 'Figure dimensions not sufficient for number of plots. ' + 'Available plot slots: {}, number of plots: {}' + .format(nx*ny, nplot)) + xsize = kwargs.pop("xsize", 6) + ysize = kwargs.pop("ysize", 6) + fig.set_size_inches(xsize, ysize) + for i in range(nplot): + ax = fig.add_subplot(ny, nx, i+1) + _plot(self._plots[i], ax, **self._kwargs[i]) + fig.tight_layout() + _makeplot(kwargs.pop("name", None)) diff --git a/test/test_minimization/test_minimizers.py b/test/test_minimization/test_minimizers.py index 7b6df9e777c3386a8d7eed98985e91ea2b07e72e..5ec52853538d3e488fda8bc69277a21e2a120e96 100644 --- a/test/test_minimization/test_minimizers.py +++ b/test/test_minimization/test_minimizers.py @@ -113,6 +113,12 @@ class Test_Minimizers(unittest.TestCase): iteration_limit=1000) return ift.InversionEnabler(RBCurv(self._position), t1) + def apply_metric(self, x): + inp = x.to_global_data_rw() + pos = self._position.to_global_data_rw() + return ift.Field.from_global_data( + space, rosen_hess_prod(pos, inp)) + try: minimizer = eval(minimizer) energy = RBEnergy(position=starting_point) @@ -145,12 +151,11 @@ class Test_Minimizers(unittest.TestCase): return ift.Field.full(self.position.domain, 2*x*np.exp(-(x**2))) - @property - def metric(self): - x = self.position.to_global_data()[0] - v = (2 - 4*x*x)*np.exp(-x**2) + def apply_metric(self, x): + p = self.position.to_global_data()[0] + v = (2 - 4*p*p)*np.exp(-p**2) return ift.DiagonalOperator( - ift.Field.full(self.position.domain, v)) + ift.Field.full(self.position.domain, v))(x) try: minimizer = eval(minimizer) @@ -190,6 +195,12 @@ class Test_Minimizers(unittest.TestCase): return ift.DiagonalOperator( ift.Field.full(self.position.domain, v)) + def apply_metric(self, x): + p = self.position.to_global_data()[0] + v = np.cosh(p) + return ift.DiagonalOperator( + ift.Field.full(self.position.domain, v))(x) + try: minimizer = eval(minimizer) energy = CoshEnergy(position=starting_point) diff --git a/test/test_operators/test_adjoint.py b/test/test_operators/test_adjoint.py index 50397e6c6b5ed9898f8e8d62dfc0e2597289b041..8d4bc649a5a6badd0106a8286482be4058982b00 100644 --- a/test/test_operators/test_adjoint.py +++ b/test/test_operators/test_adjoint.py @@ -188,10 +188,10 @@ class Consistency_Tests(unittest.TestCase): @expand(product([0, 1, 2, 3, (0, 1), (0, 2), (0, 1, 2), (0, 2, 3), (1, 3)], [np.float64, np.complex128])) - def testDomainDistributor(self, spaces, dtype): + def testContractionOperator(self, spaces, dtype): dom = (ift.RGSpace(10), ift.UnstructuredDomain(13), ift.GLSpace(5), ift.HPSpace(4)) - op = ift.DomainDistributor(dom, spaces) + op = ift.ContractionOperator(dom, spaces) ift.extra.consistency_check(op, dtype, dtype) @expand(product([True, False]))