diff --git a/nomad/app/v1/routers/groups.py b/nomad/app/v1/routers/groups.py index b6b27d739e7aa756e1557437ac0e75cf03094fa2..0c6f24329f2605f44f356b949048eaa47666d9f3 100644 --- a/nomad/app/v1/routers/groups.py +++ b/nomad/app/v1/routers/groups.py @@ -18,19 +18,20 @@ from typing import List, Optional, Set -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends, HTTPException, Query, status from pydantic import BaseModel, Field from nomad.datamodel import User as UserDataModel from nomad.groups import ( UserGroup as MongoUserGroup, +) +from nomad.groups import ( create_user_group as create_mongo_user_group, ) from nomad.utils import strip -from .auth import create_user_dependency from ..models import User - +from .auth import create_user_dependency router = APIRouter() default_tag = 'groups' @@ -75,7 +76,7 @@ def get_mongo_user_group(group_id: str) -> MongoUserGroup: if user_group is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=strip(f"User group '{group_id}' was not found."), + detail=f"User group '{group_id}' was not found.", ) return user_group @@ -107,11 +108,32 @@ def check_user_may_edit_user_group(user: User, user_group: MongoUserGroup): @router.get( - '', tags=[default_tag], summary='List user groups.', response_model=UserGroups + '', + tags=[default_tag], + summary='List user groups. Use at most one filter.', + response_model=UserGroups, ) -async def get_user_groups(): +async def get_user_groups( + group_id: Optional[List[str]] = Query( + None, description='Search groups by their full id.' + ), + search_terms: Optional[str] = Query( + None, description='Search groups by parts of their name.' + ), +): """Get data about user groups.""" - user_groups = MongoUserGroup.objects + if group_id is not None and search_terms is not None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail='Only one of group_id or search_terms may be used at a time.', + ) + + if group_id is not None: + user_groups = MongoUserGroup.get_by_ids(group_id) + elif search_terms is not None: + user_groups = MongoUserGroup.get_by_search_terms(search_terms) + else: + user_groups = MongoUserGroup.objects data = [UserGroup.from_orm(user_group) for user_group in user_groups] return {'data': data} diff --git a/nomad/groups.py b/nomad/groups.py index 99a2085d2bfb42a9c671a7986f6c934c9c0d77c8..c0637226a440dbdcf019410791ffd80f381db9af 100644 --- a/nomad/groups.py +++ b/nomad/groups.py @@ -15,9 +15,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Iterable, List, Optional +import operator +from functools import reduce +from typing import Iterable, Optional, Union -from mongoengine import Document, StringField, ListField +from mongoengine import Document, ListField, StringField from mongoengine.queryset.visitor import Q from nomad.utils import create_uuid @@ -37,6 +39,16 @@ class UserGroup(Document): meta = {'indexes': ['group_name', 'owner', 'members']} + @classmethod + def get_by_ids(cls, group_ids: Union[str, Iterable[str]]): + """ + Returns UserGroup objects with group_ids. + """ + if not isinstance(group_ids, Iterable): + group_ids = [group_ids] + user_groups = cls.objects(group_id__in=group_ids) + return user_groups + @classmethod def get_by_user_id(cls, user_id: Optional[str]): """ @@ -48,7 +60,7 @@ class UserGroup(Document): return user_groups @classmethod - def get_ids_by_user_id(cls, user_id: Optional[str], include_all=True) -> List[str]: + def get_ids_by_user_id(cls, user_id: Optional[str], include_all=True) -> list[str]: """ Returns ids of all user groups where user_id is owner or member. Does include special group 'all', even if user_id is missing or not a user. @@ -58,6 +70,20 @@ class UserGroup(Document): group_ids.extend(group.group_id for group in cls.get_by_user_id(user_id)) return group_ids + @classmethod + def get_by_search_terms(cls, search_terms: str): + """ + Returns UserGroup objects where group_name includes search_terms (no case). + """ + split_terms = str(search_terms).split() + if not split_terms: + return [] + + query = (Q(group_name__icontains=term) for term in split_terms) + query = reduce(operator.and_, query) + user_groups = cls.objects(query) + return user_groups + def create_user_group( *, diff --git a/tests/app/v1/routers/test_groups.py b/tests/app/v1/routers/test_groups.py index 49b1fe93fa01eac120cca6f23aa9d2c7e7f21e5e..0e55c650570af89a28fc7a3611fc806578e1569e 100644 --- a/tests/app/v1/routers/test_groups.py +++ b/tests/app/v1/routers/test_groups.py @@ -58,6 +58,48 @@ def test_get_groups( assert_group(group, response_group) +@pytest.mark.parametrize( + 'filters, ref_group_labels', + [ + pytest.param({'group_id': ['group1']}, ['group1'], id='id'), + pytest.param( + {'group_id': ['group1', 'group2']}, ['group1', 'group2'], id='ids' + ), + pytest.param({'search_terms': 'Uniq'}, ['uniq'], id='uniq'), + pytest.param({'search_terms': 'iq'}, ['uniq'], id='uniq-partial'), + pytest.param({'search_terms': 'Twin'}, ['twin1', 'twin2'], id='twins'), + pytest.param({'search_terms': 'Twin One'}, ['twin1'], id='twin1'), + pytest.param( + {'search_terms': 'One'}, ['twin1', 'numerals'], id='twin1-numerals' + ), + pytest.param({'search_terms': 'One Two'}, ['numerals'], id='numerals'), + pytest.param( + {'search_terms': 'Tw'}, ['twin1', 'twin2', 'numerals'], id='tw-partial' + ), + ], +) +def test_get_filtered_groups( + auth_headers, + client, + convert_group_labels_to_ids, + groups_module, + filters, + ref_group_labels, +): + filters = convert_group_labels_to_ids(filters) + response = perform_get(client, base_url, auth_headers['user1'], **filters) + assert_response(response, 200) + + response_groups = UserGroups.parse_raw(response.content) + response_ids = [group.group_id for group in response_groups.data] + ref_group_ids = convert_group_labels_to_ids(ref_group_labels) + assert_unordered_lists(response_ids, ref_group_ids) + + for response_group in response_groups.data: + group = get_user_group(response_group.group_id) + assert_group(group, response_group) + + @pytest.mark.parametrize( 'user_label, expected_status_code', [ diff --git a/tests/fixtures/groups.py b/tests/fixtures/groups.py index 28352e974e9b15f6a00947a383d00f8f76c1be10..f91bc50439ff2ce67e75a5fdb96de401c2b1f301 100644 --- a/tests/fixtures/groups.py +++ b/tests/fixtures/groups.py @@ -15,8 +15,10 @@ from tests.utils import fake_group_uuid, fake_user_uuid, generate_convert_label def group_molds(): """Return a dict: group label -> group data (dict).""" - def old_group(owner, members): - group_str = str(owner) + ''.join(str(m) for m in members) + def old_group(owner, members, group_str=None): + if group_str is None: + group_str = str(owner) + ''.join(str(m) for m in members) + return dict( group_id=fake_group_uuid(group_str), group_name=f'Group {group_str}', @@ -43,6 +45,10 @@ def group_molds(): 'group18': old_group(1, [8]), 'group19': old_group(1, [9]), 'group123': old_group(1, [2, 3]), + 'uniq': old_group(0, [], 'Uniq'), + 'twin1': old_group(0, [], 'Twin One'), + 'twin2': old_group(0, [], 'Twin Two'), + 'numerals': old_group(0, [], 'One Two Three'), } new_groups = {