Commit 9b1dd4a5 authored by Marcel Henrik Schubert's avatar Marcel Henrik Schubert
Browse files

added canvas

parent 3b1463e3
......@@ -27,6 +27,8 @@ from sklearn.linear_model import PassiveAggressiveClassifier, SGDClassifier
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import matplotlib
matplotlib.use('agg')
import matplotlib.colors as cl
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
......@@ -54,6 +56,11 @@ labels_l = ['age', 'gender', 'author']
phases_l = ['child_21', 'young_adult_35', 'adult_50', 'old_adult_65', 'retiree']
#get_ipython().run_line_magic('matplotlib', 'inline')
#import matplotlib
cmap = cl.LinearSegmentedColormap.from_list("", ["skyblue","cadetblue","darkblue", "steelblue"]) #define colors
# In[3]:
......@@ -80,7 +87,7 @@ phases_l = ['child_21', 'young_adult_35', 'adult_50', 'old_adult_65', 'retiree']
def identity_tokenizer(text):
return text
def plot_confusion_matrix(cm, classes,normalize=True,title='Confusion matrix',cmap=plt.cm.Blues, ax = None):
def plot_confusion_matrix(cm, classes,normalize=True,title='Confusion matrix',cmap=plt.cm.Blues, ax = None, fig = None):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
......@@ -90,7 +97,6 @@ def plot_confusion_matrix(cm, classes,normalize=True,title='Confusion matrix',cm
cm_old = cm
print(cm_old)
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
......@@ -100,7 +106,7 @@ def plot_confusion_matrix(cm, classes,normalize=True,title='Confusion matrix',cm
im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
#plt.title(title)
cbar = ax.figure.colorbar(im, ax = ax, ticks=[], label = "Heat per Row (Normalized from 0 to 1)")
tick_marks = np.arange(len(classes))
#ax.set_xticks(tick_marks)
......@@ -131,7 +137,7 @@ def plot_confusion_matrix(cm, classes,normalize=True,title='Confusion matrix',cm
color="white" if cm[i, j] > thresh else "black")
plt.tight_layout()
return fig, ax
# Start with analysis of precision and recall as well as heatmaps/confusion matrices:
......@@ -542,27 +548,12 @@ def plotter(subsets, subana, phases, labels):
print(len(most_pred[200]['org']['author']['retiree']['feature_vecs_pos']))
pos = most_pred[200]['org']['author']['retiree']['count_tot_pos'][0:26]
arr = most_pred[200]['org']['author']['retiree']['val_array']
maxi = most_pred[200]['org']['author']['retiree']['max_array']
print(pos)
print([(el[0], arr[el[0]], maxi[el[0]]) for el in pos])
print(len(most_pred[200]['org']['author']['retiree']['feature_vecs_neg']))
neg = most_pred[200]['org']['author']['retiree']['count_tot_neg'][0:26]
mini = most_pred[200]['org']['author']['retiree']['min_array']
print(neg)
print([(el[0], arr[el[0]], mini[el[0]]) for el in neg])
arr
#get_ipython().run_line_magic('matplotlib', 'inline')
#import matplotlib
cmap = cl.LinearSegmentedColormap.from_list("", ["skyblue","cadetblue","darkblue", "steelblue"]) #define colors
......@@ -613,14 +604,15 @@ def plotter(subsets, subana, phases, labels):
labels=phases)
f = plt.figure()
ax = f.subplots()
plot_confusion_matrix(cnf_matrix, classes=phases,title=None, ax=ax)
plt.show()
f, ax plot_confusion_matrix(cnf_matrix, classes=phases,title=None, ax=ax, plt = f)
plt.tight_layout()
#plt.show()
#f.savefig('../Data/results/heatmaps/test.png')
f.savefig(savedir+ 'heatmaps/cm_{st}_{an}_{label}_{group}.jpg'.format(st = st,
f.savefig(savedir+ 'heatmaps/cm_{st}_{an}_{label}_{group}.png'.format(st = st,
an=an,
label=label,
group='life_phase'))
plt.close()
plt.close(fig = f)
......@@ -662,8 +654,8 @@ def plotter(subsets, subana, phases, labels):
life_ph_pred_auth,
labels=phases)
f = plt.figure()
plot_confusion_matrix(cnf_matrix, classes=phases,title=None)
plt.show()
f, ax = plot_confusion_matrix(cnf_matrix, classes=phases,title=None)
plt.tight_layout()
f.savefig(savedir+'heatmaps/cm_{st}_{an}_{label}_{group}.png'.format(st = st, an=an,
label=label,
group='life_phase'))
......@@ -688,8 +680,8 @@ def plotter(subsets, subana, phases, labels):
gen_pred_auth,
labels=['female', 'male'])
f = plt.figure()
plot_confusion_matrix(cnf_matrix, classes=['female', 'male'],title=None)
plt.show()
f, ax = plot_confusion_matrix(cnf_matrix, classes=['female', 'male'],title=None)
plt.tight_layout()
f.savefig(savedir+'heatmaps/cm_{st}_{an}_{label}_{group}.png'.format(st = st, an=an,
label=label,
group='gender'))
......@@ -700,8 +692,8 @@ def plotter(subsets, subana, phases, labels):
life_ph_pred_auth_wrong,
labels=phases)
f = plt.figure()
plot_confusion_matrix(cnf_matrix, classes=phases,title=None)
plt.show()
f, ax = plot_confusion_matrix(cnf_matrix, classes=phases,title=None)
plt.tight_layout()
f.savefig(savedir+'heatmaps/cm_{st}_{an}_{label}_{group}_false.png'.format(st = st, an=an,
label=label,
group='life_phase'))
......@@ -711,8 +703,8 @@ def plotter(subsets, subana, phases, labels):
gen_pred_auth_wrong,
labels=['female', 'male'])
f = plt.figure()
plot_confusion_matrix(cnf_matrix, classes=['female', 'male'],title=None)
plt.show()
f, ax = plot_confusion_matrix(cnf_matrix, classes=['female', 'male'],title=None)
plt.tight_layout()
f.savefig(savedir+'heatmaps/cm_{st}_{an}_{label}_{group}_false.png'.format(st = st, an=an,
label=label,
group='gender'))
......@@ -780,7 +772,7 @@ def plotter(subsets, subana, phases, labels):
#print([vocab[el] for el in ind_pos + ind_neg])
df = pd.DataFrame(tmp, index = [vocab[el].replace('§', '').encode('unicode-escape') for el in ind_pos + ind_neg])
f,ax = plt.subplots(figsize=(18, len(ind_pos+ind_neg)/6))
sns.heatmap(df, fmt= '.1f',ax=ax, center = 0, yticklabels = True)
ax = sns.heatmap(df, fmt= '.1f',ax=ax, center = 0, yticklabels = True)
f.savefig(savedir+'featureplots/features_heat_{}_{}_{}_phases.png'.format(st, an, label))
elif label == 'author':
......@@ -827,7 +819,7 @@ def plotter(subsets, subana, phases, labels):
#print([vocab[el] for el in ind_pos + ind_neg])
df = pd.DataFrame(tmp, index = [vocab[el].replace('§', '').encode('unicode-escape') for el in ind])
f,ax = plt.subplots(figsize=(18, len(ind_pos+ind_neg)/6))
sns.heatmap(df, fmt= '.1f',ax=ax, center = 0, yticklabels = True)
ax = sns.heatmap(df, fmt= '.1f',ax=ax, center = 0, yticklabels = True)
f.savefig(savedir+'featureplots/features_heat_{}_{}_{}_phases.png'.format(st, an, label))
dic = {}
......@@ -861,7 +853,7 @@ def plotter(subsets, subana, phases, labels):
ind = ind_pos+ind_neg
df = pd.DataFrame(tmp, index = [vocab[el].replace('§', '').encode('unicode-escape') for el in ind])
f,ax = plt.subplots(figsize=(18, len(ind_pos+ind_neg)/6))
sns.heatmap(df, fmt= '.1f',ax=ax, center = 0, yticklabels = True)
ax = sns.heatmap(df, fmt= '.1f',ax=ax, center = 0, yticklabels = True)
f.savefig(savedir+'featureplots/features_heat_{}_{}_{}_gender_phases.png'.format(st, an, label))
return 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment