From 9ab0794ed65bdfdebc1db971de683eef156a4d20 Mon Sep 17 00:00:00 2001
From: Sascha Klawohn <sascha.klawohn@physik.hu-berlin.de>
Date: Mon, 15 Jan 2024 10:53:40 +0100
Subject: [PATCH] Add test 'add groups to upload'

---
 tests/app/v1/routers/test_groups.py           |  2 +-
 tests/app/v1/routers/uploads/common.py        |  2 +-
 .../v1/routers/uploads/test_group_uploads.py  | 67 ++++++++++++++++++-
 tests/conftest.py                             | 17 ++++-
 4 files changed, 82 insertions(+), 6 deletions(-)

diff --git a/tests/app/v1/routers/test_groups.py b/tests/app/v1/routers/test_groups.py
index c0d969495e..e444152d2e 100644
--- a/tests/app/v1/routers/test_groups.py
+++ b/tests/app/v1/routers/test_groups.py
@@ -38,7 +38,7 @@ def test_get_groups(
     assert_response(response, expected_status_code)
 
     groups = UserGroups.parse_raw(response.content)
-    for group, ex_group in zip(groups.data, user_groups_module):
+    for group, ex_group in zip(groups.data, user_groups_module.values()):
         ex_group = UserGroup.from_orm(ex_group)
         assert group == ex_group
 
diff --git a/tests/app/v1/routers/uploads/common.py b/tests/app/v1/routers/uploads/common.py
index c915b21892..a3432721be 100644
--- a/tests/app/v1/routers/uploads/common.py
+++ b/tests/app/v1/routers/uploads/common.py
@@ -18,4 +18,4 @@ def assert_upload(response_json, **kwargs):
 
     for key, value in kwargs.items():
         assert data.get(key, None) == value
-    return data
\ No newline at end of file
+    return data
diff --git a/tests/app/v1/routers/uploads/test_group_uploads.py b/tests/app/v1/routers/uploads/test_group_uploads.py
index 74b35b8e14..b1de5eba45 100644
--- a/tests/app/v1/routers/uploads/test_group_uploads.py
+++ b/tests/app/v1/routers/uploads/test_group_uploads.py
@@ -1,7 +1,10 @@
 import pytest
-from ..common import assert_response, perform_get
+
+from nomad.processing.data import Upload
+from ..common import assert_response, perform_get, perform_post
 from .common import assert_upload
 
+
 @pytest.mark.parametrize(
     'kwargs',
     [
@@ -124,3 +127,65 @@ def test_get_group_upload(
     assert_response(response, expected_status_code)
     if expected_status_code == 200:
         assert_upload(response.json())
+
+
+@pytest.mark.parametrize(
+    'user, expected_status_code, group_quantity, new_groups',
+    [
+        pytest.param(
+            'test_user',
+            200,
+            'coauthor_groups',
+            ['other_owner_group'],
+            id='coauthor-other-group',
+        ),
+        pytest.param(
+            'test_user',
+            200,
+            'coauthor_groups',
+            ['user_owner_group', 'other_owner_group', 'mixed_group'],
+            id='coauthor-multiple-groups',
+        ),
+        pytest.param(
+            'test_user',
+            200,
+            'reviewer_groups',
+            ['other_owner_group'],
+            id='reviewer-other-group',
+        ),
+        pytest.param(
+            'other_test_user',
+            422,
+            'reviewer_groups',
+            ['other_owner_group'],
+            id='other-user-reviewer-other-group',
+        ),
+    ],
+)
+def test_add_groups_to_upload(
+    client,
+    user_groups_module,
+    proc_infra,
+    upload_no_group,
+    test_auth_dict,
+    user,
+    expected_status_code,
+    group_quantity,
+    new_groups,
+):
+    user_auth, __token = test_auth_dict[user]
+    upload_id = list(upload_no_group.uploads)[0]
+    new_group_ids = [user_groups_module[label].group_id for label in new_groups]
+
+    url = f'uploads/{upload_id}/edit'
+    metadata = {group_quantity: new_group_ids}
+    edit_request = dict(metadata=metadata)
+    response = perform_post(client, url, user_auth, json=edit_request)
+
+    assert_response(response, expected_status_code)
+    if expected_status_code != 200:
+        return
+
+    upload = Upload.get(upload_id)
+    upload.block_until_complete()
+    assert getattr(upload, group_quantity) == new_group_ids
diff --git a/tests/conftest.py b/tests/conftest.py
index 66975740b4..845dcc3b7e 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -531,10 +531,10 @@ def mixed_group():
 
 
 def _user_groups():
-    user_groups = []
-    for group in test_user_groups.values():
+    user_groups = {}
+    for label, group in test_user_groups.items():
         user_group = create_user_group(**group)
-        user_groups.append(user_group)
+        user_groups[label] = user_group
 
     return user_groups
 
@@ -1332,6 +1332,17 @@ def example_data_groups(
     data.save(with_files=False)
 
 
+@pytest.fixture(scope='function')
+def upload_no_group(mongo_function, test_user):
+    data = ExampleData(main_author=test_user)
+    data.create_upload(upload_id='id_no_group')
+    data.save()
+
+    yield data
+
+    data.delete()
+
+
 @pytest.fixture(scope='function')
 def example_datasets(mongo_function, test_user, other_test_user):
     dataset_specs = (
-- 
GitLab