Skip to content
Snippets Groups Projects

Some fixes to fully support ROCm

Merged Tobias Melson requested to merge rocm_fixes into master
3 files
+ 26
4
Compare changes
  • Side-by-side
  • Inline
Files
3
+ 17
1
@@ -20,7 +20,7 @@ known_microarchs = {"haswell", "skylake", "znver2", "znver4"}
default_microarch = "haswell"
package_attributes = ["MPCDF:enable_repositories"]
config_attributes = ["MPCDF:compiler_modules", "MPCDF:cuda_modules", "MPCDF:mpi_modules", "MPCDF:pgi_modules", "MPCDF:openmpi_flavors"]
config_attributes = ["MPCDF:compiler_modules", "MPCDF:cuda_modules", "MPCDF:rocm_modules", "MPCDF:mpi_modules", "MPCDF:pgi_modules", "MPCDF:openmpi_flavors"]
intel_parallel_studio = {
"mpcdf_intel_parallel_studio_2017_7": {"compiler": "intel_17_0_7", "impi": "impi_2017_4", "mkl": "mkl_2017_4-module", },
@@ -552,12 +552,14 @@ def mpcdf_enable_repositories(api_url, project, package, verbose=False, dry_run=
compilers = overloaded_package_attribute(api_url, project, package, "MPCDF:compiler_modules")
mpis = overloaded_package_attribute(api_url, project, package, "MPCDF:mpi_modules")
cudas = overloaded_package_attribute(api_url, project, package, "MPCDF:cuda_modules")
rocms = overloaded_package_attribute(api_url, project, package, "MPCDF:rocm_modules")
pgis = overloaded_package_attribute(api_url, project, package, "MPCDF:pgi_modules")
openmpi_flavors = overloaded_package_attribute(api_url, project, package, "MPCDF:openmpi_flavors")
all_compilers = overloaded_project_attribute(api_url, project, "MPCDF:compiler_modules")
all_mpis = overloaded_project_attribute(api_url, project, "MPCDF:mpi_modules")
all_cudas = overloaded_project_attribute(api_url, project, "MPCDF:cuda_modules")
all_rocms = overloaded_project_attribute(api_url, project, "MPCDF:rocm_modules")
all_pgis = overloaded_project_attribute(api_url, project, "MPCDF:pgi_modules")
all_openmpi_flavors = overloaded_project_attribute(api_url, project, "MPCDF:openmpi_flavors")
@@ -603,6 +605,10 @@ def mpcdf_enable_repositories(api_url, project, package, verbose=False, dry_run=
for cuda in (c for c in cudas if c in all_cudas):
yield cuda
def actual_rocms():
for rocm in (r for r in rocms if r in all_rocms):
yield rocm
def actual_pgis():
for pgi in (p for p in pgis if p in all_pgis):
yield pgi
@@ -649,6 +655,16 @@ def mpcdf_enable_repositories(api_url, project, package, verbose=False, dry_run=
if valid_cuda(cuda, compiler) and valid_mpi(compiler, mpi):
enable(cuda + "_" + mpi + "_" + compiler)
if flag == "rocm":
for rocm, compiler in product(actual_rocms(), all_compilers):
if is_gcc_compiler(compiler):
enable(rocm + "_" + compiler)
if flag == "rocm_mpi":
for rocm, mpi, compiler in product(actual_rocms(), actual_mpis(), all_compilers):
if is_gcc_compiler(compiler) and is_openmpi(mpi) and valid_mpi(compiler, mpi):
enable(rocm + "_" + mpi + "_" + compiler)
if flag == "pgi":
for pgi in actual_pgis():
enable(pgi)
Loading