Commit c6956c48 authored by Markus Scheidgen's avatar Markus Scheidgen
Browse files

Merge branch 'parallel-archive-client' into 'v1.0.0'

Parallel archive client

See merge request !157
parents a1ffb277 37ac62ea
Pipeline #80900 passed with stages
in 27 minutes and 48 seconds
......@@ -2,7 +2,6 @@
A simple example that uses the NOMAD client library to access the archive.
'''
from nomad import config
from nomad.client import ArchiveQuery
from nomad.metainfo import units
......@@ -22,12 +21,13 @@ query = ArchiveQuery(
'section_system[-2]': '*'
}
},
per_page=10,
parallel=5,
per_page=20,
max=1000)
print(query)
for result in query[0:10]:
for result in query[0:100]:
run = result.section_run[0]
energy = run.section_single_configuration_calculation[0].energy_total
formula = run.section_system[0].chemical_composition_reduced
......
......@@ -148,6 +148,9 @@ from bravado import requests_client as bravado_requests_client
import time
from keycloak import KeycloakOpenID
from io import StringIO
import math
from urllib.parse import urlencode
import multiprocessing
from nomad import config
from nomad import metainfo as mi
......@@ -157,6 +160,16 @@ from nomad.datamodel import EntryArchive
from nomad import parsing # pylint: disable=unused-import
# This is only necessary to path it during test, because the HTTP client interface differs in test
def get_json(response):
return response.json()
# This is only necessary to path it during test, because the HTTP client interface differs in test
def get_length(response):
return len(response.content)
class QueryError(Exception):
pass
......@@ -192,7 +205,7 @@ class ApiStatistics(mi.MSection):
nentries = mi.Quantity(
type=int, default=0,
description='Number queries entries')
description='Number queried entries')
last_response_nentries = mi.Quantity(
type=int, default=0,
......@@ -222,6 +235,67 @@ class ApiStatistics(mi.MSection):
return out.getvalue()
class ProcState:
'''
A basic pickable data-class that holds the state of one parallel running
processes that loads archive API data.
'''
def __init__(self, archive_query: 'ArchiveQuery'):
self.url = archive_query.url
self.request: Dict[str, Any] = dict(
query=dict(**archive_query.query),
required=archive_query.required,
raise_errors=archive_query.raise_errors)
self.per_page = archive_query.per_page
self.authentication = archive_query.authentication
self.upload_ids: List[str] = []
self.nentries = 0
self.total = None
self.after = None
self.results = None
self.error: Exception = None
self.data_size = 0
def _run_proc(proc_state: ProcState) -> ProcState:
'''
The main function for a process that retrieves data from the archive API based
on its state. Will create a new state. Otherwise it is completely stateless.
'''
try:
url = '%s/%s/%s' % (proc_state.url, 'archive', 'query')
# create the query
proc_state.request['query']['upload_id'] = proc_state.upload_ids
aggregation = proc_state.request.setdefault('aggregation', {'per_page': proc_state.per_page})
if proc_state.after is not None:
aggregation['after'] = proc_state.after
# run the query
response = requests.post(url, headers=proc_state.authentication, json=proc_state.request)
if response.status_code != 200:
if response.status_code == 400:
message = get_json(response).get('message')
errors = get_json(response).get('errors')
if message:
raise QueryError('%s: %s' % (message, errors))
raise QueryError('The query is invalid for unknown reasons (400).')
raise QueryError(
'The query is invalid for unknown reasons (%d).' % response.status_code)
# update the state
proc_state.data_size += get_length(response)
data = get_json(response)
proc_state.results = data.get('results', [])
proc_state.after = data['aggregation'].get('after', None)
proc_state.total = data['aggregation'].get('total', 0)
except Exception as e:
proc_state.error = e
return proc_state
class ArchiveQuery(collections.abc.Sequence):
'''
Object of this class represent a query on the NOMAD Archive. It is solely configured
......@@ -241,34 +315,34 @@ class ArchiveQuery(collections.abc.Sequence):
per_page: Determine how many results are downloaded per page (or scroll window).
Default is 10.
max: Optionally determine the maximum amount of downloaded archives. The iteration
will stop even if more results are available. Default is unlimited.
will stop if max is surpassed even if more results are available. Default is unlimited.
raise_errors: There situations where archives for certain entries are unavailable.
If set to True, this cases will raise an Exception. Otherwise, the entries
with missing archives are simply skipped (default).
authentication: Optionally provide detailed authentication information. Usually,
providing ``username`` and ``password`` should suffice.
parallel: Number of processes to use to retrieve data in parallel. Only data
from different uploads can be retrieved in parallel. Default is 1. The
argument ``per_page`` will refer to archived retrieved in one process per
call.
'''
def __init__(
self,
query: dict = None, required: dict = None,
url: str = None, username: str = None, password: str = None,
per_page: int = 10, max: int = None,
parallel: int = 1, per_page: int = 10, max: int = None,
raise_errors: bool = False,
authentication: Union[Dict[str, str], KeycloakAuthenticator] = None):
self._after = None
self.page = 1
self.parallel = parallel
self.per_page = per_page
self.max = max
self.query: Dict[str, Any] = {
'query': {},
'raise_errors': raise_errors
}
if query is not None:
self.query['query'].update(query)
self.query = query if query is not None else dict()
self.raise_errors = raise_errors
self.required = required if required is not None else dict(section_run='*')
if required is not None:
self.query['query_schema'] = required
# We try to add all required properties to the query to ensure that only
# results with those properties are returned.
section_run_key = next(key for key in required if key.split('[')[0] == 'section_run')
......@@ -285,8 +359,8 @@ class ArchiveQuery(collections.abc.Sequence):
if isinstance(value, dict):
stack.append(value)
quantities.add(key.split('[')[0])
self.query['query'].setdefault('dft.quantities', []).extend(quantities)
self.query['query']['domain'] = 'dft'
self.query.setdefault('dft.quantities', []).extend(quantities)
self.query['domain'] = 'dft'
self.password = password
self.username = username
......@@ -294,9 +368,9 @@ class ArchiveQuery(collections.abc.Sequence):
self._authentication = authentication
self._total = -1
self._capped_total = -1
self._results: List[dict] = []
self._statistics = ApiStatistics()
self._proc_states: List[ProcState] = None
@property
def authentication(self):
......@@ -320,64 +394,109 @@ class ArchiveQuery(collections.abc.Sequence):
else:
return self._authentication
def _create_initial_proc_state(self):
'''
Does preliminary queries to the repo API to determine the distribution of queried
entries over uploads and creates initial state for the processes that collect
data from the archive API in parallel.
'''
uploads: Dict[str, Any] = dict()
nentries = 0
# acquire all uploads and how many entries they contain
query_parameters = self.query
url = '%s/repo/quantity/upload_id?%s' % (self.url, urlencode(query_parameters, doseq=True))
after: str = None
while True:
response = requests.get(
url if after is None else '%s&after=%s' % (url, after),
# TODO size=1000,
headers=self.authentication)
response_data = get_json(response)
after = response_data['quantity']['after']
values = response_data['quantity']['values']
if len(values) == 0:
break
uploads.update(values)
for upload in values.values():
nentries += upload['total']
# distribute uploads to processes
if self.parallel is None:
self.parallel = 1
# TODO This implements a simplified distribution, where an upload is fully
# handled by an individual process. This works because of the high likely hood
# that popular analysis queries (e.g. AFLOW) have results spread over
# many uploads. In other use-cases, e.g. analysing data from an individual user,
# this might not work well, because all entries might be contained in one
# upload.
self._proc_states = []
nentries_per_proc = math.ceil(nentries / self.parallel)
proc_state = ProcState(self)
for upload_id, upload_data in uploads.items():
if proc_state.nentries >= nentries_per_proc:
self._proc_states.append(proc_state)
proc_state = ProcState(self)
proc_state.upload_ids.append(upload_id)
proc_state.nentries += upload_data['total']
self._proc_states.append(proc_state)
self._total = nentries
self._statistics.nentries = nentries
def call_api(self):
'''
Calls the API to retrieve the next set of results. Is automatically called, if
not yet downloaded entries are accessed.
'''
url = '%s/%s/%s' % (self.url, 'archive', 'query')
if self._proc_states is None:
self._create_initial_proc_state()
# run the necessary processes
nproc_states = len(self._proc_states)
if nproc_states == 1:
self._proc_states[0] = _run_proc(self._proc_states[0])
elif nproc_states > 1:
with multiprocessing.Pool(nproc_states) as pool:
self._proc_states = pool.map(_run_proc, self._proc_states)
else:
assert False, 'archive query was not stopped before running out of things to query'
aggregation = self.query.setdefault('aggregation', {'per_page': self.per_page})
if self._after is not None:
aggregation['after'] = self._after
# grab the results from the processes
new_states: List[ProcState] = []
self._statistics.last_response_nentries = 0
self._statistics.last_response_data_size = 0
for proc_state in self._proc_states:
if proc_state.error:
raise proc_state.error
response = requests.post(url, headers=self.authentication, json=self.query)
if response.status_code != 200:
if response.status_code == 400:
message = response.json().get('message')
errors = response.json().get('errors')
if message:
raise QueryError('%s: %s' % (message, errors))
self._statistics.last_response_data_size += proc_state.data_size
self._statistics.loaded_data_size += proc_state.data_size
self._statistics.last_response_nentries += len(proc_state.results)
raise QueryError('The query is invalid for unknown reasons.')
self._results.extend([
EntryArchive.m_from_dict(result['archive'])
for result in proc_state.results])
raise response.raise_for_status()
proc_state.results = None
if proc_state.after is not None:
new_states.append(proc_state)
data = response.json
if not isinstance(data, dict):
data = data()
self._proc_states = new_states
self._statistics.loaded_nentries = len(self._results)
self._statistics.napi_calls += 1
aggregation = data['aggregation']
self._after = aggregation.get('after')
self._total = aggregation['total']
if self.max is not None and len(self._results) >= self.max:
# artificially end the query
self._proc_states = []
if self.max is not None:
self._capped_total = min(self.max, self._total)
else:
self._capped_total = self._total
results = data.get('results', [])
for result in results:
archive = EntryArchive.m_from_dict(result['archive'])
self._results.append(archive)
try:
data_size = len(response.content)
self._statistics.last_response_data_size = data_size
self._statistics.loaded_data_size += data_size
self._statistics.nentries = self._total
self._statistics.last_response_nentries = len(results)
self._statistics.loaded_nentries = len(self._results)
self._statistics.napi_calls += 1
except Exception:
# fails in test due to mocked requests library
pass
if self._after is None:
# there are no more search results, we need to avoid further calls
self._capped_total = len(self._results)
if len(self._proc_states) == 0:
self._total = len(self._results)
def __repr__(self):
......@@ -399,10 +518,10 @@ class ArchiveQuery(collections.abc.Sequence):
return self._results[key]
def __len__(self): # pylint: disable=invalid-length-returned
if self._capped_total == -1:
if self._total == -1:
self.call_api()
return self._capped_total
return self._total
@property
def total(self):
......
from typing import List
from typing import List, Tuple
import pytest
from nomad.client import query_archive
from nomad.metainfo import MSection, SubSection
from nomad.datamodel import EntryArchive
from nomad.datamodel import EntryArchive, User
from nomad.datamodel.metainfo.public import section_run
from tests.app.test_app import BlueprintClient
from tests.processing import test_data as test_processing
# TODO with the existing published_wo_user_metadata fixture there is only one entry
......@@ -66,3 +67,52 @@ def test_query_authentication(api, published, other_test_user_auth, test_user_au
assert_results(query_archive(authentication=other_test_user_auth), total=1)
assert_results(query_archive(authentication=test_user_auth), total=0)
@pytest.fixture(scope='function')
def many_uploads(non_empty_uploaded: Tuple[str, str], test_user: User, proc_infra):
_, upload_file = non_empty_uploaded
for index in range(0, 4):
upload = test_processing.run_processing(('test_upload_%d' % index, upload_file), test_user)
upload.publish_upload() # pylint: disable=no-member
try:
upload.block_until_complete(interval=.01)
except Exception:
pass
@pytest.fixture(scope='function', autouse=True)
def patch_multiprocessing_and_api(monkeypatch):
class TestPool:
''' A fake multiprocessing pool, because multiprocessing does not work well in pytest. '''
def __init__(self, n):
pass
def map(self, f, args):
return [f(arg) for arg in args]
def __enter__(self, *args, **kwargs):
return self
def __exit__(self, *args, **kwargs):
pass
monkeypatch.setattr('multiprocessing.Pool', TestPool)
monkeypatch.setattr('nomad.client.get_json', lambda response: response.json)
monkeypatch.setattr('nomad.client.get_length', lambda response: int(response.headers['Content-Length']))
def test_parallel_query(api, many_uploads, monkeypatch):
result = query_archive(required=dict(section_run='*'), parallel=2)
assert_results(result, total=4)
assert result._statistics.nentries == 4
assert result._statistics.loaded_nentries == 4
assert result._statistics.last_response_nentries == 4
assert result._statistics.napi_calls == 1
result = query_archive(required=dict(section_run='*'), parallel=2, per_page=1)
assert_results(result, total=4)
assert result._statistics.nentries == 4
assert result._statistics.loaded_nentries == 4
assert result._statistics.last_response_nentries == 2
assert result._statistics.napi_calls == 2
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment