diff --git a/.gitignore b/.gitignore index 9d2c38aca1d3a1e1ca604f278370798a68926e7f..2af58ec7a6d077cd1143d4370402778929a35dcf 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ .python-version .ipynb_checkpoints/ .python-version +.coverage_html/ __pycache__ .mypy_cache *.pyc @@ -12,7 +13,8 @@ __pycache__ /data/ .volumes/ .pytest_cache/ -.coverage +.coverage* +htmlcov try.http project/ test_*/ @@ -23,8 +25,8 @@ target/ .vscode/ vscode/ nomad.yaml -./gunicorn.log.conf -./gunicorn.conf +gunicorn.log.conf +gunicorn.conf build/ dist/ setup.json @@ -32,4 +34,3 @@ parser.osio.log gui/src/metainfo.json gui/src/searchQuantities.json examples/workdir/ -gunicorn.log.conf diff --git a/.pylintrc b/.pylintrc index e356665f2e6769c80fdfd9ce96b7ab76663b250c..103669a5538093112ec6acddd70288e3f74b88c5 100644 --- a/.pylintrc +++ b/.pylintrc @@ -3,7 +3,7 @@ # A comma-separated list of package or module names from where C extensions may # be loaded. Extensions are loading into the active Python interpreter and may # run arbitrary code. -extension-pkg-whitelist= +extension-pkg-whitelist=pydantic # Add files or directories to the blacklist. They should be base names, not # paths. @@ -666,7 +666,7 @@ ignore-on-opaque-inference=yes # List of class names for which member attributes should not be checked (useful # for classes with dynamically set attributes). This supports the use of # qualified names. -ignored-classes=optparse.Values,thread._local,_thread._local +ignored-classes=optparse.Values,thread._local,_thread._local,SearchResponse # List of module names for which member attributes should not be checked # (useful for modules/projects where namespaces are manipulated during runtime diff --git a/gui/public/env.js b/gui/public/env.js index 7ec27f5cb0db525dea2779792f9a4453d9ef9763..cc1f8be774d038e75ebcec9c8d32d730b8123fea 100644 --- a/gui/public/env.js +++ b/gui/public/env.js @@ -3,7 +3,7 @@ window.nomadEnv = { 'keycloakRealm': 'fairdi_nomad_test', 'keycloakClientId': 'nomad_gui_dev', 'appBase': 'http://nomad-lab.eu/prod/rae/beta', - 'appBase': 'http://localhost:8000/fairdi/nomad/latest', + 'appBase': 'http://localhost:8000', 'debug': false, 'matomoEnabled': false, 'matomoUrl': 'https://nomad-lab.eu/fairdi/stat', diff --git a/gui/src/config.js b/gui/src/config.js index e17fb19ec341acfb14d8ca70ff1f3d9d1977961d..f9e344c831659783aab8781c443528d1b8b28ffe 100644 --- a/gui/src/config.js +++ b/gui/src/config.js @@ -22,6 +22,7 @@ export const version = window.nomadEnv.version export const appBase = window.nomadEnv.appBase.replace(/\/$/, '') // export const apiBase = 'http://nomad-lab.eu/prod/rae/api' export const apiBase = `${appBase}/api` +export const apiV1Base = `${appBase}/api/v1` export const optimadeBase = `${appBase}/optimade` export const guiBase = process.env.PUBLIC_URL export const matomoUrl = window.nomadEnv.matomoUrl diff --git a/nomad/app/__init__.py b/nomad/app/__init__.py index 9bc865ee1ff9dd63d3a9184f7b8f8a4675349d8c..b99ac0217486600d8f7c512ee1289dad42e2efd5 100644 --- a/nomad/app/__init__.py +++ b/nomad/app/__init__.py @@ -90,7 +90,6 @@ if config.services.https: app = Flask(__name__) ''' The Flask app that serves all APIs. ''' -app.config.APPLICATION_ROOT = common.base_path # type: ignore app.config.RESTPLUS_MASK_HEADER = False # type: ignore app.config.RESTPLUS_MASK_SWAGGER = False # type: ignore app.config.SWAGGER_UI_OPERATION_ID = True # type: ignore @@ -98,17 +97,6 @@ app.config.SWAGGER_UI_REQUEST_DURATION = True # type: ignore app.config['SECRET_KEY'] = config.services.api_secret - -def api_base_path_response(env, resp): - resp('200 OK', [('Content-Type', 'text/plain')]) - return [ - ('Development nomad api server. Api is served under %s/.' % - config.services.api_base_path).encode('utf-8')] - - -app.wsgi_app = DispatcherMiddleware( # type: ignore - api_base_path_response, {config.services.api_base_path: app.wsgi_app}) - CORS(app) app.register_blueprint(api_blueprint, url_prefix='/api') @@ -182,18 +170,3 @@ def before_request(): if config.services.api_chaos > 0: if random.randint(0, 100) <= config.services.api_chaos: abort(random.choice([400, 404, 500]), 'With best wishes from the chaos monkey.') - - -@app.before_first_request -def setup(): - from nomad import infrastructure - - if not app.config['TESTING']: - # each subprocess is supposed disconnect connect again: https://jira.mongodb.org/browse/PYTHON-2090 - try: - from mongoengine import disconnect - disconnect() - except Exception: - pass - - infrastructure.setup() diff --git a/nomad/app/api/common.py b/nomad/app/api/common.py index 23930679a2f204a05db7677c4bab4ad5f8d2fc4d..4ea24caf48ca0f35eef136072ff52123d113dfd5 100644 --- a/nomad/app/api/common.py +++ b/nomad/app/api/common.py @@ -213,8 +213,8 @@ def apply_search_parameters(search_request: search.SearchRequest, args: Dict[str search_request.owner( owner, g.user.user_id if g.user is not None else None) - except ValueError as e: - abort(401, getattr(e, 'message', 'Invalid owner parameter: %s' % owner)) + except search.AuthenticationRequiredError as e: + abort(401, str(e)) except Exception as e: abort(400, getattr(e, 'message', 'Invalid owner parameter')) diff --git a/nomad/app/common.py b/nomad/app/common.py index 2cf51ddc3ad8037ac4c3298136f9d4d7a52453dd..e9231f1ab9aef6e35e4b921e3a0a77f5e69a166a 100644 --- a/nomad/app/common.py +++ b/nomad/app/common.py @@ -22,15 +22,10 @@ from datetime import datetime import pytz from contextlib import contextmanager -from nomad import config - logger: BoundLogger = None ''' A logger pre configured with information about the current request. ''' -base_path = config.services.api_base_path -''' Provides the root path of the nomad APIs. ''' - class RFC3339DateTime(fields.DateTime): diff --git a/nomad/app/optimade/api.py b/nomad/app/optimade/api.py index 914a8818359d6b65c8cffe418af429175df8f262..a4c1f9e1dac53b24b8a150812c9042a54e84b400 100644 --- a/nomad/app/optimade/api.py +++ b/nomad/app/optimade/api.py @@ -26,7 +26,7 @@ blueprint = Blueprint('optimade', __name__) base_url = 'https://%s/%s/optimade' % ( config.services.api_host.strip('/'), - config.services.api_base_path.strip('/')) + config.services.api_prefix.strip('/')) def url(endpoint: str = None, version='v1', prefix=None, **kwargs): diff --git a/nomad/app_fastapi/main.py b/nomad/app_fastapi/main.py new file mode 100644 index 0000000000000000000000000000000000000000..b2802d715cab8f633582b416ec82923a09222ec7 --- /dev/null +++ b/nomad/app_fastapi/main.py @@ -0,0 +1,130 @@ +from fastapi import FastAPI, status, Request +from fastapi.responses import JSONResponse +from fastapi.middleware.wsgi import WSGIMiddleware +import traceback + +from nomad import config, utils +from nomad.app import app as flask_app +from nomad.app_fastapi.routers import users, entries, auth + + +logger = utils.get_logger(__name__) + +app = FastAPI( + root_path=config.services.api_prefix, + openapi_url='/api/v1/openapi.json', + docs_url='/api/v1/docs', + redoc_url='/api/v1/redoc', + swagger_ui_oauth2_redirect_url='/api/v1/docs/oauth2-redirect', + + title='NOMAD API', + version='v1, NOMAD %s@%s' % (config.meta.version, config.meta.commit), + description=utils.strip(''' + **Disclaimer!** This is the new NOMAD API. It is still under development and only includes a + part of the NOMAD API functionality. You can still use the old flask-based API + as `/api` and the optimade API as `/optimade/v1`. + + ## Getting started + + ... TODO put the examples and tutorial here ... + + ## Conventions + + ### Paths + + The various API operations are organized with the following path scheme. The first + part of the path, describes the data entity that is covered by + the operations below (e.g. `entries`, `users`, `datasets`, `uploads`). For example + everything below `entries` will be about searching entries, getting + an entry, editing entries, etc. + + The second (optional and variable) path segment allows to denote a specific entity instance, + e.g. a specific entry or dataset, usually by id. With out such a variable second + path segment, its about all instances, e.g. searching entries or listing all datasets. + + Optional (if available) further path segments will determine the variety and format + of data. This is mostly for entries to distinguish the metadata, raw, and archive + data or distinguish between listing (i.e. paginated json) and downloading + (i.e. streaming a zip-file) + + Further, we try to adhere to the paradim of getting and posting resources. Therefore, + when you post a complex query, you will not post it to `/entries` (a query is not an entry), + but `/entries/query`. Here *query* being a kind of virtual resource. + + ### Parameters and bodies for GET and POST operations + + We offer **GET** and **POST** versions for many complex operations. The idea is that + **GET** is easy to use, e.g. via curl or simply in the browser, while **POST** + allows to provide more complex parameters (i.e. a JSON body). For example to + search for entries, you can use the **GET** operation `/entries` to specify simple + queries via URL, e.g. `/entries?code_name=VASP&atoms=Ti`, but you would use + **POST** `/entries/query` to provide a complex nested queries, e.g. with logical + operators. + + Typicall the **POST** version is a super-set of the functionality of the **GET** + version. But, most top-level parameters in the **POST** body, will be available + in the **GET** version as URL parameters with the same name and meaning. This + is especially true for reoccuring parameters for general API concepts like pagination + or specifying required result fields. + + ### Response layout + + Typically a response will mirror all input parameters in the normalized form that + was used to perform the operation. + + Some of these will be augmented with result values. For example the pagination + section of a request will be augmented with the total available number. + + The actual requested data, will be placed under the key `data`. + + ## About Authentication + + NOMAD is an open datasharing platform, and most of the API operations do not require + any authorization and can be freely used without a user or credentials. However, + to upload data, edit data, or view your own and potentially unpublished data, + the API needs to authenticate you. + + The NOMAD API uses OAuth and tokens to authenticate users. We provide simple operations + that allow you to acquire an *access token* via username and password based + authentication (`/auth/token`). The resulting access token can then be used on all operations + (e.g. that support or require authentication). + + To use authentication in the dashboard, simply use the Authorize button. The + dashboard GUI will manage the access token and use it while you try out the various + operations. + ''')) + + +@app.on_event('startup') +async def startup_event(): + from nomad import infrastructure + # each subprocess is supposed disconnect connect again: https://jira.mongodb.org/browse/PYTHON-2090 + try: + from mongoengine import disconnect + disconnect() + except Exception: + pass + + infrastructure.setup() + + +@app.exception_handler(Exception) +async def unicorn_exception_handler(request: Request, e: Exception): + logger.error('unexpected exception in API', url=request.url, exc_info=e) + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={ + 'detail': { + 'reason': 'Unexpected exception while handling your request', + 'exception': str(e), + 'exception_class': e.__class__.__name__, + 'exception_traceback': traceback.format_exc() + } + } + ) + +app.include_router(auth.router, prefix='/api/v1/auth') +app.include_router(users.router, prefix='/api/v1/users') +app.include_router(entries.router, prefix='/api/v1/entries') + +app.mount('/', WSGIMiddleware(flask_app)) diff --git a/nomad/app_fastapi/models.py b/nomad/app_fastapi/models.py new file mode 100644 index 0000000000000000000000000000000000000000..64175667803f80a4559c97c8e9366450e257e31a --- /dev/null +++ b/nomad/app_fastapi/models.py @@ -0,0 +1,873 @@ +from typing import List, Dict, Optional, Union, Any +import enum +from fastapi import Body, Request, HTTPException, Query as FastApiQuery +import pydantic +from pydantic import BaseModel, Field, validator, root_validator +import datetime +import numpy as np +import re +import fnmatch + +from nomad import datamodel # pylint: disable=unused-import +from nomad.utils import strip +from nomad.metainfo import Datetime, MEnum +from nomad.app_fastapi.utils import parameter_dependency_from_model +from nomad.metainfo.search_extension import metrics, search_quantities + + +class User(BaseModel): + user_id: str + email: Optional[str] = None + first_name: Optional[str] = None + last_name: Optional[str] = None + + +Metric = enum.Enum('Metric', {name: name for name in metrics}) # type: ignore +Quantity = enum.Enum('Quantity', {name: name for name in search_quantities}) # type: ignore +AggregateableQuantity = enum.Enum('AggregateableQuantity', { # type: ignore + name: name for name in search_quantities + if search_quantities[name].aggregateable}) + +AggregateableQuantity.__doc__ = ''' + Statistics and aggregations can only be computed for those search quantities that have + discrete values. For example a statistics aggregates a certain metric (e.g. the number of entries) + over all entries were this quantity has the same value (bucket aggregation, think historgam here). +''' + +Value = Union[str, int, float, bool, datetime.datetime] +ComparableValue = Union[str, int, float, datetime.datetime] + + +class AggregationOrderType(str, enum.Enum): + ''' + Allows to order statistics or aggregations by either quantity values (`values`) or number + of entries (`entries`). + ''' + values = 'values' + entries = 'entries' + + +class HTTPExceptionModel(BaseModel): + detail: str + + +class NoneEmptyBaseModel(BaseModel): + @root_validator + def check_exists(cls, values): # pylint: disable=no-self-argument + assert any(value is not None for value in values.values()) + return values + + +class All(NoneEmptyBaseModel): + op: List[Value] = Field(None, alias='all') + + +class None_(NoneEmptyBaseModel): + op: List[Value] = Field(None, alias='none') + + +class Any_(NoneEmptyBaseModel): + op: List[Value] = Field(None, alias='any') + + +class ComparisonOperator(NoneEmptyBaseModel): pass + + +class Lte(ComparisonOperator): + op: ComparableValue = Field(None, alias='lte') + + +class Lt(ComparisonOperator): + op: ComparableValue = Field(None, alias='lt') + + +class Gte(ComparisonOperator): + op: ComparableValue = Field(None, alias='gte') + + +class Gt(ComparisonOperator): + op: ComparableValue = Field(None, alias='gt') + + +class LogicalOperator(NoneEmptyBaseModel): + + @validator('op', check_fields=False) + def validate_query(cls, query): # pylint: disable=no-self-argument + if isinstance(query, list): + return [_validate_query(item) for item in query] + + return _validate_query(query) + + +class And(LogicalOperator): + op: List['Query'] = Field(None, alias='and') + + +class Or(LogicalOperator): + op: List['Query'] = Field(None, alias='or') + + +class Not(LogicalOperator): + op: 'Query' = Field(None, alias='not') + + +ops = { + 'lte': Lte, + 'lt': Lt, + 'gte': Gte, + 'gt': Gt, + 'all': All, + 'none': None_, + 'any': Any_ +} + + +QueryParameterValue = Union[Value, List[Value], Lte, Lt, Gte, Gt, Any_, All, None_] + +Query = Union[ + Dict[str, QueryParameterValue], And, Or, Not] + + +And.update_forward_refs() +Or.update_forward_refs() +Not.update_forward_refs() + + +class Owner(str, enum.Enum): + ''' + The `owner` allows to limit the scope of the searched based on entry ownership. + This is useful, if you only want to search among all publically downloadable + entries, or only among your own entries, etc. + + These are the possible owner values and their meaning: + * `all`: Consider all entries. + * `public` (default): Consider all entries that can be publically downloaded, + i.e. only published entries without embargo + * `user`: Only consider entries that belong to you. + * `shared`: Only consider entries that belong to you or are shared with you. + * `visible`: Consider all entries that are visible to you. This includes + entries with embargo or unpublished entries that belong to you or are + shared with you. + * `staging`: Only search through unpublished entries. + ''' + + # There seems to be a slight bug in fast API. When it creates the example in OpenAPI + # it will ignore any given default or example and simply take the first enum value. + # Therefore, we put public first, which is the most default and save in most contexts. + public = 'public' + all_ = 'all' + visible = 'visible' + shared = 'shared' + user = 'user' + staging = 'staging' + admin = 'admin' + + +class WithQuery(BaseModel): + owner: Optional[Owner] = Body('public') + query: Optional[Query] = Body( + None, + embed=True, + description=strip(''' + A query can be very simple list of parameters. Different parameters are combined + with a logical **and**, values of the same parameter with also with a logical **and**. + The following would search for all entries that are VASP calculations, + contain *Na* **and** *Cl*, **and** are authored by *Stefano Curtarolo* + **and** *Chris Wolverton*. + ``` + { + "atoms": ["Na", "Cl"], + "dft.code_name": "VASP", + "authors": ["Stefano Curtarolo", "Chris Wolverton"] + } + ``` + + A short cut to change the logical combination of values in a list, is to + add a suffix to the quantity `:any`: + ``` + { + "atoms": ["Na", "Cl"], + "dft.code_name": "VASP", + "authors:any": ["Stefano Curtarolo", "Chris Wolverton"] + } + ``` + + Otherwise, you can also write complex logical combinations of parameters like this: + ``` + { + "and": [ + { + "or": [ + { + "atoms": ["Cl", "Na"] + }, + { + "atoms": ["H", "O"] + } + ] + }, + { + "not": { + "dft.crystal": "cubic" + } + } + ] + } + ``` + Other short-cut prefixes are `none:` and `any:` (the default). + + By default all quantity values have to **equal** the given values to match. For + some values you can also use comparison operators like this: + ``` + { + "upload_time": { + "gt": "2020-01-01", + "lt": "2020-08-01" + }, + "dft.workflow.section_geometry_optimization.final_energy_difference": { + "lte": 1.23e-18 + } + } + ``` + + or shorter with suffixes: + ``` + { + "upload_time:gt": "2020-01-01", + "upload_time:lt": "2020-08-01", + "dft.workflow.section_geometry_optimization.final_energy_difference:lte" 1.23e-18 + } + ``` + + The searchable quantities are a subset of the NOMAD Archive quantities defined + in the NOMAD Metainfo. The most common quantities are: %s. + ''' % ', '.join(reversed([ + '`%s`' % name + for name in search_quantities + if (name.startswith('dft') or '.' not in name) and len(name) < 20 + ]))), + example={ + 'upload_time:gt': '2020-01-01', + 'atoms': ['Ti', 'O'], + 'dft.code_name': 'VASP', + 'dft.workflow.section_geometry_optimization.final_energy_difference:lte': 1.23e-18, + 'dft.quantities': 'section_dos', + 'dft.system:any': ['bulk', '2d'] + }) + + @validator('query') + def validate_query(cls, query): # pylint: disable=no-self-argument + return _validate_query(query) + + +def _validate_query(query: Query): + if isinstance(query, dict): + for key, value in query.items(): + if ':' in key: + quantity, qualifier = key.split(':') + else: + quantity, qualifier = key, None + + assert quantity in search_quantities, '%s is not a searchable quantity' % key + if qualifier is not None: + assert quantity not in query, 'a quantity can only appear once in a query' + assert qualifier in ops, 'unknown quantity qualifier %s' % qualifier + del(query[key]) + query[quantity] = ops[qualifier](**{qualifier: value}) # type: ignore + elif isinstance(value, list): + query[quantity] = All(all=value) + + return query + + +def query_parameters( + request: Request, + owner: Optional[Owner] = FastApiQuery( + 'public', description=strip(Owner.__doc__)), + q: Optional[List[str]] = FastApiQuery( + [], description=strip(''' + Since we cannot properly offer forms for all parameters in the OpenAPI dashboard, + you can use the parameter `q` and encode a query parameter like this + `atoms__H` or `n_atoms__gt__3`. Multiple usage of `q` will combine parameters with + logical *and*. + '''))) -> WithQuery: + + # copy parameters from request + query_params = { + key: request.query_params.getlist(key) + for key in request.query_params} + + # add the encoded parameters + for parameter in q: + fragments = parameter.split('__') + if len(fragments) == 1 or len(fragments) > 3: + raise HTTPException(422, detail=[{ + 'loc': ['query', 'q'], + 'msg': 'wrong format, use <quantity>[__<op>]__<value>'}]) + name_op, value = '__'.join(fragments[:-1]), fragments[-1] + quantity_name = name_op.split('__')[0] + + if quantity_name not in search_quantities: + raise HTTPException(422, detail=[{ + 'loc': ['query', parameter], + 'msg': '%s is not a search quantity' % quantity_name}]) + + query_params.setdefault(name_op, []).append(value) + + # transform query parameters to query + query: Dict[str, Any] = {} + for key in query_params: + op = None + if '__' in key: + quantity_name, op = key.split('__') + else: + quantity_name = key + + if quantity_name not in search_quantities: + continue + + quantity = search_quantities[quantity_name] + type_ = quantity.definition.type + if type_ is Datetime: + type_ = datetime.datetime.fromisoformat + elif isinstance(type_, MEnum): + type_ = str + elif isinstance(type_, np.dtype): + type_ = float + elif type_ not in [int, float, bool]: + type_ = str + values = query_params[key] + values = [type_(value) for value in values] + + if op is None: + if quantity.many_and: + op = 'all' + if quantity.many_or: + op = 'any' + + if op is None: + if len(values) > 1: + raise HTTPException( + status_code=422, + detail=[{ + 'loc': ['query', key], + 'msg':'search parameter %s does not support multiple values' % key}]) + query[quantity_name] = values[0] + + elif op == 'all': + query[quantity_name] = All(all=values) + elif op == 'any': + query[quantity_name] = Any_(any=values) + elif op in ops: + if len(values) > 1: + raise HTTPException( + status_code=422, + detail=[{ + 'loc': ['query', key], + 'msg': 'operator %s does not support multiple values' % op}]) + query[quantity_name] = ops[op](**{op: values[0]}) + else: + raise HTTPException( + 422, detail=[{'loc': ['query', key], 'msg': 'operator %s is unknown' % op}]) + + return WithQuery(query=query, owner=owner) + + +class Direction(str, enum.Enum): + ''' + Order direction, either ascending (`asc`) or descending (`desc`) + ''' + asc = 'asc' + desc = 'desc' + + +class MetadataRequired(BaseModel): + ''' Defines which metadata quantities are included or excluded in the response. ''' + + include: Optional[List[str]] = Field( + None, description=strip(''' + Quantities to include for each result. Only those quantities will be + returned. The entry id quantity `calc_id` will always be included. + ''')) + exclude: Optional[List[str]] = Field( + None, description=strip(''' + Quantities to exclude for each result. Only all other quantities will + be returned. The quantity `calc_id` cannot be excluded. + ''')) + + @validator('include', 'exclude') + def validate_include(cls, value, values, field): # pylint: disable=no-self-argument + if value is None: + return None + + for item in value: + assert item in search_quantities or item[-1] == '*', \ + 'required fields must be valid search quantities or contain wildcards' + + if field.name == 'include' and 'calc_id' not in value: + value.append('calc_id') + + if field.name == 'exclude': + if 'calc_id' in value: + value.remove('calc_id') + + return value + + +metadata_required_parameters = parameter_dependency_from_model( + 'metadata_required_parameters', MetadataRequired) + + +class Pagination(BaseModel): + ''' Defines the order, size, and page of results. ''' + + size: Optional[int] = Field( + 10, description=strip(''' + The page size, e.g. the maximum number of entries contained in one response. + A `size` of 0 will omit any results; this is useful, when there is only + interest in other data, e.g. `aggregations` or `statistics`. + ''')) + order_by: Optional[Quantity] = Field( + Quantity.calc_id, # type: ignore + description=strip(''' + The search results are ordered by the values of this quantity. The response + either contains the first `size` value or the next `size` values after `after`. + ''')) + order: Optional[Direction] = Field( + Direction.asc, description=strip(''' + The order direction of the search results based on `order_by`. Its either + ascending `asc` or decending `desc`. + ''')) + after: Optional[str] = Field( + None, description=strip(''' + A request for the page after this value, i.e. the next `size` values behind `after`. + This depends on the `order_by` and the potentially used aggregation. + Each response contains the `after` value for the *next* request following + the defined order. + + The after value and its type depends on the `order_by` quantity and its type. + The after value will always be a string encoded value. The after value will + also contain the entry id as a *tie breaker*, if + `order_by` is not the entry's id. The *tie breaker* will be `:` separated, e.g. + `<value>:<id>`. + ''')) + + @validator('order_by') + def validate_order_by(cls, order_by): # pylint: disable=no-self-argument + if order_by is None: + return order_by + + assert order_by.value in search_quantities, 'order_by must be a valid search quantity' + quantity = search_quantities[order_by.value] + assert quantity.definition.is_scalar, 'the order_by quantity must be a scalar' + return order_by + + @validator('size') + def validate_size(cls, size): # pylint: disable=no-self-argument + assert size >= 0, 'size must be positive integer' + return size + + @validator('after') + def validate_after(cls, after, values): # pylint: disable=no-self-argument + order_by = values.get('order_by', Quantity.calc_id) + if after is not None and order_by is not None and order_by != Quantity.calc_id and ':' not in after: + after = '%s:' % after + return after + + +pagination_parameters = parameter_dependency_from_model( + 'pagination_parameters', Pagination) + + +class AggregationPagination(Pagination): + order_by: Optional[Quantity] = Field( + None, description=strip(''' + The search results are ordered by the values of this quantity. The response + either contains the first `size` value or the next `size` values after `after`. + ''')) + + +class AggregatedEntities(BaseModel): + size: Optional[pydantic.conint(gt=0)] = Field( # type: ignore + 1, description=strip(''' + The maximum number of entries that should be returned for each value in the + aggregation. + ''')) + required: Optional[MetadataRequired] = Field( + None, description=strip(''' + This allows to determined what fields should be returned for each entry. + ''')) + + +class Aggregation(BaseModel): + quantity: AggregateableQuantity = Field( + ..., description=strip(''' + The manatory name of the quantity for the aggregation. Aggregations + can only be computed for those search metadata that have discrete values; + an aggregation buckets entries that have the same value for this quantity.''')) + pagination: Optional[AggregationPagination] = Field( + AggregationPagination(), description=strip(''' + Only the data few values are returned for each API call. Pagination allows to + get the next set of values based on the last value in the last call. + ''')) + entries: Optional[AggregatedEntities] = Field( + None, description=strip(''' + Optionally, a set of entries can be returned for each value. + ''')) + + +class StatisticsOrder(BaseModel): + type_: Optional[AggregationOrderType] = Field(AggregationOrderType.entries, alias='type') + direction: Optional[Direction] = Field(Direction.desc) + + +class Statistic(BaseModel): + quantity: AggregateableQuantity = Field( + ..., description=strip(''' + The manatory name of the quantity that the statistic is calculated for. Statistics + can only be computed for those search metadata that have discrete values; a statistics + aggregates a certain metric (e.g. the number of entries) over all entries were + this quantity has the same value (bucket aggregation, think historgam here). + + There is one except and these are date/time values quantities (most notably `upload_time`). + Here each statistic value represents an time interval. The interval can + be determined via `datetime_interval`.''')) + metrics: Optional[List[Metric]] = Field( + [], description=strip(''' + By default the returned statistics will provide the number of entries for each + value. You can add more metrics. For each metric an additional number will be + provided for each value. Metrics are also based on search metadata. Depending on + the metric the number will represent either a sum (`calculations` for the number + of individual calculation in each code run) or an amount of different values + (i.e. `materials` for the amount of different material hashes).''')) + datetime_interval: Optional[pydantic.conint(gt=0)] = Field( # type: ignore + None, description=strip(''' + While statistics in general are only possible for quantities with discrete values, + these is one exception. These are date/time values quantities (most notably `upload_time`). + Here each statistic value represents an time interval. + + A date/time interval is a number of seconds greater than 0. This will only be used for + date/time valued quantities (e.g. `upload_time`). + ''')) + value_filter: Optional[pydantic.constr(regex=r'^[a-zA-Z0-9_\-\s]+$')] = Field( # type: ignore + None, description=strip(''' + An optional filter for values. Only values that contain the filter as substring + will be part of the statistics. + ''')) + size: Optional[pydantic.conint(gt=0)] = Field( # type: ignore + None, description=strip(''' + An optional maximum size of values in the statistics. The default depends on the + quantity. + ''')) + order: Optional[StatisticsOrder] = Field( + StatisticsOrder(), description=strip(''' + The values in the statistics are either ordered by the entry count or by the + natural ordering of the values. + ''')) + + @root_validator(skip_on_failure=True) + def fill_default_size(cls, values): # pylint: disable=no-self-argument + if 'size' not in values or values['size'] is None: + values['size'] = search_quantities[values['quantity'].value].statistic_size + + return values + + +class WithQueryAndPagination(WithQuery): + pagination: Optional[Pagination] = Body( + None, + example={ + 'size': 5, + 'order_by': 'upload_time' + }) + + +class EntriesMetadata(WithQueryAndPagination): + required: Optional[MetadataRequired] = Body( + None, + example={ + 'include': ['calc_id', 'mainfile', 'upload_id', 'authors', 'upload_time'] + }) + statistics: Optional[Dict[str, Statistic]] = Body( + {}, + description=strip(''' + This allows to define additional statistics that should be returned. + Statistics aggregate entries that show the same quantity values for a given quantity. + A simple example is the number of entries for each `dft.code_name`. These statistics + will be computed only over the query results. This allows to get an overview about + query results. + '''), + example={ + 'by_code_name': { + 'metrics': ['uploads', 'datasets'], + 'quantity': 'dft.code_name' + } + }) + aggregations: Optional[Dict[str, Aggregation]] = Body( + {}, + example={ + 'uploads': { + 'quantity': 'upload_id', + 'pagination': { + 'size': 10, + 'order_by': 'upload_time' + }, + 'entries': { + 'size': 1, + 'required': { + 'include': ['mainfile'] + } + } + } + }, + description=strip(''' + Defines additional aggregations to return. An aggregation list entries + for the values of a quantity, e.g. to get all uploads and their entries. + ''')) + + +class Files(BaseModel): + ''' Configures the download of files. ''' + compress: Optional[bool] = Field( + False, description=strip(''' + By default the returned zip file is not compressed. This allows to enable compression. + Compression will reduce the rate at which data is provided, often below + the rate of the compression. Therefore, compression is only sensible if the + network connection is limited.''')) + glob_pattern: Optional[str] = Field( + None, description=strip(''' + An optional *glob* (or unix style path) pattern that is used to filter the + returned files. Only files matching the pattern are returned. The pattern is only + applied to the end of the full path. Internally + [fnmatch](https://docs.python.org/3/library/fnmatch.html) is used.''')) + re_pattern: Optional[str] = Field( + None, description=strip(''' + An optional regexp that is used to filter the returned files. Only files matching + the pattern are returned. The pattern is applied in search mode to the full + path of the files. With `$` and `^` you can control if you want to match the + whole path. + + A re pattern will replace a given glob pattern.''')) + + @validator('glob_pattern') + def validate_glob_pattern(cls, glob_pattern): # pylint: disable=no-self-argument + # compile the glob pattern into re + if glob_pattern is None: + return None + + return re.compile(fnmatch.translate(glob_pattern) + r'$') + + @validator('re_pattern') + def validate_re_pattern(cls, re_pattern): # pylint: disable=no-self-argument + # compile an re + if re_pattern is None: + return None + try: + return re.compile(re_pattern) + except re.error as e: + assert False, 'could not parse the re pattern: %s' % e + + @root_validator() + def vaildate(cls, values): # pylint: disable=no-self-argument + # use the compiled glob pattern as re + if values.get('re_pattern') is None: + values['re_pattern'] = values.get('glob_pattern') + return values + + +files_parameters = parameter_dependency_from_model( + 'files_parameters', Files) + + +ArchiveRequired = Union[str, Dict[str, Any]] + + +class EntriesArchive(WithQueryAndPagination): + required: Optional[ArchiveRequired] = Body( + '*', + embed=True, + description=strip(''' + The `required` part allows you to specify what parts of the requested archives + should be returned. The NOMAD Archive is a hierarchical data format and + you can *require* certain branches (i.e. *sections*) in the hierarchy. + By specifing certain sections with specific contents or all contents (via `"*"`), + you can determine what sections and what quantities should be returned. + The default is everything: `"*"`. + + For example to specify that you are only interested in the `section_metadata` + use: + + ``` + { + "section_run": "*" + } + ``` + + Or to only get the `energy_total` from each individual calculations, use: + ``` + { + "section_run": { + "section_single_configuration_calculation": { + "energy_total": "*" + } + } + } + ``` + + You can also request certain parts of a list, e.g. the last calculation: + ``` + { + "section_run": { + "section_single_configuration_calculation[-1]": "*" + } + } + ``` + + These required specifications are also very useful to get workflow results. + This works because we can use references (e.g. workflow to final result calculation) + and the API will resolve these references and return the respective data. + For example just the total energy value and reduced formula from the resulting + calculation: + ``` + { + 'section_workflow': { + 'calculation_result_ref': { + 'energy_total': '*', + 'single_configuration_calculation_to_system_ref': { + 'chemical_composition_reduced': '*' + } + } + } + } + ``` + '''), + example={ + 'section_run': { + 'section_single_configuration_calculation[-1]': { + 'energy_total': '*' + }, + 'section_system[-1]': '*' + }, + 'section_metadata': '*' + }) + + +class EntriesArchiveDownload(WithQuery): + files: Optional[Files] = Body(None) + + +class EntriesRaw(WithQuery): + pagination: Optional[Pagination] = Body(None) + + +class EntriesRawDownload(WithQuery): + files: Optional[Files] = Body( + None, + example={ + 'glob_pattern': 'vasp*.xml*' + }) + + +class PaginationResponse(Pagination): + total: int = Field(..., description=strip(''' + The total number of entries that fit the given `query`. This is independent of + any pagination and aggregations. + ''')) + next_after: Optional[str] = Field(None, description=strip(''' + The *next* after value to be used as `after` in a follow up requests for the + next page of results. + ''')) + + +class StatisticResponse(Statistic): + data: Dict[str, Dict[str, int]] = Field( + None, description=strip(''' + The returned statistics data as dictionary. The key is a string representation of the values. + The concrete type depends on the quantity that was used to create the statistics. + Each dictionary value is a dictionary itself. The keys are the metric names the + values the metric values. The key `entries` that gives the amount of entries with + this value is always returned.''')) + + +class AggregationDataItem(BaseModel): + data: Optional[List[Dict[str, Any]]] = Field( + None, description=strip('''The entries that were requested for each value.''')) + size: int = Field( + None, description=strip('''The amount of entries with this value.''')) + + +class AggregationResponse(Aggregation): + pagination: PaginationResponse # type: ignore + data: Dict[str, AggregationDataItem] = Field( + None, description=strip(''' + The aggregation data as a dictionary. The key is a string representation of the values. + The dictionary values contain the aggregated data depending if `entries` where + requested.''')) + + +class CodeResponse(BaseModel): + curl: str + requests: str + nomad_lab: Optional[str] + + +class EntriesMetadataResponse(EntriesMetadata): + pagination: PaginationResponse + statistics: Optional[Dict[str, StatisticResponse]] # type: ignore + aggregations: Optional[Dict[str, AggregationResponse]] # type: ignore + data: List[Dict[str, Any]] = Field( + None, description=strip(''' + The entries data as a list. Each item is a dictionary with the metadata for each + entry.''')) + code: Optional[CodeResponse] + + +class EntryRawFile(BaseModel): + path: str = Field(None) + size: int = Field(None) + + +class EntryRaw(BaseModel): + calc_id: str = Field(None) + upload_id: str = Field(None) + mainfile: str = Field(None) + files: List[EntryRawFile] = Field(None) + + +class EntriesRawResponse(EntriesRaw): + pagination: PaginationResponse = Field(None) + data: List[EntryRaw] = Field(None) + + +class EntryMetadataResponse(BaseModel): + entry_id: str = Field(None) + required: MetadataRequired = Field(None) + data: Dict[str, Any] = Field( + None, description=strip('''A dictionary with the metadata of the requested entry.''')) + + +class EntryRawResponse(BaseModel): + entry_id: str = Field(...) + data: EntryRaw = Field(...) + + +class EntryArchive(BaseModel): + calc_id: str = Field(None) + upload_id: str = Field(None) + parser_name: str = Field(None) + archive: Any = Field(None) + + +class EntriesArchiveResponse(EntriesArchive): + pagination: PaginationResponse = Field(None) + data: List[EntryArchive] = Field(None) + + +class EntryArchiveResponse(BaseModel): + entry_id: str = Field(...) + data: Dict[str, Any] + + +class SearchResponse(EntriesMetadataResponse): + es_query: Any = Field( + None, description=strip('''The elasticsearch query that was used to retrieve the results.''')) diff --git a/nomad/app_fastapi/routers/auth.py b/nomad/app_fastapi/routers/auth.py new file mode 100644 index 0000000000000000000000000000000000000000..3e5cd25072af77bbee32fb6a2fb3689e14e1894f --- /dev/null +++ b/nomad/app_fastapi/routers/auth.py @@ -0,0 +1,112 @@ +from fastapi import Depends, APIRouter, HTTPException, status +from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm +from pydantic import BaseModel + +from nomad import infrastructure +from nomad.utils import get_logger, strip +from nomad.app_fastapi.models import User, HTTPExceptionModel +from nomad.app_fastapi.utils import create_responses + +logger = get_logger(__name__) + +router = APIRouter() +default_tag = 'auth' + + +class Token(BaseModel): + access_token: str + token_type: str + + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl='/api/v1/auth/token', auto_error=False) + + +async def get_optional_user(access_token: str = Depends(oauth2_scheme)) -> User: + ''' + A dependency that provides the authenticated (if credentials are available) or None. + ''' + if access_token is None: + return None + + try: + return User(**infrastructure.keycloak.tokenauth(access_token)) + except infrastructure.KeycloakError as e: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=str(e), headers={'WWW-Authenticate': 'Bearer'}) + + +async def get_required_user(user: User = Depends(get_optional_user)) -> User: + ''' + A dependency that provides the authenticated user or raises 401 if no user is + authenticated. + ''' + if user is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail='Authentication required', + headers={'WWW-Authenticate': 'Bearer'}) + + return user + + +_bad_credentials_response = status.HTTP_401_UNAUTHORIZED, { + 'model': HTTPExceptionModel, + 'description': strip(''' + Unauthorized. The provided credentials were not recognized.''')} + + +@router.post( + '/token', + tags=[default_tag], + summary='Get an access token', + responses=create_responses(_bad_credentials_response), + response_model=Token) +async def get_token(form_data: OAuth2PasswordRequestForm = Depends()): + ''' + This API uses OAuth as an authentication mechanism. This operation allows you to + retrieve an *access token* by posting username and password as form data. + + This token can be used on subsequent API calls to authenticate + you. Operations that support or require authentication will expect the *access token* + in an HTTP Authorization header like this: `Authorization: Bearer <access token>`. + + On the OpenAPI dashboard, you can use the *Authorize* button at the top. + + You only need to provide `username` and `password` values. You can ignore the other + parameters. + ''' + + try: + access_token = infrastructure.keycloak.basicauth( + form_data.username, form_data.password) + except infrastructure.KeycloakError: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail='Incorrect username or password', + headers={'WWW-Authenticate': 'Bearer'}) + + return {'access_token': access_token, 'token_type': 'bearer'} + + +@router.get( + '/token', + tags=[default_tag], + summary='Get an access token', + responses=create_responses(_bad_credentials_response), + response_model=Token) +async def get_token_via_query(username: str, password: str): + ''' + This is an convenience alternative to the **POST** version of this operation. + It allows you to retrieve an *access token* by providing username and password. + ''' + + try: + access_token = infrastructure.keycloak.basicauth(username, password) + except infrastructure.KeycloakError: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail='Incorrect username or password', + headers={'WWW-Authenticate': 'Bearer'}) + + return {'access_token': access_token, 'token_type': 'bearer'} diff --git a/nomad/app_fastapi/routers/entries.py b/nomad/app_fastapi/routers/entries.py new file mode 100644 index 0000000000000000000000000000000000000000..df83e5b3cea59e073e9be564838fe05895810b5b --- /dev/null +++ b/nomad/app_fastapi/routers/entries.py @@ -0,0 +1,761 @@ +from typing import Dict, Iterator, Any, List, Set, cast +from fastapi import APIRouter, Depends, Path, status, HTTPException +from fastapi.responses import StreamingResponse +import os.path +import io +import json +import orjson + +from nomad import search, files, config, utils +from nomad.utils import strip +from nomad.archive import ( + query_archive, ArchiveQueryError, compute_required_with_referenced, + read_partial_archives_from_mongo, filter_archive) +from nomad.app_fastapi.utils import create_streamed_zipfile, File, create_responses +from nomad.app_fastapi.routers.auth import get_optional_user +from nomad.app_fastapi.models import ( + Pagination, WithQuery, MetadataRequired, EntriesMetadataResponse, EntriesMetadata, + EntryMetadataResponse, query_parameters, metadata_required_parameters, Files, Query, + pagination_parameters, files_parameters, User, Owner, HTTPExceptionModel, EntriesRaw, + EntriesRawResponse, EntriesRawDownload, EntryRaw, EntryRawFile, EntryRawResponse, + EntriesArchiveDownload, EntryArchiveResponse, EntriesArchive, EntriesArchiveResponse, + ArchiveRequired) + + +router = APIRouter() +default_tag = 'entries' +metadata_tag = 'entries/metadata' +raw_tag = 'entries/raw' +archive_tag = 'entries/archive' + +logger = utils.get_logger(__name__) + + +_bad_owner_response = status.HTTP_401_UNAUTHORIZED, { + 'model': HTTPExceptionModel, + 'description': strip(''' + Unauthorized. The given owner requires authorization, + but no or bad authentication credentials are given.''')} + +_bad_id_response = status.HTTP_404_NOT_FOUND, { + 'model': HTTPExceptionModel, + 'description': strip(''' + Entry not found. The given id does not match any entry.''')} + +_raw_download_response = 200, { + 'content': {'application/zip': {}}, + 'description': strip(''' + A zip file with the requested raw files. The file is streamed. + The content length is not known in advance. + ''')} + +_archive_download_response = 200, { + 'content': {'application/zip': {}}, + 'description': strip(''' + A zip file with the requested archive files. The file is streamed. + The content length is not known in advance. + ''')} + + +_bad_archive_required_response = status.HTTP_400_BAD_REQUEST, { + 'model': HTTPExceptionModel, + 'description': strip(''' + The given required specification could not be understood.''')} + + +def perform_search(*args, **kwargs): + try: + return search.search(*args, **kwargs) + except search.AuthenticationRequiredError as e: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(e)) + except search.ElasticSearchError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail='Elasticsearch could not process your query: %s' % str(e)) + + +@router.post( + '/query', tags=['entries/metadata'], + summary='Search entries and retrieve their metadata', + response_model=EntriesMetadataResponse, + responses=create_responses(_bad_owner_response), + response_model_exclude_unset=True, + response_model_exclude_none=True) +async def post_entries_metadata_query( + data: EntriesMetadata, + user: User = Depends(get_optional_user)): + + ''' + Executes a *query* and returns a *page* of the results with *required* result data + as well as *statistics* and *aggregated* data. + + This is the basic search operation to retrieve metadata for entries that match + certain search criteria (`query` and `owner`). All parameters (including `query`, `owner`) + are optional. Look at the body schema or parameter documentation for more details. + + By default the *empty* search (that returns everything) is performed. Only a small + page of the search results are returned at a time; use `pagination` in subsequent + requests to retrive more data. Each entry has a lot of different *metadata*, use + `required` to limit the data that is returned. + + The `statistics` and `aggregations` keys will further allow to return statistics + and aggregated data over all search results. + ''' + + return perform_search( + owner=data.owner, + query=data.query, + pagination=data.pagination, + required=data.required, + statistics=data.statistics, + aggregations=data.aggregations, + user_id=user.user_id if user is not None else None) + + +@router.get( + '', tags=[metadata_tag], + summary='Search entries and retrieve their metadata', + response_model=EntriesMetadataResponse, + responses=create_responses(_bad_owner_response), + response_model_exclude_unset=True, + response_model_exclude_none=True) +async def get_entries_metadata( + with_query: WithQuery = Depends(query_parameters), + pagination: Pagination = Depends(pagination_parameters), + required: MetadataRequired = Depends(metadata_required_parameters), + user: User = Depends(get_optional_user)): + ''' + Executes a *query* and returns a *page* of the results with *required* result data. + This is a version of `/entries/query`. Queries work a little different, because + we cannot put complex queries into URL parameters. + + In addition to the `q` parameter (see parameter documentation for details), you can use all NOMAD + search quantities as parameters, e.g. `?atoms=H&atoms=O`. Those quantities can be + used with additional operators attached to their names, e.g. `?n_atoms__gte=3` for + all entries with more than 3 atoms. Operators are `all`, `any`, `none`, `gte`, + `gt`, `lt`, `lte`. + ''' + + return perform_search( + owner=with_query.owner, query=with_query.query, + pagination=pagination, required=required, + user_id=user.user_id if user is not None else None) + + +def _do_exaustive_search(owner: Owner, query: Query, include: List[str], user: User) -> Iterator[Dict[str, Any]]: + after = None + while True: + response = perform_search( + owner=owner, query=query, + pagination=Pagination(size=100, after=after, order_by='upload_id'), + required=MetadataRequired(include=include), + user_id=user.user_id if user is not None else None) + + after = response.pagination.next_after + + for result in response.data: + yield result + + if after is None or len(response.data) == 0: + break + + +class _Uploads(): + ''' + A helper class that caches subsequent access to upload files the same upload. + ''' + def __init__(self): + self._upload_files = None + + def get_upload_files(self, upload_id: str) -> files.UploadFiles: + if self._upload_files is not None and self._upload_files.upload_id != upload_id: + self._upload_files.close() + + if self._upload_files is None or self._upload_files.upload_id != upload_id: + self._upload_files = files.UploadFiles.get( + upload_id, is_authorized=lambda *args, **kwargs: True) + + return self._upload_files + + def close(self): + if self._upload_files is not None: + self._upload_files.close() + + +def _create_entry_raw(entry_metadata: Dict[str, Any], uploads: _Uploads): + calc_id = entry_metadata['calc_id'] + upload_id = entry_metadata['upload_id'] + mainfile = entry_metadata['mainfile'] + + upload_files = uploads.get_upload_files(upload_id) + mainfile_dir = os.path.dirname(mainfile) + + files = [] + for file_name, file_size in upload_files.raw_file_list(directory=mainfile_dir): + path = os.path.join(mainfile_dir, file_name) + files.append(EntryRawFile(path=path, size=file_size)) + + return EntryRaw(calc_id=calc_id, upload_id=upload_id, mainfile=mainfile, files=files) + + +def _answer_entries_raw_request( + owner: Owner, query: Query, pagination: Pagination, user: User): + + if owner == Owner.all_: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=strip(''' + The owner=all is not allowed for this operation as it will search for entries + that you might now be allowed to access. + ''')) + + search_response = perform_search( + owner=owner, query=query, + pagination=pagination, + required=MetadataRequired(include=['calc_id', 'upload_id', 'mainfile']), + user_id=user.user_id if user is not None else None) + + uploads = _Uploads() + try: + response_data = [ + _create_entry_raw(entry_metadata, uploads) + for entry_metadata in search_response.data] + finally: + uploads.close() + + return EntriesRawResponse( + owner=search_response.owner, + query=search_response.query, + pagination=search_response.pagination, + data=response_data) + + +def _answer_entries_raw_download_request(owner: Owner, query: Query, files: Files, user: User): + if owner == Owner.all_: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=strip(''' + The owner=all is not allowed for this operation as it will search for entries + that you might now be allowed to access. + ''')) + + response = perform_search( + owner=owner, query=query, + pagination=Pagination(size=0), + required=MetadataRequired(include=[]), + user_id=user.user_id if user is not None else None) + + if response.pagination.total > config.max_entry_download: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail='The limit of maximum number of entries in a single download (%d) has been exeeded (%d).' % ( + config.max_entry_download, response.pagination.total)) + + uploads = _Uploads() + files_params = Files() if files is None else files + manifest = [] + search_includes = ['calc_id', 'upload_id', 'mainfile'] + streamed_paths: Set[str] = set() + + try: + # a generator of File objects to create the streamed zip from + def raw_file_generator(): + # go through all entries that match the query + for entry_metadata in _do_exaustive_search(owner, query, include=search_includes, user=user): + upload_id = entry_metadata['upload_id'] + mainfile = entry_metadata['mainfile'] + + upload_files = uploads.get_upload_files(upload_id) + mainfile_dir = os.path.dirname(mainfile) + + # go through all files that belong to this entry + all_filtered = True + files = upload_files.raw_file_list(directory=mainfile_dir) + for file_name, file_size in files: + path = os.path.join(mainfile_dir, file_name) + + # apply the filter + if files_params.re_pattern is not None and not files_params.re_pattern.search(path): + continue + all_filtered = False + + # add upload_id to path used in streamed zip + streamed_path = os.path.join(upload_id, path) + + # check if already streamed + if streamed_path in streamed_paths: + continue + streamed_paths.add(streamed_path) + + # yield the file + with upload_files.raw_file(path, 'rb') as f: + yield File(path=streamed_path, f=f, size=file_size) + + if not all_filtered or len(files) == 0: + entry_metadata['mainfile'] = os.path.join(upload_id, mainfile) + manifest.append(entry_metadata) + + # add the manifest at the end + manifest_content = json.dumps(manifest).encode() + yield File(path='manifest.json', f=io.BytesIO(manifest_content), size=len(manifest_content)) + + # create the streaming response with zip file contents + content = create_streamed_zipfile(raw_file_generator(), compress=files_params.compress) + return StreamingResponse(content, media_type='application/zip') + finally: + uploads.close() + + +_entries_raw_query_docstring = strip(''' + Will perform a search and return a *page* of raw file metadata for entries fulfilling + the query. This allows you to get a complete list of all rawfiles with their full + path in their respective upload and their sizes. The first returned files for each + entry, is their respective *mainfile*. + + Each entry on NOMAD represents a set of raw files. These are the input and output + files (as well as additional auxiliary files) in their original form, i.e. as + provided by the uploader. More specifically, an entry represents a code-run identified + by a certain *mainfile*. This is usually the main output file of the code. All other + files in the same directory are considered the entries *auxiliary* no matter their role + or if they were actually parsed by NOMAD. + + This operation supports the usual `owner`, `query`, and `pagination` parameters. + ''') + + +@router.post( + '/raw/query', + tags=[raw_tag], + summary='Search entries and get their raw files metadata', + description=_entries_raw_query_docstring, + response_model=EntriesRawResponse, + responses=create_responses(_bad_owner_response), + response_model_exclude_unset=True, + response_model_exclude_none=True) +async def post_entries_raw_query(data: EntriesRaw, user: User = Depends(get_optional_user)): + + return _answer_entries_raw_request( + owner=data.owner, query=data.query, pagination=data.pagination, user=user) + + +@router.get( + '/raw', + tags=[raw_tag], + summary='Search entries and get raw their raw files metadata', + description=_entries_raw_query_docstring, + response_model=EntriesRawResponse, + response_model_exclude_unset=True, + response_model_exclude_none=True, + responses=create_responses(_bad_owner_response)) +async def get_entries_raw( + with_query: WithQuery = Depends(query_parameters), + pagination: Pagination = Depends(pagination_parameters), + user: User = Depends(get_optional_user)): + + return _answer_entries_raw_request( + owner=with_query.owner, query=with_query.query, pagination=pagination, user=user) + + +_entries_raw_download_query_docstring = strip(''' + This operation will perform a search and stream a .zip file with raw input and output + files of the found entries. + + Each entry on NOMAD represents a set of raw files. These are the input and output + files (as well as additional auxiliary files) in their original form, i.e. as + provided by the uploader. More specifically, an entry represents a code-run identified + by a certain *mainfile*. This is usually the main output file of the code. All other + files in the same directory are considered the entries *auxiliary* no matter their role + or if they were actually parsed by NOMAD. + + After performing a search (that uses the same parameters as in all search operations), + NOMAD will iterate through all results and create a .zip-file with all the entries' + main and auxiliary files. The files will be organized in the same directory structure + that they were uploaded in. The respective upload root directories are further prefixed + with the `upload_id` of the respective uploads. The .zip-file will further contain + a `manifest.json` with `upload_id`, `calc_id`, and `mainfile` of each entry. + ''') + + +@router.post( + '/raw/download/query', + tags=[raw_tag], + summary='Search entries and download their raw files', + description=_entries_raw_download_query_docstring, + response_class=StreamingResponse, + responses=create_responses(_raw_download_response, _bad_owner_response)) +async def post_entries_raw_download_query( + data: EntriesRawDownload, user: User = Depends(get_optional_user)): + + return _answer_entries_raw_download_request( + owner=data.owner, query=data.query, files=data.files, user=user) + + +@router.get( + '/raw/download', + tags=[raw_tag], + summary='Search entries and download their raw files', + description=_entries_raw_download_query_docstring, + response_class=StreamingResponse, + responses=create_responses(_raw_download_response, _bad_owner_response)) +async def get_entries_raw_download( + with_query: WithQuery = Depends(query_parameters), + files: Files = Depends(files_parameters), + user: User = Depends(get_optional_user)): + + return _answer_entries_raw_download_request( + owner=with_query.owner, query=with_query.query, files=files, user=user) + + +def _read_archive(entry_metadata, uploads, required): + calc_id = entry_metadata['calc_id'] + upload_id = entry_metadata['upload_id'] + upload_files = uploads.get_upload_files(upload_id) + + try: + with upload_files.read_archive(calc_id) as archive: + return { + 'calc_id': calc_id, + 'parser_name': entry_metadata['parser_name'], + 'archive': query_archive(archive, {calc_id: required})[calc_id] + } + except ArchiveQueryError as e: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + + +def _answer_entries_archive_request( + owner: Owner, query: Query, pagination: Pagination, required: ArchiveRequired, + user: User): + + if owner == Owner.all_: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=strip(''' + The owner=all is not allowed for this operation as it will search for entries + that you might now be allowed to access. + ''')) + + if required is None: + required = '*' + + try: + required_with_references = compute_required_with_referenced(required) + except KeyError as e: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=( + 'The required specification contains an unknown quantity or section: %s' % str(e))) + + search_response = perform_search( + owner=owner, query=query, + pagination=pagination, + required=MetadataRequired(include=['calc_id', 'upload_id', 'parser_name']), + user_id=user.user_id if user is not None else None) + + if required_with_references is not None: + # We can produce all the required archive data from the partial archives stored + # in mongodb. + entry_ids = [entry['calc_id'] for entry in search_response.data] + partial_archives = cast(dict, read_partial_archives_from_mongo(entry_ids, as_dict=True)) + + uploads = _Uploads() + response_data = {} + for entry_metadata in search_response.data: + calc_id, upload_id = entry_metadata['calc_id'], entry_metadata['upload_id'] + + archive_data = None + if required_with_references is not None: + try: + partial_archive = partial_archives[calc_id] + archive_data = filter_archive(required, partial_archive, transform=lambda e: e) + except KeyError: + # the partial archive might not exist, e.g. due to processing problems + pass + except ArchiveQueryError as e: + detail = 'The required specification could not be understood: %s' % str(e) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=detail) + + if archive_data is None: + try: + archive_data = _read_archive(entry_metadata, uploads, required)['archive'] + except KeyError as e: + logger.error('missing archive', exc_info=e, calc_id=calc_id) + continue + + response_data[calc_id] = { + 'calc_id': calc_id, + 'upload_id': upload_id, + 'parser_name': entry_metadata['parser_name'], + 'archive': archive_data} + + uploads.close() + + return EntriesArchiveResponse( + owner=search_response.owner, + query=search_response.query, + pagination=search_response.pagination, + data=list(response_data.values())) + + +_entries_archive_docstring = strip(''' + This operation will perform a search with the given `query` and `owner` and return + the a *page* of `required` archive data. Look at the body schema or parameter documentation + for more details. The **GET** version of this operation will only allow to provide + the full archives. + ''') + + +@router.post( + '/archive/query', + tags=[archive_tag], + summary='Search entries and access their archives', + description=_entries_archive_docstring, + response_model=EntriesArchiveResponse, + response_model_exclude_unset=True, + response_model_exclude_none=True, + responses=create_responses(_bad_owner_response, _bad_archive_required_response)) +async def post_entries_archive_query( + data: EntriesArchive, user: User = Depends(get_optional_user)): + + return _answer_entries_archive_request( + owner=data.owner, query=data.query, pagination=data.pagination, + required=data.required, user=user) + + +@router.get( + '/archive', + tags=[archive_tag], + summary='Search entries and access their archives', + description=_entries_archive_docstring, + response_model=EntriesArchiveResponse, + response_model_exclude_unset=True, + response_model_exclude_none=True, + responses=create_responses(_bad_owner_response, _bad_archive_required_response)) +async def get_entries_archive_query( + with_query: WithQuery = Depends(query_parameters), + pagination: Pagination = Depends(pagination_parameters), + user: User = Depends(get_optional_user)): + + return _answer_entries_archive_request( + owner=with_query.owner, query=with_query.query, pagination=pagination, + required=None, user=user) + + +def _answer_entries_archive_download_request( + owner: Owner, query: Query, files: Files, user: User): + + if owner == Owner.all_: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=strip(''' + The owner=all is not allowed for this operation as it will search for entries + that you might now be allowed to access. + ''')) + + files_params = Files() if files is None else files + + response = perform_search( + owner=owner, query=query, + pagination=Pagination(size=0), + required=MetadataRequired(include=[]), + user_id=user.user_id if user is not None else None) + + if response.pagination.total > config.max_entry_download: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=( + 'The limit of maximum number of entries in a single download (%d) has been ' + 'exeeded (%d).' % (config.max_entry_download, response.pagination.total))) + + uploads = _Uploads() + manifest = [] + search_includes = ['calc_id', 'upload_id', 'parser_name'] + + # a generator of File objects to create the streamed zip from + def file_generator(): + # go through all entries that match the query + for entry_metadata in _do_exaustive_search(owner, query, include=search_includes, user=user): + path = os.path.join(entry_metadata['upload_id'], '%s.json' % entry_metadata['calc_id']) + try: + archive_data = _read_archive(entry_metadata, uploads, '*') + + f = io.BytesIO(orjson.dumps( + archive_data, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS)) + + yield File(path=path, f=f, size=f.getbuffer().nbytes) + except KeyError as e: + logger.error('missing archive', calc_id=entry_metadata['calc_id'], exc_info=e) + + entry_metadata['path'] = path + manifest.append(entry_metadata) + + # add the manifest at the end + manifest_content = json.dumps(manifest).encode() + yield File(path='manifest.json', f=io.BytesIO(manifest_content), size=len(manifest_content)) + + try: + # create the streaming response with zip file contents + content = create_streamed_zipfile(file_generator(), compress=files_params.compress) + return StreamingResponse(content, media_type='application/zip') + finally: + uploads.close() + + +_entries_archive_download_docstring = strip(''' + This operation will perform a search with the given `query` and `owner` and stream + a .zip-file with the full archive contents for all matching entries. This is not + paginated. Look at the body schema or parameter documentation for more details. + ''') + + +@router.post( + '/archive/download/query', + tags=[archive_tag], + summary='Search entries and download their archives', + description=_entries_archive_download_docstring, + response_class=StreamingResponse, + responses=create_responses( + _archive_download_response, _bad_owner_response, _bad_archive_required_response)) +async def post_entries_archive_download_query( + data: EntriesArchiveDownload, user: User = Depends(get_optional_user)): + + return _answer_entries_archive_download_request( + owner=data.owner, query=data.query, files=data.files, user=user) + + +@router.get( + '/archive/download', + tags=[archive_tag], + summary='Search entries and download their archives', + description=_entries_archive_download_docstring, + response_class=StreamingResponse, + responses=create_responses( + _archive_download_response, _bad_owner_response, _bad_archive_required_response)) +async def get_entries_archive_download( + with_query: WithQuery = Depends(query_parameters), + files: Files = Depends(files_parameters), + user: User = Depends(get_optional_user)): + + return _answer_entries_archive_download_request( + owner=with_query.owner, query=with_query.query, files=files, user=user) + + +@router.get( + '/{entry_id}', tags=[metadata_tag], + summary='Get the metadata of an entry by its id', + response_model=EntryMetadataResponse, + responses=create_responses(_bad_id_response), + response_model_exclude_unset=True, + response_model_exclude_none=True) +async def get_entry_metadata( + entry_id: str = Path(..., description='The unique entry id of the entry to retrieve metadata from.'), + required: MetadataRequired = Depends(metadata_required_parameters), + user: User = Depends(get_optional_user)): + ''' + Retrives the entry metadata for the given id. + ''' + + query = {'calc_id': entry_id} + response = perform_search(owner=Owner.all_, query=query, required=required, user_id=user.user_id if user is not None else None) + + if response.pagination.total == 0: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail='The entry with the given id does not exist or is not visible to you.') + + return { + 'entry_id': entry_id, + 'required': required, + 'data': response.data[0] + } + + +@router.get( + '/{entry_id}/raw', + tags=[raw_tag], + summary='Get the raw files metadata for an entry by its id', + response_model=EntryRawResponse, + responses=create_responses(_bad_id_response), + response_model_exclude_unset=True, + response_model_exclude_none=True) +async def get_entry_raw( + entry_id: str = Path(..., description='The unique entry id of the entry to retrieve raw data from.'), + files: Files = Depends(files_parameters), + user: User = Depends(get_optional_user)): + ''' + Returns the file metadata for all input and output files (including auxiliary files) + of the given `entry_id`. The first file will be the *mainfile*. + ''' + query = dict(calc_id=entry_id) + response = perform_search( + owner=Owner.visible, query=query, + required=MetadataRequired(include=['calc_id', 'upload_id', 'mainfile']), + user_id=user.user_id if user is not None else None) + + if response.pagination.total == 0: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail='The entry with the given id does not exist or is not visible to you.') + + uploads = _Uploads() + try: + return EntryRawResponse(entry_id=entry_id, data=_create_entry_raw(response.data[0], uploads)) + finally: + uploads.close() + + +@router.get( + '/{entry_id}/raw/download', + tags=[raw_tag], + summary='Get the raw data of an entry by its id', + response_class=StreamingResponse, + responses=create_responses(_bad_id_response, _raw_download_response)) +async def get_entry_raw_download( + entry_id: str = Path(..., description='The unique entry id of the entry to retrieve raw data from.'), + files: Files = Depends(files_parameters), + user: User = Depends(get_optional_user)): + ''' + Streams a .zip file with the raw files from the requested entry. + ''' + query = dict(calc_id=entry_id) + response = perform_search( + owner=Owner.visible, query=query, + required=MetadataRequired(include=['calc_id']), + user_id=user.user_id if user is not None else None) + + if response.pagination.total == 0: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail='The entry with the given id does not exist or is not visible to you.') + + return _answer_entries_raw_download_request(owner='public', query=query, files=files, user=user) + + +@router.get( + '/{entry_id}/archive', + tags=[archive_tag], + summary='Get the archive for an entry by its id', + response_model=EntryArchiveResponse, + response_model_exclude_unset=True, + response_model_exclude_none=True, + responses=create_responses(_bad_id_response)) +async def get_entry_archive( + entry_id: str = Path(..., description='The unique entry id of the entry to retrieve raw data from.'), + user: User = Depends(get_optional_user)): + ''' + Returns the full archive for the given `entry_id`. + ''' + query = dict(calc_id=entry_id) + response = perform_search( + owner=Owner.visible, query=query, + required=MetadataRequired(include=['calc_id', 'upload_id', 'parser_name']), + user_id=user.user_id if user is not None else None) + + if response.pagination.total == 0: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail='The entry with the given id does not exist or is not visible to you.') + + uploads = _Uploads() + try: + try: + archive_data = _read_archive(response.data[0], uploads, required='*') + except KeyError: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail='The entry with the given id does exist, but it has no archive.') + + return { + 'entry_id': entry_id, + 'data': archive_data['archive']} + finally: + uploads.close() diff --git a/nomad/app_fastapi/routers/users.py b/nomad/app_fastapi/routers/users.py new file mode 100644 index 0000000000000000000000000000000000000000..8ac94b2020db938f20f922cae4b60f1efb638302 --- /dev/null +++ b/nomad/app_fastapi/routers/users.py @@ -0,0 +1,27 @@ +from fastapi import Depends, APIRouter, status + +from nomad.app_fastapi.routers.auth import get_required_user +from nomad.app_fastapi.models import User, HTTPExceptionModel +from nomad.app_fastapi.utils import create_responses +from nomad.utils import strip + +router = APIRouter() +default_tag = 'users' + + +_authentication_required_response = status.HTTP_401_UNAUTHORIZED, { + 'model': HTTPExceptionModel, + 'description': strip(''' + Unauthorized. The operation requires authorization, + but no or bad authentication credentials are given.''')} + + +@router.get( + '/me', + tags=[default_tag], + summary='Get your account data', + description='Returnes the account data of the authenticated user.', + responses=create_responses(_authentication_required_response), + response_model=User) +async def read_users_me(current_user: User = Depends(get_required_user)): + return current_user diff --git a/nomad/app_fastapi/utils.py b/nomad/app_fastapi/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b5b07f9ad875229eea2fc098e64503c07fe93f53 --- /dev/null +++ b/nomad/app_fastapi/utils.py @@ -0,0 +1,101 @@ +from typing import Dict, Iterator, Any +from types import FunctionType +import sys +import inspect +from fastapi import Query, HTTPException # pylint: disable=unused-import +from pydantic import ValidationError, BaseModel # pylint: disable=unused-import +import zipstream + +if sys.version_info >= (3, 7): + import zipfile +else: + import zipfile37 as zipfile # pragma: no cover + + +def parameter_dependency_from_model(name: str, model_cls): + ''' + Takes a pydantic model class as input and creates a dependency with corresponding + Query parameter definitions that can be used for GET + requests. + + This will only work, if the fields defined in the input model can be turned into + suitable query parameters. Otherwise fastapi will complain down the road. + + Arguments: + name: Name for the dependency function. + model_cls: A ``BaseModel`` inheriting model class as input. + ''' + names = [] + annotations: Dict[str, type] = {} + defaults = [] + for field_model in model_cls.__fields__.values(): + field_info = field_model.field_info + + names.append(field_model.name) + annotations[field_model.name] = field_model.outer_type_ + defaults.append(Query(field_model.default, description=field_info.description)) + + code = inspect.cleandoc(''' + def %s(%s): + try: + return %s(%s) + except ValidationError as e: + errors = e.errors() + for error in errors: + error['loc'] = ['query'] + list(error['loc']) + raise HTTPException(422, detail=errors) + + ''' % ( + name, ', '.join(names), model_cls.__name__, + ', '.join(['%s=%s' % (name, name) for name in names]))) + + compiled = compile(code, 'string', 'exec') + env = {model_cls.__name__: model_cls} + env.update(**globals()) + func = FunctionType(compiled.co_consts[0], env, name) + func.__annotations__ = annotations + func.__defaults__ = (*defaults,) + + return func + + +class File(BaseModel): + path: str + f: Any + size: int + + +def create_streamed_zipfile( + files: Iterator[File], + compress: bool = False) -> Iterator[bytes]: + + ''' + Creates a streaming zipfile object that can be used in fastapi's ``StreamingResponse``. + ''' + + def path_to_write_generator(): + for file_obj in files: + def content_generator(): + while True: + data = file_obj.f.read(1024 * 64) + if not data: + break + yield data + + yield dict( + arcname=file_obj.path, + iterable=content_generator(), + buffer_size=file_obj.size) + + compression = zipfile.ZIP_DEFLATED if compress else zipfile.ZIP_STORED + zip_stream = zipstream.ZipFile(mode='w', compression=compression, allowZip64=True) + zip_stream.paths_to_write = path_to_write_generator() + + for chunk in zip_stream: + yield chunk + + +def create_responses(*args): + return { + status_code: response + for status_code, response in args} diff --git a/nomad/archive.py b/nomad/archive.py index 31193d754c85bddfd291adfd1fd562798c01713a..c802854ac4aabe2b3e84add129e10fb355034534 100644 --- a/nomad/archive.py +++ b/nomad/archive.py @@ -812,7 +812,7 @@ def delete_partial_archives_from_mongo(entry_ids: List[str]): def read_partial_archives_from_mongo(entry_ids: List[str], as_dict=False) -> Dict[str, Union[EntryArchive, Dict]]: ''' - Reads the partial archives for a set of entries of the same upload. + Reads the partial archives for a set of entries. Arguments: entry_ids: A list of entry ids. @@ -866,7 +866,7 @@ def compute_required_with_referenced(required): # TODO this function should be based on the metainfo if not isinstance(required, dict): - return required + return None if any(key.startswith('section_run') for key in required): return None diff --git a/nomad/config.py b/nomad/config.py index 3226bd9bd9f333695675d47393d1c593f1816f8e..934151dac372df0b6b2137be2f6395729159831e 100644 --- a/nomad/config.py +++ b/nomad/config.py @@ -149,7 +149,7 @@ logstash = NomadConfig( services = NomadConfig( api_host='localhost', api_port=8000, - api_base_path='/fairdi/nomad/latest', + api_prefix='/', api_secret='defaultApiSecret', api_chaos=0, admin_user_id='00000000-0000-0000-0000-000000000000', @@ -177,7 +177,7 @@ def api_url(ssl: bool = True, api: str = 'api'): base_url = '%s://%s/%s' % ( 'https' if services.https and ssl else 'http', services.api_host.strip('/'), - services.api_base_path.strip('/')) + services.api_prefix.strip('/')) return '%s/%s' % (base_url.strip('/'), api) @@ -305,6 +305,7 @@ parser_matching_size = 150 * 80 # 150 lines of 80 ASCII characters per line console_log_level = logging.WARNING max_upload_size = 32 * (1024 ** 3) raw_file_strip_cutoff = 1000 +max_entry_download = 500000 use_empty_parsers = False reprocess_unmatched = True metadata_file_name = 'nomad' diff --git a/nomad/datamodel/datamodel.py b/nomad/datamodel/datamodel.py index b2fefd3ce8b53993ebe04bdab328c5f297887bff..95c22a56bb9ac17add57e3cf34a91a0806ae809a 100644 --- a/nomad/datamodel/datamodel.py +++ b/nomad/datamodel/datamodel.py @@ -580,3 +580,8 @@ class EntryArchive(metainfo.MSection): processing_logs = metainfo.Quantity( type=Any, shape=['0..*'], description='The processing logs for this entry as a list of structlog entries.') + + +# preemptively create the elasticsearch document definition, which populates metrics and +# search quantities in the search_extension +EntryMetadata.m_def.a_elastic.document diff --git a/nomad/infrastructure.py b/nomad/infrastructure.py index 4778a37fb9a3c869be52d6f35481ae0aab9df364..8812c95cc740c1e8aa6c89a60e1320b3c33d5e68 100644 --- a/nomad/infrastructure.py +++ b/nomad/infrastructure.py @@ -23,6 +23,7 @@ is run once for each *api* and *worker* process. Individual functions for partia exist to facilitate testing, aspects of :py:mod:`nomad.cli`, etc. ''' +from typing import Dict, Any import os.path import os import shutil @@ -134,6 +135,9 @@ def setup_elastic(create_mappings=True): return elastic_client +class KeycloakError(Exception): pass + + class Keycloak(): ''' A class that encapsulates all keycloak related functions for easier mocking and @@ -171,6 +175,63 @@ class Keycloak(): return self.__public_keys + def basicauth(self, username: str, password: str) -> str: + ''' + Performs basic authentication and returns an access token. + + Raises: + KeycloakError + ''' + try: + token_info = self._oidc_client.token(username=username, password=password) + except KeycloakAuthenticationError as e: + raise KeycloakError(e) + except Exception as e: + logger.error('cannot perform basicauth', exc_info=e) + raise e + + return token_info['access_token'] + + def tokenauth(self, access_token: str) -> Dict[str, Any]: + ''' + Authenticates the given token and returns the user record. + + Raises: + KeycloakError + ''' + try: + kid = jwt.get_unverified_header(access_token)['kid'] + key = keycloak._public_keys.get(kid) + if key is None: + logger.error('The user provided keycloak public key does not exist. Does the UI use the right realm?') + raise KeycloakError(utils.strip(''' + Could not validate credentials. + The user provided keycloak public key does not exist. + Does the UI use the right realm?''')) + + options = dict(verify_aud=False, verify_exp=True, verify_iss=True) + payload = jwt.decode( + access_token, key=key, algorithms=['RS256'], options=options, + issuer='%s/realms/%s' % (config.keycloak.server_url.rstrip('/'), config.keycloak.realm_name)) + + user_id: str = payload.get('sub') + if user_id is None: + raise KeycloakError(utils.strip(''' + Could not validate credentials. + The given token does not contain a user_id.''')) + + return dict( + user_id=user_id, + email=payload.get('email', None), + first_name=payload.get('given_name', None), + last_name=payload.get('family_name', None)) + + except jwt.InvalidTokenError: + raise KeycloakError('Could not validate credentials. The given token is invalid.') + except Exception as e: + logger.error('cannot perform tokenauth', exc_info=e) + raise e + def authorize_flask(self, basic: bool = True) -> str: ''' Authorizes the current flask request with keycloak. Uses either Bearer or Basic diff --git a/nomad/metainfo/elastic_extension.py b/nomad/metainfo/elastic_extension.py index c616b2651c4874ed4932020383cfea307eda9c8a..ac031498477d24d6772fe90da7f7f793b107aa32 100644 --- a/nomad/metainfo/elastic_extension.py +++ b/nomad/metainfo/elastic_extension.py @@ -256,6 +256,8 @@ class Elastic(DefinitionAnnotation): index: A boolean that indicates if this quantity should be indexed or merely be part of the elastic document ``_source`` without being indexed for search. + aggregateable: + A boolean that determines, if this quantity can be used in aggregations ''' def __init__( self, @@ -287,3 +289,7 @@ class Elastic(DefinitionAnnotation): self.qualified_field = field else: self.qualified_field = '%s.%s' % (prefix, field) + + @property + def aggregateable(self): + return self.mapping is None or self.mapping.__class__.__name__ == 'Keyword' diff --git a/nomad/metainfo/metainfo.py b/nomad/metainfo/metainfo.py index 4be3d354bbc911e48b13e785be498348ed4e5fc5..222f257448980979d0efb87a9c482db2f3c02bfc 100644 --- a/nomad/metainfo/metainfo.py +++ b/nomad/metainfo/metainfo.py @@ -1694,7 +1694,13 @@ class MSection(metaclass=MObjectMeta): # TODO find a way to make this a subclas sub_section.m_copy(deep=True, parent=copy) for sub_section in self.m_get_sub_sections(sub_section_def)] - copy.__dict__[sub_section_def.name] = sub_sections_copy + if sub_section_def.repeats: + copy.__dict__[sub_section_def.name] = sub_sections_copy + else: + if len(sub_sections_copy) == 1: + copy.__dict__[sub_section_def.name] = sub_sections_copy[0] + else: + copy.__dict__[sub_section_def.name] = None return cast(MSectionBound, copy) diff --git a/nomad/metainfo/search_extension.py b/nomad/metainfo/search_extension.py index db6da1778087877669caa482c9152f8d54a1e707..3c3b2b82633287173d71a696be0626920b6b1b94 100644 --- a/nomad/metainfo/search_extension.py +++ b/nomad/metainfo/search_extension.py @@ -19,6 +19,7 @@ from typing import Callable, Any, Dict, List, DefaultDict from collections import defaultdict +from nomad import config from nomad.metainfo.elastic_extension import Elastic @@ -41,6 +42,12 @@ order_default_quantities_by_index: DefaultDict[str, Dict[str, 'Search']] = defau ''' The quantity for each domain (key) that is the default quantity to order search results by. ''' +search_quantities = search_quantities_by_index[config.elastic.index_name] +groups = groups_by_index[config.elastic.index_name] +metrics = metrics_by_index[config.elastic.index_name] +order_default_quantities = order_default_quantities_by_index[config.elastic.index_name] + + # TODO multi, split are more flask related class Search(Elastic): ''' diff --git a/nomad/search.py b/nomad/search.py index 7d544a9128f44c0107d1974466f0ba8c332473cc..6d112401d38e80e510c657aca6197432c783fd93 100644 --- a/nomad/search.py +++ b/nomad/search.py @@ -23,17 +23,18 @@ This module represents calculations in elastic search. from typing import Iterable, Dict, List, Any from elasticsearch_dsl import Search, Q, A, analyzer, tokenizer import elasticsearch.helpers -from elasticsearch.exceptions import NotFoundError +from elasticsearch.exceptions import NotFoundError, RequestError from datetime import datetime import json from nomad.datamodel.material import Material from nomad import config, datamodel, infrastructure, utils -from nomad.metainfo.search_extension import search_quantities_by_index, metrics_by_index, order_default_quantities_by_index, groups_by_index -search_quantities = search_quantities_by_index[config.elastic.index_name] -groups = groups_by_index[config.elastic.index_name] -metrics = metrics_by_index[config.elastic.index_name] -order_default_quantities = order_default_quantities_by_index[config.elastic.index_name] +from nomad.metainfo.search_extension import ( # pylint: disable=unused-import + search_quantities, metrics, order_default_quantities, groups) +from nomad.app_fastapi import models as api_models +from nomad.app_fastapi.models import ( + Pagination, PaginationResponse, Query, MetadataRequired, SearchResponse, Aggregation, + Statistic, StatisticResponse, AggregationOrderType, AggregationResponse, AggregationDataItem) path_analyzer = analyzer( @@ -47,6 +48,9 @@ class AlreadyExists(Exception): pass class ElasticSearchError(Exception): pass +class AuthenticationRequiredError(Exception): pass + + class ScrollIdNotFound(Exception): pass @@ -114,6 +118,43 @@ def refresh(): infrastructure.elastic_client.indices.refresh(config.elastic.index_name) +def _owner_es_query(owner: str, user_id: str = None): + if owner == 'all': + q = Q('term', published=True) + if user_id is not None: + q = q | Q('term', owners__user_id=user_id) + elif owner == 'public': + q = Q('term', published=True) & Q('term', with_embargo=False) + elif owner == 'visible': + q = Q('term', published=True) & Q('term', with_embargo=False) + if user_id is not None: + q = q | Q('term', owners__user_id=user_id) + elif owner == 'shared': + if user_id is None: + raise AuthenticationRequiredError('Authentication required for owner value shared.') + + q = Q('term', owners__user_id=user_id) + elif owner == 'user': + if user_id is None: + raise AuthenticationRequiredError('Authentication required for owner value user.') + + q = Q('term', uploader__user_id=user_id) + elif owner == 'staging': + if user_id is None: + raise AuthenticationRequiredError('Authentication required for owner value user') + q = Q('term', published=False) & Q('term', owners__user_id=user_id) + elif owner == 'admin': + if user_id is None or not datamodel.User.get(user_id=user_id).is_admin: + raise AuthenticationRequiredError('This can only be used by the admin user.') + q = None + else: + raise KeyError('Unsupported owner value') + + if q is not None: + return q + return Q() + + class SearchRequest: ''' Represents a search request and allows to execute that request. @@ -177,40 +218,7 @@ class SearchRequest: ValueError: If the owner_type requires a user but none is given, or the given user is not allowed to use the given owner_type. ''' - if owner_type == 'all': - q = Q('term', published=True) - if user_id is not None: - q = q | Q('term', owners__user_id=user_id) - elif owner_type == 'public': - q = Q('term', published=True) & Q('term', with_embargo=False) - elif owner_type == 'visible': - q = Q('term', published=True) & Q('term', with_embargo=False) - if user_id is not None: - q = q | Q('term', owners__user_id=user_id) - elif owner_type == 'shared': - if user_id is None: - raise ValueError('Authentication required for owner value shared.') - - q = Q('term', owners__user_id=user_id) - elif owner_type == 'user': - if user_id is None: - raise ValueError('Authentication required for owner value user.') - - q = Q('term', uploader__user_id=user_id) - elif owner_type == 'staging': - if user_id is None: - raise ValueError('Authentication required for owner value user') - q = Q('term', published=False) & Q('term', owners__user_id=user_id) - elif owner_type == 'admin': - if user_id is None or not datamodel.User.get(user_id=user_id).is_admin: - raise ValueError('This can only be used by the admin user.') - q = None - else: - raise KeyError('Unsupported owner value') - - if q is not None: - self.q = self.q & q - + self.q &= _owner_es_query(owner=owner_type, user_id=user_id) return self def search_parameters(self, **kwargs): @@ -448,7 +456,7 @@ class SearchRequest: examples: Number of results to return that has each value order_by: - A sortable quantity that should be used to order. The max of each + A sortable quantity that should be used to order. By default, the max of each value bucket is used. order: "desc" or "asc" @@ -535,7 +543,7 @@ class SearchRequest: if order == 1: search = search.sort(order_by_quantity.search_field) else: - search = search.sort('-%s' % order_by_quantity.search_field) + search = search.sort('-%s' % order_by_quantity.search_field) # pylint: disable=no-member search = search.params(preserve_order=True) @@ -564,7 +572,7 @@ class SearchRequest: if order == 1: search = search.sort(order_by_quantity.search_field) else: - search = search.sort('-%s' % order_by_quantity.search_field) + search = search.sort('-%s' % order_by_quantity.search_field) # pylint: disable=no-member search = search[(page - 1) * per_page: page * per_page] es_result = search.execute() @@ -807,3 +815,305 @@ def flat(obj, prefix=None): return result else: return obj + + +def _api_to_es_query(query: api_models.Query) -> Q: + ''' + Creates an ES query based on the API's query model. This needs to be a normalized + query expression with explicit objects for logical, set, and comparison operators. + Shorthand notations ala ``quantity:operator`` are not supported here; this + needs to be resolved via the respective pydantic validator. There is also no + validation of quantities and types. + ''' + def quantity_to_es(name: str, value: api_models.Value) -> Q: + # TODO depends on keyword or not, value might need normalization, etc. + quantity = search_quantities[name] + return Q('match', **{quantity.search_field: value}) + + def parameter_to_es(name: str, value: api_models.QueryParameterValue) -> Q: + + if isinstance(value, api_models.All): + return Q('bool', must=[ + quantity_to_es(name, item) + for item in value.op]) + + if isinstance(value, api_models.Any_): + return Q('bool', should=[ + quantity_to_es(name, item) + for item in value.op]) + + if isinstance(value, api_models.None_): + return Q('bool', must_not=[ + quantity_to_es(name, item) + for item in value.op]) + + if isinstance(value, api_models.ComparisonOperator): + quantity = search_quantities[name] + return Q('range', **{quantity.search_field: { + type(value).__name__.lower(): value.op}}) + + # list of values is treated as an "all" over the items + if isinstance(value, list): + return Q('bool', must=[ + quantity_to_es(name, item) + for item in value]) + + return quantity_to_es(name, value) + + def query_to_es(query: api_models.Query) -> Q: + if isinstance(query, api_models.LogicalOperator): + if isinstance(query, api_models.And): + return Q('bool', must=[query_to_es(operand) for operand in query.op]) + + if isinstance(query, api_models.Or): + return Q('bool', should=[query_to_es(operand) for operand in query.op]) + + if isinstance(query, api_models.Not): + return Q('bool', must_not=query_to_es(query.op)) + + raise NotImplementedError() + + if not isinstance(query, dict): + raise NotImplementedError() + + # dictionary is like an "and" of all items in the dict + if len(query) == 0: + return Q() + + if len(query) == 1: + key = next(iter(query)) + return parameter_to_es(key, query[key]) + + return Q('bool', must=[ + parameter_to_es(name, value) for name, value in query.items()]) + + return query_to_es(query) + + +def _api_to_es_statistic(es_search: Search, name: str, statistic: Statistic) -> A: + ''' + Creates an ES aggregation based on the API's statistic model. + ''' + + quantity = search_quantities[statistic.quantity.value] + if quantity.statistic_values is not None: + statistic.size = len(quantity.statistic_values) + + terms_kwargs = {} + if statistic.value_filter is not None: + terms_kwargs['include'] = '.*%s.*' % statistic.value_filter + + order_type = '_count' if statistic.order.type_ == AggregationOrderType.entries else '_key' + statistic_agg = es_search.aggs.bucket('statistic:%s' % name, A( + 'terms', + field=quantity.search_field, + size=statistic.size, + order={order_type: statistic.order.direction.value}, + **terms_kwargs)) + + for metric in statistic.metrics: + metric_quantity = metrics[metric.value] + statistic_agg.metric('metric:%s' % metric_quantity.metric_name, A( + metric_quantity.metric, + field=metric_quantity.search_field)) + + +def _es_to_api_statistics(es_response, name: str, statistic: Statistic) -> StatisticResponse: + ''' + Creates a StatisticResponse from elasticsearch response on a request executed with + the given statistics. + ''' + quantity = search_quantities[statistic.quantity.value] + + es_statistic = es_response.aggs['statistic:' + name] + statistic_data = {} + for bucket in es_statistic.buckets: + value_data = dict(entries=bucket.doc_count) + for metric in statistic.metrics: + value_data[metric.value] = bucket['metric:' + metric.value].value + statistic_data[bucket.key] = value_data + + if quantity.statistic_values is not None: + for value in quantity.statistic_values: + if value not in statistic_data: + statistic_data[value] = dict(entries=0, **{ + metric.value: 0 for metric in statistic.metrics}) + + return StatisticResponse(data=statistic_data, **statistic.dict(by_alias=True)) + + +def _api_to_es_aggregation(es_search: Search, name: str, agg: Aggregation) -> A: + ''' + Creates an ES aggregation based on the API's aggregation model. + ''' + quantity = search_quantities[agg.quantity.value] + terms = A('terms', field=quantity.search_field, order=agg.pagination.order.value) + + # We are using elastic searchs 'composite aggregations' here. We do not really + # compose aggregations, but only those pseudo composites allow us to use the + # 'after' feature that allows to scan through all aggregation values. + order_by = agg.pagination.order_by + if order_by is None: + composite = dict(sources={name: terms}, size=agg.pagination.size) + else: + order_quantity = search_quantities[order_by.value] + sort_terms = A('terms', field=order_quantity.search_field, order=agg.pagination.order.value) + composite = dict(sources=[{order_by.value: sort_terms}, {quantity.name: terms}], size=agg.pagination.size) + + if agg.pagination.after is not None: + if order_by is None: + composite['after'] = {name: agg.pagination.after} + else: + order_value, quantity_value = agg.pagination.after.split(':') + composite['after'] = {quantity.name: quantity_value, order_quantity.name: order_value} + + composite_agg = es_search.aggs.bucket('agg:%s' % name, 'composite', **composite) + + if agg.entries is not None and agg.entries.size > 0: + kwargs: Dict[str, Any] = {} + if agg.entries.required is not None: + if agg.entries.required.include is not None: + kwargs.update(_source=dict(includes=agg.entries.required.include)) + else: + kwargs.update(_source=dict(excludes=agg.entries.required.exclude)) + + composite_agg.metric('entries', A('top_hits', size=agg.entries.size, **kwargs)) + + # additional cardinality to get total + es_search.aggs.metric('agg:%s:total' % name, 'cardinality', field=quantity.search_field) + + +def _es_to_api_aggregation(es_response, name: str, agg: Aggregation) -> AggregationResponse: + ''' + Creates a AggregationResponse from elasticsearch response on a request executed with + the given aggregation. + ''' + order_by = agg.pagination.order_by + quantity = search_quantities[agg.quantity.value] + es_agg = es_response.aggs['agg:' + name] + + def get_entries(agg): + if 'entries' in agg: + return [item['_source'] for item in agg.entries.hits.hits] + else: + return None + + if agg.pagination.order_by is None: + agg_data = { + bucket.key[name]: AggregationDataItem(size=bucket.doc_count, data=get_entries(bucket)) + for bucket in es_agg.buckets} + else: + agg_data = { + bucket.key[quantity.search_field]: AggregationDataItem(size=bucket.doc_count, data=get_entries(bucket)) + for bucket in es_agg.buckets} + + aggregation_dict = agg.dict(by_alias=True) + pagination = PaginationResponse( + total=es_response.aggs['agg:%s:total' % name]['value'], + **aggregation_dict.pop('pagination')) + + if 'after_key' in es_agg: + after_key = es_agg['after_key'] + if order_by is None: + pagination.next_after = after_key[name] + else: + pagination.next_after = ':'.join(after_key.to_dict().values()) + + return AggregationResponse(data=agg_data, pagination=pagination, **aggregation_dict) + + +def search( + owner: str = 'public', + query: Query = None, + pagination: Pagination = None, + required: MetadataRequired = None, + aggregations: Dict[str, Aggregation] = {}, + statistics: Dict[str, Statistic] = {}, + user_id: str = None) -> SearchResponse: + + # The first half of this method creates the ES query. Then the query is run on ES. + # The second half is about transforming the ES response to a SearchResponse. + + # query and owner + if query is None: + query = {} + es_query = _api_to_es_query(query) + es_query &= _owner_es_query(owner=owner, user_id=user_id) + + # pagination + if pagination is None: + pagination = Pagination() + + search = Search(index=config.elastic.index_name) + + search = search.query(es_query) + order_field = search_quantities[pagination.order_by.value].search_field + sort = {order_field: pagination.order.value} + if order_field != 'calc_id': + sort['calc_id'] = pagination.order.value + search = search.sort(sort) + search = search.extra(size=pagination.size) + if pagination.after: + search = search.extra(search_after=pagination.after.rsplit(':', 1)) + + # required + if required: + if required.include is not None and pagination.order_by.value not in required.include: + required.include.append(pagination.order_by.value) + if required.exclude is not None and pagination.order_by.value in required.exclude: + required.exclude.remove(pagination.order_by.value) + search = search.source(includes=required.include, excludes=required.exclude) + + # statistics + for name, statistic in statistics.items(): + _api_to_es_statistic(search, name, statistic) + + # aggregations + for name, agg in aggregations.items(): + _api_to_es_aggregation(search, name, agg) + + # execute + try: + es_response = search.execute() + except RequestError as e: + raise ElasticSearchError(e) + + more_response_data = {} + + # pagination + next_after = None + if 0 < len(es_response.hits) < es_response.hits.total: + last = es_response.hits[-1] + if order_field == 'calc_id': + next_after = last['calc_id'] + else: + after_value = last + for order_field_segment in order_field.split('.'): + after_value = after_value[order_field_segment] + next_after = '%s:%s' % (after_value, last['calc_id']) + pagination_response = PaginationResponse( + total=es_response.hits.total, + next_after=next_after, + **pagination.dict()) + + # statistics + if len(statistics) > 0: + more_response_data['statistics'] = { + name: _es_to_api_statistics(es_response, name, statistic) + for name, statistic in statistics.items()} + + # aggregations + if len(aggregations) > 0: + more_response_data['aggregations'] = { + name: _es_to_api_aggregation(es_response, name, aggregation) + for name, aggregation in aggregations.items()} + + more_response_data['es_query'] = es_query.to_dict() + + return SearchResponse( + owner=owner, + query=query, + pagination=pagination_response, + required=required, + data=[hit.to_dict() for hit in es_response.hits], + **more_response_data) diff --git a/nomad/utils/__init__.py b/nomad/utils/__init__.py index e3d464adb4e13ab16669b15031645bad79edcb3a..d30e0c4dadf2310e22c3481a0d97ce7a8502ff97 100644 --- a/nomad/utils/__init__.py +++ b/nomad/utils/__init__.py @@ -35,6 +35,7 @@ Depending on the configuration all logs will also be send to a central logstash. .. autofunc::nomad.utils.create_uuid .. autofunc::nomad.utils.timer .. autofunc::nomad.utils.lnr +.. autofunc::nomad.utils.strip ''' from typing import List, Iterable @@ -50,6 +51,7 @@ import sys from datetime import timedelta import collections import logging +import inspect from nomad import config @@ -427,3 +429,8 @@ class RestrictedDict(OrderedDict): hash_str = json.dumps(self, sort_keys=True) return hash(hash_str) + + +def strip(docstring): + ''' Removes any unnecessary whitespaces from a multiline doc string or description. ''' + return inspect.cleandoc(docstring) diff --git a/ops/docker-compose/nomad-oasis/.gitignore b/ops/docker-compose/nomad-oasis/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..0967d032c37204860c0e46bd394cff0a00653d6f --- /dev/null +++ b/ops/docker-compose/nomad-oasis/.gitignore @@ -0,0 +1 @@ +nomad.yaml \ No newline at end of file diff --git a/ops/docker-compose/nomad-oasis/README.md b/ops/docker-compose/nomad-oasis/README.md index a4a819f707adb4f374a10c08d2edb5ea5d8672a3..51fecbd2f3c468c0d273c27dc9b300dedf7cab3a 100644 --- a/ops/docker-compose/nomad-oasis/README.md +++ b/ops/docker-compose/nomad-oasis/README.md @@ -183,7 +183,7 @@ client: url: 'http://<your-host>/nomad-oasis/api' services: - api_base_path: '/nomad-oasis' + api_prefix: '/nomad-oasis' admin_user_id: '<your admin user id>' keycloak: @@ -214,25 +214,27 @@ proxy is an nginx server and needs a configuration similar to this: ```none server { listen 80; - server_name <your-host>; + server_name www.example.com; proxy_set_header Host $host; - location / { + location /nomad-oasis/ { + rewrite ^/nomad-oasis/(.*) /$1 break; proxy_pass http://app:8000; } location ~ /nomad-oasis\/?(gui)?$ { - rewrite ^ /nomad-oasis/gui/ permanent; + rewrite ^ /gui/ permanent; } location /nomad-oasis/gui/ { proxy_intercept_errors on; error_page 404 = @redirect_to_index; + rewrite ^/nomad-oasis/(.*) /$1 break; proxy_pass http://app:8000; } location @redirect_to_index { - rewrite ^ /nomad-oasis/gui/index.html break; + rewrite ^ /gui/index.html break; proxy_pass http://app:8000; } @@ -242,23 +244,27 @@ server { if_modified_since off; expires off; etag off; + rewrite ^/nomad-oasis/(.*) /$1 break; proxy_pass http://app:8000; } location ~ \/api\/uploads\/?$ { client_max_body_size 35g; proxy_request_buffering off; + rewrite ^/nomad-oasis/(.*) /$1 break; proxy_pass http://app:8000; } location ~ \/api\/(raw|archive) { proxy_buffering off; + rewrite ^/nomad-oasis/(.*) /$1 break; proxy_pass http://app:8000; } location ~ \/api\/mirror { proxy_buffering off; proxy_read_timeout 600; + rewrite ^/nomad-oasis/(.*) /$1 break; proxy_pass http://app:8000; } } diff --git a/ops/docker-compose/nomad-oasis/nginx.conf b/ops/docker-compose/nomad-oasis/nginx.conf index 1c51a256e3d3edbd27fd49a5e0374d95f15a5e2b..5e608c96ad4ba792a50fb7ec4da509bfeafc3040 100644 --- a/ops/docker-compose/nomad-oasis/nginx.conf +++ b/ops/docker-compose/nomad-oasis/nginx.conf @@ -3,22 +3,24 @@ server { server_name www.example.com; proxy_set_header Host $host; - location / { + location /nomad-oasis/ { + rewrite ^/nomad-oasis/(.*) /$1 break; proxy_pass http://app:8000; } location ~ /nomad-oasis\/?(gui)?$ { - rewrite ^ /nomad-oasis/gui/ permanent; + rewrite ^ /gui/ permanent; } location /nomad-oasis/gui/ { proxy_intercept_errors on; error_page 404 = @redirect_to_index; + rewrite ^/nomad-oasis/(.*) /$1 break; proxy_pass http://app:8000; } location @redirect_to_index { - rewrite ^ /nomad-oasis/gui/index.html break; + rewrite ^ /gui/index.html break; proxy_pass http://app:8000; } @@ -28,23 +30,27 @@ server { if_modified_since off; expires off; etag off; + rewrite ^/nomad-oasis/(.*) /$1 break; proxy_pass http://app:8000; } location ~ \/api\/uploads\/?$ { client_max_body_size 35g; proxy_request_buffering off; + rewrite ^/nomad-oasis/(.*) /$1 break; proxy_pass http://app:8000; } - location ~ \/api\/(raw|archive) { + location ~ \/api.*\/(raw|archive) { proxy_buffering off; + rewrite ^/nomad-oasis/(.*) /$1 break; proxy_pass http://app:8000; } location ~ \/api\/mirror { proxy_buffering off; proxy_read_timeout 600; + rewrite ^/nomad-oasis/(.*) /$1 break; proxy_pass http://app:8000; } diff --git a/ops/helm/nomad/templates/api-deployment.yaml b/ops/helm/nomad/templates/api-deployment.yaml index 1755b029e5b7f998784d104e35bec970d1515bed..61d0ecb568385e745df3de6c8565c1584ae7d0f2 100644 --- a/ops/helm/nomad/templates/api-deployment.yaml +++ b/ops/helm/nomad/templates/api-deployment.yaml @@ -67,10 +67,6 @@ metadata: data: gunicorn.conf: | secure_scheme_headers = {'X-FORWARDED-PROTOCOL': 'ssl', 'X-FORWARDED-PROTO': 'https', 'X-FORWARDED-SSL': 'on'} - {{ if ne .Values.app.workerClass "sync" }} - worker_class = '{{ .Values.app.workerClass }}' - threads = {{ .Values.app.threads }} - {{ end }} worker_connections = 1000 workers = {{ .Values.app.worker }} --- @@ -226,14 +222,14 @@ spec: command: ["./run.sh"] livenessProbe: httpGet: - path: "{{ .Values.proxy.external.path }}/alive" + path: "/alive" port: 8000 initialDelaySeconds: 30 periodSeconds: 30 timeoutSeconds: 5 readinessProbe: httpGet: - path: "{{ .Values.proxy.external.path }}/alive" + path: "/alive" port: 8000 initialDelaySeconds: 15 periodSeconds: 15 diff --git a/ops/helm/nomad/templates/gui-deployment.yml b/ops/helm/nomad/templates/gui-deployment.yml index d014dc9ac1c717adf853e7eee6e5686bc50fde87..846b938f180e61a95ce0ba9e2acba8349d6aaa6b 100644 --- a/ops/helm/nomad/templates/gui-deployment.yml +++ b/ops/helm/nomad/templates/gui-deployment.yml @@ -36,43 +36,49 @@ data: {{ end }} location / { + rewrite ^{{ .Values.proxy.external.path }}/(.*) /$1 break; proxy_pass http://{{ include "nomad.fullname" . }}-app:8000; } location ~ {{ .Values.proxy.external.path }}\/?(gui)?$ { - rewrite ^ {{ .Values.proxy.external.path }}/gui/ permanent; + rewrite ^ /gui/ permanent; } location {{ .Values.proxy.external.path }}/gui/ { proxy_intercept_errors on; error_page 404 = @redirect_to_index; + rewrite ^{{ .Values.proxy.external.path }}/(.*) /$1 break; proxy_pass http://{{ include "nomad.fullname" . }}-app:8000; } location @redirect_to_index { - rewrite ^ {{ .Values.proxy.external.path }}/gui/index.html break; + rewrite ^ /gui/index.html break; proxy_pass http://{{ include "nomad.fullname" . }}-app:8000; } location {{ .Values.proxy.external.path }}/docs/ { proxy_intercept_errors on; error_page 404 = @redirect_to_index_docs; + rewrite ^{{ .Values.proxy.external.path }}/(.*) /$1 break; proxy_pass http://{{ include "nomad.fullname" . }}-app:8000; } location @redirect_to_index_docs { rewrite ^ {{ .Values.proxy.external.path }}/docs/index.html break; + rewrite ^{{ .Values.proxy.external.path }}/(.*) /$1 break; proxy_pass http://{{ include "nomad.fullname" . }}-app:8000; } location {{ .Values.proxy.external.path }}/encyclopedia/ { proxy_intercept_errors on; error_page 404 = @redirect_to_encyclopedia_index; + rewrite ^{{ .Values.proxy.external.path }}/(.*) /$1 break; proxy_pass http://{{ include "nomad.fullname" . }}-app:8000; } location @redirect_to_encyclopedia_index { rewrite ^ {{ .Values.proxy.external.path }}/encyclopedia/index.html break; + rewrite ^{{ .Values.proxy.external.path }}/(.*) /$1 break; proxy_pass http://{{ include "nomad.fullname" . }}-app:8000; } @@ -82,29 +88,34 @@ data: if_modified_since off; expires off; etag off; + rewrite ^{{ .Values.proxy.external.path }}/(.*) /$1 break; proxy_pass http://{{ include "nomad.fullname" . }}-app:8000; } location ~ \/api\/uploads\/?$ { client_max_body_size 35g; proxy_request_buffering off; + rewrite ^{{ .Values.proxy.external.path }}/(.*) /$1 break; proxy_pass http://{{ include "nomad.fullname" . }}-app:8000; } - location ~ \/api\/(raw|archive) { + location ~ \/api.*\/(raw|archive) { proxy_buffering off; + rewrite ^{{ .Values.proxy.external.path }}/(.*) /$1 break; proxy_pass http://{{ include "nomad.fullname" . }}-app:8000; } location ~ \/api\/mirror { proxy_buffering off; proxy_read_timeout {{ .Values.proxy.mirrorTimeout }}; + rewrite ^{{ .Values.proxy.external.path }}/(.*) /$1 break; proxy_pass http://{{ include "nomad.fullname" . }}-app:8000; } location ~ \/api\/repo\/edit { proxy_buffering off; proxy_read_timeout {{ .Values.proxy.editTimeout }}; + rewrite ^{{ .Values.proxy.external.path }}/(.*) /$1 break; proxy_pass http://{{ include "nomad.fullname" . }}-app:8000; } } diff --git a/ops/helm/nomad/templates/nomad-configmap.yml b/ops/helm/nomad/templates/nomad-configmap.yml index 4c47631f03bd885566552d388a5efe8b811c1235..fe1062ff7141ed7ff5d65827cd12bd1059fcd42d 100644 --- a/ops/helm/nomad/templates/nomad-configmap.yml +++ b/ops/helm/nomad/templates/nomad-configmap.yml @@ -27,7 +27,7 @@ data: services: api_host: "{{ .Values.proxy.external.host }}" api_port: {{ .Values.proxy.external.port }} - api_base_path: "{{ .Values.proxy.external.path }}" + api_prefix: "{{ .Values.proxy.external.path }}" api_secret: "{{ .Values.api.secret }}" https: {{ .Values.proxy.external.https }} upload_limit: {{ .Values.api.uploadLimit }} diff --git a/ops/helm/nomad/values.yaml b/ops/helm/nomad/values.yaml index 309de30709ca4aa9f33e0478ae76ccbfb6f3ea38..9db2353300e43d16970847368eac02f72b66cb6f 100644 --- a/ops/helm/nomad/values.yaml +++ b/ops/helm/nomad/values.yaml @@ -44,10 +44,6 @@ app: replicas: 1 ## Number of gunicorn worker. worker: 10 - ## Number of threads per gunicorn worker (for async workerClass) - threads: 4 - ## Gunircon worker class (http://docs.gunicorn.org/en/stable/settings.html#worker-class) - workerClass: 'gthread' console_loglevel: INFO logstash_loglevel: INFO nomadNodeType: "public" diff --git a/ops/tests/ping.py b/ops/tests/ping.py index ae57c4c2045e86c18e20c4b6029bf3b3c04e2473..73fc1762c8fda87cd4bd734a781af0650ad831a1 100644 --- a/ops/tests/ping.py +++ b/ops/tests/ping.py @@ -39,6 +39,6 @@ while True: atoms)) end = time.time() print('PING – %s – %f - %s' % (response.status_code, end - start, datetime.now())) - time.sleep(1) + time.sleep(5) except Exception as e: print('ERROR – %s' % e) diff --git a/requirements.txt b/requirements.txt index c887f77d82f894cf07413d9cfdbfa08ec19ba0b3..225f73a3d5ec17739ef158fadefba762298f5159 100644 --- a/requirements.txt +++ b/requirements.txt @@ -76,6 +76,9 @@ python-json-logger recommonmark jinja2 rdflib +fastapi +uvicorn[standard] +python-multipart # [dev] setuptools @@ -98,3 +101,4 @@ names essential_generators twine python-gitlab +devtools diff --git a/run.sh b/run.sh index 4517aebb8e8ed544821fe315d288412ef19890e5..fd69c688f67d4b15d0d9e41cc4a195634cae26f9 100644 --- a/run.sh +++ b/run.sh @@ -3,4 +3,4 @@ python -m nomad.cli admin ops gui-config params=() [ -e gunicorn.conf ] && params+=(--config gunicorn.conf) [ -e gunicorn.log.conf ] && params+=(--log-config gunicorn.log.conf) -python -m gunicorn.app.wsgiapp "${params[@]}" -b 0.0.0.0:8000 nomad.app:app \ No newline at end of file +python -m gunicorn.app.wsgiapp "${params[@]}" --worker-class=uvicorn.workers.UvicornWorker -b 0.0.0.0:8000 nomad.app_fastapi.main:app diff --git a/tests/app_fastapi/routers/conftest.py b/tests/app_fastapi/routers/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..cc7f31e1e9d057c723c83cb69c42fee11775821f --- /dev/null +++ b/tests/app_fastapi/routers/conftest.py @@ -0,0 +1,9 @@ +import pytest +from fastapi.testclient import TestClient + +from nomad.app_fastapi.main import app + + +@pytest.fixture(scope='session') +def client(): + return TestClient(app, base_url='http://testserver/api/v1/') diff --git a/tests/app_fastapi/routers/test_auth.py b/tests/app_fastapi/routers/test_auth.py new file mode 100644 index 0000000000000000000000000000000000000000..3efb27d6c6948a44a27c1aa87c22bd1726535e3c --- /dev/null +++ b/tests/app_fastapi/routers/test_auth.py @@ -0,0 +1,24 @@ +import pytest +from urllib.parse import urlencode + + +def perform_get_token_test(client, http_method, status_code, username, password): + if http_method == 'post': + response = client.post( + 'auth/token', + data=dict(username=username, password=password)) + else: + response = client.get('auth/token?%s' % urlencode( + dict(username=username, password=password))) + + assert response.status_code == status_code + + +@pytest.mark.parametrize('http_method', ['post', 'get']) +def test_get_token(client, test_user, http_method): + perform_get_token_test(client, http_method, 200, test_user.username, 'password') + + +@pytest.mark.parametrize('http_method', ['post', 'get']) +def test_get_token_bad_credentials(client, http_method): + perform_get_token_test(client, http_method, 401, 'bad', 'credentials') diff --git a/tests/app_fastapi/routers/test_entries.py b/tests/app_fastapi/routers/test_entries.py new file mode 100644 index 0000000000000000000000000000000000000000..6d5f73f5ce7352cdf1adaeccd278c13afc3a01e9 --- /dev/null +++ b/tests/app_fastapi/routers/test_entries.py @@ -0,0 +1,982 @@ +from typing import List +import pytest +from devtools import debug +from datetime import datetime +from urllib.parse import urlencode +import zipfile +import io +import json + +from nomad import infrastructure, config +from nomad.archive import write_partial_archive_to_mongo +from nomad.metainfo.search_extension import search_quantities +from nomad.datamodel import EntryArchive, EntryMetadata, DFTMetadata +from nomad.app_fastapi.models import AggregateableQuantity, Metric + +from tests.conftest import clear_raw_files, clear_elastic_infra +from tests.test_files import create_test_upload_files +from tests.utils import assert_at_least + +''' +These are the tests for all API operations below ``entries``. The tests are organized +using the following type of methods: fixtures, ``perfrom_*_test``, ``assert_*``, and +``test_*``. While some ``test_*`` methods test individual API operations, some +test methods will test multiple API operations that use common aspects like +supporting queries, pagination, or the owner parameter. The test methods will use +``perform_*_test`` methods as an parameter. Similarely, the ``assert_*`` methods allow +to assert for certain aspects in the responses. +''' + + +@pytest.fixture(scope='module') +def data(elastic_infra, raw_files_infra, mongo_infra, test_user, other_test_user): + ''' + Provides a couple of uploads and entries including metadata, raw-data, and + archive files. + + 23 published without embargo + 1 unpublished + 1 unpublished shared + 1 published with embargo + 1 published shared with embargo + + partial archive exists only for id_01 + raw files and archive file for id_02 are missing + id_10, id_11 reside in the same directory + ''' + archives: List[EntryArchive] = [] + archive = EntryArchive() + entry_metadata = archive.m_create( + EntryMetadata, + domain='dft', + upload_id='upload_id_1', + upload_time=datetime.now(), + uploader=test_user, + published=True, + with_embargo=False, + atoms=['H', 'O'], + n_atoms=2, + parser_name='parsers/vasp') + entry_metadata.m_create( + DFTMetadata, + code_name='VASP', + xc_functional='GGA', + system='bulk') + archive.m_update_from_dict({ + 'section_run': [{}], + 'section_workflow': {} + }) + + # one upload with two calc published with embargo, one shared + archives.clear() + entry_metadata.m_update( + upload_id='id_embargo', + calc_id='id_embargo', + mainfile='test_content/test_embargo_entry/mainfile.json', + shared_with=[], + with_embargo=True) + entry_metadata.a_elastic.index() + archives.append(archive.m_copy(deep=True)) + entry_metadata.m_update( + calc_id='id_embargo_shared', + mainfile='test_content/test_embargo_entry_shared/mainfile.json', + shared_with=[other_test_user]) + entry_metadata.a_elastic.index() + archives.append(archive.m_copy(deep=True)) + create_test_upload_files(entry_metadata.upload_id, archives) + + # one upload with two calc in staging, one shared + archives.clear() + entry_metadata.m_update( + upload_id='id_unpublished', + calc_id='id_unpublished', + mainfile='test_content/test_entry/mainfile.json', + with_embargo=False, + shared_with=[], + published=False) + entry_metadata.a_elastic.index() + archives.append(archive.m_copy(deep=True)) + entry_metadata.m_update( + calc_id='id_unpublished_shared', + mainfile='test_content/test_entry_shared/mainfile.json', + shared_with=[other_test_user]) + entry_metadata.a_elastic.index() + archives.append(archive.m_copy(deep=True)) + create_test_upload_files( + entry_metadata.upload_id, archives, published=False) + + # one upload with 23 calcs published + archives.clear() + for i in range(1, 24): + mainfile = 'test_content/subdir/test_entry_%02d/mainfile.json' % i + if i == 11: + mainfile = 'test_content/subdir/test_entry_10/mainfile_11.json' + entry_metadata.m_update( + upload_id='id_published', + calc_id='id_%02d' % i, + mainfile=mainfile, + with_embargo=False, + published=True, + shared_with=[]) + entry_metadata.a_elastic.index() + if i != 2: + archives.append(archive.m_copy(deep=True)) + if i == 1: + write_partial_archive_to_mongo(archive) + + infrastructure.elastic_client.indices.refresh(index=config.elastic.index_name) + create_test_upload_files(entry_metadata.upload_id, archives) + + yield + + clear_elastic_infra() + clear_raw_files() + + +def perform_entries_metadata_test( + client, owner=None, headers={}, status_code=200, + entries=None, http_method='get', **kwargs): + + if http_method == 'get': + params = {} + if owner is not None: + params['owner'] = owner + for value in kwargs.values(): + params.update(**value) + response = client.get( + 'entries?%s' % urlencode(params, doseq=True), headers=headers) + + elif http_method == 'post': + body = dict(**kwargs) + if owner is not None: + body['owner'] = owner + response = client.post('entries/query', headers=headers, json=body) + + else: + assert False + + response_json = assert_entries_metadata_response(response, status_code=status_code) + + if response_json is None: + return + + assert 'pagination' in response_json + if entries is not None and entries >= 0: + assert response_json['pagination']['total'] == entries + + return response_json + + +def perform_entries_raw_download_test( + client, headers={}, query={}, owner=None, files={}, entries=-1, files_per_entry=5, + status_code=200, http_method='get'): + + if owner == 'all': + # This operation is not allow for owner 'all' + status_code = 401 + + if http_method == 'post': + body = {'query': query, 'files': files} + if owner is not None: + body['owner'] = owner + response = client.post('entries/raw/download/query', headers=headers, json=body) + + elif http_method == 'get': + params = dict(**query) + params.update(**files) + if owner is not None: + params['owner'] = owner + response = client.get('entries/raw/download?%s' % urlencode(params, doseq=True), headers=headers) + + else: + assert False + + assert_response(response, status_code) + if status_code == 200: + assert_raw_zip_file( + response, files=entries * files_per_entry + 1, manifest_entries=entries, + compressed=files.get('compress', False)) + + +def perform_entries_raw_test( + client, owner=None, headers={}, status_code=200, + entries=None, http_method='get', files_per_entry=-1, **kwargs): + + if owner == 'all': + # This operation is not allow for owner 'all' + status_code = 401 + + if http_method == 'get': + params = {} + if owner is not None: + params['owner'] = owner + for value in kwargs.values(): + params.update(**value) + response = client.get( + 'entries/raw?%s' % urlencode(params, doseq=True), headers=headers) + + elif http_method == 'post': + body = dict(**kwargs) + if owner is not None: + body['owner'] = owner + response = client.post('entries/raw/query', headers=headers, json=body) + + else: + assert False + + response_json = assert_entries_metadata_response(response, status_code=status_code) + + if response_json is None: + return None + + assert 'pagination' in response_json + if entries is not None: + assert response_json['pagination']['total'] == entries + + assert_entries_raw_response(response_json, files_per_entry=files_per_entry) + + return response_json + + +def perform_entries_archive_download_test( + client, headers={}, query={}, owner=None, files={}, + entries=-1, status_code=200, http_method='get'): + + if owner == 'all': + # This operation is not allow for owner 'all' + status_code = 401 + + if http_method == 'post': + body = {'query': query, 'files': files} + if owner is not None: + body['owner'] = owner + response = client.post('entries/archive/download/query', headers=headers, json=body) + + elif http_method == 'get': + params = dict(**query) + params.update(**files) + if owner is not None: + params['owner'] = owner + response = client.get('entries/archive/download?%s' % urlencode(params, doseq=True), headers=headers) + + else: + assert False + + assert_response(response, status_code) + if status_code == 200: + assert_archive_zip_file( + response, entries=entries, + compressed=files.get('compress', False)) + + +def perform_entries_archive_test( + client, headers={}, entries=-1, status_code=200, http_method='get', **kwargs): + + if kwargs.get('owner') == 'all': + # This operation is not allow for owner 'all' + status_code = 401 + + if http_method == 'get': + assert 'required' not in kwargs + params = {} + if 'owner' in kwargs: params.update(owner=kwargs['owner']) + if 'query' in kwargs: params.update(**kwargs['query']) + if 'pagination' in kwargs: params.update(**kwargs['pagination']) + response = client.get('entries/archive?%s' % urlencode(params, doseq=True), headers=headers) + + else: + body = dict(**kwargs) + response = client.post('entries/archive/query', headers=headers, json=body) + + assert_response(response, status_code) + if status_code != 200: + return None + + json_response = response.json() + if entries >= 0: + assert json_response['pagination']['total'] == entries + for archive_data in json_response['data']: + required = kwargs.get('required', '*') + archive = archive_data['archive'] + if required == '*': + for key in ['section_metadata', 'section_run']: + assert key in archive + else: + for key in required: assert key in archive + for key in archive: assert key in required + + return json_response + + +def assert_response(response, status_code=None): + ''' General assertions for status_code and error messages ''' + if status_code and response.status_code != status_code: + try: + debug(response.json()) + except Exception: + pass + + assert status_code is None or response.status_code == status_code + + if status_code == 422: + response_json = response.json() + details = response_json['detail'] + assert len(details) > 0 + for detail in details: + assert 'loc' in detail + assert 'msg' in detail + return None + + if 400 <= status_code < 500: + response_json = response.json() + assert 'detail' in response_json + return None + + +def assert_entries_metadata_response(response, status_code=None): + assert_response(response, status_code) + + if status_code != 200 or response.status_code != 200: + return None + + response_json = response.json() + assert 'es_query' not in response_json + assert 'data' in response_json + return response_json + + +def assert_statistic(response_json, name, statistic, size=-1): + assert 'statistics' in response_json + assert name in response_json['statistics'] + statistic_response = response_json['statistics'][name] + for key in ['data', 'size', 'order', 'quantity']: + assert key in statistic_response + + assert_at_least(statistic, statistic_response) + + default_size = search_quantities[statistic['quantity']].statistic_size + assert statistic.get('size', default_size) >= len(statistic_response['data']) + + if size != -1: + assert len(statistic_response['data']) == size + + values = list(statistic_response['data'].keys()) + for index, value in enumerate(values): + data = statistic_response['data'][value] + assert 'entries' in data + for metric in statistic.get('metrics', []): + assert metric in data + + if index < len(values) - 1: + + def order_value(value, data): + if statistic_response['order']['type'] == 'entries': + return data['entries'] + else: + return value + + if statistic_response['order']['direction'] == 'asc': + assert order_value(value, data) <= order_value(values[index + 1], statistic_response['data'][values[index + 1]]) + else: + assert order_value(value, data) >= order_value(values[index + 1], statistic_response['data'][values[index + 1]]) + + if 'order' in statistic: + assert statistic_response['order']['type'] == statistic['order'].get('type', 'entries') + assert statistic_response['order']['direction'] == statistic['order'].get('direction', 'desc') + + +def assert_required(data, required): + if 'include' in required: + for key in data: + assert key in required['include'] or '%s.*' % key in required['include'] or key == 'calc_id' + if 'exclude' in required: + for key in required['exclude']: + assert key not in data or key == 'calc_id' + + +def assert_aggregations(response_json, name, agg, total: int, size: int): + assert 'aggregations' in response_json + assert name in response_json['aggregations'] + agg_response = response_json['aggregations'][name] + + for key in ['data', 'pagination', 'quantity']: + assert key in agg_response + + assert_at_least(agg, agg_response) + + n_data = len(agg_response['data']) + assert agg.get('pagination', {}).get('size', 10) >= n_data + assert agg_response['pagination']['total'] >= n_data + for item in agg_response['data'].values(): + for key in ['size']: + assert key in item + assert item['size'] > 0 + if size >= 0: + assert n_data == size + if total >= 0: + assert agg_response['pagination']['total'] == total + + if 'entries' in agg: + agg_data = [item['data'][0] for item in agg_response['data'].values()] + else: + agg_data = [{agg['quantity']: value} for value in agg_response['data']] + + if 'pagination' in agg: + assert_pagination(agg['pagination'], agg_response['pagination'], agg_data) + else: + assert_pagination({}, agg_response['pagination'], agg_data, order_by=agg['quantity']) + + if 'entries' in agg: + for item in agg_response['data'].values(): + assert 'data' in item + assert agg['entries'].get(size, 10) >= len(item['data']) > 0 + if 'required' in agg['entries']: + for entry in item['data']: + assert_required(entry, agg['entries']['required']) + + +def assert_pagination(pagination, pagination_response, data, order_by=None, order=None): + assert_at_least(pagination, pagination_response) + assert len(data) <= pagination_response['size'] + assert len(data) <= pagination_response['total'] + + if order is None: + order = pagination_response.get('order', 'asc') + if order_by is None: + order_by = pagination_response.get('order_by') + + if order_by is not None: + for index, item in enumerate(data): + if index < len(data) - 1 and order_by in item: + if order == 'desc': + assert item[order_by] >= data[index + 1][order_by] + else: + assert item[order_by] <= data[index + 1][order_by] + + +def assert_raw_zip_file( + response, files: int = -1, manifest_entries: int = -1, compressed: bool = False): + + manifest_keys = ['calc_id', 'upload_id', 'mainfile'] + + assert len(response.content) > 0 + with zipfile.ZipFile(io.BytesIO(response.content)) as zip_file: + with zip_file.open('manifest.json', 'r') as f: + manifest = json.load(f) + + with_missing_files = any(entry['calc_id'] == 'id_02' for entry in manifest) + with_overlapping_files = any(entry['calc_id'] == 'id_11' for entry in manifest) + + assert zip_file.testzip() is None + zip_files = set(zip_file.namelist()) + if files >= 0: + if with_missing_files or with_overlapping_files: + assert files - (5 if with_missing_files else 0) - (4 if with_overlapping_files else 0) <= len(zip_files) < files + else: + assert len(zip_files) == files + assert (zip_file.getinfo(zip_file.namelist()[0]).compress_type > 0) == compressed + + for path in zip_files: + assert path == 'manifest.json' or path.startswith('id_') + + if manifest_entries >= 0: + assert len(manifest) == manifest_entries + + for entry in manifest: + if 'mainfile' in manifest: + manifest['mainfile'] in zip_files + assert all(key in entry for key in manifest_keys) + assert all(key in manifest_keys for key in entry) + + +def assert_entries_raw_response(response_json, files_per_entry: int = -1): + assert 'data' in response_json + for entry in response_json['data']: + assert_entry_raw(entry, files_per_entry) + + +def assert_entry_raw_response(response_json, files_per_entry: int = -1): + for key in ['entry_id', 'data']: + assert key in response_json + assert_entry_raw(response_json['data'], files_per_entry=files_per_entry) + + +def assert_entry_raw(data, files_per_entry: int = -1): + for key in ['upload_id', 'calc_id', 'files']: + assert key in data + files = data['files'] + if files_per_entry >= 0: + if data['calc_id'] == 'id_02': + # missing files + assert len(files) == 0 + elif data['calc_id'] in ['id_10', 'id_11']: + # overlapping files + assert len(files) == files_per_entry + 1 + else: + assert len(files) == files_per_entry + for file_ in files: + assert 'size' in file_ + assert 'path' in file_ + + +def assert_archive_zip_file(response, entries: int = -1, compressed: bool = False): + manifest_keys = ['calc_id', 'upload_id', 'path', 'parser_name'] + + assert len(response.content) > 0 + with zipfile.ZipFile(io.BytesIO(response.content)) as zip_file: + assert zip_file.testzip() is None + with zip_file.open('manifest.json', 'r') as f: + manifest = json.load(f) + + with_missing_files = any(entry['calc_id'] == 'id_02' for entry in manifest) + + zip_files = set(zip_file.namelist()) + if entries >= 0: + assert len(zip_files) == entries + 1 - (1 if with_missing_files else 0) + assert (zip_file.getinfo(zip_file.namelist()[0]).compress_type > 0) == compressed + + for path in zip_files: + assert path.endswith('.json') + with zip_file.open(path, 'r') as f: + data = json.load(f) + if path != 'manifest.json': + for key in ['calc_id', 'archive']: + assert key in data + assert_archive(data['archive']) + + if entries >= 0: + assert len(manifest) == entries + + for entry in manifest: + if 'mainfile' in manifest: + manifest['path'] in zip_files + assert all(key in entry for key in manifest_keys) + assert all(key in manifest_keys for key in entry) + + +def assert_archive_response(response_json, required=None): + for key in ['entry_id', 'data']: + assert key in response_json + assert_archive(response_json['data'], required=required) + + +def assert_archive(archive, required=None): + for key in ['section_metadata']: + assert key in archive + + +n_code_names = search_quantities['dft.code_name'].statistic_size + + +@pytest.mark.parametrize('statistic, size, status_code, user', [ + pytest.param({'quantity': 'dft.code_name'}, n_code_names, 200, None, id='fixed-values'), + pytest.param({'quantity': 'dft.code_name', 'metrics': ['uploads']}, n_code_names, 200, None, id='metrics'), + pytest.param({'quantity': 'dft.code_name', 'metrics': ['does not exist']}, -1, 422, None, id='bad-metric'), + pytest.param({'quantity': 'calc_id', 'size': 1000}, 23, 200, None, id='size-to-large'), + pytest.param({'quantity': 'calc_id', 'size': 10}, 10, 200, None, id='size'), + pytest.param({'quantity': 'calc_id', 'size': -1}, -1, 422, None, id='bad-size-1'), + pytest.param({'quantity': 'calc_id', 'size': 0}, -1, 422, None, id='bad-size-2'), + pytest.param({'quantity': 'calc_id'}, 20, 200, None, id='size-default'), + pytest.param({'quantity': 'calc_id', 'value_filter': '_0'}, 9, 200, None, id='filter'), + pytest.param({'quantity': 'calc_id', 'value_filter': '.*_0.*'}, -1, 422, None, id='bad-filter'), + pytest.param({'quantity': 'upload_id', 'order': {'type': 'values'}}, 3, 200, 'test_user', id='order-type'), + pytest.param({'quantity': 'upload_id', 'order': {'direction': 'asc'}}, 3, 200, 'test_user', id='order-direction'), + pytest.param({'quantity': 'does not exist'}, -1, 422, None, id='bad-quantity')]) +def test_entries_statistics(client, data, test_user_auth, statistic, size, status_code, user): + statistics = {'test_statistic': statistic} + headers = {} + if user == 'test_user': + headers = test_user_auth + + response_json = perform_entries_metadata_test( + client, headers=headers, owner='visible', statistics=statistics, + status_code=status_code, http_method='post') + + if response_json is None: + return + + assert_statistic(response_json, 'test_statistic', statistic, size=size) + + +def test_entries_statistics_ignore_size(client, data): + statistic = {'quantity': 'dft.code_name', 'size': 10} + statistics = {'test_statistic': statistic} + response_json = perform_entries_metadata_test( + client, statistics=statistics, status_code=200, http_method='post') + statistic.update(size=n_code_names) + assert_statistic(response_json, 'test_statistic', statistic, size=n_code_names) + + +def test_entries_all_statistics(client, data): + statistics = { + quantity.value: {'quantity': quantity.value, 'metrics': [metric.value for metric in Metric]} + for quantity in AggregateableQuantity} + response_json = perform_entries_metadata_test( + client, statistics=statistics, status_code=200, http_method='post') + for name, statistic in statistics.items(): + assert_statistic(response_json, name, statistic) + + +@pytest.mark.parametrize('aggregation, total, size, status_code', [ + pytest.param({'quantity': 'upload_id', 'pagination': {'order_by': 'uploader'}}, 3, 3, 200, id='order'), + pytest.param({'quantity': 'dft.labels_springer_classification'}, 0, 0, 200, id='no-results'), + pytest.param({'quantity': 'upload_id', 'pagination': {'after': 'id_published'}}, 3, 1, 200, id='after'), + pytest.param({'quantity': 'upload_id', 'pagination': {'order_by': 'uploader', 'after': 'Sheldon Cooper:id_published'}}, 3, 1, 200, id='after-order'), + pytest.param({'quantity': 'upload_id', 'entries': {'size': 10}}, 3, 3, 200, id='entries'), + pytest.param({'quantity': 'upload_id', 'entries': {'size': 1}}, 3, 3, 200, id='entries-size'), + pytest.param({'quantity': 'upload_id', 'entries': {'size': 0}}, -1, -1, 422, id='bad-entries'), + pytest.param({'quantity': 'upload_id', 'entries': {'size': 10, 'required': {'include': ['calc_id', 'uploader']}}}, 3, 3, 200, id='entries-include'), + pytest.param({'quantity': 'upload_id', 'entries': {'size': 10, 'required': {'exclude': ['files', 'mainfile']}}}, 3, 3, 200, id='entries-exclude') +]) +def test_entries_aggregations(client, data, test_user_auth, aggregation, total, size, status_code): + headers = test_user_auth + aggregations = {'test_agg_name': aggregation} + response_json = perform_entries_metadata_test( + client, headers=headers, owner='visible', aggregations=aggregations, + pagination=dict(size=0), + status_code=status_code, http_method='post') + + if response_json is None: + return + + assert_aggregations(response_json, 'test_agg_name', aggregation, total=total, size=size) + + +@pytest.mark.parametrize('required, status_code', [ + pytest.param({'include': ['calc_id', 'upload_id']}, 200, id='include'), + pytest.param({'include': ['dft.*', 'upload_id']}, 200, id='include-section'), + pytest.param({'exclude': ['upload_id']}, 200, id='exclude'), + pytest.param({'exclude': ['missspelled', 'upload_id']}, 422, id='bad-quantitiy'), + pytest.param({'exclude': ['calc_id']}, 200, id='exclude-id'), + pytest.param({'include': ['upload_id']}, 200, id='include-id') +]) +@pytest.mark.parametrize('http_method', ['post', 'get']) +def test_entries_required(client, data, required, status_code, http_method): + response_json = perform_entries_metadata_test( + client, required=required, pagination={'size': 1}, status_code=status_code, http_method=http_method) + + if response_json is None: + return + + assert_required(response_json['data'][0], required) + + +@pytest.mark.parametrize('entry_id, required, status_code', [ + pytest.param('id_01', {}, 200, id='id'), + pytest.param('doesnotexist', {}, 404, id='404'), + pytest.param('id_01', {'include': ['calc_id', 'upload_id']}, 200, id='include'), + pytest.param('id_01', {'exclude': ['upload_id']}, 200, id='exclude'), + pytest.param('id_01', {'exclude': ['calc_id', 'upload_id']}, 200, id='exclude-calc-id') +]) +def test_entry_metadata(client, data, entry_id, required, status_code): + response = client.get('entries/%s?%s' % (entry_id, urlencode(required, doseq=True))) + response_json = assert_entries_metadata_response(response, status_code=status_code) + + if response_json is None: + return + + assert_required(response_json['data'], required) + + +@pytest.mark.parametrize('query, files, entries, files_per_entry, status_code', [ + pytest.param({}, {}, 23, 5, 200, id='all'), + pytest.param({'calc_id': 'id_01'}, {}, 1, 5, 200, id='all'), + pytest.param({'dft.code_name': 'DOESNOTEXIST'}, {}, 0, 5, 200, id='empty') +]) +@pytest.mark.parametrize('http_method', ['post', 'get']) +def test_entries_raw(client, data, query, files, entries, files_per_entry, status_code, http_method): + perform_entries_raw_test( + client, status_code=status_code, query=query, files=files, entries=entries, + files_per_entry=files_per_entry, http_method=http_method) + + +@pytest.mark.parametrize('query, files, entries, files_per_entry, status_code', [ + pytest.param({}, {}, 23, 5, 200, id='all'), + pytest.param({'dft.code_name': 'DOESNOTEXIST'}, {}, 0, 5, 200, id='empty'), + pytest.param({}, {'glob_pattern': '*.json'}, 23, 1, 200, id='glob'), + pytest.param({}, {'re_pattern': '[a-z]*\\.aux'}, 23, 4, 200, id='re'), + pytest.param({}, {'re_pattern': 'test_entry_02'}, 1, 5, 200, id='re-filter-entries'), + pytest.param({}, {'re_pattern': 'test_entry_02/.*\\.json'}, 1, 1, 200, id='re-filter-entries-and-files'), + pytest.param({}, {'glob_pattern': '*.json', 're_pattern': '.*\\.aux'}, 23, 4, 200, id='re-overwrites-glob'), + pytest.param({}, {'re_pattern': '**'}, -1, -1, 422, id='bad-re-pattern'), + pytest.param({}, {'compress': True}, 23, 5, 200, id='compress') +]) +@pytest.mark.parametrize('http_method', ['post', 'get']) +def test_entries_download_raw(client, data, query, files, entries, files_per_entry, status_code, http_method): + perform_entries_raw_download_test( + client, status_code=status_code, query=query, files=files, entries=entries, + files_per_entry=files_per_entry, http_method=http_method) + + +@pytest.mark.parametrize('http_method', ['post', 'get']) +@pytest.mark.parametrize('test_method', [ + pytest.param(perform_entries_raw_download_test, id='raw-download'), + pytest.param(perform_entries_archive_download_test, id='archive-download')]) +def test_entries_download_max(monkeypatch, client, data, test_method, http_method): + monkeypatch.setattr('nomad.config.max_entry_download', 20) + + test_method(client, status_code=400, http_method=http_method) + + +@pytest.mark.parametrize('entry_id, files_per_entry, status_code', [ + pytest.param('id_01', 5, 200, id='id'), + pytest.param('id_embargo', -1, 404, id='404'), + pytest.param('doesnotexist', -1, 404, id='404')]) +def test_entry_raw(client, data, entry_id, files_per_entry, status_code): + response = client.get('entries/%s/raw' % entry_id) + assert_response(response, status_code) + if status_code == 200: + assert_entry_raw_response(response.json(), files_per_entry=files_per_entry) + + +@pytest.mark.parametrize('entry_id, files, files_per_entry, status_code', [ + pytest.param('id_01', {}, 5, 200, id='id'), + pytest.param('doesnotexist', {}, -1, 404, id='404'), + pytest.param('id_01', {'glob_pattern': '*.json'}, 1, 200, id='glob'), + pytest.param('id_01', {'re_pattern': '[a-z]*\\.aux'}, 4, 200, id='re'), + pytest.param('id_01', {'re_pattern': '**'}, -1, 422, id='bad-re-pattern'), + pytest.param('id_01', {'compress': True}, 5, 200, id='compress')]) +def test_entry_download_raw(client, data, entry_id, files, files_per_entry, status_code): + response = client.get('entries/%s/raw/download?%s' % (entry_id, urlencode(files, doseq=True))) + assert_response(response, status_code) + if status_code == 200: + assert_raw_zip_file( + response, files=files_per_entry + 1, manifest_entries=1, + compressed=files.get('compress', False)) + + +@pytest.mark.parametrize('query, files, entries, status_code', [ + pytest.param({}, {}, 23, 200, id='all'), + pytest.param({'dft.code_name': 'DOESNOTEXIST'}, {}, -1, 200, id='empty'), + pytest.param({}, {'compress': True}, 23, 200, id='compress') +]) +@pytest.mark.parametrize('http_method', ['post', 'get']) +def test_entries_archive_download(client, data, query, files, entries, status_code, http_method): + perform_entries_archive_download_test( + client, status_code=status_code, query=query, files=files, entries=entries, + http_method=http_method) + + +@pytest.mark.parametrize('required, status_code', [ + pytest.param('*', 200, id='full'), + pytest.param({'section_metadata': '*'}, 200, id='partial'), + pytest.param({'section_run': {'section_system[NOTANINT]': '*'}}, 400, id='bad-required-1'), + pytest.param({'section_metadata': {'owners[NOTANINT]': '*'}}, 400, id='bad-required-2'), + pytest.param({'DOESNOTEXIST': '*'}, 400, id='bad-required-3') +]) +def test_entries_archive(client, data, required, status_code): + perform_entries_archive_test( + client, status_code=status_code, required=required, http_method='post') + + +@pytest.mark.parametrize('entry_id, status_code', [ + pytest.param('id_01', 200, id='id'), + pytest.param('id_02', 404, id='404'), + pytest.param('doesnotexist', 404, id='404')]) +def test_entry_archive(client, data, entry_id, status_code): + response = client.get('entries/%s/archive' % entry_id) + assert_response(response, status_code) + if status_code == 200: + assert_archive_response(response.json()) + + +def perform_entries_owner_test( + client, test_user_auth, other_test_user_auth, admin_user_auth, + owner, user, status_code, total, http_method, test_method): + + headers = None + if user == 'test_user': + headers = test_user_auth + elif user == 'other_test_user': + headers = other_test_user_auth + elif user == 'admin_user': + headers = admin_user_auth + elif user == 'bad_user': + headers = {'Authorization': 'Bearer NOTATOKEN'} + + test_method( + client, headers=headers, owner=owner, status_code=status_code, entries=total, + http_method=http_method) + + +@pytest.mark.parametrize('query, status_code, total', [ + pytest.param({}, 200, 23, id='empty'), + pytest.param('str', 422, -1, id='not-dict'), + pytest.param({'calc_id': 'id_01'}, 200, 1, id='match'), + pytest.param({'mispelled': 'id_01'}, 422, -1, id='not-quantity'), + pytest.param({'calc_id': ['id_01', 'id_02']}, 200, 0, id='match-list-0'), + pytest.param({'calc_id': 'id_01', 'atoms': ['H', 'O']}, 200, 1, id='match-list-1'), + pytest.param({'calc_id:any': ['id_01', 'id_02']}, 200, 2, id='any-short'), + pytest.param({'calc_id': {'any': ['id_01', 'id_02']}}, 200, 2, id='any'), + pytest.param({'calc_id': {'any': 'id_01'}}, 422, -1, id='any-not-list'), + pytest.param({'calc_id:any': 'id_01'}, 422, -1, id='any-short-not-list'), + pytest.param({'calc_id:gt': 'id_01'}, 200, 22, id='gt-short'), + pytest.param({'calc_id': {'gt': 'id_01'}}, 200, 22, id='gt'), + pytest.param({'calc_id': {'gt': ['id_01']}}, 422, 22, id='gt-list'), + pytest.param({'calc_id': {'missspelled': 'id_01'}}, 422, -1, id='not-op'), + pytest.param({'calc_id:lt': ['id_01']}, 422, -1, id='gt-shortlist'), + pytest.param({'calc_id:misspelled': 'id_01'}, 422, -1, id='not-op-short'), + pytest.param({'or': [{'calc_id': 'id_01'}, {'calc_id': 'id_02'}]}, 200, 2, id='or'), + pytest.param({'or': {'calc_id': 'id_01', 'dft.code_name': 'VASP'}}, 422, -1, id='or-not-list'), + pytest.param({'and': [{'calc_id': 'id_01'}, {'calc_id': 'id_02'}]}, 200, 0, id='and'), + pytest.param({'not': {'calc_id': 'id_01'}}, 200, 22, id='not'), + pytest.param({'not': [{'calc_id': 'id_01'}]}, 422, -1, id='not-list'), + pytest.param({'not': {'not': {'calc_id': 'id_01'}}}, 200, 1, id='not-nested-not'), + pytest.param({'not': {'calc_id:any': ['id_01', 'id_02']}}, 200, 21, id='not-nested-any'), + pytest.param({'and': [{'calc_id:any': ['id_01', 'id_02']}, {'calc_id:any': ['id_02', 'id_03']}]}, 200, 1, id='and-nested-any'), + pytest.param({'and': [{'not': {'calc_id': 'id_01'}}, {'not': {'calc_id': 'id_02'}}]}, 200, 21, id='not-nested-not') +]) +@pytest.mark.parametrize('test_method', [ + pytest.param(perform_entries_metadata_test, id='metadata'), + pytest.param(perform_entries_raw_download_test, id='raw-download'), + pytest.param(perform_entries_raw_test, id='raw'), + pytest.param(perform_entries_archive_test, id='archive'), + pytest.param(perform_entries_archive_download_test, id='archive-download')]) +def test_entries_post_query(client, data, query, status_code, total, test_method): + response_json = test_method(client, query=query, status_code=status_code, entries=total, http_method='post') + + response = client.post('entries/query', json={'query': query}) + response_json = assert_entries_metadata_response(response, status_code=status_code) + + if response_json is None: + return + + if 'pagination' not in response_json: + return + + pagination = response_json['pagination'] + assert pagination['total'] == total + assert pagination['size'] == 10 + assert pagination['order_by'] == 'calc_id' + assert pagination['order'] == 'asc' + assert ('next_after' in pagination) == (total > 10) + + +@pytest.mark.parametrize('query, status_code, total', [ + pytest.param({}, 200, 23, id='empty'), + pytest.param({'calc_id': 'id_01'}, 200, 1, id='match'), + pytest.param({'mispelled': 'id_01'}, 200, 23, id='not-quantity'), + pytest.param({'calc_id': ['id_01', 'id_02']}, 200, 2, id='match-many-or'), + pytest.param({'atoms': ['H', 'O']}, 200, 23, id='match-list-many-and-1'), + pytest.param({'atoms': ['H', 'O', 'Zn']}, 200, 0, id='match-list-many-and-2'), + pytest.param({'n_atoms': 2}, 200, 23, id='match-int'), + pytest.param({'n_atoms__gt': 2}, 200, 0, id='gt-int'), + pytest.param({'calc_id__any': ['id_01', 'id_02']}, 200, 2, id='any'), + pytest.param({'calc_id__any': 'id_01'}, 200, 1, id='any-not-list'), + pytest.param({'domain': ['dft', 'ems']}, 422, -1, id='list-not-op'), + pytest.param({'calc_id__gt': 'id_01'}, 200, 22, id='gt'), + pytest.param({'calc_id__gt': ['id_01', 'id_02']}, 422, -1, id='gt-list'), + pytest.param({'calc_id__missspelled': 'id_01'}, 422, -1, id='not-op'), + pytest.param({'q': 'calc_id__id_01'}, 200, 1, id='q-match'), + pytest.param({'q': 'missspelled__id_01'}, 422, -1, id='q-bad-quantity'), + pytest.param({'q': 'bad_encoded'}, 422, -1, id='q-bad-encode'), + pytest.param({'q': 'n_atoms__2'}, 200, 23, id='q-match-int'), + pytest.param({'q': 'n_atoms__gt__2'}, 200, 0, id='q-gt'), + pytest.param({'q': 'dft.workflow.section_geometry_optimization.final_energy_difference__1e-24'}, 200, 0, id='foat'), + pytest.param({'q': 'domain__dft'}, 200, 23, id='enum'), + pytest.param({'q': 'upload_time__gt__2014-01-01'}, 200, 23, id='datetime'), + pytest.param({'q': ['atoms__all__H', 'atoms__all__O']}, 200, 23, id='q-all'), + pytest.param({'q': ['atoms__all__H', 'atoms__all__X']}, 200, 0, id='q-all') +]) +@pytest.mark.parametrize('test_method', [ + pytest.param(perform_entries_metadata_test, id='metadata'), + pytest.param(perform_entries_raw_download_test, id='raw-download'), + pytest.param(perform_entries_raw_test, id='raw'), + pytest.param(perform_entries_archive_test, id='archive'), + pytest.param(perform_entries_archive_download_test, id='archive-download')]) +def test_entries_get_query(client, data, query, status_code, total, test_method): + response_json = test_method( + client, query=query, status_code=status_code, entries=total, http_method='get') + + if response_json is None: + return + + if 'pagination' not in response_json: + return + + response = client.get('entries?%s' % urlencode(query, doseq=True)) + + response_json = assert_entries_metadata_response(response, status_code=status_code) + + if response_json is None: + return + + pagination = response_json['pagination'] + assert pagination['total'] == total + assert pagination['size'] == 10 + assert pagination['order_by'] == 'calc_id' + assert pagination['order'] == 'asc' + assert ('next_after' in pagination) == (total > 10) + + +@pytest.mark.parametrize('owner, user, status_code, total', [ + pytest.param('user', None, 401, -1, id='user-wo-auth'), + pytest.param('staging', None, 401, -1, id='staging-wo-auth'), + pytest.param('visible', None, 200, 23, id='visible-wo-auth'), + pytest.param('admin', None, 401, -1, id='admin-wo-auth'), + pytest.param('shared', None, 401, -1, id='shared-wo-auth'), + pytest.param('public', None, 200, 23, id='public-wo-auth'), + + pytest.param('user', 'test_user', 200, 27, id='user-test-user'), + pytest.param('staging', 'test_user', 200, 2, id='staging-test-user'), + pytest.param('visible', 'test_user', 200, 27, id='visible-test-user'), + pytest.param('admin', 'test_user', 401, -1, id='admin-test-user'), + pytest.param('shared', 'test_user', 200, 27, id='shared-test-user'), + pytest.param('public', 'test_user', 200, 23, id='public-test-user'), + + pytest.param('user', 'other_test_user', 200, 0, id='user-other-test-user'), + pytest.param('staging', 'other_test_user', 200, 1, id='staging-other-test-user'), + pytest.param('visible', 'other_test_user', 200, 25, id='visible-other-test-user'), + pytest.param('shared', 'other_test_user', 200, 2, id='shared-other-test-user'), + pytest.param('public', 'other_test_user', 200, 23, id='public-other-test-user'), + + pytest.param('all', None, 200, 25, id='metadata-all-wo-auth'), + pytest.param('all', 'test_user', 200, 27, id='metadata-all-test-user'), + pytest.param('all', 'other_test_user', 200, 26, id='metadata-all-other-test-user'), + + pytest.param('admin', 'admin_user', 200, 27, id='admin-admin-user'), + pytest.param('all', 'bad_user', 401, -1, id='bad-credentials') +]) +@pytest.mark.parametrize('http_method', ['post', 'get']) +@pytest.mark.parametrize('test_method', [ + pytest.param(perform_entries_metadata_test, id='metadata'), + pytest.param(perform_entries_raw_download_test, id='raw-download'), + pytest.param(perform_entries_raw_test, id='raw'), + pytest.param(perform_entries_archive_test, id='archive'), + pytest.param(perform_entries_archive_download_test, id='archive-download')]) +def test_entries_owner( + client, data, test_user_auth, other_test_user_auth, admin_user_auth, + owner, user, status_code, total, http_method, test_method): + + perform_entries_owner_test( + client, test_user_auth, other_test_user_auth, admin_user_auth, + owner, user, status_code, total, http_method, test_method) + + +@pytest.mark.parametrize('pagination, response_pagination, status_code', [ + pytest.param({}, {'total': 23, 'size': 10, 'next_after': 'id_10'}, 200, id='empty'), + pytest.param({'size': 1}, {'total': 23, 'size': 1, 'next_after': 'id_01'}, 200, id='size'), + pytest.param({'size': 0}, {'total': 23, 'size': 0}, 200, id='size-0'), + pytest.param({'size': 1, 'after': 'id_01'}, {'after': 'id_01', 'next_after': 'id_02'}, 200, id='after'), + pytest.param({'size': 1, 'after': 'id_02', 'order': 'desc'}, {'next_after': 'id_01'}, 200, id='after-desc'), + pytest.param({'size': 1, 'order_by': 'n_atoms'}, {'next_after': '2:id_01'}, 200, id='order-by-after-int'), + pytest.param({'size': 1, 'order_by': 'dft.code_name'}, {'next_after': 'VASP:id_01'}, 200, id='order-by-after-nested'), + pytest.param({'size': -1}, None, 422, id='bad-size'), + pytest.param({'order': 'misspelled'}, None, 422, id='bad-order'), + pytest.param({'order_by': 'misspelled'}, None, 422, id='bad-order-by'), + pytest.param({'order_by': 'atoms', 'after': 'H:id_01'}, None, 422, id='order-by-list'), + pytest.param({'order_by': 'n_atoms', 'after': 'some'}, None, 400, id='order-by-bad-after') +]) +@pytest.mark.parametrize('http_method', ['post', 'get']) +@pytest.mark.parametrize('test_method', [ + pytest.param(perform_entries_metadata_test, id='metadata'), + pytest.param(perform_entries_raw_test, id='raw'), + pytest.param(perform_entries_archive_test, id='archive')]) +def test_entries_pagination(client, data, pagination, response_pagination, status_code, http_method, test_method): + response_json = test_method( + client, pagination=pagination, status_code=status_code, http_method=http_method) + + if response_json is None: + return + + assert_pagination(pagination, response_json['pagination'], response_json['data']) diff --git a/tests/app_fastapi/routers/test_users.py b/tests/app_fastapi/routers/test_users.py new file mode 100644 index 0000000000000000000000000000000000000000..83812d1bcc5de0849294cccf5b3ed14e84af9dad --- /dev/null +++ b/tests/app_fastapi/routers/test_users.py @@ -0,0 +1,14 @@ + +def test_me(client, test_user_auth): + response = client.get('users/me', headers=test_user_auth) + assert response.status_code == 200 + + +def test_me_auth_required(client): + response = client.get('users/me') + assert response.status_code == 401 + + +def test_me_auth_bad_token(client): + response = client.get('users/me', headers={'Authentication': 'Bearer NOTATOKEN'}) + assert response.status_code == 401 diff --git a/tests/app_fastapi/test_utils.py b/tests/app_fastapi/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d9a25496da39d85142ee966a6697ae446f99bd60 --- /dev/null +++ b/tests/app_fastapi/test_utils.py @@ -0,0 +1,43 @@ +from typing import Iterator +import os.path +import zipfile + +from nomad import config +from nomad.datamodel import EntryArchive, EntryMetadata +from nomad.app_fastapi.utils import create_streamed_zipfile, File + +from tests.conftest import clear_raw_files +from tests.test_files import create_test_upload_files + + +def test_create_streamed_zip(raw_files_infra): + # We use the files of a simpe test upload to create streamed zip with all the raw + # files. + archive = EntryArchive() + metadata = archive.m_create(EntryMetadata) + metadata.upload_id = 'test_id' + metadata.calc_id = 'test_id' + metadata.mainfile = 'root/subdir/mainfile.json' + + upload_files = create_test_upload_files('test_id', [archive]) + + def generate_files() -> Iterator[File]: + for path in upload_files.raw_file_manifest(): + with upload_files.raw_file(path) as f: + yield File( + path=path, + f=f, + size=upload_files.raw_file_size(path)) + + if not os.path.exists(config.fs.tmp): + os.makedirs(config.fs.tmp) + + zip_file_path = os.path.join(config.fs.tmp, 'results.zip') + with open(zip_file_path, 'wb') as f: + for content in create_streamed_zipfile(generate_files()): + f.write(content) + + with zipfile.ZipFile(zip_file_path) as zf: + assert zf.testzip() is None + + clear_raw_files() diff --git a/tests/conftest.py b/tests/conftest.py index 16fecf41197891435dc6240532d48bc665663efc..457ce84217fcd425851c5f212be4f4425d8a2f68 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,7 +16,7 @@ # limitations under the License. # -from typing import Tuple, List +from typing import Tuple, List, Dict, Any import pytest import logging from collections import namedtuple @@ -78,6 +78,7 @@ def raw_files_infra(): config.fs.staging = '.volumes/test_fs/staging' config.fs.public = '.volumes/test_fs/public' config.fs.prefix_size = 2 + clear_raw_files() @pytest.fixture(scope='function') @@ -193,10 +194,25 @@ def elastic_infra(monkeysession): return infrastructure.setup_elastic() except Exception: # try to delete index, error might be caused by changed mapping - from elasticsearch_dsl import connections - connections.create_connection(hosts=['%s:%d' % (config.elastic.host, config.elastic.port)]) \ - .indices.delete(index='nomad_fairdi_calcs_test') - return infrastructure.setup_elastic() + return clear_elastic_infra() + + +def clear_elastic_infra(): + from elasticsearch_dsl import connections + connection = connections.create_connection( + hosts=['%s:%d' % (config.elastic.host, config.elastic.port)]) + + try: + connection.indices.delete(index='nomad_fairdi_test') + except Exception: + pass + + try: + connection.indices.delete(index='nomad_fairdi_materials_test') + except Exception: + pass + + return infrastructure.setup_elastic() def clear_elastic(elastic): @@ -210,7 +226,7 @@ def clear_elastic(elastic): except elasticsearch.exceptions.NotFoundError: # it is unclear why this happens, but it happens at least once, when all tests # are executed - infrastructure.setup_elastic() + clear_elastic_infra() @pytest.fixture(scope='function') @@ -238,6 +254,12 @@ class KeycloakMock: self.id_counter = 2 self.users = dict(**test_users) + def tokenauth(self, access_token: str) -> Dict[str, Any]: + if access_token in self.users: + return self.users[access_token] + else: + raise infrastructure.KeycloakError('user does not exist') + def authorize_flask(self, *args, **kwargs): if 'Authorization' in request.headers and request.headers['Authorization'].startswith('Bearer '): user_id = request.headers['Authorization'].split(None, 1)[1].strip() @@ -269,6 +291,13 @@ class KeycloakMock: User(**test_user) for test_user in self.users.values() if query in ' '.join([str(value) for value in test_user.values()])] + def basicauth(self, username: str, password: str) -> str: + for user in self.users.values(): + if user['username'] == username: + return user['user_id'] + + raise infrastructure.KeycloakError() + @property def access_token(self): return g.oidc_access_token diff --git a/tests/test_archive.py b/tests/test_archive.py index 5d4827ca460854acfd2699f2af32cd9df198e7e0..3a6af6546b78688555a8942cd727f4e706f643b1 100644 --- a/tests/test_archive.py +++ b/tests/test_archive.py @@ -427,3 +427,7 @@ def test_compute_required_incomplete(archive): }) assert required is not None + + +def test_compute_required_full(): + assert compute_required_with_referenced('*') is None diff --git a/tests/test_files.py b/tests/test_files.py index e85363115b4221bd9eac86082545cec8b0863deb..779424c68f5dd7e4550da80b15b37ec115ab66de 100644 --- a/tests/test_files.py +++ b/tests/test_files.py @@ -16,7 +16,7 @@ # limitations under the License. # -from typing import Generator, Any, Dict, Tuple, Iterable +from typing import Generator, Any, Dict, Tuple, Iterable, List import os import os.path import shutil @@ -25,7 +25,7 @@ import itertools import zipfile import re -from nomad import config, datamodel +from nomad import config, datamodel, utils from nomad.files import DirectoryObject, PathObject from nomad.files import StagingUploadFiles, PublicUploadFiles, UploadFiles, Restricted, \ ArchiveBasedStagingUploadFiles @@ -480,3 +480,88 @@ def assert_upload_files( assert calc.with_embargo or isinstance(upload_files, StagingUploadFiles) upload_files.close() + + +def create_test_upload_files( + upload_id: str, + archives: List[datamodel.EntryArchive], + published: bool = True, + template_files: str = example_file, + template_mainfile: str = example_file_mainfile) -> UploadFiles: + ''' + Creates an upload_files object and the underlying files for test/mock purposes. + + Arguments: + upload_id: The upload id for the upload. Will generate a random UUID if None. + archives: A list of class:`datamodel.EntryArchive` metainfo objects. This will + be used to determine the mainfiles. Will create respective directories and + copy the template calculation to create raw files for each archive. + Will also be used to fill the archives in the create upload. + published: Creates a :class:`PublicUploadFiles` object with published files + instead of a :class:`StagingUploadFiles` object with staging files. Default + is published. + template_files: A zip file with example files in it. One directory will be used + as a template. It will be copied for each given archive. + template_mainfile: Path of the template mainfile within the given template_files. + ''' + if upload_id is None: upload_id = utils.create_uuid() + if archives is None: archives = [] + + upload_files = ArchiveBasedStagingUploadFiles( + upload_id, upload_path=template_files, create=True) + + upload_files.extract() + + upload_raw_files = upload_files.join_dir('raw') + source = upload_raw_files.join_dir(os.path.dirname(template_mainfile)).os_path + + for archive in archives: + # create a copy of the given template files for each archive + mainfile = archive.section_metadata.mainfile + assert mainfile is not None, 'Archives to create test upload must have a mainfile' + target = upload_raw_files.join_file(os.path.dirname(mainfile)).os_path + if os.path.exists(target): + for file_ in os.listdir(source): + shutil.copy(os.path.join(source, file_), target) + else: + shutil.copytree(source, target) + os.rename( + os.path.join(target, os.path.basename(template_mainfile)), + os.path.join(target, os.path.basename(mainfile))) + + # create an archive "file" for each archive + calc_id = archive.section_metadata.calc_id + assert calc_id is not None, 'Archives to create test upload must have a calc id' + upload_files.write_archive(calc_id, archive.m_to_dict()) + + # remove the template + shutil.rmtree(source) + + if published: + upload_files.pack([archive.section_metadata for archive in archives]) + upload_files.delete() + return UploadFiles.get(upload_id) + + return upload_files + + +def test_test_upload_files(raw_files_infra): + upload_id = utils.create_uuid() + archives: datamodel.EntryArchive = [] + for index in range(0, 3): + archive = datamodel.EntryArchive() + metadata = archive.m_create(datamodel.EntryMetadata) + metadata.calc_id = 'example_calc_id_%d' % index + metadata.mainfile = 'test/test/calc_%d/mainfile_%d.json' % (index, index) + archives.append(archive) + + upload_files = create_test_upload_files(upload_id, archives) + + try: + assert_upload_files( + upload_id, + [archive.section_metadata for archive in archives], + PublicUploadFiles) + finally: + if upload_files.exists(): + upload_files.delete() diff --git a/tests/test_search.py b/tests/test_search.py index e7b3d34c84d4403b19701e7ea23bc8cb308e834e..6a7a9e87c976c248d304dbcbad9753eec1535c51 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -20,9 +20,12 @@ from typing import List, Iterable from elasticsearch_dsl import Q import pytest from datetime import datetime +import json -from nomad import datamodel, search, processing, infrastructure, config -from nomad.search import entry_document, SearchRequest +from nomad import datamodel, processing, infrastructure, config +from nomad.metainfo import search_extension +from nomad.search import entry_document, SearchRequest, search, flat +from nomad.app_fastapi.models import WithQuery def test_init_mapping(elastic): @@ -43,7 +46,7 @@ def test_index_normalized_calc(elastic, normalized: datamodel.EntryArchive): domain='dft', upload_id='test upload id', calc_id='test id') entry_metadata.apply_domain_metadata(normalized) search_entry = create_entry(entry_metadata) - entry = search.flat(search_entry.to_dict()) + entry = flat(search_entry.to_dict()) assert 'calc_id' in entry assert 'atoms' in entry @@ -75,7 +78,7 @@ def test_index_normalized_calc_with_author( entry_metadata.apply_domain_metadata(normalized) search_entry = create_entry(entry_metadata) - search.flat(search_entry.to_dict()) + flat(search_entry.to_dict()) def test_index_upload(elastic, processed: processing.Upload): @@ -86,7 +89,7 @@ def test_index_upload(elastic, processed: processing.Upload): def example_search_data(elastic, normalized: datamodel.EntryArchive): entry_metadata = normalized.section_metadata entry_metadata.m_update( - domain='dft', upload_id='test upload id', calc_id='test id', + domain='dft', upload_id='test upload id', calc_id='test id', published=True, upload_time=datetime.now()) entry_metadata.apply_domain_metadata(normalized) create_entry(entry_metadata) @@ -187,11 +190,11 @@ def assert_metrics(container, metrics_names): def test_search_statistics(elastic, example_search_data): - assert 'authors' in search.metrics.keys() - assert 'datasets' in search.metrics.keys() - assert 'unique_entries' in search.metrics.keys() + assert 'authors' in search_extension.metrics.keys() + assert 'datasets' in search_extension.metrics.keys() + assert 'unique_entries' in search_extension.metrics.keys() - use_metrics = search.metrics.keys() + use_metrics = search_extension.metrics.keys() request = SearchRequest(domain='dft').statistic( 'dft.system', size=10, metrics_to_use=use_metrics).date_histogram(metrics_to_use=use_metrics) @@ -229,7 +232,7 @@ def test_global_statistics(elastic, example_search_data): def test_search_totals(elastic, example_search_data): - use_metrics = search.metrics.keys() + use_metrics = search_extension.metrics.keys() request = SearchRequest(domain='dft').totals(metrics_to_use=use_metrics) results = request.execute() @@ -245,18 +248,18 @@ def test_search_totals(elastic, example_search_data): def test_search_exclude(elastic, example_search_data): for item in SearchRequest().execute_paginated()['results']: - assert 'atoms' in search.flat(item) + assert 'atoms' in flat(item) for item in SearchRequest().exclude('atoms').execute_paginated()['results']: - assert 'atoms' not in search.flat(item) + assert 'atoms' not in flat(item) def test_search_include(elastic, example_search_data): for item in SearchRequest().execute_paginated()['results']: - assert 'atoms' in search.flat(item) + assert 'atoms' in flat(item) for item in SearchRequest().include('calc_id').execute_paginated()['results']: - item = search.flat(item) + item = flat(item) assert 'atoms' not in item assert 'calc_id' in item @@ -320,7 +323,7 @@ def assert_search_upload( assert search_results.count() == len(list(upload_entries)) if search_results.count() > 0: for hit in search_results: - hit = search.flat(hit.to_dict()) + hit = flat(hit.to_dict()) for key, value in kwargs.items(): assert hit.get(key, None) == value @@ -357,3 +360,24 @@ if __name__ == '__main__': yield calc.to_dict(include_meta=True) bulk(infrastructure.elastic_client, gen_data()) + + +@pytest.mark.parametrize('api_query, total', [ + pytest.param('{}', 1, id='empty'), + pytest.param('{"dft.code_name": "VASP"}', 1, id="match"), + pytest.param('{"dft.code_name": "VASP", "dft.xc_functional": "dne"}', 0, id="match_all"), + pytest.param('{"and": [{"dft.code_name": "VASP"}, {"dft.xc_functional": "dne"}]}', 0, id="and"), + pytest.param('{"or":[{"dft.code_name": "VASP"}, {"dft.xc_functional": "dne"}]}', 1, id="or"), + pytest.param('{"not":{"dft.code_name": "VASP"}}', 0, id="not"), + pytest.param('{"dft.code_name": {"all": ["VASP", "dne"]}}', 0, id="all"), + pytest.param('{"dft.code_name": {"any": ["VASP", "dne"]}}', 1, id="any"), + pytest.param('{"dft.code_name": {"none": ["VASP", "dne"]}}', 0, id="none"), + pytest.param('{"dft.code_name": {"gte": "VASP"}}', 1, id="gte"), + pytest.param('{"dft.code_name": {"gt": "A"}}', 1, id="gt"), + pytest.param('{"dft.code_name": {"lte": "VASP"}}', 1, id="lte"), + pytest.param('{"dft.code_name": {"lt": "A"}}', 0, id="lt"), +]) +def test_search_query(elastic, example_search_data, api_query, total): + api_query = json.loads(api_query) + results = search(owner='all', query=WithQuery(query=api_query).query) + assert results.pagination.total == total # pylint: disable=no-member diff --git a/tests/utils.py b/tests/utils.py index b1378a76756926466c5c2bef069ca08d0cffdac9..6dd3836071406b84f24db23dfaa21cb9491dab96 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -48,3 +48,17 @@ def assert_log(caplog, level: str, event_part: str) -> LogRecord: assert record is not None return record + + +def assert_at_least(source, target): + ''' + Compares two dicts recursively and asserts that all information in source equals + the same information in target. Additional information in target is ignored. + ''' + for key, value in source.items(): + assert key in target, '%s with value %s in %s is not in %s' % (key, source[key], source, target) + if isinstance(value, dict): + assert_at_least(value, target[key]) + else: + assert value == target[key], '%s with value %s in %s is not equal the target value %s in %s' % ( + key, source[key], source, target[key], target)