Commit ff826e67 authored by Marcel Henrik Schubert's avatar Marcel Henrik Schubert
Browse files

added feature heatmap for gender

parent 783fa4d7
......@@ -111,7 +111,7 @@ def plot_confusion_matrix(cm, classes,normalize=True,title='Confusion matrix',cm
im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
#plt.title(title)
cbar = ax.figure.colorbar(im, ax = ax, ticks=[], label = "Heat per Row (Normalized from 0 to 1)")
cbar = ax.figure.colorbar(im, ax = ax, ticks=[], label = "Heat per Row (Normalized)")
tick_marks = np.arange(len(classes))
#ax.set_xticks(tick_marks)
#ax.set_yticks(tick_marks)
......@@ -344,7 +344,9 @@ def plotter(subsets, subana, phases, labels):
if key_len > 2:
res_dic[st][an][label]['labels'][key]['feature_vec'] = coef[key_enc]
elif label == 'gender':
res_dic[st][an][label]['labels'][key]['feature_vec'] = coef[0]
elif key_enc > 0:
res_dic[st][an][label]['labels'][key]['feature_vec'] = coef[0]
......@@ -462,7 +464,7 @@ def plotter(subsets, subana, phases, labels):
else:
most_pred[st][an][label]['feature_vecs'].append(list(np.argsort(res_dic[st][an][label]['labels']['male']['feature_vec'])))
most_pred[st][an][label]['feature_vecs'] = res_dic[st][an][label]['labels']['male']['feature_vec']
#sys.exit(1)
......@@ -545,20 +547,6 @@ def plotter(subsets, subana, phases, labels):
for vec in most_pred[st][an][label][sex]['feature_vecs_neg']:
c.update(list(vec[-25:]))
most_pred[st][an][label][sex]['count_top25_neg'] = c.most_common()
for st in subsets:
for ana in subana:
......@@ -733,8 +721,8 @@ def plotter(subsets, subana, phases, labels):
life_ph_pred_auth,
labels=phases)
f = plt.figure()
ax = f.subplots()
plot_confusion_matrix(cnf_matrix, classes=phases,title=None, ax=ax)
plt.show()
plt.tight_layout()
......@@ -891,7 +879,32 @@ def plotter(subsets, subana, phases, labels):
dic = {}
ind_pos = []
ind_neg = []
if label == 'age':
if label == 'gender':
norm = most_pred[st][an][label]['feature_vecs']/np.std(most_pred[st][an][label]['feature_vecs'])
ind_m = np.argsort(-norm)[:10]
ind_f = np.argsort(norm)[:10]
ind = [int(el) for el in list(ind_m)+ list(ind_f)]
coef = most_pred[st][an][label]['feature_vecs'][ind]
coef = [float(el) for el in list(coef)]
ind = [vocab[el].replace('§', '') for el in ind]
ind = [re.sub(r'\s', 'BLANK', el) for el in ind]
ind = [el.replace('$', r'\$') for el in ind]
ind = [el.replace('\n', 'BREAK') for el in ind]
#print(ind)
for i in range(0, len(ind)):
try:
ind[i].encode('ascii')
except:
ind[i] = ind[i].encode('unicode-escape')
df = pd.DataFrame({'gender (positive: male)': coef} , index = ind)
f,ax = plt.subplots(figsize=(18, len(ind)/10))
sns.heatmap(df, fmt= '.1f',ax=ax, center = 0, yticklabels = True)
f.savefig(savedir+'featureplots/features_heat_{}_{}_{}.pdf'.format(st, an, label))
f.savefig(savedir+'featureplots/features_heat_{}_{}_{}.png'.format(st, an, label))
elif label == 'age':
phase_key = []
for ph in phases:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment