Skip to content
Snippets Groups Projects
Commit 33af5777 authored by Sebastian Eibl's avatar Sebastian Eibl
Browse files

refactored k smallest

parent 230c8d34
No related branches found
No related tags found
No related merge requests found
// Copyright 2021 Thomas A. R. Purcell
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/**
* @author Sebastian Eibl <sebastian.eibl@mpcdf.mpg.de>
*/
#pragma once
#include <cstdint>
#include <vector>
template <typename T>
class KSmallest
{
public:
KSmallest(const int64_t k, const T& initial_value) : _max_idx(0), _items(k, initial_value) {}
bool insert(const T& item)
{
if (_items[_max_idx] > item)
{
_items[_max_idx] = item;
_max_idx = std::max_element(_items.begin(), _items.end()) -
_items.begin();
return true;
}
return false;
}
auto& data() {return _items;}
private:
int64_t _max_idx;
std::vector<T> _items;
};
\ No newline at end of file
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "descriptor_identifier/solver/SISSOSolver.hpp" #include "descriptor_identifier/solver/SISSOSolver.hpp"
#include "KSmallest.hpp"
#include "loss_function/RMSEGPU.hpp" #include "loss_function/RMSEGPU.hpp"
#include "utils/DescriptorMatrix.hpp" #include "utils/DescriptorMatrix.hpp"
#include "utils/EnumerateUniqueCombinations.hpp" #include "utils/EnumerateUniqueCombinations.hpp"
...@@ -184,12 +185,10 @@ void SISSOSolver::l0_regularization(const int n_dim) ...@@ -184,12 +185,10 @@ void SISSOSolver::l0_regularization(const int n_dim)
void SISSOSolver::l0_regularization_gpu(const int n_dim) void SISSOSolver::l0_regularization_gpu(const int n_dim)
{ {
const int n_get_models = std::max(_n_residual, _n_models_store); const int n_get_models = std::max(_n_residual, _n_models_store);
std::vector<inds_sc_pair> min_sc_inds(n_get_models); KSmallest<inds_sc_pair> best_feature_combinations(n_get_models, inds_sc_pair());
setup_regulairzation(); setup_regulairzation();
int max_error_ind = 0;
DescriptorMatrix descriptorMatrix; DescriptorMatrix descriptorMatrix;
PropertiesVector properties(_loss->prop_train()); PropertiesVector properties(_loss->prop_train());
RMSEGPU loss(descriptorMatrix.getDeviceDescriptorMatrix(), RMSEGPU loss(descriptorMatrix.getDeviceDescriptorMatrix(),
...@@ -217,19 +216,13 @@ void SISSOSolver::l0_regularization_gpu(const int n_dim) ...@@ -217,19 +216,13 @@ void SISSOSolver::l0_regularization_gpu(const int n_dim)
for (size_t model_idx = 0; model_idx < feature_indices.size(); ++model_idx) for (size_t model_idx = 0; model_idx < feature_indices.size(); ++model_idx)
{ {
if (scores[model_idx] <= min_sc_inds[max_error_ind].score()) best_feature_combinations.insert({feature_indices[model_idx], scores[model_idx]});
{
update_min_inds_scores(
feature_indices[model_idx], scores[model_idx], max_error_ind, min_sc_inds);
max_error_ind = std::max_element(min_sc_inds.begin(), min_sc_inds.end()) -
min_sc_inds.begin();
}
} }
} }
std::vector<inds_sc_pair> all_min_sc_inds(_mpi_comm->size() * n_get_models); std::vector<inds_sc_pair> all_min_sc_inds(_mpi_comm->size() * n_get_models);
mpi::all_gather(*_mpi_comm, min_sc_inds.data(), n_get_models, all_min_sc_inds); mpi::all_gather(*_mpi_comm, best_feature_combinations.data().data(), n_get_models, all_min_sc_inds);
auto inds = util_funcs::argsort<inds_sc_pair>(all_min_sc_inds); auto inds = util_funcs::argsort<inds_sc_pair>(all_min_sc_inds);
std::vector<std::vector<int>> indexes(n_get_models, std::vector<int>(n_dim)); std::vector<std::vector<int>> indexes(n_get_models, std::vector<int>(n_dim));
......
// Copyright 2021 Thomas A. R. Purcell
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <cstdint>
#include <descriptor_identifier/solver/KSmallest.hpp>
TEST(KSmallest, ksmallest)
{
KSmallest<int64_t> ksmallest(10, std::numeric_limits<int64_t>::max());
for (auto& val : ksmallest.data())
{
EXPECT_GT(val, 10);
}
for (int64_t val = 100; val > 0; --val)
ksmallest.insert(val);
EXPECT_EQ(ksmallest.data().size(), 10);
for (auto& val : ksmallest.data())
{
EXPECT_LE(val, 10);
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment