diff --git a/nomad/app/v1/routers/groups.py b/nomad/app/v1/routers/groups.py index b6b27d739e7aa756e1557437ac0e75cf03094fa2..0c6f24329f2605f44f356b949048eaa47666d9f3 100644 --- a/nomad/app/v1/routers/groups.py +++ b/nomad/app/v1/routers/groups.py @@ -18,19 +18,20 @@ from typing import List, Optional, Set -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends, HTTPException, Query, status from pydantic import BaseModel, Field from nomad.datamodel import User as UserDataModel from nomad.groups import ( UserGroup as MongoUserGroup, +) +from nomad.groups import ( create_user_group as create_mongo_user_group, ) from nomad.utils import strip -from .auth import create_user_dependency from ..models import User - +from .auth import create_user_dependency router = APIRouter() default_tag = 'groups' @@ -75,7 +76,7 @@ def get_mongo_user_group(group_id: str) -> MongoUserGroup: if user_group is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=strip(f"User group '{group_id}' was not found."), + detail=f"User group '{group_id}' was not found.", ) return user_group @@ -107,11 +108,32 @@ def check_user_may_edit_user_group(user: User, user_group: MongoUserGroup): @router.get( - '', tags=[default_tag], summary='List user groups.', response_model=UserGroups + '', + tags=[default_tag], + summary='List user groups. Use at most one filter.', + response_model=UserGroups, ) -async def get_user_groups(): +async def get_user_groups( + group_id: Optional[List[str]] = Query( + None, description='Search groups by their full id.' + ), + search_terms: Optional[str] = Query( + None, description='Search groups by parts of their name.' + ), +): """Get data about user groups.""" - user_groups = MongoUserGroup.objects + if group_id is not None and search_terms is not None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail='Only one of group_id or search_terms may be used at a time.', + ) + + if group_id is not None: + user_groups = MongoUserGroup.get_by_ids(group_id) + elif search_terms is not None: + user_groups = MongoUserGroup.get_by_search_terms(search_terms) + else: + user_groups = MongoUserGroup.objects data = [UserGroup.from_orm(user_group) for user_group in user_groups] return {'data': data} diff --git a/nomad/groups.py b/nomad/groups.py index 99a2085d2bfb42a9c671a7986f6c934c9c0d77c8..3ae4cc3b8e525c606192de940f2bd37afe3cc29d 100644 --- a/nomad/groups.py +++ b/nomad/groups.py @@ -15,9 +15,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import operator +from functools import reduce from typing import Iterable, List, Optional -from mongoengine import Document, StringField, ListField +from mongoengine import Document, ListField, StringField from mongoengine.queryset.visitor import Q from nomad.utils import create_uuid @@ -37,6 +39,14 @@ class UserGroup(Document): meta = {'indexes': ['group_name', 'owner', 'members']} + @classmethod + def get_by_ids(cls, group_ids: Iterable[str]): + """ + Returns UserGroup objects with group_ids. + """ + user_groups = cls.objects(group_id__in=group_ids) + return user_groups + @classmethod def get_by_user_id(cls, user_id: Optional[str]): """ @@ -58,6 +68,20 @@ class UserGroup(Document): group_ids.extend(group.group_id for group in cls.get_by_user_id(user_id)) return group_ids + @classmethod + def get_by_search_terms(cls, search_terms: str): + """ + Returns UserGroup objects where group_name includes search_terms (no case). + """ + search_terms = str(search_terms).split() + if not search_terms: + return [] + + query = (Q(group_name__icontains=term) for term in search_terms) + query = reduce(operator.and_, query) + user_groups = cls.objects(query) + return user_groups + def create_user_group( *,