From b05af557259f46ceb4e4c7c4648af8b4f0d4837d Mon Sep 17 00:00:00 2001 From: Theodore Chang <theodore.chang@physik.hu-berlin.de> Date: Tue, 15 Oct 2024 19:10:19 +0000 Subject: [PATCH] Use wrapper to ensure safe close --- docs/howto/customization/hdf5.md | 5 +++- nomad/datamodel/context.py | 21 +++------------ nomad/datamodel/hdf5.py | 45 +++++++++++++++++++++++++------- nomad/metainfo/metainfo.py | 2 +- tests/datamodel/test_hdf5.py | 2 -- 5 files changed, 44 insertions(+), 31 deletions(-) diff --git a/docs/howto/customization/hdf5.md b/docs/howto/customization/hdf5.md index 278cdf8660..3b37d32488 100644 --- a/docs/howto/customization/hdf5.md +++ b/docs/howto/customization/hdf5.md @@ -120,6 +120,8 @@ class LargeData(ArchiveSection): The assigned value will also be written to the archive HDF5 file and serialized as `/uploads/test_upload/archive/test_entry#/data/value`. +To read the dataset, one shall use the context manager to ensure the file is closed properly when done. + ```python archive.data.value = np.ones(3) @@ -127,7 +129,8 @@ serialized = archive.m_to_dict() serialized['data']['value'] # '/uploads/test_upload/archive/test_entry#/data/value' deserialized = archive.m_from_dict(serialized, m_context=archive.m_context) -deserialized.data.value +with deserialized.data.value as dataset: + print(dataset[:]) # array([1., 1., 1.]) ``` diff --git a/nomad/datamodel/context.py b/nomad/datamodel/context.py index 7782c52685..2f3ccbd1e8 100644 --- a/nomad/datamodel/context.py +++ b/nomad/datamodel/context.py @@ -262,9 +262,6 @@ class Context(MetainfoContext): self.archives[url] = archive self.urls[archive] = url - def close(self): - pass - class ServerContext(Context): def __init__(self, upload=None): @@ -412,22 +409,10 @@ class ServerContext(Context): return response.json()['data'] - def open_hdf5_file(self, section: MSection): - upload_id, entry_id = self._get_ids(section.m_root(), required=True) - - if not upload_id or not entry_id: - return None - - return h5py.File(self.upload_files.archive_hdf5_location(entry_id), 'a') - - def close(self): - pass - - def __enter__(self): - return self + def hdf5_path(self, section: MSection): + _, entry_id = self._get_ids(section.m_root(), required=True) - def __exit__(self, type, value, traceback): - self.close() + return self.upload_files.archive_hdf5_location(entry_id) def _validate_url(url): diff --git a/nomad/datamodel/hdf5.py b/nomad/datamodel/hdf5.py index 121de53324..b2aa966d60 100644 --- a/nomad/datamodel/hdf5.py +++ b/nomad/datamodel/hdf5.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from __future__ import annotations from typing import Any import h5py @@ -22,6 +23,7 @@ import re import numpy as np import pint +from h5py import File from nomad.metainfo.data_type import NonPrimitive from nomad.utils import get_logger @@ -44,6 +46,25 @@ def match_hdf5_reference(reference: str): return match.groupdict() +class HDF5Wrapper: + def __init__(self, file: str, path: str): + self.file: str = file + self.path: str = path + self.handler: h5py.File | None = None + + def __enter__(self): + self._close() + self.handler = h5py.File(self.file, 'a') + return self.handler[self.path] + + def __exit__(self, exc_type, exc_value, traceback): + self._close() + + def _close(self): + if self.handler: + self.handler.close() + + class HDF5Reference(NonPrimitive): @staticmethod def _get_upload_files(archive, path: str): @@ -120,7 +141,7 @@ class HDF5Dataset(NonPrimitive): if not isinstance(value, (str, np.ndarray, h5py.Dataset, pint.Quantity)): raise ValueError(f'Invalid HDF5 dataset value: {value}.') - hdf5_file = section_context.open_hdf5_file(section) + hdf5_path: str = section_context.hdf5_path(section) if isinstance(value, str): if not (match := match_hdf5_reference(value)): @@ -129,7 +150,11 @@ class HDF5Dataset(NonPrimitive): file, path = match['file_id'], match['path'] - target_dataset = (hdf5_file[file] if file in hdf5_file else hdf5_file)[path] + with File(hdf5_path, 'a') as hdf5_file: + if file in hdf5_file: + segment = f'{file}/{path}' + else: + segment = path else: if isinstance(value, pint.Quantity): if self._definition.unit is not None: @@ -137,11 +162,13 @@ class HDF5Dataset(NonPrimitive): else: value = value.magnitude - target_dataset = hdf5_file.require_dataset( - f'{section.m_path()}/{self._definition.name}', - shape=getattr(value, 'shape', ()), - dtype=getattr(value, 'dtype', None), - ) - target_dataset[...] = value + with File(hdf5_path, 'a') as hdf5_file: + segment = f'{section.m_path()}/{self._definition.name}' + target_dataset = hdf5_file.require_dataset( + segment, + shape=getattr(value, 'shape', ()), + dtype=getattr(value, 'dtype', None), + ) + target_dataset[...] = value - return target_dataset + return HDF5Wrapper(hdf5_path, segment) diff --git a/nomad/metainfo/metainfo.py b/nomad/metainfo/metainfo.py index 670ad6a78e..79717aac9b 100644 --- a/nomad/metainfo/metainfo.py +++ b/nomad/metainfo/metainfo.py @@ -837,7 +837,7 @@ class Context: return section.section_cls return None - def open_hdf5_file(self, section: MSection): + def hdf5_path(self, section: MSection): raise NotImplementedError diff --git a/tests/datamodel/test_hdf5.py b/tests/datamodel/test_hdf5.py index 52eb29ceab..bbdb88e143 100644 --- a/tests/datamodel/test_hdf5.py +++ b/tests/datamodel/test_hdf5.py @@ -84,5 +84,3 @@ def test_hdf5(test_context, quantity_type, value): with h5py.File(test_context.upload_files.raw_file(filename, 'rb')) as f: quantity = HDF5Reference.read_dataset(archive, value) assert (quantity == f[path][()]).all() - - test_context.close() -- GitLab