From 6002b92bacbf38c7ef16373a0c39279587e8b501 Mon Sep 17 00:00:00 2001
From: Berk Onat <b.onat@warwick.ac.uk>
Date: Fri, 30 Mar 2018 20:00:54 +0100
Subject: [PATCH] Added topology converter for Tinker files at MDDataAccess and
 replay settings to SmartParserCommon

---
 .../nomadcore/md_data_access/MDDataAccess.py  | 193 ++++++++++++------
 1 file changed, 131 insertions(+), 62 deletions(-)

diff --git a/common/python/nomadcore/md_data_access/MDDataAccess.py b/common/python/nomadcore/md_data_access/MDDataAccess.py
index 4e2eb98..97548a4 100644
--- a/common/python/nomadcore/md_data_access/MDDataAccess.py
+++ b/common/python/nomadcore/md_data_access/MDDataAccess.py
@@ -337,6 +337,8 @@ def pmdConvertTopoDict(topoStor, topoPMD):
     topo=None
     if(isinstance(topoPMD, pmd.charmm.CharmmParameterSet) or
        isinstance(topoPMD, pmd.charmm.CharmmPsfFile) or
+       isinstance(topoPMD, pmd.tinker.tinkerfiles.DynFile) or
+       isinstance(topoPMD, pmd.tinker.tinkerfiles.XyzFile) or
        (topoPMD is None and topoStor.charmmcoor is not None) or 
        topoStor.charmmcoortopo is not None):
         topo = topoPMD
@@ -446,16 +448,29 @@ def pmdConvertTopoDict(topoStor, topoPMD):
         nbfixList = None
         # If topo is a PsfFile, the parameter files were not supplied.
         # Hence, we can only extract topo info from Psf class in topo.
-        if isinstance(topoPMD, pmd.charmm.CharmmPsfFile):
-            structure = [[getX(x,'name'), getX(x,'type'), getX(x,'residue','name'), 
-                          getX(x,'residue','_idx'), getX(x,'residue','segid'), 
-                          getX(x,'residue','chain'), getX(x,'altloc'), 
-                          getX(x,'irotat'), getX(x,'occupancy'), getX(x,'bfactor'), 
-                          getX(x,'mass'), getX(x,'_charge'), getX(x,'solvent_radius'), 
-                          getX(x,'atomic_number'), getX(x,'tree'), getX(x,'atom_type','name'), 
-                          getX(x,'atom_type','epsilon'), getX(x,'atom_type','rmin'), 
-                          getX(x,'atom_type','epsilon_14'), getX(x,'atom_type','rmin_14'), 
-                          getX(x,'atom_type','nbfix')] for x in topoPMD.atoms]
+        if(isinstance(topoPMD, pmd.charmm.CharmmPsfFile) or
+           isinstance(topoPMD, pmd.tinker.tinkerfiles.XyzFile)):
+            unassignedAtomType = False
+            try:
+                structure = [[getX(x,'name'), getX(x,'type'), getX(x,'residue','name'), 
+                              getX(x,'residue','_idx'), getX(x,'residue','segid'), 
+                              getX(x,'residue','chain'), getX(x,'altloc'), 
+                              getX(x,'irotat'), getX(x,'occupancy'), getX(x,'bfactor'), 
+                              getX(x,'mass'), getX(x,'_charge'), getX(x,'solvent_radius'), 
+                              getX(x,'atomic_number'), getX(x,'tree'), getX(x,'atom_type','name'), 
+                              getX(x,'atom_type','epsilon'), getX(x,'atom_type','rmin'), 
+                              getX(x,'atom_type','epsilon_14'), getX(x,'atom_type','rmin_14'), 
+                              getX(x,'atom_type','nbfix')] for x in topoPMD.atoms]
+            except AttributeError:
+                unassignedAtomType = True
+                structure = [[getX(x,'name'), getX(x,'type'), getX(x,'residue.name'), 
+                              getX(x,'residue._idx'), getX(x,'residue.segid'), 
+                              getX(x,'residue.chain'), getX(x,'altloc'), 
+                              getX(x,'irotat'), getX(x,'occupancy'), getX(x,'bfactor'), getX(x,'mass'), 
+                              getX(x,'._charge'), getX(x,'solvent_radius'), getX(x,'atomic_number'), 
+                              getX(x,'tree'), '', None, None, None, None, 
+                              None] for x in topoPMD.atoms]
+
             chainList = [a[5] for a in structure]
             segmentList = [a[4] for a in structure]
             residList = [a[3] for a in structure]
@@ -726,19 +741,23 @@ def pmdConvertTopoDict(topoStor, topoPMD):
         chargesList = list(chargeDict.values())
         bfactorList = list(bfactorDict.values())
 
-        if isinstance(topoPMD, pmd.charmm.CharmmPsfFile):
-            if topoPMD.bonds[0].type is not None:
-                for bi, bondtype in enumerate(topoPMD.bonds[0].type.list):
-                    topoPMD.bonds[0].type.list[bi]._idx=bi
-                topo_bond_list = np.asarray([
-                    [x.atom1._idx, x.atom2._idx, x.type._idx, 
-                     x.type.k, x.type.req] for x in topoPMD.bonds
-                    ])
+        if(isinstance(topoPMD, pmd.charmm.CharmmPsfFile) or
+           isinstance(topoPMD, pmd.tinker.tinkerfiles.XyzFile)):
+            if getattr(topoPMD, 'bonds'):
+                if topoPMD.bonds[0].type is not None:
+                    for bi, bondtype in enumerate(topoPMD.bonds[0].type.list):
+                        topoPMD.bonds[0].type.list[bi]._idx=bi
+                    topo_bond_list = np.asarray([
+                        [x.atom1._idx, x.atom2._idx, x.type._idx, 
+                         x.type.k, x.type.req] for x in topoPMD.bonds
+                        ])
+                else:
+                    topo_bond_list = np.asarray([
+                        [x.atom1._idx, x.atom2._idx] for x in topoPMD.bonds
+                        ])
+                topbList = topo_bond_list[:,0:2]
             else:
-                topo_bond_list = np.asarray([
-                    [x.atom1._idx, x.atom2._idx] for x in topoPMD.bonds
-                    ])
-            topbList = topo_bond_list[:,0:2]
+                topbList = []
         elif(isinstance(topoPMD, pmd.charmm.CharmmParameterSet) and
             (topoStor.charmmcoor is not None or 
              topoStor.charmmcoortopo is not None)):
@@ -1410,7 +1429,10 @@ def mdtConvertTopoDict(topoStor, topoMDT):
                     radiusDict.update({atom.name : atom.element.radius})
 
         massesList = list(massesDict.values())
-        atom_type_list = list(topologyDict["name"].values())
+        try:
+            atom_type_list = list(topologyDict["name"].values())
+        except AttributeError:
+            atom_type_list = topologyDict["name"]
         atom_element_list = [atom.element.symbol for atom in topo.atoms]
         elementList = list(elementDict.values())
         radiusList = list(radiusDict.values())
@@ -2213,10 +2235,10 @@ class MDDataAccess(object):
         if self.topofile:
             topofilename = os.path.basename(self.topofile)
         if self.topoformat is None and topofilename is not None:
-            file_format = self.get_file_format(topofilename, self.topoformat)
-            self.topoformat = file_format
+            fileloadformat = self.get_file_format(topofilename, self.topoformat)
+            self.topoformat = fileloadformat
         else:
-            file_format = self.topoformat
+            fileloadformat = self.topoformat
 
         usedefault=True
         # Use the given order to check topology
@@ -2239,7 +2261,7 @@ class MDDataAccess(object):
                 if "charmmcoor" in interface:
                     topohandler_check = None
                     charmmcoor_dict = None
-                    filetopoformat = re.sub('[.]', '', file_format)
+                    filetopoformat = re.sub('[.]', '', fileloadformat)
                     if('CHARMMCOR' == filetopoformat.upper() or 
                        'CHARMMCOOR' == filetopoformat.upper() or 
                        #'COR' == file_format.upper() or 
@@ -2274,7 +2296,7 @@ class MDDataAccess(object):
                             self.topocode = "charmmcoor"
                             break
                 elif "parmed" in interface:
-                    filetopoformat = re.sub('[.]', '', file_format)
+                    filetopoformat = re.sub('[.]', '', fileloadformat)
                     self.topohandler = self.load_parmed_topology(filetopoformat, base_topo=self.topohandler)
                     if self.topohandler:
                         usedefault=False
@@ -2282,21 +2304,21 @@ class MDDataAccess(object):
                         break
                 elif "pymolfile" in interface:
                     if self.topohandler is None:
-                        self.topohandler = self.load_pymolfile_topology(file_format)
+                        self.topohandler = self.load_pymolfile_topology(fileloadformat)
                     if self.topohandler:
                         usedefault=False
                         self.topocode = "pymolfile"
                         break
                 elif "mdtraj" in interface:
                     if self.topohandler is None:
-                        self.topohandler = self.load_mdtraj_topology(file_format)
+                        self.topohandler = self.load_mdtraj_topology(fileloadformat)
                     if self.topohandler:
                         usedefault=False
                         self.topocode = "mdtraj"
                         break
                 elif "mdanalysis" in interface:
                     if self.topohandler is None:
-                        self.topohandler = self.load_mdanalysis_topology(file_format)
+                        self.topohandler = self.load_mdanalysis_topology(fileloadformat)
                     if self.topohandler:
                         usedefault=False
                         self.topocode = "mdanalysis"
@@ -2304,7 +2326,7 @@ class MDDataAccess(object):
                 elif "ase" in interface:
                     if self.topohandler is None:
                         if self.topofile:
-                            self.topohandler = self.load_ase_support(self.topofile, file_format=file_format)
+                            self.topohandler = self.load_ase_support(self.topofile, file_format=fileloadformat)
                     if self.topohandler:
                         usedefault=False
                         self.topocode = "ase"
@@ -2313,7 +2335,7 @@ class MDDataAccess(object):
                     if isinstance(self.UserSuppliedInterface, MDDataAccess.UserSuppliedInterface):
                         if self.UserSuppliedInterface.name in interface:
                             if self.topohandler is None:
-                                self.topohandler = self.UserSuppliedInterface.topology_support(self.topofile, file_format=file_format)
+                                self.topohandler = self.UserSuppliedInterface.topology_support(self.topofile, file_format=fileloadformat)
                             if self.topohandler:
                                 usedefault=False
                                 self.topocode = self.UserSuppliedInterface.name
@@ -2324,26 +2346,26 @@ class MDDataAccess(object):
         if usedefault:
             if self.topohandler is None:
                 # Nothing to lose to be heroistic here.
-                self.topohandler = self.load_pymolfile_topology(file_format)
+                self.topohandler = self.load_pymolfile_topology(fileloadformat)
             if self.topohandler is None:
                 self.topocode = "pymolfile"
 
             if self.topohandler is None:
-                self.topohandler = self.load_mdtraj_topology(file_format)
+                self.topohandler = self.load_mdtraj_topology(fileloadformat)
                 if self.topohandler:
                     self.topocode = "mdtraj"
 
             # If MDTraj does not have support for the format 
             # or can not load the topology, use MDAnalysis and ASE.
             if self.topohandler is None:
-                self.topohandler = self.load_mdanalysis_topology(file_format)
+                self.topohandler = self.load_mdanalysis_topology(fileloadformat)
                 if self.topohandler:
                     self.topocode = "mdanalysis"
 
             # Fall back to check ASE support
             if self.topohandler is None:
                 ase_support = False
-                ase_support = self.get_ase_format_support(file_format)
+                ase_support = self.get_ase_format_support(fileloadformat)
                 # May still have chance that ASE can recognize the 
                 # format with its filetype checking function
                 if ase_support is None:
@@ -2353,7 +2375,7 @@ class MDDataAccess(object):
                         self.topocode = "ase"
                 else:
                     if self.topofile:
-                        self.topohandler = self.load_ase_support(self.topofile, file_format=file_format)
+                        self.topohandler = self.load_ase_support(self.topofile, file_format=fileloadformat)
                         self.topocode = "ase"
 
         # If no success after all attempts return False
@@ -2383,16 +2405,16 @@ class MDDataAccess(object):
         if chkfile:
             chkfilename = os.path.basename(chkfile)
         if chkformat is None:
-            file_format = self.get_file_format(chkfile, chkformat)
+            fileloadformat = self.get_file_format(chkfile, chkformat)
             chkformat = file_format
             if "input" in filetype:
-                self.incoordformat = file_format
+                self.incoordformat = fileloadformat
             elif "output" in filetype:
-                self.outcoordformat = file_format
+                self.outcoordformat = fileloadformat
             else:
-                self.trajformat = file_format
+                self.trajformat = fileloadformat
         else:
-            file_format = chkformat
+            fileloadformat = chkformat
 
         usedefault=True
         # Use the given order to check topology
@@ -2445,7 +2467,7 @@ class MDDataAccess(object):
                 if "charmmcoor" in interface:
                     trajhandler_check = None
                     charmmcoor_dict = None
-                    filetrajformat = re.sub('[.]', '', file_format)
+                    filetrajformat = re.sub('[.]', '', fileloadformat)
                     if('CHARMMCOOR' == filetrajformat.upper() or 
                        'CHARMMCOR' == filetrajformat.upper() or 
                        #'COR' == file_format.upper() or 
@@ -2582,7 +2604,7 @@ class MDDataAccess(object):
                         numatoms = self.get_natoms_from_topo(self.topocode)
                     trajhandler_check = None
                     try:
-                        trajhandler_check = mdt_FormatRegistry.fileobjects[file_format]
+                        trajhandler_check = mdt_FormatRegistry.fileobjects[fileloadformat]
                     except KeyError:
                         pass
                     else:
@@ -2630,7 +2652,7 @@ class MDDataAccess(object):
                                             self.natoms = numatoms
                                     break
                 if "mdanalysis" in interface:
-                    mdanalysis_format = re.sub('[.]', '', file_format)
+                    mdanalysis_format = re.sub('[.]', '', fileloadformat)
                     mdanalysis_format = mdanalysis_format.upper()
                     if self.topohandler is not None and chkfile is not None:
                         # if the topology handler is a MDAnalysis universe, 
@@ -2787,12 +2809,19 @@ class MDDataAccess(object):
                         pass
                 if "ase" in interface:
                     ase_support = None
-                    ase_support = self.get_ase_format_support(file_format)
+                    ase_support = self.get_ase_format_support(fileloadformat)
                     if ase_support is None:
                         ase_support = ase_io.formats.filetype(self.trajfile)
                         trajhandler = None
                         trajhandler = self.ase_iread(self.trajfile, fileformat=ase_support)
-                        if trajhandler:
+                        try:
+                            trajtest = next(trajhandler)
+                        except StopIteration:
+                            trajtest = None
+                        trajhandler = None
+                        if trajtest is not None:
+                            trajhandler = self.ase_iread(self.trajfile, fileformat=ase_support)
+                        if trajhandler and trajtest is not None:
                             if self.interfacematch:
                                 if interface in self.topocode:
                                     usedefault=False
@@ -2828,8 +2857,16 @@ class MDDataAccess(object):
                                 break
                     else:
                         trajhandler = None
-                        trajhandler = self.ase_iread(self.trajfile, fileformat=file_format)
-                        if trajhandler:
+                        trajhandler = self.ase_iread(self.trajfile, fileformat=fileloadformat)
+                        trajtest = None
+                        try:
+                            trajtest = next(trajhandler)
+                        except StopIteration:
+                            trajtest = None
+                        trajhandler = None
+                        if trajtest is not None:
+                            trajhandler = self.ase_iread(self.trajfile, fileformat=ase_support)
+                        if trajhandler and trajtest is not None:
                             if self.interfacematch:
                                 if interface in self.topocode:
                                     usedefault=False
@@ -2868,7 +2905,7 @@ class MDDataAccess(object):
                         if isinstance(self.UserSuppliedInterface, MDDataAccess.UserSuppliedInterface):
                             if self.UserSuppliedInterface.name in interface:
                                 trajhandler = None
-                                trajhandler = self.UserSuppliedInterface.trajectory_support(self.trajfile, file_format=file_format)
+                                trajhandler = self.UserSuppliedInterface.trajectory_support(self.trajfile, file_format=fileloadformat)
                                 if trajhandler:
                                     if self.interfacematch:
                                         if interface in self.topocode:
@@ -3015,7 +3052,7 @@ class MDDataAccess(object):
                 # Second,check whether MDtraj has support for the file type
                 # trajhandler_check = mdt_FormatRegistry.loaders[file_format]
                 try:
-                    trajhandler_check = mdt_FormatRegistry.fileobjects[file_format]
+                    trajhandler_check = mdt_FormatRegistry.fileobjects[fileloadformat]
                 except KeyError:
                     pass
                 else:
@@ -3059,7 +3096,7 @@ class MDDataAccess(object):
             else:
                 trajhandler = self.trajhandler 
             if trajhandler is None:
-                mdanalysis_format = re.sub('[.]', '', file_format)
+                mdanalysis_format = re.sub('[.]', '', fileloadformat)
                 if self.topohandler is not None and chkfile is not None:
                     if isinstance(self.topohandler, mda_u.Universe):
                         try:
@@ -3185,7 +3222,7 @@ class MDDataAccess(object):
                 trajhandler = self.trajhandler 
             if trajhandler is None:
                 ase_support = None
-                ase_support = self.get_ase_format_support(file_format)
+                ase_support = self.get_ase_format_support(fileloadformat)
                 trajcode=None
                 # May still have chance that ASE can recognize the 
                 # format with its filetype checking function
@@ -3196,9 +3233,17 @@ class MDDataAccess(object):
                         except (FileNotFoundError,IOError):
                             pass
                     trajhandler = None
+                    trajtest = None
                     if chkfile is not None:
                         trajhandler = self.ase_iread(chkfile, fileformat=ase_support)
-                    if trajhandler:
+                        try:
+                            trajtest = next(trajhandler)
+                        except StopIteration:
+                            trajtest = None
+                        trajhandler = None
+                        if trajtest is not None:
+                            trajhandler = self.ase_iread(self.trajfile, fileformat=ase_support)
+                    if trajhandler and trajtest is not None:
                         trajcode=None
                         if self.interfacematch:
                             if "ase" in self.topocode:
@@ -3209,10 +3254,18 @@ class MDDataAccess(object):
                             trajcode = "ase"
                 else:
                     trajhandler = None
+                    trajtest = None
                     if chkfile is not None:
-                        trajhandler = self.ase_iread(chkfile, fileformat=file_format)
+                        trajhandler = self.ase_iread(chkfile, fileformat=fileloadformat)
+                        try:
+                            trajtest = next(trajhandler)
+                        except StopIteration:
+                            trajtest = None
+                        trajhandler = None
+                        if trajtest is not None:
+                            trajhandler = self.ase_iread(self.trajfile, fileformat=ase_support)
                     trajcode=None
-                    if trajhandler:
+                    if trajhandler and trajtest is not None:
                         if self.interfacematch:
                             if "ase" in self.topocode:
                                 trajcode = "ase"
@@ -3455,6 +3508,18 @@ class MDDataAccess(object):
                     topology = pmd.gromacs.GromacsTopologyFile(top)
                 except(ValueError, AttributeError, IOError):
                     pass
+        elif ext in ['tinkertop', 'tinkerxyz', 'txyz']:
+            if self.topofile:
+                try:
+                    topology = pmd.tinker.tinkerfiles.XyzFile(top)
+                except(ValueError, AttributeError, IOError):
+                    pass
+        elif ext in ['tinkercor', 'tinkerdyn', 'dyn']:
+            if self.topofile:
+                try:
+                    topology = pmd.tinker.tinkerfiles.DynFile(top)
+                except(ValueError, AttributeError, IOError):
+                    pass
         elif ext in ['rtf', 'charmmtop', 'charmmstrrtf']: # .top is used by Gromacs.
             if base_topo is None:
                 base_topo = pmd.charmm.CharmmParameterSet()
@@ -3574,7 +3639,8 @@ class MDDataAccess(object):
             elif ext in ['.gro']:
                 topology = mdt.core.trajectory.load_gro(top, **wrapkwargs).topology
             elif ext in ['.arc']:
-                topology = mdt_load.arc(top, **wrapkwargs).topology
+                #topology = mdt_load.arc(top, **wrapkwargs).topology
+                topology = mdt.core.trajectory.load_arc(top, **wrapkwargs).topology
             elif ext in ['.hoomdxml']:
                 topology = mdt_load.hoomdxml(top, **wrapkwargs).topology
             elif isinstance(top, mdt_Trajectory):
@@ -3687,7 +3753,7 @@ class MDDataAccess(object):
             handler = ase_io.iread(filename, index=":", format=fileformat)
         except (AttributeError, IOError, OSError, 
                 ValueError, ImportError, ModuleNotFoundError):
-            return
+            return handler
 
         if handler is not None:
             try:
@@ -3695,9 +3761,9 @@ class MDDataAccess(object):
                     yield value.get_positions()
             except (AttributeError, IOError, OSError, 
                     ValueError, ImportError, ModuleNotFoundError):
-                return
+                return handler
         else:
-            return
+            return handler
     
     def charmm_coor_toporead(self, coorDict=None):
         """Returns the atom list for CHARMM COOR input
@@ -4191,7 +4257,10 @@ class MDDataAccess(object):
                             if parser_ui.fileDict[fileItem].strDict:
                                 self.trajstream.update(parser_ui.fileDict[fileItem].strDict)
                         traj_loaded = self.load()
-                        self.atompositions = self.iread()
+                        try: 
+                            self.atompositions = self.iread()
+                        except TypeError:
+                            pass
         if self.atompositions is not None:
             parser_ui.trajectory = self.set_TrajectoryData()
 
-- 
GitLab