diff --git a/nomad/metainfo/metainfo.py b/nomad/metainfo/metainfo.py index 9733fbb5edd8ce81988abc9b70f2e51c519615d6..c82533058630d8bec9b03d09221e701d57b053b3 100644 --- a/nomad/metainfo/metainfo.py +++ b/nomad/metainfo/metainfo.py @@ -2718,6 +2718,7 @@ class Section(Definition): def __init_metainfo__(self): super().__init_metainfo__() + # Init extending_sections if self.extends_base_section: base_sections_count = len(self.base_sections) if base_sections_count == 0: @@ -2735,6 +2736,21 @@ class Section(Definition): base_section.extending_sections = base_section.extending_sections + [self] + # Transfer properties of inherited and overwriten property definitions that + # have not been overwritten + inherited_properties: Dict[str, Property] = dict() + for base_section in self.all_base_sections: + inherited_properties.update(**base_section.all_properties) + + for property in self.quantities + self.sub_sections: + inherited_property = inherited_properties.get(property.name) + if inherited_property is None: + continue + + for m_quantity in property.m_def.all_quantities.values(): + if not property.m_is_set(m_quantity) and inherited_property.m_is_set(m_quantity): + property.m_set(m_quantity, inherited_property.m_get(m_quantity)) + # validate def validate(definition): errors, warnings = definition.m_all_validate() @@ -2756,9 +2772,8 @@ class Section(Definition): @constraint def unique_names(self): - # start with the names of all base_sections names: Set[str] = set() - for base in list(self.all_base_sections) + self.extending_sections: + for base in self.extending_sections: for quantity in base.quantities + base.sub_sections: for alias in quantity.aliases: names.add(alias) @@ -2947,25 +2962,23 @@ Section.event_handlers = Quantity( @derived(cached=True) -def inherited_sections(self) -> Set[Section]: - result: Set[Section] = set() - result.add(self) - for base_section in self.base_sections: - result.add(base_section) - for base_base_section in base_section.all_base_sections: - result.add(base_base_section) +def inherited_sections(self) -> List[Section]: + result: List[Section] = [] + for base_section in self.all_base_sections: + result.append(base_section) for extending_section in self.extending_sections: - result.add(extending_section) + result.append(extending_section) + result.append(self) return result @derived(cached=True) -def all_base_sections(self) -> Set[Section]: - result: Set[Section] = set() +def all_base_sections(self) -> List[Section]: + result: List[Section] = [] for base_section in self.base_sections: - result.add(base_section) for base_base_section in base_section.all_base_sections: - result.add(base_base_section) + result.append(base_base_section) + result.append(base_section) return result diff --git a/tests/metainfo/test_metainfo.py b/tests/metainfo/test_metainfo.py index 98b0e71238973600f2f21977fef7944dd406e1db..217454d6f7d310b830ed7e55ce9772a6a4cfb58d 100644 --- a/tests/metainfo/test_metainfo.py +++ b/tests/metainfo/test_metainfo.py @@ -191,14 +191,15 @@ class TestM2: class TestBase(MSection): name = Quantity(type=str) - with pytest.raises(MetainfoError): - class TestSection(TestBase): # pylint: disable=unused-variable - name = Quantity(type=int) + # this is possible, can overwrite existing quantity + class TestSection(TestBase): # pylint: disable=unused-variable + name = Quantity(type=int) def test_unique_names_extends(self): class TestBase(MSection): name = Quantity(type=str) + # this is not possible, cant replace existing quantity with pytest.raises(MetainfoError): class TestSection(TestBase): # pylint: disable=unused-variable m_def = Section(extends_base_section=True) diff --git a/tests/metainfo/test_sections.py b/tests/metainfo/test_sections.py new file mode 100644 index 0000000000000000000000000000000000000000..4ff970ea0332c0508d60e313cabc2147baa8ef85 --- /dev/null +++ b/tests/metainfo/test_sections.py @@ -0,0 +1,160 @@ +# +# Copyright The NOMAD Authors. +# +# This file is part of NOMAD. See https://nomad-lab.eu for further info. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Contains more general test cases that are replaced continiously by more specialized +# in-depth tests in test_* files of the same module. + +import pytest + +from nomad.metainfo import MSection +from nomad.metainfo.metainfo import Quantity, SubSection, Section + + +def test_base_section(): + class BaseSection(MSection): + pass + + class Section(BaseSection): + pass + + assert Section.m_def.base_sections == [BaseSection.m_def] + assert BaseSection.m_def.base_sections == [] + + assert isinstance(Section(), Section) + assert isinstance(Section(), BaseSection) + + +def test_quantity_inheritance(): + class BaseSection(MSection): + test_quantity = Quantity(type=str) + + class Section(BaseSection): + pass + + assert 'test_quantity' in Section.m_def.all_properties + assert Section.test_quantity == BaseSection.test_quantity + + section = Section(test_quantity='test_value') + assert section.test_quantity == 'test_value' + assert section.m_to_dict()['test_quantity'] == 'test_value' + + +def test_quantity_overwrite(): + class BaseSection(MSection): + test_quantity = Quantity(type=str) + + class Section(BaseSection): + test_quantity = Quantity(type=int) + + assert 'test_quantity' in Section.m_def.all_properties + assert Section.m_def.all_properties['test_quantity'] == Section.test_quantity + assert BaseSection.test_quantity != Section.test_quantity + assert BaseSection.m_def.all_properties['test_quantity'] != Section.m_def.all_properties['test_quantity'] + + with pytest.raises(TypeError): + Section(test_quantity='test_value') + + section = Section(test_quantity=1) + assert section.test_quantity == 1 + assert section.m_to_dict()['test_quantity'] == 1 + + +def test_quantity_partial_overwrite(): + class BaseSection(MSection): + test_quantity = Quantity(type=str, description='test_description') + + class Section(BaseSection): + test_quantity = Quantity(type=int) + + assert Section.test_quantity.description == 'test_description' + assert Section.test_quantity.type == int + + +def test_sub_section_inheritance(): + class OtherSection(MSection): + pass + + class BaseSection(MSection): + test_sub_section = SubSection(sub_section=OtherSection) + + class Section(BaseSection): + pass + + assert 'test_sub_section' in Section.m_def.all_properties + assert Section.test_sub_section == BaseSection.test_sub_section + + section = Section(test_sub_section=OtherSection()) + assert section.test_sub_section.m_def == OtherSection.m_def + assert section.m_to_dict()['test_sub_section'] == {} + + +def test_sub_section_overwrite(): + class OtherSection(MSection): + pass + + class BaseSection(MSection): + test_sub_section = SubSection(sub_section=OtherSection) + + class Section(BaseSection): + test_sub_section = SubSection(sub_section=OtherSection, repeats=True) + + assert 'test_sub_section' in Section.m_def.all_properties + assert Section.m_def.all_properties['test_sub_section'] == Section.test_sub_section + assert BaseSection.test_sub_section != Section.test_sub_section + assert BaseSection.m_def.all_properties['test_sub_section'] != Section.m_def.all_properties['test_sub_section'] + + with pytest.raises(TypeError): + Section(test_sub_section=OtherSection()) + + section = Section(test_sub_section=[OtherSection()]) + assert len(section.test_sub_section) == 1 + assert section.m_to_dict()['test_sub_section'] == [{}] + + +def test_sub_section_partial_overwrite(): + class OtherSection(MSection): + pass + + class BaseSection(MSection): + test_sub_section = SubSection(sub_section=OtherSection, description='test_description') + + class Section(BaseSection): + test_sub_section = SubSection(repeats=True) + + assert Section.test_sub_section.description == 'test_description' + assert Section.test_sub_section.sub_section == OtherSection.m_def + assert Section.test_sub_section.repeats + + +def test_overwrite_programmatic(): + class BaseSection(MSection): + test_quantity = Quantity(type=str, description='test_description') + + section_def = Section(name='Section', base_sections=[BaseSection.m_def]) + section_def.m_add_sub_section( + Section.quantities, + Quantity(name='test_quantity')) + + # This happens automatically in Python class based section defintions, but + # has to be called manually in programatic section definitions. + section_def.__init_metainfo__() + + section_cls = section_def.section_cls + + assert section_cls.test_quantity.description == 'test_description' + assert section_cls.test_quantity.type == str