diff --git a/nomad/app/v1/routers/graph.py b/nomad/app/v1/routers/graph.py index 1d1006c53c11f2f8198d23cda71807b6c1003ee8..e24fb048ef1045c298334ef98ffe3cb637ac601e 100644 --- a/nomad/app/v1/routers/graph.py +++ b/nomad/app/v1/routers/graph.py @@ -80,9 +80,7 @@ def relocate_children(request): description='Query the database with a graph style without verification.', response_class=GraphJSONResponse, ) -async def raw_query( - query=Body(...), user: User = Depends(create_user_dependency(required=True)) -): +async def raw_query(query=Body(...), user: User = Depends(create_user_dependency())): relocate_children(query) with MongoReader(query, user=user) as reader: return GraphJSONResponse(await reader.read()) @@ -99,7 +97,7 @@ async def raw_query( ) async def basic_query( query: GraphRequest = Body(...), - user: User = Depends(create_user_dependency(required=True)), + user: User = Depends(create_user_dependency()), ): try: query_dict = query.dict( diff --git a/nomad/graph/graph_reader.py b/nomad/graph/graph_reader.py index 24139f771409b93ca2c6d3ee3cd064f12790534d..7fd3344227b991b7048858ad8fa64c2fae4d7bbc 100644 --- a/nomad/graph/graph_reader.py +++ b/nomad/graph/graph_reader.py @@ -876,6 +876,14 @@ class GeneralReader: for upload in self.upload_pool.values(): upload.close() + @property + def auth_user_id(self) -> str: + return self.user.user_id if self.user else '' + + @property + def auth_user_is_admin(self) -> bool: + return self.user.is_admin if self.user else False + def _log( self, message: str, @@ -918,7 +926,7 @@ class GeneralReader: async def retrieve_user(self, user_id: str) -> str | dict: # `me` is a convenient way to refer to the current user if user_id == 'me': - user_id = self.user.user_id + user_id = self.auth_user_id def _retrieve(): return User.get(user_id=user_id) @@ -1030,7 +1038,9 @@ class GeneralReader: async def retrieve_entry(self, entry_id: str) -> str | dict: def _search(): return perform_search( - owner='all', query={'entry_id': entry_id}, user_id=self.user.user_id + owner='all', + query={'entry_id': entry_id}, + user_id=self.auth_user_id or None, ) if (await asyncio.to_thread(_search)).pagination.total == 0: @@ -1056,7 +1066,7 @@ class GeneralReader: ) return dataset_id - if dataset.user_id != self.user.user_id: + if dataset.user_id != self.auth_user_id: self._log( f'No access to dataset {dataset_id}.', error_type=QueryError.NOACCESS ) @@ -1303,19 +1313,19 @@ class MongoReader(GeneralReader): @functools.cached_property def uploads(self): return Upload.objects( - Q(main_author=self.user.user_id) - | Q(reviewers=self.user.user_id) - | Q(coauthors=self.user.user_id) + Q(main_author=self.auth_user_id) + | Q(reviewers=self.auth_user_id) + | Q(coauthors=self.auth_user_id) ) @functools.cached_property def datasets(self): - return Dataset.m_def.a_mongo.objects(user_id=self.user.user_id) + return Dataset.m_def.a_mongo.objects(user_id=self.auth_user_id) async def _query_es(self, config: RequestConfig): search_params: dict = { - 'owner': 'user', - 'user_id': self.user.user_id, + 'owner': 'user' if self.auth_user_id else 'public', + 'user_id': self.auth_user_id or None, 'query': {}, # 'required': MetadataRequired(include=['entry_id']) } @@ -1410,9 +1420,9 @@ class MongoReader(GeneralReader): mongo_query &= Q(publish_time=None) if config.query.is_owned is True: - mongo_query &= Q(main_author=self.user.user_id) + mongo_query &= Q(main_author=self.auth_user_id) elif config.query.is_owned is False: - mongo_query &= Q(main_author__ne=self.user.user_id) + mongo_query &= Q(main_author__ne=self.auth_user_id) return config.query.dict(exclude_unset=True), self.uploads.filter(mongo_query) @@ -2091,7 +2101,7 @@ class EntryReader(MongoReader): class ElasticSearchReader(EntryReader): async def retrieve_entry(self, entry_id: str) -> str | dict: search_response = perform_search( - owner='all', query={'entry_id': entry_id}, user_id=self.user.user_id + owner='all', query={'entry_id': entry_id}, user_id=self.auth_user_id or None ) if search_response.pagination.total == 0: @@ -2137,11 +2147,11 @@ class UserReader(MongoReader): | Q(coauthors=self.target_user_id) ) # self.user must have access to the upload - if self.target_user_id != self.user.user_id and not self.user.is_admin: + if self.target_user_id != self.auth_user_id and not self.auth_user_is_admin: mongo_query &= ( - Q(main_author=self.user.user_id) - | Q(reviewers=self.user.user_id) - | Q(coauthors=self.user.user_id) + Q(main_author=self.auth_user_id) + | Q(reviewers=self.auth_user_id) + | Q(coauthors=self.auth_user_id) ) return Upload.objects(mongo_query) @@ -2164,10 +2174,11 @@ class UserReader(MongoReader): user_id: str = target_user['user_id'] elif isinstance(user_id_or_dict, str): if user_id_or_dict == 'me': - user_id = self.user.user_id + user_id = self.auth_user_id else: user_id = user_id_or_dict - target_user = await self.retrieve_user(user_id) + # if user_id == '' there is no auth user thus set to empty dict + target_user = await self.retrieve_user(user_id) if user_id else {} else: # should not reach here raise NotImplementedError diff --git a/tests/app/v1/routers/test_graph.py b/tests/app/v1/routers/test_graph.py index 5a12924f66a284f8d3af88036857ed051cf8b171..5f2b96ec5b9c6c18858e2595e8274d87c3e895ad 100644 --- a/tests/app/v1/routers/test_graph.py +++ b/tests/app/v1/routers/test_graph.py @@ -41,6 +41,38 @@ def assert_path_exists(path, response): raise KeyError +def contains(subset: dict, superset: dict) -> bool: + """ + Recursively checks if the 'subset' dictionary is contained within the 'superset' dictionary. + + For every key-value pair in 'subset', the same key must exist in 'superset'. + If the value is a dictionary, then the function checks recursively that the nested dictionary is also contained. + + Parameters: + subset (dict): The dictionary to be checked for containment. + superset (dict): The dictionary that should contain the subset. + + Returns: + bool: True if every key-value pair in subset is found in supserset, otherwise False. + """ + for key, value in subset.items(): + # Check if the key exists in haystack + if key not in superset: + return False + + # If the value is a dictionary, perform a recursive check. + if isinstance(value, dict): + if not isinstance(superset[key], dict): + return False + if not contains(value, superset[key]): + return False + else: + # Direct comparison for non-dictionary values. + if superset[key] != value: + return False + return True + + def test_graph_query_random(auth_headers, client, example_data): user_auth = auth_headers['user1'] response = client.post( @@ -68,30 +100,130 @@ def test_graph_query_random(auth_headers, client, example_data): @pytest.mark.parametrize( - 'upload_id,entry_id,user,status_code', + 'upload_id,entry_id,user,expected_value,status_code', [ - pytest.param('id_embargo', 'id_embargo_1', 'user1', 200, id='ok'), + pytest.param( + 'id_embargo', + 'id_embargo_1', + 'user1', + {'uploads': {'id_embargo': {'entries': {'id_embargo_1': {}}}}}, + 200, + id='ok', + ), pytest.param( 'id_child_entries', 'id_child_entries_child1', 'user1', + { + 'uploads': { + 'id_child_entries': {'entries': {'id_child_entries_child1': {}}} + } + }, 200, id='child-entry', ), - pytest.param('id_embargo', 'id_embargo_1', 'user0', 200, id='admin-access'), - pytest.param('id_embargo', 'id_embargo_1', None, 401, id='no-credentials'), pytest.param( - 'id_embargo', 'id_embargo_1', 'invalid', 401, id='invalid-credentials' + 'id_embargo', + 'id_embargo_1', + 'user0', + {'uploads': {'id_embargo': {'entries': {'id_embargo_1': {}}}}}, + 200, + id='admin-access', + ), + pytest.param( + 'id_embargo', + 'id_embargo_1', + None, + { + 'uploads': { + 'id_embargo': { + 'm_errors': [ + { + 'error_type': 'NOACCESS', + 'message': 'No access to upload id_embargo.', + } + ] + } + } + }, + 200, + id='no-credentials', + ), + pytest.param( + 'id_embargo', 'id_embargo_1', 'invalid', {}, 401, id='invalid-credentials' + ), + pytest.param( + 'id_embargo', + 'id_embargo_1', + 'user2', + { + 'uploads': { + 'id_embargo': { + 'm_errors': [ + { + 'error_type': 'NOACCESS', + 'message': 'No access to upload id_embargo.', + } + ] + } + } + }, + 200, + id='no-access', + ), + pytest.param( + 'silly_value', + 'id_embargo_1', + 'user1', + { + 'uploads': { + 'silly_value': { + 'm_errors': [ + { + 'error_type': 'NOTFOUND', + 'message': 'The value silly_value is not a valid upload id.', + } + ] + } + } + }, + 200, + id='invalid-upload_id', ), - pytest.param('id_embargo', 'id_embargo_1', 'user2', 404, id='no-access'), pytest.param( - 'silly_value', 'id_embargo_1', 'user1', 404, id='invalid-upload_id' + 'id_embargo', + 'silly_value', + 'user1', + { + 'uploads': { + 'id_embargo': { + 'entries': { + 'silly_value': { + 'm_errors': [ + { + 'error_type': 'NOACCESS', + 'message': 'The value silly_value is not a valid entry id or not visible to current user.', + } + ] + } + } + } + } + }, + 200, + id='invalid-entry_id', ), - pytest.param('id_embargo', 'silly_value', 'user1', 404, id='invalid-entry_id'), ], ) def test_graph_query( - auth_headers, client, example_data, upload_id, entry_id, user, status_code + auth_headers, + client, + example_data, + upload_id, + entry_id, + user, + expected_value, + status_code, ): user_auth = auth_headers[user] response = client.post( @@ -99,14 +231,10 @@ def test_graph_query( json={Token.UPLOADS: {upload_id: {Token.ENTRIES: {entry_id: '*'}}}}, headers={'Accept': 'application/json'} | (user_auth if user_auth else {}), ) - target_path = (Token.UPLOADS, upload_id, Token.ENTRIES, entry_id) - if 200 == status_code: - assert_path_exists(target_path, response.json()) - elif 401 == status_code: - assert response.status_code == 401 - else: - with pytest.raises(KeyError): - assert_path_exists(target_path, response.json()) + + assert response.status_code == status_code + if status_code == 200: + assert contains(expected_value, response.json()) @pytest.mark.parametrize( @@ -257,7 +385,10 @@ def example_upload(example_archive, user1, mongo_function, elastic_function): ), id='user2', ), - pytest.param(dict(user=None, expected_status_code=401), id='no-credentials'), + pytest.param( + dict(user=None, expected_status_code=200, expected_upload_ids=[]), + id='no-credentials', + ), pytest.param( dict(user='invalid', expected_status_code=401), id='invalid-credentials' ), @@ -410,14 +541,11 @@ def test_get_uploads_graph(auth_headers, client, example_data, kwargs): def assert_upload_ids(a, b): assert set(a.keys()) == set(b) + assert response.status_code == expected_status_code if expected_status_code == 200: - result[Token.UPLOADS].pop('m_response', None) - if expected_upload_ids: - assert_upload_ids(result[Token.UPLOADS], expected_upload_ids) - else: - assert result[Token.UPLOADS] == {} - else: - assert response.status_code == expected_status_code + uploads = result.get(Token.UPLOADS, {}) + uploads.pop('m_response', None) + assert_upload_ids(uploads, expected_upload_ids) @pytest.mark.parametrize(