Commit a91fbdd6 authored by Lauri Himanen's avatar Lauri Himanen
Browse files

Bug fixes.

parent 6755ba03
...@@ -87,23 +87,25 @@ class CP2KInput(object): ...@@ -87,23 +87,25 @@ class CP2KInput(object):
logger.warning(message) logger.warning(message)
return (None, None) return (None, None)
keyword = section.get_keyword(keyword) keyword = section.get_keyword_object(keyword)
if keyword and section: if keyword and section:
return (keyword, section) return (keyword, section)
else: else:
return (None, section) return (None, section)
def get_keyword(self, path): def get_keyword_value_formatted(self, path):
"""Returns the keyword that is specified by the given path. """
If the keyword has no value set, returns the default value defined in
the XML.
""" """
keyword, section = self.get_keyword_and_section(path) keyword, section = self.get_keyword_and_section(path)
if keyword: if keyword:
if keyword.value is not None: return keyword.get_value_formatted()
return keyword.get_value()
else: def get_keyword_value(self, path):
return keyword.default_value """
"""
keyword, section = self.get_keyword_and_section(path)
if keyword:
return keyword.get_value()
def get_default_keyword(self, path): def get_default_keyword(self, path):
return self.get_section(path).default_keyword.value return self.get_section(path).default_keyword.value
...@@ -180,10 +182,17 @@ class Keyword(InputObject): ...@@ -180,10 +182,17 @@ class Keyword(InputObject):
self.default_name = default_name self.default_name = default_name
def get_value(self): def get_value(self):
"""Returns the unformatted value of this keyword. This is exactly what
was set by the used in the input as a string.
"""
return self.value
def get_value_formatted(self):
"""Returns the value stored in this keyword by removing the possible """Returns the value stored in this keyword by removing the possible
unit definition and formatting the string into the correct data type. unit definition and formatting the string into the correct data type.
""" """
# Decode the unit and the value if not done before # Decode the unit and the value if not done before
proper_value = None
if self.default_unit: if self.default_unit:
if not self.value_no_unit: if not self.value_no_unit:
self.decode_cp2k_unit_and_value() self.decode_cp2k_unit_and_value()
...@@ -191,6 +200,8 @@ class Keyword(InputObject): ...@@ -191,6 +200,8 @@ class Keyword(InputObject):
proper_value = self.value_no_unit proper_value = self.value_no_unit
else: else:
proper_value = self.value proper_value = self.value
if proper_value is None:
proper_value = self.default_value
returned = None returned = None
dim = int(self.data_dimension) dim = int(self.data_dimension)
...@@ -297,7 +308,7 @@ class Section(object): ...@@ -297,7 +308,7 @@ class Section(object):
self.sections = defaultdict(list) self.sections = defaultdict(list)
self.description = None self.description = None
def get_keyword(self, name): def get_keyword_object(self, name):
keyword = self.keywords.get(name) keyword = self.keywords.get(name)
if keyword: if keyword:
if len(keyword) == 1: if len(keyword) == 1:
...@@ -305,6 +316,22 @@ class Section(object): ...@@ -305,6 +316,22 @@ class Section(object):
else: else:
logger.error("The keyword '{}' in '{}' does not exist or has too many entries.".format(name, self.name)) logger.error("The keyword '{}' in '{}' does not exist or has too many entries.".format(name, self.name))
def get_keyword_value_formatted(self, name):
"""Returns the keyword value formatted to the correct shape and type,
and returns the default value if nothing was specified.
"""
keyword_object = self.get_keyword_object(name)
if keyword_object is not None:
value = keyword_object.get_value_formatted()
return value
def get_keyword_value(self, name):
"""Returns the keyword value as a raw string as specfied by the used.
"""
keyword_object = self.get_keyword_object(name)
if keyword_object is not None:
return keyword_object.get_value()
def get_subsection(self, name): def get_subsection(self, name):
subsection = self.sections.get(name) subsection = self.sections.get(name)
if subsection: if subsection:
...@@ -315,6 +342,16 @@ class Section(object): ...@@ -315,6 +342,16 @@ class Section(object):
else: else:
logger.error("The subsection '{}' in '{}' does not exist.".format(name, self.name)) logger.error("The subsection '{}' in '{}' does not exist.".format(name, self.name))
def get_section_parameter(self):
"""Get the section parameter, or if not specified the lone keyword
value.
"""
if self.section_parameter is not None:
value = self.section_parameter.value
if value is None:
value = self.section_parameter.lone_keyword_value
return value.upper()
#=============================================================================== #===============================================================================
class SectionParameters(InputObject): class SectionParameters(InputObject):
......
...@@ -57,9 +57,9 @@ def recursive_tree_generation(xml_element, for_metainfo=False, name_stack=[]): ...@@ -57,9 +57,9 @@ def recursive_tree_generation(xml_element, for_metainfo=False, name_stack=[]):
section = Section(section_name) section = Section(section_name)
# Ignore sections that control the print settings # Ignore sections that control the print settings
ignored = ["EACH", "PRINT"] # ignored = ["EACH", "PRINT"]
if section_name in ignored: # if section_name in ignored:
return # return
if for_metainfo: if for_metainfo:
# Descriptions # Descriptions
...@@ -223,7 +223,7 @@ def generate_input_object_metainfo_json(child, parent, name_stack): ...@@ -223,7 +223,7 @@ def generate_input_object_metainfo_json(child, parent, name_stack):
path = ".".join(name_stack) path = ".".join(name_stack)
json_obj = {} json_obj = {}
json_obj["name"] = "x_cp2k_{}.{}".format(path, child.name) json_obj["name"] = "x_cp2k_{}.{}".format(path, child.name)
json_obj["superNames"] = ["x_cp2k_{}".format(path)] json_obj["superNames"] = ["x_cp2k_section_{}".format(path)]
# Description # Description
description = child.description description = child.description
...@@ -260,9 +260,9 @@ def generate_section_metainfo_json(child, parent, name_stack): ...@@ -260,9 +260,9 @@ def generate_section_metainfo_json(child, parent, name_stack):
path = ".".join(name_stack[:-1]) path = ".".join(name_stack[:-1])
json_obj = {} json_obj = {}
json_obj["name"] = "x_cp2k_{}".format(name) json_obj["name"] = "x_cp2k_section_{}".format(name)
json_obj["kindStr"] = "type_section" json_obj["kindStr"] = "type_section"
json_obj["superNames"] = ["x_cp2k_{}".format(path)] json_obj["superNames"] = ["x_cp2k_section_{}".format(path)]
description = child.description description = child.description
if description is None or description.isspace(): if description is None or description.isspace():
...@@ -276,13 +276,13 @@ def generate_section_metainfo_json(child, parent, name_stack): ...@@ -276,13 +276,13 @@ def generate_section_metainfo_json(child, parent, name_stack):
if __name__ == "__main__": if __name__ == "__main__":
# xml to pickle # xml to pickle
# xml_file = open("../versions/cp2k262/input_data/cp2k_input.xml", 'r') xml_file = open("../versions/cp2k262/input_data/cp2k_input.xml", 'r')
# object_tree = CP2KInput(generate_object_tree(xml_file)) object_tree = CP2KInput(generate_object_tree(xml_file))
# file_name = "../versions/cp2k262/input_data/cp2k_input_tree.pickle" file_name = "../versions/cp2k262/input_data/cp2k_input_tree.pickle"
# fh = open(file_name, "wb") fh = open(file_name, "wb")
# pickle.dump(object_tree, fh, protocol=2) pickle.dump(object_tree, fh, protocol=2)
# Metainfo generation # Metainfo generation
xml_file = open("../versions/cp2k262/input_data/cp2k_input.xml", 'r') # xml_file = open("../versions/cp2k262/input_data/cp2k_input.xml", 'r')
object_tree = CP2KInput(generate_object_tree(xml_file, for_metainfo=True)) # object_tree = CP2KInput(generate_object_tree(xml_file, for_metainfo=True))
generate_input_metainfos(object_tree) # generate_input_metainfos(object_tree)
...@@ -103,12 +103,42 @@ class CP2KInputParser(BasicParser): ...@@ -103,12 +103,42 @@ class CP2KInputParser(BasicParser):
elif section_parameter == "B3LYP": elif section_parameter == "B3LYP":
xc_list.append(XCFunctional("HYB_GGA_XC_B3LYP")) xc_list.append(XCFunctional("HYB_GGA_XC_B3LYP"))
elif section_parameter == "TPSS":
xc_list.append(XCFunctional("MGGA_X_TPSS"))
xc_list.append(XCFunctional("MGGA_C_TPSS"))
else: else:
logger.warning("Unknown XC functional given in XC_FUNCTIONAL section parameter.") logger.warning("Unknown XC functional given in XC_FUNCTIONAL section parameter.")
# Otherwise one has to look at the individual functional settings # Otherwise one has to look at the individual functional settings
else: else:
pass pbe = xc.get_subsection("PBE")
if pbe is not None:
if pbe.accessed:
sp = pbe.get_section_parameter()
if sp == "T":
parametrization = pbe.get_keyword_value_formatted("PARAMETRIZATION")
scale_x = pbe.get_keyword_value_formatted("SCALE_X")
scale_c = pbe.get_keyword_value_formatted("SCALE_C")
if parametrization == "ORIG":
xc_list.append(XCFunctional("GGA_X_PBE", scale_x))
xc_list.append(XCFunctional("GGA_C_PBE", scale_c))
elif parametrization == "PBESOL":
xc_list.append(XCFunctional("GGA_X_PBE_SOL", scale_x))
xc_list.append(XCFunctional("GGA_C_PBE_SOL", scale_c))
elif parametrization == "REVPBE":
xc_list.append(XCFunctional("GGA_X_PBE_R", scale_x))
xc_list.append(XCFunctional("GGA_C_PBE", scale_c))
tpss = xc.get_subsection("TPSS")
if tpss is not None:
if tpss.accessed:
sp = tpss.get_section_parameter()
if sp == "T":
scale_x = tpss.get_keyword_value_formatted("SCALE_X")
scale_c = tpss.get_keyword_value_formatted("SCALE_C")
xc_list.append(XCFunctional("MGGA_X_TPSS", scale_x))
xc_list.append(XCFunctional("MGGA_C_TPSS", scale_c))
# Sort the functionals alphabetically by name # Sort the functionals alphabetically by name
xc_list.sort(key=lambda x: x.name) xc_list.sort(key=lambda x: x.name)
...@@ -137,7 +167,7 @@ class CP2KInputParser(BasicParser): ...@@ -137,7 +167,7 @@ class CP2KInputParser(BasicParser):
#======================================================================= #=======================================================================
# Cell periodicity # Cell periodicity
periodicity = self.input_tree.get_keyword("FORCE_EVAL/SUBSYS/CELL/PERIODIC") periodicity = self.input_tree.get_keyword_value_formatted("FORCE_EVAL/SUBSYS/CELL/PERIODIC")
if periodicity is not None: if periodicity is not None:
periodicity = periodicity.upper() periodicity = periodicity.upper()
periodicity_list = ("X" in periodicity, "Y" in periodicity, "Z" in periodicity) periodicity_list = ("X" in periodicity, "Y" in periodicity, "Z" in periodicity)
...@@ -155,7 +185,7 @@ class CP2KInputParser(BasicParser): ...@@ -155,7 +185,7 @@ class CP2KInputParser(BasicParser):
#======================================================================= #=======================================================================
# Stress tensor calculation method # Stress tensor calculation method
stress_tensor_method = self.input_tree.get_keyword("FORCE_EVAL/STRESS_TENSOR") stress_tensor_method = self.input_tree.get_keyword_value_formatted("FORCE_EVAL/STRESS_TENSOR")
if stress_tensor_method != "NONE": if stress_tensor_method != "NONE":
mapping = { mapping = {
"NUMERICAL": "Numerical", "NUMERICAL": "Numerical",
...@@ -179,7 +209,7 @@ class CP2KInputParser(BasicParser): ...@@ -179,7 +209,7 @@ class CP2KInputParser(BasicParser):
normalized_path = path normalized_path = path
# Path is relative, project name added # Path is relative, project name added
else: else:
project_name = self.input_tree.get_keyword("GLOBAL/PROJECT_NAME") project_name = self.input_tree.get_keyword_value_formatted("GLOBAL/PROJECT_NAME")
if path: if path:
normalized_path = "{}-{}".format(project_name, path) normalized_path = "{}-{}".format(project_name, path)
else: else:
...@@ -248,7 +278,7 @@ class CP2KInputParser(BasicParser): ...@@ -248,7 +278,7 @@ class CP2KInputParser(BasicParser):
for line in self.input_lines: for line in self.input_lines:
# Remove comments and whitespaces # Remove comments and whitespaces
line = line.split('!', 1)[0].strip() line = line.split('!', 1)[0].split('#', 1)[0].strip()
# Skip empty lines # Skip empty lines
if len(line) == 0: if len(line) == 0:
...@@ -286,10 +316,14 @@ class CP2KInputParser(BasicParser): ...@@ -286,10 +316,14 @@ class CP2KInputParser(BasicParser):
else: else:
split = line.split(None, 1) split = line.split(None, 1)
if len(split) <= 1: if len(split) <= 1:
raise IndexError("A keyword in the CP2K input had no value associated with it. The line causing this is: '{}'".format(line)) keyword_value = ""
else:
keyword_value = split[1]
keyword_name = split[0].upper() keyword_name = split[0].upper()
keyword_value = split[1] try:
self.input_tree.set_keyword(path + "/" + keyword_name, keyword_value) self.input_tree.set_keyword(path + "/" + keyword_name, keyword_value)
except UnboundLocalError:
print line
# Here we store some exceptional print settings that are # Here we store some exceptional print settings that are
# inportant to the parsing. These dont exist in the input tree # inportant to the parsing. These dont exist in the input tree
...@@ -336,7 +370,8 @@ class CP2KInputParser(BasicParser): ...@@ -336,7 +370,8 @@ class CP2KInputParser(BasicParser):
section_parameter = section.section_parameter section_parameter = section.section_parameter
if section_parameter is not None: if section_parameter is not None:
name = "{}.SECTION_PARAMETERS".format(path) name = "{}.SECTION_PARAMETERS".format(path)
self.backend.addValue(name, section_parameter.value) if section_parameter.value is not None:
self.backend.addValue(name, section_parameter.value)
# Default keyword # Default keyword
default_keyword = section.default_keyword default_keyword = section.default_keyword
......
...@@ -622,12 +622,12 @@ if __name__ == '__main__': ...@@ -622,12 +622,12 @@ if __name__ == '__main__':
suites.append(unittest.TestLoader().loadTestsFromTestCase(TestEnergyForce)) suites.append(unittest.TestLoader().loadTestsFromTestCase(TestEnergyForce))
suites.append(unittest.TestLoader().loadTestsFromTestCase(TestStressTensorMethods)) suites.append(unittest.TestLoader().loadTestsFromTestCase(TestStressTensorMethods))
suites.append(unittest.TestLoader().loadTestsFromTestCase(TestSelfInteractionCorrectionMethod)) suites.append(unittest.TestLoader().loadTestsFromTestCase(TestSelfInteractionCorrectionMethod))
suites.append(unittest.TestLoader().loadTestsFromTestCase(TestConfigurationPeriodicDimensions)) # suites.append(unittest.TestLoader().loadTestsFromTestCase(TestConfigurationPeriodicDimensions))
suites.append(unittest.TestLoader().loadTestsFromTestCase(TestSCFConvergence)) # suites.append(unittest.TestLoader().loadTestsFromTestCase(TestSCFConvergence))
suites.append(unittest.TestLoader().loadTestsFromTestCase(TestForceFiles)) # suites.append(unittest.TestLoader().loadTestsFromTestCase(TestForceFiles))
suites.append(unittest.TestLoader().loadTestsFromTestCase(TestPreprocessor)) # suites.append(unittest.TestLoader().loadTestsFromTestCase(TestPreprocessor))
suites.append(unittest.TestLoader().loadTestsFromTestCase(TestGeoOpt)) # suites.append(unittest.TestLoader().loadTestsFromTestCase(TestGeoOpt))
suites.append(unittest.TestLoader().loadTestsFromTestCase(TestGeoOptTrajFormats)) # suites.append(unittest.TestLoader().loadTestsFromTestCase(TestGeoOptTrajFormats))
suites.append(unittest.TestLoader().loadTestsFromTestCase(TestGeoOptOptimizers)) # suites.append(unittest.TestLoader().loadTestsFromTestCase(TestGeoOptOptimizers))
alltests = unittest.TestSuite(suites) alltests = unittest.TestSuite(suites)
unittest.TextTestRunner(verbosity=0).run(alltests) unittest.TextTestRunner(verbosity=0).run(alltests)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment