Commit e9092f06 authored by Lauri Himanen's avatar Lauri Himanen
Browse files

Started doing the cell parsing, etc.

parent 8341d560
import ase.io
import logging
import MDAnalysis
logger = logging.getLogger(__name__)
......@@ -9,39 +8,24 @@ class AtomsEngine(object):
"""Used to parse various different atomic coordinate files.
See the dictionary 'formats' for all the supported formats and a brief
explanation.
Reading is primarily done by ASE or MDAnalysis, but in some cases own
implementation had to be made.
Returns all coordinates as numpy arrays.
explanation.Reading is primarily done by ASE or MDAnalysis, but in some cases own
implementation is used. Returns all coordinates as numpy arrays.
"""
formats = {
"xyz": "",
"xyz": "(.xyz): The XYZ file format.",
"cif": "(.cif): Crystallographic Information File",
"pdb-cp2k": "(.pdb): Protein Data Bank file written by CP2K, the format is a bit peculiar so a custom implementation is used",
"pdb": "(.pdb): Protein Data Bank",
#"dcd": "(.dcd): Binary trajectory file format used by CHARMM, NAMD, and X-PLOR.",
}
def __init__(self, parser):
"""
Args:
cp2k_parser: Instance of a NomadParser or it's subclass. Allows
access to e.g. unified file reading methods.
"""
self.parser = parser
def determine_tool(self, format):
"""Determines which tool to use for extracting trajectories in the
given format.
"""
ASE = "ASE"
custom = "custom"
formats = {
"xyz": ASE,
"cif": ASE,
"pdb-cp2k": custom,
"pdb": ASE,
"xyz": "ASE",
"cif": "ASE",
"pdb": "ASE",
}
result = formats.get(format)
if result:
......@@ -59,15 +43,15 @@ class AtomsEngine(object):
else:
return True
def n_atoms(self, contents, format):
def n_atoms(self, file_handle, format):
"""Read the first configuration of the coordinate file to extract the
number of atoms in it.
"""
iterator = self.iread(contents, format)
iterator = self.iread(file_handle, format)
pos = iterator.next()
return pos.shape[0]
def iread(self, contents, format, index=0):
def iread(self, file_handle, format, index=0):
"""Returns an iterator that goes through the given trajectory file one
configuration at a time. Good for e.g. streaming the contents to disc as the
whole file doesn't have to be loaded into memory.
......@@ -76,22 +60,35 @@ class AtomsEngine(object):
if not self.check_format_support(format):
return
if file_handle is None:
print "NONE"
tool = self.determine_tool(format)
# After reading the ASE source code, it seems that the ASE iread does
# actually read the entire file into memory and the yields the
# configurations from it. Should be checked at some point.
if tool == "ASE":
iterator = ase.io.iread(contents, format=format)
return self.ase_wrapper(iterator)
return self.ase_iread(file_handle, format, index)
elif tool == "custom":
if format == "pdb-cp2k":
iterator = self.parser.csvengine.iread(contents, columns=[3, 4, 5], comments=["TITLE", "AUTHOR", "REMARK", "CRYST"], separator="END")
return iterator
return self.custom_iread(file_handle, format, index)
elif tool == "MDAnalysis":
return self.mdanalysis_iread(file_handle, format, index)
def ase_wrapper(self, iterator):
def ase_iread(self, file_handle, format, index):
"""
"""
# After reading the ASE source code, it seems that the ASE iread does
# actually read the entire file into memory and the yields the
# configurations from it. Should be checked at some point.
def ase_generator(iterator):
"""Used to wrap an iterator returned by ase.io.iread so that it returns
the positions instead of the ase.Atoms object.
"""
for value in iterator:
yield value.get_positions()
iterator = ase.io.iread(file_handle, format=format)
return ase_generator(iterator)
def custom_iread(self, file_handle, format, index):
"""
"""
pass
......@@ -64,6 +64,9 @@ class CP2KInputEngine(object):
path += '/'
path += item
# Mark the section as accessed.
self.input_tree.set_section_accessed(path)
# Save the section parameters
if len(parts) > 1:
self.input_tree.set_parameter(path, parts[1].strip())
......
......@@ -5,114 +5,141 @@ because the pickling of these classes is wrong if they are defined in the same
file which is run in console (module will be then __main__).
"""
from collections import defaultdict
import logging
logger = logging.getLogger(__name__)
#===============================================================================
class Keyword(object):
"""Information about a keyword in a CP2K calculation.
"""
def __init__(self, default_name, default_value):
self.value = None
self.default_name = default_name
self.default_value = default_value
class Root(object):
def __init__(self, root_section):
self.root_section = root_section
#===============================================================================
class Section(object):
"""An input section in a CP2K calculation.
"""
def set_parameter(self, path, value):
parameter, section = self.get_parameter_and_section(path)
parameter.value = value
def __init__(self, name):
self.name = name
self.keywords = defaultdict(list)
self.default_keyword = ""
self.parameter = None
self.sections = defaultdict(list)
def set_keyword(self, path, value):
keyword, section = self.get_keyword_and_section(path)
if keyword and section:
keyword.value = value
elif section is not None:
# print "Saving default keyword at path '{}'".format(path)
split_path = path.rsplit("/", 1)
keyword = split_path[1]
section.default_keyword += keyword + " " + value + "\n"
def get_section(self, path):
split_path = path.split("/")
section = self
section = self.root_section
for part in split_path:
section = section.sections.get(part)
if section:
if len(section) == 1:
section = section[0]
else:
# print "The subsection '{}' is repeated. Not yet supported.".format(path)
return None
else:
# print "Subsection '{}' does not exist in section '{}'".format(path, self.name)
section = section.get_subsection(part)
if not section:
print "Error in getting section at path '{}'.".format(path)
return None
return section
def get_keyword_object(self, path):
def get_keyword_and_section(self, path):
split_path = path.rsplit("/", 1)
keyword = split_path[1]
section_path = split_path[0]
section = self.get_section(section_path)
keyword = section.keywords.get(keyword)
if keyword:
if len(keyword) == 1:
return keyword[0]
# print "The keyword in '{}' does not exist or has too many entries.".format(path)
return None
keyword = section.get_keyword(keyword)
if keyword and section:
return (keyword, section)
elif section:
return (None, section)
return (None, None)
def get_keyword(self, path):
"""Returns the keyword that is specified by the given path.
If the keyword has no value set, returns the default value defined in
the XML.
"""
keyword = self.get_keyword_object(path)
keyword, section = self.get_keyword_and_section(path)
if keyword:
if keyword.value is not None:
return keyword.value
else:
if section.accessed:
return keyword.default_value
def get_default_keyword(self, path):
return self.get_section(path).default_keyword
def set_keyword(self, path, value):
keyword = self.get_keyword_object(path)
if keyword:
keyword.value = value
else:
# print "Saving default keyword at path '{}'".format(path)
split_path = path.rsplit("/", 1)
keyword = split_path[1]
section_path = split_path[0]
section = self.get_section(section_path)
section.default_keyword += keyword + " " + value + "\n"
def set_section_accessed(self, path):
section = self.get_section(path)
section.accessed = True
def get_keyword_default(self, path):
keyword = self.get_keyword_object(path)
keyword, section = self.get_keyword_and_section(path)
if keyword:
return keyword.default_value
def get_parameter_object(self, path):
def get_parameter_and_section(self, path):
section = self.get_section(path)
parameter = section.parameter
if parameter:
return parameter
else:
print "The section parameters object '{}' could not be found.".format(path)
return (parameter, section)
def get_parameter(self, path):
parameter = self.get_parameter_object(path)
parameter, section = self.get_parameter_and_section(path)
if parameter:
if parameter.value:
return parameter.value
elif section and section.accessed:
return parameter.lone_value
def set_parameter(self, path, value):
parameter = self.get_parameter_object(path)
parameter.value = value
def get_parameter_lone(self, path):
parameter = self.get_parameter_object(path)
return parameter.lone_value
# def get_parameter_lone(self, path):
# parameter = self.get_parameter_object(path)
# return parameter.lone_value
# def get_parameter_default(self, path):
# parameter = self.get_parameter_object(path)
# return parameter.default_value
#===============================================================================
class Keyword(object):
"""Information about a keyword in a CP2K calculation.
"""
def __init__(self, default_name, default_value):
self.value = None
self.default_name = default_name
self.default_value = default_value
def get_parameter_default(self, path):
parameter = self.get_parameter_object(path)
return parameter.default_value
#===============================================================================
class Section(object):
"""An input section in a CP2K calculation.
"""
def __init__(self, name):
self.accessed = False
self.name = name
self.keywords = defaultdict(list)
self.default_keyword = ""
self.parameter = None
self.sections = defaultdict(list)
def get_keyword(self, name):
keyword = self.keywords.get(name)
if keyword:
if len(keyword) == 1:
return keyword[0]
else:
logger.error("The keyword '{}' in '{}' does not exist or has too many entries.".format(name, self.name))
def get_subsection(self, name):
subsection = self.sections.get(name)
if subsection:
if len(subsection) == 1:
return subsection[0]
else:
logger.error("The subsection '{}' in '{}' has too many entries.".format(name, self.name))
else:
logger.error("The subsection '{}' in '{}' does not exist.".format(name, self.name))
#===============================================================================
......
......@@ -85,7 +85,7 @@ def recursive_tree_generation(xml_element):
# Run main function by default
if __name__ == "__main__":
xml_file = open("./cp2k_262/cp2k_input.xml", 'r')
object_tree = generate_object_tree(xml_file)
object_tree = Root(generate_object_tree(xml_file))
file_name = "./cp2k_262/cp2k_input_tree.pickle"
fh = open(file_name, "wb")
pickle.dump(object_tree, fh, protocol=2)
......@@ -88,9 +88,7 @@ class CSVEngine(object):
# Start iterating
configuration = []
print contents.name
for line in contents: # This actually reads line by line and only keeps the current line in memory
print line
# If separator encountered, yield the stored configuration
if is_separator(line):
......
......@@ -82,6 +82,7 @@ class NomadParser(object):
self.metainfo_to_keep = None
self.metainfo_to_skip = None
self.file_ids = {}
self.results = {}
self.filepaths_wo_id = None
self.test_mode = test_mode
self.backend = JsonParseEventsWriterBackend(None, stream)
......@@ -178,17 +179,23 @@ class NomadParser(object):
Checks through the list given by get_supported_quantities and also
checks the metainfoToSkip parameter given in the JSON input.
"""
if name not in self.metainfos:
logger.error("The metaname '{}' was not declared on the metainfo file defined in the JSON input.".format(name))
return False
if name not in self.get_supported_quantities():
logger.error("The metaname '{}' is not available in this parser version.".format(name))
return False
if name in self.metainfo_to_skip:
logger.error("The metaname '{}' cannot be calculated as it is in the list 'metaInfoToSkip'.".format(name))
return False
return True
def parse(self):
"""Start parsing the contents.
"""
# Determine which values in metainfo are parseable
metainfos = self.metainfos.itervalues()
for metainfo in metainfos:
name = metainfo["name"]
if self.check_quantity_availability(name):
self.parse_quantity(name)
def parse_quantity(self, name):
"""Given a unique quantity id (=metaInfo name) which is supported by
the parser, parses the corresponding quantity (if available), converts
......@@ -202,7 +209,8 @@ class NomadParser(object):
if not available:
return
result = self.start_parsing(name)
# Get the result by parsing or from cache
result = self.get_result_object(name)
if result is not None:
if isinstance(result, Result):
......@@ -215,10 +223,10 @@ class NomadParser(object):
self.result_saver(result)
# In test mode just return the values directly
else:
if result. value is not None:
if result.value is not None:
if result.value_iterable is None:
return result.value
if result.value_iterable is not None:
elif result.value_iterable is not None:
values = []
for value in result.value_iterable:
values.append(value)
......@@ -226,6 +234,15 @@ class NomadParser(object):
if values.size != 0:
return values
def get_result_object(self, name):
# Check cache
result = self.results.get(name)
if result is None:
result = self.start_parsing(name)
if result.cache:
self.results[name] = result
return result
def result_saver(self, result):
"""Given a result object, saves the results to the backend.
......@@ -382,9 +399,22 @@ class Result(object):
The repeatable values can also be given as generator functions. With
generators you can easily push results from a big data file piece by piece
to the backend without loading the entire file into memory.
Attributes:
cache: Boolean indicating whether the result should be cached in memory.
name: The name of the metainfo corresponding to this result
value: The value of the result. Used for storing single results.
value_iterable: Iterable object containing multiple results.
unit: Unit of the result. Use the Pint units from UnitRegistry. e.g.
unit = ureg.newton. Used to automatically convert to SI.
dtypstr: The datatype string specified in metainfo.
shape: The expected shape of the result specified in metainfo.
repeats: A boolean indicating if this value can repeat. Specified in
metainfo.
"""
def __init__(self, meta_name=""):
def __init__(self):
self.name = None
self.value = None
self.value_iterable = None
......@@ -394,6 +424,7 @@ class Result(object):
self.dtypestr = None
self.repeats = None
self.shape = None
self.cache = False
#===============================================================================
......
......@@ -14,6 +14,7 @@ def scan_path_for_files(path):
".xyz",
".cif",
".pdb",
".dcd",
}
files = {}
for filename in os.listdir(path):
......
......@@ -33,7 +33,7 @@ class CP2KParser(NomadParser):
self.regexengine = RegexEngine(self)
self.xmlengine = XMLEngine(self)
self.inputengine = CP2KInputEngine()
self.atomsengine = AtomsEngine(self)
self.atomsengine = AtomsEngine()
self.version_number = None
self.implementation = None
......@@ -161,7 +161,10 @@ class CP2KParser(NomadParser):
file_format = self.input_tree.get_keyword("MOTION/PRINT/TRAJECTORY/FORMAT")
extension = {
"PDB": "pdb",
"XYZ": "xyz"
"XYZ": "xyz",
"XMOL": "xyz",
"ATOMIC": "xyz",
"DCD": "dcd",
}[file_format]
if path.startswith("="):
normalized_path = path[1:]
......@@ -246,7 +249,6 @@ class CP2KImplementation(object):
def decode_cp2k_unit(self, unit):
"""Given a CP2K unit name, decode it as Pint unit definition.
"""
map = {
# Length
"bohr": ureg.bohr,
......@@ -338,9 +340,11 @@ class CP2KImplementation(object):
return result
def _Q_particle_forces(self):
"""Return all the forces for every step found.
"""Return the forces that are controlled by
"FORCE_EVAL/PRINT/FORCES/FILENAME". These forces are typicalle printed
out during optimization or single point calculation.
Supports forces printed in the output file or in a single .xyz file.
Supports forces printed in the output file or in a single XYZ file.
"""
result = Result()
result.unit = ureg.force_au
......@@ -396,12 +400,17 @@ class CP2KImplementation(object):
that must be present for all calculations.
"""
result = Result()
result.cache = True
# Check where the coordinates are specified
coord_format = self.input_tree.get_keyword("FORCE_EVAL/SUBSYS/TOPOLOGY/COORD_FILE_FORMAT")
if not coord_format:
coord_format = self.input_tree.get_keyword_default("FORCE_EVAL/SUBSYS/TOPOLOGY/COORD_FILE_FORMAT")
# Check if the unit cell is multiplied programmatically
multiples = self.input_tree.get_keyword("FORCE_EVAL/SUBSYS/TOPOLOGY/MULTIPLE_UNIT_CELL")
if not multiples:
multiples = self.input_tree.get_keyword_default("FORCE_EVAL/SUBSYS/TOPOLOGY/MULTIPLE_UNIT_CELL")
factors = [int(x) for x in multiples.split()]
factor = np.prod(np.array(factors))
......@@ -442,17 +451,64 @@ class CP2KImplementation(object):
# Read the trajectory
traj_file = self.parser.get_file_handle("trajectory")
file_format = self.input_tree.get_keyword("MOTION/PRINT/TRAJECTORY/FORMAT")
input_file_format = self.input_tree.get_keyword("MOTION/PRINT/TRAJECTORY/FORMAT")
file_format = {
"XYZ": "xyz",
"PDB": "pdb-cp2k"
}[file_format]
"XMOL": "xyz",
"PDB": "pdb-cp2k",
"ATOMIC": "atomic",
}.get(input_file_format)
if file_format is None:
logger.error("Unsupported trajectory file format '{}'.".format(input_file_format))
# Use a custom implementation for the CP2K specific weird formats
if file_format == "pdb-cp2k":
traj_iter = self.parser.csvengine.iread(traj_file, columns=[3, 4, 5], comments=["TITLE", "AUTHOR", "REMARK", "CRYST"], separator="END")
elif file_format == "atomic":
n_atoms = self.parser.get_result_object("particle_number").value
def atomic_generator():
conf = []
i = 0
for line in traj_file:
line = line.strip()
components = np.array([float(x) for x in line.split()])
conf.append(components)
i += 1
if i == n_atoms:
yield np.array(conf)
conf = []
i = 0
traj_iter = atomic_generator()
else:
traj_iter = self.atomsengine.iread(traj_file, format=file_format)
# Return the iterator
result.value_iterable = traj_iter
return result
def _Q_cell(self):
# Cell given as three vectors
A = self.input_tree.get_keyword("FORCE_EVAL/SUBSYS/CELL/A")
B = self.input_tree.get_keyword("FORCE_EVAL/SUBSYS/CELL/B")
C = self.input_tree.get_keyword("FORCE_EVAL/SUBSYS/CELL/C")
if A and B and C:
return
# Cell given as three lengths and three angles
ABC = self.input_tree.get_keyword("FORCE_EVAL/SUBSYS/CELL/ABC")
abg = self.input_tree.get_keyword("FORCE_EVAL/SUBSYS/CELL/ALPHA_BETA_GAMMA")
# Cell given in external file
cell_format = self.input_tree.get_keyword("FORCE_EVAL/SUBSYS/CELL/CELL_FILE_FORMAT")
cell_file = self.input_tree.get_keyword("FORCE_EVAL/SUBSYS/CELL/CELL_FILE_NAME")
# Multiplication factor
factor = self.input_tree.get_keyword("FORCE_EVAL/SUBSYS/CELL/CELL_FILE_NAME")
#===============================================================================
class CP2K_262_Implementation(CP2KImplementation):
......
......@@ -190,6 +190,26 @@ class TestTrajectory(unittest.TestCase):
self.assertEqual(n_particles, 2)
self.assertEqual(n_dim, 3)
def test_xmol(self):
parser = getparser("trajectory/xmol")
pos = parser.parse_quantity("particle_position")
n_conf = pos.shape[0]
n_particles = pos.shape[1]
n_dim = pos.shape[2]