Commit 93b71d1a authored by tlc@void's avatar tlc@void
Browse files

Use joblib

parent 72a377e0
Pipeline #134853 passed with stages
in 35 minutes and 4 seconds
......@@ -15,12 +15,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import math
from datetime import datetime
import functools
from itertools import zip_longest
import multiprocessing
import operator
from typing import Optional, Set, Union, Dict, Iterator, Any, List
from fastapi import (
......@@ -35,6 +31,7 @@ import json
import orjson
from pydantic.main import create_model
from starlette.responses import Response
from joblib import Parallel, delayed, parallel_backend
from nomad import files, config, utils, metainfo, processing as proc
from nomad import datamodel
......@@ -695,7 +692,7 @@ def _validate_required(required: ArchiveRequired, user) -> RequiredReader:
detail=[dict(msg=e.msg, loc=['required'] + e.loc)])
def _read_entry_from_archive(entry, uploads, required_reader: RequiredReader):
def _read_entry_from_archive(entry: dict, uploads, required_reader: RequiredReader):
entry_id, upload_id = entry['entry_id'], entry['upload_id']
# all other exceptions are handled by the caller `_read_entries_from_archive`
......@@ -703,11 +700,8 @@ def _read_entry_from_archive(entry, uploads, required_reader: RequiredReader):
upload_files = uploads.get_upload_files(upload_id)
with upload_files.read_archive(entry_id, True) as archive:
return {
'entry_id': entry_id,
'upload_id': upload_id,
'parser_name': entry['parser_name'],
'archive': required_reader.read(archive, entry_id, upload_id)}
entry['archive'] = required_reader.read(archive, entry_id, upload_id)
return entry
except ArchiveQueryError as e:
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=str(e))
except KeyError as e:
......@@ -716,7 +710,7 @@ def _read_entry_from_archive(entry, uploads, required_reader: RequiredReader):
return None
def _read_entries_from_archive(entries, required, user):
def _read_entries_from_archive(entries: Union[list, dict], required: ArchiveRequired, user):
'''
Takes pickleable arguments so that it can be offloaded to worker processes.
......@@ -725,10 +719,10 @@ def _read_entries_from_archive(entries, required, user):
with _Uploads() as uploads:
required_reader = _validate_required(required, user)
responses = [_read_entry_from_archive(
entry, uploads, required_reader) for entry in entries if entry is not None]
if isinstance(entries, dict):
return _read_entry_from_archive(entries, uploads, required_reader)
return list(filter(None, responses))
return [_read_entry_from_archive(entry, uploads, required_reader) for entry in entries]
def _answer_entries_archive_request(
......@@ -749,39 +743,29 @@ def _answer_entries_archive_request(
required=MetadataRequired(include=['entry_id', 'upload_id', 'parser_name']),
user_id=user.user_id if user is not None else None)
# fewer than 8 entries per process is not useful
entries: list = [{
'entry_id': entry['entry_id'], 'upload_id': entry['upload_id'],
'parser_name': entry['parser_name']} for entry in search_response.data]
# fewer than config.archive.min_entries_per_process entries per process is not useful
# more than config.max_process_number processes is too much for the server
number: int = min(
len(search_response.data) // config.archive.min_entires_per_process,
int(math.ceil(len(entries) / config.archive.min_entries_per_process)),
config.archive.max_process_number)
if number <= 1:
request_data: list = _read_entries_from_archive(search_response.data, required, user)
request_data: list = _read_entries_from_archive(entries, required, user)
else:
entries_per_process = len(search_response.data) // number + 1
# use process pool
pool: multiprocessing.pool.Pool = multiprocessing.pool.Pool(processes=number)
try:
responses = pool.map(
functools.partial(_read_entries_from_archive, required=required, user=user),
zip_longest(*[iter(search_response.data)] * entries_per_process))
finally:
# gracefully shutdown the pool
pool.close()
pool.join()
# collect results from each process
# https://stackoverflow.com/a/45323085
request_data = functools.reduce(operator.iconcat, responses, [])
with parallel_backend('threading', n_jobs=number):
request_data = Parallel()(delayed(
_read_entries_from_archive)(i, required, user) for i in entries)
return EntriesArchiveResponse(
owner=search_response.owner,
query=search_response.query,
pagination=search_response.pagination,
required=required,
data=request_data)
data=list(filter(None, request_data)))
_entries_archive_docstring = strip('''
......
......@@ -423,7 +423,7 @@ archive = NomadConfig(
block_size=256 * 1024,
read_buffer_size=256 * 1024, # GPFS needs at least 256K to achieve decent performance
max_process_number=20, # maximum number of processes can be assigned to process archive query
min_entires_per_process=20 # minimum number of entries per process
min_entries_per_process=20 # minimum number of entries per process
)
......
......@@ -123,10 +123,7 @@ class MEnum(Sequence):
self._list = list(kwargs.values())
self._values = set(kwargs.values()) # For allowing constant time member check
self._map = kwargs
def __getattr__(self, attr):
return self._map[attr]
self.__dict__.update(kwargs)
def __getitem__(self, index):
return self._list[index]
......
......@@ -91,6 +91,7 @@ dockerspawner==12.1.0
oauthenticator==14.2.0
validators==0.18.2
aiofiles==0.8.0
joblib==1.1.0
# [dev]
markupsafe==2.0.1
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment