Commit 330802d5 authored by Philipp Arras's avatar Philipp Arras
Browse files

Fix getting_started_3

parent fa7c9f88
...@@ -85,7 +85,7 @@ if __name__ == '__main__': ...@@ -85,7 +85,7 @@ if __name__ == '__main__':
plot = ift.Plot() plot = ift.Plot()
plot.add(signal(MOCK_POSITION), title='Ground Truth') plot.add(signal(MOCK_POSITION), title='Ground Truth')
plot.add(R.adjoint_times(data), title='Data') plot.add(R.adjoint_times(data), title='Data')
plot.add([A(MOCK_POSITION)], title='Power Spectrum') plot.add([A(MOCK_POSITION.extract(A.domain))], title='Power Spectrum')
plot.output(ny=1, nx=3, xsize=24, ysize=6, name="setup.png") plot.output(ny=1, nx=3, xsize=24, ysize=6, name="setup.png")
# number of samples used to estimate the KL # number of samples used to estimate the KL
...@@ -97,18 +97,24 @@ if __name__ == '__main__': ...@@ -97,18 +97,24 @@ if __name__ == '__main__':
plot = ift.Plot() plot = ift.Plot()
plot.add(signal(KL.position), title="reconstruction") plot.add(signal(KL.position), title="reconstruction")
plot.add([A(KL.position), A(MOCK_POSITION)], title="power") plot.add(
[
A(KL.position.extract(A.domain)),
A(MOCK_POSITION.extract(A.domain))
],
title="power")
plot.output(ny=1, ysize=6, xsize=16, name="loop.png") plot.output(ny=1, ysize=6, xsize=16, name="loop.png")
plot = ift.Plot() plot = ift.Plot()
sc = ift.StatCalculator() sc = ift.StatCalculator()
for sample in KL.samples: for sample in KL.samples:
sc.add(signal(sample+KL.position)) sc.add(signal(sample + KL.position))
plot.add(sc.mean, title="Posterior Mean") plot.add(sc.mean, title="Posterior Mean")
plot.add(ift.sqrt(sc.var), title="Posterior Standard Deviation") plot.add(ift.sqrt(sc.var), title="Posterior Standard Deviation")
powers = [A(s+KL.position) for s in KL.samples] powers = [A((s + KL.position).extract(A.domain)) for s in KL.samples]
plot.add( plot.add(
[A(KL.position), A(MOCK_POSITION)]+powers, [A(KL.position.extract(A.domain)),
A(MOCK_POSITION.extract(A.domain))] + powers,
title="Sampled Posterior Power Spectrum") title="Sampled Posterior Power Spectrum")
plot.output(ny=1, nx=3, xsize=24, ysize=6, name="results.png") plot.output(ny=1, nx=3, xsize=24, ysize=6, name="results.png")
...@@ -21,6 +21,7 @@ from __future__ import absolute_import, division, print_function ...@@ -21,6 +21,7 @@ from __future__ import absolute_import, division, print_function
from .. import utilities from .. import utilities
from ..compat import * from ..compat import *
from ..domain_tuple import DomainTuple from ..domain_tuple import DomainTuple
from ..domains.domain import Domain
from ..field import Field from ..field import Field
from ..linearization import Linearization from ..linearization import Linearization
from ..sugar import makeOp from ..sugar import makeOp
...@@ -84,10 +85,12 @@ class GaussianEnergy(EnergyOperator): ...@@ -84,10 +85,12 @@ class GaussianEnergy(EnergyOperator):
self._icov = None if covariance is None else covariance.inverse self._icov = None if covariance is None else covariance.inverse
def _checkEquivalence(self, newdom): def _checkEquivalence(self, newdom):
if isinstance(newdom, Domain):
newdom = DomainTuple.make(newdom)
if self._domain is None: if self._domain is None:
self._domain = DomainTuple.make(newdom) self._domain = newdom
else: else:
if self._domain != DomainTuple.make(newdom): if self._domain != newdom:
raise ValueError("domain mismatch") raise ValueError("domain mismatch")
def apply(self, x): def apply(self, x):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment