From 072952d29940df4d16e010f870af1d5fdb1eb584 Mon Sep 17 00:00:00 2001 From: Haoyu Yang <yanghaoyu97@outlook.com> Date: Mon, 24 Mar 2025 09:05:32 +0000 Subject: [PATCH] Added experimental API endpoint for bulk entries metadata export Changelog: Added --- nomad/app/v1/routers/entries.py | 135 +++++++++++++++++++++- nomad/app/v1/utils.py | 5 +- nomad/config/defaults.yaml | 1 + nomad/config/models/config.py | 4 + tests/app/v1/routers/test_entries.py | 164 +++++++++++++++++++++++---- 5 files changed, 283 insertions(+), 26 deletions(-) diff --git a/nomad/app/v1/routers/entries.py b/nomad/app/v1/routers/entries.py index f974e15421..65d025c220 100644 --- a/nomad/app/v1/routers/entries.py +++ b/nomad/app/v1/routers/entries.py @@ -15,17 +15,28 @@ # See the License for the specific language governing permissions and # limitations under the License. # + +import csv import io import json import os.path -from collections.abc import Iterator +from collections.abc import AsyncIterator, Iterator from datetime import datetime from enum import Enum from typing import Any import orjson import yaml -from fastapi import APIRouter, Body, Depends, HTTPException, Path, Request, status +from fastapi import ( + APIRouter, + Body, + Depends, + Header, + HTTPException, + Path, + Request, + status, +) from fastapi import Query as QueryParameter from fastapi.exceptions import RequestValidationError from fastapi.responses import ORJSONResponse, StreamingResponse @@ -474,10 +485,13 @@ def perform_search(*args, **kwargs): search_response = search(*args, **kwargs) search_response.es_query = None return search_response + except QueryValidationError as e: raise RequestValidationError(errors=e.errors) + except AuthenticationRequiredError as e: raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=str(e)) + except SearchError as e: raise HTTPException( status.HTTP_400_BAD_REQUEST, @@ -507,7 +521,7 @@ async def post_entries_metadata_query( By default the *empty* search (that returns everything) is performed. Only a small page of the search results are returned at a time; use `pagination` in subsequent - requests to retrive more data. Each entry has a lot of different *metadata*, use + requests to retrieve more data. Each entry has a lot of different *metadata*, use `required` to limit the data that is returned. The `statistics` and `aggregations` keys will further allow to return statistics @@ -876,6 +890,121 @@ async def get_entries_raw( ) +@router.get( + '/export', + tags=[APITag.METADATA], + summary='Search entries and download their metadata in selected format', + response_class=StreamingResponse, + responses=create_responses(_bad_owner_response), +) +async def export_entries_metadata( + with_query: WithQuery = Depends(query_parameters), + content_type: str = Header('application/json'), + required: MetadataRequired = Depends(metadata_required_parameters), + user: User = Depends(create_user_dependency(signature_token_auth_allowed=True)), + page_size: int = QueryParameter(10_000, gt=0), +): + """(**Experimental**) Export metadata entries in a selected format. + + This endpoint allows users to export metadata entries in either JSON or CSV format. + The format must be specified via the `Content-Type` HTTP header: + - `application/json` → Returns the metadata as a JSON response. + - `text/csv` → Returns the metadata as a CSV file. + """ + if with_query.owner == Owner.all_: + raise HTTPException( + status.HTTP_401_UNAUTHORIZED, + detail=strip( + """ + The owner=all is not allowed for this operation as it will search for entries + that you might now be allowed to access. + """ + ), + ) + + response = perform_search( + owner=with_query.owner, + query=with_query.query, + pagination=MetadataPagination(page_size=0), + required=MetadataRequired(include=[]), + user_id=user.user_id if user is not None else None, + ) + + if response.pagination.total > config.services.max_entry_metadata_download: + raise HTTPException( + status.HTTP_400_BAD_REQUEST, + detail=( + f'The limit of maximum number of metadata in a single download ' + f'({config.services.max_entry_metadata_download}) has been exceeded ({response.pagination.total}).' + ), + ) + + async def json_stream() -> AsyncIterator[bytes]: + """Stream metadata in JSON format.""" + first_item: bool = True + yield b'[' # Start of JSON array + + for entry_metadata in _do_exhaustive_search( + owner=with_query.owner, + query=with_query.query, + user=user, + required=required, + page_size=page_size, + ): + if not first_item: + yield b',' # Separate JSON objects + first_item = False + yield json.dumps(entry_metadata, default=str).encode('utf-8') + + yield b']' # End of JSON array + + async def csv_stream() -> AsyncIterator[bytes]: + """Stream metadata in CSV format.""" + first_row: bool = True + buffer: io.StringIO = io.StringIO() + writer: csv.DictWriter | None = None + + for entry_metadata in _do_exhaustive_search( + owner=with_query.owner, + query=with_query.query, + user=user, + required=required, + page_size=page_size, + ): + if first_row: + writer = csv.DictWriter(buffer, fieldnames=entry_metadata.keys()) + yield buffer.getvalue().encode('utf-8') # Send column headers + buffer.seek(0) + buffer.truncate(0) # Clear buffer + writer.writeheader() + first_row = False + + writer.writerow(entry_metadata) + yield buffer.getvalue().encode('utf-8') # Send row data + buffer.seek(0) + buffer.truncate(0) + + if content_type == 'text/csv': + return StreamingResponse( + csv_stream(), + media_type=content_type, + headers=browser_download_headers(filename='metadata_export.csv'), + ) + + elif content_type == 'application/json': + return StreamingResponse( + json_stream(), + media_type=content_type, + headers=browser_download_headers(filename='metadata_export.json'), + ) + + else: + raise HTTPException( + status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE, + detail=f"Unsupported {content_type=}. Expected 'application/json' or 'text/csv'.", + ) + + def _read_archive(entry_metadata, uploads, required_reader: RequiredReader): entry_id = entry_metadata['entry_id'] upload_id = entry_metadata['upload_id'] diff --git a/nomad/app/v1/utils.py b/nomad/app/v1/utils.py index c62316fe4e..6d95f48f70 100644 --- a/nomad/app/v1/utils.py +++ b/nomad/app/v1/utils.py @@ -245,8 +245,9 @@ async def create_stream_from_string(content: str): yield x -def create_responses(*args): - return {status_code: response for status_code, response in args} +def create_responses(*args) -> dict: + """Pack status code-response pairs into a dictionary.""" + return dict(args) def browser_download_headers( diff --git a/nomad/config/defaults.yaml b/nomad/config/defaults.yaml index 9ed0dec09e..9c7c00e5e0 100644 --- a/nomad/config/defaults.yaml +++ b/nomad/config/defaults.yaml @@ -210,6 +210,7 @@ services: image_resource_http_max_age: 2592000 log_api_queries: true max_entry_download: 50000 + max_entry_metadata_download: 100000 optimade_enabled: true unavailable_value: unavailable upload_limit: 10 diff --git a/nomad/config/models/config.py b/nomad/config/models/config.py index 385f0d8ba6..245945e022 100644 --- a/nomad/config/models/config.py +++ b/nomad/config/models/config.py @@ -169,6 +169,10 @@ class Services(ConfigBaseModel): Page-after-value-based pagination is independent and can be used without limitations. """, ) + max_entry_metadata_download: int = Field( + 100_000, + description='The maximum amount of entries metadata that can be downloaded.', + ) unavailable_value: str = Field( 'unavailable', description=""" diff --git a/tests/app/v1/routers/test_entries.py b/tests/app/v1/routers/test_entries.py index 36985eb08f..fbefb3a3c7 100644 --- a/tests/app/v1/routers/test_entries.py +++ b/tests/app/v1/routers/test_entries.py @@ -16,9 +16,11 @@ # limitations under the License. # +import csv import io import json import zipfile +from typing import Literal from urllib.parse import urlencode import pytest @@ -49,15 +51,13 @@ from .common import ( post_query_test_parameters, ) -""" -These are the tests for all API operations below ``entries``. The tests are organized -using the following type of methods: fixtures, ``perfrom_*_test``, ``assert_*``, and -``test_*``. While some ``test_*`` methods test individual API operations, some -test methods will test multiple API operations that use common aspects like -supporting queries, pagination, or the owner parameter. The test methods will use -``perform_*_test`` methods as an parameter. Similarely, the ``assert_*`` methods allow -to assert for certain aspects in the responses. -""" +# These are the tests for all API operations below ``entries``. The tests are organized +# using the following type of methods: fixtures, ``perfrom_*_test``, ``assert_*``, and +# ``test_*``. While some ``test_*`` methods test individual API operations, some +# test methods will test multiple API operations that use common aspects like +# supporting queries, pagination, or the owner parameter. The test methods will use +# ``perform_*_test`` methods as an parameter. Similarely, the ``assert_*`` methods allow +# to assert for certain aspects in the responses. def perform_entries_raw_test( @@ -96,7 +96,7 @@ def perform_entries_raw_test( ) else: - assert False + pytest.fail(f'Invalid HTTP method {http_method}') assert_response(response, status_code) if status_code == 200: @@ -139,7 +139,7 @@ def perform_entries_rawdir_test( response = client.post('entries/rawdir/query', headers=headers, json=body) else: - assert False + pytest.fail(f'Invalid HTTP method {http_method}') response_json = assert_base_metadata_response(response, status_code=status_code) @@ -193,7 +193,7 @@ def perform_entries_archive_download_test( ) else: - assert False + pytest.fail(f'Invalid HTTP method {http_method}') assert_response(response, status_code) if status_code == 200: @@ -376,7 +376,7 @@ def assert_archive(archive, required=None): assert key in archive -program_name = 'results.method.simulation.program_name' +PROGRAM_NAME: str = 'results.method.simulation.program_name' def test_entries_all_metrics(client, example_data): @@ -617,7 +617,7 @@ def test_entry_metadata( id='child-entries', ), pytest.param( - None, None, {program_name: 'DOESNOTEXIST'}, {}, 0, 5, 200, id='empty' + None, None, {PROGRAM_NAME: 'DOESNOTEXIST'}, {}, 0, 5, 200, id='empty' ), ], ) @@ -663,7 +663,7 @@ def test_entries_rawdir( id='child-entries', ), pytest.param( - None, None, {program_name: 'DOESNOTEXIST'}, {}, 0, 5, 200, id='empty' + None, None, {PROGRAM_NAME: 'DOESNOTEXIST'}, {}, 0, 5, 200, id='empty' ), pytest.param(None, None, {}, {'glob_pattern': '*.json'}, 23, 1, 200, id='glob'), pytest.param( @@ -761,6 +761,128 @@ def test_entries_download_max( test_method(client, status_code=400, http_method=http_method) +class TestEntriesExportMetadata: + @pytest.mark.parametrize( + 'user, owner, query, status_code', + [ + pytest.param(None, None, {}, 200, id='all'), + pytest.param(None, 'all', {}, 401, id='owner_all'), + ], + ) + @pytest.mark.parametrize('content_type', ['application/json', 'text/csv']) + def test_entries_export_metadata( + self, + auth_headers, + client, + example_data, + user, + owner, + query, + status_code, + content_type, + page_size: int = 10_000, + ): + if owner == 'all': + # This operation is not allow for owner 'all' + status_code = 401 + + params = dict(**query) + params['page_size'] = page_size + + if owner is not None: + params['owner'] = owner + + response = client.get( + f'entries/export?{urlencode(params)}', + headers={**(auth_headers[user] or {}), 'Content-Type': content_type}, + ) + + assert_response(response, status_code) + + if status_code == 200: + if content_type == 'application/json': + for i, entry in enumerate(response.json(), start=1): + assert entry['upload_id'] == 'id_published' + assert entry['entry_id'] == f'id_{i:02d}' + + else: + csv_data = io.StringIO(response.text) + rows = list(csv.reader(csv_data)) + assert len(rows) >= 2 + + assert set(rows[0]).issuperset({'upload_id', 'entry_id'}) + + for i, entry_row in enumerate(rows[1:], start=1): + assert 'id_published' in entry_row # upload_id + assert f'id_{i:02d}' in entry_row # entry_id + + @pytest.mark.parametrize( + 'user, owner, query, status_code, content_type', + [ + pytest.param(None, None, {}, 415, 'invalid', id='all'), + ], + ) + def test_invalid_content_type( + self, + auth_headers, + client, + example_data, + user, + owner, + query, + status_code, + content_type, + ): + params = dict(**query) + + if owner is not None: + params['owner'] = owner + + response = client.get( + f'entries/export?{urlencode(params)}', + headers={**(auth_headers[user] or {}), 'Content-Type': content_type}, + ) + + assert_response(response, status_code) + + @pytest.mark.parametrize( + 'user, owner, query, status_code', + [ + pytest.param(None, None, {}, 400, id='all'), + ], + ) + @pytest.mark.parametrize('content_type', ['application/json', 'text/csv']) + def test_max_entry_metadata_download( + self, + auth_headers, + client, + example_data, + monkeypatch, + user, + owner, + query, + status_code, + content_type: Literal['application/json', 'text/csv'], + ): + monkeypatch.setattr('nomad.config.services.max_entry_metadata_download', 1) + + params = dict(**query) + if owner is not None: + params['owner'] = owner + + response = client.get( + 'entries/export', + headers={**(auth_headers[user] or {}), 'Content-Type': content_type}, + ) + + assert response.status_code == status_code + + expected_message = ( + 'The limit of maximum number of metadata in a single download' + ) + assert expected_message in response.json()['detail'] + + @pytest.mark.parametrize( 'user, entry_id, files_per_entry, status_code', [ @@ -1014,7 +1136,7 @@ def test_entry_raw_file( ), pytest.param(None, None, {}, {'metadata': '*'}, {}, 23, 200, id='required'), pytest.param( - None, None, {program_name: 'DOESNOTEXIST'}, None, {}, -1, 200, id='empty' + None, None, {PROGRAM_NAME: 'DOESNOTEXIST'}, None, {}, -1, 200, id='empty' ), pytest.param(None, None, {}, None, {'compress': True}, 23, 200, id='compress'), ], @@ -1154,8 +1276,8 @@ def test_entry_archive_query( assert_archive_response(response.json(), required=required) -elements = 'results.material.elements' -n_elements = 'results.material.n_elements' +ELEMENTS: str = 'results.material.elements' +N_ELEMENTS: str = 'results.material.n_elements' @pytest.mark.parametrize( @@ -1231,7 +1353,7 @@ def test_entries_post_query( 'total_gt': 22, }, int={ - 'name': 'results.material.n_elements', + 'name': N_ELEMENTS, 'values': [2, 1], 'total': 23, 'total_any': 23, @@ -1358,8 +1480,8 @@ def test_entries_owner( @pytest.mark.parametrize( 'pagination, response_pagination, status_code', pagination_test_parameters( - elements='results.material.elements', - n_elements='results.material.n_elements', + elements=ELEMENTS, + n_elements=N_ELEMENTS, crystal_system='results.material.symmetry.crystal_system', total=23, ), -- GitLab