Commit cc6f3d1d authored by Luigi Sbailo's avatar Luigi Sbailo
Browse files

Sisso training adjusted and visualizer creates the convex hull

parent 5e5b84e7
......@@ -45,8 +45,8 @@
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2020-09-18T17:12:15.289136Z",
"start_time": "2020-09-18T17:12:15.266454Z"
"end_time": "2020-09-21T15:37:52.170260Z",
"start_time": "2020-09-21T15:37:52.152299Z"
}
},
"outputs": [],
......@@ -57,11 +57,11 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2020-09-18T17:12:17.341570Z",
"start_time": "2020-09-18T17:12:17.278901Z"
"end_time": "2020-09-21T15:39:21.528179Z",
"start_time": "2020-09-21T15:37:52.172263Z"
}
},
"outputs": [],
......@@ -70,30 +70,18 @@
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"df_train = pd.read_csv(\"./data/topological_insulators/train.csv\", index_col=0).astype(float)\n",
"from topological_insulators.visualizer import Visualizer"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"ExecuteTime": {
"end_time": "2020-09-18T16:22:32.147790Z",
"start_time": "2020-09-18T16:21:05.396108Z"
}
},
"outputs": [],
"source": [
"df_train = pd.read_csv(\"./data/topological_insulators/train.csv\", index_col=0).astype(float)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2020-09-17T14:53:59.475551Z",
"start_time": "2020-09-17T14:53:55.205322Z"
"end_time": "2020-09-21T15:39:26.048895Z",
"start_time": "2020-09-21T15:39:21.529550Z"
}
},
"outputs": [],
......@@ -128,11 +116,11 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 55,
"metadata": {
"ExecuteTime": {
"end_time": "2020-09-17T15:07:22.181242Z",
"start_time": "2020-09-17T15:07:22.111580Z"
"end_time": "2020-09-21T18:23:15.349088Z",
"start_time": "2020-09-21T18:23:15.268450Z"
}
},
"outputs": [
......@@ -758,7 +746,7 @@
"[152 rows x 160001 columns]"
]
},
"execution_count": 13,
"execution_count": 55,
"metadata": {},
"output_type": "execute_result"
}
......@@ -769,25 +757,11 @@
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"ExecuteTime": {
"end_time": "2020-09-17T14:54:16.222565Z",
"start_time": "2020-09-17T14:54:00.266030Z"
}
},
"outputs": [],
"source": [
"sisso.fit()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2020-09-17T14:54:18.151143Z",
"start_time": "2020-09-17T14:54:18.133529Z"
"end_time": "2020-09-21T15:39:40.566746Z",
"start_time": "2020-09-21T15:39:40.556019Z"
}
},
"outputs": [],
......@@ -797,11 +771,11 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 23,
"metadata": {
"ExecuteTime": {
"end_time": "2020-09-17T14:54:28.990112Z",
"start_time": "2020-09-17T14:54:28.969771Z"
"end_time": "2020-09-21T16:27:14.736332Z",
"start_time": "2020-09-21T16:27:14.708845Z"
},
"scrolled": true
},
......@@ -810,16 +784,16 @@
"feat_1=model.feats[0].value\n",
"feat_0=model.feats[1].value\n",
"compounds=df_train.index.to_list()\n",
"classified=np.concatenate([np.ones(67),np.zeros(85)])"
"classified=model.prop_train"
]
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 24,
"metadata": {
"ExecuteTime": {
"end_time": "2020-09-17T14:54:29.773658Z",
"start_time": "2020-09-17T14:54:29.751318Z"
"end_time": "2020-09-21T16:27:15.346299Z",
"start_time": "2020-09-21T16:27:15.315502Z"
}
},
"outputs": [],
......@@ -832,26 +806,19 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 90,
"metadata": {
"ExecuteTime": {
"end_time": "2020-09-17T14:54:33.791053Z",
"start_time": "2020-09-17T14:54:33.164038Z"
"end_time": "2020-09-21T19:34:32.446358Z",
"start_time": "2020-09-21T19:34:31.933201Z"
},
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"hi\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a9e759530ba04a31a45a14e91ff96737",
"model_id": "a2a0a14863e34412865fd019002bea96",
"version_major": 2,
"version_minor": 0
},
......@@ -864,8 +831,15 @@
}
],
"source": [
"Visualizer(df).show()"
"Visualizer(df, sisso).show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
......
%% Cell type:markdown id: tags:
<div id="teaser" style=' background-position: right center; background-size: 00px; background-repeat: no-repeat;
padding-top: 20px;
padding-right: 10px;
padding-bottom: 170px;
padding-left: 10px;
border-bottom: 14px double #333;
border-top: 14px double #333;' >
<div style="text-align:center">
<b><font size="6.4">Artificial intelligence for high-throughput discovery of topological insulators</font></b>
</div>
<p>
<span class="nomad--last-updated" data-version="v1.0.0">[Last updated: Sep 17, 2020]</span>
<div>
<img style="float: left;" src="assets/topological_insulators/Logo_MPG.png" width="200">
<img style="float: right;" src="assets/topological_insulators/Logo_NOMAD.png" width="250">
</div>
</div>
%% Cell type:markdown id: tags:
Insulator discovery
%% Cell type:code id: tags:
``` python
%load_ext autoreload
%autoreload 2
```
%% Cell type:code id: tags:
``` python
from cpp_sisso import get_max_number_feats, get_estimate_n_feat_next_rung, generate_fs, SISSOClassifier, generate_phi_0_from_csv, FeatureSpace
import numpy as np
import pandas as pd
from topological_insulators.visualizer import Visualizer
```
%% Cell type:code id: tags:
``` python
df_train = pd.read_csv("./data/topological_insulators/train.csv", index_col=0).astype(float)
from topological_insulators.visualizer import Visualizer
```
%% Cell type:code id: tags:
``` python
phi_0, prop_unit, prop, prop_test, task_sizes_train, task_sizes_test, leave_out_inds = generate_phi_0_from_csv(
df_train, "Class", cols="all", task_key=None, leave_out_frac=0.0
)
feat_space = generate_fs(
phi_0,
prop,
task_sizes_train,
["add", "sub", "mult", "div", "abs_diff", "sq", "cb", "sqrt", "cbrt", "inv", "abs"],
"classification",
0,
50
)
sisso = SISSOClassifier(
feat_space,
prop_unit,
prop,
prop_test,
task_sizes_train,
task_sizes_test,
leave_out_inds,
2,
10,
10,
)
```
%% Cell type:code id: tags:
``` python
df_train
```
%%%% Output: execute_result
[[[Z11A*Z11B]/sqrt[X16E]]/[[Z11A/Z11C]+[Z11B/Z11D]]] \
compound
Sb_Sb_Te_Te_Te 910.70
As_Bi_Te_Te_S 764.41
Bi_Bi_Te_Se_Se 1068.50
Bi_Sb_Te_S_Te 607.74
Bi_As_Te_Se_Te 732.90
... ...
Bi_Sb_S_S_Se 316.51
Sb_Sb_Te_Se_S 652.75
Sb_Sb_Se_Te_S 652.75
Bi_Bi_S_S_S 413.39
As_As_Te_Te_Se 537.30
[[[Z11A*Z11B]/[X16A+X16E]]/[[Z11A/Z11C]+[Z11B/Z11D]]] \
compound
Sb_Sb_Te_Te_Te 317.99
As_Bi_Te_Te_S 257.95
Bi_Bi_Te_Se_Se 373.38
Bi_Sb_Te_S_Te 213.74
Bi_As_Te_Se_Te 257.76
... ...
Bi_Sb_S_S_Se 110.60
Sb_Sb_Te_Se_S 226.45
Sb_Sb_Se_Te_S 226.45
Bi_Bi_S_S_S 144.35
As_As_Te_Te_Se 181.40
[[[Z11A*Z11B]/[X16B+X16E]]/[[Z11A/Z11C]+[Z11B/Z11D]]] \
compound
Sb_Sb_Te_Te_Te 317.99
As_Bi_Te_Te_S 266.92
Bi_Bi_Te_Se_Se 373.38
Bi_Sb_Te_S_Te 212.20
Bi_As_Te_Se_Te 248.17
... ...
Bi_Sb_S_S_Se 109.88
Sb_Sb_Te_Se_S 226.45
Sb_Sb_Se_Te_S 226.45
Bi_Bi_S_S_S 144.35
As_As_Te_Te_Se 181.40
[[[Z11A+Z11B]/sqrt[X16E]]/[[X16A/Z11D]+[X16B/Z11C]]] \
compound
Sb_Sb_Te_Te_Te 888.49
As_Bi_Te_Te_S 894.13
Bi_Bi_Te_Se_Se 1058.00
Bi_Sb_Te_S_Te 555.50
Bi_As_Te_Se_Te 786.20
... ...
Bi_Sb_S_S_Se 329.88
Sb_Sb_Te_Se_S 636.83
Sb_Sb_Se_Te_S 636.83
Bi_Bi_S_S_S 409.30
As_As_Te_Te_Se 492.94
[[[Z11A+Z11B]*[Z11C+Z11D]]-abs[[Z11A*Z11D]-[Z11B*Z11C]]] \
compound
Sb_Sb_Te_Te_Te 10608.0
As_Bi_Te_Te_S 9464.0
Bi_Bi_Te_Se_Se 12782.0
Bi_Sb_Te_S_Te 7788.0
Bi_As_Te_Se_Te 8870.0
... ...
Bi_Sb_S_S_Se 3776.0
Sb_Sb_Te_Se_S 7854.0
Sb_Sb_Se_Te_S 7854.0
Bi_Bi_S_S_S 5312.0
As_As_Te_Te_Se 6864.0
[[Z11A*Z11B]/[[Z11A/Z11C]+[Z11B/Z11D]]] \
compound
Sb_Sb_Te_Te_Te 1326.00
As_Bi_Te_Te_S 1227.80
Bi_Bi_Te_Se_Se 1706.30
Bi_Sb_Te_S_Te 884.89
Bi_As_Te_Se_Te 1067.10
... ...
Bi_Sb_S_S_Se 505.43
Sb_Sb_Te_Se_S 1048.50
Sb_Sb_Se_Te_S 1048.50
Bi_Bi_S_S_S 664.00
As_As_Te_Te_Se 858.00
[[[Z11A+Z11B]/sqrt[X16E]]/[[Z11C]^-1+[Z11D]^-1]] \
compound
Sb_Sb_Te_Te_Te 1821.40
As_Bi_Te_Te_S 1877.70
Bi_Bi_Te_Se_Se 2137.10
Bi_Sb_Te_S_Te 1126.00
Bi_As_Te_Se_Te 1637.80
... ...
Bi_Sb_S_S_Se 671.31
Sb_Sb_Te_Se_S 1305.50
Sb_Sb_Se_Te_S 1305.50
Bi_Bi_S_S_S 826.78
As_As_Te_Te_Se 1074.60
[[[Z11A/X16B]+[Z11B/X16E]]/[[X16A/Z11D]+[X16B/Z11C]]] \
compound
Sb_Sb_Te_Te_Te 620.63
As_Bi_Te_Te_S 600.57
Bi_Bi_Te_Se_Se 749.44
Bi_Sb_Te_S_Te 389.59
Bi_As_Te_Se_Te 529.33
... ...
Bi_Sb_S_S_Se 237.79
Sb_Sb_Te_Se_S 447.72
Sb_Sb_Se_Te_S 447.72
Bi_Bi_S_S_S 290.14
As_As_Te_Te_Se 334.88
[[[Z11A/X16E]+[Z11B/X16A]]/[[X16A/Z11D]+[X16B/Z11C]]] \
compound
Sb_Sb_Te_Te_Te 620.63
As_Bi_Te_Te_S 629.75
Bi_Bi_Te_Se_Se 749.44
Bi_Sb_Te_S_Te 388.71
Bi_As_Te_Se_Te 547.57
... ...
Bi_Sb_S_S_Se 227.21
Sb_Sb_Te_Se_S 447.72
Sb_Sb_Se_Te_S 447.72
Bi_Bi_S_S_S 290.14
As_As_Te_Te_Se 334.88
[[sqrt[X16D]*[Z11A+Z11B]]/[[X16D/Z11C]+[X16E/Z11D]]] ... \
compound ...
Sb_Sb_Te_Te_Te 1821.40 ...
As_Bi_Te_Te_S 1868.70 ...
Bi_Bi_Te_Se_Se 2137.10 ...
Bi_Sb_Te_S_Te 1181.90 ...
Bi_As_Te_Se_Te 1662.90 ...
... ... ...
Bi_Sb_S_S_Se 671.30 ...
Sb_Sb_Te_Se_S 1303.90 ...
Sb_Sb_Se_Te_S 1326.40 ...
Bi_Bi_S_S_S 826.78 ...
As_As_Te_Te_Se 1070.00 ...
[[abs[X16B-X16D]/log[Z11A]]/[[X16C-X16E]-abs[X16A-X16E]]] \
compound
Sb_Sb_Te_Te_Te -0.254330
As_Bi_Te_Te_S -0.033256
Bi_Bi_Te_Se_Se -0.124940
Bi_Sb_Te_S_Te -1.199400
Bi_As_Te_Se_Te -0.837320
... ...
Bi_Sb_S_S_Se -0.239880
Sb_Sb_Te_Se_S -0.128450
Sb_Sb_Se_Te_S -0.031792
Bi_Bi_S_S_S -0.226300
As_As_Te_Te_Se -0.021450
[[abs[X16A-X16C]/log[Z11B]]/[[X16D-X16E]-abs[X16B-X16E]]] \
compound
Sb_Sb_Te_Te_Te -0.254330
As_Bi_Te_Te_S -0.013312
Bi_Bi_Te_Se_Se -0.042699
Bi_Sb_Te_S_Te 0.065214
Bi_As_Te_Se_Te 0.077297
... ...
Bi_Sb_S_S_Se -0.303040
Sb_Sb_Te_Se_S -0.031792
Sb_Sb_Se_Te_S -0.128450
Bi_Bi_S_S_S -0.226300
As_As_Te_Te_Se -0.021450
[[[X16B-X16D]/log[Z11A]]/[[X16C-X16E]-abs[X16A-X16E]]] \
compound
Sb_Sb_Te_Te_Te 0.254330
As_Bi_Te_Te_S 0.033256
Bi_Bi_Te_Se_Se 0.124940
Bi_Sb_Te_S_Te 1.199400
Bi_As_Te_Se_Te 0.837320
... ...
Bi_Sb_S_S_Se 0.239880
Sb_Sb_Te_Se_S 0.128450
Sb_Sb_Se_Te_S 0.031792
Bi_Bi_S_S_S 0.226300
As_As_Te_Te_Se -0.021450
[[[X16A-X16C]/log[Z11B]]/[[X16D-X16E]-abs[X16B-X16E]]] \
compound
Sb_Sb_Te_Te_Te 0.254330
As_Bi_Te_Te_S -0.013312
Bi_Bi_Te_Se_Se 0.042699
Bi_Sb_Te_S_Te -0.065214
Bi_As_Te_Se_Te -0.077297
... ...
Bi_Sb_S_S_Se 0.303040
Sb_Sb_Te_Se_S 0.031792
Sb_Sb_Se_Te_S 0.128450
Bi_Bi_S_S_S 0.226300
As_As_Te_Te_Se -0.021450
abs[[log[X16A]*[X16A-X16C]]-[[X16C-X16D]*[Z11E/Z11D]]] \
compound
Sb_Sb_Te_Te_Te 0.050249
As_Bi_Te_Te_S 0.046759
Bi_Bi_Te_Se_Se 0.359690
Bi_Sb_Te_S_Te 1.424700
Bi_As_Te_Se_Te 0.587340
... ...
Bi_Sb_S_S_Se 0.393730
Sb_Sb_Te_Se_S 0.152100
Sb_Sb_Se_Te_S 0.491230
Bi_Bi_S_S_S 0.393730
As_As_Te_Te_Se 0.046759
abs[[log[X16B]*[X16B-X16D]]-[[X16D-X16C]*[Z11E/Z11C]]] \
compound
Sb_Sb_Te_Te_Te 0.050249
As_Bi_Te_Te_S 0.070310
Bi_Bi_Te_Se_Se 0.653800
Bi_Sb_Te_S_Te 0.840460
Bi_As_Te_Se_Te 0.718350
... ...
Bi_Sb_S_S_Se 0.380460
Sb_Sb_Te_Se_S 0.491230
Sb_Sb_Se_Te_S 0.152100
Bi_Bi_S_S_S 0.393730
As_As_Te_Te_Se 0.046759
abs[[[X16D-X16C]*[Z11E/Z11D]]-[log[X16A]*abs[X16A-X16C]]] \
compound
Sb_Sb_Te_Te_Te 0.050249
As_Bi_Te_Te_S 0.046759
Bi_Bi_Te_Se_Se 0.359690
Bi_Sb_Te_S_Te 1.424700
Bi_As_Te_Se_Te 0.587340
... ...
Bi_Sb_S_S_Se 0.393730
Sb_Sb_Te_Se_S 0.152100
Sb_Sb_Se_Te_S 0.491230
Bi_Bi_S_S_S 0.393730
As_As_Te_Te_Se 0.046759
abs[[[X16C-X16D]*[Z11E/Z11C]]-[log[X16B]*abs[X16B-X16D]]] \
compound
Sb_Sb_Te_Te_Te 0.050249
As_Bi_Te_Te_S 0.070310
Bi_Bi_Te_Se_Se 0.653800
Bi_Sb_Te_S_Te 0.840460
Bi_As_Te_Se_Te 0.718350
... ...
Bi_Sb_S_S_Se 0.380460
Sb_Sb_Te_Se_S 0.491230
Sb_Sb_Se_Te_S 0.152100
Bi_Bi_S_S_S 0.393730
As_As_Te_Te_Se 0.046759
[[[Z11D/X16A]-[Z11E/X16D]]*[[Z11C-Z11E]/abs[X16A-X16E]]] \
compound
Sb_Sb_Te_Te_Te 0.00
As_Bi_Te_Te_S 1467.50
Bi_Bi_Te_Se_Se 118.81
Bi_Sb_Te_S_Te 0.00
Bi_As_Te_Se_Te 0.00
... ...
Bi_Sb_S_S_Se 178.56
Sb_Sb_Te_Se_S 700.36
Sb_Sb_Se_Te_S 605.16
Bi_Bi_S_S_S 0.00
As_As_Te_Te_Se 380.21
Class
compound
Sb_Sb_Te_Te_Te 0.0
As_Bi_Te_Te_S 0.0
Bi_Bi_Te_Se_Se 0.0
Bi_Sb_Te_S_Te 0.0
Bi_As_Te_Se_Te 0.0
... ...
Bi_Sb_S_S_Se 1.0
Sb_Sb_Te_Se_S 1.0
Sb_Sb_Se_Te_S 1.0
Bi_Bi_S_S_S 1.0
As_As_Te_Te_Se 1.0
[152 rows x 160001 columns]
%% Cell type:code id: tags:
``` python
sisso.fit()
```
%% Cell type:code id: tags:
``` python
model = sisso.models[1][0]
```
%% Cell type:code id: tags:
``` python
feat_1=model.feats[0].value
feat_0=model.feats[1].value
compounds=df_train.index.to_list()
classified=np.concatenate([np.ones(67),np.zeros(85)])
classified=model.prop_train
```
%% Cell type:code id: tags:
``` python
df=pd.DataFrame(data={"Compound":compounds,
"Classification":classified,
"Feat_0":feat_0,
"Feat_1":feat_1})
```
%% Cell type:code id: tags:
``` python
Visualizer(df).show()
Visualizer(df, sisso).show()
```
%%%% Output: stream
%%%% Output: display_data
hi
%%%% Output: display_data
%% Cell type:code id: tags:
``` python
```
......
......@@ -5,12 +5,12 @@ import numpy as np
from IPython.display import display, HTML, FileLink
import os
import pandas as pd
from scipy.spatial import ConvexHull
class Visualizer:
def __init__(self, df_selected, sisso):
def __init__(self, df_selected):
self.sisso = sisso
self.df_selected = df_selected
self.marker_size = 7
self.marker_symbol_cls0 = 'circle'
......@@ -25,7 +25,8 @@ class Visualizer:
]
self.font_size = 12
self.cross_size = 15
self.line_width = 1
self.hullsline_width = 1
self.clsline_width = 1
self.font_families = ['Source Sans Pro',
'Helvetica',
'Open Sans',
......@@ -59,9 +60,9 @@ class Visualizer:
# self.df_selected[feat] = values
# self.target_predict = sisso.models[sisso.n_dim - 1][0].fit
# self.target_train = sisso.models[sisso.n_dim - 1][0].prop_train
#
# self.coefficients = list(reversed(sisso.models[sisso.n_dim - 1][0].coefs[0][:-1]))
# self.intercept = sisso.models[sisso.n_dim - 1][0].coefs[0][-1]
self.coefficients = list(reversed(sisso.models[sisso.n_dim - 1][0].coefs[0][:-1]))
self.intercept = sisso.models[sisso.n_dim - 1][0].coefs[0][-1]
self.total_features = 2
self.features = ['Feat_0', 'Feat_1']
......@@ -88,67 +89,20 @@ class Visualizer:
self.viewer_l = JsmolView()
self.viewer_r = JsmolView()
self.instantiate_widgets()
# line_x, line_y = self.f_x(self.features[0], self.features[1])
# self.line_x = line_x
# self.line_y = line_y
print('hi')
x_hullvx_cls0 = [0.68060E+04,
0.68060E+04,
0.69360E+04,
0.83000E+04,
0.83000E+04,
0.94640E+04,
0.17264E+05,
0.17264E+05,
0.17264E+05,
0.12272E+05,
0.94640E+04,
0.78000E+04,
0.72060E+04]
y_hullvx_cls0 = [
0.18347E+02,
0.18347E+02,
0.28195E+02,
0.66772E+02,
0.66772E+02,
0.72881E+02,
0.70416E+02,
0.32634E+02,
0.51485E+01,
0.35512E+01,
0.28624E+01,
0.28624E+01,
0.73945E+01]
x_hullvx_cls1 = [
0.32640E+04,
0.37760E+04,
0.53120E+04,
0.78000E+04,
0.78540E+04,
0.78540E+04,
0.68060E+04,
0.68060E+04,
0.61880E+04,
0.53120E+04,
0.48120E+04]
y_hullvx_cls1 = [
0.82732E+01,
0.10083E+03,
0.10083E+03,
0.72881E+02,
0.67551E+02,
0.67551E+02,
0.27366E+02,
0.27366E+02,
0.11541E+02,
0.88713E+01,
0.81101E+01,
]
line_x, line_y = self.f_x(self.features[0], self.features[1])
self.line_x = line_x
self.line_y = line_y
# Design of the convex hulls
hull_cls0 = ConvexHull(df_selected[self.df_cls0][[self.features[0], self.features[1]]].to_numpy())
vertexes_cls0 = df_selected[self.df_cls0][[self.features[0], self.features[1]]].to_numpy()[hull_cls0.vertices]
hull_cls1 = ConvexHull(df_selected[self.df_cls1][[self.features[0], self.features[1]]].to_numpy())
vertexes_cls1 = df_selected[self.df_cls1][[self.features[0], self.features[1]]].to_numpy()[hull_cls1.vertices]
x_hullvx_cls0 = vertexes_cls0[:, 0]
y_hullvx_cls0 = vertexes_cls0[:, 1]
x_hullvx_cls1 = vertexes_cls1[:, 0]
y_hullvx_cls1 = vertexes_cls1[:, 1]
n_intervals = 100
self.xhull_cls0 = np.array([x_hullvx_cls0[0]])
self.yhull_cls0 = np.array([y_hullvx_cls0[0]])
for xy in zip(x_hullvx_cls0, y_hullvx_cls0):
......@@ -156,7 +110,6 @@ class Visualizer:
self.yhull_cls0 = np.concatenate([self.yhull_cls0, np.linspace(self.yhull_cls0[-1], xy[1], n_intervals)])
self.xhull_cls0 = np.concatenate([self.xhull_cls0, np.linspace(self.xhull_cls0[-1], x_hullvx_cls0[0], n_intervals)])
self.yhull_cls0 = np.concatenate([self.yhull_cls0, np.linspace(self.yhull_cls0[-1], y_hullvx_cls0[0], n_intervals)])
self.xhull_cls1 = np.array([x_hullvx_cls1[0]])
self.yhull_cls1 = np.array([y_hullvx_cls1[0]])
for xy in zip(x_hullvx_cls1, y_hullvx_cls1):
......@@ -231,6 +184,17 @@ class Visualizer:
name=r'Convex' + '<br>' + 'hull 1',
visible=True
),
)
self.fig.add_trace(
go.Scatter(
x=self.line_x,
y=self.line_y,
line=dict(color='Black', width=1, dash='solid'),
name=r'Classification' + '<br>' + 'line',
visible=True
),
)
x_min = min(min(x_cls0), min(x_cls1))
y_min = min(min(y_cls0), min(y_cls1))
......@@ -268,13 +232,14 @@ class Visualizer:
self.scatter_cls1 = self.fig.data[1]
self.scatter_hull0 = self.fig.data[2]
self.scatter_hull1 = self.fig.data[3]
self.scatter_clsline = self.fig.data[4]
if self.total_features == 2:
self.scatter_hull0.visible = True
self.scatter_hull1.visible = True
else:
self.widg_linewidth.disabled = True
self.widg_linestyle.disabled = True
self.widg_hullslinewidth.disabled = True
self.widg_hullslinestyle.disabled = True
self.update_markers()
......@@ -497,7 +462,7 @@ class Visualizer:
structure_l = self.df_selected[self.df_selected['Chem Formula'] ==
self.widg_compound_text_l.value]['Structure'].values[0]
self.viewer_l.script(
"load data/insulators_discovery/structures/" + structure_l + "_structures/"
"load data/topological_insulators/structures/" + structure_l + "_structures/"
+ self.widg_compound_text_l.value + ".xyz")
symbols_cls0 = self.symbols_cls0
......@@ -531,7 +496,7 @@ class Visualizer:
structure_r = self.df_selected[self.df_selected['Chem Formula'] ==
self.widg_compound_text_r.value]['Structure'].values[0]
self.viewer_r.script(
"load data/insulators_discovery/structures/" + structure_r + "_structures/"
"load data/topological_insulators/structures/" + structure_r + "_structures/"
+ self.widg_compound_text_r.value + ".xyz")
symbols_cls0 = self.symbols_cls0
......@@ -605,19 +570,30 @@ class Visualizer:
self.set_markers_size(feature=self.widg_featmarker.value)
self.update_markers()
def handle_linewidth_change(self, change):
def handle_hullslinewidth_change(self, change):
self.line_width = change.new
self.hullsline_width = change.new
with self.fig.batch_update():
self.scatter_hull0.line.width = change.new
self.scatter_hull1.line.width = change.new
def handle_linestyle_change(self, change):
def handle_hullslinestyle_change(self, change):
with self.fig.batch_update():
self.scatter_hull0.line.dash = change.new
self.scatter_hull1.line.dash = change.new
def handle_clslinewidth_change(self, change):
self.clsline_width = change.new
with self.fig.batch_update():
self.scatter_clsline.line.width = change.new
def handle_clslinestyle_change(self, change):
with self.fig.batch_update():
self.scatter_clsline.line.dash = change.new
def handle_markersymbol_cls0_change(self, change):
for i, e in enumerate(self.symbols_cls0):
......@@ -657,7 +633,7 @@ class Visualizer:
text = "A download link will appear soon."
with self.widg_print_out:
print(text)
path = "./data/insulators_discovery/plots/"
path = "./data/topological_insulators/plots/"
try:
os.mkdir(path)
except:
......@@ -702,16 +678,16 @@ class Visualizer:
self.widg_checkbox_l.value = True
def view_structure_cls0_l(self, formula):
self.viewer_l.script("load data/insulators_discovery/structures/RS_structures/" + formula + ".xyz")
self.viewer_l.script("load data/topological_insulators/structures/RS_structures/" + formula + ".xyz")
def view_structure_cls0_r(self, formula):
self.viewer_r.script("load data/insulators_discovery/structures/RS_structures/" + formula + ".xyz")
self.viewer_r.script("load data/topological_insulators/structures/RS_structures/" + formula + ".xyz")
def view_structure_cls1_l(self, formula):
self.viewer_l.script("load data/insulators_discovery/structures/ZB_structures/" + formula + ".xyz")
self.viewer_l.script("load data/topological_insulators/structures/ZB_structures/" + formula + ".xyz")
def view_structure_cls1_r(self, formula):
self.viewer_r.script("load data/insulators_discovery/structures/ZB_structures/" + formula + ".xyz")
self.viewer_r.script("load data/topological_insulators/structures/ZB_structures/" + formula + ".xyz")
def update_point_cls0(self, trace, points, selector):
# changes the points labeled with a cross on the map.
......@@ -823,12 +799,14 @@ class Visualizer:
self.widg_reset_button.on_click(self.reset_button_clicked)
self.widg_print_button.on_click(self.print_button_clicked)
self.widg_bgtoggle_button.on_click(self.bgtoggle_button_clicked)
self.widg_linestyle.observe(self.handle_linestyle_change, names='value')
self.scatter_cls0.on_click(self.update_point_cls0)
self.scatter_cls1.on_click(self.update_point_cls1)
self.widg_markersize.observe(self.handle_markersize_change, names='value')
self.widg_crosssize.observe(self.handle_crossize_change, names='value')
self.widg_linewidth.observe(self.handle_linewidth_change, names='value')
self.widg_hullslinewidth.observe(self.handle_hullslinewidth_change, names='value')
self.widg_hullslinestyle.observe(self.handle_hullslinestyle_change, names='value')
self.widg_clslinewidth.observe(self.handle_clslinewidth_change, names='value')
self.widg_clslinestyle.observe(self.handle_clslinestyle_change, names='value')
self.widg_fontfamily.observe(self.handle_fontfamily_change, names='value')
self.widg_fontsize.observe(self.handle_fontsize_change, names='value')
self.widg_plotutils_button.on_click(self.plotappearance_button_clicked)
......@@ -941,16 +919,28 @@ class Visualizer:
value=str(self.font_size),
layout = widgets.Layout(left='30px', width='200px')
)
self.widg_linewidth = widgets.BoundedIntText(
placeholder=str(self.line_width),
self.widg_hullslinewidth = widgets.BoundedIntText(
placeholder=str(self.hullsline_width),
description='Hulls width',
value=str(self.hullsline_width),
layout=widgets.Layout(left='30px', width='200px')
)
self.widg_hullslinestyle = widgets.Dropdown(
options=self.line_styles,
description='Hulls style',
value=self.line_styles[0],
layout=widgets.Layout(left='30px', width='200px')
)
self.widg_clslinewidth = widgets.BoundedIntText(
placeholder=str(self.clsline_width),
description='Line width',
value=str(self.line_width),
layout = widgets.Layout(left='30px', width='200px')
value=str(self.clsline_width),
layout=widgets.Layout(left='30px', width='200px')
)
self.widg_linestyle = widgets.Dropdown(
self.widg_clslinestyle = widgets.Dropdown(
options=self.line_styles,
description='Line style',
value=self.line_styles[0],
value='solid',
layout=widgets.Layout(left='30px', width='200px')
)
self.widg_fontfamily = widgets.Dropdown(
......@@ -959,7 +949,6 @@ class Visualizer:
value=self.font_families[0],
layout=widgets.Layout(left='30px', width='200px')
)
self.widg_bgcolor = widgets.Text(
placeholder=str(self.bg_color),
description='Background',
......@@ -969,24 +958,24 @@ class Visualizer: