diff --git a/mpcdf_common.py b/mpcdf_common.py index 9bf75c6c69a3b6d5d1e40dddd19f6b05e8400963..d50cf6267a3297fd3c3b2c55d48a08f8495050cf 100644 --- a/mpcdf_common.py +++ b/mpcdf_common.py @@ -574,6 +574,11 @@ def mpcdf_enable_repositories(api_url, project, package, verbose=False, dry_run= if valid_pgi_mpi(pgi, mpi): enable(mpi + "_" + pgi) + if flag == "cuda_aware_mpi": + for cuda, mpi, compiler in product(actual_cudas(), actual_mpis(), all_compilers): + if valid_cuda(cuda, compiler) and valid_mpi(compiler, mpi): + enable(cuda + "_aware_" + mpi + "_" + compiler) + if len(build.getchildren()) > 0: build.getchildren()[-1].tail = "\n " @@ -662,11 +667,11 @@ Macros: for oldrepo in root.findall("./repository"): root.remove(oldrepo) - def repo(name, dependencies, compiler=False, mpi=False, cuda=False, cuda_mpi=False, additional_tags=(), **macros): + def repo(name, dependencies, compiler=False, mpi=False, cuda=False, cuda_mpi=False, cuda_aware_mpi=False, additional_tags=(), **macros): old_repos.discard(name) - have_compiler = compiler or mpi or cuda or cuda_mpi - have_mpi = mpi or cuda_mpi - have_cuda = cuda or cuda_mpi + have_compiler = compiler or mpi or cuda or cuda_mpi or cuda_aware_mpi + have_mpi = mpi or cuda_mpi or cuda_aware_mpi + have_cuda = cuda or cuda_mpi or cuda_aware_mpi existing_repo = root.find("./repository[@name='{0}']".format(name)) if existing_repo is not None: @@ -713,6 +718,7 @@ Macros: repoconf.append("%is_mpi_repository {0}".format(1 if mpi else 0)) repoconf.append("%is_cuda_repository {0}".format(1 if cuda else 0)) repoconf.append("%is_cuda_mpi_repository {0}".format(1 if cuda_mpi else 0)) + repoconf.append("%is_cuda_aware_mpi_repository {0}".format(1 if cuda_aware_mpi else 0)) repoconf.append("%have_mpcdf_compiler {0}".format(1 if have_compiler else 0)) repoconf.append("%have_mpcdf_mpi {0}".format(1 if have_mpi else 0)) @@ -805,6 +811,10 @@ Macros: (project, mpi + "_" + compiler)), cuda_mpi=True) + repo(cuda + "_aware_" + mpi + "_" + compiler, + ((project, cuda + "_" + mpi + "_" + compiler),), + cuda_aware_mpi=True) + if old_repos and not remove_old: print("Warning: Keeping the prjconf sections for the following obsolete repositories:") for name in sorted(old_repos):