Skip to content
Snippets Groups Projects
Commit b09ae971 authored by Philipp Arras's avatar Philipp Arras
Browse files

Add simpler MPI likelihood example

parent 888f127c
No related tags found
1 merge request!41Draft: Lh distribution
Pipeline #144667 passed
......@@ -54,74 +54,76 @@ def distribution_strategy(distribution_shape, comm):
return comm0, comm1
comm_world = MPI.COMM_WORLD
n_lhs = 4
n_samples = 5
samplecomm, lhcomm = distribution_strategy((2 * n_samples, n_lhs), comm_world)
position_space = ift.RGSpace((128, 128))
args = {
"offset_mean": 0,
"offset_std": (1e-3, 1e-6),
"fluctuations": (1.0, 0.8),
"loglogavgslope": (-3.0, 1),
"flexibility": (2, 1.0),
"asperity": (0.5, 0.4),
}
correlated_field = ift.SimpleCorrelatedField(position_space, **args)
signal = ift.sigmoid(correlated_field)
# Create all likelihoods
n_los = 100
assert n_los > n_lhs
LOS_starts = list(ift.random.current_rng().random((n_los, 2)).T)
LOS_ends = list(0.5 + 0 * ift.random.current_rng().random((n_los, 2)).T)
mock_position = ift.from_random(signal.domain, "normal")
responses, datas, inverse_covariances = [], [], []
for ii in range(n_lhs):
lo, hi = ift.utilities.shareRange(n_los, n_lhs, ii)
R = ift.LOSResponse(
position_space,
starts=[xx[lo:hi] for xx in LOS_starts],
ends=[yy[lo:hi] for yy in LOS_ends],
if __name__ == "__main__":
comm_world = MPI.COMM_WORLD
n_lhs = 4
n_samples = 5
samplecomm, lhcomm = distribution_strategy((2 * n_samples, n_lhs), comm_world)
position_space = ift.RGSpace((128, 128))
args = {
"offset_mean": 0,
"offset_std": (1e-3, 1e-6),
"fluctuations": (1.0, 0.8),
"loglogavgslope": (-3.0, 1),
"flexibility": (2, 1.0),
"asperity": (0.5, 0.4),
}
correlated_field = ift.SimpleCorrelatedField(position_space, **args)
signal = ift.sigmoid(correlated_field)
# Create all likelihoods
n_los = 100
assert n_los > n_lhs
LOS_starts = list(ift.random.current_rng().random((n_los, 2)).T)
LOS_ends = list(0.5 + 0 * ift.random.current_rng().random((n_los, 2)).T)
mock_position = ift.from_random(signal.domain, "normal")
responses, datas, inverse_covariances = [], [], []
for ii in range(n_lhs):
lo, hi = ift.utilities.shareRange(n_los, n_lhs, ii)
R = ift.LOSResponse(
position_space,
starts=[xx[lo:hi] for xx in LOS_starts],
ends=[yy[lo:hi] for yy in LOS_ends],
)
responses.append(R)
data_space = R.target
noise = 0.001
N = ift.ScalingOperator(data_space, noise, np.float64)
datas.append(R(signal(mock_position)) + N.draw_sample())
inverse_covariances.append(N.inverse)
lhs = [
ift.GaussianEnergy(data=dd, inverse_covariance=ic) @ rr
for dd, rr, ic in zip(datas, responses, inverse_covariances)
]
# /Create all likelihoods
# Select the likelihoods relevant for the current task
lh_size, lh_rank, lh_master = ift.utilities.get_MPI_params_from_comm(lhcomm)
lo, hi = ift.utilities.shareRange(n_lhs, lh_size, lh_rank)
lhs = lhs[lo:hi]
# /Select the likelihoods relevant for the current task
likelihood_energy = rve.AllreduceSum(lhs, lhcomm) @ signal
ic_sampling = ift.GradientNormController(name="Sampling (linear)", iteration_limit=20)
ic_newton = ift.GradientNormController(name="Newton", iteration_limit=10)
minimizer = ift.NewtonCG(ic_newton)
samples = ift.optimize_kl(
likelihood_energy,
2,
n_samples,
minimizer,
ic_sampling,
None,
plottable_operators={"signal": (signal, dict(vmin=0, vmax=1))},
ground_truth_position=mock_position,
output_directory=f"results_{comm_world.Get_size()}tasks",
overwrite=True,
comm=samplecomm,
)
responses.append(R)
data_space = R.target
noise = 0.001
N = ift.ScalingOperator(data_space, noise, np.float64)
datas.append(R(signal(mock_position)) + N.draw_sample())
inverse_covariances.append(N.inverse)
lhs = [
ift.GaussianEnergy(data=dd, inverse_covariance=ic) @ rr
for dd, rr, ic in zip(datas, responses, inverse_covariances)
]
# /Create all likelihoods
# Select the likelihoods relevant for the current task
lh_size, lh_rank, lh_master = ift.utilities.get_MPI_params_from_comm(lhcomm)
lo, hi = ift.utilities.shareRange(n_lhs, lh_size, lh_rank)
lhs = lhs[lo:hi]
# /Select the likelihoods relevant for the current task
likelihood_energy = rve.AllreduceSum(lhs, lhcomm) @ signal
ic_sampling = ift.GradientNormController(name="Sampling (linear)", iteration_limit=20)
ic_newton = ift.GradientNormController(name="Newton", iteration_limit=10)
minimizer = ift.NewtonCG(ic_newton)
samples = ift.optimize_kl(
likelihood_energy,
2,
n_samples,
minimizer,
ic_sampling,
None,
plottable_operators={"signal": (signal, dict(vmin=0, vmax=1))},
ground_truth_position=mock_position,
output_directory=f"results_{comm_world.Get_size()}tasks",
overwrite=True,
comm=samplecomm,
)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2022 Max-Planck-Society
# Copyright(C) 2022 Max-Planck-Society, Philipp Arras
import nifty8 as ift
import numpy as np
import resolve as rve
from nifty8.utilities import allreduce_sum
from mpi4py import MPI
from getting_started_3_mpilh import distribution_strategy
def get_local(lst, comm):
return lst[slice(*ift.utilities.shareRange(len(lst), *ift.utilities.get_MPI_params_from_comm(comm)[:2]))]
def eval_likelihood(x, num):
return num*num*x
def main():
n_lhs = 4
n_samples = 5
comm = MPI.COMM_WORLD
samplecomm, lhcomm = distribution_strategy((2 * n_samples, n_lhs), comm)
# Would not be present in real application
global_samples = list(np.arange(n_samples)*10)
global_lh_inds = list(range(n_lhs))
# /Would not be present in real application
# We want to compute: sum_i sum_j lh_j(sample_i)
# In this example lh_j = lambda x: j*j*x
# Get quantities that would be locally present
local_lh_inds = get_local(global_lh_inds, lhcomm)
local_samples = get_local(global_samples, samplecomm)
res = allreduce_sum([allreduce_sum([eval_likelihood(ss, i_local_lh) for i_local_lh in local_lh_inds], lhcomm) for ss in local_samples], samplecomm)
print("Total sum", res)
if __name__ == "__main__":
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment