Commit 4e44b0a4 authored by Lauri Himanen's avatar Lauri Himanen
Browse files

Started merging with the push interface, better base class for parsers.

parent eccf31ab
......@@ -8,15 +8,18 @@ logger = logging.getLogger(__name__)
class AtomsEngine(object):
"""Used to parse various different atomic coordinate files.
Supports the following file formats:
- xyz (.xyz):
- cif (.cif): Crystallographic Information File
- cp2k-pdb (.pdb): Protein Data Bank file written by CP2K, the format
is a bit peculiar so a custom implementation is used
See the dictionary 'formats' for all the supported formats and a brief
explanation.
Reading is primarily done by ASE or MDAnalysis, but in some cases own
implementation had to be made.
"""
formats = {
"xyz": "",
"cif": "(.cif): Crystallographic Information File",
"pdb-cp2k": "(.pdb): Protein Data Bank file written by CP2K, the format is a bit peculiar so a custom implementation is used",
"pdb": "(.pdb): Protein Data Bank",
}
def __init__(self, parser):
"""
......@@ -27,14 +30,16 @@ class AtomsEngine(object):
self.parser = parser
def determine_tool(self, format):
"""Determines which tool to use for extracting trajectories in the
given format.
"""
ASE = "ASE"
custom = "custom"
# MDAnalysis = "MDAnalysis"
formats = {
"xyz": ASE,
"cif": ASE,
"pdb-cp2k": custom,
"pdb": ASE,
"cp2k-pdb": custom,
}
result = formats.get(format)
if result:
......@@ -42,23 +47,23 @@ class AtomsEngine(object):
else:
logger.warning("The format '{}' is not supported by AtomsEngine.".format(format))
def check_format_support(self, format):
"""Check if the given format is supported.
"""
if format not in AtomsEngine.formats:
logger.error("The format '{}' is not supported by AtomsEngine.".format(format))
return False
else:
return True
def n_atoms(self, contents, format):
"""Read the first configuration of the coordinate file to extract the
number of atoms in it.
"""
# Figure out which tool to use
tool = self.determine_tool(format)
n_atoms = None
if tool == "ASE":
atoms = ase.io.read(contents, index=0, format=format)
n_atoms = atoms.get_number_of_atoms()
return n_atoms
if tool == "MDAnalysis":
u = MDAnalysis.Universe(contents.name)
n_atoms = len(u.atoms)
return n_atoms
iterator = self.iread(contents, format)
pos = iterator.next()
return pos.shape[0]
def iread(self, contents, format, index=0):
"""Returns an iterator that goes through the given trajectory file one
......@@ -66,15 +71,28 @@ class AtomsEngine(object):
whole file doesn't have to be loaded into memory.
"""
if not self.check_format_support(format):
return
tool = self.determine_tool(format)
# After reading the ASE source code, it seems that the ASE iread does
# actually read the entire file into memory and the yields the
# configurations from it. Should be checked at some point.
if tool == "ASE":
iterator = ase.io.iread(contents, None, format)
return iterator
iterator = ase.io.iread(contents, format=format)
return self.ase_wrapper(iterator)
elif tool == "custom":
if format == "cp2k-pdb":
if format == "pdb-cp2k":
iterator = self.parser.csvengine.iread(contents, columns=[3, 4, 5], comments=["TITLE", "AUTHOR", "REMARK", "CRYST"], separator="END")
return iterator
if format == "xyz":
iterator = self.parser.csvengine.iread(contents, columns=[-3, -2, -1], comments=["^i ="], separator="^\s*\d+$")
return iterator
def ase_wrapper(self, iterator):
"""Used to wrap an iterator returned by ase.io.iread so that it returns
the positions instead of the ase.Atoms object.
"""
for value in iterator:
yield value.get_positions()
This diff is collapsed.
......@@ -14,20 +14,16 @@ def scan_path_for_files(path):
".cif",
".pdb",
}
files = []
files = {}
for filename in os.listdir(path):
extension = os.path.splitext(filename)[1]
if extension in extensions:
file_object = {
"path": os.path.join(path, filename),
"file_id": "",
}
files.append(file_object)
files[os.path.join(path, filename)] = ""
return files
#===============================================================================
def get_parser(path):
def get_parser(path, test=True):
files = scan_path_for_files(path)
json_input = {
"version": "nomadparsein.json 1.0",
......@@ -36,7 +32,7 @@ def get_parser(path):
"metainfoToSkip": [],
"files": files
}
parser = CP2KParser(json.dumps(json_input))
parser = CP2KParser(json.dumps(json_input), test=test)
return parser
......
import os
import re
from cp2kparser.generics.nomadparser import NomadParser, Result, ResultKind, ResultCode
from cp2kparser.generics.nomadparser import NomadParser, Result, ResultCode
from cp2kparser.implementation.regexs import *
from cp2kparser.engines.regexengine import RegexEngine
from cp2kparser.engines.csvengine import CSVEngine
......@@ -9,6 +9,7 @@ from cp2kparser.engines.xmlengine import XMLEngine
from cp2kparser.engines.atomsengine import AtomsEngine
import numpy as np
import logging
import sys
logger = logging.getLogger(__name__)
from cp2kparser import ureg
......@@ -22,10 +23,10 @@ class CP2KParser(NomadParser):
implementation. For other versions there should be classes that extend from
this.
"""
def __init__(self, input_json_string):
NomadParser.__init__(self, input_json_string)
def __init__(self, input_json_string, stream=sys.stdout, test=False):
self.version_number = None
# Initialize the base class
NomadParser.__init__(self, input_json_string, stream, test)
# Engines are created here
self.csvengine = CSVEngine(self)
......@@ -34,14 +35,14 @@ class CP2KParser(NomadParser):
self.inputengine = CP2KInputEngine()
self.atomsengine = AtomsEngine(self)
self.version_number = None
self.input_tree = None
self.regexs = None
self.analyse_input_json()
self.check_resolved_file_ids()
self.determine_file_ids_from_extension()
# Use some convenient functions from base
self.determine_file_ids_pre_setup()
self.setup_version()
self.determine_file_ids()
# self.open_files()
self.determine_file_ids_post_setup()
def setup_version(self):
"""Inherited from NomadParser.
......@@ -75,42 +76,20 @@ class CP2KParser(NomadParser):
self.implementation = globals()["CP2KImplementation"](self)
def read_part_of_file(self, file_id, size=1024):
fh = self.file_handles[file_id]
fh.seek(0, os.SEEK_SET)
fh = self.get_file_handle(file_id)
buffer = fh.read(size)
return buffer
def check_resolved_file_ids(self):
"""Save the file id's that were given in the JSON input.
"""
resolved = {}
resolvable = []
for file_object in self.files:
path = file_object.get("path")
file_id = file_object.get("file_id")
if not file_id:
resolvable.append(path)
else:
resolved[file_id] = path
for id, path in resolved.iteritems():
self.file_ids[id] = path
self.get_file_handle(id)
self.resolvable = resolvable
def determine_file_ids_from_extension(self):
def determine_file_ids_pre_setup(self):
"""First resolve the files that can be identified by extension.
"""
for file_path in self.resolvable:
for file_path in self.files.iterkeys():
if file_path.endswith(".inp"):
self.file_ids["input"] = file_path
self.get_file_handle("input")
self.setup_file_id(file_path, "input")
if file_path.endswith(".out"):
self.file_ids["output"] = file_path
self.get_file_handle("output")
self.setup_file_id(file_path, "output")
def determine_file_ids(self):
def determine_file_ids_post_setup(self):
"""Inherited from NomadParser.
"""
# Determine the presence of force file
......@@ -157,7 +136,10 @@ class CP2KParser(NomadParser):
logger.debug("Normalizing trajectory path")
project_name = self.input_tree.get_keyword("GLOBAL/PROJECT_NAME")
file_format = self.input_tree.get_keyword("MOTION/PRINT/TRAJECTORY/FORMAT")
extension = {"PDB": "pdb"}[file_format]
extension = {
"PDB": "pdb",
"XYZ": "xyz"
}[file_format]
if path.startswith("="):
normalized_path = path[1:]
elif re.match(r"./", path):
......@@ -209,10 +191,9 @@ class CP2KParser(NomadParser):
break
# folders.reverse()
return folders
def get_quantity_unformatted(self, name):
def start_parsing(self, name):
"""Inherited from NomadParser. The timing and caching is already
implemented in the superclass.
"""
......@@ -253,7 +234,6 @@ class CP2KImplementation(object):
def _Q_energy_total(self):
"""Return the total energy from the bottom of the input file"""
result = Result()
result.return_type = ResultKind.energy
result.unit = ureg.hartree
result.value = float(self.regexengine.parse(self.regexs.energy_total, self.parser.get_file_handle("output")))
return result
......@@ -269,7 +249,6 @@ class CP2KImplementation(object):
belong to the list defined in NoMaD wiki.
"""
result = Result()
result.return_type = ResultKind.text
# First try to look at the shortcut
xc_shortcut = self.input_tree.get_parameter("FORCE_EVAL/DFT/XC/XC_FUNCTIONAL")
......@@ -333,7 +312,6 @@ class CP2KImplementation(object):
Supports forces printed in the output file or in a single .xyz file.
"""
result = Result()
result.return_type = ResultKind.force
result.unit = ureg.force_au
# Determine if a separate force file is used or are the forces printed
......@@ -400,7 +378,6 @@ class CP2KImplementation(object):
that must be present for all calculations.
"""
result = Result()
result.return_type = ResultKind.number
# Check where the coordinates are specified
coord_format = self.input_tree.get_keyword("FORCE_EVAL/SUBSYS/TOPOLOGY/COORD_FILE_FORMAT")
......@@ -445,12 +422,19 @@ class CP2KImplementation(object):
# Read the trajectory
traj_file = self.parser.get_file_handle("trajectory")
traj_iter = self.atomsengine.iread(traj_file, format="cp2k-pdb")
positions = []
file_format = self.input_tree.get_keyword("MOTION/PRINT/TRAJECTORY/FORMAT")
file_format = {
"XYZ": "xyz",
"PDB": "pdb-cp2k"
}[file_format]
traj_iter = self.atomsengine.iread(traj_file, format=file_format)
# Loop through the iterator to get all configurations
positions = []
for configuration in traj_iter:
positions.append(configuration)
result.value = np.array(positions)
if positions:
result.value = np.array(positions)
return result
......
......@@ -11,7 +11,7 @@ import pstats
def getparser(folder):
path = os.path.dirname(os.path.realpath(__file__))
path = os.path.join(path, folder)
parser = get_parser(path)
parser = get_parser(path, test=True)
return parser
......@@ -20,8 +20,8 @@ class TestFunctionals(unittest.TestCase):
def getxc(self, folder, result):
path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "functionals", folder)
parser = get_parser(path)
xc = parser.get_quantity_unformatted("XC_functional")
parser = getparser(path)
xc = parser.parse_quantity("XC_functional")
self.assertEqual(xc.value, result)
def test_pade(self):
......@@ -42,7 +42,7 @@ class TestForces(unittest.TestCase):
def test_forces_in_outputfile_n(self):
parser = getparser("forces/outputfile/n")
forces = parser.get_quantity_unformatted("particle_forces").value
forces = parser.parse_quantity("particle_forces").value
n_conf = forces.shape[0]
n_particles = forces.shape[1]
n_dim = forces.shape[2]
......@@ -52,7 +52,7 @@ class TestForces(unittest.TestCase):
def test_forces_in_outputfile_1(self):
parser = getparser("forces/outputfile/1")
forces = parser.get_quantity_unformatted("particle_forces").value
forces = parser.parse_quantity("particle_forces").value
n_conf = forces.shape[0]
n_particles = forces.shape[1]
n_dim = forces.shape[2]
......@@ -62,12 +62,12 @@ class TestForces(unittest.TestCase):
def test_forces_in_outputfile_0(self):
parser = getparser("forces/outputfile/0")
forces = parser.get_quantity_unformatted("particle_forces").value
forces = parser.parse_quantity("particle_forces").value
self.assertEqual(forces, None)
def test_forces_in_singlexyzfile_n(self):
parser = getparser("forces/singlexyzfile/n")
forces = parser.get_quantity_unformatted("particle_forces").value
forces = parser.parse_quantity("particle_forces").value
n_conf = forces.shape[0]
n_particles = forces.shape[1]
n_dim = forces.shape[2]
......@@ -77,7 +77,7 @@ class TestForces(unittest.TestCase):
def test_forces_in_singlexyzfile_1(self):
parser = getparser("forces/singlexyzfile/1")
forces = parser.get_quantity_unformatted("particle_forces").value
forces = parser.parse_quantity("particle_forces").value
n_conf = forces.shape[0]
n_particles = forces.shape[1]
n_dim = forces.shape[2]
......@@ -87,7 +87,7 @@ class TestForces(unittest.TestCase):
def test_forces_in_singlexyzfile_0(self):
parser = getparser("forces/singlexyzfile/0")
forces = parser.get_quantity_unformatted("particle_forces").value
forces = parser.parse_quantity("particle_forces").value
self.assertEqual(forces, None)
......@@ -96,42 +96,42 @@ class TestParticleNumber(unittest.TestCase):
def test_input_n(self):
parser = getparser("particle_number/inputfile/n")
n = parser.get_quantity_unformatted("particle_number").value
n = parser.parse_quantity("particle_number").value
self.assertEqual(n, 2)
def test_input_1(self):
parser = getparser("particle_number/inputfile/1")
n = parser.get_quantity_unformatted("particle_number").value
n = parser.parse_quantity("particle_number").value
self.assertEqual(n, 1)
def test_input_extra_lines(self):
parser = getparser("particle_number/inputfile/extra_lines")
n = parser.get_quantity_unformatted("particle_number").value
n = parser.parse_quantity("particle_number").value
self.assertEqual(n, 2)
def test_input_multiplication(self):
parser = getparser("particle_number/inputfile/multiplication")
n = parser.get_quantity_unformatted("particle_number").value
n = parser.parse_quantity("particle_number").value
self.assertEqual(n, 12)
def test_xyz_n(self):
parser = getparser("particle_number/xyz/n")
n = parser.get_quantity_unformatted("particle_number").value
n = parser.parse_quantity("particle_number").value
self.assertEqual(n, 2)
def test_xyz_multiplication(self):
parser = getparser("particle_number/xyz/multiplication")
n = parser.get_quantity_unformatted("particle_number").value
n = parser.parse_quantity("particle_number").value
self.assertEqual(n, 12)
def test_cif_n(self):
parser = getparser("particle_number/cif/n")
n = parser.get_quantity_unformatted("particle_number").value
n = parser.parse_quantity("particle_number").value
self.assertEqual(n, 2)
def test_pdb_n(self):
parser = getparser("particle_number/pdb/n")
n = parser.get_quantity_unformatted("particle_number").value
n = parser.parse_quantity("particle_number").value
self.assertEqual(n, 2)
......@@ -140,7 +140,47 @@ class TestTrajectory(unittest.TestCase):
def test_filenames_bare(self):
parser = getparser("trajectory/filenames/bare")
pos = parser.get_quantity_unformatted("particle_position").value
pos = parser.parse_quantity("particle_position").value
n_conf = pos.shape[0]
n_particles = pos.shape[1]
n_dim = pos.shape[2]
self.assertEqual(n_conf, 11)
self.assertEqual(n_particles, 2)
self.assertEqual(n_dim, 3)
def test_filenames_dotslash(self):
parser = getparser("trajectory/filenames/dotslash")
pos = parser.parse_quantity("particle_position").value
n_conf = pos.shape[0]
n_particles = pos.shape[1]
n_dim = pos.shape[2]
self.assertEqual(n_conf, 11)
self.assertEqual(n_particles, 2)
self.assertEqual(n_dim, 3)
def test_filenames_equals(self):
parser = getparser("trajectory/filenames/equals")
pos = parser.parse_quantity("particle_position").value
n_conf = pos.shape[0]
n_particles = pos.shape[1]
n_dim = pos.shape[2]
self.assertEqual(n_conf, 11)
self.assertEqual(n_particles, 2)
self.assertEqual(n_dim, 3)
def test_pdb(self):
parser = getparser("trajectory/pdb")
pos = parser.parse_quantity("particle_position").value
n_conf = pos.shape[0]
n_particles = pos.shape[1]
n_dim = pos.shape[2]
self.assertEqual(n_conf, 11)
self.assertEqual(n_particles, 2)
self.assertEqual(n_dim, 3)
def test_xyz(self):
parser = getparser("trajectory/xyz")
pos = parser.parse_quantity("particle_position").value
n_conf = pos.shape[0]
n_particles = pos.shape[1]
n_dim = pos.shape[2]
......@@ -151,6 +191,8 @@ class TestTrajectory(unittest.TestCase):
if __name__ == '__main__':
logger = logging.getLogger("cp2kparser")
logger.setLevel(logging.ERROR)
logger = logging.getLogger("nomadparser")
logger.setLevel(logging.ERROR)
# unittest.main()
suites = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment