diff --git a/mpcdf_common.py b/mpcdf_common.py index 88fa87c2ea9a5c4178f18a43cedc41bf7774ac16..e97fd719e7c5338e89db758244a355d6cbf41de4 100644 --- a/mpcdf_common.py +++ b/mpcdf_common.py @@ -665,6 +665,17 @@ def mpcdf_enable_repositories(api_url, project, package, verbose=False, dry_run= if is_gcc_compiler(compiler) and is_openmpi(mpi) and valid_mpi(compiler, mpi): enable(rocm + "_" + mpi + "_" + compiler) + if flag == "rocm_aware_mpi": + for rocm, mpi, compiler in product(actual_rocms(), actual_mpis(), all_compilers): + if is_gcc_compiler(compiler) and is_openmpi(mpi) and valid_mpi(compiler, mpi): + enable(rocm + "_aware_" + mpi + "_" + compiler) + + if flag == "rocm_aware_openmpi_flavors": + for rocm, mpi, compiler in product(actual_rocms(), actual_mpis(), all_compilers): + if is_gcc_compiler(compiler) and is_openmpi(mpi) and valid_mpi(compiler, mpi): + for of in actual_openmpi_flavors(): + enable(rocm + "_aware_" + mpi + "_" + compiler + "_" + of) + if flag == "pgi": for pgi in actual_pgis(): enable(pgi) @@ -857,12 +868,12 @@ Macros: root.remove(oldrepo) def repo(name: str, dependencies: tuple, compiler=False, mpi=False, cuda=False, cuda_mpi=False, cuda_aware_mpi=False, - rocm=False, rocm_mpi=False, additional_tags: Optional[tuple] = None, **macros: str) -> None: + rocm=False, rocm_mpi=False, rocm_aware_mpi=False, additional_tags: Optional[tuple] = None, **macros: str) -> None: old_repos.discard(name) - have_compiler = compiler or mpi or cuda or cuda_mpi or cuda_aware_mpi - have_mpi = mpi or cuda_mpi or cuda_aware_mpi + have_compiler = compiler or mpi or cuda or cuda_mpi or cuda_aware_mpi or rocm_mpi or rocm_aware_mpi + have_mpi = mpi or cuda_mpi or cuda_aware_mpi or rocm_mpi or rocm_aware_mpi have_cuda = cuda or cuda_mpi or cuda_aware_mpi - have_rocm = rocm or rocm_mpi + have_rocm = rocm or rocm_mpi or rocm_aware_mpi if additional_tags is None: additional_tags = () @@ -918,6 +929,7 @@ Macros: repoconf.append("%is_cuda_aware_mpi_repository {0}".format(1 if cuda_aware_mpi else 0)) repoconf.append("%is_rocm_repository {0}".format(1 if rocm else 0)) repoconf.append("%is_rocm_mpi_repository {0}".format(1 if rocm_mpi else 0)) + repoconf.append("%is_rocm_aware_mpi_repository {0}".format(1 if rocm_aware_mpi else 0)) repoconf.append("%have_mpcdf_compiler {0}".format(1 if have_compiler else 0)) repoconf.append("%have_mpcdf_mpi {0}".format(1 if have_mpi else 0)) @@ -1030,6 +1042,23 @@ Macros: rocm_mpi=True, additional_tags=("Prefer: mpcdf_" + rocm,)) + repo(rocm + "_aware_" + mpi + "_" + compiler, + ((project, rocm + "_" + mpi + "_" + compiler),), + rocm_aware_mpi=True, + mpi_module=mpi_module(mpi).replace("openmpi", "openmpi_gpu"), + additional_tags=("Prefer: mpcdf_mpi_" + mpi + "_" + rocm, + "Prefer: mpcdf_" + rocm,)) + for of in openmpi_flavors: + dependencies = ((project, rocm + "_" + mpi + "_" + compiler),) + if not parent: + dependencies = openmpi_flavor_dependencies(of, distribution) + dependencies + repo(rocm + "_aware_" + mpi + "_" + compiler + "_" + of, + dependencies, + rocm_aware_mpi=True, + openmpi_flavor=openmpi_flavor_kind(of), + openmpi_flavor_full=of, + additional_tags=("Prefer: mpcdf_" + rocm,)) + if old_repos: if remove_old: for name in old_repos: