From b2266ec498838e03a7fb2d56e734e5f98441f0cf Mon Sep 17 00:00:00 2001
From: Sebastian Eibl <sebastian.eibl@mpcdf.mpg.de>
Date: Fri, 1 Apr 2022 15:28:12 +0200
Subject: [PATCH] set_b on gpu

---
 src/loss_function/LossFunctionPearsonRMSE.cpp | 15 ++++++++++-----
 1 file changed, 10 insertions(+), 5 deletions(-)

diff --git a/src/loss_function/LossFunctionPearsonRMSE.cpp b/src/loss_function/LossFunctionPearsonRMSE.cpp
index aaf18e58..07e8d34e 100644
--- a/src/loss_function/LossFunctionPearsonRMSE.cpp
+++ b/src/loss_function/LossFunctionPearsonRMSE.cpp
@@ -341,7 +341,8 @@ Kokkos::View<double*> LossFunctionPearsonRMSE::operator()(
         start += _task_sizes_train[task_idx];
     }
 
-    ::get_mean_squared_difference(_batched_scores, _training_properties, _estimated_training_properties);
+    ::get_mean_squared_difference(
+        _batched_scores, _training_properties, _estimated_training_properties);
 
     return _batched_scores;
 }
@@ -427,7 +428,6 @@ void LossFunctionPearsonRMSE::set_a(const std::vector<std::vector<int>>& feature
                                                                   models(feature_idx, model_idx));
     };
     Kokkos::parallel_for("LossFunctionPearsonRMSE::set_a", policy, kernel);
-    Kokkos::fence();
 }
 
 void LossFunctionPearsonRMSE::set_a(const std::vector<model_node_ptr>& feats,
@@ -442,10 +442,15 @@ void LossFunctionPearsonRMSE::set_a(const std::vector<model_node_ptr>& feats,
 
 void LossFunctionPearsonRMSE::set_b(int taskind, int start)
 {
-    for (size_t batch_idx = 0; batch_idx < MAX_BATCHES; ++batch_idx)
+    auto b = _b;
+    auto training_properties = _training_properties;
+    auto policy = Kokkos::MDRangePolicy<Kokkos::Rank<2>>({0, 0},
+                                                         {_task_sizes_train[taskind], MAX_BATCHES});
+    auto kernel = KOKKOS_LAMBDA(const int material_idx, const int batch_idx)
     {
-        std::copy_n(&_prop_train[start], _task_sizes_train[taskind], &_b(0, batch_idx));
-    }
+        b(material_idx, batch_idx) = training_properties(start + material_idx);
+    };
+    Kokkos::parallel_for("LossFunctionPearsonRMSE::set_b", policy, kernel);
 }
 
 int LossFunctionPearsonRMSE::least_squares(int taskind, int start)
-- 
GitLab