Skip to content
Snippets Groups Projects
Commit 13a258c5 authored by Himanen, Lauri (himanel1)'s avatar Himanen, Lauri (himanel1)
Browse files

Made the input file parsing more straightworward.

parent 30ab469a
No related branches found
No related tags found
No related merge requests found
......@@ -164,30 +164,54 @@ class InputObject(object):
self.data_dimension = None
self.default_value = None
def get_formatted_value(self):
""" Used to set the value of the keyword. The data will be transformed
into the correct data type and dimension from a simple string.
#===============================================================================
class Keyword(InputObject):
"""Information about a keyword in a CP2K calculation.
"""
__slots__ = ['unit', 'value_no_unit', 'default_unit', 'default_name']
def __init__(self, name, default_value, default_unit, default_name):
super(Keyword, self).__init__(name)
self.unit = None
self.value_no_unit = None
self.default_unit = default_unit
self.default_value = default_value
self.default_name = default_name
def get_value(self):
"""Returns the value stored in this keyword by removing the possible
unit definition and formatting the string into the correct data type.
"""
# Decode the unit and the value if not done before
if self.default_unit:
if not self.value_no_unit:
self.decode_cp2k_unit_and_value()
if self.value_no_unit is not None:
proper_value = self.value_no_unit
else:
proper_value = self.value
returned = None
dim = int(self.data_dimension)
splitted = self.value.split()
splitted = proper_value.split()
if len(splitted) != dim:
logger.error("The dimensions of the CP2K input parameter {} do not match the specification in the XML file.".format(self.name))
if dim == 1:
try:
if self.data_type == "integer":
returned = int(self.value)
returned = int(proper_value)
elif self.data_type == "real":
returned = float(self.value)
returned = float(proper_value)
elif self.data_type == "word":
returned = str(self.value)
returned = str(proper_value)
elif self.data_type == "keyword":
returned = str(self.value)
returned = str(proper_value)
elif self.data_type == "string":
returned = str(self.value)
returned = str(proper_value)
elif self.data_type == "logical":
returned = str(self.value)
returned = str(proper_value)
else:
logger.error("Unknown data type '{}'".format(self.data_type))
return
......@@ -217,22 +241,7 @@ class InputObject(object):
return returned
#===============================================================================
class Keyword(InputObject):
"""Information about a keyword in a CP2K calculation.
"""
__slots__ = ['unit', 'value_no_unit', 'default_unit', 'default_name']
def __init__(self, name, default_value, default_unit, default_name):
super(Keyword, self).__init__(name)
self.unit = None
self.value_no_unit = None
self.default_unit = default_unit
self.default_value = default_value
self.default_name = default_name
def get_value(self):
def determine_value_and_unit(self):
"""If the units of this value can be changed, return a value and the
unit separately.
"""
......@@ -244,12 +253,15 @@ class Keyword(InputObject):
return self.value
def get_unit(self):
# Decode the unit and the value if not done before
if self.default_unit:
if not self.unit:
self.decode_cp2k_unit_and_value()
return self.unit
else:
logger.error("The keyword '{}' does not have a unit.".format(self.default_name))
return None
def decode_cp2k_unit_and_value(self):
"""Given a CP2K unit name, decode it as Pint unit definition.
......
......@@ -222,8 +222,8 @@ def generate_metainfo_recursively(obj, parent, container, name_stack):
def generate_input_object_metainfo_json(child, parent, name_stack):
path = ".".join(name_stack)
json_obj = {}
json_obj["name"] = "cp2k_{}.{}".format(path, child.name)
json_obj["superNames"] = ["cp2k_{}".format(path)]
json_obj["name"] = "x_cp2k_{}.{}".format(path, child.name)
json_obj["superNames"] = ["x_cp2k_{}".format(path)]
# Description
description = child.description
......@@ -232,22 +232,23 @@ def generate_input_object_metainfo_json(child, parent, name_stack):
json_obj["description"] = description
# Shape
data_dim = int(child.data_dimension)
if data_dim == -1:
data_dim = "n"
if data_dim == 1:
json_obj["shape"] = []
else:
json_obj["shape"] = [data_dim]
# data_dim = int(child.data_dimension)
# if data_dim == -1:
# data_dim = "n"
# if data_dim == 1:
# json_obj["shape"] = []
# else:
# json_obj["shape"] = [data_dim]
json_obj["shape"] = []
# Determine data type according to xml info
mapping = {
"keyword": "C",
"logical": "C",
"string": "C",
"integer": "i",
"integer": "C",
"word": "C",
"real": "f",
"real": "C",
}
json_obj["dtypeStr"] = mapping[child.data_type]
return json_obj
......@@ -259,9 +260,9 @@ def generate_section_metainfo_json(child, parent, name_stack):
path = ".".join(name_stack[:-1])
json_obj = {}
json_obj["name"] = "cp2k_{}".format(name)
json_obj["name"] = "x_cp2k_{}".format(name)
json_obj["kindStr"] = "type_section"
json_obj["superNames"] = ["cp2k_{}".format(path)]
json_obj["superNames"] = ["x_cp2k_{}".format(path)]
description = child.description
if description is None or description.isspace():
......
......@@ -330,23 +330,20 @@ class CP2KInputParser(BasicParser):
for keyword in keywords:
if keyword.value is not None:
name = "{}.{}".format(path, keyword.default_name)
formatted_value = keyword.get_formatted_value()
self.add_formatted_value_to_backend(name, formatted_value)
self.backend.addValue(name, keyword.value)
# Section parameter
section_parameter = section.section_parameter
if section_parameter is not None:
name = "{}.SECTION_PARAMETERS".format(path)
formatted_value = section_parameter.get_formatted_value()
self.add_formatted_value_to_backend(name, formatted_value)
self.backend.addValue(name, section_parameter.value)
# Default keyword
default_keyword = section.default_keyword
if default_keyword is not None:
name = "{}.DEFAULT_KEYWORD".format(path)
formatted_value = default_keyword.get_formatted_value()
self.add_formatted_value_to_backend(name, formatted_value)
self.backend.addValue(name, default_keyword.value)
# Subsections
for name, subsections in section.sections.iteritems():
......@@ -357,13 +354,6 @@ class CP2KInputParser(BasicParser):
name_stack.pop()
def add_formatted_value_to_backend(self, name, formatted_value):
if formatted_value is not None:
if isinstance(formatted_value, np.ndarray):
self.backend.addArrayValues(name, formatted_value)
else:
self.backend.addValue(name, formatted_value)
def setup_version(self, version_number):
""" The pickle file which contains preparsed data from the
x_cp2k_input.xml is version specific. By calling this function before
......
......@@ -444,15 +444,15 @@ class TestPreprocessor(unittest.TestCase):
def test_variable_multiple(self):
result = get_result("input_preprocessing/variable_multiple", "x_cp2k_CP2K_INPUT.FORCE_EVAL.DFT.MGRID.CUTOFF", optimize=False)
self.assertEqual(result, 50)
self.assertEqual(result, "50")
def test_comments(self):
result = get_result("input_preprocessing/comments", "x_cp2k_CP2K_INPUT.FORCE_EVAL.DFT.MGRID.CUTOFF", optimize=False)
self.assertEqual(result, 120)
self.assertEqual(result, "120")
def test_tabseparator(self):
result = get_result("input_preprocessing/tabseparator", "x_cp2k_CP2K_INPUT.FORCE_EVAL.DFT.MGRID.CUTOFF", optimize=False)
self.assertEqual(result, 120)
self.assertEqual(result, "120")
#===============================================================================
......@@ -617,17 +617,17 @@ if __name__ == '__main__':
logger.setLevel(logging.ERROR)
suites = []
# suites.append(unittest.TestLoader().loadTestsFromTestCase(TestErrors))
# suites.append(unittest.TestLoader().loadTestsFromTestCase(TestXCFunctional))
# suites.append(unittest.TestLoader().loadTestsFromTestCase(TestEnergyForce))
# suites.append(unittest.TestLoader().loadTestsFromTestCase(TestStressTensorMethods))
# suites.append(unittest.TestLoader().loadTestsFromTestCase(TestSelfInteractionCorrectionMethod))
# suites.append(unittest.TestLoader().loadTestsFromTestCase(TestConfigurationPeriodicDimensions))
# suites.append(unittest.TestLoader().loadTestsFromTestCase(TestSCFConvergence))
# suites.append(unittest.TestLoader().loadTestsFromTestCase(TestForceFiles))
suites.append(unittest.TestLoader().loadTestsFromTestCase(TestErrors))
suites.append(unittest.TestLoader().loadTestsFromTestCase(TestXCFunctional))
suites.append(unittest.TestLoader().loadTestsFromTestCase(TestEnergyForce))
suites.append(unittest.TestLoader().loadTestsFromTestCase(TestStressTensorMethods))
suites.append(unittest.TestLoader().loadTestsFromTestCase(TestSelfInteractionCorrectionMethod))
suites.append(unittest.TestLoader().loadTestsFromTestCase(TestConfigurationPeriodicDimensions))
suites.append(unittest.TestLoader().loadTestsFromTestCase(TestSCFConvergence))
suites.append(unittest.TestLoader().loadTestsFromTestCase(TestForceFiles))
suites.append(unittest.TestLoader().loadTestsFromTestCase(TestPreprocessor))
# suites.append(unittest.TestLoader().loadTestsFromTestCase(TestGeoOpt))
# suites.append(unittest.TestLoader().loadTestsFromTestCase(TestGeoOptTrajFormats))
# suites.append(unittest.TestLoader().loadTestsFromTestCase(TestGeoOptOptimizers))
suites.append(unittest.TestLoader().loadTestsFromTestCase(TestGeoOpt))
suites.append(unittest.TestLoader().loadTestsFromTestCase(TestGeoOptTrajFormats))
suites.append(unittest.TestLoader().loadTestsFromTestCase(TestGeoOptOptimizers))
alltests = unittest.TestSuite(suites)
unittest.TextTestRunner(verbosity=0).run(alltests)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment