diff --git a/demos/getting_started_3.py b/demos/getting_started_3.py index 65d5c66537c9201f251c0e9c667ccfd441c83031..3be2a0c8ea97fddf8b229b4f2b9c3c309555f13e 100644 --- a/demos/getting_started_3.py +++ b/demos/getting_started_3.py @@ -85,7 +85,7 @@ if __name__ == '__main__': plot = ift.Plot() plot.add(signal(MOCK_POSITION), title='Ground Truth') plot.add(R.adjoint_times(data), title='Data') - plot.add([A(MOCK_POSITION)], title='Power Spectrum') + plot.add([A(MOCK_POSITION.extract(A.domain))], title='Power Spectrum') plot.output(ny=1, nx=3, xsize=24, ysize=6, name="setup.png") # number of samples used to estimate the KL @@ -97,18 +97,24 @@ if __name__ == '__main__': plot = ift.Plot() plot.add(signal(KL.position), title="reconstruction") - plot.add([A(KL.position), A(MOCK_POSITION)], title="power") + plot.add( + [ + A(KL.position.extract(A.domain)), + A(MOCK_POSITION.extract(A.domain)) + ], + title="power") plot.output(ny=1, ysize=6, xsize=16, name="loop.png") plot = ift.Plot() sc = ift.StatCalculator() for sample in KL.samples: - sc.add(signal(sample+KL.position)) + sc.add(signal(sample + KL.position)) plot.add(sc.mean, title="Posterior Mean") plot.add(ift.sqrt(sc.var), title="Posterior Standard Deviation") - powers = [A(s+KL.position) for s in KL.samples] + powers = [A((s + KL.position).extract(A.domain)) for s in KL.samples] plot.add( - [A(KL.position), A(MOCK_POSITION)]+powers, + [A(KL.position.extract(A.domain)), + A(MOCK_POSITION.extract(A.domain))] + powers, title="Sampled Posterior Power Spectrum") plot.output(ny=1, nx=3, xsize=24, ysize=6, name="results.png") diff --git a/nifty5/operators/energy_operators.py b/nifty5/operators/energy_operators.py index 9986bbd4435a9e77811cad78f6c7171aa14c4290..809c3226b04978e7ecb579491bccb6ce78781bf9 100644 --- a/nifty5/operators/energy_operators.py +++ b/nifty5/operators/energy_operators.py @@ -21,6 +21,7 @@ from __future__ import absolute_import, division, print_function from .. import utilities from ..compat import * from ..domain_tuple import DomainTuple +from ..domains.domain import Domain from ..field import Field from ..linearization import Linearization from ..sugar import makeOp @@ -84,10 +85,12 @@ class GaussianEnergy(EnergyOperator): self._icov = None if covariance is None else covariance.inverse def _checkEquivalence(self, newdom): + if isinstance(newdom, Domain): + newdom = DomainTuple.make(newdom) if self._domain is None: - self._domain = DomainTuple.make(newdom) + self._domain = newdom else: - if self._domain != DomainTuple.make(newdom): + if self._domain != newdom: raise ValueError("domain mismatch") def apply(self, x):