From caa2a8cfac80b4b3c7611d439b01ef21e57d9615 Mon Sep 17 00:00:00 2001
From: Thomas <purcell@fhi-berlin.mpg.de>
Date: Fri, 20 Aug 2021 15:08:32 +0200
Subject: [PATCH] Bug fix

Set _n_class for all ModelClassiers
---
 src/descriptor_identifier/model/ModelClassifier.cpp           | 4 ++++
 .../descriptor_identifier/ModelClassifier.cpp                 | 4 ++--
 2 files changed, 6 insertions(+), 2 deletions(-)

diff --git a/src/descriptor_identifier/model/ModelClassifier.cpp b/src/descriptor_identifier/model/ModelClassifier.cpp
index 74e75873..2b6a1459 100644
--- a/src/descriptor_identifier/model/ModelClassifier.cpp
+++ b/src/descriptor_identifier/model/ModelClassifier.cpp
@@ -62,6 +62,8 @@ ModelClassifier::ModelClassifier(
 ModelClassifier::ModelClassifier(const std::string train_file)
 {
     populate_model(train_file);
+    _n_class = _loss->n_class();
+
     int file_train_n_convex_overlap = _train_n_convex_overlap;
     _train_n_convex_overlap = 0;
 
@@ -80,6 +82,8 @@ ModelClassifier::ModelClassifier(const std::string train_file)
 ModelClassifier::ModelClassifier(const std::string train_file, std::string test_file)
 {
     populate_model(train_file, test_file);
+    _n_class = _loss->n_class();
+
     int file_train_n_convex_overlap = _train_n_convex_overlap;
     _train_n_convex_overlap = 0;
 
diff --git a/src/python/py_binding_cpp_def/descriptor_identifier/ModelClassifier.cpp b/src/python/py_binding_cpp_def/descriptor_identifier/ModelClassifier.cpp
index 787ca61d..33db3830 100644
--- a/src/python/py_binding_cpp_def/descriptor_identifier/ModelClassifier.cpp
+++ b/src/python/py_binding_cpp_def/descriptor_identifier/ModelClassifier.cpp
@@ -23,6 +23,7 @@
 
 ModelClassifier::ModelClassifier(const ModelClassifier& o, py::list new_coefs, np::ndarray prop_train_est, np::ndarray prop_test_est) :
     Model(o),
+    _n_class(o._n_class),
     _train_n_convex_overlap(o._train_n_convex_overlap),
     _test_n_convex_overlap(o._test_n_convex_overlap)
 {
@@ -31,7 +32,6 @@ ModelClassifier::ModelClassifier(const ModelClassifier& o, py::list new_coefs, n
     {
         _coefs.push_back(python_conv_utils::from_list<double>(coef_list));
     }
-
     std::vector<int> misclassified(_n_samp_train);
     std::transform(
         _loss->prop_train().begin(),
@@ -55,6 +55,7 @@ ModelClassifier::ModelClassifier(const ModelClassifier& o, py::list new_coefs, n
 
 ModelClassifier::ModelClassifier(const ModelClassifier& o, np::ndarray new_coefs, np::ndarray prop_train_est, np::ndarray prop_test_est) :
     Model(o),
+    _n_class(o._n_class),
     _train_n_convex_overlap(o._train_n_convex_overlap),
     _test_n_convex_overlap(o._test_n_convex_overlap)
 {
@@ -63,7 +64,6 @@ ModelClassifier::ModelClassifier(const ModelClassifier& o, np::ndarray new_coefs
     {
         _coefs.push_back(python_conv_utils::from_ndarray<double>(py::extract<np::ndarray>(new_coefs[ii])));
     }
-
     std::vector<int> misclassified(_n_samp_train);
     std::transform(
         _loss->prop_train().begin(),
-- 
GitLab