Commit d7bc23ea authored by Markus Scheidgen's avatar Markus Scheidgen
Browse files

Added dimension checks and constraints.

parent 14ea3639
...@@ -107,7 +107,11 @@ class DataType: ...@@ -107,7 +107,11 @@ class DataType:
return value return value
class __Dimension(DataType): range_re = re.compile(r'(\d)\.\.(\d|\*)')
class _Dimension(DataType):
def set_normalize(self, section, quantity_def: 'Quantity', value): def set_normalize(self, section, quantity_def: 'Quantity', value):
if isinstance(value, int): if isinstance(value, int):
return value return value
...@@ -115,7 +119,7 @@ class __Dimension(DataType): ...@@ -115,7 +119,7 @@ class __Dimension(DataType):
if isinstance(value, str): if isinstance(value, str):
if value.isidentifier(): if value.isidentifier():
return value return value
if re.match(r'(\d)\.\.(\d|\*)', value): if re.match(range_re, value):
return value return value
if isinstance(value, Section): if isinstance(value, Section):
...@@ -126,8 +130,21 @@ class __Dimension(DataType): ...@@ -126,8 +130,21 @@ class __Dimension(DataType):
raise TypeError('%s is not a valid dimension' % str(value)) raise TypeError('%s is not a valid dimension' % str(value))
@staticmethod
def check_dimension(section, dimension, length):
if isinstance(dimension, int):
return dimension == length
if isinstance(dimension, str):
if dimension.isidentifier():
return dimension == getattr(section, dimension)
class __Unit(DataType): m = re.match(range_re, dimension)
start = int(m.group(1))
end = -1 if m.group(2) == '*' else int(m.group(2))
return start <= length and (end == -1 or length <= end)
class _Unit(DataType):
def set_normalize(self, section, quantity_def: 'Quantity', value): def set_normalize(self, section, quantity_def: 'Quantity', value):
if isinstance(value, str): if isinstance(value, str):
value = units.parse_units(value) value = units.parse_units(value)
...@@ -148,7 +165,7 @@ units = pint.UnitRegistry() ...@@ -148,7 +165,7 @@ units = pint.UnitRegistry()
""" The default pint unit registry that should be used to give units to quantity definitions. """ """ The default pint unit registry that should be used to give units to quantity definitions. """
class __Callable(DataType): class _Callable(DataType):
def serialize(self, section, quantity_def: 'Quantity', value): def serialize(self, section, quantity_def: 'Quantity', value):
raise MetainfoError('Callables cannot be serialized') raise MetainfoError('Callables cannot be serialized')
...@@ -156,7 +173,7 @@ class __Callable(DataType): ...@@ -156,7 +173,7 @@ class __Callable(DataType):
raise MetainfoError('Callables cannot be serialized') raise MetainfoError('Callables cannot be serialized')
class __QuantityType(DataType): class _QuantityType(DataType):
""" Data type for defining the type of a metainfo quantity. """ Data type for defining the type of a metainfo quantity.
A metainfo quantity type can be one of A metainfo quantity type can be one of
...@@ -279,7 +296,7 @@ class Reference(DataType): ...@@ -279,7 +296,7 @@ class Reference(DataType):
return MProxy(value) return MProxy(value)
class __Datetime(DataType): class _Datetime(DataType):
def __parse(self, datetime_str: str) -> datetime: def __parse(self, datetime_str: str) -> datetime:
try: try:
...@@ -308,11 +325,11 @@ class __Datetime(DataType): ...@@ -308,11 +325,11 @@ class __Datetime(DataType):
return self.__parse(value) return self.__parse(value)
Dimension = __Dimension() Dimension = _Dimension()
Unit = __Unit() Unit = _Unit()
QuantityType = __QuantityType() QuantityType = _QuantityType()
Callable = __Callable() Callable = _Callable()
Datetime = __Datetime() Datetime = _Datetime()
# Metainfo data storage and reflection interface # Metainfo data storage and reflection interface
...@@ -567,6 +584,7 @@ class MSection(metaclass=MObjectMeta): ...@@ -567,6 +584,7 @@ class MSection(metaclass=MObjectMeta):
m_def.section_cls = cls m_def.section_cls = cls
# add base sections # add base sections
extended_base_section = None
if m_def.extends_base_section: if m_def.extends_base_section:
base_sections_count = len(cls.__bases__) base_sections_count = len(cls.__bases__)
if base_sections_count == 0: if base_sections_count == 0:
...@@ -578,7 +596,7 @@ class MSection(metaclass=MObjectMeta): ...@@ -578,7 +596,7 @@ class MSection(metaclass=MObjectMeta):
'Section %s extend the base section, but has more than one base section' % m_def) 'Section %s extend the base section, but has more than one base section' % m_def)
base_section_cls = cls.__bases__[0] base_section_cls = cls.__bases__[0]
base_section = getattr(base_section_cls, 'm_def', None) extended_base_section = base_section = getattr(base_section_cls, 'm_def', None)
if base_section is None: if base_section is None:
raise MetainfoError( raise MetainfoError(
'The base section of %s is not a section class.' % m_def) 'The base section of %s is not a section class.' % m_def)
...@@ -679,8 +697,25 @@ class MSection(metaclass=MObjectMeta): ...@@ -679,8 +697,25 @@ class MSection(metaclass=MObjectMeta):
if prop.description is None: if prop.description is None:
prop.description = param.description prop.description = param.description
def __check_np(self, quantity_ref: 'Quantity', value: np.ndarray) -> np.ndarray: # validate
# TODO def validate(definition):
errors = definition.m_all_validate()
if len(errors) > 0:
raise MetainfoError(
'%s. The section definition %s violates %d more constraints' %
(str(errors[0]).strip('.'), definition, len(errors) - 1))
if extended_base_section is not None:
validate(extended_base_section)
validate(m_def)
def __check_np(self, quantity_def: 'Quantity', value: np.ndarray) -> np.ndarray:
# TODO this feels expensive, first check, then possible convert very often?
# if quantity_ref.type != value.dtype:
# raise MetainfoError(
# 'Quantity dtype %s and value dtype %s do not match.' %
# (quantity_ref.type, value.dtype))
return value return value
def __set_normalize(self, quantity_def: 'Quantity', value: Any) -> Any: def __set_normalize(self, quantity_def: 'Quantity', value: Any) -> Any:
...@@ -950,7 +985,7 @@ class MSection(metaclass=MObjectMeta): ...@@ -950,7 +985,7 @@ class MSection(metaclass=MObjectMeta):
'Do not know how to serialize data with type %s for quantity %s' % 'Do not know how to serialize data with type %s for quantity %s' %
(quantity.type, quantity)) (quantity.type, quantity))
value = self.m_data.dct[name] value = cast(MDataDict, self.m_data).dct[name]
if type(quantity.type) == np.dtype: if type(quantity.type) == np.dtype:
serializable_value = value.tolist() serializable_value = value.tolist()
...@@ -1132,8 +1167,30 @@ class MSection(metaclass=MObjectMeta): ...@@ -1132,8 +1167,30 @@ class MSection(metaclass=MObjectMeta):
return cast(MSectionBound, context) return cast(MSectionBound, context)
def __validate_shape(self, quantity_def: 'Quantity', value):
if quantity_def == Quantity.default:
return True
quantity_shape = quantity_def.shape
if type(value) == np.ndarray:
value_shape = value.shape
if isinstance(value, list) and not isinstance(value, Enum):
value_shape = [len(value)]
else:
value_shape = []
if len(value_shape) != len(quantity_shape):
return False
for i in range(0, len(value_shape)):
if not _Dimension.check_dimension(self, quantity_shape[i], value_shape[i]):
return False
return True
def m_validate(self): def m_validate(self):
""" Evaluates all constraints of this section and returns a list of errors. """ """ Evaluates all constraints and shapes of this section and returns a list of errors. """
errors: List[str] = [] errors: List[str] = []
for constraint_name in self.m_def.constraints: for constraint_name in self.m_def.constraints:
constraint = getattr(self, 'c_%s' % constraint_name, None) constraint = getattr(self, 'c_%s' % constraint_name, None)
...@@ -1150,6 +1207,12 @@ class MSection(metaclass=MObjectMeta): ...@@ -1150,6 +1207,12 @@ class MSection(metaclass=MObjectMeta):
error_str = 'Constraint %s violated.' % constraint_name error_str = 'Constraint %s violated.' % constraint_name
errors.append(error_str) errors.append(error_str)
for quantity in self.m_def.all_quantities.values():
if self.m_is_set(quantity):
if not self.__validate_shape(quantity, self.m_get(quantity)):
errors.append(
'The shape of quantity %s does not match its value.' % quantity)
return errors return errors
def m_all_validate(self): def m_all_validate(self):
...@@ -1400,6 +1463,20 @@ class Quantity(Property): ...@@ -1400,6 +1463,20 @@ class Quantity(Property):
# object (instance) case # object (instance) case
raise NotImplementedError('Deleting quantity values is not supported.') raise NotImplementedError('Deleting quantity values is not supported.')
def c_dimensions(self):
for dimension in self.shape: # pylint: disable=not-an-iterable
if isinstance(dimension, str):
if dimension.isidentifier():
dim_quantity = self.m_parent.all_quantities.get(dimension, None)
assert dim_quantity is not None, 'Dimensions must be quantities of the same section.'
assert len(dim_quantity.shape) == 0 and dim_quantity.type == int, \
'Dimensions must be shapeless and int typed.'
def c_higher_shapes_require_dtype(self):
if len(self.shape) > 1:
assert type(self.type) == np.dtype, \
'Higher dimensional quantities need a dtype and will be treated as numpy arrays.'
class DirectQuantity(Quantity): class DirectQuantity(Quantity):
""" Used for quantities that would cause indefinite loops due to bootstrapping. """ """ Used for quantities that would cause indefinite loops due to bootstrapping. """
...@@ -1588,6 +1665,25 @@ class Section(Definition): ...@@ -1588,6 +1665,25 @@ class Section(Definition):
self.all_sub_sections.update(**base_section.all_sub_sections) self.all_sub_sections.update(**base_section.all_sub_sections)
self.all_sub_sections_by_section.update(**base_section.all_sub_sections_by_section) self.all_sub_sections_by_section.update(**base_section.all_sub_sections_by_section)
def c_unique_names(self):
# start with the names of all base_sections
names: Set[str] = set(
name
for base in self.all_base_sections
for name in base.all_properties.keys())
for def_list in [self.quantities, self.sub_sections]:
for definition in def_list:
assert definition.name not in names, 'All names in a section must be unique.'
names.add(definition.name)
def c_unique_sub_sections(self):
sub_sections = set()
for sub_section in self.all_sub_sections.values():
assert sub_section.sub_section not in sub_sections, \
'The same section definition can only be used in one sub-section'
sub_sections.add(sub_section.sub_section)
class Package(Definition): class Package(Definition):
""" Packages organize metainfo defintions alongside Python modules """ Packages organize metainfo defintions alongside Python modules
...@@ -1674,7 +1770,7 @@ Section.sub_sections = SubSection( ...@@ -1674,7 +1770,7 @@ Section.sub_sections = SubSection(
Section.base_sections = Quantity( Section.base_sections = Quantity(
type=Reference(Section.m_def), shape=['0..*'], default=[], name='base_sections') type=Reference(Section.m_def), shape=['0..*'], default=[], name='base_sections')
Section.extends_base_section = Quantity(type=bool, default=False, name='extends_base_section') Section.extends_base_section = Quantity(type=bool, default=False, name='extends_base_section')
Section.constraints = Quantity(type=str, shape=['0..*'], name='constraints') Section.constraints = Quantity(type=str, shape=['0..*'], default=[], name='constraints')
Section.event_handlers = Quantity( Section.event_handlers = Quantity(
type=Callable, shape=['0..*'], name='event_handlers', virtual=True, default=[]) type=Callable, shape=['0..*'], name='event_handlers', virtual=True, default=[])
......
...@@ -16,9 +16,11 @@ import pytest ...@@ -16,9 +16,11 @@ import pytest
import numpy as np import numpy as np
import pint.quantity import pint.quantity
from nomad.metainfo.metainfo import MSection, MCategory, Section, Quantity, Definition, Package, DeriveError, units from nomad.metainfo.metainfo import MSection, MCategory, Section, Quantity, SubSection, Definition, Package, DeriveError, MetainfoError, units
from nomad.metainfo.example import Run, VaspRun, System, SystemHash, Parsing, m_package as example_package from nomad.metainfo.example import Run, VaspRun, System, SystemHash, Parsing, m_package as example_package
from tests.utils import assert_exception
def assert_section_def(section_def: Section): def assert_section_def(section_def: Section):
assert isinstance(section_def, Section) assert isinstance(section_def, Section)
...@@ -165,6 +167,56 @@ class TestM2: ...@@ -165,6 +167,56 @@ class TestM2:
def test_constraints(self): def test_constraints(self):
assert len(Run.m_def.constraints) > 0 assert len(Run.m_def.constraints) > 0
def test_unique_names(self):
class TestBase(MSection):
name = Quantity(type=str)
with assert_exception(MetainfoError):
class TestSection(TestBase): # pylint: disable=unused-variable
name = Quantity(type=int)
def test_unique_names_extends(self):
class TestBase(MSection):
name = Quantity(type=str)
with assert_exception(MetainfoError):
class TestSection(TestBase): # pylint: disable=unused-variable
m_def = Section(extends_base_section=True)
name = Quantity(type=int)
def test_unique_sub_sections(self):
with assert_exception(MetainfoError):
class TestSection(MSection): # pylint: disable=unused-variable
one = SubSection(sub_section=System)
two = SubSection(sub_section=System)
def test_dimension_exists(self):
with assert_exception(MetainfoError):
class TestSection(MSection): # pylint: disable=unused-variable
test = Quantity(type=str, shape=['does_not_exist'])
def test_dimension_is_int(self):
with assert_exception(MetainfoError):
class TestSection(MSection): # pylint: disable=unused-variable
dim = Quantity(type=str)
test = Quantity(type=str, shape=['dim'])
def test_dimension_is_shapeless(self):
with assert_exception(MetainfoError):
class TestSection(MSection): # pylint: disable=unused-variable
dim = Quantity(type=int, shape=[1])
test = Quantity(type=str, shape=['dim'])
def test_higher_shapes_require_dtype(self):
with assert_exception(MetainfoError):
class TestSection(MSection): # pylint: disable=unused-variable
test = Quantity(type=int, shape=[3, 3])
def test_only_extends_one_base(self):
with assert_exception(MetainfoError):
class TestSection(Run, System): # pylint: disable=unused-variable
m_def = Section(extends_base_section=True)
class TestM1: class TestM1:
""" Test for meta-info instances. """ """ Test for meta-info instances. """
...@@ -331,3 +383,9 @@ class TestM1: ...@@ -331,3 +383,9 @@ class TestM1:
errors = run.m_validate() errors = run.m_validate()
assert len(errors) == 1 assert len(errors) == 1
def test_validate_dimension(self):
system = System()
system.atom_labels = ['H']
system.atom_positions = []
assert len(system.m_validate()) > 0
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment