# -*- coding: utf-8 -*-
"""
Created on Wed May  9 14:56:32 2018
Version: 2.8.0
@author: Holger Niemann, Peter Drewelow, Yu Gao

mainly to clean up the downloadversionIRdata code
Tools for:
    checking IR images, 
    calculate gain and offset again from 
    check backgroundframes
    check coldframes
    ...
"""
import numpy as np
import matplotlib.pyplot as plt
from IR_config_constants import portcamdict,IRCamRefImagespath,IRCAMBadPixels_path
import h5py
from os.path import join, basename
import glob
import datetime

def get_OP_by_time(time_ns=None, shot_no=None, program_str=None):
    '''Derives operation phase (OP) of W7-X based on either:
       a nanosacond time stamp, a MDSplus style shot no. or a program ID.
       IN:
          time_ns      - integer of nanosecond time stamp, 
                         e.g. 1511972727249834301 (OPTIONAL)
          shot_no      - integer of MDSplus style shot number, 
                         e.g. 171207022 (OPTIONAL)
          program_str  - string of CoDaQ ArchiveDB style prgram number or date, 
                         e.g. '20171207.022' or '20171207' (OPTIONAL)
       RETURN:
          conn         - MDSplus connection object, to be used in e.g. 1511972727249834301
                         read_MDSplus_image_simple(), read_MDSplus_metadata()
   '''
    # derive operation phase (OP) from time as nanosecond time stamp or string
    if time_ns is not None:
        dateOP = datetime.datetime.utcfromtimestamp(time_ns/1e9)
    elif shot_no is not None:
        dateOP = datetime.datetime.strptime(str(shot_no)[:6], '%y%m%d')
    elif program_str is not None:
        dateOP = datetime.datetime.strptime(program_str[:8], '%Y%m%d')
    else:
        raise Exception('get_OP_by_time: ERROR! neither time, shot no. or program ID provided')
        
    if dateOP.year == 2017:
        if dateOP.month>8 and dateOP.month<12:
            return "OP1.2a"
        elif dateOP.month==8 and dateOP.day>=28:
            return "OP1.2a"
        elif dateOP.month==12 and dateOP.day<8:
            return "OP1.2a"
        else:
            return None        
    elif dateOP.year == 2018:
        return "OP1.2b"
    elif dateOP.year <= 2016 and dateOP.year >= 2015:
        if (dateOP.year == 2016 and dateOP.month <= 3) or (dateOP.year == 2015 and dateOP.month == 12):
            return "OP1.1"
        else:
            return None

def bestimmtheitsmass_general(data,fit):
    R=0
    if len(fit)==len(data):
        mittel=np.sum(data)/len(data)
        qam=quad_abweich_mittel(fit,mittel)
        R=qam/(qam+quad_abweich(data,fit))
    else:
        print("Arrays must have same dimensions")
    return R
    
def quad_abweich_mittel(data,mittel):
    R=0
    for i in data:
        R=R+(i-mittel)**2
    return R
    
def quad_abweich(data,fit):
    R=0
    if len(fit)==len(data):
        for i in range(len(data)):
            R=R+(data[i]-fit[i])**2
    else:
        print("Arrays must have same dimensions")
    return R    
        
def find_nearest(array,value):
    #a=array
    a = [x - value for x in array]
    mini = np.min(np.abs(a))
    try: idx= a.index(mini)
    except: idx= a.index(-mini)
    return idx#array[idx]
        
def check_coldframe(coldframe,references=None,threshold=0.5,plot_it=False):
    '''
    return true/false and the quality factor
    '''
    shapi=np.shape(coldframe)
    ##function  (np.arange(0,768)-384)**(2)/900-50
    datasets=[]
    for i in [int(shapi[1]//4),int(shapi[1]//2),int(shapi[1]//4*3)]:
        dataline=coldframe[0:shapi[0],i]    
        datasets.append(dataline-np.mean(dataline))
    if references==None:
        references=[]
        for dat in datasets:        
            mini=np.mean(dat[shapi[0]/2-50:shapi[0]/2+50])
            a=(np.mean(dat[0:50])-mini)/(int(shapi[0]/2))**2
            reference=a*(np.arange(0,shapi[0])-int(shapi[0]/2))**(2)+mini
            references.append(reference)
    bestimmtheit=[]
    if plot_it:
        plt.figure()
        plt.imshow(coldframe,vmin=np.mean(coldframe)-500,vmax=np.mean(coldframe)+500)
        plt.figure()
    for i_dat in range(len(datasets)):
        dat=datasets[i_dat]
        reference=references[i_dat]
        bestimmtheit.append(bestimmtheitsmass_general(dat,reference))
        if plot_it:            
            plt.plot(dat,label='data')
            plt.plot(reference,label='reference')
#            print(int(shapi[0]/2),1*(np.max(datasets[-1])-mini),mini)
            plt.legend()
    if np.mean(bestimmtheit)>threshold:
        return True,bestimmtheit
    else:
        return False,bestimmtheit

def check_coldframe_by_refframe(coldframe,reference_frame,threshold=0.8,plot_it=False):
    references=[]
    shapi=np.shape(reference_frame)
    for i in [int(shapi[1]//5),int(shapi[1]//2),int(shapi[1]//4*3)]:
        dataline=reference_frame[0:shapi[0],i]    
        references.append(dataline-np.mean(dataline))
    return check_coldframe(coldframe,references,threshold,plot_it)
        
def check_backgroundframe(backgroundframe,threshold=50):
    '''
    return true or false
    '''
    shapi=np.shape(backgroundframe)
    valid=True
    dataset=[]
    for i in [int(shapi[1]//4),int(shapi[1]//2),int(shapi[1]//4*3)]:
        referenceline=backgroundframe[0:shapi[0],i]    
        meanref=referenceline-np.mean(referenceline)
        dataset.append(np.max(meanref)-np.min(meanref))
    if np.mean(dataset)<threshold:
        valid=False    
    return valid,np.mean(dataset)
    
def read_bad_pixels_from_file(port, shot_no=None, program=None,time_ns=None):
    '''Reads bad pixels stored in *.bpx file on E4 server.
       Requires one of the optional arguments shot_no or program.
        IN
            port            - integer of port no of camera
            shot_no         - integer of MDSplus style shot number, e.g. 171207022 (OPTIONAL)
            program         - string of CoDaQ ArchiveDB style prgram number or date, 
                              e.g. '20171207.022' or '20171207' (OPTIONAL)
        OUT
            bad_pixle_list  - list of tuples (row,column) of pixel coordinates 
                              as integer
    '''
    if shot_no is not None:    
        OP = get_OP_by_time(shot_no=shot_no)
    elif program is not None:
        OP = get_OP_by_time(program_str=program)
    elif time_ns is not None:
        OP = get_OP_by_time(time_ns=time_ns)
    else:
        raise Exception('read_bad_pixels_from_file: ERROR! Need either shot no. or program string.')
        
    port_name = 'AEF{0}'.format(port)
    bad_pixel_file = 'badpixel_{0}.bpx'.format(portcamdict[OP][port_name][6:])
    try:
        data = np.genfromtxt(IRCAMBadPixels_path+bad_pixel_file, dtype=int)
        bad_pixle_list = list(zip(data[:,1], data[:,0]))
    except:
        bad_pixle_list=[]
    return bad_pixle_list
    

def find_outlier_pixels(frame,tolerance=3,worry_about_edges=True,plot_it=False):
    # This function finds the bad pixels in a 2D dataset. 
    # Tolerance is the number of standard deviations used for cutoff.
    frame = np.array(frame)#, dtype=int)
    from scipy.ndimage import median_filter
    blurred = median_filter(frame, size=9)
    difference = frame - blurred
    threshold = tolerance*np.std(difference)
    mean = np.mean(difference)
    if plot_it:
        
        fig = plt.figure()
        fig.suptitle('find_outlier_pixels: histogram')
        plt.hist(difference.ravel(),50,log=True,histtype='stepfilled')
        plt.axvline(mean, linewidth=2, color='k',label='mean')
        x1 = mean - np.std(difference)
        x2 = mean + np.std(difference)
        plt.axvspan(x1,x2, linewidth=2, facecolor='g',alpha=0.1,label='standard deviation')
        x1 = mean - tolerance*np.std(difference)
        x2 = mean + tolerance*np.std(difference)
        plt.axvspan(x1,x2, linewidth=2, facecolor='r',alpha=0.1,label='threshold for bad pixel')
        plt.legend()
        plt.show()
        
    #find the hot pixels
    bad_pixels = np.transpose(np.nonzero((np.abs(difference)>threshold)) )
    bad_pixels = (bad_pixels).tolist()
    bad_pixels = [tuple(l) for l in bad_pixels]
    
    if plot_it:
        plt.figure()
        plt.imshow(frame)
        for i in range(len(bad_pixels)):
            plt.scatter(bad_pixels[i][1],bad_pixels[i][0],c='None')
        plt.show()
   
    return bad_pixels

def correct_images(images,badpixels):
    print('correct_images: New routine restore_bad_pixels() is used and can be called directly. Check out "help(restore_bad_pixels)"')
    if type(badpixels)!=int:
        if type(images) == list:
            # return corrected images also as list of 2D arrays
#            images = restore_bad_pixels(images, np.invert(badpixels==1))#.astype(np.float32)
#            images = list(images)
            for i in range(len(images)):
                images[i]=restore_bad_pixels(images[i], np.invert(badpixels==1))
        else:
            # keep shape
            images = restore_bad_pixels(images, np.invert(badpixels==1)).astype(np.float32)
        
#        for i in range(len(images)):
#            images[i]=(restore_pixels(images[i],np.invert(badpixels==1))).astype(np.float32)
        print("done")
    return images


def restore_bad_pixels(frames, bad_pixel, by_list=True, check_neighbours=True, plot_it=False, verbose=0):
    """Restore bad pixel by interpolation of adjacent pixels. Optionally make 
       sure that adjacent pixels are not bad (time consuming). Default is to use 
       a list of bad pixels and a for loop. For many bad pixels consider using 
       the optinal alternative using a bad pixel mask.
        IN:
            frames              - either list of frames as 2D numpy array, 
                                  or 3D numpy array (frame number, n_rows, n_cols),
                                  or 2D numpy array (n_rows, n_cols)
            bad_pixel           - either list of tuples of bad pixel coordinates,
                                  or mask of pixel status (good=True, bad=False)
            by_list             - boolean of whether to use a list and a for loop (True),
                                  or to use a mask of bad pixel and array operations (False)
                                  (OPTIONAL: if not provided, True (list) is default)
            check_neighbours    - boolean of whether to check if neighbours of a bad pixel
                                  are not bad either before computing a mean
                                  (works only in list mode!)
                                  (OPTIONAL: if not provided, check is on)
            plot_it             - boolean to decide whether to plot intermediate
                                  results or not
                                  (OPTIONAL: if not provided, switched off)
            verbose             - integer of feedback level (amount of prints)
                                  (OPTIONAL: if not provided, only ERROR output)
        RETURN:
            frames              - 3D numpy array (frame number, n_rows, n_cols) of 
                                  corrected frames
    """

    # make sure frames is correctly shaped
    if type(frames) == list:
        frames = np.array(frames)
        frame_shape = 'list'
    else:
        if len(np.shape(frames)) == 2:
            frames = np.array([frames])
            frame_shape = '2D'
        elif len(np.shape(frames)) == 3:
            frame_shape = '3D'
            pass
        else:
            raise Exception('restore_bad_pixels: ERROR! Unexpected shape of frames.')
    frame_dtype = frames.dtype
#    frames = frames.astype(float)
    n_frames, n_rows, n_cols = np.shape(frames)
    if plot_it:
        start_frame = np.copy(frames[0])
    
    # make sure bad pixel are provided as mask and list  
    if type(bad_pixel) is list:
        blist = bad_pixel
        bmask = np.ones([n_rows, n_cols],dtype=bool)
        for pix in blist:
            try:
                bmask[pix] = False
            except Exception as E:
                Warning(E)
        bmask = np.invert(bmask)
    else:
        if np.shape(bad_pixel)[0] == n_rows and np.shape(bad_pixel)[1] == n_cols:
            bmask = np.invert(bad_pixel)            
            x,y = np.where(bmask)
            blist = list(zip(x,y))
        else:
            raise Exception('restore_bad_pixels: ERROR! bad_pixel in bad shape {0}'.format(np.shape(bad_pixel)))
            
    if verbose > 0:
        print('restore_bad_pixels: {0} bad pixels to be restored: {1} ... '.format(len(blist), blist[:3]))    
    
    # expand frame by rows and columns of zeros to simplify treatment of edges
    frames = np.dstack([np.zeros([n_frames,n_rows], dtype=frame_dtype), frames, np.zeros([n_frames,n_rows], dtype=frame_dtype)])
    frames = np.hstack([np.zeros([n_frames,1,n_cols+2], dtype=frame_dtype), frames, np.zeros([n_frames,1,n_cols+2], dtype=frame_dtype)])
    bmask = np.vstack([np.zeros([1,n_cols], dtype=bool), bmask, np.zeros([1,n_cols], dtype=bool)])
    bmask = np.hstack([np.zeros([n_rows+2,1], dtype=bool), bmask, np.zeros([n_rows+2,1], dtype=bool)])
    
    # define number of neighbours (up to 4) ina an array of expanded frame size
    n_neighbours = np.ones([n_frames, n_rows+2, n_cols+2])*4
    n_neighbours[:,1,:] = 3
    n_neighbours[:,-2,:] = 3
    n_neighbours[:,:,1] = 3
    n_neighbours[:,:,-2] = 3
    n_neighbours[:,1,1] = 2
    n_neighbours[:,1,-2] = 2
    n_neighbours[:,-2,1] = 2
    n_neighbours[:,-2,-2] = 2
    
    if by_list:
        # ===== correct bad pixels using the list of bad pixels =====
        #
        
        for pos in blist:
            # Note:
            # pos points to real frame coordinates, while bmask, n_neighbours have been expanded!
            
            if check_neighbours:
                # takes only neighbours that are not bad
                pos_l = np.where(bmask[pos[0]+1,:pos[1]+1]==False)[0]
                if len(pos_l) != 0:
                    pos_l = pos_l[-1]
                else: 
                    pos_l = pos[1]+1
                pos_r = np.where(bmask[pos[0]+1,pos[1]+1:]==False)[0]
                if len(pos_r) != 0:
                    pos_r = pos_r[0] + pos[1]+1
                else: 
                    pos_r = pos[1]+2
                pos_t = np.where(bmask[:pos[0]+1,pos[1]+1]==False)[0]
                if len(pos_t) != 0:
                    pos_t = pos_t[-1]
                else: 
                    pos_t = pos[0]+1
                pos_b = np.where(bmask[pos[0]+1:,pos[1]+1]==False)[0]
                if len(pos_b) != 0:
                    pos_b = pos_b[0] + pos[0]+1
                else: 
                    pos_b = pos[0]+2
            else:
                # insensitive to neighbours being bad as well!
                pos_l = pos[1]
                pos_r = pos[1]+2
                pos_t = pos[0]
                pos_b = pos[0]+2
            average = (frames[:,pos[0]+1,pos_l].astype(float) + 
                       frames[:,pos[0]+1,pos_r].astype(float) + 
                       frames[:,pos_t,pos[1]+1].astype(float) + 
                       frames[:,pos_b,pos[1]+1].astype(float)) / n_neighbours[:,pos[0]+1,pos[1]+1]
            frames[:,pos[0]+1,pos[1]+1] = average.astype(frame_dtype)
        frames = frames[:,1:-1,1:-1]
        
    else:
        # ======= correct bad pixels using the bad pixel mask =======
        #
        # (insensitive to neighbours being bad as well!)
   
        # prepare mask arrays for neighbours by shifting it to left, right, top and bottom
        bmask_l = np.hstack([bmask[:,1:], np.zeros([n_rows+2,1], dtype=bool)])
        bmask_r = np.hstack([np.zeros([n_rows+2,1], dtype=bool), bmask[:,:-1]])
        bmask_t = np.vstack([bmask[1:,:], np.zeros([1,n_cols+2], dtype=bool)])
        bmask_b = np.vstack([np.zeros([1,n_cols+2], dtype=bool), bmask[:-1,:]])
        
    
        # -----------------------------------
        # restore by mask
        #
        frames[:,bmask] = ( (frames[:,bmask_l].astype(float) + 
                             frames[:,bmask_r].astype(float) + 
                             frames[:,bmask_t].astype(float) + 
                             frames[:,bmask_b].astype(float)) / n_neighbours[:,bmask] ).astype(frame_dtype)
        frames = frames[:,1:-1,1:-1]
    
    # plot comparison
    if plot_it:
        plt.figure()
        plt.title('bad pixel correction of first frame')
        m = np.mean(start_frame)
        s = np.std(start_frame)
        plt.imshow(start_frame, vmin=m-s, vmax=m+s)
        plt.colorbar()
        x,y = zip(*blist)
        plt.scatter(y,x, marker='o', s=5, c='r', linewidths=1)
        plt.tight_layout()
        plt.show()

    if frame_shape == 'list':
        frames = list(frames)
    elif frame_shape == '2D' and len(np.shape(frames))==3:
        frames = frames[0]
        
    return frames


def generate_new_hot_image(cold,reference_cold,reference_hot):
    if cold is None or reference_cold is None or reference_hot is None:
        raise Exception("Cannot Calculate new Hot image, if images are missing!")
    else:
        return reference_hot+(cold-reference_cold)
    
def calculate_gain_offset_image_pix(cold_image,hot_image=None,reference_cold=None,reference_hot=None,bose=1):    
    if hot_image is None:
        hot_image=generate_new_hot_image(cold_image,reference_cold,reference_hot)
    if bose>0:
        print("calculate gain and offset")        
    Sh_ref =  hot_image[ ( np.int( np.shape(hot_image)[0]   /2  )  ) ][np.int( (np.shape(hot_image)[1]   /2  ) ) ]          
    Sc_ref =  cold_image[ ( np.int(  (np.shape(cold_image)[0])  /2 )  ) ][( np.int(  (np.shape(cold_image)[1])  /2 ) ) ]  
    Gain_rel =  ( Sh_ref  - Sc_ref ) / ( hot_image  - cold_image)    
    Off_h_rel = Sh_ref -   hot_image*Gain_rel
    Off_c_rel = Sc_ref -   cold_image*Gain_rel    
    Offset_rel  = ( Off_h_rel + Off_c_rel ) /2
    return Gain_rel,Offset_rel

def calculate_gain_offset_image(cold_image,hot_image=None,reference_cold=None,reference_hot=None,verbose=0):    
    if hot_image is None:
        hot_image=generate_new_hot_image(cold_image,reference_cold,reference_hot)
    if verbose>0:
        print("calculate gain and offset")  
    
#    Sh_ref =  hot_image[ ( np.int( np.shape(hot_image)[0]   /2  )  ) ][np.int( (np.shape(hot_image)[1]   /2  ) ) ]          
#    Sc_ref =  cold_image[ ( np.int(  (np.shape(cold_image)[0])  /2 )  ) ][( np.int(  (np.shape(cold_image)[1])  /2 ) ) ]  
#    print(hot_image[( np.int( np.shape(hot_image)[0]/2) )-2: (np.int( np.shape(hot_image)[0]/2))+3,np.int((np.shape(hot_image)[1]/2))-2:np.int((np.shape(hot_image)[1]/2))+3 ])
#    print(cold_image[( np.int( np.shape(hot_image)[0]/2) )-2: (np.int( np.shape(hot_image)[0]/2))+3,np.int((np.shape(hot_image)[1]/2))-2:np.int((np.shape(hot_image)[1]/2))+3 ])
    Sh_ref =  np.mean( hot_image[( np.int( np.shape(hot_image)[0]/2) )-2: (np.int( np.shape(hot_image)[0]/2))+3,np.int((np.shape(hot_image)[1]/2))-2:np.int((np.shape(hot_image)[1]/2))+3 ])    
    Sc_ref =  np.mean(cold_image[( np.int( np.shape(cold_image)[0]/2) )-2: (np.int( np.shape(cold_image)[0]/2))+3,np.int((np.shape(cold_image)[1]/2))-2:np.int((np.shape(cold_image)[1]/2))+3 ])    
    difference_image=hot_image  - cold_image
    indexlist=np.where(difference_image==0)
    difference_image[indexlist]=0.001
    Gain_rel =  ( Sh_ref  - Sc_ref ) / ( difference_image)    
    Gain_rel[indexlist]=0
    Off_h_rel = Sh_ref -   hot_image*Gain_rel
    Off_c_rel = Sc_ref -   cold_image*Gain_rel    
    Offset_rel  = ( Off_h_rel + Off_c_rel ) /2
    return Gain_rel,Offset_rel    
    
#%% functions from Yu Gao
""" functions by Yu Gao"""
    
def load_ref_images(port, exposuretime):
  '''
  load the reference cold and hot frame during calibration from local files.
  @port: e.g. 'AEF10'
  @exposuretime: int number.
  '''
  cameraname = portcamdict['OP1.2a'][port]
  foldername = cameraname.split('_')[0] + '_' + cameraname.split('_')[2]
  scanpath = join(IRCamRefImagespath, foldername)
  coldref, hotref = [], []
  for filename in glob.iglob(scanpath + '\*' + str(int(exposuretime)) + 'us.h5', recursive=True):
    if 'hot' in filename:
      print (filename)
      with h5py.File(filename, 'r') as h5in:
        hotref = h5in[basename(filename)].value
    elif 'cold' in filename:
      print (filename)
      with h5py.File(filename, 'r') as h5in:
        coldref = h5in[basename(filename)].value
  return coldref, hotref

def reconstruct_coldframe (exposuretime, sT, a, bnew, coldref):
  cirebuild = a * sT + bnew * exposuretime + coldref
  return cirebuild
  
  
#%% other functions
def check_dublicates(array):
    a = array
    import collections
    return [item for item, count in collections.Counter(a).items() if count > 1]
    
def check_dublicates_2(array):
    seen = set()
    uniq = []
    for x in array:
        if x not in seen:
            uniq.append(x)
            seen.add(x)
    return uniq,seen