diff --git a/nomad/client/archive.py b/nomad/client/archive.py index a4e3ea4b34bae8903315acb57a8d42a5e3966791..e9840bfc821269362477232adc73ffea795d6ad1 100644 --- a/nomad/client/archive.py +++ b/nomad/client/archive.py @@ -19,7 +19,8 @@ from __future__ import annotations import asyncio from asyncio import Semaphore -from typing import Any, Dict, List +from itertools import islice +from typing import Any import threading from click import progressbar @@ -124,6 +125,7 @@ class ArchiveQuery: they may wish to start from a specific upload, default: '' results_max (int): maximum results to query, default: 1000 page_size (int): size of page in each query, cannot exceed the limit 10000, default: 100 + batch_size (int): size of page in each download request, default: 10 username (str): username for authenticated access, default: '' password (str): password for authenticated access, default: '' retry (int): number of retry when fetching uploads, default: 4 @@ -140,6 +142,7 @@ class ArchiveQuery: after: str = None, results_max: int = 1000, page_size: int = 100, + batch_size: int = 10, username: str = None, password: str = None, retry: int = 4, @@ -160,9 +163,11 @@ class ArchiveQuery: self._page_size: int = min(page_size, 9999) if page_size > 0 else 100 if self._page_size > self._results_max: self._page_size = self._results_max + self._batch_size: int = batch_size if batch_size > 0 else 10 self._retry: int = retry if retry >= 0 else 4 self._sleep_time: float = sleep_time if sleep_time > 0.0 else 4.0 - self._semaphore = semaphore + self._semaphore = min(10, semaphore) if semaphore > 0 else 4 + self._results_actual: int = 0 from nomad.client import Auth @@ -176,7 +181,7 @@ class ArchiveQuery: # local data storage self._entries: list[tuple[str, str]] = [] - self._entries_dict: list[Dict] = [] + self._entries_dict: list[dict] = [] self._current_after: str = self._after self._current_results: int = 0 @@ -219,17 +224,17 @@ class ArchiveQuery: return request - def _download_request(self, entry_id: str) -> dict: + def _download_request(self, entry_ids: list[str]) -> dict: """ Generate download request. """ - request: Dict[str, Any] = dict(owner=self._owner, required=self._required) + request: dict[str, Any] = dict(owner=self._owner, required=self._required) request['query'] = {'and': []} for t_list in self._query_list: request['query']['and'].append(t_list) - request['query']['and'].append({'entry_id': entry_id}) - request.setdefault('pagination', {'page_size': 1}) + request['query']['and'].append({'entry_id:any': entry_ids}) + request.setdefault('pagination', {'page_size': len(entry_ids)}) # print(f'Current request: {request}') @@ -243,6 +248,18 @@ class ArchiveQuery: self._entries = [] self._current_after = self._after self._current_results = 0 + self._results_actual = 0 + + @property + def _actual_max(self): + """ + The actual maximum number of entries available. + """ + + if self._results_actual: + return min(self._results_actual, self._results_max) + + return self._results_max async def _fetch_async(self, number: int) -> int: """ @@ -260,12 +277,12 @@ class ArchiveQuery: # if the maximum number of entries has been previously fetched # not going to fetch more entries - if self._current_results >= self._results_max: + if self._current_results >= self._actual_max: return 0 # get all entries at once if number == 0: - number = self._results_max + number = self._actual_max - self._current_results num_retry: int = 0 num_entry: int = 0 @@ -301,9 +318,9 @@ class ArchiveQuery: response_json = response.json() - self._current_after = response_json['pagination'].get( - 'next_page_after_value', None - ) + pagination = response_json['pagination'] + self._current_after = pagination.get('next_page_after_value', None) + self._results_actual = pagination.get('total', 0) data = [ (entry['entry_id'], entry['upload_id']) @@ -315,9 +332,9 @@ class ArchiveQuery: if current_size == 0: break - if self._current_results + current_size > self._results_max: + if self._current_results + current_size > self._actual_max: # current query has sufficient entries to exceed the limit - data = data[: self._results_max - self._current_results] + data = data[: self._actual_max - self._current_results] self._current_results += len(data) self._entries.extend(data) break @@ -339,7 +356,7 @@ class ArchiveQuery: return num_entry - async def _download_async(self, number: int) -> List[EntryArchive]: + async def _download_async(self, number: int) -> list[EntryArchive]: """ Download required entries asynchronously. @@ -351,26 +368,37 @@ class ArchiveQuery: """ semaphore = Semaphore(self._semaphore) + actual_number: int = min(number, len(self._entries)) + + def batched(iterable, chunk_size): + iterator = iter(iterable) + while chunk := list(islice(iterator, chunk_size)): + yield chunk + with progressbar( - length=number, label=f'Downloading {number} entries...' + length=actual_number, label=f'Downloading {actual_number} entries...' ) as bar: async with AsyncClient(timeout=Timeout(timeout=300)) as session: tasks = [ asyncio.create_task(self._acquire(ids, session, semaphore, bar)) - for ids in self._entries[:number] + for ids in batched(self._entries[:actual_number], self._batch_size) ] results = await asyncio.gather(*tasks) - return [result for result in results if result] + return [archive for result in results if result for archive in result] async def _acquire( - self, ids: tuple[str, str], session: AsyncClient, semaphore: Semaphore, bar - ) -> EntryArchive | None: + self, + ids: list[tuple[str, str]], + session: AsyncClient, + semaphore: Semaphore, + bar, + ) -> list[EntryArchive] | None: """ Perform the download task. Params: - upload (Tuple[str, int]): upload + ids (list[tuple[str, str]]): a list of tuples of entry id and upload id session (httpx.AsyncClient): httpx client @@ -379,36 +407,43 @@ class ArchiveQuery: Returns: A list of EntryArchive """ + entry_ids = [x for x, _ in ids] - entry_id, upload_id = ids - - request = self._download_request(entry_id) + request = self._download_request(entry_ids) async with semaphore: response = await session.post( self._download_url, json=request, headers=self._auth.headers() ) - bar.update(1) - self._entries.remove(ids) + bar.update(len(entry_ids)) + for item in ids: + self._entries.remove(item) if response.status_code >= 400: print( - f'Request with entry id {entry_id} returns {response.status_code},' - f' will retry in the next download call...' + f'Request returns {response.status_code}, will retry in the next download call...' ) - self._entries.append(ids) + self._entries.extend(ids) return None # successfully downloaded data - context = ClientContext(self._url, upload_id=upload_id, auth=self._auth) - result = EntryArchive.m_from_dict( - response.json()['data'][0]['archive'], m_context=context - ) + results: list = [] + + response_json: list = response.json()['data'] + + for index in range(len(ids)): + entry_id, upload_id = ids[index] + context = ClientContext(self._url, upload_id=upload_id, auth=self._auth) + result = EntryArchive.m_from_dict( + response_json[index]['archive'], m_context=context + ) + + if not result: + print(f'No result returned for id {entry_id}, is the query proper?') - if not result: - print(f'No result returned for id {entry_id}, is the query proper?') + results.append(result) - return result + return results def fetch(self, number: int = 0) -> int: """ @@ -425,7 +460,7 @@ class ArchiveQuery: return run_async(self._fetch_async, number) - def download(self, number: int = 0) -> List[EntryArchive]: + def download(self, number: int = 0) -> list[EntryArchive]: """ Download fetched entries from remote. Automatically call .fetch() if not fetched. @@ -465,7 +500,7 @@ class ArchiveQuery: return await self._fetch_async(number) - async def async_download(self, number: int = 0) -> List[EntryArchive]: + async def async_download(self, number: int = 0) -> list[EntryArchive]: """ Asynchronous interface for use in a running event loop. """ diff --git a/tests/test_client.py b/tests/test_client.py index e545519ebd61a70f1943b1116d20e6cd4635ed69..8d0764f2fc23026e72007e9c98362f6b74d614bb 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -143,6 +143,7 @@ def test_async_query_parallel(async_api_v1, many_uploads, monkeypatch): async_query = ArchiveQuery(required=dict(run='*')) assert_results(async_query.download(), total=4) + assert_results(async_query.download(), total=0) async_query = ArchiveQuery(required=dict(run='*'), page_size=1)