Commit d7bc23ea authored by Markus Scheidgen's avatar Markus Scheidgen
Browse files

Added dimension checks and constraints.

parent 14ea3639
......@@ -107,7 +107,11 @@ class DataType:
return value
class __Dimension(DataType):
range_re = re.compile(r'(\d)\.\.(\d|\*)')
class _Dimension(DataType):
def set_normalize(self, section, quantity_def: 'Quantity', value):
if isinstance(value, int):
return value
......@@ -115,7 +119,7 @@ class __Dimension(DataType):
if isinstance(value, str):
if value.isidentifier():
return value
if re.match(r'(\d)\.\.(\d|\*)', value):
if re.match(range_re, value):
return value
if isinstance(value, Section):
......@@ -126,8 +130,21 @@ class __Dimension(DataType):
raise TypeError('%s is not a valid dimension' % str(value))
@staticmethod
def check_dimension(section, dimension, length):
if isinstance(dimension, int):
return dimension == length
if isinstance(dimension, str):
if dimension.isidentifier():
return dimension == getattr(section, dimension)
class __Unit(DataType):
m = re.match(range_re, dimension)
start = int(m.group(1))
end = -1 if m.group(2) == '*' else int(m.group(2))
return start <= length and (end == -1 or length <= end)
class _Unit(DataType):
def set_normalize(self, section, quantity_def: 'Quantity', value):
if isinstance(value, str):
value = units.parse_units(value)
......@@ -148,7 +165,7 @@ units = pint.UnitRegistry()
""" The default pint unit registry that should be used to give units to quantity definitions. """
class __Callable(DataType):
class _Callable(DataType):
def serialize(self, section, quantity_def: 'Quantity', value):
raise MetainfoError('Callables cannot be serialized')
......@@ -156,7 +173,7 @@ class __Callable(DataType):
raise MetainfoError('Callables cannot be serialized')
class __QuantityType(DataType):
class _QuantityType(DataType):
""" Data type for defining the type of a metainfo quantity.
A metainfo quantity type can be one of
......@@ -279,7 +296,7 @@ class Reference(DataType):
return MProxy(value)
class __Datetime(DataType):
class _Datetime(DataType):
def __parse(self, datetime_str: str) -> datetime:
try:
......@@ -308,11 +325,11 @@ class __Datetime(DataType):
return self.__parse(value)
Dimension = __Dimension()
Unit = __Unit()
QuantityType = __QuantityType()
Callable = __Callable()
Datetime = __Datetime()
Dimension = _Dimension()
Unit = _Unit()
QuantityType = _QuantityType()
Callable = _Callable()
Datetime = _Datetime()
# Metainfo data storage and reflection interface
......@@ -567,6 +584,7 @@ class MSection(metaclass=MObjectMeta):
m_def.section_cls = cls
# add base sections
extended_base_section = None
if m_def.extends_base_section:
base_sections_count = len(cls.__bases__)
if base_sections_count == 0:
......@@ -578,7 +596,7 @@ class MSection(metaclass=MObjectMeta):
'Section %s extend the base section, but has more than one base section' % m_def)
base_section_cls = cls.__bases__[0]
base_section = getattr(base_section_cls, 'm_def', None)
extended_base_section = base_section = getattr(base_section_cls, 'm_def', None)
if base_section is None:
raise MetainfoError(
'The base section of %s is not a section class.' % m_def)
......@@ -679,8 +697,25 @@ class MSection(metaclass=MObjectMeta):
if prop.description is None:
prop.description = param.description
def __check_np(self, quantity_ref: 'Quantity', value: np.ndarray) -> np.ndarray:
# TODO
# validate
def validate(definition):
errors = definition.m_all_validate()
if len(errors) > 0:
raise MetainfoError(
'%s. The section definition %s violates %d more constraints' %
(str(errors[0]).strip('.'), definition, len(errors) - 1))
if extended_base_section is not None:
validate(extended_base_section)
validate(m_def)
def __check_np(self, quantity_def: 'Quantity', value: np.ndarray) -> np.ndarray:
# TODO this feels expensive, first check, then possible convert very often?
# if quantity_ref.type != value.dtype:
# raise MetainfoError(
# 'Quantity dtype %s and value dtype %s do not match.' %
# (quantity_ref.type, value.dtype))
return value
def __set_normalize(self, quantity_def: 'Quantity', value: Any) -> Any:
......@@ -950,7 +985,7 @@ class MSection(metaclass=MObjectMeta):
'Do not know how to serialize data with type %s for quantity %s' %
(quantity.type, quantity))
value = self.m_data.dct[name]
value = cast(MDataDict, self.m_data).dct[name]
if type(quantity.type) == np.dtype:
serializable_value = value.tolist()
......@@ -1132,8 +1167,30 @@ class MSection(metaclass=MObjectMeta):
return cast(MSectionBound, context)
def __validate_shape(self, quantity_def: 'Quantity', value):
if quantity_def == Quantity.default:
return True
quantity_shape = quantity_def.shape
if type(value) == np.ndarray:
value_shape = value.shape
if isinstance(value, list) and not isinstance(value, Enum):
value_shape = [len(value)]
else:
value_shape = []
if len(value_shape) != len(quantity_shape):
return False
for i in range(0, len(value_shape)):
if not _Dimension.check_dimension(self, quantity_shape[i], value_shape[i]):
return False
return True
def m_validate(self):
""" Evaluates all constraints of this section and returns a list of errors. """
""" Evaluates all constraints and shapes of this section and returns a list of errors. """
errors: List[str] = []
for constraint_name in self.m_def.constraints:
constraint = getattr(self, 'c_%s' % constraint_name, None)
......@@ -1150,6 +1207,12 @@ class MSection(metaclass=MObjectMeta):
error_str = 'Constraint %s violated.' % constraint_name
errors.append(error_str)
for quantity in self.m_def.all_quantities.values():
if self.m_is_set(quantity):
if not self.__validate_shape(quantity, self.m_get(quantity)):
errors.append(
'The shape of quantity %s does not match its value.' % quantity)
return errors
def m_all_validate(self):
......@@ -1400,6 +1463,20 @@ class Quantity(Property):
# object (instance) case
raise NotImplementedError('Deleting quantity values is not supported.')
def c_dimensions(self):
for dimension in self.shape: # pylint: disable=not-an-iterable
if isinstance(dimension, str):
if dimension.isidentifier():
dim_quantity = self.m_parent.all_quantities.get(dimension, None)
assert dim_quantity is not None, 'Dimensions must be quantities of the same section.'
assert len(dim_quantity.shape) == 0 and dim_quantity.type == int, \
'Dimensions must be shapeless and int typed.'
def c_higher_shapes_require_dtype(self):
if len(self.shape) > 1:
assert type(self.type) == np.dtype, \
'Higher dimensional quantities need a dtype and will be treated as numpy arrays.'
class DirectQuantity(Quantity):
""" Used for quantities that would cause indefinite loops due to bootstrapping. """
......@@ -1588,6 +1665,25 @@ class Section(Definition):
self.all_sub_sections.update(**base_section.all_sub_sections)
self.all_sub_sections_by_section.update(**base_section.all_sub_sections_by_section)
def c_unique_names(self):
# start with the names of all base_sections
names: Set[str] = set(
name
for base in self.all_base_sections
for name in base.all_properties.keys())
for def_list in [self.quantities, self.sub_sections]:
for definition in def_list:
assert definition.name not in names, 'All names in a section must be unique.'
names.add(definition.name)
def c_unique_sub_sections(self):
sub_sections = set()
for sub_section in self.all_sub_sections.values():
assert sub_section.sub_section not in sub_sections, \
'The same section definition can only be used in one sub-section'
sub_sections.add(sub_section.sub_section)
class Package(Definition):
""" Packages organize metainfo defintions alongside Python modules
......@@ -1674,7 +1770,7 @@ Section.sub_sections = SubSection(
Section.base_sections = Quantity(
type=Reference(Section.m_def), shape=['0..*'], default=[], name='base_sections')
Section.extends_base_section = Quantity(type=bool, default=False, name='extends_base_section')
Section.constraints = Quantity(type=str, shape=['0..*'], name='constraints')
Section.constraints = Quantity(type=str, shape=['0..*'], default=[], name='constraints')
Section.event_handlers = Quantity(
type=Callable, shape=['0..*'], name='event_handlers', virtual=True, default=[])
......
......@@ -16,9 +16,11 @@ import pytest
import numpy as np
import pint.quantity
from nomad.metainfo.metainfo import MSection, MCategory, Section, Quantity, Definition, Package, DeriveError, units
from nomad.metainfo.metainfo import MSection, MCategory, Section, Quantity, SubSection, Definition, Package, DeriveError, MetainfoError, units
from nomad.metainfo.example import Run, VaspRun, System, SystemHash, Parsing, m_package as example_package
from tests.utils import assert_exception
def assert_section_def(section_def: Section):
assert isinstance(section_def, Section)
......@@ -165,6 +167,56 @@ class TestM2:
def test_constraints(self):
assert len(Run.m_def.constraints) > 0
def test_unique_names(self):
class TestBase(MSection):
name = Quantity(type=str)
with assert_exception(MetainfoError):
class TestSection(TestBase): # pylint: disable=unused-variable
name = Quantity(type=int)
def test_unique_names_extends(self):
class TestBase(MSection):
name = Quantity(type=str)
with assert_exception(MetainfoError):
class TestSection(TestBase): # pylint: disable=unused-variable
m_def = Section(extends_base_section=True)
name = Quantity(type=int)
def test_unique_sub_sections(self):
with assert_exception(MetainfoError):
class TestSection(MSection): # pylint: disable=unused-variable
one = SubSection(sub_section=System)
two = SubSection(sub_section=System)
def test_dimension_exists(self):
with assert_exception(MetainfoError):
class TestSection(MSection): # pylint: disable=unused-variable
test = Quantity(type=str, shape=['does_not_exist'])
def test_dimension_is_int(self):
with assert_exception(MetainfoError):
class TestSection(MSection): # pylint: disable=unused-variable
dim = Quantity(type=str)
test = Quantity(type=str, shape=['dim'])
def test_dimension_is_shapeless(self):
with assert_exception(MetainfoError):
class TestSection(MSection): # pylint: disable=unused-variable
dim = Quantity(type=int, shape=[1])
test = Quantity(type=str, shape=['dim'])
def test_higher_shapes_require_dtype(self):
with assert_exception(MetainfoError):
class TestSection(MSection): # pylint: disable=unused-variable
test = Quantity(type=int, shape=[3, 3])
def test_only_extends_one_base(self):
with assert_exception(MetainfoError):
class TestSection(Run, System): # pylint: disable=unused-variable
m_def = Section(extends_base_section=True)
class TestM1:
""" Test for meta-info instances. """
......@@ -331,3 +383,9 @@ class TestM1:
errors = run.m_validate()
assert len(errors) == 1
def test_validate_dimension(self):
system = System()
system.atom_labels = ['H']
system.atom_positions = []
assert len(system.m_validate()) > 0
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment