From e53a09f578cf6ecccda57d7918ee27d48d08c86a Mon Sep 17 00:00:00 2001 From: Thomas Purcell <purcell@fhi-berlin.mpg.de> Date: Wed, 23 Sep 2020 22:05:50 +0200 Subject: [PATCH] Update bindings/main/tests for the new no fix_intercept policy Classification without fix_intercept now included in all files --- src/main.cpp | 2 +- src/python/bindings_docstring_keyed.cpp | 4 ++-- tests/test_classification/test_classification.py | 12 +----------- 3 files changed, 4 insertions(+), 14 deletions(-) diff --git a/src/main.cpp b/src/main.cpp index b6b12267..575046b0 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -57,7 +57,7 @@ int main(int argc, char const *argv[]) } else if(IP._calc_type.compare("classification") == 0) { - SISSOClassifier sisso(IP._feat_space, IP._prop_unit, IP._prop_train, IP._prop_test, IP._task_sizes_train, IP._task_sizes_test, IP._leave_out_inds, IP._n_dim, IP._n_residuals, IP._n_models_store, IP._fix_intercept); + SISSOClassifier sisso(IP._feat_space, IP._prop_unit, IP._prop_train, IP._prop_test, IP._task_sizes_train, IP._task_sizes_test, IP._leave_out_inds, IP._n_dim, IP._n_residuals, IP._n_models_store); sisso.fit(); if(mpi_setup::comm->rank() == 0) diff --git a/src/python/bindings_docstring_keyed.cpp b/src/python/bindings_docstring_keyed.cpp index a9e57c35..4c6bf47e 100644 --- a/src/python/bindings_docstring_keyed.cpp +++ b/src/python/bindings_docstring_keyed.cpp @@ -438,8 +438,8 @@ void sisso::descriptor_identifier::registerSISSORegressor() void sisso::descriptor_identifier::registerSISSOClassifier() { - class_<SISSOClassifier, bases<SISSO_DI>>("SISSOClassifier", init<std::shared_ptr<FeatureSpace>, Unit, np::ndarray, np::ndarray, py::list, py::list, py::list, int, int, int, optional<bool>>()) - .def(init<std::shared_ptr<FeatureSpace>, Unit, py::list, py::list, py::list, py::list, py::list, int, int, int, optional<bool>>()) + class_<SISSOClassifier, bases<SISSO_DI>>("SISSOClassifier", init<std::shared_ptr<FeatureSpace>, Unit, np::ndarray, np::ndarray, py::list, py::list, py::list, int, int, int>()) + .def(init<std::shared_ptr<FeatureSpace>, Unit, py::list, py::list, py::list, py::list, py::list, int, int, int>()) .def("fit", &SISSOClassifier::fit, "@DocString_sisso_class_fit@") .add_property("models", &SISSOClassifier::models_py, "@DocString_sisso_class_models_py@") ; diff --git a/tests/test_classification/test_classification.py b/tests/test_classification/test_classification.py index bee82dca..5de1e0aa 100644 --- a/tests/test_classification/test_classification.py +++ b/tests/test_classification/test_classification.py @@ -76,17 +76,7 @@ def test_sisso_classifier(): feat_space = generate_fs(phi_0, prop, [80], op_set, "classification", 1, 10) sisso = SISSOClassifier( - feat_space, - Unit("m"), - prop, - prop_test, - [80], - [20], - list(range(10)), - 2, - 1, - 1, - False, + feat_space, Unit("m"), prop, prop_test, [80], [20], list(range(10)), 2, 1, 1 ) sisso.fit() -- GitLab