diff --git a/nomad/app/v1/models/groups.py b/nomad/app/v1/models/groups.py index dbdae164fba850f19314e7cb855e8312bf9d10d3..7b3973166708bf9c2ff5007b2a83a01e6dbe33f0 100644 --- a/nomad/app/v1/models/groups.py +++ b/nomad/app/v1/models/groups.py @@ -3,31 +3,31 @@ from pydantic_core import PydanticCustomError from .pagination import Direction, Pagination, PaginationResponse -group_name_description = 'Name of the group.' -group_members_description = 'User ids of the group members.' +GROUP_NAME_DESCRIPTION = 'Name of the group.' +GROUP_MEMBERS_DESCRIPTION = 'User ids of the group members (includes owner).' class UserGroupEdit(BaseModel): group_name: str | None = Field( default=None, - description=group_name_description, + description=GROUP_NAME_DESCRIPTION, min_length=3, max_length=32, pattern=r'^[a-zA-Z0-9][a-zA-Z0-9 ._\-]+[a-zA-Z0-9]$', ) members: set[str] | None = Field( - default=None, description=group_members_description + default=None, description=GROUP_MEMBERS_DESCRIPTION ) class UserGroup(BaseModel): group_id: str = Field(description='Unique id of the group.') group_name: str = Field( - default='Default Group Name', description=group_name_description + default='Default Group Name', description=GROUP_NAME_DESCRIPTION ) owner: str = Field(description='User id of the group owner.') members: list[str] = Field( - default_factory=list, description=group_members_description + default_factory=list, description=GROUP_MEMBERS_DESCRIPTION ) model_config = ConfigDict(from_attributes=True) @@ -39,11 +39,11 @@ class UserGroupResponse(BaseModel): class UserGroupQuery(BaseModel): - group_id: list[str] | None = Field( - None, description='Search groups by their full id.' + group_id: str | list[str] | None = Field( + None, description='Search groups by their full id (scalar or list).' ) user_id: str | None = Field( - None, description='Search groups by their owner or members ids.' + None, description="Search groups by their owner's or members' ids." ) search_terms: str | None = Field( None, description='Search groups by parts of their name.' diff --git a/nomad/app/v1/routers/entries.py b/nomad/app/v1/routers/entries.py index f974e154214bba6607ac9a7d8f79ff0a7267ab12..a046730d728d6379aa422c4344c5f255d437eaa9 100644 --- a/nomad/app/v1/routers/entries.py +++ b/nomad/app/v1/routers/entries.py @@ -41,7 +41,7 @@ from nomad.config.models.config import Reprocess from nomad.datamodel import EditableUserMetadata from nomad.datamodel.context import ServerContext from nomad.files import StreamedFile, create_zipstream_async -from nomad.groups import get_group_ids +from nomad.groups import MongoUserGroup from nomad.metainfo.elasticsearch_extension import entry_type from nomad.processing.data import Upload from nomad.search import ( @@ -1477,7 +1477,7 @@ async def post_entry_edit( writers = [writer['user_id'] for writer in entry_data.get('writers', [])] writer_groups = response.data[0].get('writer_groups', []) is_writer = user.user_id in writers or not set( - get_group_ids(user.user_id) + MongoUserGroup.get_ids_by_user_id(user.user_id) ).isdisjoint(writer_groups) if not (is_admin or is_writer): diff --git a/nomad/app/v1/routers/groups.py b/nomad/app/v1/routers/groups.py index 945b95f75706b66025c04ac8970047ae2b48e5bd..12110399215f663ed782f664acb6552ad6739d69 100644 --- a/nomad/app/v1/routers/groups.py +++ b/nomad/app/v1/routers/groups.py @@ -30,8 +30,7 @@ from nomad.app.v1.models.groups import ( from nomad.app.v1.models.pagination import PaginationResponse from nomad.app.v1.utils import parameter_dependency_from_model from nomad.datamodel import User as UserDataModel -from nomad.groups import MongoUserGroup -from nomad.groups import create_user_group as create_mongo_user_group +from nomad.groups import MongoUserGroup, create_mongo_user_group, get_mongo_user_group from nomad.utils import strip from ..models import User @@ -55,8 +54,8 @@ user_group_pagination_parameters = parameter_dependency_from_model( ) -def get_mongo_user_group(group_id: str) -> MongoUserGroup: - user_group = MongoUserGroup.objects(group_id=group_id).first() +def get_user_group_or_404(group_id: str) -> MongoUserGroup: + user_group = get_mongo_user_group(group_id) if user_group is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -124,7 +123,7 @@ async def get_user_groups( ) async def get_user_group(group_id: str): """Get data about user group.""" - user_group = get_mongo_user_group(group_id) + user_group = get_user_group_or_404(group_id) return user_group @@ -163,7 +162,7 @@ async def update_user_group( user: User = Depends(create_user_dependency(required=True)), ): """Update user group.""" - user_group = get_mongo_user_group(group_id) + user_group = get_user_group_or_404(group_id) check_user_may_edit_user_group(user, user_group) user_group_dict = user_group_edit.dict(exclude_none=True) @@ -171,9 +170,7 @@ async def update_user_group( if members is not None: check_user_ids(members) - user_group.update(**user_group_dict) - user_group.save() - user_group.reload() + user_group.clean_update_reload(**user_group_dict) return user_group @@ -187,7 +184,7 @@ async def delete_user_group( group_id: str, user: User = Depends(create_user_dependency(required=True)) ): """Delete user group.""" - user_group = get_mongo_user_group(group_id) + user_group = get_user_group_or_404(group_id) check_user_may_edit_user_group(user, user_group) user_group.delete() diff --git a/nomad/app/v1/routers/north.py b/nomad/app/v1/routers/north.py index 49d108b02ebe81e96c7a7a8a52481a5aa2f520a3..9b70bccd12464f9071b566ab43a3965161a2e0aa 100644 --- a/nomad/app/v1/routers/north.py +++ b/nomad/app/v1/routers/north.py @@ -27,7 +27,7 @@ from pydantic import BaseModel from nomad.app.v1.routers.auth import generate_simple_token from nomad.config import config from nomad.config.models.north import NORTHTool -from nomad.groups import get_group_ids +from nomad.groups import MongoUserGroup from nomad.processing import Upload from nomad.utils import get_logger, slugify, strip @@ -224,13 +224,11 @@ async def start_tool( upload_mount_dir = None user_id = str(user.user_id) - group_ids = get_group_ids(user.user_id, include_all=False) + group_ids = MongoUserGroup.get_ids_by_user_id(user_id, include_all=False) upload_query = Q() upload_query &= ( - Q(main_author=user_id) - | Q(coauthors=user.user_id) - | Q(coauthor_groups__in=group_ids) + Q(main_author=user_id) | Q(coauthors=user_id) | Q(coauthor_groups__in=group_ids) ) upload_query &= Q(publish_time=None) diff --git a/nomad/app/v1/routers/uploads.py b/nomad/app/v1/routers/uploads.py index fa6ff235ca0c7454634a8c64fe93f062507048f1..fa296ed0b20188a20c746d526dfcbe184684ea40 100644 --- a/nomad/app/v1/routers/uploads.py +++ b/nomad/app/v1/routers/uploads.py @@ -50,7 +50,7 @@ from nomad.config import config from nomad.config.models.config import Reprocess from nomad.config.models.plugins import ExampleUploadEntryPoint from nomad.files import PublicUploadFiles, StagingUploadFiles -from nomad.groups import get_group_ids +from nomad.groups import MongoUserGroup from nomad.processing import ( Entry, MetadataEditRequestHandler, @@ -2713,7 +2713,7 @@ def get_role_query(roles: list[UploadRole], user: User, include_all=False) -> Q: if not roles: roles = list(UploadRole) - group_ids = get_group_ids(user.user_id, include_all=include_all) + group_ids = MongoUserGroup.get_ids_by_user_id(user.user_id, include_all=include_all) role_query = Q() if UploadRole.main_author in roles: @@ -2739,7 +2739,7 @@ def is_user_upload_viewer(upload: Upload, user: User | None): if user.user_id in upload.viewers: return True - group_ids = get_group_ids(user.user_id) + group_ids = MongoUserGroup.get_ids_by_user_id(user.user_id) if not set(group_ids).isdisjoint(upload.viewer_groups): return True @@ -2753,7 +2753,7 @@ def is_user_upload_writer(upload: Upload, user: User): if user.user_id in upload.writers: return True - group_ids = get_group_ids(user.user_id) + group_ids = MongoUserGroup.get_ids_by_user_id(user.user_id) if not set(group_ids).isdisjoint(upload.writer_groups): return True diff --git a/nomad/graph/graph_reader.py b/nomad/graph/graph_reader.py index 7fd3344227b991b7048858ad8fa64c2fae4d7bbc..84ee4b7f169926fddb6321f01e24312b37421bdb 100644 --- a/nomad/graph/graph_reader.py +++ b/nomad/graph/graph_reader.py @@ -77,7 +77,7 @@ from nomad.graph.model import ( RequestConfig, ResolveType, ) -from nomad.groups import MongoUserGroup +from nomad.groups import MongoUserGroup, get_mongo_user_group from nomad.metainfo import ( Definition, Package, @@ -969,7 +969,7 @@ class GeneralReader: """ def _retrieve(): - return MongoUserGroup.objects(group_id=group_id).first() + return get_mongo_user_group(group_id) try: group: MongoUserGroup = await asyncio.to_thread(_retrieve) diff --git a/nomad/groups.py b/nomad/groups.py index ee792271d02ddaa82ef8bef3226e8a9aae1c08a9..d18de5691912290fbd1e73d0070f60e5d7696745 100644 --- a/nomad/groups.py +++ b/nomad/groups.py @@ -20,7 +20,7 @@ from __future__ import annotations from collections.abc import Iterable -from mongoengine import Document, ListField, Q, QuerySet, StringField +from mongoengine import Document, ListField, Q, QuerySet, StringField, signals from nomad.app.v1.models.groups import UserGroupQuery from nomad.utils import create_uuid @@ -28,13 +28,17 @@ from nomad.utils import create_uuid class MongoUserGroup(Document): """ - A group of users. One user is the owner, all others are members. + A group of users. Members are users, one of them is the owner. """ id_field = 'group_id' group_id = StringField(primary_key=True) group_name = StringField() + + # owner was previously not in members, now it should be + # it's enforced by calling clean() when saving, or instantiating the object + # but for filtering it must still be dealt with separately owner = StringField(required=True) members = ListField(StringField()) @@ -88,10 +92,11 @@ class MongoUserGroup(Document): if query.search_terms is not None: q &= cls.q_by_search_terms(query.search_terms) - return cls.objects(q) + groups = cls.objects(q) # pylint: disable=no-member + return groups @classmethod - def get_ids_by_user_id(cls, user_id: str | None, include_all=True) -> list[str]: + def get_ids_by_user_id(cls, user_id: str | None, *, include_all=True) -> list[str]: """ Returns ids of all user groups where user_id is owner or member. @@ -100,53 +105,72 @@ class MongoUserGroup(Document): """ group_ids = ['all'] if include_all else [] if user_id is not None: - q = cls.q_by_user_id(user_id) - groups = cls.objects(q) + query = UserGroupQuery(user_id=user_id) + groups = cls.get_by_query(query) group_ids.extend(group.group_id for group in groups) return group_ids + def clean(self): + """Add owner to members on clean.""" + self.members = list(set(self.members)) + if self.owner is not None and self.owner not in self.members: + self.members.append(self.owner) + super().clean() + + def clean_update_reload(self, **updates): + """Returns updated group after cleaning, validating, and saving to the DB. + + Use this instead of `update` or `modify` to ensure the object is cleaned. + """ + for k, v in updates.items(): + setattr(self, k, v) + return self.save() + + @classmethod + def _post_init_clean(cls, sender, document, **kwargs): # pylint: disable=unused-argument + """Clean document on retrieval.""" + if getattr(document, '_clean_on_init', True): + document.clean() + + # pylint: disable=protected-access + def reload_without_clean(self, *args, **kwargs): + """Reload document from database without running post_init.""" + func = self.__class__._post_init_clean + signals.post_init.disconnect(func, sender=self.__class__) + try: + return self.reload(*args, **kwargs) + finally: + signals.post_init.connect(func, sender=self.__class__) + -def create_user_group( +signals.post_init.connect(MongoUserGroup._post_init_clean, sender=MongoUserGroup) # pylint: disable=protected-access + + +def create_mongo_user_group( *, group_id: str | None = None, group_name: str | None = None, owner: str | None = None, members: Iterable[str] | None = None, + _clean: bool = True, ) -> MongoUserGroup: - user_group = MongoUserGroup( - group_id=group_id, group_name=group_name, owner=owner, members=members - ) + user_group = MongoUserGroup(group_id=group_id, group_name=group_name, owner=owner) + user_group._clean_on_init = _clean # pylint: disable=attribute-defined-outside-init protected-access + user_group.members = members if user_group.group_id is None: user_group.group_id = create_uuid() if user_group.group_name is None: user_group.group_name = user_group.group_id - user_group.save() + user_group.save(clean=_clean) return user_group -def get_user_ids_by_group_ids(group_ids: list[str]) -> set[str]: - user_ids = set() - - q = MongoUserGroup.q_by_ids(group_ids) - groups = MongoUserGroup.objects(q) - for group in groups: - user_ids.add(group.owner) - user_ids.update(group.members) - - return user_ids - - -def get_user_group(group_id: str) -> MongoUserGroup | None: - q = MongoUserGroup.q_by_ids(group_id) - return MongoUserGroup.objects(q).first() +def get_mongo_user_group(group_id: str) -> MongoUserGroup | None: + return MongoUserGroup.objects(group_id=group_id).first() # pylint: disable=no-member def user_group_exists(group_id: str, *, include_all=True) -> bool: if include_all and group_id == 'all': return True - return get_user_group(group_id) is not None - - -def get_group_ids(user_id: str, include_all=True) -> list[str]: - return MongoUserGroup.get_ids_by_user_id(user_id, include_all=include_all) + return get_mongo_user_group(group_id) is not None diff --git a/nomad/processing/data.py b/nomad/processing/data.py index 028f4ab1a88e012f61ee52a82e0051a0af5fd132..6fdfa7efc0a0823cb13795e5020d42d41a66cb6a 100644 --- a/nomad/processing/data.py +++ b/nomad/processing/data.py @@ -93,7 +93,7 @@ from nomad.files import ( UploadFiles, create_tmp_dir, ) -from nomad.groups import get_group_ids, user_group_exists +from nomad.groups import MongoUserGroup, user_group_exists from nomad.metainfo.data_type import Datatype, Datetime from nomad.normalizing import normalizers from nomad.parsing import Parser @@ -383,7 +383,7 @@ class MetadataEditRequestHandler: return self._error('No matching upload found', 'upload_id') is_admin = self.user.is_admin - group_ids = get_group_ids(self.user.user_id) + group_ids = MongoUserGroup.get_ids_by_user_id(self.user.user_id) for upload in self.affected_uploads: is_main_author = self.user.user_id == upload.main_author is_coauthor = self.user.user_id in upload.coauthors or ( diff --git a/pyproject.toml b/pyproject.toml index 910a96621ca0fbbcdeef2f9f804b1b40de813bc4..127747787c820c4b6bfafa60a54d05f8e10fa020 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,7 @@ repository = 'https://gitlab.mpcdf.mpg.de/nomad-lab/nomad-FAIR' parsing = [] infrastructure = [ 'beautifulsoup4>=4,<=4.12.3', # 4.13 introduced breaking changes + 'blinker', # mongoengine needs this for signals 'celery>=5', 'dockerspawner', 'fastapi', diff --git a/requirements-dev.txt b/requirements-dev.txt index 9c01844f5098e0f2fe4cccd5798b23b05e740135..7118a0dfeca4b06d5adc16c267874d985703ee4f 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -17,6 +17,7 @@ backrefs==5.8 # via mkdocs-material beautifulsoup4==4.12.3 # via -r requirements.txt, nomad-lab (pyproject.toml) billiard==4.2.1 # via celery, -r requirements.txt bitarray==3.1.1 # via -r requirements.txt, nomad-lab (pyproject.toml) +blinker==1.9.0 # via -r requirements.txt, nomad-lab (pyproject.toml) cachetools==5.5.2 # via -r requirements.txt, nomad-lab (pyproject.toml) celery==5.4.0 # via -r requirements.txt, nomad-lab (pyproject.toml) certifi==2025.1.31 # via elasticsearch, httpcore, httpx, requests, -r requirements.txt diff --git a/requirements.txt b/requirements.txt index 01a9f6b837d0062819287501929916fa8281cf57..33eacb409020041901310aa7a8f8dadb386b9647 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,6 +12,7 @@ attrs==25.3.0 # via jsonschema, referencing beautifulsoup4==4.12.3 # via nomad-lab (pyproject.toml) billiard==4.2.1 # via celery bitarray==3.1.1 # via nomad-lab (pyproject.toml) +blinker==1.9.0 # via nomad-lab (pyproject.toml) cachetools==5.5.2 # via nomad-lab (pyproject.toml) celery==5.4.0 # via nomad-lab (pyproject.toml) certifi==2025.1.31 # via elasticsearch, httpcore, httpx, requests diff --git a/tests/app/v1/routers/test_groups.py b/tests/app/v1/routers/test_groups.py index 3b24ceea50396bbfc026f2a4c22d07c3dc5a2afc..2073b21104e658ed954e1103c5780a76e2df476f 100644 --- a/tests/app/v1/routers/test_groups.py +++ b/tests/app/v1/routers/test_groups.py @@ -1,7 +1,8 @@ import pytest from nomad.app.v1.models.groups import UserGroup, UserGroupResponse -from nomad.groups import MongoUserGroup, get_user_group, user_group_exists +from nomad.groups import MongoUserGroup, get_mongo_user_group, user_group_exists +from tests.utils import check_with_retry from .common import assert_response, perform_get, perform_post @@ -34,6 +35,9 @@ def assert_group(group, ref_group, keys=None): assert_unordered_lists(val, ref_val) +# tests using group fixtures with scope: 'module' + + def test_group_collection_name(groups_module): MongoUserGroup._get_collection_name() == 'user_group' @@ -59,7 +63,7 @@ def test_get_groups( response_groups = UserGroupResponse.parse_raw(response.content) for response_group in response_groups.data: - group = get_user_group(response_group.group_id) + group = get_mongo_user_group(response_group.group_id) assert_group(group, response_group) @@ -122,7 +126,7 @@ def test_get_filtered_groups( assert_unordered_lists(response_ids, ref_group_ids) for response_group in response_groups.data: - group = get_user_group(response_group.group_id) + group = get_mongo_user_group(response_group.group_id) assert_group(group, response_group) @@ -149,7 +153,7 @@ def test_get_group( assert_response(response, expected_status_code) response_group = UserGroup.parse_raw(response.content) - group = get_user_group(response_group.group_id) + group = get_mongo_user_group(response_group.group_id) assert_group(group, response_group) assert_group(group, ref_group) @@ -176,11 +180,58 @@ def test_get_group_invalid( assert_response(response, expected_status_code) +# tests using group fixtures with scope: 'function' (default) + + +def test_owner_not_member(auth_headers, client, group_molds, group_owner_not_member): + ref_group = group_molds['owner_not_member_ref'] + + group = group_owner_not_member + assert group.owner not in group.members + group.clean() + assert group.owner in group.members + group.reload_without_clean() + assert group.owner not in group.members + group.reload() + assert group.owner in group.members + group.reload_without_clean() + assert group.owner not in group.members + + # GET returns cleaned group but does not change db + url = f'{base_url}/{group.group_id}' + response = perform_get(client, url, auth_headers['user1']) + assert_response(response, 200) + + response_group = UserGroup.parse_raw(response.content) + assert_group(response_group, ref_group, ref_group.keys()) + + group.reload_without_clean() + assert group.owner not in group.members + + # POST cleans group in db and returns it + url = f'{base_url}/{group.group_id}/edit' + group_edit = {'group_name': group.group_name} + response = perform_post(client, url, auth_headers['user1'], json=group_edit) + assert_response(response, 200) + + response_group = UserGroup.parse_raw(response.content) + assert_group(response_group, ref_group, ref_group.keys()) + + assert group.owner not in group.members + + # cleaned group has been saved but db might not be updated yet + def condition(): + group.reload_without_clean() + return group.owner in group.members + + assert check_with_retry(condition) + + @pytest.mark.parametrize( 'user_label, new_group_label, ref_group_label, expected_status_code', [ - pytest.param('user1', 'new_group', 'new_group', 201, id='user1'), - pytest.param('user2', 'new_group', 'new_group', 201, id='user2'), + pytest.param('user1', 'new_group', 'new_group_ref1', 201, id='user1'), + pytest.param('user2', 'new_group', 'new_group_ref2', 201, id='user2'), pytest.param('invalid', 'new_group', None, 401, id='invalid-user'), pytest.param(None, 'new_group', None, 401, id='guest-user'), pytest.param('user1', 'short_name', None, 422, id='short-name-fails'), @@ -214,7 +265,7 @@ def test_create_group( return response_group = UserGroup.parse_raw(response.content) - group = get_user_group(response_group.group_id) + group = get_mongo_user_group(response_group.group_id) assert_group(group, response_group) ref_group = group_molds[ref_group_label] assert_group(group, ref_group, ref_group.keys()) @@ -226,7 +277,7 @@ def test_create_group( pytest.param(None, 'new_group', None, 401, id='guest-fails'), pytest.param('invalid', 'new_group', None, 401, id='faker-fails'), pytest.param('user2', 'new_group', None, 401, id='user2-fails'), - pytest.param('user1', 'new_group', 'new_group', 200, id='edit-ok'), + pytest.param('user1', 'new_group', 'new_group_ref1', 200, id='edit-ok'), pytest.param('user1', 'short_name', None, 422, id='short-name-fails'), pytest.param('user1', 'long_name', None, 422, id='long-name-fails'), pytest.param('user1', 'special_char', None, 422, id='special-chars-fails'), @@ -249,13 +300,13 @@ def test_update_user_group( ref_group_label, expected_status_code, ): - group_before = get_user_group(groups_function['group1'].group_id) + group_before = get_mongo_user_group(groups_function['group1'].group_id) group_edit = group_molds[group_edit_label] url = f'{base_url}/{group_before.group_id}/edit' response = perform_post(client, url, auth_headers[user_label], json=group_edit) assert_response(response, expected_status_code) - group_after = get_user_group(group_before.group_id) + group_after = get_mongo_user_group(group_before.group_id) if response.status_code != 200: assert_group(group_after, group_before) diff --git a/tests/fixtures/groups.py b/tests/fixtures/groups.py index f91bc50439ff2ce67e75a5fdb96de401c2b1f301..03e1b16175fc83fc305efc5785b4c06d20974801 100644 --- a/tests/fixtures/groups.py +++ b/tests/fixtures/groups.py @@ -7,7 +7,7 @@ Group fixtures: import pytest -from nomad.groups import create_user_group +from nomad.groups import create_mongo_user_group from tests.utils import fake_group_uuid, fake_user_uuid, generate_convert_label @@ -15,9 +15,9 @@ from tests.utils import fake_group_uuid, fake_user_uuid, generate_convert_label def group_molds(): """Return a dict: group label -> group data (dict).""" - def old_group(owner, members, group_str=None): + def default_group(owner, members, group_str=None): if group_str is None: - group_str = str(owner) + ''.join(str(m) for m in members) + group_str = str(owner) + ''.join(str(m) for m in members if m != owner) return dict( group_id=fake_group_uuid(group_str), @@ -26,41 +26,48 @@ def group_molds(): members=[fake_user_uuid(member) for member in members], ) - def new_group(group_name, members): - return dict( + def custom_group(group_name, members, owner=None): + mold = dict( group_name=group_name, members=[fake_user_uuid(member) for member in members], ) - - old_groups = { - 'group0': old_group(0, []), - 'group1': old_group(1, []), - 'group2': old_group(2, []), - 'group3': old_group(3, []), - 'group6': old_group(6, []), - 'group8': old_group(8, []), - 'group9': old_group(9, []), - 'group14': old_group(1, [4]), - 'group15': old_group(1, [5]), - 'group18': old_group(1, [8]), - 'group19': old_group(1, [9]), - 'group123': old_group(1, [2, 3]), - 'uniq': old_group(0, [], 'Uniq'), - 'twin1': old_group(0, [], 'Twin One'), - 'twin2': old_group(0, [], 'Twin Two'), - 'numerals': old_group(0, [], 'One Two Three'), + if owner is not None: + mold['owner'] = fake_user_uuid(owner) + return mold + + default_groups = { + 'group0': default_group(0, [0]), + 'group1': default_group(1, [1]), + 'group2': default_group(2, [2]), + 'group3': default_group(3, [3]), + 'group6': default_group(6, [6]), + 'group8': default_group(8, [8]), + 'group9': default_group(9, [9]), + 'group14': default_group(1, [1, 4]), + 'group15': default_group(1, [1, 5]), + 'group18': default_group(1, [1, 8]), + 'group19': default_group(1, [1, 9]), + 'group123': default_group(1, [1, 2, 3]), + 'uniq': default_group(0, [0], 'Uniq'), + 'twin1': default_group(0, [0], 'Twin One'), + 'twin2': default_group(0, [0], 'Twin Two'), + 'numerals': default_group(0, [0], 'One Two Three'), } - new_groups = { - 'new_group': new_group('New Group X23', [2, 3]), - 'short_name': new_group('GG', []), - 'long_name': new_group('G' * 33, []), - 'double_member': new_group('Double Member', [2, 3, 2]), - 'double_member_ref': new_group('Double Member', [2, 3]), - 'special_char': new_group('G!G', []), + custom_groups = { + 'new_group': custom_group('New Group X23', [2, 3]), + 'new_group_ref1': custom_group('New Group X23', [1, 2, 3]), + 'new_group_ref2': custom_group('New Group X23', [2, 3]), + 'short_name': custom_group('GG', []), + 'long_name': custom_group('G' * 33, []), + 'double_member': custom_group('Double Member', [2, 3, 2]), + 'double_member_ref': custom_group('Double Member', [1, 2, 3]), + 'special_char': custom_group('G!G', []), + 'owner_not_member': custom_group('Owner Not Member', [2, 3], owner=1), + 'owner_not_member_ref': custom_group('Owner Not Member', [1, 2, 3], owner=1), } - return {**old_groups, **new_groups} + return {**default_groups, **custom_groups} @pytest.fixture(scope='session') @@ -90,7 +97,9 @@ def create_user_groups(group_molds): def create(): groups_with_id = {k: v for k, v in group_molds.items() if 'group_id' in v} - user_groups = {k: create_user_group(**v) for k, v in groups_with_id.items()} + user_groups = { + k: create_mongo_user_group(**v) for k, v in groups_with_id.items() + } return user_groups @@ -107,3 +116,11 @@ def groups_module(mongo_module, create_user_groups): def groups_function(mongo_function, create_user_groups): """Create and return predefined user groups for testing (function scope).""" return create_user_groups() + + +@pytest.fixture +def group_owner_not_member(mongo_function, group_molds): + """Create and return a group where owner is not a member (old behavior).""" + mold = group_molds['owner_not_member'] + group = create_mongo_user_group(**mold, _clean=False) + return group diff --git a/tests/graph/test_graph_reader.py b/tests/graph/test_graph_reader.py index d6f9d75cd2a38a1a978894c08620a68dad10b005..ad033aec9c837fda3d1e9f9784d4b8a750f62803 100644 --- a/tests/graph/test_graph_reader.py +++ b/tests/graph/test_graph_reader.py @@ -35,6 +35,7 @@ from nomad.graph.graph_reader import ( from nomad.graph.lazy_wrapper import LazyWrapper from nomad.utils.exampledata import ExampleData from tests.normalizing.conftest import simulationworkflowschema +from tests.utils import ListWithSortKey def rprint(msg): @@ -54,9 +55,12 @@ def assert_time(i, j): assert i == j -def assert_list(l1, l2): - assert len(l1) == len(l2) - for i, j in zip(l1, l2): +def assert_list(observed, expected): + assert len(observed) == len(expected) + if isinstance(expected, ListWithSortKey): + observed = sorted(observed, key=expected.sort_key) + expected = sorted(expected, key=expected.sort_key) + for i, j in zip(observed, expected): if isinstance(i, LazyWrapper): i = i.to_json() if isinstance(i, dict): @@ -67,27 +71,27 @@ def assert_list(l1, l2): assert_time(i, j) -def assert_dict(d1, d2): - if GeneralReader.__CACHE__ in d1: - del d1[GeneralReader.__CACHE__] - if 'm_response' in d1: - del d1['m_response'] - if 'm_def' in d1: - del d1['m_def'] - if 'm_def' in d2: - del d2['m_def'] - assert set(d1.keys()) == set(d2.keys()) - for k, v in d1.items(): +def assert_dict(observed, expected): + if GeneralReader.__CACHE__ in observed: + del observed[GeneralReader.__CACHE__] + if 'm_response' in observed: + del observed['m_response'] + if 'm_def' in observed: + del observed['m_def'] + if 'm_def' in expected: + del expected['m_def'] + assert set(observed.keys()) == set(expected.keys()) + for k, v in observed.items(): if isinstance(v, LazyWrapper): v = v.to_json() if isinstance(v, dict): - assert_dict(v, d2[k]) + assert_dict(v, expected[k]) elif isinstance(v, list): - assert_list(v, d2[k]) + assert_list(v, expected[k]) elif k == 'upload_files_server_path': continue else: - assert_time(v, d2[k]) + assert_time(v, expected[k]) user_dict = { @@ -2189,17 +2193,30 @@ def test_group_reader(groups_function, user1): 'is_admin': False, 'is_oasis_admin': True, }, - 'members': [ - { - 'name': 'Rajesh Koothrappali', - 'first_name': 'Rajesh', - 'last_name': 'Koothrappali', - 'email': 'rajesh.koothrappali@nomad-fairdi.tests.de', - 'user_id': '00000000-0000-0000-0000-000000000004', - 'username': 'rkoothrappali', - 'is_admin': False, - } - ], + 'members': ListWithSortKey( + ( + { + 'name': 'Rajesh Koothrappali', + 'first_name': 'Rajesh', + 'last_name': 'Koothrappali', + 'email': 'rajesh.koothrappali@nomad-fairdi.tests.de', + 'user_id': '00000000-0000-0000-0000-000000000004', + 'username': 'rkoothrappali', + 'is_admin': False, + }, + { + 'name': 'Sheldon Cooper', + 'first_name': 'Sheldon', + 'last_name': 'Cooper', + 'email': 'sheldon.cooper@nomad-coe.eu', + 'user_id': '00000000-0000-0000-0000-000000000001', + 'username': 'scooper', + 'is_admin': False, + 'is_oasis_admin': True, + }, + ), + sort_key=lambda x: x['user_id'], + ), } } }, diff --git a/tests/states/groups.py b/tests/states/groups.py index 3dae035171a6b35a4e1b9bbc6de119108c486eca..1f78b5fa856feb82c7045c5a0e2b54d0eb1c81b6 100644 --- a/tests/states/groups.py +++ b/tests/states/groups.py @@ -1,16 +1,16 @@ from nomad import infrastructure -from nomad.groups import create_user_group, get_user_group +from nomad.groups import create_mongo_user_group, get_mongo_user_group def _create(group_id, group_name, owner, members=None): members = members or [] - return create_user_group( + return create_mongo_user_group( group_id=group_id, group_name=group_name, owner=owner, members=members ) def delete_group(group_id): - get_user_group(group_id).delete() + get_mongo_user_group(group_id).delete() def init_gui_test_groups(): diff --git a/tests/utils.py b/tests/utils.py index c24fe475952acda078d98d74bebfc980be751012..18f8c6db63197c36c7b3be4d4b0387785ea97556 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -19,6 +19,7 @@ """Methods to help with testing of nomad@FAIRDI.""" import os.path +import time import urllib.parse import zipfile from logging import LogRecord @@ -27,6 +28,17 @@ from typing import Any import pytest +class ListWithSortKey(list): + """List with an attribute `sort_key`. + + Use to sort two lists by the same key with sorted(). + """ + + def __init__(self, iterable=(), sort_key=None): + super().__init__(iterable) + self.sort_key = sort_key + + def assert_log( log_output, level: str, event_part: str, negate: bool = False ) -> LogRecord: @@ -191,3 +203,13 @@ def dict_to_params(d): Can be used to make the parametrize decorator more concise.""" return [pytest.param(*item, id=id) for id, item in d.items()] + + +def check_with_retry(condition_func, retries=5, delay0=0.1): + for attempt in range(retries): + if condition_func(): + return True + + time.sleep(delay0 * (attempt + 1)) + + return condition_func()