diff --git a/dependencies/parsers/workflow b/dependencies/parsers/workflow index 73f6b752b1b56b22d3b54e9d236d7b8d174cb424..e31076ec0fde4a361c36bd30790a33527c95b116 160000 --- a/dependencies/parsers/workflow +++ b/dependencies/parsers/workflow @@ -1 +1 @@ -Subproject commit 73f6b752b1b56b22d3b54e9d236d7b8d174cb424 +Subproject commit e31076ec0fde4a361c36bd30790a33527c95b116 diff --git a/gui/src/components/archive/ArchiveBrowser.js b/gui/src/components/archive/ArchiveBrowser.js index 0c3279eb6b602c1f8f40e7d333e6f4182c63d296..29fb507ba91256f957adf088dd2e5fd0be606923 100644 --- a/gui/src/components/archive/ArchiveBrowser.js +++ b/gui/src/components/archive/ArchiveBrowser.js @@ -65,6 +65,7 @@ import ReloadIcon from '@material-ui/icons/Replay' import UploadIcon from '@material-ui/icons/CloudUpload' import { apiBase } from '../../config' import { Alert } from '@material-ui/lab' +import { complex, format } from 'mathjs' export const configState = atom({ key: 'config', @@ -619,6 +620,12 @@ class AttributeAdaptor extends ArchiveAdaptor { } } +const convertComplexArray = (real, imag) => { + return Array.isArray(real) + ? real.map((r, i) => convertComplexArray(r, imag[i])) + : format(complex(real, imag), {notation: 'auto', precision: 4, lowerExp: -999, upperExp: 999}) +} + function QuantityItemPreview({value, def}) { const units = useUnits() if (isReference(def)) { @@ -635,7 +642,7 @@ function QuantityItemPreview({value, def}) { const dimensions = [] let typeLabel = 'unknown' try { - let current = value + let current = value.re || value.im || value for (let i = 0; i < def.shape.length; i++) { dimensions.push(current.length) current = current[0] @@ -664,7 +671,15 @@ function QuantityItemPreview({value, def}) { </Typography> </Box> } else { - let finalValue = (def.type.type_data === 'nomad.metainfo.metainfo._Datetime' ? formatTimestamp(value) : value) + let finalValue + if (def.type.type_data === 'nomad.metainfo.metainfo._Datetime') { + finalValue = formatTimestamp(value) + } else if (def.type.type_data.startsWith?.('complex')) { + finalValue = convertComplexArray(value.re, value.im) + } else { + finalValue = value + } + let finalUnit if (def.unit) { const a = new Q(finalValue, def.unit).toSystem(units) @@ -687,7 +702,14 @@ const QuantityValue = React.memo(function QuantityValue({value, def, ...more}) { const units = useUnits() const getRenderValue = useCallback(value => { - let finalValue = (def.type.type_data === 'nomad.metainfo.metainfo._Datetime' ? formatTimestamp(value) : value) + let finalValue + if (def.type.type_data === 'nomad.metainfo.metainfo._Datetime') { + finalValue = formatTimestamp(value) + } else if (def.type.type_data.startsWith?.('complex')) { + finalValue = convertComplexArray(value.re, value.im) + } else { + finalValue = value + } let finalUnit if (def.unit) { const systemUnitQ = new Q(finalValue, def.unit).toSystem(units) @@ -727,7 +749,15 @@ const QuantityValue = React.memo(function QuantityValue({value, def, ...more}) { } else if (def.m_annotations?.eln?.[0]?.component === 'RichTextEditQuantity') { return <div dangerouslySetInnerHTML={{__html: value}}/> } else { - if (Array.isArray(value)) { + if (def.type.type_data.startsWith?.('complex')) { + value = convertComplexArray(value.re, value.im) + + return Array.isArray(value) + ? <ul style={{margin: 0}}> + {value.map((value, index) => <li key={index}><Typography>{value}</Typography></li>)} + </ul> + : <Typography>{value}</Typography> + } else if (Array.isArray(value)) { return <ul style={{margin: 0}}> {value.map((value, index) => { const [finalValue] = getRenderValue(value) diff --git a/gui/src/components/archive/visualizations.js b/gui/src/components/archive/visualizations.js index d298d6aa7f245b0e5176270e8003faca87b9929f..cc6fccf82462465daaaa80546c6bb908ce3c2881 100644 --- a/gui/src/components/archive/visualizations.js +++ b/gui/src/components/archive/visualizations.js @@ -114,7 +114,7 @@ export function Matrix({values, shape, invert, type}) { values = values[pages[ii++]] } - const columnWidth = 92 + const columnWidth = useRef(92) const rowHeight = 24 const rowCount = invert ? values.length : shape.length > 1 ? values[0].length : 1 const columnCount = invert ? shape.length > 1 ? values[0].length : 1 : values.length @@ -124,9 +124,9 @@ export function Matrix({values, shape, invert, type}) { if (type === 'str') { matrixRef.current.style.width = '100%' } else { - matrixRef.current.style.width = Math.min( - rootRef.current.clientWidth - 4, columnCount * columnWidth) + 'px' + matrixRef.current.style.width = `${rootRef.current.clientWidth - 4}px` } + columnWidth.current = Math.max(92, (rootRef.current.clientWidth - 4) / columnCount) }) let value = shape.length > 1 ? ({rowIndex, columnIndex}) => values[columnIndex][rowIndex] : ({columnIndex}) => values[columnIndex] @@ -142,7 +142,7 @@ export function Matrix({values, shape, invert, type}) { {({width}) => ( <Grid columnCount={columnCount} - columnWidth={columnWidth} + columnWidth={columnWidth.current} height={height} rowCount={rowCount} rowHeight={rowHeight} diff --git a/gui/src/metainfo.json b/gui/src/metainfo.json index 60cba2259e8a2446d9435f677154391a86138c89..9965ab43372e79ab233c972e593518ccbde14ad2 100644 --- a/gui/src/metainfo.json +++ b/gui/src/metainfo.json @@ -339511,8 +339511,8 @@ "name": "x_mp_uncorrected_energy_per_atom", "description": "", "type": { - "type_kind": "numpy", - "type_data": "float64" + "type_kind": "custom", + "type_data": "nomad.metainfo.metainfo._Datetime" }, "shape": [] } diff --git a/nomad/metainfo/metainfo.py b/nomad/metainfo/metainfo.py index c2876933a8bdcc9b4a7ec109128bea0c26f3e05f..648edc95912b1d8e471dff2dfda4b6028b526b4b 100644 --- a/nomad/metainfo/metainfo.py +++ b/nomad/metainfo/metainfo.py @@ -38,8 +38,8 @@ from nomad.config import process from nomad.metainfo.util import ( Annotation, DefinitionAnnotation, MEnum, MQuantity, MRegEx, MSubSectionList, MTypes, ReferenceURL, SectionAnnotation, _delta_symbols, check_dimensionality, check_unit, convert_to, default_hash, dict_to_named_list, - normalize_datetime, resolve_variadic_name, retrieve_attribute, split_python_definition, to_dict, to_numpy, - to_section_def, validate_shape, validate_url) + normalize_complex, normalize_datetime, resolve_variadic_name, retrieve_attribute, serialize_complex, + split_python_definition, to_dict, to_numpy, to_section_def, validate_shape, validate_url) from nomad.units import ureg as units # todo: remove magic comment after upgrading pylint @@ -373,9 +373,10 @@ class _QuantityType(DataType): if isinstance(value, MEnum): return value + # we normalise all np.dtype to basic np.number types if isinstance(value, np.dtype): value = value.type - # we normalise all np.dtype to basic np.number types + if value in MTypes.numpy: return value @@ -1284,9 +1285,6 @@ class MSection(metaclass=MObjectMeta): # TODO find a way to make this a subclas return value - if isinstance(target_type, DataType): - return target_type.set_normalize(self, None, value) # type: ignore - if isinstance(target_type, MEnum): if value not in cast(MEnum, target_type).get_all_values(): raise TypeError(f'The value {value} is not an enum value for {quantity_def}.') @@ -1304,6 +1302,9 @@ class MSection(metaclass=MObjectMeta): # TODO find a way to make this a subclas if target_type == int and type(value) == np.float_: return int(value) + if target_type in MTypes.complex: + return normalize_complex(value, target_type, quantity_def.unit) + if type(value) != target_type: if target_type in MTypes.primitive: try: @@ -1351,8 +1352,11 @@ class MSection(metaclass=MObjectMeta): # TODO find a way to make this a subclas raise TypeError( f'The shape of {quantity_def} requires an iterable value, but {value} is not iterable.') - value = [v for v in list( - self.__set_normalize(quantity_def, item) for item in value) if v != _unset_value] + if quantity_def.type == complex: + value = normalize_complex(value, complex, quantity_def.unit) + else: + value = [v for v in list( + self.__set_normalize(quantity_def, item) for item in value) if v != _unset_value] else: raise MetainfoError( @@ -1395,7 +1399,6 @@ class MSection(metaclass=MObjectMeta): # TODO find a way to make this a subclas if not validate_shape(self, quantity_def, m_quantity.value): raise MetainfoError(f"The shape of {m_quantity} does not match {quantity_def.shape}") - # todo validate values if quantity_def.unit is None: # no prescribed unit, need to check dimensionality, no need to convert check_dimensionality(quantity_def, m_quantity.unit) @@ -1422,8 +1425,11 @@ class MSection(metaclass=MObjectMeta): # TODO find a way to make this a subclas f'The shape of {quantity_def} requires an iterable value, ' f'but {m_quantity.value} is not iterable.') - m_quantity.value = [v for v in list( - self.__set_normalize(quantity_def, item) for item in m_quantity.value) if v != _unset_value] + if quantity_def.type == complex: + m_quantity.value = normalize_complex(m_quantity.value, complex, quantity_def.unit) + else: + m_quantity.value = [v for v in list( + self.__set_normalize(quantity_def, item) for item in m_quantity.value) if v != _unset_value] else: raise MetainfoError( @@ -1610,7 +1616,10 @@ class MSection(metaclass=MObjectMeta): # TODO find a way to make this a subclas if type(attr_value) == str or not isinstance(attr_value, IterableABC): raise TypeError(f'The shape requires an iterable value, but {attr_value} is not.') - attr_value = list(self.__set_normalize(tgt_attr, item) for item in attr_value) + if tgt_attr.type == complex: + attr_value = normalize_complex(attr_value, complex, None) + else: + attr_value = list(self.__set_normalize(tgt_attr, item) for item in attr_value) else: raise MetainfoError(f'Only numpy arrays can be used for higher dimensional quantities: {tgt_attr}.') @@ -1871,6 +1880,10 @@ class MSection(metaclass=MObjectMeta): # TODO find a way to make this a subclas serialize = serialize_data_type + elif quantity_type in MTypes.complex: + + serialize = serialize_complex + elif quantity_type in MTypes.primitive: serialize = MTypes.primitive[quantity_type] @@ -1942,7 +1955,7 @@ class MSection(metaclass=MObjectMeta): # TODO find a way to make this a subclas serialize = serialize_and_transform # serialization starts here - if quantity_type in MTypes.numpy: + if quantity_type in MTypes.numpy or quantity_type in MTypes.complex: return serialize(target_value) if len(quantity.shape) == 0: @@ -1960,6 +1973,9 @@ class MSection(metaclass=MObjectMeta): # TODO find a way to make this a subclas if isinstance(attribute.type, DataType): return attribute.type.serialize(self, None, value) + if attribute.type in MTypes.complex: + return serialize_complex(value) + if attribute.type in MTypes.primitive: if len(attribute.shape) == 0: return MTypes.primitive[attribute.type](value) # type: ignore @@ -2103,6 +2119,9 @@ class MSection(metaclass=MObjectMeta): # TODO find a way to make this a subclas def __deserialize(section: MSection, quantity_def: Quantity, quantity_value: Any): tgt_type = quantity_def.type + if tgt_type in MTypes.complex: + return normalize_complex(quantity_value, tgt_type, quantity_def.unit) + if tgt_type in MTypes.numpy: if not isinstance(quantity_value, list): return tgt_type(quantity_value) @@ -3353,7 +3372,7 @@ class PrimitiveQuantity(Quantity): if hasattr(value, 'tolist'): value = value.tolist() else: - raise TypeError(f'The value {value} for quantity {self} has not shape {self.shape}') + raise TypeError(f'The value {value} for quantity {self} has no shape {self.shape}') if any(v is not None and type(v) != self._type for v in value): raise TypeError( diff --git a/nomad/metainfo/util.py b/nomad/metainfo/util.py index fd005e4e4529d7ef147665a6b440ea0a827a1ba3..c51a141395c62134a726d5a9f519b16269359d57 100644 --- a/nomad/metainfo/util.py +++ b/nomad/metainfo/util.py @@ -20,9 +20,10 @@ import email.utils import hashlib import re from dataclasses import dataclass -from datetime import datetime, date +from datetime import date, datetime from difflib import SequenceMatcher -from typing import Sequence, Dict, Any, Optional, Union, Tuple +from functools import reduce +from typing import Any, Dict, Optional, Sequence, Tuple, Union from urllib.parse import SplitResult, urlsplit, urlunsplit import aniso8601 @@ -37,7 +38,7 @@ __hash_method = 'sha1' # choose from hashlib.algorithms_guaranteed _delta_symbols = {'delta_', 'Δ'} -@dataclass +@dataclass(frozen=True) class MRegEx: # matches the range of indices, e.g., 1..3, 0..* index_range = re.compile(r'(\d)\.\.(\d|\*)') @@ -55,15 +56,121 @@ class MRegEx: r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' r'(?::\d+)?' r'(?:/?|[/?]\S+)$', re.IGNORECASE) + complex_str = re.compile( + r'^(?=[iIjJ.\d+-])([+-]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?(?![iIjJ.\d]))?' + r'([+-]?(?:(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?)?[iIjJ])?$') -@dataclass +def normalize_complex(value, complex_type, to_unit: Union[str, ureg.Unit, None]): + ''' + Try to convert a given value to a complex number. + ''' + + def __check_precision(_type): + if isinstance(_type, type(None)): + return + + precision_error = ValueError( + f'Cannot type {_type.__name__} to complex number of type {complex_type.__name__} ' + f'due to possibility of loss of precision.') + + if complex_type in (np.complex128, complex): # 64-bit complex + if _type in (np.int64, np.uint64, np.float128, np.complex256): + raise precision_error + elif complex_type == np.complex64: # 32-bit complex + if _type in ( + int, float, + np.int32, np.int64, np.uint32, np.uint64, + np.float64, np.float128, np.complex128, np.complex256): + raise precision_error + + if isinstance(value, pint.Quantity): + scaled: np.ndarray = value.to(to_unit).magnitude if to_unit else value.magnitude + return normalize_complex(scaled, complex_type, None) + + # a list of complex numbers represented by int, float or str + if isinstance(value, list): + normalized = [normalize_complex(v, complex_type, to_unit) for v in value] + return normalized if complex_type == complex else np.array(normalized, dtype=complex_type) + + # complex or real part only + if type(value) in MTypes.num: + __check_precision(type(value)) + return complex_type(value) + + # np array + if isinstance(value, np.ndarray): + __check_precision(value.dtype.type) + return value.astype(complex_type) + + # dict representation of complex number + if isinstance(value, dict): + real = value.get('re') + imag = value.get('im') + assert real is not None or imag is not None, 'Cannot convert an empty dict to complex number.' + + def __combine(_real, _imag): + _real_list: bool = isinstance(_real, list) + _imag_list: bool = isinstance(_imag, list) + if _real_list or _real_list: + if _real is None: + return [__combine(None, i) for i in _imag] + if _imag is None: + return [__combine(r, None) for r in _real] + # leverage short-circuit evaluation, do not change order + if _real_list and _imag_list and len(_real) == len(_imag): + return [__combine(r, i) for r, i in zip(_real, _imag)] + + raise ValueError('Cannot combine real and imaginary parts of complex numbers.') + + __check_precision(type(_real)) + __check_precision(type(_imag)) + if _real is None: + return complex_type(_imag) * 1j + if _imag is None: + return complex_type(_real) + return complex_type(_real) + complex_type(_imag) * 1j + + combined = __combine(real, imag) + return combined if complex_type == complex else np.array(combined, dtype=complex_type) + + # a string, '1+2j' + # one of 'i', 'I', 'j', 'J' can be used to represent the imaginary unit + if isinstance(value, str): + match = MRegEx.complex_str.match(value) + if match is not None: + return complex_type(reduce(lambda a, b: a.replace(b, 'j'), 'iIJ', value)) + + raise ValueError(f'Cannot convert {value} to complex number.') + + +def serialize_complex(value): + ''' + Convert complex number to string. + ''' + # scalar + if type(value) in MTypes.complex: + return {'re': value.real, 'im': value.imag} + + # 1D + if isinstance(value, (list, tuple)): + return {'re': [v.real for v in value], 'im': [v.imag for v in value]} + + # ND + if isinstance(value, np.ndarray): + return {'re': value.real.tolist(), 'im': value.imag.tolist()} + + raise ValueError(f'Cannot serialize {value}.') + + +@dataclass(frozen=True) class MTypes: # todo: account for bytes which cannot be naturally serialized to JSON primitive = { str: lambda v: None if v is None else str(v), int: lambda v: None if v is None else int(v), float: lambda v: None if v is None else float(v), + complex: lambda v: None if v is None else complex(v), bool: lambda v: None if v is None else bool(v), np.bool_: lambda v: None if v is None else bool(v)} @@ -72,11 +179,14 @@ class MTypes: int_numpy = {np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64} int_python = {int} int = int_python | int_numpy - float_numpy = {np.float16, np.float32, np.float64} + float_numpy = {np.float16, np.float32, np.float64, np.float128} + complex_numpy = {np.complex64, np.complex128, np.complex256} float_python = {float} + complex_python = {complex} float = float_python | float_numpy - num_numpy = int_numpy | float_numpy - num_python = int_python | float_python + complex = complex_python | complex_numpy + num_numpy = int_numpy | float_numpy | complex_numpy + num_python = int_python | float_python | complex_python num = num_python | num_numpy str_numpy = {np.str_} bool_numpy = {np.bool_} @@ -611,10 +721,13 @@ def to_numpy(np_type, shape: list, unit: Optional[pint.Unit], definition, value: raise AttributeError( f'Could not convert value {value} of type pandas.Dataframe to a numpy array') + if np_type in MTypes.complex: + value = normalize_complex(value, np_type, unit) + if type(value) != np.ndarray: if len(shape) > 0: try: - value = np.asarray(value) + value = np.asarray(value, dtype=np_type) except TypeError: raise TypeError(f'Could not convert value {value} of {definition} to a numpy array') elif type(value) != np_type: @@ -622,6 +735,11 @@ def to_numpy(np_type, shape: list, unit: Optional[pint.Unit], definition, value: value = np_type(value) except TypeError: raise TypeError(f'Could not convert value {value} of {definition} to a numpy scalar') + elif value.dtype != np_type and np_type in MTypes.complex: + try: + value = value.astype(np_type) + except TypeError: + raise TypeError(f'Could not convert value {value} of {definition} to a numpy array') return value @@ -721,6 +839,18 @@ def __parse_datetime(datetime_str: str) -> datetime: except ValueError: pass + if 'GMT' in datetime_str: + dt_copy = datetime_str + dt_split = dt_copy.split('GMT') + tzinfo = dt_split[1].strip() + if len(tzinfo) == 2: + tzinfo = f'{tzinfo[0]}{tzinfo[1]:0>2}00' + dt_copy = f'{dt_split[0]}GMT{tzinfo}' + try: + return datetime.strptime(dt_copy, '%Y%m%d_%H:%M:%S_%Z%z') + except ValueError: + pass + try: return datetime.fromisoformat(datetime_str) except ValueError: diff --git a/tests/metainfo/test_metainfo.py b/tests/metainfo/test_metainfo.py index 45835a2a1f90301115ebda34f28468672bca6be2..47b5899f762cefb2c32f6d16d6b62ed8c6663df5 100644 --- a/tests/metainfo/test_metainfo.py +++ b/tests/metainfo/test_metainfo.py @@ -210,9 +210,7 @@ class TestM2: @pytest.mark.parametrize('dtype', [ pytest.param(np.longlong), - pytest.param(np.ulonglong), - pytest.param(np.float128), - pytest.param(np.complex128), + pytest.param(np.ulonglong) ]) def test_unsupported_type(self, dtype): with pytest.raises(MetainfoError): diff --git a/tests/metainfo/test_quantities.py b/tests/metainfo/test_quantities.py index 71141d055c74468d8946cab8e3269f35e3b9cadd..dc19959237b505b626c5469bbe6fd1fbb2f87173 100644 --- a/tests/metainfo/test_quantities.py +++ b/tests/metainfo/test_quantities.py @@ -15,30 +15,24 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import datetime +import json +import numpy as np +import pint import pytest -import json -import datetime import pytz + from nomad.metainfo.metainfo import ( - MSection, - Quantity, - Unit, - units, - JSON, - Dimension, - Datetime, - Capitalized, - Bytes, - URL, - MTypes -) + Bytes, Capitalized, Datetime, Dimension, JSON, MSection, MTypes, Quantity, URL, Unit, units) @pytest.mark.parametrize('def_type, value', [ pytest.param(str, 'hello', id='str'), pytest.param(int, 23, id='int'), pytest.param(float, 3.14e23, id='float'), + pytest.param(complex, 3.14e23 - 2j, id='complex'), + pytest.param(np.complex128, 3.14e23 - 2j, id='np.complex128'), pytest.param(bool, True, id='bool'), pytest.param(JSON, dict(key='value'), id='JSON'), pytest.param(Unit, units.parse_units('m*m/s'), id='Unit'), @@ -120,6 +114,7 @@ def test_normalization_number(def_type, unit, shape, input, output, valid): that contains both the magnitude and the unit. This way the unit information is not lost when using these values in e.g. assignments between two fields. ''' + def define(): class TestSection(MSection): @@ -135,3 +130,104 @@ def test_normalization_number(def_type, unit, shape, input, output, valid): define() else: define() + + +@pytest.mark.parametrize('unit', [ + pytest.param('m', id='has-unit'), + pytest.param(None, id='no-unit'), +]) +@pytest.mark.parametrize('quantity_type,value,shape', [ + pytest.param(complex, 1j, None, id='complex-scalar'), + pytest.param(complex, 1, None, id='complex-from-int'), + pytest.param(complex, 1.21, None, id='complex-from-float'), + pytest.param(complex, pint.Quantity(1.242, 'mm'), None, id='complex-from-pint'), + pytest.param(complex, '1j', None, id='complex-scalar-str'), + pytest.param(np.complex128, 1j, None, id='np.complex128-scalar'), + pytest.param(np.complex128, '1j', None, id='np.complex128-scalar-str'), + pytest.param(complex, [1j, 2j], ['*'], id='complex-list'), + pytest.param(np.complex128, [1j, 2j], ['*'], id='np.complex128-vector'), + pytest.param(np.complex128, np.array([1j, 2j]), ['*'], id='np.complex128-nparray'), + pytest.param(np.complex128, ['1j', '2j'], ['*'], id='np.complex128-vector-str'), +]) +def test_complex_number(unit, quantity_type, value, shape): + class TestSection(MSection): + quantity = Quantity(type=quantity_type, unit=unit, shape=shape) + + def assert_complex_equal(): + result = section.quantity.m if unit else section.quantity + if isinstance(value, (list, np.ndarray)): + for a, b in zip(result, value): + assert a == quantity_type(b) + elif not isinstance(value, pint.Quantity): + assert result == quantity_type(value) + elif unit: + assert result == quantity_type(value.to(unit).magnitude) + else: + assert result == quantity_type(value.magnitude) + + def roster(_value): + if isinstance(value, str): + for i in 'iIjJ': + yield _value.replace('j', i) + if isinstance(_value, list) and isinstance(_value[0], str): + for i in 'iIjJ': + yield [_v.replace('j', i) for _v in _value] + yield _value + + for v in roster(value): + section = TestSection() + section.quantity = v + assert_complex_equal() + + section = TestSection.m_from_dict(section.m_to_dict()) + assert_complex_equal() + + +@pytest.mark.parametrize('unit', [ + pytest.param('m', id='has-unit'), + pytest.param(None, id='no-unit'), +]) +@pytest.mark.parametrize('quantity_type,value,result,shape', [ + pytest.param(complex, {'im': -1}, -1j, None, id='complex-scalar'), + pytest.param(complex, {'re': 1.25}, complex(1.25), None, id='complex-from-float'), + pytest.param(np.complex128, {'im': 1}, np.complex128(1j), None, id='np.complex128-scalar-im'), + pytest.param(np.complex128, {'re': 1.25}, np.complex128(1.25), None, id='np.complex128-scalar-re'), + pytest.param(complex, {'re': 1.25, 'im': -2312}, 1.25 - 2312j, None, id='complex-full'), + pytest.param( + np.complex128, {'re': [[1, 2, 3], [4, 5, 6]], 'im': [[1, 2, 3], [4, 5, 6]]}, # no shape checking anyway + np.array([[1, 2, 3], [4, 5, 6]]) + 1j * np.array([[1, 2, 3], [4, 5, 6]]), ['*'], id='complex-full'), +]) +def test_complex_number_dict(unit, quantity_type, value, result, shape): + class TestSection(MSection): + quantity = Quantity(type=quantity_type, unit=unit, shape=shape) + + def assert_complex_equal(): + quantity = section.quantity.m if unit else section.quantity + if isinstance(result, np.ndarray): + assert np.all(quantity == result) + else: + assert quantity == result + + section = TestSection() + section.quantity = value + assert_complex_equal() + + section = TestSection.m_from_dict(section.m_to_dict()) + assert_complex_equal() + + +@pytest.mark.parametrize('quantity_type,value', [ + pytest.param(np.complex128, np.int64(1), id='downcast-from-int-128'), + pytest.param(np.complex128, np.complex256(1), id='downcast-from-float-128'), + pytest.param(np.complex64, {'re': 1}, id='downcast-from-int-64'), + pytest.param(np.complex64, {'re': 1.25}, id='downcast-from-float-64'), + pytest.param(np.complex128, {'re': [1.25, 1], 'im': 1}, id='mismatch-shape'), + pytest.param(np.complex128, {}, id='empty-dict'), +]) +def test_complex_number_exception(quantity_type, value): + class TestSection(MSection): + quantity = Quantity(type=quantity_type) + + section = TestSection() + with pytest.raises((ValueError, AssertionError)): + section.quantity = value