diff --git a/train_classifier_rf.py b/train_classifier_rf.py index f0e506b5ef06911c375974a91fce9baba376e560..ca818968602637adad7fa3abed8cc11a34b742c8 100644 --- a/train_classifier_rf.py +++ b/train_classifier_rf.py @@ -16,8 +16,6 @@ from astropy import units as u from matplotlib import pyplot, colors -from MyFunctions import GetHist - def info_message(text, prefix='info'): """ @@ -35,7 +33,19 @@ def info_message(text, prefix='info'): date_str = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S") print(f"({prefix:s}) {date_str:s}: {text:s}") - + + +def GetHist(data, bins=30, range=None, weights=None): + hs, edges = scipy.histogram(data, bins=bins, range=range, weights=weights) + loc = (edges[1:] + edges[:-1]) / 2 + + hist = {} + hist['Hist'] = hs + hist['X'] = loc + hist['XEdges'] = edges + + return hist + def evaluate_performance(data, class0_name='event_class_0'): data = data.dropna() diff --git a/train_energy_rf.py b/train_energy_rf.py index edd6d12c1eda664087591837f5ed0800506d3ba8..5a91924364fde6fc0fe98ee6d133b6df4a44247d 100644 --- a/train_energy_rf.py +++ b/train_energy_rf.py @@ -24,8 +24,6 @@ from astropy import units as u from matplotlib import pyplot, colors -from MyFunctions import GetHist2D - def info_message(text, prefix='info'): """ @@ -43,7 +41,26 @@ def info_message(text, prefix='info'): date_str = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S") print(f"({prefix:s}) {date_str:s}: {text:s}") - + + +def GetHist2D(x,y, bins=30, range=None, weights=None): + hs, xedges, yedges = scipy.histogram2d(x,y, bins=bins, range=range, weights=weights) + xloc = (xedges[1:] + xedges[:-1]) / 2 + yloc = (yedges[1:] + yedges[:-1]) / 2 + + xxloc, yyloc = scipy.meshgrid( xloc, yloc, indexing='ij' ) + + hist = {} + hist['Hist'] = hs + hist['X'] = xloc + hist['Y'] = yloc + hist['XX'] = xxloc + hist['YY'] = yyloc + hist['XEdges'] = xedges + hist['YEdges'] = yedges + + return hist + def evaluate_performance(data, energy_name): valid_data = data.dropna(subset=[energy_name])