deepof_model_evaluation.ipynb 160 KB
Newer Older
1
2
3
4
{
 "cells": [
  {
   "cell_type": "code",
5
   "execution_count": 1,
6
   "metadata": {},
7
   "outputs": [],
8
9
10
11
12
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
13
14
  {
   "cell_type": "code",
15
   "execution_count": 2,
16
17
18
19
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
20
    "\n",
21
22
23
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# deepOF model evaluation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Given a dataset and a trained model, this notebook allows the user to \n",
    "\n",
    "* Load and inspect the different models (encoder, decoder, grouper, gmvaep)\n",
    "* Visualize reconstruction quality for a given model\n",
    "* Visualize a static latent space\n",
    "* Visualize trajectories on the latent space for a given video\n",
    "* sample from the latent space distributions and generate video clips showcasing generated data"
   ]
  },
  {
   "cell_type": "code",
46
   "execution_count": 3,
47
48
49
50
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
51
    "\n",
52
53
54
55
56
    "os.chdir(os.path.dirname(\"../\"))"
   ]
  },
  {
   "cell_type": "code",
57
   "execution_count": 4,
58
59
60
61
62
63
64
   "metadata": {},
   "outputs": [],
   "source": [
    "import deepof.data\n",
    "import deepof.utils\n",
    "import numpy as np\n",
    "import pandas as pd\n",
65
    "import re\n",
66
    "import tensorflow as tf\n",
67
    "from collections import Counter\n",
68
69
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
70
71
72
73
74
    "from sklearn.manifold import TSNE\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.discriminant_analysis import LinearDiscriminantAnalysis\n",
    "import umap\n",
    "\n",
75
    "from ipywidgets import interactive, interact, HBox, Layout, VBox\n",
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
    "from IPython import display\n",
    "from matplotlib.animation import FuncAnimation\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "from ipywidgets import interact"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. Define and run project"
   ]
  },
  {
   "cell_type": "code",
93
   "execution_count": 76,
94
95
96
   "metadata": {},
   "outputs": [],
   "source": [
97
    "path = os.path.join(\"..\", \"..\", \"Desktop\", \"deepof-data\", \"deepof_single_topview\")\n",
98
    "trained_network = os.path.join(\"..\", \"..\", \"Desktop\", \"deepof_trained_weights_280521\", \"var_annealing\")\n",
99
    "exclude_bodyparts = tuple([\"\"])\n",
100
101
    "window_size = 22\n",
    "batch_size = 64"
102
103
104
105
   ]
  },
  {
   "cell_type": "code",
106
   "execution_count": 77,
107
   "metadata": {},
108
109
110
111
112
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
113
114
      "CPU times: user 41.9 s, sys: 3.25 s, total: 45.2 s\n",
      "Wall time: 37.5 s\n"
115
116
117
     ]
    }
   ],
118
119
120
   "source": [
    "%%time\n",
    "proj = deepof.data.project(\n",
121
    "    path=path, smooth_alpha=0.999, exclude_bodyparts=exclude_bodyparts, arena_dims=[380],\n",
122
123
124
125
126
    ")"
   ]
  },
  {
   "cell_type": "code",
127
   "execution_count": 78,
128
129
130
   "metadata": {
    "scrolled": true
   },
131
132
133
134
135
136
137
138
139
140
141
142
143
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading trajectories...\n",
      "Smoothing trajectories...\n",
      "Interpolating outliers...\n",
      "Iterative imputation of ocluded bodyparts...\n",
      "Computing distances...\n",
      "Computing angles...\n",
      "Done!\n",
      "deepof analysis of 166 videos\n",
144
145
      "CPU times: user 9min 14s, sys: 11.8 s, total: 9min 26s\n",
      "Wall time: 2min 3s\n"
146
147
148
     ]
    }
   ],
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
   "source": [
    "%%time\n",
    "proj = proj.run(verbose=True)\n",
    "print(proj)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. Load pretrained deepof model"
   ]
  },
  {
   "cell_type": "code",
164
   "execution_count": 79,
165
166
167
   "metadata": {},
   "outputs": [],
   "source": [
168
    "coords = proj.get_coords(center=\"Center\", align=\"Spine_1\", align_inplace=True)\n",
169
    "data_prep = coords.preprocess(test_videos=0, window_step=1, window_size=window_size, shuffle=False)[\n",
170
171
    "    0\n",
    "]"
172
173
174
175
   ]
  },
  {
   "cell_type": "code",
176
   "execution_count": 80,
177
   "metadata": {},
178
179
180
181
   "outputs": [
    {
     "data": {
      "text/plain": [
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
       "['GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=25_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=4_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=25_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=6_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=5_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=1_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=25_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=8_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=20_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=8_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=15_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=5_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=5_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=5_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=10_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=8_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=25_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=7_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=25_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=6_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=15_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=4_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=10_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=4_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=15_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=6_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=10_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=7_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=10_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=5_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=20_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=1_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=20_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=3_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=25_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=10_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=5_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=9_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=15_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=1_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=20_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=4_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=15_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=9_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=15_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=3_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=10_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=6_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=5_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=1_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=25_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=7_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=5_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=8_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=20_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=7_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=10_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=5_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=20_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=5_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=10_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=10_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=20_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=10_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=10_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=1_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=20_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=7_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=5_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=1_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=15_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=5_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=25_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=1_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=25_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=3_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=10_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=1_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=20_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=10_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=20_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=3_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=10_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=3_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=20_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=6_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=5_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=1_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=15_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=4_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=5_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=3_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=25_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=10_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=20_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=8_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=15_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=2_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=5_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=2_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=15_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=8_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=5_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=5_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=20_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=8_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=5_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=7_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=25_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=8_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=20_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=2_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=25_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=4_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=10_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=9_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=15_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=6_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=5_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=6_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=10_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=8_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=15_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=9_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=5_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=5_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=15_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=1_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=20_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=3_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=5_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=10_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=25_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=9_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=25_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=8_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=10_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=5_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=10_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=7_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=20_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=10_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=10_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=5_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=20_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=7_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=25_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=7_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=25_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=5_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=25_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=9_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=5_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=10_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=15_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=7_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=15_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=5_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=5_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=4_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=20_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=7_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=10_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=3_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=10_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=1_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=20_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=5_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=20_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=3_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=10_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=1_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=20_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=1_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=20_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=9_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=5_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=5_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=25_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=3_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=10_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=2_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=25_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=10_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=15_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=1_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=15_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=10_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=5_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=1_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=5_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=5_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=25_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=2_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=15_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=2_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=25_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=2_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=20_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=8_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=25_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=3_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=20_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=4_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=10_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=4_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=5_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=9_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=20_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=6_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=10_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=7_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=25_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=6_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=15_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=8_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=20_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=5_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=5_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=8_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=15_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=2_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=10_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=4_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=10_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=6_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=20_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=2_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=5_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=1_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=10_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=5_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=15_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=7_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=15_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=5_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=25_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=6_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=25_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=7_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=5_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=4_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=15_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=4_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=25_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=10_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=20_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=9_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=10_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=9_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=25_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=7_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=25_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=5_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=25_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=9_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=20_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=10_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=5_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=3_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=15_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=3_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=15_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=9_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=20_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=9_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=5_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=2_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=15_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=5_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=5_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=10_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=25_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=1_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=20_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=7_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=20_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=2_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=15_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=10_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=10_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=2_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=25_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=2_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=15_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=4_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=20_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=6_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=5_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=10_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=25_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=6_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=25_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=8_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=25_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=4_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=10_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=6_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=20_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=6_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=10_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=4_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=25_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=8_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=20_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=2_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=10_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=4_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=25_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=9_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=15_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=8_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=5_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=4_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=10_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=9_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=5_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=7_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=15_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=7_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=20_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=3_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=25_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=9_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=10_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=8_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=10_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=10_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=25_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=5_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=20_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=9_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=5_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=6_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=10_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=10_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=25_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=10_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=15_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=1_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=5_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=4_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=20_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=9_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=25_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=3_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=25_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=2_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=15_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=10_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=10_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=1_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=15_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=3_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=20_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=10_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=15_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=1_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=5_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=4_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=25_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=3_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=25_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=1_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=5_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=10_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=25_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=2_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=20_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=8_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=5_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=4_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=10_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=3_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=20_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=2_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=20_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=6_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=5_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=5_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=20_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=4_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=10_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=2_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=15_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=6_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=20_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=1_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=15_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=4_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=5_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=9_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=20_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=8_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=15_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=8_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=20_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=4_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=5_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=8_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=25_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=2_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=5_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=1_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=5_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=3_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=15_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=6_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=15_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=10_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=5_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=8_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=25_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=1_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=25_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=3_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=10_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=1_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=20_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=1_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=10_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=3_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=20_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=5_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=10_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=3_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=5_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=3_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=15_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=7_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=5_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=3_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=25_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=5_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=10_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=4_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=25_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=9_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=10_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=7_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=5_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=8_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=20_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=7_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=20_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=5_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=5_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=2_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=20_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=1_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=10_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=10_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=10_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=7_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=20_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=3_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=10_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=5_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=15_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=1_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=15_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=3_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=20_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=6_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=20_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=10_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=25_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=4_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=10_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=8_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=25_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=5_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=15_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=4_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=5_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=8_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=10_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=6_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=15_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=6_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=5_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=3_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=25_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=4_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=25_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=8_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=25_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=6_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=10_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=8_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=5_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=7_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=15_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=7_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=15_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=8_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=25_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=10_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=10_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=2_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=15_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=2_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=25_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=1_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=5_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=3_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=15_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=3_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=25_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=2_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=5_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=7_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=10_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=3_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=15_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=9_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=20_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=3_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=20_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=1_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=5_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=7_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=25_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=1_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=20_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=2_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=15_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=7_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=15_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=5_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=10_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=3_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=20_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=7_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=10_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=1_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=5_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=6_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=20_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=5_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=15_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=8_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=20_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=5_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=5_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=8_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=10_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=7_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=10_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=5_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=25_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=9_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=25_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=7_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=25_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=5_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=15_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=9_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=10_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=9_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=5_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=7_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=15_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=3_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=10_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=7_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=20_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=1_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=5_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=4_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=15_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=4_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=20_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=9_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=10_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=8_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=5_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=5_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=15_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=2_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=20_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=8_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=5_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=7_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=10_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=8_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=25_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=6_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=15_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=10_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=15_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=6_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=5_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=2_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=25_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=10_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=20_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=4_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=10_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=2_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=10_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=2_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=5_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=9_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=25_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=2_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=5_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=2_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=15_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=7_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=5_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=10_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=20_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=5_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=5_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=9_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=25_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=3_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=15_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=9_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=15_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=1_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=5_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=1_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=5_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=8_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=15_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=10_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=20_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=9_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=15_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=6_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=5_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=6_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=10_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=9_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=5_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=2_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=10_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=10_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=25_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=7_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=25_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=9_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=25_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=5_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=10_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=10_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=15_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=7_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=15_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=5_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=5_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=9_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=10_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=7_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=25_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=5_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=25_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=4_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=10_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=9_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=15_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=2_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=20_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=7_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=5_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=3_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=10_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=4_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=20_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=2_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=10_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=6_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=20_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=4_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=5_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=9_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=20_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=6_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=10_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=6_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=5_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=10_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=25_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=4_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=5_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=2_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=25_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=8_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=10_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=5_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=20_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=6_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=10_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=2_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=20_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=4_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=5_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=7_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=20_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=3_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=15_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=6_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=15_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=10_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=15_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=4_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=10_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=1_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=5_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=6_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=15_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=8_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=20_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=2_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=10_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=2_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=25_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=3_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=25_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=1_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=5_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=6_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=15_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=2_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=5_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=2_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=15_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=10_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=25_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=1_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=15_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=3_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=15_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=1_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=10_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=3_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=15_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=9_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=20_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=10_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=25_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=7_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=20_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=1_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=10_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=10_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=10_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=9_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=15_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=3_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=20_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=9_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=5_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=6_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=5_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=4_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=25_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=10_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=10_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=9_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=20_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=10_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=20_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=8_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=15_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=5_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=5_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=5_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=10_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=10_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=10_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=6_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=15_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=2_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=10_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=8_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=5_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=6_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=15_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=8_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=25_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=6_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=25_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=8_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=25_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=4_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=5_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=10_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=10_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=6_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=5_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=9_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=20_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=4_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=15_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=9_final_weights.h5',\n",
       " 'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.15_PhenoPred=0.0_RuleBasedPred=0.15_loss=ELBO_loss_warmup=10_warmup_mode=sigmoid_encoding=6_k=15_latreg=variance_entknn=100_run=4_final_weights.h5']"
582
583
      ]
     },
584
     "execution_count": 80,
585
586
587
588
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
589
590
591
592
593
594
   "source": [
    "[i for i in os.listdir(trained_network) if i.endswith(\"h5\")]"
   ]
  },
  {
   "cell_type": "code",
595
   "execution_count": 81,
596
   "metadata": {},
597
598
599
600
   "outputs": [
    {
     "data": {
      "text/plain": [
601
       "'GMVAE_input_type=coords_window_size=22_NextSeqPred=0.0_PhenoPred=0.0_RuleBasedPred=0.0_loss=ELBO_loss_warmup=25_warmup_mode=linear_encoding=6_k=15_latreg=variance_entknn=100_run=6_final_weights.h5'"
602
603
      ]
     },
604
     "execution_count": 81,
605
606
607
608
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
609
   "source": [
610
    "deepof_weights = [i for i in os.listdir(trained_network) if i.endswith(\"h5\")][1]\n",
611
612
613
614
615
    "deepof_weights"
   ]
  },
  {
   "cell_type": "code",
616
   "execution_count": 82,
617
618
619
   "metadata": {},
   "outputs": [],
   "source": [
620
621
622
623
    "# Set model parameters\n",
    "encoding = int(re.findall(\"encoding=(\\d+)_\", deepof_weights)[0])\n",
    "k = int(re.findall(\"k=(\\d+)_\", deepof_weights)[0])\n",
    "loss = re.findall(\"loss=(.+?)_\", deepof_weights)[0]\n",
624
625
626
    "NextSeqPred = float(re.findall(\"NextSeqPred=(.+?)_\", deepof_weights)[0])\n",
    "PhenoPred = float(re.findall(\"PhenoPred=(.+?)_\", deepof_weights)[0])\n",
    "RuleBasedPred = float(re.findall(\"RuleBasedPred=(.+?)_\", deepof_weights)[0])"
627
628
629
630
   ]
  },
  {
   "cell_type": "code",
631
   "execution_count": 164,
632
   "metadata": {},
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "in user code:\n\n    /Users/lucas_miranda/PycharmProjects/deepof/deepof/model_utils.py:522 call  *\n        number_of_clusters = tf.cast(\n    /Users/lucas_miranda/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/util/dispatch.py:201 wrapper  **\n        return target(*args, **kwargs)\n\n    TypeError: expand_dims_v2() missing 1 required positional argument: 'axis'\n",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-164-2a49287e7c66>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     16\u001b[0m     \u001b[0mrule_based_prediction\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mRuleBasedPred\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     17\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbuild\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 18\u001b[0;31m     \u001b[0mdata_prep\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     19\u001b[0m )\n\u001b[1;32m     20\u001b[0m \u001b[0mgmvaep\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_weights\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrained_network\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdeepof_weights\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/PycharmProjects/deepof/deepof/models.py\u001b[0m in \u001b[0;36mbuild\u001b[0;34m(self, input_shape)\u001b[0m\n\u001b[1;32m    480\u001b[0m                 \u001b[0mk\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumber_of_components\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    481\u001b[0m                 \u001b[0mloss_weight\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moverlap_loss\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 482\u001b[0;31m             )([z, z_cat])\n\u001b[0m\u001b[1;32m    483\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    484\u001b[0m         \u001b[0;31m# Define and instantiate generator\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    950\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0m_in_functional_construction_mode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    951\u001b[0m       return self._functional_construction_call(inputs, args, kwargs,\n\u001b[0;32m--> 952\u001b[0;31m                                                 input_list)\n\u001b[0m\u001b[1;32m    953\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    954\u001b[0m     \u001b[0;31m# Maintains info about the `Layer.call` stack.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py\u001b[0m in \u001b[0;36m_functional_construction_call\u001b[0;34m(self, inputs, args, kwargs, input_list)\u001b[0m\n\u001b[1;32m   1089\u001b[0m         \u001b[0;31m# Check input assumptions set after layer building, e.g. input shape.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1090\u001b[0m         outputs = self._keras_tensor_symbolic_call(\n\u001b[0;32m-> 1091\u001b[0;31m             inputs, input_masks, args, kwargs)\n\u001b[0m\u001b[1;32m   1092\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1093\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0moutputs\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py\u001b[0m in \u001b[0;36m_keras_tensor_symbolic_call\u001b[0;34m(self, inputs, input_masks, args, kwargs)\u001b[0m\n\u001b[1;32m    820\u001b[0m       \u001b[0;32mreturn\u001b[0m \u001b[0mnest\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmap_structure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkeras_tensor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mKerasTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput_signature\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    821\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 822\u001b[0;31m       \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_infer_output_signature\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_masks\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    823\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    824\u001b[0m   \u001b[0;32mdef\u001b[0m \u001b[0m_infer_output_signature\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_masks\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py\u001b[0m in \u001b[0;36m_infer_output_signature\u001b[0;34m(self, inputs, args, kwargs, input_masks)\u001b[0m\n\u001b[1;32m    861\u001b[0m           \u001b[0;31m# TODO(kaftan): do we maybe_build here, or have we already done it?\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    862\u001b[0m           \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_maybe_build\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 863\u001b[0;31m           \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcall_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    864\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    865\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_handle_activity_regularization\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/autograph/impl/api.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    668\u001b[0m       \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m  \u001b[0;31m# pylint:disable=broad-except\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    669\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mhasattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'ag_error_metadata'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 670\u001b[0;31m           \u001b[0;32mraise\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mag_error_metadata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_exception\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    671\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    672\u001b[0m           \u001b[0;32mraise\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mTypeError\u001b[0m: in user code:\n\n    /Users/lucas_miranda/PycharmProjects/deepof/deepof/model_utils.py:522 call  *\n        number_of_clusters = tf.cast(\n    /Users/lucas_miranda/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/util/dispatch.py:201 wrapper  **\n        return target(*args, **kwargs)\n\n    TypeError: expand_dims_v2() missing 1 required positional argument: 'axis'\n"
     ]
    }
   ],
652
   "source": [
653
    "(\n",
654
    "    encoder,\n",
655
656
657
658
659
    "    decoder,\n",
    "    grouper,\n",
    "    gmvaep,\n",
    "    prior,\n",
    "    posterior,\n",
660
    ") = deepof.models.GMVAE(\n",
661
662
663
    "    loss=loss,\n",
    "    number_of_components=k,\n",
    "    compile_model=True,\n",
664
    "    batch_size=batch_size,\n",
665
    "    encoding=encoding,\n",
666
667
668
    "    next_sequence_prediction=NextSeqPred,\n",
    "    phenotype_prediction=PhenoPred,\n",
    "    rule_based_prediction=RuleBasedPred,\n",
669
670
671
    ").build(\n",
    "    data_prep.shape\n",
    ")\n",
672
    "gmvaep.load_weights(os.path.join(trained_network, deepof_weights))"
673
674
675
676
   ]
  },
  {
   "cell_type": "code",
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
   "execution_count": 162,
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "in user code:\n\n    /Users/lucas_miranda/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py:1478 predict_function  *\n        return step_function(self, iterator)\n    /Users/lucas_miranda/PycharmProjects/deepof/deepof/model_utils.py:520 call  *\n        neighbourhood_entropy = purity_vector * max_groups\n    /Users/lucas_miranda/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/ops/math_ops.py:1180 binary_op_wrapper\n        raise e\n    /Users/lucas_miranda/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/ops/math_ops.py:1164 binary_op_wrapper\n        return func(x, y, name=name)\n    /Users/lucas_miranda/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/ops/math_ops.py:1496 _mul_dispatch\n        return multiply(x, y, name=name)\n    /Users/lucas_miranda/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/util/dispatch.py:201 wrapper\n        return target(*args, **kwargs)\n    /Users/lucas_miranda/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/ops/math_ops.py:518 multiply\n        return gen_math_ops.mul(x, y, name)\n    /Users/lucas_miranda/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/ops/gen_math_ops.py:6078 mul\n        \"Mul\", x=x, y=y, name=name)\n    /Users/lucas_miranda/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py:750 _apply_op_helper\n        attrs=attr_protos, op_def=op_def)\n    /Users/lucas_miranda/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py:592 _create_op_internal\n        compute_device)\n    /Users/lucas_miranda/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/framework/ops.py:3536 _create_op_internal\n        op_def=op_def)\n    /Users/lucas_miranda/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/framework/ops.py:2016 __init__\n        control_input_ops, op_def)\n    /Users/lucas_miranda/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/framework/ops.py:1856 _create_c_op\n        raise ValueError(str(e))\n\n    ValueError: Dimensions must be equal, but are 64 and 32 for '{{node SEQ_2_SEQ_GMVAE/cluster_overlap_37/mul}} = Mul[T=DT_FLOAT](SEQ_2_SEQ_GMVAE/cluster_overlap_37/map/TensorArrayV2Stack/TensorListStack, SEQ_2_SEQ_GMVAE/cluster_overlap_37/Max)' with input shapes: [64], [32].\n",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-162-f54898b44d12>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mgmvaep\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata_prep\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m32\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py\u001b[0m in \u001b[0;36mpredict\u001b[0;34m(self, x, batch_size, verbose, steps, callbacks, max_queue_size, workers, use_multiprocessing)\u001b[0m\n\u001b[1;32m   1627\u001b[0m           \u001b[0;32mfor\u001b[0m \u001b[0mstep\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata_handler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msteps\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1628\u001b[0m             \u001b[0mcallbacks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_predict_batch_begin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1629\u001b[0;31m             \u001b[0mtmp_batch_outputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1630\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mdata_handler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshould_sync\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1631\u001b[0m               \u001b[0mcontext\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masync_wait\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwds)\u001b[0m\n\u001b[1;32m    826\u001b[0m     \u001b[0mtracing_count\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexperimental_get_tracing_count\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    827\u001b[0m     \u001b[0;32mwith\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTrace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_name\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mtm\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 828\u001b[0;31m       \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    829\u001b[0m       \u001b[0mcompiler\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"xla\"\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_experimental_compile\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;34m\"nonXla\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    830\u001b[0m       \u001b[0mnew_tracing_count\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexperimental_get_tracing_count\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py\u001b[0m in \u001b[0;36m_call\u001b[0;34m(self, *args, **kwds)\u001b[0m\n\u001b[1;32m    869\u001b[0m       \u001b[0;31m# This is the first call of __call__, so we have to initialize.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    870\u001b[0m       \u001b[0minitializers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 871\u001b[0;31m       \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_initialize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwds\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0madd_initializers_to\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minitializers\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    872\u001b[0m     \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    873\u001b[0m       \u001b[0;31m# At this point we know that the initialization is complete (or less\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py\u001b[0m in \u001b[0;36m_initialize\u001b[0;34m(self, args, kwds, add_initializers_to)\u001b[0m\n\u001b[1;32m    724\u001b[0m     self._concrete_stateful_fn = (\n\u001b[1;32m    725\u001b[0m         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access\n\u001b[0;32m--> 726\u001b[0;31m             *args, **kwds))\n\u001b[0m\u001b[1;32m    727\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    728\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0minvalid_creator_scope\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0munused_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0munused_kwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/eager/function.py\u001b[0m in \u001b[0;36m_get_concrete_function_internal_garbage_collected\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   2967\u001b[0m       \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2968\u001b[0m     \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_lock\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2969\u001b[0;31m       \u001b[0mgraph_function\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_maybe_define_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   2970\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mgraph_function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2971\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/eager/function.py\u001b[0m in \u001b[0;36m_maybe_define_function\u001b[0;34m(self, args, kwargs)\u001b[0m\n\u001b[1;32m   3359\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   3360\u001b[0m           \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_function_cache\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmissed\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcall_context_key\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3361\u001b[0;31m           \u001b[0mgraph_function\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_create_graph_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   3362\u001b[0m           \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_function_cache\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimary\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mcache_key\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgraph_function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   3363\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/eager/function.py\u001b[0m in \u001b[0;36m_create_graph_function\u001b[0;34m(self, args, kwargs, override_flat_arg_shapes)\u001b[0m\n\u001b[1;32m   3204\u001b[0m             \u001b[0marg_names\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0marg_names\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   3205\u001b[0m             \u001b[0moverride_flat_arg_shapes\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moverride_flat_arg_shapes\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3206\u001b[0;31m             capture_by_value=self._capture_by_value),\n\u001b[0m\u001b[1;32m   3207\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_function_attributes\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   3208\u001b[0m         \u001b[0mfunction_spec\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfunction_spec\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py\u001b[0m in \u001b[0;36mfunc_graph_from_py_func\u001b[0;34m(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)\u001b[0m\n\u001b[1;32m    988\u001b[0m         \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moriginal_func\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf_decorator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munwrap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpython_func\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    989\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 990\u001b[0;31m       \u001b[0mfunc_outputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpython_func\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mfunc_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfunc_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    991\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    992\u001b[0m       \u001b[0;31m# invariant: `func_outputs` contains only Tensors, CompositeTensors,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py\u001b[0m in \u001b[0;36mwrapped_fn\u001b[0;34m(*args, **kwds)\u001b[0m\n\u001b[1;32m    632\u001b[0m             \u001b[0mxla_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mExit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    633\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 634\u001b[0;31m           \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mweak_wrapped_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__wrapped__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    635\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    636\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    975\u001b[0m           \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m  \u001b[0;31m# pylint:disable=broad-except\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    976\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mhasattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"ag_error_metadata\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 977\u001b[0;31m               \u001b[0;32mraise\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mag_error_metadata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_exception\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    978\u001b[0m             \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    979\u001b[0m               \u001b[0;32mraise\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mValueError\u001b[0m: in user code:\n\n    /Users/lucas_miranda/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py:1478 predict_function  *\n        return step_function(self, iterator)\n    /Users/lucas_miranda/PycharmProjects/deepof/deepof/model_utils.py:520 call  *\n        neighbourhood_entropy = purity_vector * max_groups\n    /Users/lucas_miranda/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/ops/math_ops.py:1180 binary_op_wrapper\n        raise e\n    /Users/lucas_miranda/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/ops/math_ops.py:1164 binary_op_wrapper\n        return func(x, y, name=name)\n    /Users/lucas_miranda/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/ops/math_ops.py:1496 _mul_dispatch\n        return multiply(x, y, name=name)\n    /Users/lucas_miranda/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/util/dispatch.py:201 wrapper\n        return target(*args, **kwargs)\n    /Users/lucas_miranda/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/ops/math_ops.py:518 multiply\n        return gen_math_ops.mul(x, y, name)\n    /Users/lucas_miranda/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/ops/gen_math_ops.py:6078 mul\n        \"Mul\", x=x, y=y, name=name)\n    /Users/lucas_miranda/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py:750 _apply_op_helper\n        attrs=attr_protos, op_def=op_def)\n    /Users/lucas_miranda/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py:592 _create_op_internal\n        compute_device)\n    /Users/lucas_miranda/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/framework/ops.py:3536 _create_op_internal\n        op_def=op_def)\n    /Users/lucas_miranda/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/framework/ops.py:2016 __init__\n        control_input_ops, op_def)\n    /Users/lucas_miranda/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/framework/ops.py:1856 _create_c_op\n        raise ValueError(str(e))\n\n    ValueError: Dimensions must be equal, but are 64 and 32 for '{{node SEQ_2_SEQ_GMVAE/cluster_overlap_37/mul}} = Mul[T=DT_FLOAT](SEQ_2_SEQ_GMVAE/cluster_overlap_37/map/TensorArrayV2Stack/TensorListStack, SEQ_2_SEQ_GMVAE/cluster_overlap_37/Max)' with input shapes: [64], [32].\n"
     ]
    }
   ],
   "source": [
    "gmvaep.predict(data_prep[:32])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
709
710
711
712
713
714
715
716
717
718
719
720
721
722
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Uncomment to see model summaries\n",
    "# encoder.summary()\n",
    "# decoder.summary()\n",
    "# grouper.summary()\n",
    "# gmvaep.summary()"
   ]
  },
  {
   "cell_type": "code",
723
   "execution_count": 85,
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
   "metadata": {},
   "outputs": [],
   "source": [
    "# Uncomment to plot model structure\n",
    "def plot_model(model, name):\n",
    "    tf.keras.utils.plot_model(\n",
    "        model,\n",
    "        to_file=os.path.join(\n",
    "            path,\n",
    "            \"deepof_{}_{}.png\".format(name, datetime.now().strftime(\"%Y%m%d-%H%M%S\")),\n",
    "        ),\n",
    "        show_shapes=True,\n",
    "        show_dtype=False,\n",
    "        show_layer_names=True,\n",
    "        rankdir=\"TB\",\n",
    "        expand_nested=True,\n",
    "        dpi=200,\n",
    "    )\n",
    "\n",
    "\n",
    "# plot_model(encoder, \"encoder\")\n",
    "# plot_model(decoder, \"decoder\")\n",
    "# plot_model(grouper, \"grouper\")\n",
    "# plot_model(gmvaep, \"gmvaep\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
754
    "### 4. Evaluate reconstruction (to be incorporated into deepof.evaluate)"
755
756
757
758
   ]
  },
  {
   "cell_type": "code",
759
   "execution_count": 86,
760
761
762
   "metadata": {},
   "outputs": [],
   "source": [
763
764
765
766
767
768
769
770
771
772
773
774
    "# Auxiliary animation functions\n",
    "\n",
    "\n",
    "def plot_mouse_graph(instant_x, instant_y, instant_rec_x, instant_rec_y, ax, edges):\n",
    "    \"\"\"Generates a graph plot of the mouse\"\"\"\n",
    "    plots = []\n",
    "    rec_plots = []\n",
    "    for edge in edges:\n",
    "        (temp_plot,) = ax.plot(\n",
    "            [float(instant_x[edge[0]]), float(instant_x[edge[1]])],\n",
    "            [float(instant_y[edge[0]]), float(instant_y[edge[1]])],\n",
    "            color=\"#006699\",\n",
775
    "            linewidth=2.0,\n",
776
777
778
779
    "        )\n",
    "        (temp_rec_plot,) = ax.plot(\n",
    "            [float(instant_rec_x[edge[0]]), float(instant_rec_x[edge[1]])],\n",
    "            [float(instant_rec_y[edge[0]]), float(instant_rec_y[edge[1]])],\n",
780
781
    "            color=\"red\",\n",
    "            linewidth=2.0,\n",
782
783
784
    "        )\n",
    "        plots.append(temp_plot)\n",
    "        rec_plots.append(temp_rec_plot)\n",
785
    "    return plots, rec_plots\n",
786
787
    "\n",
    "\n",
788
    "def update_mouse_graph(x, y, rec_x, rec_y, plots, rec_plots, edges):\n",
789
790
791
792
793
794
    "    \"\"\"Updates the graph plot to enable animation\"\"\"\n",
    "\n",
    "    for plot, edge in zip(plots, edges):\n",
    "        plot.set_data(\n",
    "            [float(x[edge[0]]), float(x[edge[1]])],\n",
    "            [float(y[edge[0]]), float(y[edge[1]])],\n",
795
796
797
798
799
    "        )\n",
    "    for plot, edge in zip(rec_plots, edges):\n",
    "        plot.set_data(\n",
    "            [float(rec_x[edge[0]]), float(rec_x[edge[1]])],\n",
    "            [float(rec_y[edge[0]]), float(rec_y[edge[1]])],\n",
800
801
802
803
804
    "        )"
   ]
  },
  {
   "cell_type": "code",
805
   "execution_count": 52,
806
807
808
   "metadata": {
    "scrolled": false
   },
809
810
811
812
813
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
      "Test 35_s22\n"
     ]
    },
    {
     "ename": "InvalidArgumentError",
     "evalue": " indices = 176 is not in [0, 32)\n\t [[{{node SEQ_2_SEQ_GMVAE/cluster_overlap_3/map/while/body/_10/SEQ_2_SEQ_GMVAE/cluster_overlap_3/map/while/PartitionedCall/PartitionedCall/GatherV2}}]] [Op:__inference_predict_function_76105]\n\nFunction call stack:\npredict_function\n",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mInvalidArgumentError\u001b[0m                      Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-52-9198f0263678>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     97\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     98\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 99\u001b[0;31m \u001b[0manimate_mice_across_time\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrandom_exp\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m<ipython-input-52-9198f0263678>\u001b[0m in \u001b[0;36manimate_mice_across_time\u001b[0;34m(random_exp)\u001b[0m\n\u001b[1;32m     32\u001b[0m     )[0][:100]\n\u001b[1;32m     33\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 34\u001b[0;31m     \u001b[0mdata_rec\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgmvaep\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata_prep\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     35\u001b[0m     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     36\u001b[0m         \u001b[0mdata_rec\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDataFrame\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcoords_rec\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_scaler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minverse_transform\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata_rec\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m6\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py\u001b[0m in \u001b[0;36mpredict\u001b[0;34m(self, x, batch_size, verbose, steps, callbacks, max_queue_size, workers, use_multiprocessing)\u001b[0m\n\u001b[1;32m   1627\u001b[0m           \u001b[0;32mfor\u001b[0m \u001b[0mstep\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata_handler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msteps\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1628\u001b[0m             \u001b[0mcallbacks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_predict_batch_begin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1629\u001b[0;31m             \u001b[0mtmp_batch_outputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1630\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mdata_handler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshould_sync\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1631\u001b[0m               \u001b[0mcontext\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masync_wait\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwds)\u001b[0m\n\u001b[1;32m    826\u001b[0m     \u001b[0mtracing_count\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexperimental_get_tracing_count\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    827\u001b[0m     \u001b[0;32mwith\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTrace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_name\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mtm\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 828\u001b[0;31m       \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    829\u001b[0m       \u001b[0mcompiler\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"xla\"\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_experimental_compile\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;34m\"nonXla\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    830\u001b[0m       \u001b[0mnew_tracing_count\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexperimental_get_tracing_count\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py\u001b[0m in \u001b[0;36m_call\u001b[0;34m(self, *args, **kwds)\u001b[0m\n\u001b[1;32m    893\u001b[0m       \u001b[0;31m# If we did not create any variables the trace we have is good enough.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    894\u001b[0m       return self._concrete_stateful_fn._call_flat(\n\u001b[0;32m--> 895\u001b[0;31m           filtered_flat_args, self._concrete_stateful_fn.captured_inputs)  # pylint: disable=protected-access\n\u001b[0m\u001b[1;32m    896\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    897\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mfn_with_cond\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minner_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minner_kwds\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minner_filtered_flat_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/eager/function.py\u001b[0m in \u001b[0;36m_call_flat\u001b[0;34m(self, args, captured_inputs, cancellation_manager)\u001b[0m\n\u001b[1;32m   1917\u001b[0m       \u001b[0;31m# No tape is watching; skip to running the function.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1918\u001b[0m       return self._build_call_outputs(self._inference_function.call(\n\u001b[0;32m-> 1919\u001b[0;31m           ctx, args, cancellation_manager=cancellation_manager))\n\u001b[0m\u001b[1;32m   1920\u001b[0m     forward_backward = self._select_forward_and_backward_functions(\n\u001b[1;32m   1921\u001b[0m         \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/eager/function.py\u001b[0m in \u001b[0;36mcall\u001b[0;34m(self, ctx, args, cancellation_manager)\u001b[0m\n\u001b[1;32m    558\u001b[0m               \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    559\u001b[0m               \u001b[0mattrs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mattrs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 560\u001b[0;31m               ctx=ctx)\n\u001b[0m\u001b[1;32m    561\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    562\u001b[0m           outputs = execute.execute_with_cancellation(\n",
      "\u001b[0;32m~/opt/anaconda3/envs/Machine_Learning/lib/python3.6/site-packages/tensorflow/python/eager/execute.py\u001b[0m in \u001b[0;36mquick_execute\u001b[0;34m(op_name, num_outputs, inputs, attrs, ctx, name)\u001b[0m\n\u001b[1;32m     58\u001b[0m     \u001b[0mctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mensure_initialized\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     59\u001b[0m     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,\n\u001b[0;32m---> 60\u001b[0;31m                                         inputs, attrs, num_outputs)\n\u001b[0m\u001b[1;32m     61\u001b[0m   \u001b[0;32mexcept\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_NotOkStatusException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     62\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mInvalidArgumentError\u001b[0m:  indices = 176 is not in [0, 32)\n\t [[{{node SEQ_2_SEQ_GMVAE/cluster_overlap_3/map/while/body/_10/SEQ_2_SEQ_GMVAE/cluster_overlap_3/map/while/PartitionedCall/PartitionedCall/GatherV2}}]] [Op:__inference_predict_function_76105]\n\nFunction call stack:\npredict_function\n"
833
834
835
836
     ]
    },
    {
     "data": {
837
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAlsAAAJDCAYAAAA8QNGHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAUgklEQVR4nO3dX4jld3nH8c9jYipotNBsQbKJCXRTTVWIHdIULwyYliQXmwtbSUCsEtybRmwVIaKoxCuVWhDiny2VVEHT6IUsuJKCjQTESFZsg0mILNGajUKixtwEjWmfXswo42R352Ryntk9yesFC/P7ne+c88CX2X3v75w5p7o7AADMeMGpHgAA4LlMbAEADBJbAACDxBYAwCCxBQAwSGwBAAzaNraq6nNV9UhVff8Et1dVfbKqjlbVPVX1uuWPCQCwmha5snVLkitPcvtVSfZt/DmQ5NPPfiwAgOeGbWOru+9M8ouTLLkmyed73V1J/rCqXr6sAQEAVtkyXrN1bpKHNh0f2zgHAPC8d+ZuPlhVHcj6U4158Ytf/OevfOUrd/PhAQB25Lvf/e7PunvPTr53GbH1cJLzNh3v3Tj3NN19MMnBJFlbW+sjR44s4eEBAGZV1f/s9HuX8TTioSRv3fitxMuSPN7dP13C/QIArLxtr2xV1ZeSXJ7knKo6luRDSV6YJN39mSSHk1yd5GiSJ5K8fWpYAIBVs21sdfd129zeSf5+aRMBADyHeAd5AIBBYgsAYJDYAgAYJLYAAAaJLQCAQWILAGCQ2AIAGCS2AAAGiS0AgEFiCwBgkNgCABgktgAABoktAIBBYgsAYJDYAgAYJLYAAAaJLQCAQWILAGCQ2AIAGCS2AAAGiS0AgEFiCwBgkNgCABgktgAABoktAIBBYgsAYJDYAgAYJLYAAAaJLQCAQWILAGCQ2AIAGCS2AAAGiS0AgEFiCwBgkNgCABgktgAABoktAIBBYgsAYJDYAgAYJLYAAAaJLQCAQWILAGCQ2AIAGCS2AAAGiS0AgEFiCwBgkNgCABgktgAABoktAIBBYgsAYJDYAgAYJLYAAAaJLQCAQWILAGCQ2AIAGCS2AAAGiS0AgEFiCwBgkNgCABgktgAABoktAIBBYgsAYJDYAgAYJLYAAAaJLQCAQWILAGCQ2AIAGCS2AAAGiS0AgEFiCwBgkNgCABgktgAABoktAIBBYgsAYJDYAgAYJLYAAAaJLQCAQWILAGCQ2AIAGCS2AAAGiS0AgEFiCwBgkNgCABgktgAABoktAIBBYgsAYJDYAgAYJLYAAAaJLQCAQWILAGCQ2AIAGCS2AAAGiS0AgEFiCwBgkNgCABgktgAABoktAIBBYgsAYJDYAgAYJLYAAAYtFFtVdWVVPVBVR6vqxuPcfn5V3VFV36uqe6rq6uWPCgCweraNrao6I8nNSa5KcnGS66rq4i3LPpDktu6+JMm1ST617EEBAFbRIle2Lk1ytLsf7O4nk9ya5JotazrJSze+flmSnyxvRACA1XXmAmvOTfLQpuNjSf5iy5oPJ/mPqnpnkhcnuWIp0wEArLhlvUD+uiS3dPfeJFcn+UJVPe2+q+pAVR2pqiOPPvrokh4aAOD0tUhsPZzkvE3HezfObXZ9ktuSpLu/neRFSc7ZekfdfbC717p7bc+ePTubGABghSwSW3cn2VdVF1bVWVl/AfyhLWt+nOSNSVJVr8p6bLl0BQA8720bW939VJIbktye5P6s/9bhvVV1U1Xt31j2niTvqKr/TvKlJG/r7p4aGgBgVSzyAvl09+Ekh7ec++Cmr+9L8vrljgYAsPq8gzwAwCCxBQAwSGwBAAwSWwAAg8QWAMAgsQUAMEhsAQAMElsAAIPEFgDAILEFADBIbAEADBJbAACDxBYAwCCxBQAwSGwBAAwSWwAAg8QWAMAgsQUAMEhsAQAMElsAAIPEFgDAILEFADBIbAEADBJbAACDxBYAwCCxBQAwSGwBAAwSWwAAg8QWAMAgsQUAMEhsAQAMElsAAIPEFgDAILEFADBIbAEADBJbAACDxBYAwCCxBQAwSGwBAAwSWwAAg8QWAMAgsQUAMEhsAQAMElsAAIPEFgDAILEFADBIbAEADBJbAACDxBYAwCCxBQAwSGwBAAwSWwAAg8QWAMAgsQUAMEhsAQAMElsAAIPEFgDAILEFADBIbAEADBJbAACDxBYAwCCxBQAwSGwBAAwSWwAAg8QWAMAgsQUAMEhsAQAMElsAAIPEFgDAILEFADBIbAEADBJbAACDxBYAwCCxBQAwSGwBAAwSWwAAg8QWAMAgsQUAMEhsAQAMElsAAIPEFgDAILEFADBIbAEADBJbAACDxBYAwCCxBQAwSGwBAAwSWwAAg8QWAMAgsQUAMEhsAQAMElsAAIPEFgDAILEFADBIbAEADBJbAACDxBYAwCCxBQAwSGwBAAwSWwAAgxaKraq6sqoeqKqjVXXjCda8uaruq6p7q+qLyx0TAGA1nbndgqo6I8nNSf4qybEkd1fVoe6+b9OafUnel+T13f1YVf3x1MAAAKtkkStblyY52t0PdveTSW5Ncs2WNe9IcnN3P5Yk3f3IcscEAFhNi8TWuUke2nR8bOPcZhcluaiqvlVVd1XVlcsaEABglW37NOIzuJ99SS5PsjfJnVX1mu7+5eZFVXUgyYEkOf/885f00AAAp69Frmw9nOS8Tcd7N85tdizJoe7+TXf/MMkPsh5fv6e7D3b3Wnev7dmzZ6czAwCsjEVi6+4k+6rqwqo6K8m1SQ5tWfPVrF/VSlWdk/WnFR9c4pwAACtp29jq7qeS3JDk9iT3J7mtu++tqpuqav/GstuT/Lyq7ktyR5L3dvfPp4YGAFgV1d2n5IHX1tb6yJEjp+SxAQCeiar6bnev7eR7vYM8AMAgsQUAMEhsAQAMElsAAIPEFgDAILEFADBIbAEADBJbAACDxBYAwCCxBQAwSGwBAAwSWwAAg8QWAMAgsQUAMEhsAQAMElsAAIPEFgDAILEFADBIbAEADBJbAACDxBYAwCCxBQAwSGwBAAwSWwAAg8QWAMAgsQUAMEhsAQAMElsAAIPEFgDAILEFADBIbAEADBJbAACDxBYAwCCxBQAwSGwBAAwSWwAAg8QWAMAgsQUAMEhsAQAMElsAAIPEFgDAILEFADBIbAEADBJbAACDxBYAwCCxBQAwSGwBAAwSWwAAg8QWAMAgsQUAMEhsAQAMElsAAIPEFgDAILEFADBIbAEADBJbAACDxBYAwCCxBQAwSGwBAAwSWwAAg8QWAMAgsQUAMEhsAQAMElsAAIPEFgDAILEFADBIbAEADBJbAACDxBYAwCCxBQAwSGwBAAwSWwAAg8QWAMAgsQUAMEhsAQAMElsAAIPEFgDAILEFADBIbAEADBJbAACDxBYAwCCxBQAwSGwBAAwSWwAAg8QWAMAgsQUAMEhsAQAMElsAAIPEFgDAILEFADBIbAEADBJbAACDxBYAwCCxBQAwSGwBAAwSWwAAg8QWAMAgsQUAMEhsAQAMWii2qurKqnqgqo5W1Y0nWfemquqqWlveiAAAq2vb2KqqM5LcnOSqJBcnua6qLj7OurOTvCvJd5Y9JADAqlrkytalSY5294Pd/WSSW5Ncc5x1H0ny0SS/WuJ8AAArbZHYOjfJQ5uOj22c+52qel2S87r7a0ucDQBg5T3rF8hX1QuSfCLJexZYe6CqjlTVkUcfffTZPjQAwGlvkdh6OMl5m473bpz7rbOTvDrJN6vqR0kuS3LoeC+S7+6D3b3W3Wt79uzZ+dQAACtikdi6O8m+qrqwqs5Kcm2SQ7+9sbsf7+5zuvuC7r4gyV1J9nf3kZGJAQBWyLax1d1PJbkhye1J7k9yW3ffW1U3VdX+6QEBAFbZmYss6u7DSQ5vOffBE6y9/NmPBQDw3OAd5AEABoktAIBBYgsAYJDYAgAYJLYAAAaJLQCAQWILAGCQ2AIAGCS2AAAGiS0AgEFiCwBgkNgCABgktgAABoktAIBBYgsAYJDYAgAYJLYAAAaJLQCAQWILAGCQ2AIAGCS2AAAGiS0AgEFiCwBgkNgCABgktgAABoktAIBBYgsAYJDYAgAYJLYAAAaJLQCAQWILAGCQ2AIAGCS2AAAGiS0AgEFiCwBgkNgCABgktgAABoktAIBBYgsAYJDYAgAYJLYAAAaJLQCAQWILAGCQ2AIAGCS2AAAGiS0AgEFiCwBgkNgCABgktgAABoktAIBBYgsAYJDYAgAYJLYAAAaJLQCAQWILAGCQ2AIAGCS2AAAGiS0AgEFiCwBgkNgCABgktgAABoktAIBBYgsAYJDYAgAYJLYAAAaJLQCAQWILAGCQ2AIAGCS2AAAGiS0AgEFiCwBgkNgCABgktgAABoktAIBBYgsAYJDYAgAYJLYAAAaJLQCAQWILAGCQ2AIAGCS2AAAGiS0AgEFiCwBgkNgCABgktgAABoktAIBBYgsAYJDYAgAYJLYAAAaJLQCAQWILAGCQ2AIAGCS2AAAGiS0AgEFiCwBgkNgCABgktgAABoktAIBBYgsAYJDYAgAYtFBsVdWVVfVAVR2tqhuPc/u7q+q+qrqnqr5RVa9Y/qgAAKtn29iqqjOS3JzkqiQXJ7muqi7esux7Sda6+7VJvpLkY8seFABgFS1yZevSJEe7+8HufjLJrUmu2bygu+/o7ic2Du9Ksne5YwIArKZFYuvcJA9tOj62ce5Erk/y9WczFADAc8WZy7yzqnpLkrUkbzjB7QeSHEiS888/f5kPDQBwWlrkytbDSc7bdLx349zvqaorkrw/yf7u/vXx7qi7D3b3Wnev7dmzZyfzAgCslEVi6+4k+6rqwqo6K8m1SQ5tXlBVlyT5bNZD65HljwkAsJq2ja3ufirJDUluT3J/ktu6+96quqmq9m8s+3iSlyT5clX9V1UdOsHdAQA8ryz0mq3uPpzk8JZzH9z09RVLngsA4DnBO8gDAAwSWwAAg8QWAMAgsQUAMEhsAQAMElsAAIPEFgDAILEFADBIbAEADBJbAACDxBYAwCCxBQAwSGwBAAwSWwAAg8QWAMAgsQUAMEhsAQAMElsAAIPEFgDAILEFADBIbAEADBJbAACDxBYAwCCxBQAwSGwBAAwSWwAAg8QWAMAgsQUAMEhsAQAMElsAAIPEFgDAILEFADBIbAEADBJbAACDxBYAwCCxBQAwSGwBAAwSWwAAg8QWAMAgsQUAMEhsAQAMElsAAIPEFgDAILEFADBIbAEADBJbAACDxBYAwCCxBQAwSGwBAAwSWwAAg8QWAMAgsQUAMEhsAQAMElsAAIPEFgDAILEFADBIbAEADBJbAACDxBYAwCCxBQAwSGwBAAwSWwAAg8QWAMAgsQUAMEhsAQAMElsAAIPEFgDAILEFADBIbAEADBJbAACDxBYAwCCxBQAwSGwBAAwSWwAAg8QWAMAgsQUAMEhsAQAMElsAAIPEFgDAILEFADBIbAEADBJbAACDxBYAwCCxBQAwSGwBAAwSWwAAg8QWAMAgsQUAMEhsAQAMElsAAIPEFgDAILEFADBIbAEADBJbAACDxBYAwCCxBQAwSGwBAAwSWwAAg8QWAMAgsQUAMGih2KqqK6vqgao6WlU3Huf2P6iqf9+4/TtVdcGyBwUAWEXbxlZVnZHk5iRXJbk4yXVVdfGWZdcneay7/yTJPyf56LIHBQBYRYtc2bo0ydHufrC7n0xya5Jrtqy5Jsm/bXz9lSRvrKpa3pgAAKtpkdg6N8lDm46PbZw77prufirJ40n+aBkDAgCssjN388Gq6kCSAxuHv66q7+/m47NU5yT52akegh2xd6vN/q0ue7fa/nSn37hIbD2c5LxNx3s3zh1vzbGqOjPJy5L8fOsddffBJAeTpKqOdPfaTobm1LN/q8verTb7t7rs3WqrqiM7/d5Fnka8O8m+qrqwqs5Kcm2SQ1vWHErydxtf/02S/+zu3ulQAADPFdte2erup6rqhiS3Jzkjyee6+96quinJke4+lORfk3yhqo4m+UXWgwwA4HlvoddsdffhJIe3nPvgpq9/leRvn+FjH3yG6zm92L/VZe9Wm/1bXfZute14/8qzfQAAc3xcDwDAoPHY8lE/q2uBvXt3Vd1XVfdU1Teq6hWnYk6Ob7v927TuTVXVVeW3pE4ji+xfVb1542fw3qr64m7PyPEt8Hfn+VV1R1V9b+Pvz6tPxZw8XVV9rqoeOdFbU9W6T27s7T1V9bpF7nc0tnzUz+pacO++l2Stu1+b9U8O+NjuTsmJLLh/qaqzk7wryXd2d0JOZpH9q6p9Sd6X5PXd/WdJ/mHXB+VpFvzZ+0CS27r7kqz/QtmndndKTuKWJFee5Parkuzb+HMgyacXudPpK1s+6md1bbt33X1Hdz+xcXhX1t+DjdPDIj97SfKRrP8H51e7ORzbWmT/3pHk5u5+LEm6+5FdnpHjW2TvOslLN75+WZKf7OJ8nER335n1d1U4kWuSfL7X3ZXkD6vq5dvd73Rs+aif1bXI3m12fZKvj07EM7Ht/m1c/j6vu7+2m4OxkEV+/i5KclFVfauq7qqqk/1vnN2zyN59OMlbqupY1n/T/527MxpL8Ez/bUyyyx/Xw3NTVb0lyVqSN5zqWVhMVb0gySeSvO0Uj8LOnZn1pzIuz/pV5Tur6jXd/ctTOhWLuC7JLd39T1X1l1l/n8pXd/f/nerBmDF9ZeuZfNRPTvZRP+y6RfYuVXVFkvcn2d/dv96l2djedvt3dpJXJ/lmVf0oyWVJDnmR/GljkZ+/Y0kOdfdvuvuHSX6Q9fji1Fpk765PcluSdPe3k7wo65+byOlvoX8bt5qOLR/1s7q23buquiTJZ7MeWl4vcno56f519+PdfU53X9DdF2T9NXf7u3vHn/3FUi3yd+dXs35VK1V1TtafVnxwN4fkuBbZux8neWOSVNWrsh5bj+7qlOzUoSRv3fitxMuSPN7dP93um0afRvRRP6trwb37eJKXJPnyxu80/Li795+yofmdBfeP09SC+3d7kr+uqvuS/G+S93a3ZwVOsQX37j1J/qWq/jHrL5Z/m4sMp4eq+lLW/xNzzsZr6j6U5IVJ0t2fyfpr7K5OcjTJE0nevtD92l8AgDneQR4AYJDYAgAYJLYAAAaJLQCAQWILAGCQ2AIAGCS2AAAGiS0AgEH/Dx30rkLcbwr/AAAAAElFTkSuQmCC\n",
838
      "text/plain": [
839
       "<Figure size 720x720 with 1 Axes>"
840
841
      ]
     },
842
843
844
     "metadata": {
      "needs_background": "light"
     },
845
846
847
     "output_type": "display_data"
    }
   ],
848
849
850
   "source": [
    "# Display a video with the original data superimposed with the reconstructions\n",
    "\n",
851
    "coords = proj.get_coords(center=\"Center\", align=\"Spine_1\", align_inplace=True)\n",
852
    "random_exp = np.random.choice(list(coords.keys()), 1)[0]\n",
853
    "print(random_exp)\n",
854
855
    "\n",
    "\n",
856
    "def animate_mice_across_time(random_exp):\n",
857
    "\n",
858
    "    # Define canvas\n",
859
860
    "    fig, ax = plt.subplots(1, 1, figsize=(10, 10))\n",
    "\n",
861
    "    # Retrieve body graph\n",
862
863
864
    "    edges = deepof.utils.connect_mouse_topview()\n",
    "\n",
    "    for bpart in exclude_bodyparts:\n",
865
866
    "        if bpart:\n",
    "            edges.remove_node(bpart)\n",
867
868
869
    "\n",
    "    for limb in [\"Left_fhip\", \"Right_fhip\", \"Left_bhip\", \"Right_bhip\"]:\n",
    "        edges.remove_edge(\"Center\", limb)\n",
870
871
    "        if (\"Tail_base\", limb) in edges.edges():\n",
    "            edges.remove_edge(\"Tail_base\", limb)\n",
872
873
874
    "\n",
    "    edges = edges.edges()\n",
    "\n",
875
876
877
878
879
    "    # Compute observed and predicted data to plot\n",
    "    data = coords[random_exp]\n",
    "    coords_rec = coords.filter_videos([random_exp])\n",
    "    data_prep = coords_rec.preprocess(\n",
    "        test_videos=0, window_step=1, window_size=window_size, shuffle=False\n",
880
    "    )[0][:100]\n",
881
882
    "\n",
    "    data_rec = gmvaep.predict(data_prep)\n",
883
884
885
886
887
888
    "    try:\n",
    "        data_rec = pd.DataFrame(coords_rec._scaler.inverse_transform(data_rec[:, 6, :]))\n",
    "    except TypeError:\n",
    "        data_rec = data_rec[0]\n",
    "        data_rec = pd.DataFrame(coords_rec._scaler.inverse_transform(data_rec[:, 6, :]))\n",
    "        \n",
889
    "    data_rec.columns = data.columns\n",
890
891
    "    data = pd.DataFrame(coords_rec._scaler.inverse_transform(data_prep[:, 6, :]))\n",
    "    data.columns = data_rec.columns\n",
892
    "\n",
893
    "    # Add Central coordinate, lost during alignment\n",
894
895
896
897
898
    "    data[\"Center\", \"x\"] = 0\n",
    "    data[\"Center\", \"y\"] = 0\n",
    "    data_rec[\"Center\", \"x\"] = 0\n",
    "    data_rec[\"Center\", \"y\"] = 0\n",
    "\n",
899
    "    # Plot!\n",
900
901
902
903
904
    "    init_x = data.xs(\"x\", level=1, axis=1, drop_level=False).iloc[0, :]\n",
    "    init_y = data.xs(\"y\", level=1, axis=1, drop_level=False).iloc[0, :]\n",
    "    init_rec_x = data_rec.xs(\"x\", level=1, axis=1, drop_level=False).iloc[0, :]\n",
    "    init_rec_y = data_rec.xs(\"y\", level=1, axis=1, drop_level=False).iloc[0, :]\n",
    "\n",
905
906
907
908
909
910
    "    plots, rec_plots = plot_mouse_graph(\n",
    "        init_x, init_y, init_rec_x, init_rec_y, ax, edges\n",
    "    )\n",
    "    scatter = ax.scatter(\n",
    "        x=np.array(init_x), y=np.array(init_y), color=\"#006699\", label=\"Original\"\n",
    "    )\n",
911
    "    rec_scatter = ax.scatter(\n",
912
913
914
915
    "        x=np.array(init_rec_x),\n",
    "        y=np.array(init_rec_y),\n",
    "        color=\"red\",\n",
    "        label=\"Reconstruction\",\n",
916
917
918
919
920
921
922
923
924
    "    )\n",
    "\n",
    "    # Update data in main plot\n",
    "    def animation_frame(i):\n",
    "        # Update scatter plot\n",
    "        x = data.xs(\"x\", level=1, axis=1, drop_level=False).iloc[i, :]\n",
    "        y = data.xs(\"y\", level=1, axis=1, drop_level=False).iloc[i, :]\n",
    "        rec_x = data_rec.xs(\"x\", level=1, axis=1, drop_level=False).iloc[i, :]\n",
    "        rec_y = data_rec.xs(\"y\", level=1, axis=1, drop_level=False).iloc[i, :]\n",
925
    "\n",
926
    "        scatter.set_offsets(np.c_[np.array(x), np.array(y)])\n",
927
928
    "        rec_scatter.set_offsets(np.c_[np.array(rec_x), np.array(rec_y)])\n",
    "        update_mouse_graph(x, y, rec_x, rec_y, plots, rec_plots, edges)\n",
929
930
931
    "\n",
    "        return scatter\n",
    "\n",
932
    "    animation = FuncAnimation(fig, func=animation_frame, frames=250, interval=50,)\n",
933
    "\n",
934
    "    ax.set_title(\"Original versus reconstructed data\")\n",
935
936
937
938
    "    ax.set_ylim(-100, 60)\n",
    "    ax.set_xlim(-60, 60)\n",
    "    ax.set_xlabel(\"x\")\n",
    "    ax.set_ylabel(\"y\")\n",
939
    "    plt.legend()\n",
940
941
942
943
    "\n",
    "    video = animation.to_html5_video()\n",
    "    html = display.HTML(video)\n",
    "    display.display(html)\n",
944
945
946
947
    "    plt.close()\n",
    "\n",
    "\n",
    "animate_mice_across_time(random_exp)"
948
949
950
951
952
953
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
954
    "### 5. Evaluate latent space (to be incorporated into deepof.evaluate)"
955
956
957
958
   ]
  },
  {
   "cell_type": "code",
959
   "execution_count": null,
960
   "metadata": {},
961
   "outputs": [],
962
963
964
   "source": [
    "# Get encodings and groupings for the same random video as above\n",
    "data_prep = coords.preprocess(\n",
965
    "    test_videos=0, window_step=1, window_size=window_size, shuffle=True\n",
966
967
968
969
970
    ")[0][:10000]"
   ]
  },
  {
   "cell_type": "code",
971
   "execution_count": null,
972
973
974
   "metadata": {},
   "outputs": [],
   "source": [
975
    "encodings = encoder.predict(data_prep)\n",
976
977
978
    "groupings = grouper.predict(data_prep)\n",
    "hard_groups = np.argmax(groupings, axis=1)"
   ]
979
  },
980
981
  {
   "cell_type": "code",
982
   "execution_count": null,
983
   "metadata": {},
984
   "outputs": [],
985
986
   "source": [
    "@interact(minimum_confidence=(0.0, 1.0, 0.01))\n",
987
    "def plot_cluster_population(minimum_confidence):\n",
988
989
990
    "    plt.figure(figsize=(12, 8))\n",
    "\n",
    "    groups = hard_groups[np.max(groupings, axis=1) > minimum_confidence].flatten()\n",
991
    "    groups = np.concatenate([groups, np.arange(groupings.shape[1])])\n",
992
993
994
    "    sns.countplot(groups)\n",
    "    plt.xlabel(\"Cluster\")\n",
    "    plt.title(\"Training instances per cluster\")\n",
995
    "    plt.ylim(0, hard_groups.shape[0] * 1.1)\n",
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The slider in the figure above lets you set the minimum confidence the model may yield when assigning a training instance to a cluster in order to be visualized."
   ]
  },
  {
   "cell_type": "code",
1008
   "execution_count": null,
1009
   "metadata": {},
1010
   "outputs": [],
1011
1012
1013
   "source": [
    "# Plot real data in the latent space\n",
    "\n",
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
    "samples = np.random.choice(range(encodings.shape[0]), 10000)\n",
    "sample_enc = encodings[samples, :]\n",
    "sample_grp = groupings[samples, :]\n",
    "sample_hgr = hard_groups[samples]\n",
    "k = sample_grp.shape[1]\n",
    "\n",
    "umap_reducer = umap.UMAP(n_components=2)\n",
    "pca_reducer = PCA(n_components=2)\n",
    "tsne_reducer = TSNE(n_components=2)\n",
    "lda_reducer = LinearDiscriminantAnalysis(n_components=2)\n",
    "\n",
    "umap_enc = umap_reducer.fit_transform(sample_enc)\n",
    "pca_enc = pca_reducer.fit_transform(sample_enc)\n",
    "tsne_enc = tsne_reducer.fit_transform(sample_enc)\n",
    "try:\n",
    "    lda_enc = lda_reducer.fit_transform(sample_enc, sample_hgr)\n",
    "except ValueError:\n",
    "    warnings.warn(\n",
    "        \"Only one class found. Can't use LDA\", DeprecationWarning, stacklevel=2\n",
    "    )\n",
    "\n",
1035
1036
1037
    "\n",
    "@interact(\n",
    "    minimum_confidence=(0.0, 0.99, 0.01),\n",
1038
    "    dim_red=[\"PCA\", \"LDA\", \"umap\", \"tSNE\"],\n",
1039
1040
    "    highlight_clusters=False,\n",
    "    selected_cluster=(0, k-1),\n",
1041
    ")\n",
1042
1043
1044
1045
1046
1047
    "def plot_static_latent_space(\n",
    "    minimum_confidence, dim_red, highlight_clusters, selected_cluster\n",
    "):\n",
    "\n",
    "    global sample_enc, sample_grp, sample_hgr\n",
    "\n",
1048
    "    if dim_red == \"umap\":\n",
1049
    "        enc = umap_enc\n",
1050
    "    elif dim_red == \"LDA\":\n",
1051
    "        enc = lda_enc\n",
1052
    "    elif dim_red == \"PCA\":\n",
1053
    "        enc = pca_enc\n",
1054
    "    else:\n",
1055
    "        enc = tsne_enc\n",
1056
    "\n",
1057
1058
1059
    "    enc = enc[np.max(sample_grp, axis=1) > minimum_confidence]\n",
    "    hgr = sample_hgr[np.max(sample_grp, axis=1) > minimum_confidence].flatten()\n",
    "    grp = sample_grp[np.max(sample_grp, axis=1) > minimum_confidence]\n",
1060
    "\n",
1061
    "    plt.figure(figsize=(10, 10))\n",
1062
    "\n",
1063
1064
1065
    "    sns.scatterplot(\n",
    "        x=enc[:, 0],\n",
    "        y=enc[:, 1],\n",
1066
1067
    "        hue=hgr,\n",
    "        size=np.max(grp, axis=1),\n",
1068
    "        sizes=(1, 100),\n",
1069
    "        palette=sns.color_palette(\"husl\", len(set(hgr))),\n",
1070
    "    )\n",
1071
1072
1073
1074
1075
1076
1077
1078
    "    \n",
    "    if highlight_clusters:\n",
    "        sns.kdeplot(\n",
    "            enc[hgr == selected_cluster, 0],\n",
    "            enc[hgr == selected_cluster, 1],\n",
    "            color=\"red\",\n",
    "        )\n",
    "    \n",
1079
1080
    "    plt.xlabel(\"{} 1\".format(dim_red))\n",
    "    plt.ylabel(\"{} 2\".format(dim_red))\n",
1081
    "    plt.suptitle(\"Static view of trained latent space\")\n",
1082
1083
    "    plt.show()"
   ]
1084
  },
1085
1086
  {
   "cell_type": "code",
1087
   "execution_count": null,
1088
1089
   "metadata": {},
   "outputs": [],
1090
   "source": [
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
    "def plot_mouse_graph(instant_x, instant_y, ax, edges):\n",
    "    \"\"\"Generates a graph plot of the mouse\"\"\"\n",
    "    plots = []\n",
    "    for edge in edges:\n",
    "        (temp_plot,) = ax.plot(\n",
    "            [float(instant_x[edge[0]]), float(instant_x[edge[1]])],\n",
    "            [float(instant_y[edge[0]]), float(instant_y[edge[1]])],\n",
    "            color=\"#006699\",\n",
    "            linewidth=2.0,\n",
    "        )\n",
    "        plots.append(temp_plot)\n",
    "    return plots\n",
    "\n",
    "\n",
    "def update_mouse_graph(x, y, plots, edges):\n",
    "    \"\"\"Updates the graph plot to enable animation\"\"\"\n",
    "\n",
    "    for plot, edge in zip(plots, edges):\n",
    "        plot.set_data(\n",
    "            [float(x[edge[0]]), float(x[edge[1]])],\n",
    "            [float(y[edge[0]]), float(y[edge[1]])],\n",
    "        )"
1113
   ]
1114
1115
1116
  },
  {
   "cell_type": "code",
1117
   "execution_count": null,
1118
1119
1120
   "metadata": {
    "scrolled": false
   },
1121
   "outputs": [],
1122
1123
   "source": [
    "# Plot trajectory of a video in latent space\n",
1124
1125
1126
1127
1128
1129
1130
    "traj_prep = coords.preprocess(\n",
    "    test_videos=0, window_step=1, window_size=window_size, shuffle=False\n",
    ")[0][:10000]\n",
    "\n",
    "traj_encodings = encode_to_vector.predict(traj_prep)\n",
    "traj_grp = grouper.predict(traj_prep)\n",
    "traj_hgr = np.argmax(traj_grp, axis=1)\n",
1131
    "\n",
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
    "samples = np.random.choice(range(encodings.shape[0]), 10000)\n",
    "sample_enc = encodings[samples, :]\n",
    "sample_grp = groupings[samples, :]\n",
    "sample_hgr = hard_groups[samples]\n",
    "k = sample_grp.shape[1]\n",
    "\n",
    "umap_reducer = umap.UMAP(n_components=2)\n",
    "pca_reducer = PCA(n_components=2)\n",
    "tsne_reducer = TSNE(n_components=2)\n",
    "lda_reducer = LinearDiscriminantAnalysis(n_components=2)\n",
    "\n",
1143
1144
1145
    "umap_enc = umap_reducer.fit_transform(np.concatenate([traj_encodings, sample_enc]))\n",
    "pca_enc = pca_reducer.fit_transform(np.concatenate([traj_encodings, sample_enc]))\n",
    "tsne_enc = tsne_reducer.fit_transform(np.concatenate([traj_encodings, sample_enc]))\n",
1146
    "try:\n",
1147
1148
1149
1150
    "    lda_enc = lda_reducer.fit_transform(\n",
    "        np.concatenate([traj_encodings, sample_enc]),\n",
    "        np.concatenate([traj_hgr, sample_hgr]),\n",
    "    )\n",
1151
1152
1153
1154
1155
    "except ValueError:\n",
    "    warnings.warn(\n",
    "        \"Only one class found. Can't use LDA\", DeprecationWarning, stacklevel=2\n",
    "    )\n",
    "\n",
1156
1157
    "\n",
    "@interact(\n",
1158
    "    trajectory=(100, 500), trace=False, dim_red=[\"PCA\", \"LDA\", \"umap\", \"tSNE\"],\n",
1159
    ")\n",
1160
1161
1162
1163
    "def plot_dynamic_latent_pace(trajectory, trace, dim_red):\n",
    "\n",
    "    global sample_enc, sample_grp, sample_hgr\n",
    "\n",
1164
    "    if dim_red == \"umap\":\n",
1165
    "        enc, traj_enc = umap_enc[10000:], umap_enc[:10000]\n",
1166
    "    elif dim_red == \"LDA\":\n",
1167
    "        enc, traj_enc = lda_enc[10000:], lda_enc[:10000]\n",
1168
    "    elif dim_red == \"PCA\":\n",
1169
    "        enc, traj_enc = pca_enc[10000:], pca_enc[:10000]\n",
1170
    "    else:\n",
1171
    "        enc, traj_enc = tsne_enc[10000:], tsne_enc[:10000]\n",
1172
    "\n",
1173
    "    traj_enc = traj_enc[:trajectory, :]\n",
1174
1175
1176
1177
1178
1179
1180
1181
    "\n",
    "    # Define two figures arranged horizontally\n",
    "    fig, (ax, ax2) = plt.subplots(\n",
    "        1, 2, figsize=(12, 8), gridspec_kw={\"width_ratios\": [3, 1.5]}\n",
    "    )\n",
    "\n",
    "    # Plot the animated embedding trajectory on the left\n",
    "    sns.scatterplot(\n",
1182
1183
    "        x=enc[:, 0],\n",
    "        y=enc[:, 1],\n",
1184
1185
1186
    "        hue=sample_hgr,\n",
    "        size=np.max(sample_grp, axis=1),\n",
    "        sizes=(1, 100),\n",
1187
    "        palette=sns.color_palette(\"husl\", len(set(sample_hgr))),\n",
1188
1189
1190
    "        ax=ax,\n",
    "    )\n",
    "\n",
1191
    "    traj_init = traj_enc[0, :]\n",
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
    "    scatter = ax.scatter(\n",
    "        x=[traj_init[0]], y=[traj_init[1]], s=100, color=\"red\", edgecolor=\"black\"\n",
    "    )\n",
    "    (lineplt,) = ax.plot([traj_init[0]], [traj_init[1]], color=\"red\", linewidth=2.0)\n",
    "    tracking_line_x = []\n",
    "    tracking_line_y = []\n",
    "\n",
    "    # Plot the initial data (before feeding it to the encoder) on the right\n",
    "    edges = deepof.utils.connect_mouse_topview()\n",
    "\n",
    "    for bpart in exclude_bodyparts:\n",
    "        if bpart:\n",
    "            edges.remove_node(bpart)\n",
    "\n",
    "    for limb in [\"Left_fhip\", \"Right_fhip\", \"Left_bhip\", \"Right_bhip\"]:\n",
    "        edges.remove_edge(\"Center\", limb)\n",
    "        if (\"Tail_base\", limb) in list(edges.edges()):\n",
    "            edges.remove_edge(\"Tail_base\", limb)\n",
    "\n",
    "    edges = edges.edges()\n",
    "\n",
1213
    "    inv_coords = coords._scaler.inverse_transform(traj_prep)[:, window_size // 2, :]\n",
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
    "    data = pd.DataFrame(inv_coords, columns=coords[random_exp].columns)\n",
    "\n",
    "    data[\"Center\", \"x\"] = 0\n",
    "    data[\"Center\", \"y\"] = 0\n",
    "\n",
    "    init_x = data.xs(\"x\", level=1, axis=1, drop_level=False).iloc[0, :]\n",
    "    init_y = data.xs(\"y\", level=1, axis=1, drop_level=False).iloc[0, :]\n",
    "\n",
    "    plots = plot_mouse_graph(init_x, init_y, ax2, edges)\n",
    "    track = ax2.scatter(x=np.array(init_x), y=np.array(init_y), color=\"#006699\",)\n",
    "\n",
    "    # Update data in both plots\n",
    "    def animation_frame(i):\n",
    "        # Update scatter plot\n",
1228
    "        offset = traj_enc[i, :]\n",
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
    "\n",
    "        prev_t = scatter.get_offsets()[0]\n",
    "\n",
    "        if trace:\n",
    "            tracking_line_x.append([prev_t[0], offset[0]])\n",
    "            tracking_line_y.append([prev_t[1], offset[1]])\n",
    "            lineplt.set_xdata(tracking_line_x)\n",
    "            lineplt.set_ydata(tracking_line_y)\n",
    "\n",
    "        scatter.set_offsets(np.c_[np.array(offset[0]), np.array(offset[1])])\n",
1239
    "\n",
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
    "        x = data.xs(\"x\", level=1, axis=1, drop_level=False).iloc[i, :]\n",
    "        y = data.xs(\"y\", level=1, axis=1, drop_level=False).iloc[i, :]\n",
    "        track.set_offsets(np.c_[np.array(x), np.array(y)])\n",
    "        update_mouse_graph(x, y, plots, edges)\n",
    "\n",
    "        return scatter\n",
    "\n",
    "    animation = FuncAnimation(\n",
    "        fig, func=animation_frame, frames=trajectory, interval=75,\n",
    "    )\n",
    "\n",
    "    ax.set_xlabel(\"{} 1\".format(dim_red))\n",
    "    ax.set_ylabel(\"{} 2\".format(dim_red))\n",
    "\n",
    "    ax2.set_xlabel(\"x\")\n",
    "    ax2.set_xlabel(\"y\")\n",
    "    ax2.set_ylim(-90, 60)\n",
    "    ax2.set_xlim(-60, 60)\n",
    "\n",
    "    plt.tight_layout()\n",
    "\n",
    "    video = animation.to_html5_video()\n",
    "    html = display.HTML(video)\n",
    "    display.display(html)\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6. Sample from latent space (to be incorporated into deepof.evaluate)"
   ]
  },
  {
   "cell_type": "code",
1276
   "execution_count": null,
lucas_miranda's avatar
lucas_miranda committed
1277
   "metadata": {},
1278
   "outputs": [],
lucas_miranda's avatar
lucas_miranda committed
1279
   "source": [
1280
    "# Get prior distribution\n",
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
    "\n",
    "means = prior.components_distribution.mean().numpy()\n",
    "stddevs = prior.components_distribution.stddev().numpy()\n",
    "\n",
    "samples = []\n",
    "for i in range(means.shape[0]):\n",
    "    samples.append(\n",
    "        np.random.normal(means[i, :], stddevs[i, :], size=(500, means.shape[1]))\n",
    "    )\n",
    "samples = np.concatenate(samples)\n",
    "decodings = decoder.predict(samples)\n",
    "\n",
    "umap_reducer = umap.UMAP(n_components=2)\n",
    "pca_reducer = PCA(n_components=2)\n",
    "tsne_reducer = TSNE(n_components=2)\n",
    "lda_reducer = LinearDiscriminantAnalysis(n_components=2)\n",
    "\n",
    "umap_enc = umap_reducer.fit_transform(samples)\n",
    "pca_enc = pca_reducer.fit_transform(samples)\n",
    "tsne_enc = tsne_reducer.fit_transform(samples)\n",
    "lda_enc = lda_reducer.fit_transform(samples, np.repeat(range(means.shape[0]), 500))\n",
    "\n",
    "\n",
    "@interact(dim_red=[\"PCA\", \"LDA\", \"umap\", \"tSNE\"], selected_cluster=(1, k))\n",
    "def sample_from_prior(dim_red, selected_cluster):\n",
    "\n",
    "    if dim_red == \"umap\":\n",
    "        sample_enc = umap_enc\n",
    "    elif dim_red == \"LDA\":\n",
    "        sample_enc = lda_enc\n",
    "    elif dim_red == \"PCA\":\n",
    "        sample_enc = pca_enc\n",
    "    else:\n",
    "        sample_enc = tsne_enc\n",
    "\n",
    "    fig, (ax, ax2) = plt.subplots(\n",
    "        1, 2, figsize=(12, 8), gridspec_kw={\"width_ratios\": [3, 1.5]}\n",
    "    )\n",
    "\n",
    "    hue = np.repeat(range(means.shape[0]), 500)\n",
    "\n",
    "    # Plot the animated embedding trajectory on the left\n",
    "    sns.scatterplot(\n",
    "        x=sample_enc[:, 0],\n",
    "        y=sample_enc[:, 1],\n",
    "        hue=hue,\n",
    "        palette=sns.color_palette(\"husl\", k),\n",
    "        ax=ax,\n",
    "    )\n",
    "\n",
    "    sns.kdeplot(\n",
    "        sample_enc[hue == selected_cluster, 0],\n",
    "        sample_enc[hue == selected_cluster, 1],\n",
    "        color=\"red\",\n",
    "        ax=ax,\n",
    "    )\n",
    "\n",
    "    # Get reconstructions from samples of a given cluster\n",
    "    decs = decodings[hue == selected_cluster][np.random.randint(0, 500, 5)]\n",
    "\n",
    "    # Plot the initial data (before feeding it to the encoder) on the right\n",
    "    edges = deepof.utils.connect_mouse_topview()\n",
    "\n",
    "    for bpart in exclude_bodyparts:\n",
    "        if bpart:\n",
    "            edges.remove_node(bpart)\n",
    "\n",
    "    for limb in [\"Left_fhip\", \"Right_fhip\", \"Left_bhip\", \"Right_bhip\"]:\n",
    "        edges.remove_edge(\"Center\", limb)\n",
    "        if (\"Tail_base\", limb) in list(edges.edges()):\n",
    "            edges.remove_edge(\"Tail_base\", limb)\n",
    "\n",
    "    edges = edges.edges()\n",
    "\n",
    "    inv_coords = coords._scaler.inverse_transform(decs).reshape(\n",
    "        decs.shape[0] * decs.shape[1], decs.shape[2]\n",
    "    )\n",
    "    data = pd.DataFrame(inv_coords, columns=coords[random_exp].columns)\n",
    "\n",
    "    data[\"Center\", \"x\"] = 0\n",
    "    data[\"Center\", \"y\"] = 0\n",
    "\n",
    "    init_x = data.xs(\"x\", level=1, axis=1, drop_level=False).iloc[0, :]\n",
    "    init_y = data.xs(\"y\", level=1, axis=1, drop_level=False).iloc[0, :]\n",
    "\n",
    "    plots = plot_mouse_graph(init_x, init_y, ax2, edges)\n",
    "    track = ax2.scatter(x=np.array(init_x), y=np.array(init_y), color=\"#006699\",)\n",
    "\n",
    "    # Update data in both plots\n",
    "    def animation_frame(i):\n",
    "        # Update scatter plot\n",
    "\n",
    "        x = data.xs(\"x\", level=1, axis=1, drop_level=False).iloc[i, :]\n",
    "        y = data.xs(\"y\", level=1, axis=1, drop_level=False).iloc[i, :]\n",
    "        track.set_offsets(np.c_[np.array(x), np.array(y)])\n",
    "        update_mouse_graph(x, y, plots, edges)\n",
    "\n",
    "    animation = FuncAnimation(\n",
    "        fig, func=animation_frame, frames=5 * window_size, interval=50,\n",
    "    )\n",
    "\n",
    "    # Plot samples as video on the right\n",
    "\n",
    "    ax.set_xlabel(\"{} 1\".format(dim_red))\n",
    "    ax.set_ylabel(\"{} 2\".format(dim_red))\n",
    "    ax.get_legend().remove()\n",
    "\n",
    "    ax2.set_xlabel(\"x\")\n",
    "    ax2.set_xlabel(\"y\")\n",
    "    ax2.set_ylim(-90, 60)\n",
    "    ax2.set_xlim(-60, 60)\n",
    "\n",
    "    plt.tight_layout()\n",
    "    \n",
    "    video = animation.to_html5_video()\n",
    "    html = display.HTML(video)\n",
    "    display.display(html)\n",
    "    plt.close()"
lucas_miranda's avatar
lucas_miranda committed
1399
   ]
1400
1401
1402
1403
1404
1405
1406
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}