Skip to content
Snippets Groups Projects
Commit b09ae971 authored by Philipp Arras's avatar Philipp Arras
Browse files

Add simpler MPI likelihood example

parent 888f127c
No related branches found
No related tags found
1 merge request!41Draft: Lh distribution
Pipeline #144667 passed
...@@ -54,6 +54,8 @@ def distribution_strategy(distribution_shape, comm): ...@@ -54,6 +54,8 @@ def distribution_strategy(distribution_shape, comm):
return comm0, comm1 return comm0, comm1
if __name__ == "__main__":
comm_world = MPI.COMM_WORLD comm_world = MPI.COMM_WORLD
n_lhs = 4 n_lhs = 4
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2022 Max-Planck-Society
# Copyright(C) 2022 Max-Planck-Society, Philipp Arras
import nifty8 as ift
import numpy as np
import resolve as rve
from nifty8.utilities import allreduce_sum
from mpi4py import MPI
from getting_started_3_mpilh import distribution_strategy
def get_local(lst, comm):
return lst[slice(*ift.utilities.shareRange(len(lst), *ift.utilities.get_MPI_params_from_comm(comm)[:2]))]
def eval_likelihood(x, num):
return num*num*x
def main():
n_lhs = 4
n_samples = 5
comm = MPI.COMM_WORLD
samplecomm, lhcomm = distribution_strategy((2 * n_samples, n_lhs), comm)
# Would not be present in real application
global_samples = list(np.arange(n_samples)*10)
global_lh_inds = list(range(n_lhs))
# /Would not be present in real application
# We want to compute: sum_i sum_j lh_j(sample_i)
# In this example lh_j = lambda x: j*j*x
# Get quantities that would be locally present
local_lh_inds = get_local(global_lh_inds, lhcomm)
local_samples = get_local(global_samples, samplecomm)
res = allreduce_sum([allreduce_sum([eval_likelihood(ss, i_local_lh) for i_local_lh in local_lh_inds], lhcomm) for ss in local_samples], samplecomm)
print("Total sum", res)
if __name__ == "__main__":
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment