From b05af557259f46ceb4e4c7c4648af8b4f0d4837d Mon Sep 17 00:00:00 2001
From: Theodore Chang <theodore.chang@physik.hu-berlin.de>
Date: Tue, 15 Oct 2024 19:10:19 +0000
Subject: [PATCH] Use wrapper to ensure safe close

---
 docs/howto/customization/hdf5.md |  5 +++-
 nomad/datamodel/context.py       | 21 +++------------
 nomad/datamodel/hdf5.py          | 45 +++++++++++++++++++++++++-------
 nomad/metainfo/metainfo.py       |  2 +-
 tests/datamodel/test_hdf5.py     |  2 --
 5 files changed, 44 insertions(+), 31 deletions(-)

diff --git a/docs/howto/customization/hdf5.md b/docs/howto/customization/hdf5.md
index 278cdf8660..3b37d32488 100644
--- a/docs/howto/customization/hdf5.md
+++ b/docs/howto/customization/hdf5.md
@@ -120,6 +120,8 @@ class LargeData(ArchiveSection):
 The assigned value will also be written to the archive HDF5 file and serialized as
 `/uploads/test_upload/archive/test_entry#/data/value`.
 
+To read the dataset, one shall use the context manager to ensure the file is closed properly when done.
+
 ```python
 archive.data.value = np.ones(3)
 
@@ -127,7 +129,8 @@ serialized = archive.m_to_dict()
 serialized['data']['value']
 # '/uploads/test_upload/archive/test_entry#/data/value'
 deserialized = archive.m_from_dict(serialized, m_context=archive.m_context)
-deserialized.data.value
+with deserialized.data.value as dataset:
+    print(dataset[:])
 # array([1., 1., 1.])
 ```
 
diff --git a/nomad/datamodel/context.py b/nomad/datamodel/context.py
index 7782c52685..2f3ccbd1e8 100644
--- a/nomad/datamodel/context.py
+++ b/nomad/datamodel/context.py
@@ -262,9 +262,6 @@ class Context(MetainfoContext):
         self.archives[url] = archive
         self.urls[archive] = url
 
-    def close(self):
-        pass
-
 
 class ServerContext(Context):
     def __init__(self, upload=None):
@@ -412,22 +409,10 @@ class ServerContext(Context):
 
         return response.json()['data']
 
-    def open_hdf5_file(self, section: MSection):
-        upload_id, entry_id = self._get_ids(section.m_root(), required=True)
-
-        if not upload_id or not entry_id:
-            return None
-
-        return h5py.File(self.upload_files.archive_hdf5_location(entry_id), 'a')
-
-    def close(self):
-        pass
-
-    def __enter__(self):
-        return self
+    def hdf5_path(self, section: MSection):
+        _, entry_id = self._get_ids(section.m_root(), required=True)
 
-    def __exit__(self, type, value, traceback):
-        self.close()
+        return self.upload_files.archive_hdf5_location(entry_id)
 
 
 def _validate_url(url):
diff --git a/nomad/datamodel/hdf5.py b/nomad/datamodel/hdf5.py
index 121de53324..b2aa966d60 100644
--- a/nomad/datamodel/hdf5.py
+++ b/nomad/datamodel/hdf5.py
@@ -15,6 +15,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
+from __future__ import annotations
 
 from typing import Any
 import h5py
@@ -22,6 +23,7 @@ import re
 
 import numpy as np
 import pint
+from h5py import File
 
 from nomad.metainfo.data_type import NonPrimitive
 from nomad.utils import get_logger
@@ -44,6 +46,25 @@ def match_hdf5_reference(reference: str):
     return match.groupdict()
 
 
+class HDF5Wrapper:
+    def __init__(self, file: str, path: str):
+        self.file: str = file
+        self.path: str = path
+        self.handler: h5py.File | None = None
+
+    def __enter__(self):
+        self._close()
+        self.handler = h5py.File(self.file, 'a')
+        return self.handler[self.path]
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        self._close()
+
+    def _close(self):
+        if self.handler:
+            self.handler.close()
+
+
 class HDF5Reference(NonPrimitive):
     @staticmethod
     def _get_upload_files(archive, path: str):
@@ -120,7 +141,7 @@ class HDF5Dataset(NonPrimitive):
         if not isinstance(value, (str, np.ndarray, h5py.Dataset, pint.Quantity)):
             raise ValueError(f'Invalid HDF5 dataset value: {value}.')
 
-        hdf5_file = section_context.open_hdf5_file(section)
+        hdf5_path: str = section_context.hdf5_path(section)
 
         if isinstance(value, str):
             if not (match := match_hdf5_reference(value)):
@@ -129,7 +150,11 @@ class HDF5Dataset(NonPrimitive):
 
             file, path = match['file_id'], match['path']
 
-            target_dataset = (hdf5_file[file] if file in hdf5_file else hdf5_file)[path]
+            with File(hdf5_path, 'a') as hdf5_file:
+                if file in hdf5_file:
+                    segment = f'{file}/{path}'
+                else:
+                    segment = path
         else:
             if isinstance(value, pint.Quantity):
                 if self._definition.unit is not None:
@@ -137,11 +162,13 @@ class HDF5Dataset(NonPrimitive):
                 else:
                     value = value.magnitude
 
-            target_dataset = hdf5_file.require_dataset(
-                f'{section.m_path()}/{self._definition.name}',
-                shape=getattr(value, 'shape', ()),
-                dtype=getattr(value, 'dtype', None),
-            )
-            target_dataset[...] = value
+            with File(hdf5_path, 'a') as hdf5_file:
+                segment = f'{section.m_path()}/{self._definition.name}'
+                target_dataset = hdf5_file.require_dataset(
+                    segment,
+                    shape=getattr(value, 'shape', ()),
+                    dtype=getattr(value, 'dtype', None),
+                )
+                target_dataset[...] = value
 
-        return target_dataset
+        return HDF5Wrapper(hdf5_path, segment)
diff --git a/nomad/metainfo/metainfo.py b/nomad/metainfo/metainfo.py
index 670ad6a78e..79717aac9b 100644
--- a/nomad/metainfo/metainfo.py
+++ b/nomad/metainfo/metainfo.py
@@ -837,7 +837,7 @@ class Context:
                 return section.section_cls
         return None
 
-    def open_hdf5_file(self, section: MSection):
+    def hdf5_path(self, section: MSection):
         raise NotImplementedError
 
 
diff --git a/tests/datamodel/test_hdf5.py b/tests/datamodel/test_hdf5.py
index 52eb29ceab..bbdb88e143 100644
--- a/tests/datamodel/test_hdf5.py
+++ b/tests/datamodel/test_hdf5.py
@@ -84,5 +84,3 @@ def test_hdf5(test_context, quantity_type, value):
         with h5py.File(test_context.upload_files.raw_file(filename, 'rb')) as f:
             quantity = HDF5Reference.read_dataset(archive, value)
             assert (quantity == f[path][()]).all()
-
-    test_context.close()
-- 
GitLab