Commit baef37b7 authored by Lauri Himanen's avatar Lauri Himanen
Browse files

Simplified the inputparser a bit.

parent 73e24472
...@@ -106,28 +106,11 @@ class CP2KInput(object): ...@@ -106,28 +106,11 @@ class CP2KInput(object):
else: else:
return (None, section) return (None, section)
def get_keyword(self, path, format_value=True): def get_keyword(self, path, raw=False, allow_default=True):
if format_value:
return self.get_keyword_value_formatted(path)
else:
return self.get_keyword_value_raw(path)
def get_keyword_value_formatted(self, path):
"""
"""
keyword, section = self.get_keyword_and_section(path)
if keyword:
return keyword.get_value_formatted()
def get_keyword_value_raw(self, path):
"""
"""
keyword, section = self.get_keyword_and_section(path) keyword, section = self.get_keyword_and_section(path)
if keyword: if keyword:
return keyword.get_value_raw() return keyword.get_value(raw, allow_default)
def get_default_keyword(self, path):
return self.get_section(path).default_keyword.value
def set_section_accessed(self, path): def set_section_accessed(self, path):
section = self.get_section(path) section = self.get_section(path)
...@@ -137,11 +120,6 @@ class CP2KInput(object): ...@@ -137,11 +120,6 @@ class CP2KInput(object):
message = "The CP2K input does not contain the section {}".format(path) message = "The CP2K input does not contain the section {}".format(path)
logger.warning(message) logger.warning(message)
def get_keyword_default(self, path):
keyword, section = self.get_keyword_and_section(path)
if keyword:
return keyword.default_value
def get_default_unit(self, path): def get_default_unit(self, path):
keyword, section = self.get_keyword_and_section(path) keyword, section = self.get_keyword_and_section(path)
if keyword: if keyword:
...@@ -171,6 +149,68 @@ class CP2KInput(object): ...@@ -171,6 +149,68 @@ class CP2KInput(object):
return parameter.lone_value return parameter.lone_value
#===============================================================================
class Section(object):
"""An input section in a CP2K calculation.
"""
__slots__ = ['accessed', 'name', 'keywords', 'default_keyword_names', 'default_keyword', 'section_parameter', 'sections', 'description']
def __init__(self, name):
self.accessed = False
self.name = name
self.keywords = defaultdict(list)
self.default_keyword_names = []
self.default_keyword = None
self.section_parameter = None
self.sections = defaultdict(list)
self.description = None
def get_keyword_object(self, name):
keyword = self.keywords.get(name)
if keyword:
if len(keyword) == 1:
return keyword[0]
else:
logger.error("The keyword '{}' in '{}' does not exist or has too many entries.".format(name, self.name))
def get_keyword(self, name, raw=False, allow_default=True):
"""Returns the keyword value for the given name.
Args:
name: The name of the keyword
raw: Boolean indicating if the raw value (not modified in any way)
should be returned.
allow_default: Boolean indicating if it is allowed to return the
default value is no actual value was set by the user in the input.
"""
keyword_object = self.get_keyword_object(name)
return keyword_object.get_value(raw, allow_default)
def get_subsection(self, name):
subsection = self.sections.get(name)
if subsection:
if len(subsection) == 1:
return subsection[0]
else:
logger.error("The subsection '{}' in '{}' has too many entries.".format(name, self.name))
else:
logger.error("The subsection '{}' in '{}' does not exist.".format(name, self.name))
def get_subsection_list(self, name):
subsection = self.sections.get(name)
return subsection
def get_section_parameter(self):
"""Get the section parameter, or if not specified the lone keyword
value.
"""
if self.section_parameter is not None:
value = self.section_parameter.value
if value is None:
value = self.section_parameter.lone_keyword_value
return value.upper()
#=============================================================================== #===============================================================================
class InputObject(object): class InputObject(object):
"""Base class for all kind of data elements in the CP2K input. """Base class for all kind of data elements in the CP2K input.
...@@ -200,15 +240,23 @@ class Keyword(InputObject): ...@@ -200,15 +240,23 @@ class Keyword(InputObject):
self.default_value = default_value self.default_value = default_value
self.default_name = default_name self.default_name = default_name
def get_value_raw(self): def get_value(self, raw=False, allow_default=True):
if raw:
return self._get_value_raw()
else:
return self._get_value_formatted(allow_default)
def _get_value_raw(self):
"""Returns the unformatted value of this keyword. This is exactly what """Returns the unformatted value of this keyword. This is exactly what
was set by the used in the input as a string. was set by the used in the input as a string.
""" """
return self.value return self.value
def get_value_formatted(self): def _get_value_formatted(self, allow_default=False):
"""Returns the value stored in this keyword by removing the possible """Returns the value stored in this keyword by removing the possible
unit definition and formatting the string into the correct data type. unit definition and formatting the string into the correct data type.
If asked, will use the default value if not actual value was set by
user.
""" """
# Decode the unit and the value if not done before # Decode the unit and the value if not done before
proper_value = None proper_value = None
...@@ -219,6 +267,7 @@ class Keyword(InputObject): ...@@ -219,6 +267,7 @@ class Keyword(InputObject):
proper_value = self.value_no_unit proper_value = self.value_no_unit
else: else:
proper_value = self.value proper_value = self.value
# if allow_default:
if proper_value is None: if proper_value is None:
proper_value = self.default_value proper_value = self.default_value
if proper_value is None: if proper_value is None:
...@@ -313,71 +362,6 @@ class Keyword(InputObject): ...@@ -313,71 +362,6 @@ class Keyword(InputObject):
self.value_no_unit = self.value self.value_no_unit = self.value
#===============================================================================
class Section(object):
"""An input section in a CP2K calculation.
"""
__slots__ = ['accessed', 'name', 'keywords', 'default_keyword_names', 'default_keyword', 'section_parameter', 'sections', 'description']
def __init__(self, name):
self.accessed = False
self.name = name
self.keywords = defaultdict(list)
self.default_keyword_names = []
self.default_keyword = None
self.section_parameter = None
self.sections = defaultdict(list)
self.description = None
def get_keyword_object(self, name):
keyword = self.keywords.get(name)
if keyword:
if len(keyword) == 1:
return keyword[0]
else:
logger.error("The keyword '{}' in '{}' does not exist or has too many entries.".format(name, self.name))
def get_keyword_value_formatted(self, name):
"""Returns the keyword value formatted to the correct shape and type,
and returns the default value if nothing was specified.
"""
keyword_object = self.get_keyword_object(name)
if keyword_object is not None:
value = keyword_object.get_value_formatted()
return value
def get_keyword_value_raw(self, name):
"""Returns the keyword value as a raw string as specfied by the used.
"""
keyword_object = self.get_keyword_object(name)
if keyword_object is not None:
return keyword_object.get_value_raw()
def get_subsection(self, name):
subsection = self.sections.get(name)
if subsection:
if len(subsection) == 1:
return subsection[0]
else:
logger.error("The subsection '{}' in '{}' has too many entries.".format(name, self.name))
else:
logger.error("The subsection '{}' in '{}' does not exist.".format(name, self.name))
def get_subsection_list(self, name):
subsection = self.sections.get(name)
return subsection
def get_section_parameter(self):
"""Get the section parameter, or if not specified the lone keyword
value.
"""
if self.section_parameter is not None:
value = self.section_parameter.value
if value is None:
value = self.section_parameter.lone_keyword_value
return value.upper()
#=============================================================================== #===============================================================================
class SectionParameters(InputObject): class SectionParameters(InputObject):
"""Section parameters in a CP2K calculation. """Section parameters in a CP2K calculation.
......
...@@ -135,9 +135,9 @@ class CP2KInputParser(BasicParser): ...@@ -135,9 +135,9 @@ class CP2KInputParser(BasicParser):
if pbe.accessed: if pbe.accessed:
sp = pbe.get_section_parameter() sp = pbe.get_section_parameter()
if sp == "T": if sp == "T":
parametrization = pbe.get_keyword("PARAMETRIZATION") parametrization = pbe.get_keyword("PARAMETRIZATION", allow_default=True)
scale_x = pbe.get_keyword("SCALE_X") scale_x = pbe.get_keyword("SCALE_X", allow_default=True)
scale_c = pbe.get_keyword("SCALE_C") scale_c = pbe.get_keyword("SCALE_C", allow_default=True)
if parametrization == "ORIG": if parametrization == "ORIG":
xc_list.append(XCFunctional("GGA_X_PBE", scale_x)) xc_list.append(XCFunctional("GGA_X_PBE", scale_x))
xc_list.append(XCFunctional("GGA_C_PBE", scale_c)) xc_list.append(XCFunctional("GGA_C_PBE", scale_c))
...@@ -152,8 +152,8 @@ class CP2KInputParser(BasicParser): ...@@ -152,8 +152,8 @@ class CP2KInputParser(BasicParser):
if tpss.accessed: if tpss.accessed:
sp = tpss.get_section_parameter() sp = tpss.get_section_parameter()
if sp == "T": if sp == "T":
scale_x = tpss.get_keyword("SCALE_X") scale_x = tpss.get_keyword("SCALE_X", allow_default=True)
scale_c = tpss.get_keyword("SCALE_C") scale_c = tpss.get_keyword("SCALE_C", allow_default=True)
xc_list.append(XCFunctional("MGGA_X_TPSS", scale_x)) xc_list.append(XCFunctional("MGGA_X_TPSS", scale_x))
xc_list.append(XCFunctional("MGGA_C_TPSS", scale_c)) xc_list.append(XCFunctional("MGGA_C_TPSS", scale_c))
......
...@@ -1022,9 +1022,9 @@ if __name__ == '__main__': ...@@ -1022,9 +1022,9 @@ if __name__ == '__main__':
logger.setLevel(logging.ERROR) logger.setLevel(logging.ERROR)
suites = [] suites = []
suites.append(unittest.TestLoader().loadTestsFromTestCase(TestErrors)) # suites.append(unittest.TestLoader().loadTestsFromTestCase(TestErrors))
suites.append(unittest.TestLoader().loadTestsFromTestCase(TestXCFunctional)) # suites.append(unittest.TestLoader().loadTestsFromTestCase(TestXCFunctional))
suites.append(unittest.TestLoader().loadTestsFromTestCase(TestEnergyForce)) # suites.append(unittest.TestLoader().loadTestsFromTestCase(TestEnergyForce))
suites.append(unittest.TestLoader().loadTestsFromTestCase(TestStressTensorMethods)) suites.append(unittest.TestLoader().loadTestsFromTestCase(TestStressTensorMethods))
suites.append(unittest.TestLoader().loadTestsFromTestCase(TestSelfInteractionCorrectionMethod)) suites.append(unittest.TestLoader().loadTestsFromTestCase(TestSelfInteractionCorrectionMethod))
suites.append(unittest.TestLoader().loadTestsFromTestCase(TestConfigurationPeriodicDimensions)) suites.append(unittest.TestLoader().loadTestsFromTestCase(TestConfigurationPeriodicDimensions))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment