From 24e843be3b8a7198dfc0529cf1878c7867e4df94 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lorenz=20H=C3=BCdepohl?= <dev@stellardeath.org> Date: Tue, 2 Jul 2024 12:58:07 +0200 Subject: [PATCH] First support for ROCm repositories --- mpcdf_common.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/mpcdf_common.py b/mpcdf_common.py index 9f94328..35dc0dc 100644 --- a/mpcdf_common.py +++ b/mpcdf_common.py @@ -838,11 +838,13 @@ Macros: if re.search(remove_old_matching, oldrepo.attrib["name"]): root.remove(oldrepo) - def repo(name, dependencies, compiler=False, mpi=False, cuda=False, cuda_mpi=False, cuda_aware_mpi=False, additional_tags=(), **macros): + def repo(name: str, dependencies: tuple, compiler=False, mpi=False, cuda=False, cuda_mpi=False, cuda_aware_mpi=False, + rocm=False, rocm_mpi=False, additional_tags: Optional[tuple] = None, **macros: str) -> None: old_repos.discard(name) have_compiler = compiler or mpi or cuda or cuda_mpi or cuda_aware_mpi have_mpi = mpi or cuda_mpi or cuda_aware_mpi have_cuda = cuda or cuda_mpi or cuda_aware_mpi + have_rocm = rocm or rocm_mpi if additional_tags is None: additional_tags = () @@ -896,10 +898,13 @@ Macros: repoconf.append("%is_cuda_repository {0}".format(1 if cuda else 0)) repoconf.append("%is_cuda_mpi_repository {0}".format(1 if cuda_mpi else 0)) repoconf.append("%is_cuda_aware_mpi_repository {0}".format(1 if cuda_aware_mpi else 0)) + repoconf.append("%is_rocm_repository {0}".format(1 if rocm else 0)) + repoconf.append("%is_rocm_mpi_repository {0}".format(1 if rocm_mpi else 0)) repoconf.append("%have_mpcdf_compiler {0}".format(1 if have_compiler else 0)) repoconf.append("%have_mpcdf_mpi {0}".format(1 if have_mpi else 0)) repoconf.append("%have_mpcdf_cuda {0}".format(1 if have_cuda else 0)) + repoconf.append("%have_mpcdf_rocm {0}".format(1 if have_rocm else 0)) if matching_mkl: single_matching_mkl, = matching_mkl @@ -934,6 +939,7 @@ Macros: compilers = overloaded_project_attribute(api_url, project, "MPCDF:compiler_modules") mpis = overloaded_project_attribute(api_url, project, "MPCDF:mpi_modules") cudas = overloaded_project_attribute(api_url, project, "MPCDF:cuda_modules") + rocms = overloaded_project_attribute(api_url, project, "MPCDF:rocm_modules") pgis = overloaded_project_attribute(api_url, project, "MPCDF:pgi_modules") openmpi_flavors = overloaded_project_attribute(api_url, project, "MPCDF:openmpi_flavors") @@ -995,6 +1001,17 @@ Macros: openmpi_flavor_full=of, additional_tags=("Prefer: mpcdf_" + cuda_from_compiler(cuda, compiler),)) + for rocm in rocms: + for compiler in filter(is_gcc_compiler, compilers): + repo(rocm + "_" + compiler, ((project, compiler),), rocm=True, + additional_tags=("Prefer: mpcdf_" + rocm,)) + for mpi in filter(partial(valid_mpi, compiler), filter(is_openmpi, mpis)): + repo(rocm + "_" + mpi + "_" + compiler, + ((project, rocm + "_" + compiler), + (project, mpi + "_" + compiler)), + rocm_mpi=True, + additional_tags=("Prefer: mpcdf_" + rocm,)) + if old_repos: if remove_old: for name in old_repos: -- GitLab