Commit 9a905efc authored by Reimar Heinrich Leike's avatar Reimar Heinrich Leike

Merge branch 'fix_linearization_getitem' of...

Merge branch 'fix_linearization_getitem' of https://gitlab.mpcdf.mpg.de/ift/nifty into fix_linearization_getitem
parents b9879488 61147542
Pipeline #43308 passed with stages
in 8 minutes
...@@ -21,6 +21,8 @@ ...@@ -21,6 +21,8 @@
# 1D (set mode=0), 2D (mode=1), or on the sphere (mode=2) # 1D (set mode=0), 2D (mode=1), or on the sphere (mode=2)
############################################################################### ###############################################################################
import sys
import numpy as np import numpy as np
import nifty5 as ift import nifty5 as ift
...@@ -51,7 +53,6 @@ def mask_to_nan(mask, field): ...@@ -51,7 +53,6 @@ def mask_to_nan(mask, field):
if __name__ == '__main__': if __name__ == '__main__':
import sys
np.random.seed(42) np.random.seed(42)
# Choose space on which the signal field is defined # Choose space on which the signal field is defined
...@@ -140,7 +141,7 @@ if __name__ == '__main__': ...@@ -140,7 +141,7 @@ if __name__ == '__main__':
# Plotting # Plotting
rg = isinstance(position_space, ift.RGSpace) rg = isinstance(position_space, ift.RGSpace)
plot = ift.Plot() plot = ift.Plot()
filename = f"getting_started_1_mode_{mode}.png" filename = "getting_started_1_mode_{}.png".format(mode)
if rg and len(position_space.shape) == 1: if rg and len(position_space.shape) == 1:
plot.add( plot.add(
[HT(MOCK_SIGNAL), GR.adjoint(data), [HT(MOCK_SIGNAL), GR.adjoint(data),
...@@ -155,4 +156,4 @@ if __name__ == '__main__': ...@@ -155,4 +156,4 @@ if __name__ == '__main__':
plot.add(HT(m), title='Reconstruction') plot.add(HT(m), title='Reconstruction')
plot.add(mask_to_nan(mask, HT(m - MOCK_SIGNAL)), title='Residuals') plot.add(mask_to_nan(mask, HT(m - MOCK_SIGNAL)), title='Residuals')
plot.output(nx=2, ny=2, xsize=10, ysize=10, name=filename) plot.output(nx=2, ny=2, xsize=10, ysize=10, name=filename)
print(f"Saved results as '{filename}'.") print("Saved results as '{}'.".format(filename))
...@@ -21,6 +21,8 @@ ...@@ -21,6 +21,8 @@
# 1D (set mode=0), 2D (mode=1), or on the sphere (mode=2) # 1D (set mode=0), 2D (mode=1), or on the sphere (mode=2)
############################################################################### ###############################################################################
import sys
import numpy as np import numpy as np
import nifty5 as ift import nifty5 as ift
...@@ -42,8 +44,6 @@ def exposure_2d(): ...@@ -42,8 +44,6 @@ def exposure_2d():
if __name__ == '__main__': if __name__ == '__main__':
import sys
# FIXME All random seeds to 42
np.random.seed(42) np.random.seed(42)
# Choose space on which the signal field is defined # Choose space on which the signal field is defined
...@@ -112,11 +112,11 @@ if __name__ == '__main__': ...@@ -112,11 +112,11 @@ if __name__ == '__main__':
# Plotting # Plotting
signal = sky(mock_position) signal = sky(mock_position)
reconst = sky(H.position) reconst = sky(H.position)
filename = f"getting_started_2_mode_{mode}.png" filename = "getting_started_2_mode_{}.png".format(mode)
plot = ift.Plot() plot = ift.Plot()
plot.add(signal, title='Signal') plot.add(signal, title='Signal')
plot.add(GR.adjoint(data), title='Data') plot.add(GR.adjoint(data), title='Data')
plot.add(reconst, title='Reconstruction') plot.add(reconst, title='Reconstruction')
plot.add(reconst - signal, title='Residuals') plot.add(reconst - signal, title='Residuals')
plot.output(xsize=12, ysize=10, name=filename) plot.output(xsize=12, ysize=10, name=filename)
print(f"Saved results as '{filename}'.") print("Saved results as '{}'.".format(filename))
...@@ -25,6 +25,8 @@ ...@@ -25,6 +25,8 @@
# Demo takes a while to compute # Demo takes a while to compute
############################################################# #############################################################
import sys
import numpy as np import numpy as np
import nifty5 as ift import nifty5 as ift
...@@ -43,8 +45,7 @@ def radial_los(n_los): ...@@ -43,8 +45,7 @@ def radial_los(n_los):
if __name__ == '__main__': if __name__ == '__main__':
import sys np.random.seed(420)
np.random.seed(420) # picked for a nice field realization
# Choose between random line-of-sight response (mode=0) and radial lines # Choose between random line-of-sight response (mode=0) and radial lines
# of sight (mode=1) # of sight (mode=1)
...@@ -52,7 +53,7 @@ if __name__ == '__main__': ...@@ -52,7 +53,7 @@ if __name__ == '__main__':
mode = int(sys.argv[1]) mode = int(sys.argv[1])
else: else:
mode = 0 mode = 0
filename = f"getting_started_3_mode_{mode}_" + "{}.png" filename = "getting_started_3_mode_{}_".format(mode) + "{}.png"
position_space = ift.RGSpace([128, 128]) position_space = ift.RGSpace([128, 128])
harmonic_space = position_space.get_default_codomain() harmonic_space = position_space.get_default_codomain()
...@@ -135,7 +136,7 @@ if __name__ == '__main__': ...@@ -135,7 +136,7 @@ if __name__ == '__main__':
plot.add(signal(KL.position), title="reconstruction") plot.add(signal(KL.position), title="reconstruction")
plot.add([A.force(KL.position), A.force(mock_position)], title="power") plot.add([A.force(KL.position), A.force(mock_position)], title="power")
plot.output(ny=1, ysize=6, xsize=16, plot.output(ny=1, ysize=6, xsize=16,
name=filename.format(f"loop_{i:02}")) name=filename.format("loop_{:02d}".format(i)))
# Draw posterior samples # Draw posterior samples
KL = ift.MetricGaussianKL(mean, H, N_samples) KL = ift.MetricGaussianKL(mean, H, N_samples)
...@@ -156,4 +157,4 @@ if __name__ == '__main__': ...@@ -156,4 +157,4 @@ if __name__ == '__main__':
title="Sampled Posterior Power Spectrum", title="Sampled Posterior Power Spectrum",
linewidth=[1.]*len(powers) + [3., 3.]) linewidth=[1.]*len(powers) + [3., 3.])
plot.output(ny=1, nx=3, xsize=24, ysize=6, name=filename_res) plot.output(ny=1, nx=3, xsize=24, ysize=6, name=filename_res)
print(f"Saved results as '{filename_res}'.") print("Saved results as '{}'.".format(filename_res))
...@@ -363,7 +363,7 @@ def makeOp(input): ...@@ -363,7 +363,7 @@ def makeOp(input):
return DiagonalOperator(input) return DiagonalOperator(input)
if isinstance(input, MultiField): if isinstance(input, MultiField):
return BlockDiagonalOperator( return BlockDiagonalOperator(
input.domain, {key: makeOp(val) for key, val in enumerate(input)}) input.domain, {key: makeOp(val) for key, val in input.items()})
raise NotImplementedError raise NotImplementedError
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment