diff --git a/mpcdf_common.py b/mpcdf_common.py index 3e4344bc9fa37f41f11ffa273e508f8e64f1c86a..4cefbcff1041e9cd33db0f80435d02d6847a9a7f 100644 --- a/mpcdf_common.py +++ b/mpcdf_common.py @@ -300,6 +300,8 @@ def mpcdf_setup_repositories(api_url, project, distribution=None, parent=None, p is_cuda = kwargs.pop("cuda", False) is_cuda_mpi = kwargs.pop("cuda_mpi", False) + cuda_repo = kwargs.pop("cuda_repo", "") + have_compiler = is_compiler or is_mpi or is_cuda or is_cuda_mpi have_mpi = is_mpi or is_cuda_mpi have_cuda = is_cuda or is_cuda_mpi @@ -338,6 +340,9 @@ def mpcdf_setup_repositories(api_url, project, distribution=None, parent=None, p prjconf.append("%have_mpcdf_mpi {0}".format(1 if have_mpi else 0)) prjconf.append("%have_mpcdf_cuda {0}".format(1 if have_cuda else 0)) + if is_cuda: + prjconf.append("%cuda_repository {0}".format(cuda_repo)) + for macro, value in kwargs.items(): prjconf.append("%{0} {1}".format(macro, value)) @@ -358,7 +363,7 @@ def mpcdf_setup_repositories(api_url, project, distribution=None, parent=None, p for cuda in cudas: for compiler in filter(partial(valid_cuda, cuda), compilers): - repo(cuda + "_" + compiler, (project, compiler), cuda=True) + repo(cuda + "_" + compiler, (project, compiler), cuda=True, cuda_repo=cuda) for mpi in filter(partial(valid_mpi, compiler), mpis): repo(cuda + "_" + mpi + "_" + compiler, (project, cuda + "_" + compiler),