From 1a8163281fb9c764938bd8fdd56a4804ce4fb7cd Mon Sep 17 00:00:00 2001
From: Sascha Klawohn <sascha.klawohn@physik.hu-berlin.de>
Date: Wed, 21 Feb 2024 16:20:10 +0100
Subject: [PATCH] Update group tests

---
 nomad/app/v1/routers/groups.py      |   8 +-
 nomad/groups.py                     |   6 +-
 nomad/processing/data.py            |   1 -
 tests/app/v1/routers/test_groups.py | 133 +++++++++++++++++-----------
 tests/conftest.py                   |  94 +++++++++++---------
 5 files changed, 144 insertions(+), 98 deletions(-)

diff --git a/nomad/app/v1/routers/groups.py b/nomad/app/v1/routers/groups.py
index 1c4dfebb01..b6b27d739e 100644
--- a/nomad/app/v1/routers/groups.py
+++ b/nomad/app/v1/routers/groups.py
@@ -24,7 +24,7 @@ from pydantic import BaseModel, Field
 from nomad.datamodel import User as UserDataModel
 from nomad.groups import (
     UserGroup as MongoUserGroup,
-    create_user_group as _create_user_group,
+    create_user_group as create_mongo_user_group,
 )
 from nomad.utils import strip
 
@@ -58,8 +58,8 @@ class UserGroup(BaseModel):
         default='Default Group Name', description=group_name_description
     )
     owner: str = Field(description='User id of the group owner.')
-    members: Set[str] = Field(
-        default_factory=set, description=group_members_description
+    members: List[str] = Field(
+        default_factory=list, description=group_members_description
     )
 
     class Config:
@@ -147,7 +147,7 @@ async def create_user_group(
         check_user_ids(members)
     user_group_dict['owner'] = user.user_id
 
-    user_group = _create_user_group(**user_group_dict)
+    user_group = create_mongo_user_group(**user_group_dict)
     return user_group
 
 
diff --git a/nomad/groups.py b/nomad/groups.py
index 58b6b5cbed..6487abeff2 100644
--- a/nomad/groups.py
+++ b/nomad/groups.py
@@ -48,8 +48,10 @@ class UserGroup(Document):
 
     @classmethod
     def get_ids_by_user_id(cls, user_id: Optional[str]) -> List[str]:
-        """Returns ids of all user groups where user_id is owner or member.
-        Does include special group 'all', even if user_id is missing or not a user."""
+        """
+        Returns ids of all user groups where user_id is owner or member.
+        Does include special group 'all', even if user_id is missing or not a user.
+        """
         group_ids = ['all']
         if user_id is not None:
             group_ids.extend(group.group_id for group in cls.get_by_user_id(user_id))
diff --git a/nomad/processing/data.py b/nomad/processing/data.py
index 0a715b0e51..d38cc0019e 100644
--- a/nomad/processing/data.py
+++ b/nomad/processing/data.py
@@ -85,7 +85,6 @@ from nomad.files import (
     create_tmp_dir,
     is_safe_relative_path,
 )
-from nomad.groups import UserGroup, get_user_ids_by_group_ids
 from nomad.processing.base import (
     Proc,
     process,
diff --git a/tests/app/v1/routers/test_groups.py b/tests/app/v1/routers/test_groups.py
index 3b710bd3b3..37e61240b2 100644
--- a/tests/app/v1/routers/test_groups.py
+++ b/tests/app/v1/routers/test_groups.py
@@ -7,12 +7,30 @@ from nomad.groups import get_user_group, user_group_exists
 base_url = 'groups'
 
 
-@pytest.fixture
-def new_group(test_user, other_test_user):
-    return {
-        'group_name': 'New Group',
-        'members': [test_user.user_id, other_test_user.user_id],
-    }
+def get_val(obj, key):
+    if isinstance(obj, dict):
+        return obj[key]
+
+    return getattr(obj, key)
+
+
+def assert_unordered_lists(list1, list2):
+    assert sorted(list1) == sorted(list2)
+
+
+def assert_group(group, ref_group, keys=None):
+    if keys is None:
+        keys = UserGroup.__fields__
+
+    excluded_fields = {'members'}
+    fields = set(keys) - excluded_fields
+    for field in fields:
+        assert get_val(group, field) == get_val(ref_group, field)
+
+    if 'members' in keys:
+        val = get_val(group, 'members')
+        ref_val = get_val(ref_group, 'members')
+        assert_unordered_lists(val, ref_val)
 
 
 @pytest.mark.parametrize(
@@ -27,20 +45,20 @@ def new_group(test_user, other_test_user):
 def test_get_groups(
     client,
     mongo_module,
-    user_label,
     test_auth_dict,
     user_groups_module,
+    user_label,
     expected_status_code,
 ):
-    user_auth, __token = test_auth_dict[user_label]
+    user_auth, _ = test_auth_dict[user_label]
 
     response = perform_get(client, base_url, user_auth)
     assert_response(response, expected_status_code)
 
-    groups = UserGroups.parse_raw(response.content)
-    for group, ex_group in zip(groups.data, user_groups_module.values()):
-        ex_group = UserGroup.from_orm(ex_group)
-        assert group == ex_group
+    response_groups = UserGroups.parse_raw(response.content)
+    for response_group in response_groups.data:
+        group = get_user_group(response_group.group_id)
+        assert_group(group, response_group)
 
 
 @pytest.mark.parametrize(
@@ -54,20 +72,21 @@ def test_get_groups(
 )
 def test_get_group(
     client,
-    user_label,
     test_auth_dict,
     user_groups_module,
-    user_owner_group,
+    user_label,
     expected_status_code,
 ):
-    user_auth, __token = test_auth_dict[user_label]
+    user_auth, _ = test_auth_dict[user_label]
+    ref_group = user_groups_module['other_owner_group']
 
-    response = perform_get(client, f'{base_url}/{user_owner_group.group_id}', user_auth)
+    response = perform_get(client, f'{base_url}/{ref_group.group_id}', user_auth)
     assert_response(response, expected_status_code)
 
-    group = UserGroup.parse_raw(response.content)
-    ex_group = UserGroup.from_orm(user_owner_group)
-    assert group == ex_group
+    response_group = UserGroup.parse_raw(response.content)
+    group = get_user_group(response_group.group_id)
+    assert_group(group, response_group)
+    assert_group(group, ref_group)
 
 
 @pytest.mark.parametrize(
@@ -82,37 +101,52 @@ def test_get_group(
 def test_get_group_invalid(
     client,
     mongo_module,
-    user_label,
     test_auth_dict,
+    user_groups_module,
+    user_label,
     expected_status_code,
 ):
-    user_auth, __token = test_auth_dict[user_label]
+    user_auth, _ = test_auth_dict[user_label]
+
     response = perform_get(client, f'{base_url}/invalid-group-id', user_auth)
     assert_response(response, expected_status_code)
 
 
 @pytest.mark.parametrize(
-    'user_label, expected_status_code',
+    'user_label, new_group_label, ref_group_label, expected_status_code',
     [
-        pytest.param('test_user', 201, id='test-user'),
-        pytest.param('other_test_user', 201, id='other-test-user'),
-        pytest.param('invalid', 401, id='invalid-user'),
-        pytest.param(None, 401, id='guest-user'),
+        pytest.param('test_user', 'new_group', 'new_group', 201, id='test-user'),
+        pytest.param(
+            'other_test_user', 'new_group', 'new_group', 201, id='other-test-user'
+        ),
+        pytest.param('invalid', 'new_group', None, 401, id='invalid-user'),
+        pytest.param(None, 'new_group', None, 401, id='guest-user'),
     ],
 )
 def test_create_group(
-    client, mongo_function, user_label, test_auth_dict, new_group, expected_status_code
+    client,
+    mongo_function,
+    request,
+    test_auth_dict,
+    test_user_groups_dict,
+    user_label,
+    new_group_label,
+    ref_group_label,
+    expected_status_code,
 ):
-    user_auth, __token = test_auth_dict[user_label]
+    user_auth, _ = test_auth_dict[user_label]
+    new_group = test_user_groups_dict[new_group_label]
+
     response = perform_post(client, base_url, user_auth, json=new_group)
     assert_response(response, expected_status_code)
 
     if response.status_code != 201:
         return
 
-    groups = UserGroup.parse_raw(response.content)
-    assert groups.group_name == new_group['group_name']
-    assert set(groups.members) == set(new_group['members'])
+    response_group = UserGroup.parse_raw(response.content)
+    assert_group(response_group, new_group, new_group.keys())
+    group = get_user_group(response_group.group_id)
+    assert_group(group, response_group)
 
 
 @pytest.mark.parametrize(
@@ -130,33 +164,32 @@ def test_create_group(
 def test_update_user_group(
     client,
     mongo_function,
-    user_label,
     test_auth_dict,
+    test_user_groups_dict,
     user_groups_function,
     user_owner_group,
-    new_group,
+    user_label,
     variation,
     expected_status_code,
 ):
-    user_auth, __token = test_auth_dict[user_label]
+    user_auth, _ = test_auth_dict[user_label]
     group_before = get_user_group(user_owner_group.group_id)
-    new_group.update(variation)
+    group_edit = test_user_groups_dict['new_group']
+    group_edit.update(variation)
 
-    response = perform_post(
-        client,
-        f'{base_url}/{user_owner_group.group_id}/edit',
-        user_auth,
-        json=new_group,
-    )
+    url = f'{base_url}/{group_before.group_id}/edit'
+    response = perform_post(client, url, user_auth, json=group_edit)
     assert_response(response, expected_status_code)
-    group_after = get_user_group(user_owner_group.group_id)
+    group_after = get_user_group(group_before.group_id)
 
     if response.status_code != 200:
-        assert group_before == group_after
+        assert_group(group_after, group_before)
         return
 
-    assert group_after.group_name == new_group['group_name']
-    assert set(group_after.members) == set(new_group['members'])
+    keys = group_edit.keys()
+    assert_group(group_after, group_edit, keys)
+    keys = group_after._fields - group_edit.keys()
+    assert_group(group_after, group_before, keys)
 
 
 @pytest.mark.parametrize(
@@ -170,14 +203,14 @@ def test_update_user_group(
 )
 def test_delete_group(
     client,
-    user_label,
+    other_owner_group,
     test_auth_dict,
     user_groups_function,
     user_owner_group,
-    other_owner_group,
+    user_label,
     expected_status_code,
 ):
-    user_auth, __token = test_auth_dict[user_label]
+    user_auth, _ = test_auth_dict[user_label]
 
     response = client.delete(
         f'{base_url}/{user_owner_group.group_id}', headers=user_auth
@@ -203,13 +236,13 @@ def test_delete_group(
 )
 def test_delete_group_invalid(
     client,
-    user_label,
     test_auth_dict,
     user_groups_function,
     user_owner_group,
+    user_label,
     expected_status_code,
 ):
-    user_auth, __token = test_auth_dict[user_label]
+    user_auth, _ = test_auth_dict[user_label]
 
     response = client.delete(f'{base_url}/invalid-group-id', headers=user_auth)
     assert_response(response, expected_status_code)
diff --git a/tests/conftest.py b/tests/conftest.py
index bea9e5cc29..4ea3b749b0 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -490,34 +490,36 @@ def test_user_group_uuid(handle: Any):
     return str(handle).rjust(22, 'G')
 
 
-test_user_groups = {
-    'admin_owner_group': dict(
-        group_name='Admin Owner Group',
-        owner=test_user_uuid(0),
-        group_id=test_user_group_uuid(0),
-    ),
-    'user_owner_group': dict(
-        group_name='Test Owner Group',
-        owner=test_user_uuid(1),
-        group_id=test_user_group_uuid(1),
-    ),
-    'other_owner_group': dict(
-        group_name='Other Owner Group',
-        owner=test_user_uuid(2),
-        group_id=test_user_group_uuid(2),
-    ),
-    'mixed_group': dict(
-        group_name='Mixed Group',
-        owner=test_user_uuid(0),
-        members=[test_user_uuid(1), test_user_uuid(2)],
-        group_id=test_user_group_uuid(3),
-    ),
-}
+@pytest.fixture(scope='session')
+def test_user_groups_dict():
+    def old_group(group_id, group_name, owner, members):
+        return dict(
+            group_id=test_user_group_uuid(group_id),
+            group_name=group_name,
+            owner=test_user_uuid(owner),
+            members=[test_user_uuid(member) for member in members],
+        )
+
+    def new_group(group_name, members):
+        return dict(
+            group_name=group_name,
+            members=[test_user_uuid(member) for member in members],
+        )
+
+    return {
+        'admin_owner_group': old_group(0, 'Admin Owner Group', 0, []),
+        'user_owner_group': old_group(1, 'User Owner Group', 1, []),
+        'other_owner_group': old_group(2, 'Other Owner Group', 2, []),
+        'mixed_group': old_group(3, 'Mixed Group', 0, [1, 2]),
+        'new_group': new_group('New Group', [0, 2]),
+    }
 
 
 @pytest.fixture(scope='session')
-def convert_group_labels_to_ids():
-    mapping = {label: group['group_id'] for label, group in test_user_groups.items()}
+def convert_group_labels_to_ids(test_user_groups_dict):
+    mapping = {
+        label: group['group_id'] for label, group in test_user_groups_dict.items()
+    }
 
     def convert(raw):
         if isinstance(raw, str):
@@ -535,37 +537,47 @@ def convert_group_labels_to_ids():
 
 
 @pytest.fixture(scope='session')
-def user_owner_group():
-    return UserGroup(**test_user_groups['user_owner_group'])
+def user_owner_group(test_user_groups_dict):
+    return UserGroup(**test_user_groups_dict['user_owner_group'])
 
 
 @pytest.fixture(scope='session')
-def other_owner_group():
-    return UserGroup(**test_user_groups['other_owner_group'])
+def other_owner_group(test_user_groups_dict):
+    return UserGroup(**test_user_groups_dict['other_owner_group'])
 
 
 @pytest.fixture(scope='session')
-def mixed_group():
-    return UserGroup(**test_user_groups['mixed_group'])
+def mixed_group(test_user_groups_dict):
+    return UserGroup(**test_user_groups_dict['mixed_group'])
+
 
+@pytest.fixture(scope='session')
+def create_user_groups(test_user_groups_dict):
+    def create():
+        user_groups = {}
+        for label in [
+            'admin_owner_group',
+            'user_owner_group',
+            'other_owner_group',
+            'mixed_group',
+        ]:
+            group = test_user_groups_dict[label]
+            user_group = create_user_group(**group)
+            user_groups[label] = user_group
 
-def _user_groups():
-    user_groups = {}
-    for label, group in test_user_groups.items():
-        user_group = create_user_group(**group)
-        user_groups[label] = user_group
+        return user_groups
 
-    return user_groups
+    return create
 
 
 @pytest.fixture(scope='module')
-def user_groups_module(mongo_module):
-    return _user_groups()
+def user_groups_module(mongo_module, create_user_groups):
+    return create_user_groups()
 
 
 @pytest.fixture(scope='function')
-def user_groups_function(mongo_function):
-    return _user_groups()
+def user_groups_function(mongo_function, create_user_groups):
+    return create_user_groups()
 
 
 @pytest.fixture(scope='function')
-- 
GitLab