diff --git a/nomad/app/v1/routers/entries.py b/nomad/app/v1/routers/entries.py index 3e21f5874b25843cbe473acfd8d2dd97cdb43eee..9c2d3000e9a03199a7d45d7d87bdabeb66c246f8 100644 --- a/nomad/app/v1/routers/entries.py +++ b/nomad/app/v1/routers/entries.py @@ -15,12 +15,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import math from datetime import datetime -import functools -from itertools import zip_longest -import multiprocessing -import operator from typing import Optional, Set, Union, Dict, Iterator, Any, List from fastapi import ( @@ -35,6 +31,7 @@ import json import orjson from pydantic.main import create_model from starlette.responses import Response +from joblib import Parallel, delayed, parallel_backend from nomad import files, config, utils, metainfo, processing as proc from nomad import datamodel @@ -695,7 +692,7 @@ def _validate_required(required: ArchiveRequired, user) -> RequiredReader: detail=[dict(msg=e.msg, loc=['required'] + e.loc)]) -def _read_entry_from_archive(entry, uploads, required_reader: RequiredReader): +def _read_entry_from_archive(entry: dict, uploads, required_reader: RequiredReader): entry_id, upload_id = entry['entry_id'], entry['upload_id'] # all other exceptions are handled by the caller `_read_entries_from_archive` @@ -703,11 +700,8 @@ def _read_entry_from_archive(entry, uploads, required_reader: RequiredReader): upload_files = uploads.get_upload_files(upload_id) with upload_files.read_archive(entry_id, True) as archive: - return { - 'entry_id': entry_id, - 'upload_id': upload_id, - 'parser_name': entry['parser_name'], - 'archive': required_reader.read(archive, entry_id, upload_id)} + entry['archive'] = required_reader.read(archive, entry_id, upload_id) + return entry except ArchiveQueryError as e: raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=str(e)) except KeyError as e: @@ -716,7 +710,7 @@ def _read_entry_from_archive(entry, uploads, required_reader: RequiredReader): return None -def _read_entries_from_archive(entries, required, user): +def _read_entries_from_archive(entries: Union[list, dict], required: ArchiveRequired, user): ''' Takes pickleable arguments so that it can be offloaded to worker processes. @@ -725,10 +719,10 @@ def _read_entries_from_archive(entries, required, user): with _Uploads() as uploads: required_reader = _validate_required(required, user) - responses = [_read_entry_from_archive( - entry, uploads, required_reader) for entry in entries if entry is not None] + if isinstance(entries, dict): + return _read_entry_from_archive(entries, uploads, required_reader) - return list(filter(None, responses)) + return [_read_entry_from_archive(entry, uploads, required_reader) for entry in entries] def _answer_entries_archive_request( @@ -749,39 +743,29 @@ def _answer_entries_archive_request( required=MetadataRequired(include=['entry_id', 'upload_id', 'parser_name']), user_id=user.user_id if user is not None else None) - # fewer than 8 entries per process is not useful + entries: list = [{ + 'entry_id': entry['entry_id'], 'upload_id': entry['upload_id'], + 'parser_name': entry['parser_name']} for entry in search_response.data] + + # fewer than config.archive.min_entries_per_process entries per process is not useful # more than config.max_process_number processes is too much for the server number: int = min( - len(search_response.data) // config.archive.min_entires_per_process, + int(math.ceil(len(entries) / config.archive.min_entries_per_process)), config.archive.max_process_number) if number <= 1: - request_data: list = _read_entries_from_archive(search_response.data, required, user) + request_data: list = _read_entries_from_archive(entries, required, user) else: - entries_per_process = len(search_response.data) // number + 1 - - # use process pool - pool: multiprocessing.pool.Pool = multiprocessing.pool.Pool(processes=number) - - try: - responses = pool.map( - functools.partial(_read_entries_from_archive, required=required, user=user), - zip_longest(*[iter(search_response.data)] * entries_per_process)) - finally: - # gracefully shutdown the pool - pool.close() - pool.join() - - # collect results from each process - # https://stackoverflow.com/a/45323085 - request_data = functools.reduce(operator.iconcat, responses, []) + with parallel_backend('threading', n_jobs=number): + request_data = Parallel()(delayed( + _read_entries_from_archive)(i, required, user) for i in entries) return EntriesArchiveResponse( owner=search_response.owner, query=search_response.query, pagination=search_response.pagination, required=required, - data=request_data) + data=list(filter(None, request_data))) _entries_archive_docstring = strip(''' diff --git a/nomad/config.py b/nomad/config.py index ce2467f363fb045490229d352c0d4b4704bfcb99..0596f3365ba0fc64ebaf78ca110fc7afc977e075 100644 --- a/nomad/config.py +++ b/nomad/config.py @@ -423,7 +423,7 @@ archive = NomadConfig( block_size=256 * 1024, read_buffer_size=256 * 1024, # GPFS needs at least 256K to achieve decent performance max_process_number=20, # maximum number of processes can be assigned to process archive query - min_entires_per_process=20 # minimum number of entries per process + min_entries_per_process=20 # minimum number of entries per process ) diff --git a/nomad/metainfo/metainfo.py b/nomad/metainfo/metainfo.py index f9ebe6327af4f66713f9450709c8c579eac057f9..604a1a552333b21b9712dede6a48b755840161f2 100644 --- a/nomad/metainfo/metainfo.py +++ b/nomad/metainfo/metainfo.py @@ -123,10 +123,7 @@ class MEnum(Sequence): self._list = list(kwargs.values()) self._values = set(kwargs.values()) # For allowing constant time member check - self._map = kwargs - - def __getattr__(self, attr): - return self._map[attr] + self.__dict__.update(kwargs) def __getitem__(self, index): return self._list[index] diff --git a/requirements.txt b/requirements.txt index 114f6732b49f15ed3fdd0abdae892929a30f9b12..f7bf8d9a9d170580e8dc5dcd503aa7593eb40608 100644 --- a/requirements.txt +++ b/requirements.txt @@ -91,6 +91,7 @@ dockerspawner==12.1.0 oauthenticator==14.2.0 validators==0.18.2 aiofiles==0.8.0 +joblib==1.1.0 # [dev] markupsafe==2.0.1