Commit b07dec5a authored by Markus Scheidgen's avatar Markus Scheidgen
Browse files

Metainfo logs warning about wrong np shapes instead of error.

parent 82a8ed02
......@@ -492,9 +492,10 @@ class MResource():
Represents a collection of related metainfo data, i.e. a set of :class:`MSection` instances.
'''
def __init__(self):
def __init__(self, logger=None):
self.__data: Dict['Section', List['MSection']] = dict()
self.contents: List['MSection'] = []
self.logger = logger
def create(self, section_cls: Type[MSectionBound], *args, **kwargs) -> MSectionBound:
'''
......@@ -548,6 +549,10 @@ class MResource():
for section in self.contents
if filter(section)}
def warning(self, *args, **kwargs):
if self.logger is not None:
self.logger.warn(*args, **kwargs)
class MSection(metaclass=MObjectMeta): # TODO find a way to make this a subclass of collections.abs.Mapping
'''Base class for all section instances on all meta-info levels.
......@@ -1075,8 +1080,20 @@ class MSection(metaclass=MObjectMeta): # TODO find a way to make this a subclas
serialize = _primitive_types[quantity_type]
elif type(quantity_type) == np.dtype:
def serialize_dtype(x):
return x.item()
is_scalar = quantity.is_scalar
def serialize_dtype(value):
if isinstance(value, np.ndarray):
if is_scalar:
self.m_warning('numpy quantity has wrong shape', quantity=str(quantity))
return value.tolist()
else:
if not is_scalar:
self.m_warning('numpy quantity has wrong shape', quantity=str(quantity))
return value.item()
serialize = serialize_dtype
......@@ -1105,18 +1122,14 @@ class MSection(metaclass=MObjectMeta): # TODO find a way to make this a subclas
else:
value = quantity.default
if type(quantity_type) == np.dtype and len(quantity.shape) > 0:
serializable_value = value.tolist()
if type(quantity_type) == np.dtype:
return serialize(value)
elif len(quantity.shape) == 0:
return serialize(value)
elif len(quantity.shape) == 1:
return [serialize(i) for i in value]
else:
if len(quantity.shape) == 0:
serializable_value = serialize(value)
elif len(quantity.shape) == 1:
serializable_value = [serialize(i) for i in value]
else:
raise NotImplementedError('Higher shapes (%s) not supported: %s' % (quantity.shape, quantity))
return serializable_value
raise NotImplementedError('Higher shapes (%s) not supported: %s' % (quantity.shape, quantity))
def items() -> Iterable[Tuple[str, Any]]:
# metadata
......@@ -1147,6 +1160,8 @@ class MSection(metaclass=MObjectMeta): # TODO find a way to make this a subclas
yield name, serialize_quantity(quantity, is_set)
except ValueError as e:
import traceback
traceback.print_exc()
raise ValueError('Value error (%s) for %s' % (str(e), quantity))
# sub sections
......@@ -1533,6 +1548,10 @@ class MSection(metaclass=MObjectMeta): # TODO find a way to make this a subclas
return errors, warnings
def m_warning(self, *args, **kwargs):
if self.m_resource is not None:
self.m_resource.warning(*args, **kwargs)
def __repr__(self):
m_section_name = self.m_def.name
# name_quantity_def = self.m_def.all_quantities.get('name', None)
......
......@@ -205,7 +205,9 @@ class Backend(AbstractParserBackend):
warnings and errors.
'''
def __init__(self, metainfo: Union[str, LegacyMetainfoEnvironment], domain: str = None, logger=None):
def __init__(
self, metainfo: Union[str, LegacyMetainfoEnvironment], domain: str = None,
logger=None):
assert metainfo is not None
if logger is None:
......@@ -221,7 +223,7 @@ class Backend(AbstractParserBackend):
self.env: LegacyMetainfoEnvironment = cast(LegacyMetainfoEnvironment, metainfo)
self.__legacy_env = None
self.resource = MResource()
self.resource = MResource(logger=logger)
self.entry_archive = datamodel.EntryArchive()
self.resource.add(self.entry_archive)
......
......@@ -22,6 +22,9 @@ from nomad.metainfo.metainfo import (
MetainfoError, Environment, MResource, Datetime, units, Annotation, SectionAnnotation,
DefinitionAnnotation, Reference, MProxy, derived)
from nomad.metainfo.example import Run, VaspRun, System, SystemHash, Parsing, SCC, m_package as example_package
from nomad import utils
from tests import utils as test_utils
def assert_section_def(section_def: Section):
......@@ -545,6 +548,13 @@ class TestM1:
assert scc.an_int.__class__ == np.int32
assert scc.an_int.item() == 1 # pylint: disable=no-member
def test_np_allow_wrong_shape(self, caplog):
resource = MResource(logger=utils.get_logger(__name__))
scc = resource.create(SCC)
scc.energy_total_0 = np.array([1.0, 1.0, 1.0])
scc.m_to_dict()
test_utils.assert_log(caplog, 'WARN', 'wrong shape')
def test_proxy(self):
class OtherSection(MSection):
name = Quantity(type=str)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment