Skip to content
Snippets Groups Projects
Commit 90e3a6a4 authored by Sascha Klawohn's avatar Sascha Klawohn
Browse files

Merge branch 'get-groups-filters' into 'develop'

Get groups filters

See merge request !1979
parents c4541fba 57ddbbb0
No related branches found
No related tags found
1 merge request!1979Get groups filters
Pipeline #213565 passed
...@@ -18,19 +18,20 @@ ...@@ -18,19 +18,20 @@
from typing import List, Optional, Set from typing import List, Optional, Set
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, Query, status
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from nomad.datamodel import User as UserDataModel from nomad.datamodel import User as UserDataModel
from nomad.groups import ( from nomad.groups import (
UserGroup as MongoUserGroup, UserGroup as MongoUserGroup,
)
from nomad.groups import (
create_user_group as create_mongo_user_group, create_user_group as create_mongo_user_group,
) )
from nomad.utils import strip from nomad.utils import strip
from .auth import create_user_dependency
from ..models import User from ..models import User
from .auth import create_user_dependency
router = APIRouter() router = APIRouter()
default_tag = 'groups' default_tag = 'groups'
...@@ -75,7 +76,7 @@ def get_mongo_user_group(group_id: str) -> MongoUserGroup: ...@@ -75,7 +76,7 @@ def get_mongo_user_group(group_id: str) -> MongoUserGroup:
if user_group is None: if user_group is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=strip(f"User group '{group_id}' was not found."), detail=f"User group '{group_id}' was not found.",
) )
return user_group return user_group
...@@ -107,11 +108,32 @@ def check_user_may_edit_user_group(user: User, user_group: MongoUserGroup): ...@@ -107,11 +108,32 @@ def check_user_may_edit_user_group(user: User, user_group: MongoUserGroup):
@router.get( @router.get(
'', tags=[default_tag], summary='List user groups.', response_model=UserGroups '',
tags=[default_tag],
summary='List user groups. Use at most one filter.',
response_model=UserGroups,
) )
async def get_user_groups(): async def get_user_groups(
group_id: Optional[List[str]] = Query(
None, description='Search groups by their full id.'
),
search_terms: Optional[str] = Query(
None, description='Search groups by parts of their name.'
),
):
"""Get data about user groups.""" """Get data about user groups."""
user_groups = MongoUserGroup.objects if group_id is not None and search_terms is not None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail='Only one of group_id or search_terms may be used at a time.',
)
if group_id is not None:
user_groups = MongoUserGroup.get_by_ids(group_id)
elif search_terms is not None:
user_groups = MongoUserGroup.get_by_search_terms(search_terms)
else:
user_groups = MongoUserGroup.objects
data = [UserGroup.from_orm(user_group) for user_group in user_groups] data = [UserGroup.from_orm(user_group) for user_group in user_groups]
return {'data': data} return {'data': data}
......
...@@ -15,9 +15,11 @@ ...@@ -15,9 +15,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
from typing import Iterable, List, Optional import operator
from functools import reduce
from typing import Iterable, Optional, Union
from mongoengine import Document, StringField, ListField from mongoengine import Document, ListField, StringField
from mongoengine.queryset.visitor import Q from mongoengine.queryset.visitor import Q
from nomad.utils import create_uuid from nomad.utils import create_uuid
...@@ -37,6 +39,16 @@ class UserGroup(Document): ...@@ -37,6 +39,16 @@ class UserGroup(Document):
meta = {'indexes': ['group_name', 'owner', 'members']} meta = {'indexes': ['group_name', 'owner', 'members']}
@classmethod
def get_by_ids(cls, group_ids: Union[str, Iterable[str]]):
"""
Returns UserGroup objects with group_ids.
"""
if not isinstance(group_ids, Iterable):
group_ids = [group_ids]
user_groups = cls.objects(group_id__in=group_ids)
return user_groups
@classmethod @classmethod
def get_by_user_id(cls, user_id: Optional[str]): def get_by_user_id(cls, user_id: Optional[str]):
""" """
...@@ -48,7 +60,7 @@ class UserGroup(Document): ...@@ -48,7 +60,7 @@ class UserGroup(Document):
return user_groups return user_groups
@classmethod @classmethod
def get_ids_by_user_id(cls, user_id: Optional[str], include_all=True) -> List[str]: def get_ids_by_user_id(cls, user_id: Optional[str], include_all=True) -> list[str]:
""" """
Returns ids of all user groups where user_id is owner or member. Returns ids of all user groups where user_id is owner or member.
Does include special group 'all', even if user_id is missing or not a user. Does include special group 'all', even if user_id is missing or not a user.
...@@ -58,6 +70,20 @@ class UserGroup(Document): ...@@ -58,6 +70,20 @@ class UserGroup(Document):
group_ids.extend(group.group_id for group in cls.get_by_user_id(user_id)) group_ids.extend(group.group_id for group in cls.get_by_user_id(user_id))
return group_ids return group_ids
@classmethod
def get_by_search_terms(cls, search_terms: str):
"""
Returns UserGroup objects where group_name includes search_terms (no case).
"""
split_terms = str(search_terms).split()
if not split_terms:
return []
query = (Q(group_name__icontains=term) for term in split_terms)
query = reduce(operator.and_, query)
user_groups = cls.objects(query)
return user_groups
def create_user_group( def create_user_group(
*, *,
......
...@@ -58,6 +58,48 @@ def test_get_groups( ...@@ -58,6 +58,48 @@ def test_get_groups(
assert_group(group, response_group) assert_group(group, response_group)
@pytest.mark.parametrize(
'filters, ref_group_labels',
[
pytest.param({'group_id': ['group1']}, ['group1'], id='id'),
pytest.param(
{'group_id': ['group1', 'group2']}, ['group1', 'group2'], id='ids'
),
pytest.param({'search_terms': 'Uniq'}, ['uniq'], id='uniq'),
pytest.param({'search_terms': 'iq'}, ['uniq'], id='uniq-partial'),
pytest.param({'search_terms': 'Twin'}, ['twin1', 'twin2'], id='twins'),
pytest.param({'search_terms': 'Twin One'}, ['twin1'], id='twin1'),
pytest.param(
{'search_terms': 'One'}, ['twin1', 'numerals'], id='twin1-numerals'
),
pytest.param({'search_terms': 'One Two'}, ['numerals'], id='numerals'),
pytest.param(
{'search_terms': 'Tw'}, ['twin1', 'twin2', 'numerals'], id='tw-partial'
),
],
)
def test_get_filtered_groups(
auth_headers,
client,
convert_group_labels_to_ids,
groups_module,
filters,
ref_group_labels,
):
filters = convert_group_labels_to_ids(filters)
response = perform_get(client, base_url, auth_headers['user1'], **filters)
assert_response(response, 200)
response_groups = UserGroups.parse_raw(response.content)
response_ids = [group.group_id for group in response_groups.data]
ref_group_ids = convert_group_labels_to_ids(ref_group_labels)
assert_unordered_lists(response_ids, ref_group_ids)
for response_group in response_groups.data:
group = get_user_group(response_group.group_id)
assert_group(group, response_group)
@pytest.mark.parametrize( @pytest.mark.parametrize(
'user_label, expected_status_code', 'user_label, expected_status_code',
[ [
......
...@@ -15,8 +15,10 @@ from tests.utils import fake_group_uuid, fake_user_uuid, generate_convert_label ...@@ -15,8 +15,10 @@ from tests.utils import fake_group_uuid, fake_user_uuid, generate_convert_label
def group_molds(): def group_molds():
"""Return a dict: group label -> group data (dict).""" """Return a dict: group label -> group data (dict)."""
def old_group(owner, members): def old_group(owner, members, group_str=None):
group_str = str(owner) + ''.join(str(m) for m in members) if group_str is None:
group_str = str(owner) + ''.join(str(m) for m in members)
return dict( return dict(
group_id=fake_group_uuid(group_str), group_id=fake_group_uuid(group_str),
group_name=f'Group {group_str}', group_name=f'Group {group_str}',
...@@ -43,6 +45,10 @@ def group_molds(): ...@@ -43,6 +45,10 @@ def group_molds():
'group18': old_group(1, [8]), 'group18': old_group(1, [8]),
'group19': old_group(1, [9]), 'group19': old_group(1, [9]),
'group123': old_group(1, [2, 3]), 'group123': old_group(1, [2, 3]),
'uniq': old_group(0, [], 'Uniq'),
'twin1': old_group(0, [], 'Twin One'),
'twin2': old_group(0, [], 'Twin Two'),
'numerals': old_group(0, [], 'One Two Three'),
} }
new_groups = { new_groups = {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment