From e09521eb5209d70bf2630bb62836f20942c3c32f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Lorenz=20H=C3=BCdepohl?= <dev@stellardeath.org>
Date: Thu, 8 Apr 2021 14:56:30 +0200
Subject: [PATCH] Add cuda_aware_mpi repository type

---
 mpcdf_common.py | 18 ++++++++++++++----
 1 file changed, 14 insertions(+), 4 deletions(-)

diff --git a/mpcdf_common.py b/mpcdf_common.py
index 9bf75c6..d50cf62 100644
--- a/mpcdf_common.py
+++ b/mpcdf_common.py
@@ -574,6 +574,11 @@ def mpcdf_enable_repositories(api_url, project, package, verbose=False, dry_run=
                     if valid_pgi_mpi(pgi, mpi):
                         enable(mpi + "_" + pgi)
 
+            if flag == "cuda_aware_mpi":
+                for cuda, mpi, compiler in product(actual_cudas(), actual_mpis(), all_compilers):
+                    if valid_cuda(cuda, compiler) and valid_mpi(compiler, mpi):
+                        enable(cuda + "_aware_" + mpi + "_" + compiler)
+
     if len(build.getchildren()) > 0:
         build.getchildren()[-1].tail = "\n  "
 
@@ -662,11 +667,11 @@ Macros:
         for oldrepo in root.findall("./repository"):
             root.remove(oldrepo)
 
-    def repo(name, dependencies, compiler=False, mpi=False, cuda=False, cuda_mpi=False, additional_tags=(), **macros):
+    def repo(name, dependencies, compiler=False, mpi=False, cuda=False, cuda_mpi=False, cuda_aware_mpi=False, additional_tags=(), **macros):
         old_repos.discard(name)
-        have_compiler = compiler or mpi or cuda or cuda_mpi
-        have_mpi = mpi or cuda_mpi
-        have_cuda = cuda or cuda_mpi
+        have_compiler = compiler or mpi or cuda or cuda_mpi or cuda_aware_mpi
+        have_mpi = mpi or cuda_mpi or cuda_aware_mpi
+        have_cuda = cuda or cuda_mpi or cuda_aware_mpi
 
         existing_repo = root.find("./repository[@name='{0}']".format(name))
         if existing_repo is not None:
@@ -713,6 +718,7 @@ Macros:
         repoconf.append("%is_mpi_repository {0}".format(1 if mpi else 0))
         repoconf.append("%is_cuda_repository {0}".format(1 if cuda else 0))
         repoconf.append("%is_cuda_mpi_repository {0}".format(1 if cuda_mpi else 0))
+        repoconf.append("%is_cuda_aware_mpi_repository {0}".format(1 if cuda_aware_mpi else 0))
 
         repoconf.append("%have_mpcdf_compiler {0}".format(1 if have_compiler else 0))
         repoconf.append("%have_mpcdf_mpi {0}".format(1 if have_mpi else 0))
@@ -805,6 +811,10 @@ Macros:
                       (project, mpi + "_" + compiler)),
                      cuda_mpi=True)
 
+                repo(cuda + "_aware_" + mpi + "_" + compiler,
+                     ((project, cuda + "_" + mpi + "_" + compiler),),
+                     cuda_aware_mpi=True)
+
     if old_repos and not remove_old:
         print("Warning: Keeping the prjconf sections for the following obsolete repositories:")
         for name in sorted(old_repos):
-- 
GitLab