Commit 718c4497 authored by Markus Scheidgen's avatar Markus Scheidgen
Browse files

Allow multiple sub-section of the same kind in metainfo.

parent b57e2de3
......@@ -886,16 +886,37 @@ class MSection(metaclass=MObjectMeta):
""" Returns the number of sub sections for the given sub section definition. """
return self.m_data.m_sub_section_count(self, sub_section_def)
def m_create(self, section_cls: Type[MSectionBound], **kwargs) -> MSectionBound:
def m_create(
self, section_cls: Type[MSectionBound], sub_section_def: 'SubSection' = None,
**kwargs) -> MSectionBound:
""" Creates a section instance and adds it to this section provided there is a
corresponding sub section.
Args:
section_cls: The section class for the sub-secton to create
sub_section_def: If there are multiple sub-sections for the given class,
this must be used to explicitely state the sub-section definition.
"""
section_def = section_cls.m_def
sub_section_def = self.m_def.all_sub_sections_by_section.get(section_def, None)
if sub_section_def is None:
sub_section_defs = self.m_def.all_sub_sections_by_section.get(section_def, [])
n_sub_section_defs = len(sub_section_defs)
if n_sub_section_defs == 0:
raise TypeError('There is no sub section to hold a %s in %s.' % (section_def, self.m_def))
if n_sub_section_defs > 1 and sub_section_def is None:
raise MetainfoError(
'There are multiple sub section to hold a %s in %s, '
'but no sub-section was explicitely given.' % (section_def, self.m_def))
if sub_section_def is not None and sub_section_def not in sub_section_defs:
raise MetainfoError(
'The given sub-section class %s does not match the given sub-section '
'definition %s.' % (section_cls, sub_section_def))
if sub_section_def is None:
sub_section_def = sub_section_defs[0]
sub_section = section_cls(**kwargs)
self.m_add_sub_section(sub_section_def, sub_section)
......@@ -1620,7 +1641,7 @@ class Section(Definition):
all_sub_sections_by_section:
A helper property that gives all sub-section definition including inherited ones
as a dictionary that maps section classes (i.e. Python class objects) to
:class:`SubSection`.
lists of :class:`SubSection`.
"""
section_cls: Type[MSection] = None
......@@ -1640,7 +1661,7 @@ class Section(Definition):
self.all_properties: Dict[str, Union['SubSection', Quantity]] = dict()
self.all_quantities: Dict[str, Quantity] = dict()
self.all_sub_sections: Dict[str, SubSection] = dict()
self.all_sub_sections_by_section: Dict['Section', 'SubSection'] = dict()
self.all_sub_sections_by_section: Dict['Section', List['SubSection']] = dict()
def on_add_sub_section(self, sub_section_def, sub_section):
if sub_section_def == Section.quantities:
......@@ -1650,7 +1671,8 @@ class Section(Definition):
if sub_section_def == Section.sub_sections:
self.all_properties[sub_section.name] = sub_section
self.all_sub_sections[sub_section.name] = sub_section
self.all_sub_sections_by_section[sub_section.sub_section] = sub_section
self.all_sub_sections_by_section.setdefault(
sub_section.sub_section, []).append(sub_section)
def on_set(self, quantity_def, value):
if quantity_def == Section.base_sections:
......@@ -1677,13 +1699,6 @@ class Section(Definition):
assert definition.name not in names, 'All names in a section must be unique.'
names.add(definition.name)
def c_unique_sub_sections(self):
sub_sections = set()
for sub_section in self.all_sub_sections.values():
assert sub_section.sub_section not in sub_sections, \
'The same section definition can only be used in one sub-section'
sub_sections.add(sub_section.sub_section)
class Package(Definition):
""" Packages organize metainfo defintions alongside Python modules
......
......@@ -118,7 +118,8 @@ class TestM2:
assert len(Run.m_def.sub_sections) == 3
assert Run.m_def.all_sub_sections['systems'] in Run.m_def.sub_sections
assert Run.m_def.all_sub_sections['systems'].sub_section == System.m_def
assert Run.m_def.all_sub_sections_by_section[System.m_def].sub_section == System.m_def
assert len(Run.m_def.all_sub_sections_by_section[System.m_def]) == 1
assert Run.m_def.all_sub_sections_by_section[System.m_def][0].sub_section == System.m_def
def test_properties(self):
assert len(Run.m_def.all_properties) == 6
......@@ -184,11 +185,12 @@ class TestM2:
m_def = Section(extends_base_section=True)
name = Quantity(type=int)
def test_unique_sub_sections(self):
with assert_exception(MetainfoError):
class TestSection(MSection): # pylint: disable=unused-variable
one = SubSection(sub_section=System)
two = SubSection(sub_section=System)
def test_multiple_sub_sections(self):
class TestSection(MSection): # pylint: disable=unused-variable
one = SubSection(sub_section=System)
two = SubSection(sub_section=System)
assert len(TestSection.m_def.all_sub_sections_by_section[System.m_def]) == 2
def test_dimension_exists(self):
with assert_exception(MetainfoError):
......@@ -246,13 +248,9 @@ class TestM1:
def test_defaults(self):
assert len(System().periodic_dimensions) == 3
assert System().atom_labels is None
try:
with assert_exception(AttributeError):
System().does_not_exist
assert False, 'Supposed unreachable'
except AttributeError:
pass
else:
assert False, 'Expected AttributeError'
def test_m_section(self):
assert Run().m_def == Run.m_def
......@@ -279,31 +277,16 @@ class TestM1:
assert parsing.m_parent_index == -1
def test_wrong_type(self):
try:
with assert_exception(TypeError):
Run().code_name = 1
assert False, 'Supposed unreachable'
except TypeError:
pass
else:
assert False, 'Expected TypeError'
def test_wrong_shape_1(self):
try:
with assert_exception(TypeError):
Run().code_name = ['name']
assert False, 'Supposed unreachable'
except TypeError:
pass
else:
assert False, 'Expected TypeError'
def test_wrong_shape_2(self):
try:
with assert_exception(TypeError):
System().atom_labels = 'label'
assert False, 'Supposed unreachable'
except TypeError:
pass
else:
assert False, 'Expected TypeError'
def test_np(self):
system = System()
......@@ -353,13 +336,8 @@ class TestM1:
def test_derived(self):
system = System()
try:
with assert_exception(DeriveError):
assert system.n_atoms == 3
assert False, 'supposed unreachable'
except DeriveError:
pass
else:
assert False, 'supposed unreachable'
system.atom_labels = ['H', 'H', 'O']
assert system.n_atoms == 3
......@@ -389,3 +367,16 @@ class TestM1:
system.atom_labels = ['H']
system.atom_positions = []
assert len(system.m_validate()) > 0
def test_multiple_sub_sections(self):
class TestSection(MSection): # pylint: disable=unused-variable
one = SubSection(sub_section=System)
two = SubSection(sub_section=System)
test_section = TestSection()
with assert_exception():
test_section.m_create(System)
test_section.m_create(System, TestSection.one)
assert test_section.one is not None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment