diff --git a/mpcdf_common.py b/mpcdf_common.py index 9f943280891807bc3d1e82df0590b600687e541e..35dc0dcbcd2232513551f83304b480ac36077a5e 100644 --- a/mpcdf_common.py +++ b/mpcdf_common.py @@ -838,11 +838,13 @@ Macros: if re.search(remove_old_matching, oldrepo.attrib["name"]): root.remove(oldrepo) - def repo(name, dependencies, compiler=False, mpi=False, cuda=False, cuda_mpi=False, cuda_aware_mpi=False, additional_tags=(), **macros): + def repo(name: str, dependencies: tuple, compiler=False, mpi=False, cuda=False, cuda_mpi=False, cuda_aware_mpi=False, + rocm=False, rocm_mpi=False, additional_tags: Optional[tuple] = None, **macros: str) -> None: old_repos.discard(name) have_compiler = compiler or mpi or cuda or cuda_mpi or cuda_aware_mpi have_mpi = mpi or cuda_mpi or cuda_aware_mpi have_cuda = cuda or cuda_mpi or cuda_aware_mpi + have_rocm = rocm or rocm_mpi if additional_tags is None: additional_tags = () @@ -896,10 +898,13 @@ Macros: repoconf.append("%is_cuda_repository {0}".format(1 if cuda else 0)) repoconf.append("%is_cuda_mpi_repository {0}".format(1 if cuda_mpi else 0)) repoconf.append("%is_cuda_aware_mpi_repository {0}".format(1 if cuda_aware_mpi else 0)) + repoconf.append("%is_rocm_repository {0}".format(1 if rocm else 0)) + repoconf.append("%is_rocm_mpi_repository {0}".format(1 if rocm_mpi else 0)) repoconf.append("%have_mpcdf_compiler {0}".format(1 if have_compiler else 0)) repoconf.append("%have_mpcdf_mpi {0}".format(1 if have_mpi else 0)) repoconf.append("%have_mpcdf_cuda {0}".format(1 if have_cuda else 0)) + repoconf.append("%have_mpcdf_rocm {0}".format(1 if have_rocm else 0)) if matching_mkl: single_matching_mkl, = matching_mkl @@ -934,6 +939,7 @@ Macros: compilers = overloaded_project_attribute(api_url, project, "MPCDF:compiler_modules") mpis = overloaded_project_attribute(api_url, project, "MPCDF:mpi_modules") cudas = overloaded_project_attribute(api_url, project, "MPCDF:cuda_modules") + rocms = overloaded_project_attribute(api_url, project, "MPCDF:rocm_modules") pgis = overloaded_project_attribute(api_url, project, "MPCDF:pgi_modules") openmpi_flavors = overloaded_project_attribute(api_url, project, "MPCDF:openmpi_flavors") @@ -995,6 +1001,17 @@ Macros: openmpi_flavor_full=of, additional_tags=("Prefer: mpcdf_" + cuda_from_compiler(cuda, compiler),)) + for rocm in rocms: + for compiler in filter(is_gcc_compiler, compilers): + repo(rocm + "_" + compiler, ((project, compiler),), rocm=True, + additional_tags=("Prefer: mpcdf_" + rocm,)) + for mpi in filter(partial(valid_mpi, compiler), filter(is_openmpi, mpis)): + repo(rocm + "_" + mpi + "_" + compiler, + ((project, rocm + "_" + compiler), + (project, mpi + "_" + compiler)), + rocm_mpi=True, + additional_tags=("Prefer: mpcdf_" + rocm,)) + if old_repos: if remove_old: for name in old_repos: