From be959bde8ddfb2ba5aff709315ae5c0b6d1490ef Mon Sep 17 00:00:00 2001
From: Markus Scheidgen <markus.scheidgen@gmail.com>
Date: Tue, 5 May 2020 17:10:26 +0200
Subject: [PATCH] Fixed bugs in handling indexes in the required part of the
 archive query api. Client lib now adds required quantities to query. #334

---
 examples/client.py    | 11 +++++++----
 nomad/archive.py      | 22 +++++++++++++++++++---
 nomad/client.py       | 41 ++++++++++++++++++++++++++++-------------
 tests/test_archive.py |  3 ++-
 4 files changed, 56 insertions(+), 21 deletions(-)

diff --git a/examples/client.py b/examples/client.py
index 27ba7bf5e2..a9b86e4b31 100644
--- a/examples/client.py
+++ b/examples/client.py
@@ -1,5 +1,5 @@
 from nomad import config
-from nomad.client import query_archive
+from nomad.client import ArchiveQuery
 from nomad.metainfo import units
 
 # this will not be necessary, once this is the official NOMAD version
@@ -16,7 +16,8 @@ query = ArchiveQuery(
         'section_run': {
             'section_single_configuration_calculation[0]': {
                 'energy_total': '*'
-            }
+            },
+            'section_system[0]': '*'
         }
     },
     per_page=10,
@@ -25,5 +26,7 @@ query = ArchiveQuery(
 print(query)
 
 for result in query[0:10]:
-    energy = result.section_run[0].section_single_configuration_calculation[0].energy_total
-    print('Energy %s' % energy.to(units.hartree))
+    run = result.section_run[0]
+    energy = run.section_single_configuration_calculation[0].energy_total
+    formula = run.section_system[0].chemical_composition_reduced
+    print('%s: energy %s' % (formula, energy.to(units.hartree)))
diff --git a/nomad/archive.py b/nomad/archive.py
index 6641eebaec..a9be785ac9 100644
--- a/nomad/archive.py
+++ b/nomad/archive.py
@@ -538,6 +538,13 @@ __query_archive_key_pattern = re.compile(r'^([\s\w\-]+)(\[([-?0-9]*)(:([-?0-9]*)
 
 
 def query_archive(f_or_archive_reader: Union[str, ArchiveReader, BytesIO], query_dict: dict, **kwargs):
+    def _fix_index(index, length):
+        if index is None:
+            return index
+        if index < 0:
+            return max(-(length), index)
+        else:
+            return min(length, index)
 
     def _to_son(data):
         if isinstance(data, (ArchiveList, List)):
@@ -558,10 +565,11 @@ def query_archive(f_or_archive_reader: Union[str, ArchiveReader, BytesIO], query
 
             # process array indices
             match = __query_archive_key_pattern.match(key)
-            index: Tuple[int, int] = None
+            index: Union[Tuple[int, int], int] = None
             if match:
                 key = match.group(1)
 
+                # check if we have indices
                 if match.group(2) is not None:
                     first_index, last_index = None, None
                     group = match.group(3)
@@ -573,7 +581,7 @@ def query_archive(f_or_archive_reader: Union[str, ArchiveReader, BytesIO], query
                         index = (0 if first_index is None else first_index, last_index)
 
                     else:
-                        index = (first_index, first_index + 1)  # one item
+                        index = first_index  # one item
 
                 else:
                     index = None
@@ -599,7 +607,15 @@ def query_archive(f_or_archive_reader: Union[str, ArchiveReader, BytesIO], query
                 if index is None:
                     pass
                 else:
-                    archive_child = archive_child[index[0]: index[1]]
+                    length = len(archive_child)
+                    if isinstance(index, list):
+                        index = (_fix_index(index[0], length), _fix_index(index[1], length))
+                        if index[0] == index[1]:
+                            archive_child = [archive_child[index[0]]]
+                        else:
+                            archive_child = archive_child[index[0]: index[1]]
+                    else:
+                        archive_child = [archive_child[_fix_index(index, length)]]
 
                 if isinstance(archive_child, (ArchiveList, list)):
                     result[key] = [_load_data(val, item) for item in archive_child]
diff --git a/nomad/client.py b/nomad/client.py
index cbeed1f19e..9721c9610a 100644
--- a/nomad/client.py
+++ b/nomad/client.py
@@ -40,23 +40,23 @@ This script should yield a result like this:
 
 .. code::
 
-    Number queries entries: 7667
+    Number queries entries: 7628
     Number of entries loaded in the last api call: 10
-    Bytes loaded in the last api call: 3579
-    Bytes loaded from this query: 3579
+    Bytes loaded in the last api call: 118048
+    Bytes loaded from this query: 118048
     Number of downloaded entries: 10
     Number of made api calls: 1
 
-    Energy -178.6990610734937 hartree
-    Energy -6551.45699684026 hartree
-    Energy -6551.461104765451 hartree
-    Energy -548.9736595672932 hartree
-    Energy -548.9724185656775 hartree
-    Energy -1510.3938165430286 hartree
-    Energy -1510.3937761449583 hartree
-    Energy -11467.827149010665 hartree
-    Energy -16684.667362890417 hartree
-    Energy -1510.3908614326358 hartree
+    Cd2O2: energy -11467.827149010665 hartree
+    Sr2O2: energy -6551.45699684026 hartree
+    Sr2O2: energy -6551.461104765451 hartree
+    Be2O2: energy -178.6990610734937 hartree
+    Ca2O2: energy -1510.3938165430286 hartree
+    Ca2O2: energy -1510.3937761449583 hartree
+    Ba2O2: energy -16684.667362890417 hartree
+    Mg2O2: energy -548.9736595672932 hartree
+    Mg2O2: energy -548.9724185656775 hartree
+    Ca2O2: energy -1510.3908614326358 hartree
 
 Let's discuss the different elements here. First, we have a set of imports. The NOMAD source
 codes comes with various sub-modules. The `client` module contains everything related
@@ -266,6 +266,21 @@ class ArchiveQuery(collections.abc.Sequence):
             self.query['query'].update(query)
         if required is not None:
             self.query['query_schema'] = required
+            # We try to add all required properties to the query to ensure that only
+            # results with those properties are returned.
+            section_run_key = next(key for key in required if key.split('[')[0] == 'section_run')
+            if section_run_key is not None:
+                # add all quantities in required to the query part
+                quantities = set()
+                stack = [required[section_run_key]]
+                while len(stack) > 0:
+                    required_dict = stack.pop()
+                    for key, value in required_dict.items():
+                        if isinstance(value, dict):
+                            stack.append(value)
+                        quantities.add(key.split('[')[0])
+                self.query['query'].setdefault('dft.quantities', []).extend(quantities)
+                self.query['query']['domain'] = 'dft'
 
         self.password = password
         self.username = username
diff --git a/tests/test_archive.py b/tests/test_archive.py
index b9065c9b26..660a1a4308 100644
--- a/tests/test_archive.py
+++ b/tests/test_archive.py
@@ -236,9 +236,10 @@ test_query_example: Dict[Any, Any] = {
     ({'c1': {'s1': {'ss1[:2]': '*'}}}, {'c1': {'s1': {'ss1': test_query_example['c1']['s1']['ss1'][:2]}}}),
     ({'c1': {'s1': {'ss1[0:2]': '*'}}}, {'c1': {'s1': {'ss1': test_query_example['c1']['s1']['ss1'][0:2]}}}),
     ({'c1': {'s1': {'ss1[-2]': '*'}}}, {'c1': {'s1': {'ss1': test_query_example['c1']['s1']['ss1'][-2:-1]}}}),
+    ({'c1': {'s1': {'ss1[-10]': '*'}}}, {'c1': {'s1': {'ss1': test_query_example['c1']['s1']['ss1'][-2:-1]}}}),
     ({'c1': {'s1': {'ss1[:-1]': '*'}}}, {'c1': {'s1': {'ss1': test_query_example['c1']['s1']['ss1'][:-1]}}}),
     ({'c1': {'s1': {'ss1[1:-1]': '*'}}}, {'c1': {'s1': {'ss1': test_query_example['c1']['s1']['ss1'][1:-1]}}}),
-    ({'c2': {'s1': {'ss1[-3:-1]': '*'}}}, {'c2': {'s1': {'ss1': test_query_example['c2']['s1']['ss1'][-3:-1]}}}),
+    ({'c2': {'s1': {'ss1[-3:-1]': '*'}}}, {'c2': {'s1': {'ss1': [test_query_example['c2']['s1']['ss1'][-1]]}}}),
     ({'c1': {'s2[0]': {'p1': '*'}}}, {'c1': {'s2': [{'p1': test_query_example['c1']['s2'][0]['p1']}]}}),
     ({'c1': {'s3': '*'}}, {'c1': {}}),
     ({'c1': {'s1[0]': '*'}}, ArchiveQueryError())
-- 
GitLab