Commit baef37b7 authored by Lauri Himanen's avatar Lauri Himanen
Browse files

Simplified the inputparser a bit.

parent 73e24472
......@@ -106,28 +106,11 @@ class CP2KInput(object):
else:
return (None, section)
def get_keyword(self, path, format_value=True):
if format_value:
return self.get_keyword_value_formatted(path)
else:
return self.get_keyword_value_raw(path)
def get_keyword_value_formatted(self, path):
"""
"""
keyword, section = self.get_keyword_and_section(path)
if keyword:
return keyword.get_value_formatted()
def get_keyword(self, path, raw=False, allow_default=True):
def get_keyword_value_raw(self, path):
"""
"""
keyword, section = self.get_keyword_and_section(path)
if keyword:
return keyword.get_value_raw()
def get_default_keyword(self, path):
return self.get_section(path).default_keyword.value
return keyword.get_value(raw, allow_default)
def set_section_accessed(self, path):
section = self.get_section(path)
......@@ -137,11 +120,6 @@ class CP2KInput(object):
message = "The CP2K input does not contain the section {}".format(path)
logger.warning(message)
def get_keyword_default(self, path):
keyword, section = self.get_keyword_and_section(path)
if keyword:
return keyword.default_value
def get_default_unit(self, path):
keyword, section = self.get_keyword_and_section(path)
if keyword:
......@@ -171,6 +149,68 @@ class CP2KInput(object):
return parameter.lone_value
#===============================================================================
class Section(object):
"""An input section in a CP2K calculation.
"""
__slots__ = ['accessed', 'name', 'keywords', 'default_keyword_names', 'default_keyword', 'section_parameter', 'sections', 'description']
def __init__(self, name):
self.accessed = False
self.name = name
self.keywords = defaultdict(list)
self.default_keyword_names = []
self.default_keyword = None
self.section_parameter = None
self.sections = defaultdict(list)
self.description = None
def get_keyword_object(self, name):
keyword = self.keywords.get(name)
if keyword:
if len(keyword) == 1:
return keyword[0]
else:
logger.error("The keyword '{}' in '{}' does not exist or has too many entries.".format(name, self.name))
def get_keyword(self, name, raw=False, allow_default=True):
"""Returns the keyword value for the given name.
Args:
name: The name of the keyword
raw: Boolean indicating if the raw value (not modified in any way)
should be returned.
allow_default: Boolean indicating if it is allowed to return the
default value is no actual value was set by the user in the input.
"""
keyword_object = self.get_keyword_object(name)
return keyword_object.get_value(raw, allow_default)
def get_subsection(self, name):
subsection = self.sections.get(name)
if subsection:
if len(subsection) == 1:
return subsection[0]
else:
logger.error("The subsection '{}' in '{}' has too many entries.".format(name, self.name))
else:
logger.error("The subsection '{}' in '{}' does not exist.".format(name, self.name))
def get_subsection_list(self, name):
subsection = self.sections.get(name)
return subsection
def get_section_parameter(self):
"""Get the section parameter, or if not specified the lone keyword
value.
"""
if self.section_parameter is not None:
value = self.section_parameter.value
if value is None:
value = self.section_parameter.lone_keyword_value
return value.upper()
#===============================================================================
class InputObject(object):
"""Base class for all kind of data elements in the CP2K input.
......@@ -200,15 +240,23 @@ class Keyword(InputObject):
self.default_value = default_value
self.default_name = default_name
def get_value_raw(self):
def get_value(self, raw=False, allow_default=True):
if raw:
return self._get_value_raw()
else:
return self._get_value_formatted(allow_default)
def _get_value_raw(self):
"""Returns the unformatted value of this keyword. This is exactly what
was set by the used in the input as a string.
"""
return self.value
def get_value_formatted(self):
def _get_value_formatted(self, allow_default=False):
"""Returns the value stored in this keyword by removing the possible
unit definition and formatting the string into the correct data type.
If asked, will use the default value if not actual value was set by
user.
"""
# Decode the unit and the value if not done before
proper_value = None
......@@ -219,6 +267,7 @@ class Keyword(InputObject):
proper_value = self.value_no_unit
else:
proper_value = self.value
# if allow_default:
if proper_value is None:
proper_value = self.default_value
if proper_value is None:
......@@ -313,71 +362,6 @@ class Keyword(InputObject):
self.value_no_unit = self.value
#===============================================================================
class Section(object):
"""An input section in a CP2K calculation.
"""
__slots__ = ['accessed', 'name', 'keywords', 'default_keyword_names', 'default_keyword', 'section_parameter', 'sections', 'description']
def __init__(self, name):
self.accessed = False
self.name = name
self.keywords = defaultdict(list)
self.default_keyword_names = []
self.default_keyword = None
self.section_parameter = None
self.sections = defaultdict(list)
self.description = None
def get_keyword_object(self, name):
keyword = self.keywords.get(name)
if keyword:
if len(keyword) == 1:
return keyword[0]
else:
logger.error("The keyword '{}' in '{}' does not exist or has too many entries.".format(name, self.name))
def get_keyword_value_formatted(self, name):
"""Returns the keyword value formatted to the correct shape and type,
and returns the default value if nothing was specified.
"""
keyword_object = self.get_keyword_object(name)
if keyword_object is not None:
value = keyword_object.get_value_formatted()
return value
def get_keyword_value_raw(self, name):
"""Returns the keyword value as a raw string as specfied by the used.
"""
keyword_object = self.get_keyword_object(name)
if keyword_object is not None:
return keyword_object.get_value_raw()
def get_subsection(self, name):
subsection = self.sections.get(name)
if subsection:
if len(subsection) == 1:
return subsection[0]
else:
logger.error("The subsection '{}' in '{}' has too many entries.".format(name, self.name))
else:
logger.error("The subsection '{}' in '{}' does not exist.".format(name, self.name))
def get_subsection_list(self, name):
subsection = self.sections.get(name)
return subsection
def get_section_parameter(self):
"""Get the section parameter, or if not specified the lone keyword
value.
"""
if self.section_parameter is not None:
value = self.section_parameter.value
if value is None:
value = self.section_parameter.lone_keyword_value
return value.upper()
#===============================================================================
class SectionParameters(InputObject):
"""Section parameters in a CP2K calculation.
......
......@@ -135,9 +135,9 @@ class CP2KInputParser(BasicParser):
if pbe.accessed:
sp = pbe.get_section_parameter()
if sp == "T":
parametrization = pbe.get_keyword("PARAMETRIZATION")
scale_x = pbe.get_keyword("SCALE_X")
scale_c = pbe.get_keyword("SCALE_C")
parametrization = pbe.get_keyword("PARAMETRIZATION", allow_default=True)
scale_x = pbe.get_keyword("SCALE_X", allow_default=True)
scale_c = pbe.get_keyword("SCALE_C", allow_default=True)
if parametrization == "ORIG":
xc_list.append(XCFunctional("GGA_X_PBE", scale_x))
xc_list.append(XCFunctional("GGA_C_PBE", scale_c))
......@@ -152,8 +152,8 @@ class CP2KInputParser(BasicParser):
if tpss.accessed:
sp = tpss.get_section_parameter()
if sp == "T":
scale_x = tpss.get_keyword("SCALE_X")
scale_c = tpss.get_keyword("SCALE_C")
scale_x = tpss.get_keyword("SCALE_X", allow_default=True)
scale_c = tpss.get_keyword("SCALE_C", allow_default=True)
xc_list.append(XCFunctional("MGGA_X_TPSS", scale_x))
xc_list.append(XCFunctional("MGGA_C_TPSS", scale_c))
......
......@@ -1022,9 +1022,9 @@ if __name__ == '__main__':
logger.setLevel(logging.ERROR)
suites = []
suites.append(unittest.TestLoader().loadTestsFromTestCase(TestErrors))
suites.append(unittest.TestLoader().loadTestsFromTestCase(TestXCFunctional))
suites.append(unittest.TestLoader().loadTestsFromTestCase(TestEnergyForce))
# suites.append(unittest.TestLoader().loadTestsFromTestCase(TestErrors))
# suites.append(unittest.TestLoader().loadTestsFromTestCase(TestXCFunctional))
# suites.append(unittest.TestLoader().loadTestsFromTestCase(TestEnergyForce))
suites.append(unittest.TestLoader().loadTestsFromTestCase(TestStressTensorMethods))
suites.append(unittest.TestLoader().loadTestsFromTestCase(TestSelfInteractionCorrectionMethod))
suites.append(unittest.TestLoader().loadTestsFromTestCase(TestConfigurationPeriodicDimensions))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment