Commit f228e47e authored by Markus Scheidgen's avatar Markus Scheidgen
Browse files

Added extensions and derived mechanisms.

parent a224ea3b
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import numpy as np import numpy as np
from nomad.metainfo import MSection, MCategory, Section, Quantity, Package, SubSection, units from nomad.metainfo import MSection, MCategory, Section, Quantity, Package, SubSection, Enum, units
m_package = Package(links=['http://metainfo.nomad-coe.eu']) m_package = Package(links=['http://metainfo.nomad-coe.eu'])
...@@ -24,7 +24,7 @@ class System(MSection): ...@@ -24,7 +24,7 @@ class System(MSection):
""" All data that describes a simulated system. """ """ All data that describes a simulated system. """
n_atoms = Quantity( n_atoms = Quantity(
type=int, default=0, type=int, derived=lambda system: len(system.atom_labels),
description='Number of atoms in the simulated system.') description='Number of atoms in the simulated system.')
atom_labels = Quantity( atom_labels = Quantity(
...@@ -42,7 +42,7 @@ class System(MSection): ...@@ -42,7 +42,7 @@ class System(MSection):
unit_cell = Quantity(synonym_for='lattice_vectors') unit_cell = Quantity(synonym_for='lattice_vectors')
periodic_dimensions = Quantity( periodic_dimensions = Quantity(
type=bool, shape=[3], categories=[SystemHash.m_def], type=bool, shape=[3], default=[False, False, False], categories=[SystemHash.m_def],
description='A vector of booleans indicating in which dimensions the unit cell is repeated.') description='A vector of booleans indicating in which dimensions the unit cell is repeated.')
...@@ -56,13 +56,13 @@ class Run(MSection): ...@@ -56,13 +56,13 @@ class Run(MSection):
parsing = SubSection(sub_section=Parsing.m_def) parsing = SubSection(sub_section=Parsing.m_def)
# class VaspRun(MSection): class VaspRun(Run):
# """ All VASP specific quantities for section Run. """ """ All VASP specific quantities for section Run. """
# m_def = Section(extends=Run.m_def) m_def = Section(extends_base_section=True)
# x_vasp_raw_format = Quantity( x_vasp_raw_format = Quantity(
# type=Enum(['xml', 'outcar']), type=Enum(['xml', 'outcar']),
# description='The file format of the parsed VASP mainfile.') description='The file format of the parsed VASP mainfile.')
if __name__ == '__main__': if __name__ == '__main__':
...@@ -85,16 +85,18 @@ if __name__ == '__main__': ...@@ -85,16 +85,18 @@ if __name__ == '__main__':
run = Run() run = Run()
run.code_name = 'VASP' run.code_name = 'VASP'
run.code_version = '1.0.0' run.code_version = '1.0.0'
run.m_as(VaspRun).x_vasp_raw_format = 'outcar'
# The same as
run.x_vasp_raw_format = 'outcar' # type: ignore
system = run.m_create(System) system = run.m_create(System)
system.n_atoms = 3
system.atom_labels = ['H', 'H', 'O'] system.atom_labels = ['H', 'H', 'O']
# Or to read data from existing metainfo data: # Or to read data from existing metainfo data:
print(system.atom_labels) print(system.atom_labels)
print(system.n_atoms)
# To serialize the data: # To serialize the data:
print(run.m_to_json(indent=2)) print(run.m_to_json(indent=2))
print('###########') # print(m_package.m_to_json(indent=2)) # type: ignore, pylint: disable=undefined-variable
print(m_package.m_to_json(indent=2)) # type: ignore, pylint: disable=undefined-variable
...@@ -160,6 +160,11 @@ class MetainfoError(Exception): ...@@ -160,6 +160,11 @@ class MetainfoError(Exception):
pass pass
class DeriveError(MetainfoError):
""" An error occurred while computing a derived value. """
pass
# Reflection # Reflection
class Enum(list): class Enum(list):
...@@ -475,8 +480,8 @@ class MDataTypeAndShapeChecks(MDataDelegating): ...@@ -475,8 +480,8 @@ class MDataTypeAndShapeChecks(MDataDelegating):
'The value %s is not an enum value for quantity %s.' % 'The value %s is not an enum value for quantity %s.' %
(value, quantity_def)) (value, quantity_def))
elif quantity_def == Quantity.type: elif quantity_def in [Quantity.type, Quantity.derived]:
# TODO check this special case of values used as quantity types # TODO check these special cases for Quantity quantities
pass pass
elif quantity_def.type == Any: elif quantity_def.type == Any:
...@@ -576,9 +581,17 @@ class MSection(metaclass=MObjectMeta): ...@@ -576,9 +581,17 @@ class MSection(metaclass=MObjectMeta):
if self.m_def is None: if self.m_def is None:
self.m_def = cls.m_def self.m_def = cls.m_def
# check m_def
if cls.m_def is not None: if cls.m_def is not None:
assert self.m_def == cls.m_def, \ if self.m_def != cls.m_def:
'Section class and section definition must match' MetainfoError('Section class and section definition must match.')
if self.m_def.extends_base_section:
MetainfoError('Section extends another section and cannot be instantiated.')
else:
if not is_bootstrapping:
MetainfoError('Section has not m_def.')
# get annotations from kwargs # get annotations from kwargs
self.m_annotations: Dict[str, Any] = {} self.m_annotations: Dict[str, Any] = {}
...@@ -614,6 +627,43 @@ class MSection(metaclass=MObjectMeta): ...@@ -614,6 +627,43 @@ class MSection(metaclass=MObjectMeta):
m_def.description = inspect.cleandoc(cls.__doc__).strip() m_def.description = inspect.cleandoc(cls.__doc__).strip()
m_def.section_cls = cls m_def.section_cls = cls
# add base sections
if m_def.extends_base_section:
base_sections_count = len(cls.__bases__)
if base_sections_count == 0:
raise MetainfoError(
'Section %s extend the base section, but has no base section.' % m_def)
elif base_sections_count > 1:
raise MetainfoError(
'Section %s extend the base section, but has more than one base section' % m_def)
base_section_cls = cls.__bases__[0]
base_section = getattr(base_section_cls, 'm_def', None)
if base_section is None:
raise MetainfoError(
'The base section of %s is not a section class.' % m_def)
for name, attr in cls.__dict__.items():
if isinstance(attr, Property):
setattr(base_section_cls, name, attr)
section_to_add_properties_to = base_section
else:
for base_cls in cls.__bases__:
if base_cls != MSection:
base_section = getattr(base_cls, 'm_def')
if base_section is None:
raise TypeError(
'Section defining classes must have MSection or a decendant as '
'base classes.')
base_sections = list(m_def.m_get(Section.base_sections))
base_sections.append(base_section)
m_def.m_set(Section.base_sections, base_sections)
section_to_add_properties_to = m_def
for name, attr in cls.__dict__.items(): for name, attr in cls.__dict__.items():
# transfer names and descriptions for properties # transfer names and descriptions for properties
if isinstance(attr, Property): if isinstance(attr, Property):
...@@ -623,25 +673,12 @@ class MSection(metaclass=MObjectMeta): ...@@ -623,25 +673,12 @@ class MSection(metaclass=MObjectMeta):
attr.__doc__ = attr.description attr.__doc__ = attr.description
if isinstance(attr, Quantity): if isinstance(attr, Quantity):
m_def.m_add_sub_section(Section.quantities, attr) section_to_add_properties_to.m_add_sub_section(Section.quantities, attr)
elif isinstance(attr, SubSection): elif isinstance(attr, SubSection):
m_def.m_add_sub_section(Section.sub_sections, attr) section_to_add_properties_to.m_add_sub_section(Section.sub_sections, attr)
else: else:
raise NotImplementedError('Unknown property kind.') raise NotImplementedError('Unknown property kind.')
# add base sections
for base_cls in cls.__bases__:
if base_cls != MSection:
base_section = getattr(base_cls, 'm_def')
if base_section is None:
raise TypeError(
'Section defining classes must have MSection or a decendant as '
'base classes.')
base_sections = list(m_def.m_get(Section.base_sections))
base_sections.append(base_section)
m_def.m_set(Section.base_sections, base_sections)
# add section cls' section to the module's package # add section cls' section to the module's package
module_name = cls.__module__ module_name = cls.__module__
pkg = Package.from_module(module_name) pkg = Package.from_module(module_name)
...@@ -655,15 +692,26 @@ class MSection(metaclass=MObjectMeta): ...@@ -655,15 +692,26 @@ class MSection(metaclass=MObjectMeta):
def m_set(self, quantity_def: 'Quantity', value: Any) -> None: def m_set(self, quantity_def: 'Quantity', value: Any) -> None:
""" Set the given value for the given quantity. """ """ Set the given value for the given quantity. """
quantity_def = self.__resolve_synonym(quantity_def) quantity_def = self.__resolve_synonym(quantity_def)
if quantity_def.derived is not None:
raise MetainfoError('The quantity %s is derived and cannot be set.' % quantity_def)
self.m_data.m_set(self, quantity_def, value) self.m_data.m_set(self, quantity_def, value)
def m_get(self, quantity_def: 'Quantity') -> Any: def m_get(self, quantity_def: 'Quantity') -> Any:
""" Retrieve the given value for the given quantity. """ """ Retrieve the given value for the given quantity. """
quantity_def = self.__resolve_synonym(quantity_def) quantity_def = self.__resolve_synonym(quantity_def)
if quantity_def.derived is not None:
try:
return quantity_def.derived(self)
except Exception as e:
raise DeriveError('Could not derive value for %s: %s' % (quantity_def, str(e)))
return self.m_data.m_get(self, quantity_def) return self.m_data.m_get(self, quantity_def)
def m_is_set(self, quantity_def: 'Quantity') -> bool: def m_is_set(self, quantity_def: 'Quantity') -> bool:
quantity_def = self.__resolve_synonym(quantity_def) quantity_def = self.__resolve_synonym(quantity_def)
if quantity_def.derived is not None:
return True
return self.m_data.m_is_set(self, quantity_def) return self.m_data.m_is_set(self, quantity_def)
def m_add_values(self, quantity_def: 'Quantity', values: Any, offset: int) -> None: def m_add_values(self, quantity_def: 'Quantity', values: Any, offset: int) -> None:
...@@ -729,6 +777,10 @@ class MSection(metaclass=MObjectMeta): ...@@ -729,6 +777,10 @@ class MSection(metaclass=MObjectMeta):
else: else:
self.m_set(prop, value) self.m_set(prop, value)
def m_as(self, section_cls: Type[MSectionBound]) -> MSectionBound:
""" 'Casts' this section to the given extending sections. """
return cast(MSectionBound, self)
def m_follows(self, definition: 'Section') -> bool: def m_follows(self, definition: 'Section') -> bool:
""" Determines if this section's definition is or is derived from the given definition. """ """ Determines if this section's definition is or is derived from the given definition. """
return self.m_def == definition or self.m_def in definition.all_base_sections return self.m_def == definition or self.m_def in definition.all_base_sections
...@@ -753,7 +805,7 @@ class MSection(metaclass=MObjectMeta): ...@@ -753,7 +805,7 @@ class MSection(metaclass=MObjectMeta):
yield name, sub_section.m_to_dict() yield name, sub_section.m_to_dict()
for name, quantity in self.m_def.all_quantities.items(): for name, quantity in self.m_def.all_quantities.items():
if self.m_is_set(quantity): if self.m_is_set(quantity) and quantity.derived is None:
to_json_serializable: Callable[[Any], Any] = str to_json_serializable: Callable[[Any], Any] = str
if isinstance(quantity.type, DataType): if isinstance(quantity.type, DataType):
...@@ -979,6 +1031,7 @@ class Quantity(Property): ...@@ -979,6 +1031,7 @@ class Quantity(Property):
unit: 'Quantity' = None unit: 'Quantity' = None
default: 'Quantity' = None default: 'Quantity' = None
synonym_for: 'Quantity' = None synonym_for: 'Quantity' = None
derived: 'Quantity' = None
# TODO derived_from = Quantity(type=Quantity, shape=['0..*']) # TODO derived_from = Quantity(type=Quantity, shape=['0..*'])
# TODO categories = Quantity(type=Category, shape=['0..*']) # TODO categories = Quantity(type=Category, shape=['0..*'])
...@@ -1081,8 +1134,8 @@ class Section(Definition): ...@@ -1081,8 +1134,8 @@ class Section(Definition):
sub_sections: 'SubSection' = None sub_sections: 'SubSection' = None
base_sections: 'Quantity' = None base_sections: 'Quantity' = None
# TODO extends = Quantity(type=bool), denotes this section as a container for extends_base_section: 'Quantity' = None
# new quantities that belong to the base-class section definitions
@cached_property @cached_property
def all_base_sections(self) -> Set['Section']: def all_base_sections(self) -> Set['Section']:
all_base_sections: Set['Section'] = set() all_base_sections: Set['Section'] = set()
...@@ -1217,7 +1270,12 @@ Section.base_sections = Quantity( ...@@ -1217,7 +1270,12 @@ Section.base_sections = Quantity(
Inherit all quantity and sub section definitions from the given sections. Inherit all quantity and sub section definitions from the given sections.
Will be derived from Python base classes. Will be derived from Python base classes.
''') ''')
Section.extends_base_section = Quantity(
type=bool, default=False, name='extends_base_section',
description='''
If True, the quantity definitions of this section will be added to the base section.
Only one base section is allowed.
''')
SubSection.repeats = Quantity( SubSection.repeats = Quantity(
type=bool, name='repeats', default=False, type=bool, name='repeats', default=False,
...@@ -1276,10 +1334,15 @@ Quantity.default = DirectQuantity( ...@@ -1276,10 +1334,15 @@ Quantity.default = DirectQuantity(
''') ''')
Quantity.synonym_for = DirectQuantity( Quantity.synonym_for = DirectQuantity(
type=str, name='synonym_for', description=''' type=str, name='synonym_for', description='''
With this set, the quantitiy will become a virtual quantity and its data is not stored With this set, the quantity will become a virtual quantity and its data is not stored
directly. Setting and getting quantity, will change the *synonym* quantity instead. Use directly. Setting and getting quantity, will change the *synonym* quantity instead. Use
the name of the quantity as value. the name of the quantity as value.
''') ''')
Quantity.derived = DirectQuantity(
type=Callable, default=None, name='derived', description='''
Derived quantities are computed from other quantities of the same section. The value
of derived needs to be a callable that takes the section and returns a value.
''')
Package.section_definitions = SubSection( Package.section_definitions = SubSection(
sub_section=Section.m_def, name='section_definitions', repeats=True, sub_section=Section.m_def, name='section_definitions', repeats=True,
...@@ -1302,4 +1365,4 @@ Quantity.__init_cls__() ...@@ -1302,4 +1365,4 @@ Quantity.__init_cls__()
print('Metainfo initialization took %d ms' % ((time.time() - start_time) * 1000)) print('Metainfo initialization took %d ms' % ((time.time() - start_time) * 1000))
units = UnitRegistry() units = UnitRegistry()
""" The default pint unit registry that should be used to give units to quantity definitions. """ """ The default pint unit registry that should be used to give units to quantity definitions. """
\ No newline at end of file
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
import pytest import pytest
import numpy as np import numpy as np
from nomad.metainfo.metainfo import MSection, MCategory, Section, Quantity, Definition, Package from nomad.metainfo.metainfo import MSection, MCategory, Section, Quantity, Definition, Package, DeriveError
from nomad.metainfo.example import Run, System, SystemHash, Parsing, m_package as example_package from nomad.metainfo.example import Run, VaspRun, System, SystemHash, Parsing, m_package as example_package
def assert_section_def(section_def: Section): def assert_section_def(section_def: Section):
...@@ -102,7 +102,7 @@ class TestM2: ...@@ -102,7 +102,7 @@ class TestM2:
assert Run.m_def.name == 'Run' assert Run.m_def.name == 'Run'
def test_quantities(self): def test_quantities(self):
assert len(Run.m_def.quantities) == 2 assert len(Run.m_def.quantities) == 3
assert Run.m_def.all_quantities['code_name'] in Run.m_def.quantities assert Run.m_def.all_quantities['code_name'] in Run.m_def.quantities
assert Run.m_def.all_quantities['code_name'] == Run.__dict__['code_name'] assert Run.m_def.all_quantities['code_name'] == Run.__dict__['code_name']
...@@ -113,7 +113,7 @@ class TestM2: ...@@ -113,7 +113,7 @@ class TestM2:
assert Run.m_def.all_sub_sections_by_section[System.m_def].sub_section == System.m_def assert Run.m_def.all_sub_sections_by_section[System.m_def].sub_section == System.m_def
def test_properties(self): def test_properties(self):
assert len(Run.m_def.all_properties) == 4 assert len(Run.m_def.all_properties) == 5
def test_get_quantity_def(self): def test_get_quantity_def(self):
assert System.n_atoms == System.m_def.all_properties['n_atoms'] assert System.n_atoms == System.m_def.all_properties['n_atoms']
...@@ -141,18 +141,21 @@ class TestM2: ...@@ -141,18 +141,21 @@ class TestM2:
def test_package(self): def test_package(self):
assert example_package.name == 'nomad.metainfo.example' assert example_package.name == 'nomad.metainfo.example'
assert example_package.description == 'An example metainfo package.' assert example_package.description == 'An example metainfo package.'
assert example_package.m_sub_section_count(Package.section_definitions) == 3 assert example_package.m_sub_section_count(Package.section_definitions) == 4
assert example_package.m_sub_section_count(Package.category_definitions) == 1 assert example_package.m_sub_section_count(Package.category_definitions) == 1
def test_base_sections(self): def test_base_sections(self):
assert Definition.m_def in iter(Section.m_def.base_sections) assert Definition.m_def in iter(Section.m_def.base_sections)
print(Section.m_def.base_sections)
assert 'name' in Section.m_def.all_quantities assert 'name' in Section.m_def.all_quantities
assert 'name' in Quantity.m_def.all_quantities assert 'name' in Quantity.m_def.all_quantities
def test_unit(self): def test_unit(self):
assert System.lattice_vectors.unit is not None assert System.lattice_vectors.unit is not None
def test_extension(self):
assert getattr(Run, 'x_vasp_raw_format', None) is not None
assert 'x_vasp_raw_format' in Run.m_def.all_quantities
class TestM1: class TestM1:
""" Test for meta-info instances. """ """ Test for meta-info instances. """
...@@ -180,7 +183,7 @@ class TestM1: ...@@ -180,7 +183,7 @@ class TestM1:
assert_section_instance(system) assert_section_instance(system)
def test_defaults(self): def test_defaults(self):
assert System().n_atoms == 0 assert len(System().periodic_dimensions) == 3
assert System().atom_labels is None assert System().atom_labels is None
try: try:
System().does_not_exist System().does_not_exist
...@@ -258,7 +261,6 @@ class TestM1: ...@@ -258,7 +261,6 @@ class TestM1:
run = Run() run = Run()
run.code_name = 'test code name' run.code_name = 'test code name'
system: System = run.m_create(System) system: System = run.m_create(System)
system.n_atoms = 3
system.atom_labels = ['H', 'H', 'O'] system.atom_labels = ['H', 'H', 'O']
system.atom_positions = np.array([[1.2e-10, 0, 0], [0, 1.2e-10, 0], [0, 0, 1.2e-10]]) system.atom_positions = np.array([[1.2e-10, 0, 0], [0, 1.2e-10, 0], [0, 0, 1.2e-10]])
...@@ -280,3 +282,23 @@ class TestM1: ...@@ -280,3 +282,23 @@ class TestM1:
new_example_data = Run.m_from_dict(dct) new_example_data = Run.m_from_dict(dct)
self.assert_example_data(new_example_data) self.assert_example_data(new_example_data)
def test_derived(self):
system = System()
try:
assert system.n_atoms == 3
assert False, 'supposed unreachable'
except DeriveError:
pass
else:
assert False, 'supposed unreachable'
system.atom_labels = ['H', 'H', 'O']
assert system.n_atoms == 3
pass
def test_extension(self):
run = Run()
run.m_as(VaspRun).x_vasp_raw_format = 'outcar'
assert run.m_as(VaspRun).x_vasp_raw_format == 'outcar'
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment