diff --git a/src/descriptor_identifier/solver/KSmallest.hpp b/src/descriptor_identifier/solver/KSmallest.hpp new file mode 100644 index 0000000000000000000000000000000000000000..63f7d0fdee2b050f73dcb93b057a9d7ae4a1fa94 --- /dev/null +++ b/src/descriptor_identifier/solver/KSmallest.hpp @@ -0,0 +1,46 @@ +// Copyright 2021 Thomas A. R. Purcell +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/** + * @author Sebastian Eibl <sebastian.eibl@mpcdf.mpg.de> + */ + +#pragma once + +#include <cstdint> +#include <vector> + +template <typename T> +class KSmallest +{ +public: + KSmallest(const int64_t k, const T& initial_value) : _max_idx(0), _items(k, initial_value) {} + bool insert(const T& item) + { + if (_items[_max_idx] > item) + { + _items[_max_idx] = item; + _max_idx = std::max_element(_items.begin(), _items.end()) - + _items.begin(); + return true; + } + return false; + } + + auto& data() {return _items;} + +private: + int64_t _max_idx; + std::vector<T> _items; +}; \ No newline at end of file diff --git a/src/descriptor_identifier/solver/SISSOSolver.cpp b/src/descriptor_identifier/solver/SISSOSolver.cpp index 5011a2a0c47b81f6f6637d782c48446f3b5a2b27..af2f9425ee4a20150649968e73dab8174c7b3ad1 100644 --- a/src/descriptor_identifier/solver/SISSOSolver.cpp +++ b/src/descriptor_identifier/solver/SISSOSolver.cpp @@ -21,6 +21,7 @@ #include "descriptor_identifier/solver/SISSOSolver.hpp" +#include "KSmallest.hpp" #include "loss_function/RMSEGPU.hpp" #include "utils/DescriptorMatrix.hpp" #include "utils/EnumerateUniqueCombinations.hpp" @@ -184,12 +185,10 @@ void SISSOSolver::l0_regularization(const int n_dim) void SISSOSolver::l0_regularization_gpu(const int n_dim) { const int n_get_models = std::max(_n_residual, _n_models_store); - std::vector<inds_sc_pair> min_sc_inds(n_get_models); + KSmallest<inds_sc_pair> best_feature_combinations(n_get_models, inds_sc_pair()); setup_regulairzation(); - int max_error_ind = 0; - DescriptorMatrix descriptorMatrix; PropertiesVector properties(_loss->prop_train()); RMSEGPU loss(descriptorMatrix.getDeviceDescriptorMatrix(), @@ -217,19 +216,13 @@ void SISSOSolver::l0_regularization_gpu(const int n_dim) for (size_t model_idx = 0; model_idx < feature_indices.size(); ++model_idx) { - if (scores[model_idx] <= min_sc_inds[max_error_ind].score()) - { - update_min_inds_scores( - feature_indices[model_idx], scores[model_idx], max_error_ind, min_sc_inds); - max_error_ind = std::max_element(min_sc_inds.begin(), min_sc_inds.end()) - - min_sc_inds.begin(); - } + best_feature_combinations.insert({feature_indices[model_idx], scores[model_idx]}); } } std::vector<inds_sc_pair> all_min_sc_inds(_mpi_comm->size() * n_get_models); - mpi::all_gather(*_mpi_comm, min_sc_inds.data(), n_get_models, all_min_sc_inds); + mpi::all_gather(*_mpi_comm, best_feature_combinations.data().data(), n_get_models, all_min_sc_inds); auto inds = util_funcs::argsort<inds_sc_pair>(all_min_sc_inds); std::vector<std::vector<int>> indexes(n_get_models, std::vector<int>(n_dim)); diff --git a/tests/googletest/descriptor_identification/solver/test_sisso_ksmallest.cc b/tests/googletest/descriptor_identification/solver/test_sisso_ksmallest.cc new file mode 100644 index 0000000000000000000000000000000000000000..5f73053c6aa4c6b22396801135e851eea643b4ed --- /dev/null +++ b/tests/googletest/descriptor_identification/solver/test_sisso_ksmallest.cc @@ -0,0 +1,36 @@ +// Copyright 2021 Thomas A. R. Purcell +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include <gtest/gtest.h> + +#include <cstdint> +#include <descriptor_identifier/solver/KSmallest.hpp> + +TEST(KSmallest, ksmallest) +{ + KSmallest<int64_t> ksmallest(10, std::numeric_limits<int64_t>::max()); + + for (auto& val : ksmallest.data()) + { + EXPECT_GT(val, 10); + } + + for (int64_t val = 100; val > 0; --val) + ksmallest.insert(val); + + EXPECT_EQ(ksmallest.data().size(), 10); + for (auto& val : ksmallest.data()) + { + EXPECT_LE(val, 10); + } +} \ No newline at end of file