Commit a0ab986b authored by Philipp Arras's avatar Philipp Arras
Browse files

Merge remote-tracking branch 'nifty-dev/NIFTy_5' into misc_work

parents 0f7cc8c4 73532911
......@@ -96,3 +96,12 @@ run_getting_started_3:
script:
- python demos/getting_started_3.py
- python3 demos/getting_started_3.py
run_bernoulli:
stage: demo_runs
script:
- python demos/bernoulli_demo.py
- python3 demos/bernoulli_demo.py
artifacts:
paths:
- '*.png'
......@@ -58,7 +58,7 @@
"### Posterior\n",
"The Posterior is given by:\n",
"\n",
"$$\\mathcal P (s|d) \\propto P(s,d) = \\mathcal G(d-Rs,N) \\,\\mathcal G(s,S) \\propto \\mathcal G (m,D) $$\n",
"$$\\mathcal P (s|d) \\propto P(s,d) = \\mathcal G(d-Rs,N) \\,\\mathcal G(s,S) \\propto \\mathcal G (s-m,D) $$\n",
"\n",
"where\n",
"$$\\begin{align}\n",
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
import nifty5 as ift
import numpy as np
if __name__ == '__main__':
# ABOUT THIS CODE
# FIXME ABOUT THIS CODE
np.random.seed(41)
# Set up the position space of the signal
......@@ -70,6 +88,6 @@ if __name__ == '__main__':
reconstruction = sky.at(H.position).value
ift.plot(reconstruction, title='reconstruction', name='reconstruction.pdf')
ift.plot(GR.adjoint_times(data), title='data', name='data.pdf')
ift.plot(sky.at(mock_position).value, title='truth', name='truth.pdf')
ift.plot(reconstruction, title='reconstruction', name='reconstruction.png')
ift.plot(GR.adjoint_times(data), title='data', name='data.png')
ift.plot(sky.at(mock_position).value, title='truth', name='truth.png')
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
import nifty5 as ift
import numpy as np
def make_chess_mask():
def make_chess_mask(position_space):
mask = np.ones(position_space.shape)
for i in range(4):
for j in range(4):
......@@ -17,27 +35,35 @@ def make_random_mask():
return mask.to_global_data()
if __name__ == '__main__':
# # description of the tutorial ###
# Choose problem geometry and masking
def mask_to_nan(mask, field):
masked_data = field.local_data.copy()
masked_data[mask.local_data == 0] = np.nan
return ift.from_local_data(field.domain, masked_data)
# One dimensional regular grid
position_space = ift.RGSpace([1024])
mask = np.ones(position_space.shape)
# # Two dimensional regular grid with chess mask
# position_space = ift.RGSpace([128,128])
# mask = make_chess_mask()
if __name__ == '__main__':
np.random.seed(42)
# FIXME description of the tutorial
# # Sphere with half of its locations randomly masked
# position_space = ift.HPSpace(128)
# mask = make_random_mask()
# Choose problem geometry and masking
mode = 0
if mode == 0:
# One dimensional regular grid
position_space = ift.RGSpace([1024])
mask = np.ones(position_space.shape)
elif mode == 1:
# Two dimensional regular grid with chess mask
position_space = ift.RGSpace([128, 128])
mask = make_chess_mask(position_space)
else:
# Sphere with half of its locations randomly masked
position_space = ift.HPSpace(128)
mask = make_random_mask()
harmonic_space = position_space.get_default_codomain()
HT = ift.HarmonicTransformOperator(harmonic_space, target=position_space)
# set correlation structure with a power spectrum and build
# Set correlation structure with a power spectrum and build
# prior correlation covariance
def power_spectrum(k):
return 100. / (20.+k**3)
......@@ -47,7 +73,7 @@ if __name__ == '__main__':
S = ift.DiagonalOperator(prior_correlation_structure)
# build instrument response consisting of a discretization, mask
# Build instrument response consisting of a discretization, mask
# and harmonic transformaion
GR = ift.GeometryRemover(position_space)
mask = ift.Field.from_global_data(position_space, mask)
......@@ -56,19 +82,19 @@ if __name__ == '__main__':
data_space = GR.target
# setting the noise covariance
# Set the noise covariance
noise = 5.
N = ift.ScalingOperator(noise, data_space)
# creating mock data
# Create mock data
MOCK_SIGNAL = S.draw_sample()
MOCK_NOISE = N.draw_sample()
data = R(MOCK_SIGNAL) + MOCK_NOISE
# building propagator D and information source j
# Build propagator D and information source j
j = R.adjoint_times(N.inverse_times(data))
D_inv = R.adjoint * N.inverse * R + S.inverse
# make it invertible
# Make it invertible
IC = ift.GradientNormController(iteration_limit=500, tol_abs_gradnorm=1e-3)
D = ift.InversionEnabler(D_inv, IC, approximation=S.inverse).inverse
......@@ -76,4 +102,16 @@ if __name__ == '__main__':
m = D(j)
# PLOTTING
# Truth, data, reconstruction, residuals
rg = isinstance(position_space, ift.RGSpace)
if rg and len(position_space.shape) == 1:
ift.plot([HT(MOCK_SIGNAL), GR.adjoint(data), HT(m)],
label=['Mock signal', 'Data', 'Reconstruction'],
alpha=[1, .3, 1],
name='getting_started_1.png')
else:
ift.plot(HT(MOCK_SIGNAL), title='Mock Signal', name='mock_signal.png')
ift.plot(mask_to_nan(mask, (GR*Mask).adjoint(data)),
title='Data', name='data.png')
ift.plot(HT(m), title='Reconstruction', name='reconstruction.png')
ift.plot(mask_to_nan(mask, HT(m-MOCK_SIGNAL)), name='residuals.png')
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
import nifty5 as ift
import numpy as np
......@@ -18,7 +36,7 @@ def get_2D_exposure():
if __name__ == '__main__':
# ABOUT THIS CODE
# FIXME description of the tutorial
np.random.seed(41)
# Set up the position space of the signal
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
import nifty5 as ift
import numpy as np
......@@ -9,7 +27,7 @@ def get_random_LOS(n_los):
if __name__ == '__main__':
# ## ABOUT THIS TUTORIAL
# FIXME description of the tutorial
np.random.seed(42)
position_space = ift.RGSpace([128, 128])
......@@ -65,9 +83,9 @@ if __name__ == '__main__':
INITIAL_POSITION = ift.from_random('normal', H.position.domain)
position = INITIAL_POSITION
ift.plot(signal.at(MOCK_POSITION).value, name='truth.pdf')
ift.plot(R.adjoint_times(data), name='data.pdf')
ift.plot([A.at(MOCK_POSITION).value], name='power.pdf')
ift.plot(signal.at(MOCK_POSITION).value, name='truth.png')
ift.plot(R.adjoint_times(data), name='data.png')
ift.plot([A.at(MOCK_POSITION).value], name='power.png')
# number of samples used to estimate the KL
N_samples = 20
......@@ -81,17 +99,17 @@ if __name__ == '__main__':
KL, convergence = minimizer(KL)
position = KL.position
ift.plot(signal.at(position).value, name='reconstruction.pdf')
ift.plot(signal.at(position).value, name='reconstruction.png')
ift.plot([A.at(position).value, A.at(MOCK_POSITION).value],
name='power.pdf')
name='power.png')
sc = ift.StatCalculator()
for sample in samples:
sc.add(signal.at(sample+position).value)
ift.plot(sc.mean, name='avrg.pdf')
ift.plot(ift.sqrt(sc.var), name='std.pdf')
ift.plot(sc.mean, name='avrg.png')
ift.plot(ift.sqrt(sc.var), name='std.png')
powers = [A.at(s+position).value for s in samples]
ift.plot([A.at(position).value, A.at(MOCK_POSITION).value]+powers,
name='power.pdf')
name='power.png')
......@@ -17,11 +17,14 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
from ..compat import *
import sys
import numpy as np
from .random import Random
from mpi4py import MPI
import sys
from ..compat import *
from .random import Random
_comm = MPI.COMM_WORLD
ntask = _comm.Get_size()
......
......@@ -19,9 +19,10 @@
# Data object module for NIFTy that uses simple numpy ndarrays.
import numpy as np
from numpy import empty, empty_like, exp, full, log
from numpy import ndarray as data_object
from numpy import full, empty, empty_like, sqrt, ones, zeros, vdot, \
exp, log, tanh
from numpy import ones, sqrt, tanh, vdot, zeros
from .random import Random
ntask = 1
......
......@@ -17,9 +17,11 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
from ..compat import *
import numpy as np
from ..compat import *
class Random(object):
@staticmethod
......
......@@ -17,8 +17,8 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
from .compat import *
from .compat import *
try:
from mpi4py import MPI
......
......@@ -17,6 +17,7 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
from .compat import *
from .domains.domain import Domain
......
......@@ -17,8 +17,10 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
from ..compat import *
import numpy as np
from ..compat import *
from .structured_domain import StructuredDomain
......
......@@ -17,8 +17,10 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
from ..compat import *
import abc
from ..compat import *
from ..utilities import NiftyMetaBase
......
......@@ -17,8 +17,10 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
from ..compat import *
import numpy as np
from ..compat import *
from .structured_domain import StructuredDomain
......
......@@ -17,8 +17,10 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
from ..compat import *
import numpy as np
from ..compat import *
from .structured_domain import StructuredDomain
......
......@@ -17,10 +17,12 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
from ..compat import *
import numpy as np
from .structured_domain import StructuredDomain
from ..compat import *
from ..field import Field
from .structured_domain import StructuredDomain
class LMSpace(StructuredDomain):
......
......@@ -17,11 +17,13 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
from ..compat import *
from ..sugar import exp
import numpy as np
from .. import dobj
from ..compat import *
from ..field import Field
from ..sugar import exp
from .structured_domain import StructuredDomain
......
......@@ -17,10 +17,12 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
from ..compat import *
import numpy as np
from .structured_domain import StructuredDomain
from .. import dobj
from ..compat import *
from .structured_domain import StructuredDomain
class PowerSpace(StructuredDomain):
......
......@@ -17,11 +17,13 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
from ..compat import *
import numpy as np
from .structured_domain import StructuredDomain
from ..field import Field
from .. import dobj
from ..compat import *
from ..field import Field
from .structured_domain import StructuredDomain
class RGSpace(StructuredDomain):
......
......@@ -17,11 +17,14 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
from ..compat import *
import abc
from .domain import Domain
import numpy as np
from ..compat import *
from .domain import Domain
class StructuredDomain(Domain):
"""The abstract base class for all structured NIFTy domains.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment