diff --git a/nomad/app/dcat/common.py b/nomad/app/dcat/common.py index 131dbad6b6645ff075af9aee428f6da3880b964b..9a91b70c9feb5bc6c6314f217c015c26852d4337 100644 --- a/nomad/app/dcat/common.py +++ b/nomad/app/dcat/common.py @@ -76,7 +76,7 @@ response_types = [ def rdf_response( - format: Optional[Formats] = Query(None), accept: Optional[str] = Header(None) + format: Formats | None = Query(None), accept: str | None = Header(None) ): format_ = format.value if format else None if format_ is None: diff --git a/nomad/app/dcat/routers/dcat.py b/nomad/app/dcat/routers/dcat.py index 8a027c15698918ce4c86354798c30962946e469d..7dc52a01be1fa3dea138428aebff2653e6552b61 100644 --- a/nomad/app/dcat/routers/dcat.py +++ b/nomad/app/dcat/routers/dcat.py @@ -89,7 +89,7 @@ async def get_dataset( ) async def get_catalog( after: str = Query(None, description='return entries after the given entry_id'), - modified_since: Union[datetime, date] = Query( + modified_since: datetime | date = Query( None, description='maximum entry time (e.g. upload time)' ), rdf_respose=Depends(rdf_response), diff --git a/nomad/app/h5grove_app.py b/nomad/app/h5grove_app.py index ac88d08339eb35cbec355c52569b232ea63e3763..110eb589a7cdac8d0c436fb84d5b4ddbae3c9896 100644 --- a/nomad/app/h5grove_app.py +++ b/nomad/app/h5grove_app.py @@ -24,7 +24,8 @@ import traceback import re import urllib.parse import h5py -from typing import Callable, Dict, Any, IO +from typing import Dict, Any, IO +from collections.abc import Callable from h5grove import fastapi_utils as h5grove_router, utils as h5grove_utils @@ -40,7 +41,7 @@ logger = utils.get_logger(__name__) def open_zipped_h5_file( filepath: str, create_error: Callable[[int, str], Exception], - h5py_options: Dict[str, Any] = {}, + h5py_options: dict[str, Any] = {}, ) -> h5py.File: import re import io diff --git a/nomad/app/optimade/__init__.py b/nomad/app/optimade/__init__.py index d51b47d99d9595c4fd283a1d2b7f0cdbe36d3b52..c10f5c65ca70489080833e924e65d579edf26078 100644 --- a/nomad/app/optimade/__init__.py +++ b/nomad/app/optimade/__init__.py @@ -105,7 +105,7 @@ from nomad.config import config from optimade.server.config import CONFIG # nopep8 CONFIG.root_path = '%s/optimade' % config.services.api_base_path -CONFIG.base_url = '%s://%s' % ( +CONFIG.base_url = '{}://{}'.format( 'https' if config.services.https else 'http', config.services.api_host.strip('/'), ) diff --git a/nomad/app/optimade/common.py b/nomad/app/optimade/common.py index d040ef37eebf9bb4d5146c39fe5535fb75359933..623cd3847cdf3cb2ebe36363bef0ba18ed62841b 100644 --- a/nomad/app/optimade/common.py +++ b/nomad/app/optimade/common.py @@ -23,7 +23,7 @@ from nomad.metainfo.metainfo import Quantity, Reference from nomad.metainfo.elasticsearch_extension import SearchQuantity, entry_type -_provider_specific_fields: Dict[str, SearchQuantity] = None +_provider_specific_fields: dict[str, SearchQuantity] = None def create_provider_field(name, definition): @@ -43,7 +43,7 @@ def create_provider_field(name, definition): return dict(name=name, description=description, type=optimade_type, sortable=False) -def provider_specific_fields() -> Dict[str, SearchQuantity]: +def provider_specific_fields() -> dict[str, SearchQuantity]: global _provider_specific_fields if _provider_specific_fields is not None: diff --git a/nomad/app/optimade/elasticsearch.py b/nomad/app/optimade/elasticsearch.py index fd741fe32aeb1f7e1b632f911b445a6b294eb88e..6c8836638faacb847ab7a83f7c0cabfcc3baca12 100644 --- a/nomad/app/optimade/elasticsearch.py +++ b/nomad/app/optimade/elasticsearch.py @@ -36,7 +36,7 @@ class NomadStructureMapper(StructureMapper): return doc @classproperty - def ALL_ATTRIBUTES(cls) -> Set[str]: # pylint: disable=no-self-argument + def ALL_ATTRIBUTES(cls) -> set[str]: # pylint: disable=no-self-argument result = getattr(cls, '_ALL_ATTRIBUTES', None) if result is None: result = StructureMapper.ALL_ATTRIBUTES # pylint: disable=no-member @@ -110,8 +110,8 @@ class StructureCollection(EntryCollection): def _es_to_optimade_result( self, es_result: dict, - response_fields: Set[str], - upload_files_cache: Dict[str, files.UploadFiles] = None, + response_fields: set[str], + upload_files_cache: dict[str, files.UploadFiles] = None, ) -> StructureResource: if upload_files_cache is None: upload_files_cache = {} @@ -226,9 +226,9 @@ class StructureCollection(EntryCollection): return attrs def _es_to_optimade_results( - self, es_results: List[dict], response_fields: Set[str] + self, es_results: list[dict], response_fields: set[str] ): - upload_files_cache: Dict[str, files.UploadFiles] = {} + upload_files_cache: dict[str, files.UploadFiles] = {} optimade_results = [] try: for es_result in es_results: @@ -243,7 +243,7 @@ class StructureCollection(EntryCollection): return optimade_results - def _run_db_query(self, criteria: Dict[str, Any], single_entry=False): + def _run_db_query(self, criteria: dict[str, Any], single_entry=False): sort, order = criteria.get('sort', (('chemical_formula_reduced', 1),))[0] sort_quantity = datamodel.OptimadeEntry.m_def.all_quantities.get(sort, None) if sort_quantity is None: diff --git a/nomad/app/optimade/filterparser.py b/nomad/app/optimade/filterparser.py index ca85b236e5d20441930d3a9ca70c6d6916274583..75cec2fd210ebd7665df4007a270362419ca8e84 100644 --- a/nomad/app/optimade/filterparser.py +++ b/nomad/app/optimade/filterparser.py @@ -42,7 +42,7 @@ class FilterException(Exception): def _get_transformer(without_prefix, **kwargs): from nomad.datamodel import OptimadeEntry - quantities: Dict[str, Quantity] = { + quantities: dict[str, Quantity] = { q.name: Quantity( q.name, backend_field='optimade.%s' % q.name, diff --git a/nomad/app/resources/main.py b/nomad/app/resources/main.py index 9fa7e4980730913b0ed451ee141872cc79ffab82..e5c4b6992de2ffdd2ea96610a917f511023567f9 100644 --- a/nomad/app/resources/main.py +++ b/nomad/app/resources/main.py @@ -38,7 +38,7 @@ app = FastAPI( redoc_url='/extensions/redoc', swagger_ui_oauth2_redirect_url='/extensions/docs/oauth2-redirect', title='Resources API', - version='v1, NOMAD %s@%s' % (config.meta.version, config.meta.commit), + version=f'v1, NOMAD {config.meta.version}@{config.meta.commit}', description="NOMAD's API for serving related external resources", ) diff --git a/nomad/app/resources/routers/resources.py b/nomad/app/resources/routers/resources.py index c36276a9ee7f8563f50d90a94260d6b828db5229..878d3a3ea335a78e2c7d92bcab4566c9a0d7aaf6 100644 --- a/nomad/app/resources/routers/resources.py +++ b/nomad/app/resources/routers/resources.py @@ -98,7 +98,7 @@ optimade_providers = { # ) } -optimade_dbs: List[str] = [ +optimade_dbs: list[str] = [ str(details['name']) for details in optimade_providers.values() ] @@ -191,7 +191,7 @@ class ResourceModel(BaseModel): # data: Dict[str, Any] = Field( # {}, description=''' Value of the data referenced by the entry. # ''') - available_data: List[str] = Field( + available_data: list[str] = Field( [], description="""List of available data referenced by the entry""" ) url: str = Field( @@ -212,25 +212,25 @@ class ResourceModel(BaseModel): Date the data was downloaded. """, ) - database_name: Optional[str] = Field( + database_name: str | None = Field( None, description=""" Name to identify the referenced data. """, ) - kind: Optional[str] = Field( + kind: str | None = Field( None, description=""" Kind of the reference data, e.g. journal, online, book. """, ) - comment: Optional[str] = Field( + comment: str | None = Field( None, description=""" Annotations on the reference. """, ) - database_version: Optional[str] = Field( + database_version: str | None = Field( None, description=""" Version of the database. @@ -246,7 +246,7 @@ class ResourceModel(BaseModel): class ResourcesModel(BaseModel): - data: List[ResourceModel] = Field( + data: list[ResourceModel] = Field( [], description='The list of resources, currently in our database.' ) @@ -274,7 +274,7 @@ async def _download(session: httpx.AsyncClient, path: str) -> httpx.Response: return None -def _update_dict(target: Dict[str, float], source: Dict[str, float]): +def _update_dict(target: dict[str, float], source: dict[str, float]): for key, val in source.items(): if key in target: target[key] += val @@ -282,11 +282,11 @@ def _update_dict(target: Dict[str, float], source: Dict[str, float]): target[key] = val -def _components(formula_str: str, multiplier: float = 1.0) -> Dict[str, float]: +def _components(formula_str: str, multiplier: float = 1.0) -> dict[str, float]: # match atoms and molecules (in brackets) components = formula_re.findall(formula_str) - symbol_amount: Dict[str, float] = {} + symbol_amount: dict[str, float] = {} for component in components: element, amount_e, molecule, amount_m = component if element: @@ -317,7 +317,7 @@ def _normalize_formula(formula_str: str) -> str: return ''.join(formula_sorted) -def parse_springer_entry(htmltext: str) -> Dict[str, str]: +def parse_springer_entry(htmltext: str) -> dict[str, str]: """ Parse the springer entry quantities in required_items from an html text. """ @@ -374,7 +374,7 @@ def parse_springer_entry(htmltext: str) -> Dict[str, str]: return results -def parse_aflow_prototype(text: str) -> Dict[str, Any]: +def parse_aflow_prototype(text: str) -> dict[str, Any]: """ Parse information from aflow prototype structure entry. """ @@ -389,7 +389,7 @@ def parse_aflow_prototype(text: str) -> Dict[str, Any]: async def _get_urls_aflow_prototypes( session: httpx.AsyncClient, space_group_number: int -) -> List[str]: +) -> list[str]: if space_group_number is None or space_group_number == 0: return [] @@ -412,7 +412,7 @@ async def _get_urls_aflow_prototypes( async def _get_resources_aflow_prototypes( session: httpx.AsyncClient, path: str, chemical_formula: str -) -> List[Resource]: +) -> list[Resource]: response = await _download(session, path) if response is None: return [] @@ -451,7 +451,7 @@ async def _get_resources_aflow_prototypes( async def _get_urls_springer_materials( session: httpx.AsyncClient, chemical_formula: str -) -> List[str]: +) -> list[str]: if chemical_formula is None: return [] @@ -480,7 +480,7 @@ async def _get_urls_springer_materials( async def _get_resources_springer_materials( session: httpx.AsyncClient, path: str -) -> List[Resource]: +) -> list[Resource]: resource = Resource() resource.url = path resource.id = os.path.basename(path) @@ -514,8 +514,8 @@ async def _get_resources_springer_materials( async def _get_urls_optimade( chemical_formula_hill: str, chemical_formula_reduced: str, - providers: List[str] = None, -) -> List[str]: + providers: list[str] = None, +) -> list[str]: filter_hill = ( f'chemical_formula_hill = "{chemical_formula_hill}"' if chemical_formula_hill is not None @@ -549,13 +549,13 @@ async def _get_urls_optimade( async def _get_resources_optimade( session: httpx.AsyncClient, path: str -) -> List[Resource]: +) -> list[Resource]: response = await _download(session, path) if response is None: logger.error(f'Error accessing optimade resources.', data=dict(path=path)) return [] data = response.json() - resources: List[Resource] = [] + resources: list[Resource] = [] meta = data.get('meta', dict()) provider = meta.get('provider', dict()).get('name', '') base_url = path.split('structures?filter')[0] @@ -592,7 +592,7 @@ async def _get_resources_optimade( @app.task def retrieve_resources( status_resource_id, - urls_to_ignore: List[str], + urls_to_ignore: list[str], space_group_number, chemical_formula, chemical_formula_hill, @@ -618,9 +618,9 @@ def retrieve_resources( ) ) - aflow_urls: List[str] - springer_urls: List[str] - optimade_urls: List[str] + aflow_urls: list[str] + springer_urls: list[str] + optimade_urls: list[str] aflow_urls, springer_urls, optimade_urls = await asyncio.gather( aflow_task, springer_task, optimade_task ) @@ -672,7 +672,7 @@ def retrieve_resources( ) async def get_resources( space_group_number: int = FastApiQuery(None), - wyckoff_letters: List[str] = FastApiQuery(None), + wyckoff_letters: list[str] = FastApiQuery(None), n_sites: int = FastApiQuery(None), chemical_formula_reduced: str = FastApiQuery(None), ): @@ -692,11 +692,11 @@ async def get_resources( wyckoff_letters = list(set(wyckoff_letters)) wyckoff_letters.sort() - sources: Dict[str, int] = dict() + sources: dict[str, int] = dict() - def convert_resources_to_models(resources) -> List[ResourceModel]: - data: List[ResourceModel] = [] - additional_data: List[ResourceModel] = [] + def convert_resources_to_models(resources) -> list[ResourceModel]: + data: list[ResourceModel] = [] + additional_data: list[ResourceModel] = [] for resource in resources: if ( resource is None diff --git a/nomad/app/v1/models/graph/graph_models.py b/nomad/app/v1/models/graph/graph_models.py index d199c6c1b62872b329fe483f920ad010cd623a5d..acb7b59ffc1fb17c93a989a2ef918fa21a714f10 100644 --- a/nomad/app/v1/models/graph/graph_models.py +++ b/nomad/app/v1/models/graph/graph_models.py @@ -63,20 +63,20 @@ class DirectoryResponseOptions(BaseModel): class GraphDirectory(BaseModel): - m_errors: List[Error] + m_errors: list[Error] m_is: Literal['Directory'] m_request: DirectoryRequestOptions m_response: DirectoryResponseOptions - m_children: Union[GraphDirectory, GraphFile] + m_children: GraphDirectory | GraphFile class GraphFile(BaseModel): - m_errors: List[Error] + m_errors: list[Error] m_is: Literal['File'] m_request: DirectoryRequestOptions path: str size: int - entry: Optional[GraphEntry] = None + entry: GraphEntry | None = None # The old API also had those, but they can be grabbed from entry: # parser_name, entry_id, archive # This is similar to the question for "m_parent" in Directory. At least we need @@ -86,7 +86,7 @@ class GraphFile(BaseModel): class MSection(BaseModel): - m_errors: List[Error] + m_errors: list[Error] m_request: RecursionOptions m_def: MDef m_children: Any = None @@ -98,7 +98,7 @@ class MDef(MSection): class GraphEntry(mapped(EntryProcData, mainfile='mainfile_path', entry_metadata=None)): # type: ignore - m_errors: List[Error] + m_errors: list[Error] mainfile: GraphFile upload: GraphUpload archive: MSection @@ -107,11 +107,11 @@ class GraphEntry(mapped(EntryProcData, mainfile='mainfile_path', entry_metadata= class EntriesRequestOptions(BaseModel): # The old API does not support any queries - pagination: Optional[EntryProcDataPagination] = None + pagination: EntryProcDataPagination | None = None class EntriesResponseOptions(BaseModel): - pagination: Optional[PaginationResponse] = None + pagination: PaginationResponse | None = None # The "upload" was only necessary, because in the old API you would not get the upload. # In the graph API, the upload would be the parent anyways # upload: Upload @@ -130,15 +130,15 @@ class GraphUser( # This would only refer to uploads with the user as main_author. # For many clients and use-cases uploads.m_request.query will be the # more generic or only option - uploads: Optional[GraphUploads] - datasets: Optional[GraphDatasets] + uploads: GraphUploads | None + datasets: GraphDatasets | None model_config = ConfigDict( extra='forbid', ) class GraphUsers(BaseModel): - m_errors: List[Error] + m_errors: list[Error] m_children: GraphUser @@ -147,10 +147,10 @@ class GraphUpload( UploadProcData, entries='n_entries', main_author=GraphUser, - coauthors=List[GraphUser], - reviewers=List[GraphUser], - viewers=List[GraphUser], - writers=List[GraphUser], + coauthors=list[GraphUser], + reviewers=list[GraphUser], + viewers=list[GraphUser], + writers=list[GraphUser], ), ): # The old API includes some extra data here: @@ -172,19 +172,19 @@ class GraphUpload( class UploadRequestOptions(BaseModel): - pagination: Optional[UploadProcDataPagination] = None - query: Optional[UploadProcDataQuery] = None + pagination: UploadProcDataPagination | None = None + query: UploadProcDataQuery | None = None class UploadResponseOptions(BaseModel): - pagination: Optional[PaginationResponse] = None - query: Optional[UploadProcDataQuery] = None + pagination: PaginationResponse | None = None + query: UploadProcDataQuery | None = None class GraphUploads(BaseModel): m_request: UploadRequestOptions m_response: UploadResponseOptions - m_errors: List[Error] + m_errors: list[Error] m_children: GraphUpload @@ -194,17 +194,17 @@ class GraphEntryMetadata(BaseModel, extra=Extra.allow): class SearchRequestOptions(BaseModel): - query: Optional[Metadata] = None + query: Metadata | None = None class SearchResponseOptions(BaseModel): - query: Optional[MetadataResponse] = None + query: MetadataResponse | None = None class GraphSearch(BaseModel): m_request: SearchRequestOptions m_response: SearchResponseOptions - m_errors: List[Error] + m_errors: list[Error] m_children: GraphEntryMetadata @@ -213,55 +213,55 @@ class GraphDataset(mapped(DatasetV1, query=None, entries=None)): # type: ignore class DatasetRequestOptions(BaseModel): - pagination: Optional[DatasetPagination] = None - query: Optional[DatasetQuery] = None + pagination: DatasetPagination | None = None + query: DatasetQuery | None = None class DatasetResponseOptions(BaseModel): - pagination: Optional[PaginationResponse] = None - query: Optional[DatasetQuery] = None + pagination: PaginationResponse | None = None + query: DatasetQuery | None = None class GraphDatasets(BaseModel): m_request: DatasetRequestOptions m_response: DatasetResponseOptions - m_errors: List[Error] + m_errors: list[Error] m_children: GraphDataset class MetainfoRequestOptions(BaseModel): - pagination: Optional[MetainfoPagination] = None - query: Optional[MetainfoQuery] = None + pagination: MetainfoPagination | None = None + query: MetainfoQuery | None = None class MetainfoResponseOptions(BaseModel): - pagination: Optional[PaginationResponse] = None - query: Optional[MetainfoQuery] = None + pagination: PaginationResponse | None = None + query: MetainfoQuery | None = None class GraphMetainfo(BaseModel): m_request: MetainfoRequestOptions m_response: MetainfoResponseOptions - m_errors: List[Error] + m_errors: list[Error] m_children: MSection -class GraphGroup(mapped(UserGroup, owner=GraphUser, members=List[GraphUser])): # type: ignore - m_errors: List[Error] +class GraphGroup(mapped(UserGroup, owner=GraphUser, members=list[GraphUser])): # type: ignore + m_errors: list[Error] class GroupRequestOptions(BaseModel): - pagination: Optional[UserGroupPagination] - query: Optional[UserGroupQuery] + pagination: UserGroupPagination | None + query: UserGroupQuery | None class GroupResponseOptions(BaseModel): - pagination: Optional[PaginationResponse] - query: Optional[UserGroupQuery] + pagination: PaginationResponse | None + query: UserGroupQuery | None class GraphGroups(BaseModel): - m_errors: List[Error] + m_errors: list[Error] m_children: GraphGroup m_request: GroupRequestOptions m_response: GroupResponseOptions diff --git a/nomad/app/v1/models/graph/utils.py b/nomad/app/v1/models/graph/utils.py index 4452f0af5ef32f7d2abb43aba110f57bb8c98c29..f7d59b63dab1452064ffb718be455eb8ce7189fd 100644 --- a/nomad/app/v1/models/graph/utils.py +++ b/nomad/app/v1/models/graph/utils.py @@ -18,20 +18,18 @@ from __future__ import annotations from typing import ( - Dict, - List, Optional, - Type, Literal, Union, Any, - Callable, ForwardRef, get_type_hints, get_origin, get_args, cast, ) +from collections.abc import Callable +from types import UnionType from datetime import datetime from pydantic import ( BaseModel, @@ -52,7 +50,7 @@ response_suffix = 'Response' graph_model_export = False -def json_schema_extra(schema: dict[str, Any], model: Type[_DictModel]) -> None: +def json_schema_extra(schema: dict[str, Any], model: type[_DictModel]) -> None: if 'm_children' not in model.__annotations__: raise TypeError( f'No m_children field defined for dict model {model.__name__}. ' @@ -64,7 +62,7 @@ def json_schema_extra(schema: dict[str, Any], model: Type[_DictModel]) -> None: f"Could not determine m_children's type. Did you miss to call update_forward_refs()?" ) - if get_origin(value_type) == Union: + if get_origin(value_type) in (Union, UnionType): value_types = get_args(value_type) else: value_types = (value_type,) @@ -174,13 +172,16 @@ def _get_request_type(type_hint: Any, ns: ModelNamespace) -> Any: if origin is dict: key_type, value_type = args - return Dict[key_type, _get_request_type(value_type, ns)] # type: ignore + return dict[key_type, _get_request_type(value_type, ns)] # type: ignore # This is about Optional[T], which is translated to Union[None, T] - if origin is Union and len(args) == 2 and isinstance(None, args[1]): - return _get_request_type(args[0], ns) + if origin in (Union, UnionType) and len(args) == 2: + if isinstance(None, args[1]): + return _get_request_type(args[0], ns) + if isinstance(None, args[0]): + return _get_request_type(args[1], ns) - if origin is Union: + if origin in (Union, UnionType): union_types = tuple(_get_request_type(type_, ns) for type_ in args) return Union[union_types] # type: ignore @@ -197,7 +198,7 @@ def _get_response_type(type_hint: Any, ns: ModelNamespace) -> Any: if origin is list: value_type = args[0] - return List[_get_response_type(value_type, ns)] # type: ignore + return list[_get_response_type(value_type, ns)] # type: ignore if origin is dict: key_type, value_type = args @@ -206,24 +207,27 @@ def _get_response_type(type_hint: Any, ns: ModelNamespace) -> Any: # We have detected direct type recursion, like in # Path = Dict[str, 'Path'] return type_hint - return Dict[key_type, _get_response_type(value_type, ns)] # type: ignore + return dict[key_type, _get_response_type(value_type, ns)] # type: ignore # This is about Optional[T], which is translated to Union[None, T] - if origin is Union and len(args) == 2 and isinstance(None, args[1]): - return _get_response_type(args[0], ns) + if origin in (Union, UnionType) and len(args) == 2: + if isinstance(None, args[1]): + return _get_response_type(args[0], ns) + if isinstance(None, args[0]): + return _get_response_type(args[1], ns) - if origin is Union: + if origin in (Union, UnionType): union_types = tuple(_get_response_type(type_, ns) for type_ in args) return Union[union_types] # type: ignore raise NotImplementedError(type_hint) -ModelNamespace = Dict[str, Union[Type[BaseModel], ForwardRef]] +ModelNamespace = dict[str, Union[type[BaseModel], ForwardRef]] def _generate_model( - source_model: Union[Type[BaseModel], Any], + source_model: type[BaseModel] | Any, suffix: str, generate_type: Callable[[type, ModelNamespace], type], ns: ModelNamespace, @@ -256,7 +260,7 @@ def _generate_model( if field_name == 'm_children': origin, args = get_origin(type_hint), get_args(type_hint) - if origin is Union: + if origin in (Union, UnionType): types = args else: types = (type_hint,) @@ -274,7 +278,7 @@ def _generate_model( # TODO we always add Literal['*'] at the end. Maybe it should be configurable # which models want to support '*' values for their children? value_type = Union[value_types + (Literal['*'],)] # type: ignore - fields['m_children'] = (Optional[Dict[str, cast(Type, value_type)]], None) # type: ignore + fields['m_children'] = (Optional[dict[str, cast(type, value_type)]], None) # type: ignore continue if field_name == 'm_request': @@ -367,7 +371,7 @@ def _generate_model( return result_model -def mapped(model: Type[BaseModel], **mapping: Union[str, type]) -> Type[BaseModel]: +def mapped(model: type[BaseModel], **mapping: str | type) -> type[BaseModel]: """ Creates a new pydantic model based on the given model. The mapping argument allows to either change the name of a field in the input model or change the type of a field @@ -416,9 +420,9 @@ def mapped(model: Type[BaseModel], **mapping: Union[str, type]) -> Type[BaseMode ) -def generate_request_model(source_model: Type[BaseModel]): +def generate_request_model(source_model: type[BaseModel]): return _generate_model(source_model, request_suffix, _get_request_type, dict()) -def generate_response_model(source_model: Type[BaseModel]): +def generate_response_model(source_model: type[BaseModel]): return _generate_model(source_model, response_suffix, _get_response_type, dict()) diff --git a/nomad/app/v1/models/groups.py b/nomad/app/v1/models/groups.py index 3718f1b281b3a31452901e1715ec1db57a1f7c41..1e7c34bc4d6f6af56489a2ff73ba7deb4a80728e 100644 --- a/nomad/app/v1/models/groups.py +++ b/nomad/app/v1/models/groups.py @@ -17,14 +17,14 @@ group_members_description = 'User ids of the group members.' class UserGroupEdit(BaseModel): - group_name: Optional[str] = Field( + group_name: str | None = Field( default=None, description=group_name_description, min_length=3, max_length=32, pattern=r'^[a-zA-Z0-9][a-zA-Z0-9 ._\-]+[a-zA-Z0-9]$', ) - members: Optional[Set[str]] = Field( + members: set[str] | None = Field( default=None, description=group_members_description ) @@ -35,7 +35,7 @@ class UserGroup(BaseModel): default='Default Group Name', description=group_name_description ) owner: str = Field(description='User id of the group owner.') - members: List[str] = Field( + members: list[str] = Field( default_factory=list, description=group_members_description ) @@ -43,18 +43,18 @@ class UserGroup(BaseModel): class UserGroupResponse(BaseModel): - pagination: Optional[PaginationResponse] = Field(None) - data: List[UserGroup] + pagination: PaginationResponse | None = Field(None) + data: list[UserGroup] class UserGroupQuery(BaseModel): - group_id: Optional[List[str]] = Field( + group_id: list[str] | None = Field( None, description='Search groups by their full id.' ) - user_id: Optional[str] = Field( + user_id: str | None = Field( None, description='Search groups by their owner or members ids.' ) - search_terms: Optional[str] = Field( + search_terms: str | None = Field( None, description='Search groups by parts of their name.' ) diff --git a/nomad/app/v1/models/models.py b/nomad/app/v1/models/models.py index 730cf964ffc8287be8b2c2ea8aa3203d924353a9..8921dfe90b2e28e0cac42be4eaca444e102b43bb 100644 --- a/nomad/app/v1/models/models.py +++ b/nomad/app/v1/models/models.py @@ -20,7 +20,8 @@ import enum import fnmatch import json import re -from typing import Any, Dict, List, Mapping, Optional, Union +from typing import Any, Dict, List, Optional, Union +from collections.abc import Mapping import pydantic from fastapi import Body, HTTPException, Request @@ -49,7 +50,7 @@ from nomad.metainfo.elasticsearch_extension import ( from nomad.utils import strip from .pagination import Pagination, PaginationResponse -from typing_extensions import Annotated +from typing import Annotated User: Any = datamodel.User.m_def.a_pydantic.model @@ -118,19 +119,19 @@ class NoneEmptyBaseModel(BaseModel): class All(NoneEmptyBaseModel): - op: List[Value] = Field(None, alias='all') + op: list[Value] = Field(None, alias='all') model_config = ConfigDict(extra='forbid') class None_(NoneEmptyBaseModel): - op: List[Value] = Field(None, alias='none') + op: list[Value] = Field(None, alias='none') model_config = ConfigDict(extra='forbid') class Any_(NoneEmptyBaseModel): - op: List[Value] = Field(None, alias='any') + op: list[Value] = Field(None, alias='any') model_config = ConfigDict(extra='forbid') @@ -169,10 +170,10 @@ class Range(BaseModel): return values - lt: Optional[ComparableValue] = Field(None) - lte: Optional[ComparableValue] = Field(None) - gt: Optional[ComparableValue] = Field(None) - gte: Optional[ComparableValue] = Field(None) + lt: ComparableValue | None = Field(None) + lte: ComparableValue | None = Field(None) + gt: ComparableValue | None = Field(None) + gte: ComparableValue | None = Field(None) model_config = ConfigDict(extra='forbid') @@ -187,7 +188,7 @@ ops = { 'any': Any_, } -CriteriaValue = Union[Value, List[Value], Range, Any_, All, None_, Dict[str, Any]] +CriteriaValue = Union[Value, list[Value], Range, Any_, All, None_, dict[str, Any]] class LogicalOperator(NoneEmptyBaseModel): @@ -201,7 +202,7 @@ class LogicalOperator(NoneEmptyBaseModel): class And(LogicalOperator): - op: List['Query'] = Field(None, alias='and') + op: list['Query'] = Field(None, alias='and') @model_validator(mode='before') @classmethod @@ -213,7 +214,7 @@ class And(LogicalOperator): class Or(LogicalOperator): - op: List['Query'] = Field(None, alias='or') + op: list['Query'] = Field(None, alias='or') @model_validator(mode='before') @classmethod @@ -379,8 +380,8 @@ def restrict_query_to_upload(query: Query, upload_id: str): class WithQuery(BaseModel): - owner: Optional[Owner] = Body('public') - query: Optional[Query] = Body( + owner: Owner | None = Body('public') + query: Query | None = Body( None, embed=True, description=query_documentation, @@ -442,10 +443,8 @@ class QueryParameters: def __call__( self, request: Request, - owner: Optional[Owner] = FastApiQuery( - 'public', description=strip(Owner.__doc__) - ), - json_query: Optional[str] = FastApiQuery( + owner: Owner | None = FastApiQuery('public', description=strip(Owner.__doc__)), + json_query: str | None = FastApiQuery( None, description=strip( """ @@ -453,7 +452,7 @@ class QueryParameters: """ ), ), - q: Optional[List[str]] = FastApiQuery( + q: list[str] | None = FastApiQuery( [], description=strip( """ @@ -515,7 +514,7 @@ class QueryParameters: query_params.setdefault(name_op, []).append(value) # transform query parameters to query - query: Dict[str, Any] = {} + query: dict[str, Any] = {} for key, query_value in query_params.items(): op = None if '__' in key: @@ -588,14 +587,14 @@ class QueryParameters: class MetadataRequired(BaseModel): """Defines which metadata quantities are included or excluded in the response.""" - include: Optional[List[str]] = Field( + include: list[str] | None = Field( None, description=strip(""" Quantities to include for each result. Only those quantities will be returned. At least one id quantity (e.g. `entry_id`) will always be included. """), ) - exclude: Optional[List[str]] = Field( + exclude: list[str] | None = Field( None, description=strip(""" Quantities to exclude for each result. Only all other quantities will @@ -611,7 +610,7 @@ metadata_required_parameters = parameter_dependency_from_model( # type: ignore class MetadataBasedPagination(Pagination): - order_by: Optional[str] = Field( + order_by: str | None = Field( None, description=strip( """ @@ -642,7 +641,7 @@ class MetadataBasedPagination(Pagination): class MetadataPagination(MetadataBasedPagination): - page: Optional[int] = Field( + page: int | None = Field( None, description=strip( """ @@ -656,7 +655,7 @@ class MetadataPagination(MetadataBasedPagination): ), ) - page_offset: Optional[int] = Field( + page_offset: int | None = Field( None, description=strip( """ @@ -707,7 +706,7 @@ metadata_pagination_parameters = parameter_dependency_from_model( class AggregationPagination(MetadataBasedPagination): - order_by: Optional[str] = Field( + order_by: str | None = Field( None, # type: ignore description=strip( """ @@ -749,7 +748,7 @@ class AggregationPagination(MetadataBasedPagination): class AggregatedEntities(BaseModel): - size: Optional[Annotated[int, Field(gt=0)]] = Field( # type: ignore + size: Annotated[int, Field(gt=0)] | None = Field( # type: ignore 1, description=strip( """ @@ -758,7 +757,7 @@ class AggregatedEntities(BaseModel): """ ), ) - required: Optional[MetadataRequired] = Field( + required: MetadataRequired | None = Field( None, description=strip( """ @@ -801,7 +800,7 @@ class QuantityAggregation(AggregationBase): class BucketAggregation(QuantityAggregation): - metrics: Optional[List[str]] = Field( # type: ignore + metrics: list[str] | None = Field( # type: ignore [], description=strip( """ @@ -816,7 +815,7 @@ class BucketAggregation(QuantityAggregation): class TermsAggregation(BucketAggregation): - pagination: Optional[AggregationPagination] = Field( + pagination: AggregationPagination | None = Field( None, description=strip( """ @@ -829,7 +828,7 @@ class TermsAggregation(BucketAggregation): """ ), ) - size: Optional[Annotated[int, Field(gt=0)]] = Field( # type: ignore + size: Annotated[int, Field(gt=0)] | None = Field( # type: ignore None, description=strip( """ @@ -839,11 +838,9 @@ class TermsAggregation(BucketAggregation): """ ), ) - include: Optional[ # type: ignore - Union[ - List[str], Annotated[str, StringConstraints(pattern=r'^[a-zA-Z0-9_\-\s]+$')] - ] - ] = Field( + include: None | ( # type: ignore + list[str] | Annotated[str, StringConstraints(pattern=r'^[a-zA-Z0-9_\-\s]+$')] + ) = Field( None, description=strip( """ @@ -855,7 +852,7 @@ class TermsAggregation(BucketAggregation): """ ), ) - entries: Optional[AggregatedEntities] = Field( + entries: AggregatedEntities | None = Field( None, description=strip( """ @@ -867,7 +864,7 @@ class TermsAggregation(BucketAggregation): class Bounds(BaseModel): - min: Optional[float] = Field( + min: float | None = Field( None, description=strip( """ @@ -875,7 +872,7 @@ class Bounds(BaseModel): """ ), ) - max: Optional[float] = Field( + max: float | None = Field( None, description=strip( """ @@ -897,7 +894,7 @@ class Bounds(BaseModel): class HistogramAggregation(BucketAggregation): - interval: Optional[float] = Field( + interval: float | None = Field( None, gt=0, description=strip( @@ -907,7 +904,7 @@ class HistogramAggregation(BucketAggregation): """ ), ) - buckets: Optional[int] = Field( + buckets: int | None = Field( None, gt=0, description=strip( @@ -921,8 +918,8 @@ class HistogramAggregation(BucketAggregation): """ ), ) - offset: Optional[float] = Field(None, gte=0) - extended_bounds: Optional[Bounds] = None + offset: float | None = Field(None, gte=0) + extended_bounds: Bounds | None = None @model_validator(mode='before') def check_bucketing(cls, values): # pylint: disable=no-self-argument @@ -955,7 +952,7 @@ class MinMaxAggregation(QuantityAggregation): class StatisticsAggregation(AggregationBase): - metrics: Optional[List[str]] = Field( # type: ignore + metrics: list[str] | None = Field( # type: ignore [], description=strip( """ @@ -968,7 +965,7 @@ class StatisticsAggregation(AggregationBase): class Aggregation(BaseModel): - terms: Optional[TermsAggregation] = Body( + terms: TermsAggregation | None = Body( None, description=strip( """ @@ -1010,7 +1007,7 @@ class Aggregation(BaseModel): ), ) - histogram: Optional[HistogramAggregation] = Body( + histogram: HistogramAggregation | None = Body( None, description=strip( """ @@ -1036,7 +1033,7 @@ class Aggregation(BaseModel): ), ) - date_histogram: Optional[DateHistogramAggregation] = Body( + date_histogram: DateHistogramAggregation | None = Body( None, description=strip( """ @@ -1062,7 +1059,7 @@ class Aggregation(BaseModel): ), ) - auto_date_histogram: Optional[AutoDateHistogramAggregation] = Body( + auto_date_histogram: AutoDateHistogramAggregation | None = Body( None, description=strip( """ @@ -1091,7 +1088,7 @@ class Aggregation(BaseModel): ), ) - min_max: Optional[MinMaxAggregation] = Body( + min_max: MinMaxAggregation | None = Body( None, description=strip( """ @@ -1114,7 +1111,7 @@ class Aggregation(BaseModel): ), ) - statistics: Optional[StatisticsAggregation] = Body( + statistics: StatisticsAggregation | None = Body( None, description=strip( """ @@ -1138,13 +1135,13 @@ class Aggregation(BaseModel): class WithQueryAndPagination(WithQuery): - pagination: Optional[MetadataPagination] = Body( + pagination: MetadataPagination | None = Body( None, example={'page_size': 5, 'order_by': 'upload_create_time'} ) class Metadata(WithQueryAndPagination): - required: Optional[MetadataRequired] = Body( + required: MetadataRequired | None = Body( None, example={ 'include': [ @@ -1156,7 +1153,7 @@ class Metadata(WithQueryAndPagination): ] }, ) - aggregations: Optional[Dict[str, Aggregation]] = Body( + aggregations: dict[str, Aggregation] | None = Body( {}, example={ 'all_codes': { @@ -1200,7 +1197,7 @@ class MetadataEditListAction(BaseModel): Defines an action to perform on a list quantity. This enables users to add and remove values. """ - set: Optional[Union[str, List[str]]] = Field( + set: str | list[str] | None = Field( None, description=strip( """ @@ -1209,14 +1206,14 @@ class MetadataEditListAction(BaseModel): add- or remove-operation.""" ), ) - add: Optional[Union[str, List[str]]] = Field( + add: str | list[str] | None = Field( None, description=strip( """ Value(s) to add to the list""" ), ) - remove: Optional[Union[str, List[str]]] = Field( + remove: str | list[str] | None = Field( None, description=strip( """ @@ -1233,7 +1230,7 @@ for quantity in datamodel.EditableUserMetadata.m_def.definitions: quantity.type if quantity.type in (str, int, float, bool) else str ) else: - pydantic_type = Union[str, List[str], MetadataEditListAction] + pydantic_type = Union[str, list[str], MetadataEditListAction] if getattr(quantity, 'a_auth_level', None) == datamodel.AuthLevel.admin: description = '**NOTE:** Only editable by admin user' else: @@ -1253,14 +1250,14 @@ MetadataEditActions = create_model( class MetadataEditRequest(WithQuery): """Defines a request to edit metadata.""" - metadata: Optional[MetadataEditActions] = Field( # type: ignore + metadata: MetadataEditActions | None = Field( # type: ignore None, description=strip( """ Metadata to set, on the upload and/or selected entries.""" ), ) - entries: Optional[Dict[str, MetadataEditActions]] = Field( # type: ignore + entries: dict[str, MetadataEditActions] | None = Field( # type: ignore None, description=strip( """ @@ -1269,14 +1266,14 @@ class MetadataEditRequest(WithQuery): the entries. Note, only quantities defined on the entry level can be set using this method.""" ), ) - entries_key: Optional[str] = Field( + entries_key: str | None = Field( default='entry_id', description=strip( """ Defines which type of key is used in `entries_metadata`. Default is `entry_id`.""" ), ) - verify_only: Optional[bool] = Field( + verify_only: bool | None = Field( default=False, description=strip( """ @@ -1289,7 +1286,7 @@ class MetadataEditRequest(WithQuery): class Files(BaseModel): """Configures the download of files.""" - compress: Optional[bool] = Field( + compress: bool | None = Field( False, description=strip( """ @@ -1299,7 +1296,7 @@ class Files(BaseModel): network connection is limited.""" ), ) - glob_pattern: Optional[str] = Field( + glob_pattern: str | None = Field( None, description=strip( """ @@ -1309,7 +1306,7 @@ class Files(BaseModel): [fnmatch](https://docs.python.org/3/library/fnmatch.html) is used.""" ), ) - re_pattern: Optional[str] = Field( + re_pattern: str | None = Field( None, description=strip( """ @@ -1321,7 +1318,7 @@ class Files(BaseModel): A re pattern will replace a given glob pattern.""" ), ) - include_files: Optional[List[str]] = Field( + include_files: list[str] | None = Field( None, description=strip( """ @@ -1372,7 +1369,7 @@ files_parameters = parameter_dependency_from_model('files_parameters', Files) # class Bucket(BaseModel): - entries: Optional[List[Dict[str, Any]]] = Field( + entries: list[dict[str, Any]] | None = Field( None, description=strip("""The entries that were requested for each value.""") ) count: int = Field( @@ -1386,19 +1383,19 @@ class Bucket(BaseModel): aggregations on non nested quantities.""" ), ) - metrics: Optional[Dict[str, int]] = None + metrics: dict[str, int] | None = None - value: Union[StrictBool, float, str] + value: StrictBool | float | str class BucketAggregationResponse(BaseModel): - data: List[Bucket] = Field( + data: list[Bucket] = Field( None, description=strip("""The aggregation data as a list.""") ) class TermsAggregationResponse(BucketAggregationResponse, TermsAggregation): - pagination: Optional[PaginationResponse] = None # type: ignore + pagination: PaginationResponse | None = None # type: ignore class HistogramAggregationResponse(BucketAggregationResponse, HistogramAggregation): @@ -1420,33 +1417,33 @@ class AutoDateHistogramAggregationResponse( class MinMaxAggregationResponse(MinMaxAggregation): - data: List[Union[float, None]] + data: list[float | None] class StatisticsAggregationResponse(StatisticsAggregation): - data: Optional[Dict[str, int]] = None + data: dict[str, int] | None = None class AggregationResponse(Aggregation): - terms: Optional[TermsAggregationResponse] = None - histogram: Optional[HistogramAggregationResponse] = None - date_histogram: Optional[DateHistogramAggregationResponse] = None - auto_date_histogram: Optional[AutoDateHistogramAggregationResponse] = None - min_max: Optional[MinMaxAggregationResponse] = None - statistics: Optional[StatisticsAggregationResponse] = None + terms: TermsAggregationResponse | None = None + histogram: HistogramAggregationResponse | None = None + date_histogram: DateHistogramAggregationResponse | None = None + auto_date_histogram: AutoDateHistogramAggregationResponse | None = None + min_max: MinMaxAggregationResponse | None = None + statistics: StatisticsAggregationResponse | None = None class CodeResponse(BaseModel): curl: str requests: str - nomad_lab: Optional[str] = None + nomad_lab: str | None = None class MetadataResponse(Metadata): pagination: PaginationResponse = None # type: ignore - aggregations: Optional[Dict[str, AggregationResponse]] = None # type: ignore + aggregations: dict[str, AggregationResponse] | None = None # type: ignore - data: List[Dict[str, Any]] = Field( + data: list[dict[str, Any]] = Field( None, description=strip( """ @@ -1455,7 +1452,7 @@ class MetadataResponse(Metadata): ), ) - code: Optional[CodeResponse] = None + code: CodeResponse | None = None es_query: Any = Field( None, description=strip( diff --git a/nomad/app/v1/models/pagination.py b/nomad/app/v1/models/pagination.py index faa08946079ae66ac535acd160063eb5af65ac08..2a127218625f3a7a4a552e72e1e351406cbbe592 100644 --- a/nomad/app/v1/models/pagination.py +++ b/nomad/app/v1/models/pagination.py @@ -27,28 +27,28 @@ class Direction(str, enum.Enum): class Pagination(BaseModel): """Defines the order, size, and page of results.""" - page_size: Optional[int] = Field( + page_size: int | None = Field( 10, description=strip(""" The page size, e.g. the maximum number of items contained in one response. A `page_size` of 0 will return no results. """), ) - order_by: Optional[str] = Field( + order_by: str | None = Field( None, description=strip(""" The results are ordered by the values of this field. If omitted, default ordering is applied. """), ) - order: Optional[Direction] = Field( + order: Direction | None = Field( Direction.asc, description=strip(""" The ordering direction of the results based on `order_by`. Its either ascending `asc` or descending `desc`. Default is `asc`. """), ) - page_after_value: Optional[str] = Field( + page_after_value: str | None = Field( None, description=strip(""" This attribute defines the position after which the page begins, and is used @@ -67,7 +67,7 @@ class Pagination(BaseModel): `page_after_value` and `next_page_after_value` to iterate through the results. """), ) - page: Optional[int] = Field( + page: int | None = Field( None, description=strip(""" The number of the page (1-based). When provided in a request, this attribute @@ -81,7 +81,7 @@ class Pagination(BaseModel): **NOTE #2**: Only one, `page`, `page_offset` or `page_after_value`, can be used. """), ) - page_offset: Optional[int] = Field( + page_offset: int | None = Field( None, description=strip(""" The number of skipped entries. When provided in a request, this attribute @@ -232,7 +232,7 @@ class PaginationResponse(Pagination): """ ), ) - next_page_after_value: Optional[str] = Field( + next_page_after_value: str | None = Field( None, description=strip( """ @@ -242,7 +242,7 @@ class PaginationResponse(Pagination): """ ), ) - page_url: Optional[str] = Field( + page_url: str | None = Field( None, description=strip( """ @@ -250,7 +250,7 @@ class PaginationResponse(Pagination): """ ), ) - next_page_url: Optional[str] = Field( + next_page_url: str | None = Field( None, description=strip( """ @@ -258,7 +258,7 @@ class PaginationResponse(Pagination): """ ), ) - prev_page_url: Optional[str] = Field( + prev_page_url: str | None = Field( None, description=strip( """ @@ -267,7 +267,7 @@ class PaginationResponse(Pagination): """ ), ) - first_page_url: Optional[str] = Field( + first_page_url: str | None = Field( None, description=strip( """ diff --git a/nomad/app/v1/routers/auth.py b/nomad/app/v1/routers/auth.py index f758d6e80c4c33fe98a701bc9230e220d7bbaf11..318aa5682fdb23d6b0f7ed513f51b453eead8a01 100644 --- a/nomad/app/v1/routers/auth.py +++ b/nomad/app/v1/routers/auth.py @@ -20,7 +20,8 @@ import hmac import hashlib import uuid import requests -from typing import Callable, cast, Union +from typing import cast, Union +from collections.abc import Callable from inspect import Parameter, signature from functools import wraps from fastapi import ( @@ -228,7 +229,7 @@ def _get_user_bearer_token_auth(bearer_token: str) -> User: unverified_payload = jwt.decode( bearer_token, options={'verify_signature': False} ) - if unverified_payload.keys() == set(['user', 'exp']): + if unverified_payload.keys() == {'user', 'exp'}: user = _get_user_from_simple_token(bearer_token) return user except jwt.exceptions.DecodeError: @@ -410,7 +411,7 @@ async def get_token_via_query(username: str, password: str): response_model=SignatureToken, ) async def get_signature_token( - user: Union[User, None] = Depends(create_user_dependency(required=True)), + user: User | None = Depends(create_user_dependency(required=True)), ): """ Generates and returns a signature token for the authenticated user. Authentication @@ -462,7 +463,7 @@ def generate_upload_token(user): bytes(config.services.api_secret, 'utf-8'), msg=payload, digestmod=hashlib.sha1 ) - return '%s.%s' % ( + return '{}.{}'.format( utils.base64_encode(payload), utils.base64_encode(signature.digest()), ) diff --git a/nomad/app/v1/routers/datasets.py b/nomad/app/v1/routers/datasets.py index 5d94cfc40fe57e25c1ff00a70a8be563dd662298..6cecffb685f91c813fdf57e6ca4b4223186f9ef8 100644 --- a/nomad/app/v1/routers/datasets.py +++ b/nomad/app/v1/routers/datasets.py @@ -249,7 +249,7 @@ dataset_pagination_parameters = parameter_dependency_from_model( class DatasetsResponse(BaseModel): pagination: PaginationResponse = Field(None) - data: List[Dataset] = Field(None) # type: ignore + data: list[Dataset] = Field(None) # type: ignore class DatasetResponse(BaseModel): @@ -263,12 +263,10 @@ class DatasetType(str, enum.Enum): class DatasetCreate(BaseModel): # type: ignore - dataset_name: Optional[str] = Field( - None, description='The new name for the dataset.' - ) - dataset_type: Optional[DatasetType] = Field(None) - query: Optional[Query] = Field(None) - entries: Optional[List[str]] = Field(None) + dataset_name: str | None = Field(None, description='The new name for the dataset.') + dataset_type: DatasetType | None = Field(None) + query: Query | None = Field(None) + entries: list[str] | None = Field(None) @router.get( @@ -283,7 +281,7 @@ async def get_datasets( request: Request, dataset_id: str = FastApiQuery(None), dataset_name: str = FastApiQuery(None), - user_id: List[str] = FastApiQuery(None), + user_id: list[str] = FastApiQuery(None), dataset_type: str = FastApiQuery(None), doi: str = FastApiQuery(None), prefix: str = FastApiQuery(None), diff --git a/nomad/app/v1/routers/entries.py b/nomad/app/v1/routers/entries.py index a89248d2b5d9590052e3c9d1fb8a2a7ae43b934b..0f7abd0f340e1a1bf517708f159f6fc5d2fab679 100644 --- a/nomad/app/v1/routers/entries.py +++ b/nomad/app/v1/routers/entries.py @@ -18,7 +18,8 @@ from datetime import datetime from enum import Enum -from typing import Annotated, Optional, Set, Union, Dict, Iterator, Any, List, Type +from typing import Annotated, Optional, Set, Union, Dict, Any, List, Type +from collections.abc import Iterator from fastapi import ( APIRouter, Depends, @@ -196,7 +197,7 @@ replace the references: ) -ArchiveRequired = Union[str, Dict[str, Any]] +ArchiveRequired = Union[str, dict[str, Any]] _archive_required_field = Body( '*', @@ -210,23 +211,23 @@ _archive_required_field = Body( class EntriesArchive(WithQueryAndPagination): - required: Optional[ArchiveRequired] = _archive_required_field + required: ArchiveRequired | None = _archive_required_field class EntryArchiveRequest(BaseModel): - required: Optional[ArchiveRequired] = _archive_required_field + required: ArchiveRequired | None = _archive_required_field class EntriesArchiveDownload(WithQuery, EntryArchiveRequest): - files: Optional[Files] = Body(None) + files: Files | None = Body(None) class EntriesRawDir(WithQuery): - pagination: Optional[MetadataPagination] = Body(None) + pagination: MetadataPagination | None = Body(None) class EntriesRaw(WithQuery): - files: Optional[Files] = Body(None, example={'glob_pattern': 'vasp*.xml*'}) + files: Files | None = Body(None, example={'glob_pattern': 'vasp*.xml*'}) class EntryRawDirFile(BaseModel): @@ -238,13 +239,13 @@ class EntryRawDir(BaseModel): entry_id: str = Field(None) upload_id: str = Field(None) mainfile: str = Field(None) - mainfile_key: Optional[str] = Field(None) - files: List[EntryRawDirFile] = Field(None) + mainfile_key: str | None = Field(None) + files: list[EntryRawDirFile] = Field(None) class EntriesRawDirResponse(EntriesRawDir): pagination: PaginationResponse = Field(None) # type: ignore - data: List[EntryRawDir] = Field(None) + data: list[EntryRawDir] = Field(None) class EntryRawDirResponse(BaseModel): @@ -256,12 +257,12 @@ class EntryArchive(BaseModel): entry_id: str = Field(None) upload_id: str = Field(None) parser_name: str = Field(None) - archive: Dict[str, Any] = Field(None) + archive: dict[str, Any] = Field(None) class EntriesArchiveResponse(EntriesArchive): pagination: PaginationResponse = Field(None) # type: ignore - data: List[EntryArchive] = Field(None) + data: list[EntryArchive] = Field(None) class EntryArchiveResponse(EntryArchiveRequest): @@ -277,10 +278,10 @@ class EntryMetadataResponse(BaseModel): class EntryMetadataEditActionField(BaseModel): value: str = Field(None, description='The value/values that is set as a string.') - success: Optional[bool] = Field( + success: bool | None = Field( None, description='If this can/could be done. Only in API response.' ) - message: Optional[str] = Field( + message: str | None = Field( None, descriptin='A message that details the action result. Only in API response.', ) @@ -292,7 +293,7 @@ EntryMetadataEditActions = create_model( quantity.name: ( Optional[EntryMetadataEditActionField] if quantity.is_scalar - else Optional[List[EntryMetadataEditActionField]], + else Optional[list[EntryMetadataEditActionField]], None, ) for quantity in EditableUserMetadata.m_def.definitions @@ -302,9 +303,7 @@ EntryMetadataEditActions = create_model( class EntryMetadataEdit(WithQuery): - verify: Optional[bool] = Field( - False, description='If true, no action is performed.' - ) + verify: bool | None = Field(False, description='If true, no action is performed.') actions: EntryMetadataEditActions = Field( # type: ignore None, description='Each action specifies a single value (even for multi valued quantities).', @@ -331,7 +330,7 @@ class ArchiveChangeAction(Enum): remove = 'remove' -def json_schema_extra(schema: dict[str, Any], model: Type['ArchiveChange']): +def json_schema_extra(schema: dict[str, Any], model: type['ArchiveChange']): schema['properties']['new_value'] = {} @@ -344,7 +343,7 @@ class ArchiveChange(BaseModel): class EntryEdit(BaseModel): - changes: List[ArchiveChange] + changes: list[ArchiveChange] class EntryEditResponse(EntryEdit): @@ -575,8 +574,8 @@ async def get_entries_metadata( def _do_exhaustive_search( - owner: Owner, query: Query, include: List[str], user: User -) -> Iterator[Dict[str, Any]]: + owner: Owner, query: Query, include: list[str], user: User +) -> Iterator[dict[str, Any]]: page_after_value = None while True: response = perform_search( @@ -591,8 +590,7 @@ def _do_exhaustive_search( page_after_value = response.pagination.next_page_after_value - for result in response.data: - yield result + yield from response.data if page_after_value is None or len(response.data) == 0: break @@ -626,7 +624,7 @@ class _Uploads: self._upload_files.close() -def _create_entry_rawdir(entry_metadata: Dict[str, Any], uploads: _Uploads): +def _create_entry_rawdir(entry_metadata: dict[str, Any], uploads: _Uploads): entry_id = entry_metadata['entry_id'] upload_id = entry_metadata['upload_id'] mainfile = entry_metadata['mainfile'] @@ -1315,7 +1313,7 @@ async def get_entry_raw_file( ..., description="A relative path to a file based on the directory of the entry's mainfile.", ), - offset: Optional[int] = QueryParameter( + offset: int | None = QueryParameter( 0, ge=0, description=strip( @@ -1324,7 +1322,7 @@ async def get_entry_raw_file( is the start of the file.""" ), ), - length: Optional[int] = QueryParameter( + length: int | None = QueryParameter( -1, ge=-1, description=strip( @@ -1333,7 +1331,7 @@ async def get_entry_raw_file( the file is streamed.""" ), ), - decompress: Optional[bool] = QueryParameter( + decompress: bool | None = QueryParameter( False, description=strip( """ @@ -1521,7 +1519,7 @@ async def post_entry_edit( key = to_key(path_segment) repeated_sub_section = isinstance(next_key, int) - next_value: Union[list, dict] = [] if repeated_sub_section else {} + next_value: list | dict = [] if repeated_sub_section else {} if isinstance(section_data, list): if section_data[key] is None: @@ -1634,11 +1632,11 @@ async def post_entry_archive_query( def edit( - query: Query, user: User, mongo_update: Dict[str, Any] = None, re_index=True -) -> List[str]: + query: Query, user: User, mongo_update: dict[str, Any] = None, re_index=True +) -> list[str]: # get all entries that have to change - entry_ids: List[str] = [] - upload_ids: Set[str] = set() + entry_ids: list[str] = [] + upload_ids: set[str] = set() with utils.timer(logger, 'edit query executed'): all_entries = _do_exhaustive_search( owner=Owner.user, query=query, include=['entry_id', 'upload_id'], user=user @@ -1662,7 +1660,7 @@ def edit( # re-index the affected entries in elastic search with utils.timer(logger, 'edit elastic update executed', size=len(entry_ids)): if re_index: - updated_metadata: List[datamodel.EntryMetadata] = [] + updated_metadata: list[datamodel.EntryMetadata] = [] for entry in proc.Entry.objects(entry_id__in=entry_ids): entry_metadata = entry.mongo_metadata(entry.upload) # Ensure that updated fields are marked as "set", even if they are cleared diff --git a/nomad/app/v1/routers/info.py b/nomad/app/v1/routers/info.py index a44825ecc92ce1d070f717939266b1325b3c34cc..8f123c6ec80955219fa4d1e30eee11b1e4b15158 100644 --- a/nomad/app/v1/routers/info.py +++ b/nomad/app/v1/routers/info.py @@ -76,24 +76,22 @@ class StatisticsModel(BaseModel): class CodeInfoModel(BaseModel): - code_name: Optional[str] = Field( - None, description='Name of the code or input format' - ) - code_homepage: Optional[str] = Field( + code_name: str | None = Field(None, description='Name of the code or input format') + code_homepage: str | None = Field( None, description='Homepage of the code or input format' ) class InfoModel(BaseModel): - parsers: List[str] - metainfo_packages: List[str] - codes: List[CodeInfoModel] - normalizers: List[str] - plugin_entry_points: List[dict] = Field( + parsers: list[str] + metainfo_packages: list[str] + codes: list[CodeInfoModel] + normalizers: list[str] + plugin_entry_points: list[dict] = Field( None, desciption='List of plugin entry points that are activated in this deployment.', ) - plugin_packages: List[dict] = Field( + plugin_packages: list[dict] = Field( None, desciption='List of plugin packages that are installed in this deployment.', ) @@ -115,7 +113,7 @@ class InfoModel(BaseModel): ) -_statistics: Dict[str, Any] = None +_statistics: dict[str, Any] = None def statistics(): diff --git a/nomad/app/v1/routers/metainfo.py b/nomad/app/v1/routers/metainfo.py index c9a557cc0d47f53814f130a9dd806bbcdb92f884..fe3477ce5b1452b1d45904d97342101d0b937e39 100644 --- a/nomad/app/v1/routers/metainfo.py +++ b/nomad/app/v1/routers/metainfo.py @@ -140,7 +140,7 @@ _not_authorized_to_upload = ( class PackageDefinitionResponse(BaseModel): section_definition_id: str = Field(None) - data: Dict[str, Any] = Field(None) + data: dict[str, Any] = Field(None) def get_package_by_section_definition_id(section_definition_id: str) -> dict: diff --git a/nomad/app/v1/routers/north.py b/nomad/app/v1/routers/north.py index 60ac7ca7ab379a49377183a6ffbd5bb3dad68893..0ad4cb4105f3952479a09455a42f50636e4247e2 100644 --- a/nomad/app/v1/routers/north.py +++ b/nomad/app/v1/routers/north.py @@ -54,18 +54,18 @@ class ToolStateEnum(str, Enum): class ToolModel(NORTHTool): name: str - state: Optional[ToolStateEnum] = None + state: ToolStateEnum | None = None class ToolResponseModel(BaseModel): tool: str username: str - upload_mount_dir: Optional[str] = None + upload_mount_dir: str | None = None data: ToolModel class ToolsResponseModel(BaseModel): - data: List[ToolModel] = [] + data: list[ToolModel] = [] _bad_tool_response = ( @@ -162,7 +162,7 @@ async def get_tool( async def start_tool( tool: ToolModel = Depends(tool), user: User = Depends(create_user_dependency(required=True)), - upload_id: Optional[str] = None, + upload_id: str | None = None, ): tool.state = ToolStateEnum.stopped @@ -198,7 +198,7 @@ async def start_tool( ) upload_query &= Q(publish_time=None) - uploads: List[Dict] = [] + uploads: list[dict] = [] for upload in Upload.objects.filter(upload_query): if not hasattr(upload.upload_files, 'external_os_path'): # In case the files are missing for one reason or another @@ -224,7 +224,7 @@ async def start_tool( } ) - external_mounts: List[Dict[str, str]] = [] + external_mounts: list[dict[str, str]] = [] for ext_mount in tool.external_mounts: external_mounts.append( { diff --git a/nomad/app/v1/routers/suggestions.py b/nomad/app/v1/routers/suggestions.py index d6877804f961672829c66bf5ee7e75e947010b2c..a528fde90bf9788b78131c7acde5948207713367 100644 --- a/nomad/app/v1/routers/suggestions.py +++ b/nomad/app/v1/routers/suggestions.py @@ -34,7 +34,7 @@ router = APIRouter() # This is a dynamically create enum class for enumerating all allowed # quantities. FastAPI uses python enums to validate and document options. -suggestable_quantities: Set[str] = None +suggestable_quantities: set[str] = None class SuggestionError(Exception): @@ -43,7 +43,7 @@ class SuggestionError(Exception): class Suggestion(BaseModel): value: str = Field(None, description='The returned suggestion.') - weight: Optional[float] = Field(None, description='The suggestion weight.') + weight: float | None = Field(None, description='The suggestion weight.') class Quantity(BaseModel): @@ -55,7 +55,7 @@ class Quantity(BaseModel): class SuggestionsRequest(BaseModel): - quantities: List[Quantity] = Field( # type: ignore + quantities: list[Quantity] = Field( # type: ignore None, description='List of quantities for which the suggestions are retrieved.' ) input: str = Field( @@ -68,7 +68,7 @@ class SuggestionsRequest(BaseModel): '', tags=['suggestions'], summary='Get a list of suggestions for the given quantity names and input.', - response_model=Dict[str, List[Suggestion]], + response_model=dict[str, list[Suggestion]], response_model_exclude_unset=True, response_model_exclude_none=True, ) @@ -123,8 +123,8 @@ async def get_suggestions( raise SuggestionError from e # We return the original field in the source document. - response: Dict[str, List[Suggestion]] = defaultdict(list) - suggestions: Dict[str, Dict[str, float]] = defaultdict(dict) + response: dict[str, list[Suggestion]] = defaultdict(list) + suggestions: dict[str, dict[str, float]] = defaultdict(dict) def add_suggestion(name, value, weight): values = suggestions[name] @@ -168,7 +168,7 @@ async def get_suggestions( for item in original: options.append(item) - options: List[str] = [] + options: list[str] = [] parts = quantity_path.split('.') gather_options(option._source, parts, options) diff --git a/nomad/app/v1/routers/systems.py b/nomad/app/v1/routers/systems.py index 8c7e375aa9563d1e1a76590d6ca7576591014230..d392014a7fb30e1d0665c71b38140a0b0a731076 100644 --- a/nomad/app/v1/routers/systems.py +++ b/nomad/app/v1/routers/systems.py @@ -119,7 +119,7 @@ class FormatFeature(str, Enum): PBC = 'Periodic boundary conditions (PBC)' -format_map: Dict[str, dict] = OrderedDict( +format_map: dict[str, dict] = OrderedDict( { 'cif': { 'label': 'cif', @@ -204,7 +204,7 @@ FormatEnum = TempFormatEnum( ) # type: ignore -wrap_mode_map: Dict[str, dict] = OrderedDict( +wrap_mode_map: dict[str, dict] = OrderedDict( { 'original': {'description': 'The original positions as set in the data'}, 'wrap': { @@ -351,7 +351,7 @@ Here is a brief rundown of the different features each format supports: path = path[len(prefix) :] # Add indexing - query_list: List[Union[str, int]] = [] + query_list: list[str | int] = [] paths = [x for x in path.split('/') if x != ''] i = 0 while i < len(paths): diff --git a/nomad/app/v1/routers/uploads.py b/nomad/app/v1/routers/uploads.py index 54e3a2f6bb50f524cc2b6a1b7b074da844212556..3043cc64b41cdf0d5c4b2918d80876bb77ad3635 100644 --- a/nomad/app/v1/routers/uploads.py +++ b/nomad/app/v1/routers/uploads.py @@ -124,27 +124,27 @@ class UploadRole(str, Enum): class ProcData(BaseModel): process_running: bool = Field(description='If a process is running') - current_process: Optional[str] = Field( + current_process: str | None = Field( None, description='Name of the current or last completed process' ) process_status: str = Field( ProcessStatus.READY, description='The status of the current or last completed process', ) - last_status_message: Optional[str] = Field( + last_status_message: str | None = Field( None, description='A short, human readable message from the current process, with ' 'information about what the current process is doing, or information ' 'about the completion (successful or not) of the last process, if no ' 'process is currently running.', ) - errors: List[str] = Field( + errors: list[str] = Field( descriptions='A list of error messages that occurred during the last processing' ) - warnings: List[str] = Field( + warnings: list[str] = Field( description='A list of warning messages that occurred during the last processing' ) - complete_time: Optional[datetime] = Field( + complete_time: datetime | None = Field( None, description='Date and time of the completion of the last process' ) model_config = ConfigDict(from_attributes=True) @@ -152,51 +152,51 @@ class ProcData(BaseModel): class UploadProcData(ProcData): upload_id: str = Field(description='The unique id for the upload.') - upload_name: Optional[str] = Field( + upload_name: str | None = Field( None, description='The name of the upload. This can be provided during upload ' 'using the `upload_name` query parameter.', ) - upload_create_time: Optional[datetime] = Field( + upload_create_time: datetime | None = Field( None, description='Date and time of the creation of the upload.' ) - main_author: Optional[str] = Field( + main_author: str | None = Field( None, description=strip('The main author of the upload.') ) - coauthors: Optional[List[str]] = Field( + coauthors: list[str] | None = Field( None, description=strip('A list of upload coauthors.') ) - coauthor_groups: Optional[List[str]] = Field( + coauthor_groups: list[str] | None = Field( None, description=strip('A list of upload coauthor groups.') ) - reviewers: Optional[List[str]] = Field( + reviewers: list[str] | None = Field( None, description=strip('A list of upload reviewers.') ) - reviewer_groups: Optional[List[str]] = Field( + reviewer_groups: list[str] | None = Field( None, description=strip('A list of upload reviewer groups.') ) - writers: Optional[List[str]] = Field( + writers: list[str] | None = Field( None, description=strip('All writer users (main author, upload coauthors).') ) - writer_groups: Optional[List[str]] = Field( + writer_groups: list[str] | None = Field( None, description=strip('All writer groups (coauthor groups).') ) - viewers: Optional[List[str]] = Field( + viewers: list[str] | None = Field( None, description=strip( 'All viewer users (main author, upload coauthors, and reviewers)' ), ) - viewer_groups: Optional[List[str]] = Field( + viewer_groups: list[str] | None = Field( None, description=strip('All viewer groups (coauthor groups, reviewer groups).'), ) published: bool = Field(False, description='If this upload is already published.') - published_to: Optional[List[str]] = Field( + published_to: list[str] | None = Field( None, description='A list of other NOMAD deployments that this upload was uploaded to already.', ) - publish_time: Optional[datetime] = Field( + publish_time: datetime | None = Field( None, description='Date and time of publication, if the upload has been published.', ) @@ -212,7 +212,7 @@ class UploadProcData(ProcData): entries: int = Field( 0, description='The number of identified entries in this upload.' ) - upload_files_server_path: Optional[str] = Field( + upload_files_server_path: str | None = Field( None, description='The path to the uploads files on the server.' ) @@ -221,10 +221,10 @@ class EntryProcData(ProcData): entry_id: str = Field() entry_create_time: datetime = Field() mainfile: str = Field() - mainfile_key: Optional[str] = Field(None) + mainfile_key: str | None = Field(None) upload_id: str = Field() parser_name: str = Field() - entry_metadata: Optional[dict] = Field(None) + entry_metadata: dict | None = Field(None) class UploadProcDataPagination(Pagination): @@ -339,15 +339,15 @@ class UploadProcDataResponse(BaseModel): class UploadProcDataQuery(BaseModel): - upload_id: Optional[List[str]] = Field( + upload_id: list[str] | None = Field( None, description='Search for uploads matching the given id. Multiple values can be specified.', ) - upload_name: Optional[List[str]] = Field( + upload_name: list[str] | None = Field( None, description='Search for uploads matching the given upload_name. Multiple values can be specified.', ) - is_processing: Optional[bool] = Field( + is_processing: bool | None = Field( None, description=strip( """ @@ -356,7 +356,7 @@ class UploadProcDataQuery(BaseModel): If unset, include everything.""" ), ) - is_published: Optional[bool] = Field( + is_published: bool | None = Field( None, description=strip( """ @@ -365,10 +365,10 @@ class UploadProcDataQuery(BaseModel): If unset: include everything.""" ), ) - process_status: Optional[str] = Field( + process_status: str | None = Field( None, description=strip('Search by the process status.') ) - is_owned: Optional[bool] = Field( + is_owned: bool | None = Field( None, description=strip( """ @@ -393,7 +393,7 @@ upload_proc_data_query_parameters = parameter_dependency_from_model( class UploadProcDataQueryResponse(BaseModel): query: UploadProcDataQuery = Field() pagination: PaginationResponse = Field() - data: List[UploadProcData] = Field( + data: list[UploadProcData] = Field( None, description=strip( """ @@ -434,7 +434,7 @@ class EntryProcDataQueryResponse(BaseModel): """ ), ) - data: List[EntryProcData] = Field( + data: list[EntryProcData] = Field( None, description=strip( """ @@ -472,15 +472,15 @@ class RawDirFileMetadata(BaseModel): """Metadata about a file""" name: str = Field() - size: Optional[int] = Field(None) - entry_id: Optional[str] = Field( + size: int | None = Field(None) + entry_id: str | None = Field( None, description=strip( """ If this is a mainfile: the ID of the corresponding entry.""" ), ) - parser_name: Optional[str] = Field( + parser_name: str | None = Field( None, description=strip( """ @@ -499,8 +499,8 @@ class RawDirDirectoryMetadata(BaseModel): """Metadata about a directory""" name: str = Field() - size: Optional[int] = Field(None) - content: List[RawDirElementMetadata] = Field( + size: int | None = Field(None) + content: list[RawDirElementMetadata] = Field( examples=[ [ {'name': 'a_directory', 'is_file': False, 'size': 456}, @@ -519,18 +519,18 @@ class RawDirDirectoryMetadata(BaseModel): class RawDirResponse(BaseModel): path: str = Field(examples=['The/requested/path']) access: str = Field() - file_metadata: Optional[RawDirFileMetadata] = Field(None) - directory_metadata: Optional[RawDirDirectoryMetadata] = Field(None) - pagination: Optional[PaginationResponse] = Field(None) + file_metadata: RawDirFileMetadata | None = Field(None) + directory_metadata: RawDirDirectoryMetadata | None = Field(None) + pagination: PaginationResponse | None = Field(None) class ProcessingData(BaseModel): upload_id: str = Field() path: str = Field() - entry_id: Optional[str] = Field(None) - parser_name: Optional[str] = Field(None) - entry: Optional[EntryProcData] = Field(None) - archive: Optional[Dict[str, Any]] = Field(None) + entry_id: str | None = Field(None) + parser_name: str | None = Field(None) + entry: EntryProcData | None = Field(None) + archive: dict[str, Any] | None = Field(None) class PutRawFileResponse(BaseModel): @@ -548,7 +548,7 @@ class PutRawFileResponse(BaseModel): The upload data as a dictionary.""" ), ) - processing: Optional[ProcessingData] = Field( + processing: ProcessingData | None = Field( None, description=strip( """ @@ -561,8 +561,8 @@ class PutRawFileResponse(BaseModel): class DeleteEntryFilesRequest(WithQuery): """Defines a request to delete entry files.""" - owner: Optional[Owner] = Body('all') - include_parent_folders: Optional[bool] = Field( + owner: Owner | None = Body('all') + include_parent_folders: bool | None = Field( False, description=strip( """ @@ -777,7 +777,7 @@ async def get_command_examples( ) async def get_uploads( request: Request, - roles: List[UploadRole] = FastApiQuery( + roles: list[UploadRole] = FastApiQuery( None, description='Only return uploads where the user has one of the given roles.', ), @@ -1042,7 +1042,7 @@ async def get_upload_rawdir_path( directory_list = upload_files.raw_directory_list(path) upload_files.close() content = [] - path_to_element: Dict[str, RawDirElementMetadata] = {} + path_to_element: dict[str, RawDirElementMetadata] = {} total = 0 total_size = 0 for i, path_info in enumerate(directory_list): @@ -1149,7 +1149,7 @@ async def get_upload_raw_path( upload_id: str = Path(..., description='The unique id of the upload.'), path: str = Path(..., description='The path within the upload raw files.'), files_params: Files = Depends(files_parameters), - offset: Optional[int] = FastApiQuery( + offset: int | None = FastApiQuery( 0, description=strip( """ @@ -1158,7 +1158,7 @@ async def get_upload_raw_path( is 0, i.e. the start of the file.""" ), ), - length: Optional[int] = FastApiQuery( + length: int | None = FastApiQuery( -1, description=strip( """ @@ -1324,7 +1324,7 @@ async def put_upload_raw_path( request: Request, upload_id: str = Path(..., description='The unique id of the upload.'), path: str = Path(..., description='The path within the upload raw files.'), - file: List[UploadFile] = File(None), + file: list[UploadFile] = File(None), local_path: str = FastApiQuery( None, description=strip("""Internal/Admin use only."""), @@ -1733,7 +1733,7 @@ async def get_upload_entry_archive_mainfile( mainfile: str = Path( ..., description="The mainfile path within the upload's raw files." ), - mainfile_key: Optional[str] = FastApiQuery( + mainfile_key: str | None = FastApiQuery( None, description='The mainfile_key, for accessing child entries.' ), user: User = Depends(create_user_dependency(required=False)), @@ -1784,7 +1784,7 @@ async def get_upload_entry_archive( ) async def post_upload( request: Request, - file: List[UploadFile] = File(None), + file: list[UploadFile] = File(None), local_path: str = FastApiQuery( None, description=strip( @@ -1792,7 +1792,7 @@ async def post_upload( Internal/Admin use only.""" ), ), - example_upload_id: Optional[str] = FastApiQuery( + example_upload_id: str | None = FastApiQuery( None, description=strip( """ @@ -2231,7 +2231,7 @@ async def post_upload_action_delete_entry_files( # Determine paths to delete try: - paths_to_delete: Set[str] = set() + paths_to_delete: set[str] = set() for es_entry in es_entries: mainfile = es_entry['mainfile'] path_to_delete = ( @@ -2328,14 +2328,14 @@ async def post_upload_action_lift_embargo( ) async def get_upload_bundle( upload_id: str = Path(..., description='The unique id of the upload.'), - include_raw_files: Optional[bool] = FastApiQuery( + include_raw_files: bool | None = FastApiQuery( True, description=strip( """ If raw files should be included in the bundle (true by default).""" ), ), - include_archive_files: Optional[bool] = FastApiQuery( + include_archive_files: bool | None = FastApiQuery( True, description=strip( """ @@ -2343,7 +2343,7 @@ async def get_upload_bundle( (true by default).""" ), ), - include_datasets: Optional[bool] = FastApiQuery( + include_datasets: bool | None = FastApiQuery( True, description=strip( """ @@ -2397,7 +2397,7 @@ async def get_upload_bundle( ) async def post_upload_bundle( request: Request, - file: List[UploadFile] = File(None), + file: list[UploadFile] = File(None), local_path: str = FastApiQuery( None, description=strip( @@ -2405,7 +2405,7 @@ async def post_upload_bundle( Internal/Admin use only.""" ), ), - embargo_length: Optional[int] = FastApiQuery( + embargo_length: int | None = FastApiQuery( None, description=strip( """ @@ -2414,7 +2414,7 @@ async def post_upload_bundle( embargo.""" ), ), - include_raw_files: Optional[bool] = FastApiQuery( + include_raw_files: bool | None = FastApiQuery( None, description=strip( """ @@ -2422,7 +2422,7 @@ async def post_upload_bundle( *(only admins can change this setting)*.""" ), ), - include_archive_files: Optional[bool] = FastApiQuery( + include_archive_files: bool | None = FastApiQuery( None, description=strip( """ @@ -2430,7 +2430,7 @@ async def post_upload_bundle( *(only admins can change this setting)*.""" ), ), - include_datasets: Optional[bool] = FastApiQuery( + include_datasets: bool | None = FastApiQuery( None, description=strip( """ @@ -2438,7 +2438,7 @@ async def post_upload_bundle( *(only admins can change this setting)*.""" ), ), - include_bundle_info: Optional[bool] = FastApiQuery( + include_bundle_info: bool | None = FastApiQuery( None, description=strip( """ @@ -2446,7 +2446,7 @@ async def post_upload_bundle( *(only admins can change this setting)*.""" ), ), - keep_original_timestamps: Optional[bool] = FastApiQuery( + keep_original_timestamps: bool | None = FastApiQuery( None, description=strip( """ @@ -2455,7 +2455,7 @@ async def post_upload_bundle( *(only admins can change this setting)*.""" ), ), - set_from_oasis: Optional[bool] = FastApiQuery( + set_from_oasis: bool | None = FastApiQuery( None, description=strip( """ @@ -2463,7 +2463,7 @@ async def post_upload_bundle( *(only admins can change this setting)*.""" ), ), - trigger_processing: Optional[bool] = FastApiQuery( + trigger_processing: bool | None = FastApiQuery( None, description=strip( """ @@ -2565,11 +2565,11 @@ async def post_upload_bundle( async def _get_files_if_provided( tmp_dir_prefix: str, request: Request, - file: List[UploadFile], + file: list[UploadFile], local_path: str, file_name: str, user: User, -) -> Tuple[List[str], List[str], Union[None, int]]: +) -> tuple[list[str], list[str], None | int]: """ If the user provides one or more files with the api call, load and save them to a temporary folder (or, if method 0 is used, just "forward" the file path). The method thus needs to identify @@ -2579,7 +2579,7 @@ async def _get_files_if_provided( data was provided with the api call. """ # Determine the source data stream - sources: List[Tuple[Any, str]] = [] # List of tuples (source, filename) + sources: list[tuple[Any, str]] = [] # List of tuples (source, filename) if local_path: # Method 0: Local file - only for admin use. if not user.is_admin: @@ -2713,7 +2713,7 @@ def _query_mongodb(**kwargs): return Upload.objects(**kwargs) -def get_role_query(roles: List[UploadRole], user: User, include_all=False) -> Q: +def get_role_query(roles: list[UploadRole], user: User, include_all=False) -> Q: """ Create MongoDB filter query for user with given roles (default: all roles) """ @@ -2733,7 +2733,7 @@ def get_role_query(roles: List[UploadRole], user: User, include_all=False) -> Q: return role_query -def is_user_upload_viewer(upload: Upload, user: Optional[User]): +def is_user_upload_viewer(upload: Upload, user: User | None): if 'all' in upload.reviewer_groups: return True @@ -2768,7 +2768,7 @@ def is_user_upload_writer(upload: Upload, user: User): def get_upload_with_read_access( - upload_id: str, user: Optional[User], include_others: bool = False + upload_id: str, user: User | None, include_others: bool = False ) -> Upload: """ Determines if the user has read access to the upload. If so, the corresponding Upload diff --git a/nomad/app/v1/routers/users.py b/nomad/app/v1/routers/users.py index 58c8df11e972736024a8800fc61a1c03fa52a9b8..e5f92534f4312dd824ceb7b489a62e172f43f5a8 100644 --- a/nomad/app/v1/routers/users.py +++ b/nomad/app/v1/routers/users.py @@ -57,7 +57,7 @@ _bad_invite_response = ( class Users(BaseModel): - data: List[User] + data: list[User] @router.get( @@ -91,7 +91,7 @@ async def read_users_me( response_model=Users, ) async def get_users( - prefix: Optional[str] = Query( + prefix: str | None = Query( None, description=strip( """ @@ -99,7 +99,7 @@ async def get_users( """ ), ), - user_id: Union[List[str], None] = Query( + user_id: list[str] | None = Query( None, description=strip( """ @@ -107,7 +107,7 @@ async def get_users( """ ), ), - username: Union[List[str], None] = Query( + username: list[str] | None = Query( None, description=strip( """ @@ -115,7 +115,7 @@ async def get_users( """ ), ), - email: Union[List[str], None] = Query( + email: list[str] | None = Query( None, description=strip( """ @@ -124,7 +124,7 @@ async def get_users( ), ), ): - users: List[User] = [] + users: list[User] = [] for key, values in dict(user_id=user_id, username=username, email=email).items(): if not values: continue diff --git a/nomad/app/v1/utils.py b/nomad/app/v1/utils.py index e242d26765382c2903e6186bf9892d55b479390a..319958b76750f4ce4e5398b4e1986551a5495dad 100644 --- a/nomad/app/v1/utils.py +++ b/nomad/app/v1/utils.py @@ -16,7 +16,8 @@ # limitations under the License. # -from typing import List, Dict, Set, Iterator, Any, Optional, Union +from typing import List, Dict, Set, Any, Optional, Union +from collections.abc import Iterator from types import FunctionType import urllib import io @@ -31,7 +32,7 @@ from nomad.files import UploadFiles, StreamedFile, create_zipstream def parameter_dependency_from_model( - name: str, model_cls: BaseModel, exclude: List[str] = [] + name: str, model_cls: BaseModel, exclude: list[str] = [] ) -> FunctionType: """ Takes a pydantic model class as input and creates a dependency with corresponding @@ -46,7 +47,7 @@ def parameter_dependency_from_model( model_cls: A ``BaseModel`` inheriting model class as input. """ names = [] - annotations: Dict[str, type] = {} + annotations: dict[str, type] = {} defaults = [] for field_name, field_model in model_cls.model_fields.items(): try: @@ -75,7 +76,7 @@ def parameter_dependency_from_model( name, ', '.join(names), model_cls.__name__, # type: ignore - ', '.join(['%s=%s' % (name, name) for name in names]), + ', '.join([f'{name}={name}' for name in names]), ) ) @@ -95,11 +96,11 @@ class DownloadItem(BaseModel): upload_id: str raw_path: str zip_path: str - entry_metadata: Optional[Dict[str, Any]] = None + entry_metadata: dict[str, Any] | None = None async def create_download_stream_zipped( - download_items: Union[DownloadItem, Iterator[DownloadItem]], + download_items: DownloadItem | Iterator[DownloadItem], upload_files: UploadFiles = None, re_pattern: Any = None, recursive: bool = False, @@ -126,7 +127,7 @@ async def create_download_stream_zipped( if isinstance(download_items, DownloadItem) else download_items ) - streamed_paths: Set[str] = set() + streamed_paths: set[str] = set() for download_item in items: if upload_files and upload_files.upload_id != download_item.upload_id: @@ -249,7 +250,7 @@ def create_responses(*args): def browser_download_headers( filename: str, media_type: str = 'application/octet-stream' -) -> Dict[str, str]: +) -> dict[str, str]: """ Creates standardized headers which tells browsers that they should download the data to a file with the specified filename. Note, the `media_type` should normally be @@ -282,7 +283,7 @@ def update_url_query_arguments(original_url: str, **kwargs) -> str: return urllib.parse.urlunparse((scheme, netloc, path, params, query, fragment)) -def convert_data_to_dict(data: Any) -> Dict[str, Any]: +def convert_data_to_dict(data: Any) -> dict[str, Any]: """ Converts a pydantic model or a dictionary containing pydantic models to a dictionary. diff --git a/nomad/archive/converter.py b/nomad/archive/converter.py index 4b672f67a65f17db136e4dd961425ad56e4acc95..2f581b032bde61db701ec18dd621fc9688ffae02 100644 --- a/nomad/archive/converter.py +++ b/nomad/archive/converter.py @@ -23,7 +23,8 @@ import os.path import signal from concurrent.futures import ProcessPoolExecutor from multiprocessing import Manager -from typing import Iterable, Callable +from collections.abc import Callable +from collections.abc import Iterable from nomad.config import config from nomad.archive import to_json, read_archive diff --git a/nomad/archive/partial.py b/nomad/archive/partial.py index 40475b6d480aa0ef9b4322c18a4e74b5da92383a..a1b0e44a244bbb432380916cf742b128d9fc1194 100644 --- a/nomad/archive/partial.py +++ b/nomad/archive/partial.py @@ -32,7 +32,7 @@ from nomad.datamodel import EntryArchive from nomad.datamodel.metainfo.common import FastAccess -def create_partial_archive(archive: EntryArchive) -> Dict: +def create_partial_archive(archive: EntryArchive) -> dict: """ Creates a partial archive JSON serializable dict that can be stored directly. The given archive is filtered based on the metainfo category ``FastAccess``. @@ -46,10 +46,10 @@ def create_partial_archive(archive: EntryArchive) -> Dict: """ # A list with all referenced sections that might not yet been ensured to be in the # resulting partial archive - referenceds: List[MSection] = [] + referenceds: list[MSection] = [] # contents keeps track of all sections in the partial archive by keeping their # JSON serializable form and placeholder status in a dict - contents: Dict[MSection, Tuple[dict, bool]] = dict() + contents: dict[MSection, tuple[dict, bool]] = dict() def partial(definition: Definition, section: MSection) -> bool: """ @@ -87,7 +87,7 @@ def create_partial_archive(archive: EntryArchive) -> Dict: the section's serialization is added (or replacing an existing placeholder). Otherwise, an empty dict is added as a placeholder for referenced children. """ - result: Dict[str, Any] = None + result: dict[str, Any] = None content, content_is_placeholder = contents.get(section, (None, True)) if content is not None: if content_is_placeholder and not placeholder: @@ -141,7 +141,7 @@ def write_partial_archive_to_mongo(archive: EntryArchive): def read_partial_archive_from_mongo( entry_id: str, as_dict=False -) -> Union[EntryArchive, Dict]: +) -> EntryArchive | dict: """ Reads the partial archive for the given id from mongodb. @@ -160,15 +160,15 @@ def read_partial_archive_from_mongo( return EntryArchive.m_from_dict(archive_dict) -def delete_partial_archives_from_mongo(entry_ids: List[str]): +def delete_partial_archives_from_mongo(entry_ids: list[str]): mongo_db = infrastructure.mongo_client[config.mongo.db_name] mongo_collection = mongo_db['archive'] mongo_collection.delete_many(dict(_id={'$in': entry_ids})) def read_partial_archives_from_mongo( - entry_ids: List[str], as_dict=False -) -> Dict[str, Union[EntryArchive, Dict]]: + entry_ids: list[str], as_dict=False +) -> dict[str, EntryArchive | dict]: """ Reads the partial archives for a set of entries. @@ -193,7 +193,7 @@ def read_partial_archives_from_mongo( } -__all_parent_sections: Dict[Section, Tuple[str, Section]] = {} +__all_parent_sections: dict[Section, tuple[str, Section]] = {} def _all_parent_sections(): @@ -261,7 +261,7 @@ def compute_required_with_referenced(required): return result - def traverse(current: Union[dict, str], parent: Section = EntryArchive.m_def): + def traverse(current: dict | str, parent: Section = EntryArchive.m_def): if isinstance(current, str): return diff --git a/nomad/archive/query.py b/nomad/archive/query.py index 84de7db850b8256bc3f14a488691579d478c122d..822392656dbb0ca330da8a34ad6b3f52ff774e61 100644 --- a/nomad/archive/query.py +++ b/nomad/archive/query.py @@ -18,7 +18,8 @@ import functools import re -from typing import Any, Dict, Callable, Union, Tuple +from typing import Any, Dict, Union, Tuple +from collections.abc import Callable from io import BytesIO from nomad import utils @@ -37,11 +38,11 @@ def _fix_index(index, length): @functools.lru_cache(maxsize=1024) -def _extract_key_and_index(match) -> Tuple[str, Union[Tuple[int, int], int]]: +def _extract_key_and_index(match) -> tuple[str, tuple[int, int] | int]: key = match.group(1) # noinspection PyTypeChecker - index: Union[Tuple[int, int], int] = None + index: tuple[int, int] | int = None # check if we have indices if match.group(2) is not None: @@ -59,7 +60,7 @@ def _extract_key_and_index(match) -> Tuple[str, Union[Tuple[int, int], int]]: # @cached(thread_safe=False, max_size=1024) -def _extract_child(archive_item, prop, index) -> Union[dict, list]: +def _extract_child(archive_item, prop, index) -> dict | list: archive_child = archive_item[prop] from .storage_v2 import ArchiveList as ArchiveListNew @@ -94,8 +95,8 @@ class ArchiveQueryError(Exception): def query_archive( - f_or_archive_reader: Union[str, ArchiveReader, BytesIO], query_dict: dict, **kwargs -) -> Dict: + f_or_archive_reader: str | ArchiveReader | BytesIO, query_dict: dict, **kwargs +) -> dict: """ Takes an open msg-pack based archive (either as str, reader, or BytesIO) and returns the archive as JSON serializable dictionary filtered based on the given required @@ -141,7 +142,7 @@ def query_archive( ) -def _load_data(query_dict: Dict[str, Any], archive_item: ArchiveDict) -> Dict: +def _load_data(query_dict: dict[str, Any], archive_item: ArchiveDict) -> dict: query_dict_with_fixed_ids = { utils.adjust_uuid_size(key): value for key, value in query_dict.items() } @@ -149,12 +150,12 @@ def _load_data(query_dict: Dict[str, Any], archive_item: ArchiveDict) -> Dict: def filter_archive( - required: Union[str, Dict[str, Any]], - archive_item: Union[Dict, ArchiveDict, str], + required: str | dict[str, Any], + archive_item: dict | ArchiveDict | str, transform: Callable, - result_root: Dict = None, + result_root: dict = None, resolve_inplace: bool = False, -) -> Dict: +) -> dict: if archive_item is None: return None @@ -185,7 +186,7 @@ def filter_archive( f'resolving references in non partial archives is not yet implemented' ) - result: Dict[str, Any] = {} + result: dict[str, Any] = {} for key, val in required.items(): key = key.strip() diff --git a/nomad/archive/required.py b/nomad/archive/required.py index 905cd07378695fdd7c25a072ac544fa00e2b5c13..f3ef389c4549f54879ca4a05bcc098164b1a9efe 100644 --- a/nomad/archive/required.py +++ b/nomad/archive/required.py @@ -56,7 +56,7 @@ class RequiredValidationError(Exception): @functools.lru_cache(maxsize=1024) -def _parse_required_key(key: str) -> Tuple[str, Union[Tuple[int, int], int]]: +def _parse_required_key(key: str) -> tuple[str, tuple[int, int] | int]: key = key.strip() match = _query_archive_key_pattern.match(key) @@ -66,7 +66,7 @@ def _parse_required_key(key: str) -> Tuple[str, Union[Tuple[int, int], int]]: return _extract_key_and_index(match) -def _setdefault(target: Union[dict, list], key, value_type: type): +def _setdefault(target: dict | list, key, value_type: type): if isinstance(target, list): if target[key] is None: target[key] = value_type() @@ -141,7 +141,7 @@ class RequiredReader: def __init__( self, - required: Union[dict, str], + required: dict | str, root_section_def: Section = None, resolve_inplace: bool = False, user=None, @@ -453,9 +453,9 @@ class RequiredReader: return resolved_result path_stack.reverse() - target_container: Union[dict, list] = dataset.result_root + target_container: dict | list = dataset.result_root # noinspection PyTypeChecker - prop_or_index: Union[str, int] = None + prop_or_index: str | int = None while len(path_stack) > 0: if prop_or_index is not None: target_container = _setdefault(target_container, prop_or_index, dict) @@ -540,9 +540,9 @@ class RequiredReader: def _apply_required( self, required: dict | str, - archive_item: Union[dict, str], + archive_item: dict | str, dataset: RequiredReferencedArchive, - ) -> Union[Dict, str]: + ) -> dict | str: if archive_item is None: return None # type: ignore diff --git a/nomad/archive/storage.py b/nomad/archive/storage.py index 68cb2f47d2e472a8c4f7234a5fcff7983518d8f3..c0baeb75c87900e9f73fa249ada9802d52022aaf 100644 --- a/nomad/archive/storage.py +++ b/nomad/archive/storage.py @@ -17,7 +17,8 @@ # from __future__ import annotations -from typing import Any, Tuple, Dict, Union, cast, Generator +from typing import Any, Tuple, Dict, Union, cast +from collections.abc import Generator from io import BytesIO, BufferedReader from collections.abc import Mapping, Sequence @@ -38,13 +39,13 @@ def unpackb(o): return msgspec.msgpack.decode(o) -def _decode(position: bytes) -> Tuple[int, int]: +def _decode(position: bytes) -> tuple[int, int]: return int.from_bytes( position[:5], byteorder='little', signed=False ), int.from_bytes(position[5:], byteorder='little', signed=False) -def _unpack_entry(data: bytes) -> Tuple[Any, Tuple[Any, Any]]: +def _unpack_entry(data: bytes) -> tuple[Any, tuple[Any, Any]]: entry_uuid = unpackb(data[:_toc_uuid_size]) positions_encoded = unpackb(data[_toc_uuid_size:]) return entry_uuid, (_decode(positions_encoded[0]), _decode(positions_encoded[1])) @@ -73,7 +74,7 @@ class ArchiveItem: self._f.seek(offset) return self._f.read(size) - def _read(self, position: Tuple[int, int]): + def _read(self, position: tuple[int, int]): start, end = position raw_data = self._direct_read(end - start, start + self._offset) return unpackb(raw_data) @@ -140,7 +141,7 @@ class ArchiveDict(ArchiveItem, Mapping): class ArchiveReader(ArchiveDict): - def __init__(self, file_or_path: Union[str, BytesIO], use_blocked_toc=True): + def __init__(self, file_or_path: str | BytesIO, use_blocked_toc=True): self._file_or_path = file_or_path if isinstance(self._file_or_path, str): @@ -189,7 +190,7 @@ class ArchiveReader(ArchiveDict): 'Archive top-level TOC is not a msgpack map (dictionary).' ) - self._toc: Dict[str, Any] = {} + self._toc: dict[str, Any] = {} self._toc_block_info = [None] * (self._toc_number // _entries_per_block + 1) def __enter__(self): @@ -313,7 +314,7 @@ class ArchiveReader(ArchiveDict): return self._f.closed if isinstance(self._file_or_path, str) else True -def read_archive(file_or_path: Union[str, BytesIO], **kwargs) -> ArchiveReader: +def read_archive(file_or_path: str | BytesIO, **kwargs) -> ArchiveReader: """ Allows to read a msgpack-based archive. diff --git a/nomad/atomutils.py b/nomad/atomutils.py index 07e9959918b2373c75426b692d8cd765d50c5c86..b6c089282cdb753994a1c5b71bc291da265937e9 100644 --- a/nomad/atomutils.py +++ b/nomad/atomutils.py @@ -29,12 +29,12 @@ from typing import ( TYPE_CHECKING, Any, Dict, - Iterable, List, Tuple, Union, cast, ) +from collections.abc import Iterable import ase.data import ase.geometry @@ -94,7 +94,7 @@ def get_summed_mass(atomic_numbers=None, masses=None, indices=None, atom_labels= def get_masses_from_computational_model( archive, repr_system: System = None, method_index: int = -1 -) -> Union[List[float], Dict[str, float]]: +) -> list[float] | dict[str, float]: """ Gets the masses based on the masses provided in atom parameters of the computational model. Only returns the mass list in case @@ -205,7 +205,7 @@ def is_valid_basis(basis: NDArray[Any]) -> bool: def translate_pretty( - fractional: NDArray[Any], pbc: Union[bool, NDArray[Any]] + fractional: NDArray[Any], pbc: bool | NDArray[Any] ) -> NDArray[Any]: """Translates atoms such that fractional positions are minimized.""" pbc = pbc2pbc(pbc) @@ -226,7 +226,7 @@ def translate_pretty( def get_center_of_positions( positions: NDArray[Any], cell: NDArray[Any] = None, - pbc: Union[bool, NDArray[Any]] = True, + pbc: bool | NDArray[Any] = True, weights=None, relative=False, ) -> NDArray[Any]: @@ -281,7 +281,7 @@ def get_center_of_positions( def wrap_positions( positions: NDArray[Any], cell: NDArray[Any] = None, - pbc: Union[bool, NDArray[Any]] = True, + pbc: bool | NDArray[Any] = True, center: NDArray[Any] = [0.5, 0.5, 0.5], pretty_translation=False, eps: float = 1e-12, @@ -329,7 +329,7 @@ def wrap_positions( def unwrap_positions( positions: NDArray[Any], cell: NDArray[Any], - pbc: Union[bool, NDArray[Any]], + pbc: bool | NDArray[Any], relative=False, ) -> NDArray[Any]: """ @@ -363,7 +363,7 @@ def unwrap_positions( ) -def chemical_symbols(atomic_numbers: Iterable[int]) -> List[str]: +def chemical_symbols(atomic_numbers: Iterable[int]) -> list[str]: """ Converts atomic numbers to chemical_symbols. @@ -435,9 +435,7 @@ def reciprocal_cell(cell: NDArray[Any]) -> NDArray[Any]: return np.linalg.pinv(cell).transpose() -def find_match( - pos: NDArray[Any], positions: NDArray[Any], eps: float -) -> Union[int, None]: +def find_match(pos: NDArray[Any], positions: NDArray[Any], eps: float) -> int | None: """ Attempts to find a position within a larger list of positions. @@ -537,7 +535,7 @@ def cell_to_cellpar(cell: NDArray[Any], degrees=False) -> NDArray[Any]: def get_symmetry_string( - space_group: int, wyckoff_sets: List[Any], is_2d: bool = False + space_group: int, wyckoff_sets: list[Any], is_2d: bool = False ) -> str: """ Used to serialize symmetry information into a string. The Wyckoff @@ -566,20 +564,20 @@ def get_symmetry_string( element = group.element wyckoff_letter = group.wyckoff_letter n_atoms = len(group.indices) - i_string = '{} {} {}'.format(element, wyckoff_letter, n_atoms) + i_string = f'{element} {wyckoff_letter} {n_atoms}' wyckoff_strings.append(i_string) wyckoff_string = ', '.join(sorted(wyckoff_strings)) if is_2d: - string = '2D {} {}'.format(space_group, wyckoff_string) + string = f'2D {space_group} {wyckoff_string}' else: - string = '{} {}'.format(space_group, wyckoff_string) + string = f'{space_group} {wyckoff_string}' return string def get_hill_decomposition( atom_labels: NDArray[Any], reduced: bool = False -) -> Tuple[List[str], List[int]]: +) -> tuple[list[str], list[int]]: """ Given a list of atomic labels, returns the chemical formula using the Hill system (https://en.wikipedia.org/wiki/Hill_system) with an exception @@ -694,7 +692,7 @@ def get_formula_string(symbols: Iterable[str], counts: Iterable[int]) -> str: def get_normalized_wyckoff( atomic_numbers: NDArray[Any], wyckoff_letters: NDArray[Any] -) -> Dict[str, Dict[str, int]]: +) -> dict[str, dict[str, int]]: """ Returns a normalized Wyckoff sequence for the given atomic numbers and corresponding wyckoff letters. In a normalized sequence the chemical @@ -711,7 +709,7 @@ def get_normalized_wyckoff( 'X_<index>'. """ # Count the occurrence of each chemical species - atom_count: Dict[int, int] = {} + atom_count: dict[int, int] = {} for atomic_number in atomic_numbers: atom_count[atomic_number] = atom_count.get(atomic_number, 0) + 1 @@ -960,7 +958,7 @@ class Formula: ) self._original_formula = formula - def count(self) -> Dict[str, int]: + def count(self) -> dict[str, int]: """Return dictionary that maps chemical symbol to number of atoms.""" return self._count.copy() @@ -990,11 +988,11 @@ class Formula: else: raise ValueError(f'Invalid format option "{fmt}"') - def elements(self) -> List[str]: + def elements(self) -> list[str]: """Returns the list of chemical elements present in the formula.""" return sorted(self.count().keys()) - def atomic_fractions(self) -> Dict[str, float]: + def atomic_fractions(self) -> dict[str, float]: """ Returns dictionary that maps chemical symbol to atomic fraction. @@ -1007,7 +1005,7 @@ class Formula: atomic_fractions = {key: value / total_count for key, value in count.items()} return atomic_fractions - def mass_fractions(self) -> Dict[str, float]: + def mass_fractions(self) -> dict[str, float]: """ Returns a dictionary that maps chemical symbol to mass fraction. @@ -1026,7 +1024,7 @@ class Formula: } return mass_fractions - def elemental_composition(self) -> List[ElementalComposition]: + def elemental_composition(self) -> list[ElementalComposition]: """ Returns the atomic and mass fractions as a list of `ElementalComposition` objects. Any unknown elements are ignored. @@ -1052,8 +1050,8 @@ class Formula: def populate( self, - section: Union[Material, System], - descriptive_format: Union[str, None] = 'original', + section: Material | System, + descriptive_format: str | None = 'original', overwrite: bool = False, ) -> None: """ @@ -1115,9 +1113,7 @@ class Formula: n_matched_chars = sum([len(match[0]) for match in matches]) n_formula = len(formula.strip()) if n_matched_chars == n_formula: - formula = ''.join( - ['{}{}'.format(match[1], match[2]) for match in matches] - ) + formula = ''.join([f'{match[1]}{match[2]}' for match in matches]) return formula def _formula_hill(self) -> str: @@ -1250,7 +1246,7 @@ class Formula: } return self._dict2str(count_anonymous) - def _dict2str(self, dct: Dict[str, int]) -> str: + def _dict2str(self, dct: dict[str, int]) -> str: """Convert symbol-to-count dict to a string. Omits the chemical proportion number 1. diff --git a/nomad/bundles.py b/nomad/bundles.py index 09adea0e2fe4d3fba2dee36c2818ae6bdf0a3169..126f0a2c71aa003a72ee277beab12c621f05efff 100644 --- a/nomad/bundles.py +++ b/nomad/bundles.py @@ -8,7 +8,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import cast, Any, Tuple, List, Set, Dict, Iterable +from typing import cast, Any, Tuple, List, Set, Dict +from collections.abc import Iterable import os import json from datetime import datetime, timedelta @@ -133,14 +134,11 @@ class BundleExporter: ) # 2. Files from the upload dir - for file_source in self.upload.upload_files.files_to_bundle( - self.export_settings - ): - yield file_source + yield from self.upload.upload_files.files_to_bundle(self.export_settings) def _create_bundle_info(self): """Create the bundle_info.json data""" - bundle_info: Dict[str, Any] = dict( + bundle_info: dict[str, Any] = dict( upload_id=self.upload.upload_id, source=config.meta.dict(), # Information about the source system, i.e. this NOMAD installation export_settings=self.export_settings.dict(), @@ -150,7 +148,7 @@ class BundleExporter: ], ) # Handle datasets - dataset_ids: Set[str] = set() + dataset_ids: set[str] = set() for entry_dict in bundle_info['entries']: entry_datasets = entry_dict.get('datasets') if entry_datasets: @@ -198,7 +196,7 @@ class BundleImporter: self.bundle: BrowsableFileSource = None self.upload: Upload = None self.upload_files: UploadFiles = None - self._bundle_info: Dict[str, Any] = None + self._bundle_info: dict[str, Any] = None @classmethod def looks_like_a_bundle(cls, path): @@ -301,9 +299,9 @@ class BundleImporter: self.upload = upload logger = self.upload.get_logger(bundle_path=self.bundle_path) current_time = datetime.utcnow() - new_datasets: List[datamodel.Dataset] = [] - dataset_id_mapping: Dict[str, str] = {} - entry_data_to_index: List[datamodel.EntryArchive] = [] # Data to index in ES + new_datasets: list[datamodel.Dataset] = [] + dataset_id_mapping: dict[str, str] = {} + entry_data_to_index: list[datamodel.EntryArchive] = [] # Data to index in ES try: self._check_bundle_and_settings(running_locally) self._import_upload_mongo_data(current_time) @@ -490,7 +488,7 @@ class BundleImporter: and 0 <= self.upload.embargo_length <= 36 ), 'Invalid embargo_length, must be between 0 and 36 months' - def _import_datasets(self) -> Tuple[List[datamodel.Dataset], Dict[str, str]]: + def _import_datasets(self) -> tuple[list[datamodel.Dataset], dict[str, str]]: """Creates datasets from the bundle.""" required_keys_datasets = ('dataset_id', 'dataset_name', 'user_id') @@ -498,8 +496,8 @@ class BundleImporter: 'Missing datasets definition in bundle_info.json' ) datasets = self.bundle_info['datasets'] - new_datasets: List[datamodel.Dataset] = [] - dataset_id_mapping: Dict[ + new_datasets: list[datamodel.Dataset] = [] + dataset_id_mapping: dict[ str, str ] = {} # Map from old to new id (usually the same) for dataset_dict in datasets: @@ -537,7 +535,7 @@ class BundleImporter: def _import_entries_mongo_data( self, current_time, dataset_id_mapping - ) -> List[Entry]: + ) -> list[Entry]: """Creates mongo entries from the data in the bundle_info""" required_keys_entry_level = ( '_id', @@ -640,8 +638,8 @@ class BundleImporter: raise def _get_entry_data_to_index( - self, entries: List[Entry] - ) -> List[datamodel.EntryArchive]: + self, entries: list[Entry] + ) -> list[datamodel.EntryArchive]: entry_data_to_index = [] if self.import_settings.include_archive_files: for entry in entries: @@ -655,7 +653,7 @@ class BundleImporter: self.upload_files.close() # Because full_entry_metadata reads the archive files. return entry_data_to_index - def _index_search(self, entry_data_to_index: List[datamodel.EntryArchive]): + def _index_search(self, entry_data_to_index: list[datamodel.EntryArchive]): # Index in elastic search if entry_data_to_index: search.index( @@ -670,7 +668,7 @@ class BundleImporter: ) -def keys_exist(data: Dict[str, Any], required_keys: Iterable[str], error_message: str): +def keys_exist(data: dict[str, Any], required_keys: Iterable[str], error_message: str): """ Checks if the specified keys exist in the provided dictionary structure `data`. Supports dot-notation to access subkeys. diff --git a/nomad/cli/admin/admin.py b/nomad/cli/admin/admin.py index 5f09940322f202033674dbd84d01cef5a11d893f..5871edd593e1d16707e443e069a4353207586b4f 100644 --- a/nomad/cli/admin/admin.py +++ b/nomad/cli/admin/admin.py @@ -445,7 +445,7 @@ def migrate_mongo( print('Cannot specify a query when using --ids-from-file.') return -1 try: - with open(ids_from_file, 'r') as f: + with open(ids_from_file) as f: upload_ids = [line.strip() for line in f.readlines() if line.strip()] except FileNotFoundError: print(f'Could not open file {ids_from_file}', file=sys.stderr) @@ -536,5 +536,5 @@ def rewrite_doi_urls(dois, dry, save_existing_records): edit_doi_url(doi) finally: if save_existing_records: - with open(save_existing_records, 'wt') as f: + with open(save_existing_records, 'w') as f: json.dump(existing_records, f, indent=2) diff --git a/nomad/cli/admin/migrate.py b/nomad/cli/admin/migrate.py index cce076a4f72914bf8d79ed420906613a380a8b92..e857c013e657121fb330afa213aee002ea31e4f1 100644 --- a/nomad/cli/admin/migrate.py +++ b/nomad/cli/admin/migrate.py @@ -76,8 +76,8 @@ class _UpgradeStatistics(BaseModel): class _DatasetCacheItem(BaseModel): - converted_dataset_dict: Optional[Dict[str, Any]] = None - converted_doi_dict: Optional[Dict[str, Any]] = None + converted_dataset_dict: dict[str, Any] | None = None + converted_doi_dict: dict[str, Any] | None = None ready_to_commit: bool = False @@ -99,8 +99,8 @@ def migrate_mongo_uploads( db_dst: Database, uploads_query: Any, failed_ids_to_file: bool, - upload_update: Dict[str, Any], - entry_update: Dict[str, Any], + upload_update: dict[str, Any], + entry_update: dict[str, Any], overwrite: str, fix_problems: bool, dry: bool, @@ -114,7 +114,7 @@ def migrate_mongo_uploads( src_entry_collection = ( db_src.calc if 'calc' in db_src.list_collection_names() else db_src.entry ) - dataset_cache: Dict[str, _DatasetCacheItem] = {} + dataset_cache: dict[str, _DatasetCacheItem] = {} stats = _UpgradeStatistics() stats.uploads.total = number_of_uploads count_treated = count_failures = count_processing = 0 @@ -232,11 +232,11 @@ def migrate_mongo_uploads( def _convert_mongo_upload( db_src: Database, src_entry_collection: Collection, - upload_dict: Dict[str, Any], - upload_update: Dict[str, Any], - entry_update: Dict[str, Any], + upload_dict: dict[str, Any], + upload_update: dict[str, Any], + entry_update: dict[str, Any], fix_problems: bool, - dataset_cache: Dict[str, _DatasetCacheItem], + dataset_cache: dict[str, _DatasetCacheItem], stats: _UpgradeStatistics, logger, ): @@ -280,12 +280,12 @@ def _convert_mongo_upload( first_entry_uploader = first_metadata.get('uploader') first_external_db = first_metadata.get('external_db') first_entry_coauthors = first_metadata.get('coauthors', ()) - common_coauthors = set(_wrap_author(ca) for ca in first_entry_coauthors) + common_coauthors = {_wrap_author(ca) for ca in first_entry_coauthors} fixed_external_db = False for entry_dict in entry_dicts: assert 'metadata' in entry_dict, 'Entry dict has no metadata key' - entry_metadata_dict: Dict[str, Any] = entry_dict['metadata'] + entry_metadata_dict: dict[str, Any] = entry_dict['metadata'] with_embargo = entry_metadata_dict.get('with_embargo') assert with_embargo == first_with_embargo, ( 'Inconsistent embargo settings for entries' @@ -344,9 +344,7 @@ def _convert_mongo_upload( if _wrap_author(ca) in common_coauthors ] else: - common_coauthors = set( - _wrap_author(ca) for ca in upload_dict.get('coauthors', ()) - ) + common_coauthors = {_wrap_author(ca) for ca in upload_dict.get('coauthors', ())} # Check that all required fields are there for field in ( @@ -364,7 +362,7 @@ def _convert_mongo_upload( upload_dict.update(upload_update) # migrate entries - newly_encountered_dataset_ids: Set[str] = set() + newly_encountered_dataset_ids: set[str] = set() for entry_dict in entry_dicts: assert not _is_processing(entry_dict), ( f'the entry {entry_dict["_id"]} has status processing, but the upload is not processing.' @@ -410,8 +408,8 @@ def _convert_mongo_upload( entry_dict.update(entry_update) # All conversion successful! Ready to migrate - dataset_dicts: List[Dict[str, Any]] = [] - doi_dicts: List[Dict[str, Any]] = [] + dataset_dicts: list[dict[str, Any]] = [] + doi_dicts: list[dict[str, Any]] = [] for dataset_id in newly_encountered_dataset_ids: ds_cache = dataset_cache[dataset_id] if not ds_cache.ready_to_commit: @@ -428,7 +426,7 @@ def _convert_mongo_upload( def _convert_mongo_entry( - entry_dict: Dict[str, Any], common_coauthors: Set, fix_problems: bool, logger + entry_dict: dict[str, Any], common_coauthors: set, fix_problems: bool, logger ): _convert_mongo_proc(entry_dict) # Validate the id and possibly fix problems @@ -490,7 +488,7 @@ def _convert_mongo_entry( assert parser_name in parser_dict, f'Parser does not exist: {parser_name}' -def _convert_mongo_proc(proc_dict: Dict[str, Any]): +def _convert_mongo_proc(proc_dict: dict[str, Any]): if 'tasks_status' in proc_dict: # Old v0 version process_status = proc_dict['tasks_status'] @@ -504,7 +502,7 @@ def _convert_mongo_proc(proc_dict: Dict[str, Any]): if not last_status_message: # Generate a nicer last_status_message current_process: str = proc_dict.get('current_process') - errors: List[str] = proc_dict.get('errors') + errors: list[str] = proc_dict.get('errors') if errors: last_status_message = f'Process {current_process} failed: {errors[-1]}' elif current_process and process_status == ProcessStatus.SUCCESS: @@ -521,7 +519,7 @@ def _convert_mongo_proc(proc_dict: Dict[str, Any]): proc_dict.pop(field, None) -def _convert_mongo_dataset(dataset_dict: Dict[str, Any]): +def _convert_mongo_dataset(dataset_dict: dict[str, Any]): _rename_key(dataset_dict, 'name', 'dataset_name') _rename_key(dataset_dict, 'created', 'dataset_create_time') _rename_key(dataset_dict, 'modified', 'dataset_modified_time') @@ -532,7 +530,7 @@ def _convert_mongo_dataset(dataset_dict: Dict[str, Any]): ) -def _convert_mongo_doi(doi_dict: Dict[str, Any]): +def _convert_mongo_doi(doi_dict: dict[str, Any]): pass @@ -567,7 +565,7 @@ def _get_dataset_cache_data( ) -def _is_processing(proc_dict: Dict[str, Any]) -> bool: +def _is_processing(proc_dict: dict[str, Any]) -> bool: process_status = proc_dict.get('tasks_status') # Used in v0 if not process_status: process_status = proc_dict['process_status'] @@ -575,10 +573,10 @@ def _is_processing(proc_dict: Dict[str, Any]) -> bool: def _commit_upload( - upload_dict: Dict[str, Any], - entry_dicts: List[Dict[str, Any]], - dataset_dicts: List[Dict[str, Any]], - doi_dicts: List[Dict[str, Any]], + upload_dict: dict[str, Any], + entry_dicts: list[dict[str, Any]], + dataset_dicts: list[dict[str, Any]], + doi_dicts: list[dict[str, Any]], db_dst: Database, stats: _UpgradeStatistics, ): @@ -614,7 +612,7 @@ def _commit_upload( stats.entries.migrated += len(entry_dicts) -def _rename_key(d: Dict[str, Any], old_name: str, new_name: str): +def _rename_key(d: dict[str, Any], old_name: str, new_name: str): """ Renames a key in the provided dictionary `d`, from `old_name` to `new_name`. We may use "point notation" in `old_name`, i.e. "metadata.external_id" will look for a diff --git a/nomad/cli/admin/run.py b/nomad/cli/admin/run.py index f3a2373b2af1854e39af2a819cca48f05d1e99bc..ab846f3e0e203bbe039094be8d0393eac403915d 100644 --- a/nomad/cli/admin/run.py +++ b/nomad/cli/admin/run.py @@ -106,12 +106,12 @@ def run_app( os.path.join(run_gui_folder, source_file_glob), recursive=True ) for source_file in source_files: - with open(source_file, 'rt') as f: + with open(source_file) as f: file_data = f.read() file_data = file_data.replace( '/fairdi/nomad/latest', config.services.api_base_path ) - with open(source_file, 'wt') as f: + with open(source_file, 'w') as f: f.write(file_data) # App and gui are served from the same server, same port. Replace the base urls with diff --git a/nomad/cli/admin/springer.py b/nomad/cli/admin/springer.py index ccd5c4d28ed2ebcbdb84440d1d52acc7748a7312..74dde0e7350ba1b56a4806f8736156e80f4d2540 100644 --- a/nomad/cli/admin/springer.py +++ b/nomad/cli/admin/springer.py @@ -49,7 +49,7 @@ search_re = re.compile(' href="(/isp/[^"]+)') formula_re = re.compile(r'([A-Z][a-z]?)([0-9.]*)|\[(.*?)\]([0-9]+)') -def _update_dict(dict0: Dict[str, float], dict1: Dict[str, float]): +def _update_dict(dict0: dict[str, float], dict1: dict[str, float]): for key, val in dict1.items(): if key in dict0: dict0[key] += val @@ -57,11 +57,11 @@ def _update_dict(dict0: Dict[str, float], dict1: Dict[str, float]): dict0[key] = val -def _components(formula_str: str, multiplier: float = 1.0) -> Dict[str, float]: +def _components(formula_str: str, multiplier: float = 1.0) -> dict[str, float]: # match atoms and molecules (in brackets) components = formula_re.findall(formula_str) - symbol_amount: Dict[str, float] = {} + symbol_amount: dict[str, float] = {} for component in components: element, amount_e, molecule, amount_m = component if element: @@ -93,7 +93,7 @@ def normalize_formula(formula_str: str) -> str: return ''.join(formula_sorted) -def parse(htmltext: str) -> Dict[str, str]: +def parse(htmltext: str) -> dict[str, str]: """ Parser the quantities in required_items from an html text. """ @@ -143,7 +143,7 @@ def parse(htmltext: str) -> Dict[str, str]: return results -def _merge_dict(dict0: Dict[str, Any], dict1: Dict[str, Any]) -> Dict[str, Any]: +def _merge_dict(dict0: dict[str, Any], dict1: dict[str, Any]) -> dict[str, Any]: if not isinstance(dict1, dict) or not isinstance(dict0, dict): return dict1 @@ -188,7 +188,7 @@ def update_springer(max_n_query: int = 10, retry_time: int = 120): config.normalize.springer_db_path, {spg: '*' for spg in archive_keys} ) - sp_ids: List[str] = [] + sp_ids: list[str] = [] for spg in sp_data: if not isinstance(sp_data[spg], dict): continue diff --git a/nomad/cli/admin/uploads.py b/nomad/cli/admin/uploads.py index b36ccfb772fa82690ab9f6d2e0dd8ae997321a2c..55d35f783e400a01d6d915893a9981e7a62b534d 100644 --- a/nomad/cli/admin/uploads.py +++ b/nomad/cli/admin/uploads.py @@ -47,7 +47,7 @@ def _run_parallel( ) # copy the whole mongo query set to avoid cursor timeouts cv = threading.Condition() - threads: typing.List[threading.Thread] = [] + threads: list[threading.Thread] = [] state = dict(completed_count=0, skipped_count=0, available_threads_count=parallel) @@ -259,7 +259,7 @@ def _query_uploads( if entries_mongo_query: entries_mongo_query_q = Q(**json.loads(entries_mongo_query)) - entries_query_uploads: Set[str] = None + entries_query_uploads: set[str] = None if entries_es_query is not None: entries_es_query_dict = json.loads(entries_es_query) @@ -278,12 +278,10 @@ def _query_uploads( }, ) - entries_query_uploads = set( - [ - cast(str, bucket.value) - for bucket in results.aggregations['uploads'].terms.data - ] - ) # pylint: disable=no-member + entries_query_uploads = { + cast(str, bucket.value) + for bucket in results.aggregations['uploads'].terms.data + } # pylint: disable=no-member if outdated: entries_mongo_query_q &= Q(nomad_version={'$ne': config.meta.version}) @@ -710,11 +708,11 @@ def process( uploads, parallel: int, process_running: bool, - setting: typing.List[str], + setting: list[str], print_progress: int, ): _, uploads = _query_uploads(uploads, **ctx.obj.uploads_kwargs) - settings: typing.Dict[str, bool] = {} + settings: dict[str, bool] = {} for settings_str in setting: key, value = settings_str.split('=') settings[key] = bool(value) diff --git a/nomad/cli/admin/users.py b/nomad/cli/admin/users.py index 4cf652bb1e6be0d60a56f293e60a4d1329ae4f16..49a54132d73b9ebcebc7fe38783ca6a47e683af0 100644 --- a/nomad/cli/admin/users.py +++ b/nomad/cli/admin/users.py @@ -34,7 +34,7 @@ def import_command(path_to_users_file): from nomad import infrastructure, datamodel, utils - with open(path_to_users_file, 'rt') as f: + with open(path_to_users_file) as f: users = json.load(f) logger = utils.get_logger(__name__) diff --git a/nomad/cli/aflow.py b/nomad/cli/aflow.py index 63f61583417337f7626a8e35d85d2403072c4af7..0548025e18bc522a880de218459857a4e7cf7c06 100644 --- a/nomad/cli/aflow.py +++ b/nomad/cli/aflow.py @@ -66,7 +66,7 @@ class DbUpdater: self.parallel = 2 self.max_depth = None self.target_file = None - self.uids: List[str] = [] + self.uids: list[str] = [] self._set(**kwargs) self.auth = auth @@ -103,7 +103,7 @@ class DbUpdater: self._session = requests.Session() - def _get_paths(self, root: str) -> typing.List[str]: + def _get_paths(self, root: str) -> list[str]: response = self._session.get(root, verify=False) if not response.ok: response.raise_for_status() @@ -143,13 +143,13 @@ class DbUpdater: line = f.readline() return data - def _write_to_file(self, data: typing.List, filename: str): + def _write_to_file(self, data: list, filename: str): with open(filename, 'w') as f: for i in range(len(data)): if isinstance(data[i], str): f.write('%s\n' % data[i]) else: - f.write('%s %s \n' % (data[i][0], data[i][1])) + f.write(f'{data[i][0]} {data[i][1]} \n') def get_db_list(self): if self.dbfile is not None and os.path.isfile(self.dbfile): @@ -237,7 +237,7 @@ class DbUpdater: Identify the difference between the nomad list and db list """ - def reduce_list(ilist: typing.List[str]): + def reduce_list(ilist: list[str]): olist = [] for e in ilist: p = urllib_parse.urlparse(e).path.strip('/') @@ -264,7 +264,7 @@ class DbUpdater: # add the root back u = urllib_parse.urlparse(self.root_url) up = u.path.strip('/').split('/')[0] - root = '%s://%s/%s' % (u.scheme, u.netloc, up) + root = f'{u.scheme}://{u.netloc}/{up}' self.update_list = [os.path.join(root, e) for e in self.update_list] self.is_updated_list = [False] * len(self.update_list) print('Found %d entries to be added in NOMAD' % len(self.update_list)) @@ -273,7 +273,7 @@ class DbUpdater: data = [self.update_list[i] for i in range(len(self.update_list))] self._write_to_file(data, self.outfile) - def _get_files(self, path: str) -> typing.Tuple[str, float]: + def _get_files(self, path: str) -> tuple[str, float]: def is_dir(path: str) -> bool: path = path.strip() if path[-1] == '/' and self.root_url in path: @@ -338,18 +338,18 @@ class DbUpdater: return dirname, size - def _make_name(self, dirs: typing.List[str]) -> typing.Tuple[str, str]: + def _make_name(self, dirs: list[str]) -> tuple[str, str]: # name will be first and last entries d1 = self._to_namesafe(dirs[0].lstrip(self._local_path)) d2 = self._to_namesafe(dirs[-1].lstrip(self._local_path)) - tarname = '%s-%s' % (d1, d2) - uploadname = '%s_%s' % (self.db_name.upper(), tarname) + tarname = f'{d1}-{d2}' + uploadname = f'{self.db_name.upper()}_{tarname}' tarname = os.path.join(self._local_path, '%s.tar' % tarname) return tarname, uploadname - def _cleanup(self, ilist: typing.Union[str, typing.List[str]]): + def _cleanup(self, ilist: str | list[str]): if isinstance(ilist, str): ilist = [ilist] for name in ilist: @@ -412,8 +412,8 @@ class DbUpdater: assert uid is not None return uid - def _download_proc(self, plist: typing.List[str]): - def tar_files(dirs: typing.List[str], tarname: str): + def _download_proc(self, plist: list[str]): + def tar_files(dirs: list[str], tarname: str): if os.path.isfile(tarname): return @@ -426,9 +426,9 @@ class DbUpdater: except Exception as e: os.remove(tarname) - print('Error writing tar file %s. %s' % (tarname, e)) + print(f'Error writing tar file {tarname}. {e}') - def get_status_upload(uploadname: str) -> typing.Tuple[str, str]: + def get_status_upload(uploadname: str) -> tuple[str, str]: response = api.get(f'uploads', params=dict(name=uploadname), auth=self.auth) assert response.status_code == 200 response_json = response.json() @@ -541,7 +541,7 @@ def write_prototype_data_file(aflow_prototypes: dict, filepath) -> None: aflow_prototypes """ - class NoIndent(object): + class NoIndent: def __init__(self, value): self.value = value @@ -551,7 +551,7 @@ def write_prototype_data_file(aflow_prototypes: dict, filepath) -> None: """ def __init__(self, *args, **kwargs): - super(NoIndentEncoder, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self.kwargs = dict(kwargs) del self.kwargs['indent'] self._replacement_map = {} @@ -560,14 +560,14 @@ def write_prototype_data_file(aflow_prototypes: dict, filepath) -> None: if isinstance(o, NoIndent): key = uuid.uuid4().hex self._replacement_map[key] = json.dumps(o.value, **self.kwargs) - return '@@%s@@' % (key,) + return f'@@{key}@@' else: - return super(NoIndentEncoder, self).default(o) + return super().default(o) def encode(self, o): - result = super(NoIndentEncoder, self).encode(o) + result = super().encode(o) for k, v in self._replacement_map.items(): - result = result.replace('"@@%s@@"' % (k,), v) + result = result.replace(f'"@@{k}@@"', v) return result prototype_dict = aflow_prototypes['prototypes_by_spacegroup'] @@ -585,7 +585,7 @@ def write_prototype_data_file(aflow_prototypes: dict, filepath) -> None: pass # Save the updated data - with io.open(filepath, 'w', encoding='utf8') as f: + with open(filepath, 'w', encoding='utf8') as f: json_dump = json.dumps( aflow_prototypes, ensure_ascii=False, @@ -596,7 +596,7 @@ def write_prototype_data_file(aflow_prototypes: dict, filepath) -> None: json_dump = re.sub( r'\"(-?\d+(?:[\.,]\d+)?)\"', r'\1', json_dump ) # Removes quotes around numbers - f.write('aflow_prototypes = {}\n'.format(json_dump)) + f.write(f'aflow_prototypes = {json_dump}\n') def update_prototypes(ctx, filepath, matches_only): @@ -690,7 +690,7 @@ def update_prototypes(ctx, filepath, matches_only): newdict['atom_labels'] = atom_labels newdictarray.append(newdict) - print('Processed: {}'.format(len(newdictarray))) + print(f'Processed: {len(newdictarray)}') # Sort prototype dictionaries by spacegroup and make dictionary structure_types_by_spacegroup = {} diff --git a/nomad/cli/dev.py b/nomad/cli/dev.py index fe7894fe9d4bbed73cd0c261b0444dde010ef239..074c88bd1093d68b0d23d459c51164efe2bdd926 100644 --- a/nomad/cli/dev.py +++ b/nomad/cli/dev.py @@ -362,7 +362,7 @@ def update_parser_readmes(parser): parser_path = './dependencies/parsers/' # Open general template - with open(generic_fn, 'r') as generic: # read only + with open(generic_fn) as generic: # read only generic_contents = generic.read() # Replace the comment at the top of the gereral template @@ -376,7 +376,7 @@ def update_parser_readmes(parser): def open_metadata(path): # read local yaml metadata file - with open(path, 'r') as metadata_f: + with open(path) as metadata_f: try: metadata = yaml.load(metadata_f, Loader=yaml.SafeLoader) except Exception as e: @@ -489,7 +489,7 @@ def example_data(username: str): return data -def _generate_units_json() -> Tuple[Any, Any]: +def _generate_units_json() -> tuple[Any, Any]: from pint.converters import ScaleConverter from collections import defaultdict from nomad.units import ureg diff --git a/nomad/client/archive.py b/nomad/client/archive.py index 671765afd2d2673d5f7c1ff63453bcd5669c0686..f68d6f88757dc1614259829dbee625f140542bc6 100644 --- a/nomad/client/archive.py +++ b/nomad/client/archive.py @@ -572,7 +572,7 @@ class ArchiveQuery: self, keys_to_filter: list[str] = None, resolve_references: bool = False, - query_selection: Union[str, list[str]] = 'last', + query_selection: str | list[str] = 'last', ): """ Interface to convert the archives to pandas dataframe. @@ -585,7 +585,7 @@ class ArchiveQuery: Returns: pandas dataframe of the downloaded (and selected) archives """ - t_list: Union[list[Any], dict] = [] + t_list: list[Any] | dict = [] if query_selection == 'all': t_list = [item for sublist in self._entries_dict for item in sublist] elif query_selection == 'last': diff --git a/nomad/client/processing.py b/nomad/client/processing.py index 7c84791e4bd4289d94a6acc2987164384136702d..97af552bb41fc83642d9a884ffad4c4bb1cc7cce 100644 --- a/nomad/client/processing.py +++ b/nomad/client/processing.py @@ -37,7 +37,7 @@ def parse( server_context: bool = False, username: str = None, password: str = None, -) -> typing.List[datamodel.EntryArchive]: +) -> list[datamodel.EntryArchive]: """ Run the given parser on the provided mainfile. If parser_name is given, we only try to match this parser, otherwise we try to match all parsers. @@ -73,9 +73,7 @@ def parse( return entry_archives -def normalize( - normalizer: typing.Union[str, typing.Callable], entry_archive, logger=None -): +def normalize(normalizer: str | typing.Callable, entry_archive, logger=None): from nomad import normalizing if logger is None: @@ -206,9 +204,7 @@ class LocalEntryProcessing: if exception: sys.exit(1) - def parse( - self, parser_name: str = None, **kwargs - ) -> typing.List[datamodel.EntryArchive]: + def parse(self, parser_name: str = None, **kwargs) -> list[datamodel.EntryArchive]: """ Run the given parser on the downloaded entry. If no parser is given, do parser matching and use the respective parser. @@ -222,9 +218,7 @@ class LocalEntryProcessing: **kwargs, ) - def normalize( - self, normalizer: typing.Union[str, typing.Callable], entry_archive=None - ): + def normalize(self, normalizer: str | typing.Callable, entry_archive=None): """ Parse the downloaded entry and run the given normalizer. """ diff --git a/nomad/common.py b/nomad/common.py index f9be021bd4f51091a9676d22b8228102cd1eb101..ce0d536ff66afc9088d9e245b69b1567f6010d30 100644 --- a/nomad/common.py +++ b/nomad/common.py @@ -27,7 +27,7 @@ import shutil import zipfile import tarfile from typing import Optional -from typing_extensions import Literal +from typing import Literal from tempfile import TemporaryDirectory import httpx @@ -61,7 +61,7 @@ def get_package_path(package_name: str) -> str: return package_path -def download_file(url: str, filepath: str) -> Optional[str]: +def download_file(url: str, filepath: str) -> str | None: """Used to download a file from the given URL to the given directory. Arg: @@ -117,7 +117,7 @@ decompress_file_extensions = { } -def get_compression_format(path: str) -> Optional[Literal['zip', 'tar', 'error']]: +def get_compression_format(path: str) -> Literal['zip', 'tar', 'error'] | None: """ Returns the decompression format ('zip', 'tar' or 'error') if `path` specifies a file which should be automatically decompressed before adding it to an upload. If `path` diff --git a/nomad/config/__init__.py b/nomad/config/__init__.py index b189ec9c4073c08305654c47f6b7e4e35f889911..303d2c0d4c63ba26481c65728d8d667be1dd2e2f 100644 --- a/nomad/config/__init__.py +++ b/nomad/config/__init__.py @@ -44,15 +44,15 @@ from nomad.config.models.config import Config logger = logging.getLogger(__name__) -def _load_config_yaml() -> Dict[str, Any]: +def _load_config_yaml() -> dict[str, Any]: """ Loads the configuration from a YAML file. """ config_file = os.environ.get('NOMAD_CONFIG', 'nomad.yaml') - config_data: Dict[str, Any] = {} + config_data: dict[str, Any] = {} if os.path.exists(config_file): - with open(config_file, 'r') as stream: + with open(config_file) as stream: try: config_data = yaml.load(stream, Loader=yaml.SafeLoader) except yaml.YAMLError as e: @@ -61,7 +61,7 @@ def _load_config_yaml() -> Dict[str, Any]: return config_data -def _load_config_env() -> Dict[str, Any]: +def _load_config_env() -> dict[str, Any]: """ Loads the configuration from environment variables. @@ -84,7 +84,7 @@ def _load_config_env() -> Dict[str, Any]: root[part] = new root = new - config_data: Dict[str, Any] = {} + config_data: dict[str, Any] = {} prefix = 'NOMAD_' for key, value in os.environ.items(): if key == 'NOMAD_CONFIG' or not key.startswith(prefix): @@ -102,7 +102,7 @@ def _load_config_env() -> Dict[str, Any]: return config_data -def _merge(*args) -> Dict[str, Any]: +def _merge(*args) -> dict[str, Any]: """ Recursively merge the given dictionaries one by one. @@ -138,7 +138,7 @@ def load_config() -> Config: """Custom config loader. Used instead of Pydantic BaseSettings because of custom merging logic and custom loading of environment variables. """ - with open(os.path.join(os.path.dirname(__file__), 'defaults.yaml'), 'r') as stream: + with open(os.path.join(os.path.dirname(__file__), 'defaults.yaml')) as stream: config_default = yaml.load(stream, Loader=yaml.SafeLoader) config_yaml = _load_config_yaml() config_env = _load_config_env() diff --git a/nomad/config/models/common.py b/nomad/config/models/common.py index 2aa9ec4380b5751d0a4cd3001cc8aacbea0a37ba..f091f34cf60fb03b5dde1fe12447935c70dd42bb 100644 --- a/nomad/config/models/common.py +++ b/nomad/config/models/common.py @@ -28,7 +28,7 @@ class ConfigBaseModel(BaseModel): def customize( self: ConfigBaseModelBound, - custom_settings: Union[ConfigBaseModelBound, Dict[str, Any]], + custom_settings: ConfigBaseModelBound | dict[str, Any], ) -> ConfigBaseModelBound: """ Returns a new config object, created by taking a copy of the current config and @@ -82,14 +82,14 @@ class ConfigBaseModel(BaseModel): class OptionsBase(ConfigBaseModel): """The most basic model for defining the availability of different options.""" - include: Optional[List[str]] = Field( + include: list[str] | None = Field( None, description=""" List of included options. If not explicitly defined, all of the options will be included by default. """, ) - exclude: Optional[List[str]] = Field( + exclude: list[str] | None = Field( None, description=""" List of excluded options. Has higher precedence than include. @@ -109,13 +109,13 @@ class OptionsGlob(ConfigBaseModel): using glob/wildcard syntax. """ - include: Optional[List[str]] = Field( + include: list[str] | None = Field( None, description=""" List of included options. Supports glob/wildcard syntax. """, ) - exclude: Optional[List[str]] = Field( + exclude: list[str] | None = Field( None, description=""" List of excluded options. Supports glob/wildcard syntax. Has higher precedence than include. @@ -128,11 +128,11 @@ class Options(OptionsBase): elements and defining the configuration of each element. """ - options: Optional[Dict[str, Any]] = Field( # type: ignore + options: dict[str, Any] | None = Field( # type: ignore {}, description='Contains the available options.' ) - def filtered_keys(self) -> List[str]: + def filtered_keys(self) -> list[str]: """Returns a list of keys that fullfill the include/exclude requirements. """ @@ -146,7 +146,7 @@ class Options(OptionsBase): exclude = self.exclude or [] return [key for key in include if key not in exclude] - def filtered_values(self) -> List[Any]: + def filtered_values(self) -> list[Any]: """Returns a list of values that fullfill the include/exclude requirements. """ @@ -154,7 +154,7 @@ class Options(OptionsBase): self.options[key] for key in self.filtered_keys() if key in self.options ] - def filtered_items(self) -> List[Tuple[str, Any]]: + def filtered_items(self) -> list[tuple[str, Any]]: """Returns a list of key/value pairs that fullfill the include/exclude requirements. """ @@ -174,4 +174,4 @@ class OptionsSingle(Options): class OptionsMulti(Options): """Represents options where multiple values can be selected.""" - selected: List[str] = Field(description='Selected options.') + selected: list[str] = Field(description='Selected options.') diff --git a/nomad/config/models/config.py b/nomad/config/models/config.py index 13fcefc14e7a26883fed9782e98bd91bd9682c1a..12c424e4ec985913e6dd26e846ef79a2bfd23e49 100644 --- a/nomad/config/models/config.py +++ b/nomad/config/models/config.py @@ -39,10 +39,7 @@ except Exception: # noqa # package is not installed pass -if sys.version_info < (3, 10): - from importlib_metadata import entry_points -else: - from importlib.metadata import entry_points +from importlib.metadata import entry_points from .common import ( @@ -149,7 +146,7 @@ class Services(ConfigBaseModel): True, description="""If true the app will serve the h5grove API.""" ) - console_log_level: Union[int, str] = Field( + console_log_level: int | str = Field( logging.WARNING, description=""" The log level that controls console logging for all NOMAD services (app, worker, north). @@ -312,7 +309,7 @@ class Oasis(ConfigBaseModel): False, description='Set to `True` to indicate that this deployment is a NOMAD Oasis.', ) - allowed_users: List[str] = Field( + allowed_users: list[str] = Field( None, description=""" A list of usernames or user account emails. These represent a white-list of @@ -361,7 +358,7 @@ class Celery(ConfigBaseModel): timeout: int = 1800 # 1/2hr acks_late: bool = False routing: str = CELERY_QUEUE_ROUTING - priorities: Dict[str, int] = { + priorities: dict[str, int] = { 'Upload.process_upload': 5, 'Upload.delete_upload': 9, 'Upload.publish_upload': 10, @@ -378,7 +375,7 @@ class FS(ConfigBaseModel): north_home_external: str = None local_tmp: str = '/tmp' prefix_size: int = 2 - archive_version_suffix: Union[str, List[str]] = Field( + archive_version_suffix: str | list[str] = Field( ['v1.2', 'v1'], description=""" This allows to add an additional segment to the names of archive files and @@ -453,15 +450,15 @@ class Mongo(ConfigBaseModel): ) port: int = Field(27017, description='The port to connect with mongodb.') db_name: str = Field('nomad_v1', description='The used mongodb database name.') - username: Optional[str] = None - password: Optional[str] = None + username: str | None = None + password: str | None = None class Logstash(ConfigBaseModel): enabled: bool = False host: str = 'localhost' tcp_port: str = '5000' - level: Union[int, str] = logging.DEBUG + level: int | str = logging.DEBUG # Validators _level = field_validator('level', mode='before')(normalize_loglevel) @@ -497,7 +494,7 @@ class Logtransfer(ConfigBaseModel): 600, description='Time interval in seconds after which stored logs are potentially transferred.', ) - level: Union[int, str] = Field( + level: int | str = Field( logging.INFO, description='The min log level for logs to be transferred.' ) log_file: str = Field( @@ -535,7 +532,7 @@ class Mail(ConfigBaseModel): user: str = '' password: str = '' from_address: str = 'support@nomad-lab.eu' - cc_address: Optional[str] = None + cc_address: str | None = None class Normalize(ConfigBaseModel): @@ -627,7 +624,7 @@ class Normalize(ConfigBaseModel): level in order to still detect a gap. Unit: Joule. """, ) - springer_db_path: Optional[str] = Field( + springer_db_path: str | None = Field( os.path.join( os.path.dirname(os.path.abspath(__file__)), 'normalizing/data/springer.msg' ) @@ -696,7 +693,7 @@ class Process(ConfigBaseModel): index_materials: bool = True reuse_parser: bool = True metadata_file_name: str = 'nomad' - metadata_file_extensions: Tuple[str, ...] = ('json', 'yaml', 'yml') + metadata_file_extensions: tuple[str, ...] = ('json', 'yaml', 'yml') auxfile_cutoff: int = 100 parser_matching_size: int = 150 * 80 max_upload_size: int = 32 * (1024**3) @@ -946,7 +943,7 @@ class Config(ConfigBaseModel): bundle_import: BundleImport = BundleImport() archive: Archive = Archive() ui: UI = UI() - plugins: Optional[Plugins] = None + plugins: Plugins | None = None def api_url( self, @@ -967,12 +964,12 @@ class Config(ConfigBaseModel): base = base[:-1] if page is not None: - return '%s/gui/%s' % (base, page) + return f'{base}/gui/{page}' return '%s/gui' % base def rabbitmq_url(self): - return 'pyamqp://%s:%s@%s//' % ( + return 'pyamqp://{}:{}@{}//'.format( self.rabbitmq.user, self.rabbitmq.password, self.rabbitmq.host, @@ -1119,7 +1116,7 @@ class Config(ConfigBaseModel): _plugins['plugin_packages'] = plugin_packages # Handle plugins defined in nomad.yaml (old plugin mechanism) - def load_plugin_yaml(name, values: Dict[str, Any]): + def load_plugin_yaml(name, values: dict[str, Any]): """Loads plugin metadata from nomad_plugin.yaml""" python_package = values.get('python_package') if not python_package: @@ -1135,7 +1132,7 @@ class Config(ConfigBaseModel): metadata_path = os.path.join(package_path, 'nomad_plugin.yaml') if os.path.exists(metadata_path): try: - with open(metadata_path, 'r', encoding='UTF-8') as f: + with open(metadata_path, encoding='UTF-8') as f: metadata = yaml.load(f, Loader=yaml.SafeLoader) except Exception as e: raise ValueError( diff --git a/nomad/config/models/north.py b/nomad/config/models/north.py index ddc109a54a467cf0322e17fe0eb0e1bb780e7f18..1f5f0f2bab73854011e0b4c2da83a959928ad801 100644 --- a/nomad/config/models/north.py +++ b/nomad/config/models/north.py @@ -46,17 +46,17 @@ class NORTHExternalMount(BaseModel): class NORTHTool(BaseModel): - short_description: Optional[str] = Field( + short_description: str | None = Field( None, description='A short description of the tool, e.g. shown in the NOMAD GUI.', ) - description: Optional[str] = Field( + description: str | None = Field( None, description='A description of the tool, e.g. shown in the NOMAD GUI.' ) - image: Optional[str] = Field( + image: str | None = Field( None, description='The docker image (incl. tags) to use for the tool.' ) - cmd: Optional[str] = Field( + cmd: str | None = Field( None, description='The container cmd that is passed to the spawner.' ) image_pull_policy: str = Field( @@ -65,14 +65,14 @@ class NORTHTool(BaseModel): privileged: bool = Field( False, description='Whether the tool needs to run in privileged mode.' ) - default_url: Optional[str] = Field( + default_url: str | None = Field( None, description=( 'An optional path prefix that is added to the container URL to ' 'reach the tool, e.g. "/lab" for jupyterlab.' ), ) - path_prefix: Optional[str] = Field( + path_prefix: str | None = Field( None, description=( 'An optional path prefix that is added to the container URL to ' @@ -86,11 +86,11 @@ class NORTHTool(BaseModel): 'This also enables tools to be launched from files in the NOMAD UI.' ), ) - file_extensions: List[str] = Field( + file_extensions: list[str] = Field( [], description='The file extensions of files that this tool should be launchable for.', ) - mount_path: Optional[str] = Field( + mount_path: str | None = Field( None, description=( 'The path in the container where uploads and work directories will be mounted, ' @@ -101,16 +101,16 @@ class NORTHTool(BaseModel): None, description='A URL to an icon that is used to represent the tool in the NOMAD UI.', ) - maintainer: List[NORTHToolMaintainer] = Field( + maintainer: list[NORTHToolMaintainer] = Field( [], description='The maintainers of the tool.' ) - external_mounts: List[NORTHExternalMount] = Field( + external_mounts: list[NORTHExternalMount] = Field( [], description='Additional mounts to be added to tool containers.' ) class NORTHTools(Options): - options: Dict[str, NORTHTool] = Field(dict(), description='The available plugin.') + options: dict[str, NORTHTool] = Field(dict(), description='The available plugin.') class NORTH(ConfigBaseModel): @@ -118,7 +118,7 @@ class NORTH(ConfigBaseModel): Settings related to the operation of the NOMAD remote tools hub service *north*. """ - enabled: Optional[bool] = Field( + enabled: bool | None = Field( True, description=""" Enables or disables the NORTH API and UI views. This is independent of @@ -144,7 +144,7 @@ class NORTH(ConfigBaseModel): The internal host name that NOMAD services use to connect to the jupyterhub API. """, ) - hub_port: Union[int, str] = Field( + hub_port: int | str = Field( 9000, description=""" The internal port that NOMAD services use to connect to the jupyterhub API. diff --git a/nomad/config/models/plugins.py b/nomad/config/models/plugins.py index b4dac54919ca535304efb735aee96568e6687e80..9657130e6cc68f1bb1e3ff5fa1ce5c69c033d211 100644 --- a/nomad/config/models/plugins.py +++ b/nomad/config/models/plugins.py @@ -41,16 +41,16 @@ if TYPE_CHECKING: class EntryPoint(BaseModel): """Base model for a NOMAD plugin entry points.""" - id: Optional[str] = Field( + id: str | None = Field( None, description='Unique identifier corresponding to the entry point name. Automatically set to the plugin entry point name in pyproject.toml.', ) entry_point_type: str = Field(description='Determines the entry point type.') - name: Optional[str] = Field(None, description='Name of the plugin entry point.') - description: Optional[str] = Field( + name: str | None = Field(None, description='Name of the plugin entry point.') + description: str | None = Field( None, description='A human readable description of the plugin entry point.' ) - plugin_package: Optional[str] = Field( + plugin_package: str | None = Field( None, description='The plugin package from which this entry points comes from.' ) @@ -127,8 +127,8 @@ class ParserEntryPoint(EntryPoint, metaclass=ABCMeta): level will attempt to match raw files first. """, ) - aliases: List[str] = Field([], description="""List of alternative parser names.""") - mainfile_contents_re: Optional[str] = Field( + aliases: list[str] = Field([], description="""List of alternative parser names.""") + mainfile_contents_re: str | None = Field( None, description=""" A regular expression that is applied the content of a potential mainfile. @@ -152,13 +152,13 @@ class ParserEntryPoint(EntryPoint, metaclass=ABCMeta): for a file, if the expression matches. """, ) - mainfile_binary_header: Optional[bytes] = Field( + mainfile_binary_header: bytes | None = Field( None, description=""" Matches a binary file if the given bytes are included in the file. """, ) - mainfile_binary_header_re: Optional[bytes] = Field( + mainfile_binary_header_re: bytes | None = Field( None, description=""" Matches a binary file if the given binary regular expression bytes matches the @@ -172,13 +172,13 @@ class ParserEntryPoint(EntryPoint, metaclass=ABCMeta): matches a parser. """, ) - mainfile_contents_dict: Optional[dict] = Field( + mainfile_contents_dict: dict | None = Field( None, description=""" Is used to match structured data files like JSON or HDF5. """, ) - supported_compressions: List[str] = Field( + supported_compressions: list[str] = Field( [], description=""" Files compressed with the given formats (e.g. xz, gz) are uncompressed and @@ -226,26 +226,22 @@ class ExampleUploadEntryPoint(EntryPoint): entry_point_type: Literal['example_upload'] = Field( 'example_upload', description='Determines the entry point type.' ) - category: Optional[str] = Field(description='Category for the example upload.') - title: Optional[str] = Field(description='Title of the example upload.') - description: Optional[str] = Field( + category: str | None = Field(description='Category for the example upload.') + title: str | None = Field(description='Title of the example upload.') + description: str | None = Field( description='Longer description of the example upload.' ) - resources: Optional[ + resources: None | ( # Note that the order here matters: pydantic may interpret a dictionary # as a list of strings instead of an UploadResource object if the order # is wrong here. - Union[ - List[Union[UploadResource, str]], - UploadResource, - str, - ] - ] = Field(None, description='List of data resources for this example upload.') - path: Optional[str] = Field( + list[UploadResource | str] | UploadResource | str + ) = Field(None, description='List of data resources for this example upload.') + path: str | None = Field( None, deprecated='"path" is deprecated, use "resources" instead.', ) - url: Optional[str] = Field( + url: str | None = Field( None, deprecated='"url" is deprecated, use "resources" instead.', ) @@ -453,19 +449,17 @@ class PluginBase(BaseModel): plugin_type: str = Field( description='The type of the plugin.', ) - id: Optional[str] = Field( - None, description='The unique identifier for this plugin.' - ) + id: str | None = Field(None, description='The unique identifier for this plugin.') name: str = Field( description='A short descriptive human readable name for the plugin.' ) - description: Optional[str] = Field( + description: str | None = Field( None, description='A human readable description of the plugin.' ) - plugin_documentation_url: Optional[str] = Field( + plugin_documentation_url: str | None = Field( None, description='The URL to the plugins main documentation page.' ) - plugin_source_code_url: Optional[str] = Field( + plugin_source_code_url: str | None = Field( None, description='The URL of the plugins main source code repository.' ) @@ -501,11 +495,11 @@ class Schema(PythonPluginBase): A Schema describes a NOMAD Python schema that can be loaded as a plugin. """ - package_path: Optional[str] = Field( + package_path: str | None = Field( None, description='Path of the plugin package. Will be determined using python_package if not explicitly defined.', ) - key: Optional[str] = Field(None, description='Key used to identify this plugin.') + key: str | None = Field(None, description='Key used to identify this plugin.') plugin_type: Literal['schema'] = Field( 'schema', description=""" @@ -577,7 +571,7 @@ class Parser(PythonPluginBase): parser class directly for parsing and matching. """, ) - mainfile_contents_re: Optional[str] = Field( + mainfile_contents_re: str | None = Field( None, description=""" A regular expression that is applied the content of a potential mainfile. @@ -601,13 +595,13 @@ class Parser(PythonPluginBase): expression matches. """, ) - mainfile_binary_header: Optional[bytes] = Field( + mainfile_binary_header: bytes | None = Field( None, description=""" Matches a binary file if the given bytes are included in the file. """, ) - mainfile_binary_header_re: Optional[bytes] = Field( + mainfile_binary_header_re: bytes | None = Field( None, description=""" Matches a binary file if the given binary regular expression bytes matches the @@ -621,7 +615,7 @@ class Parser(PythonPluginBase): matches a parser. """, ) - mainfile_contents_dict: Optional[dict] = Field( + mainfile_contents_dict: dict | None = Field( None, description=""" Is used to match structured data files like JSON, HDF5 or csv/excel files. In case of a csv/excel file @@ -638,7 +632,7 @@ class Parser(PythonPluginBase): <i>__has_comment: str<i> (only for csv/xlsx files) """, ) - supported_compressions: List[str] = Field( + supported_compressions: list[str] = Field( [], description=""" Files compressed with the given formats (e.g. xz, gz) are uncompressed and @@ -657,10 +651,10 @@ class Parser(PythonPluginBase): The order by which the parser is executed with respect to other parsers. """, ) - code_name: Optional[str] = None - code_homepage: Optional[str] = None - code_category: Optional[str] = None - metadata: Optional[dict] = Field( + code_name: str | None = None + code_homepage: str | None = None + code_category: str | None = None + metadata: dict | None = Field( None, description=""" Metadata passed to the UI. Deprecated.""", @@ -701,7 +695,7 @@ EntryPointType = Union[ class EntryPoints(Options): - options: Dict[str, EntryPointType] = Field( + options: dict[str, EntryPointType] = Field( dict(), description='The available plugin entry points.' ) @@ -710,25 +704,25 @@ class PluginPackage(BaseModel): name: str = Field( description='Name of the plugin Python package, read from pyproject.toml.' ) - description: Optional[str] = Field( + description: str | None = Field( None, description='Package description, read from pyproject.toml.' ) - version: Optional[str] = Field( + version: str | None = Field( None, description='Plugin package version, read from pyproject.toml.' ) - homepage: Optional[str] = Field( + homepage: str | None = Field( None, description='Link to the plugin package homepage, read from pyproject.toml.', ) - documentation: Optional[str] = Field( + documentation: str | None = Field( None, description='Link to the plugin package documentation page, read from pyproject.toml.', ) - repository: Optional[str] = Field( + repository: str | None = Field( None, description='Link to the plugin package source code repository, read from pyproject.toml.', ) - entry_points: List[str] = Field( + entry_points: list[str] = Field( description='List of entry point ids contained in this package, read form pyproject.toml' ) @@ -737,7 +731,7 @@ class Plugins(BaseModel): entry_points: EntryPoints = Field( description='Used to control plugin entry points.' ) - plugin_packages: Dict[str, PluginPackage] = Field( + plugin_packages: dict[str, PluginPackage] = Field( description=""" Contains the installed installed plugin packages with the package name used as a key. This is autogenerated and should not be modified. diff --git a/nomad/config/models/ui.py b/nomad/config/models/ui.py index 60b772d240a6528abbfaeebadb22c479a81f464b..a0ad71defd404aa4e8ed64e079e2802a4bbe0e84 100644 --- a/nomad/config/models/ui.py +++ b/nomad/config/models/ui.py @@ -18,7 +18,8 @@ from enum import Enum from typing import List, Dict, Union, Optional -from typing_extensions import Literal, Annotated +from typing import Literal +from typing import Annotated from pydantic import BaseModel, ConfigDict, model_validator, Field from .common import ( @@ -54,7 +55,7 @@ class UnitSystemUnit(ConfigBaseModel): registered in the NOMAD unit registry (`nomad.units.ureg`). """ ) - locked: Optional[bool] = Field( + locked: bool | None = Field( False, description='Whether the unit is locked in the unit system it is defined in.', ) @@ -98,7 +99,7 @@ class UnitSystem(ConfigBaseModel): label: str = Field( description='Short, descriptive label used for this unit system.' ) - units: Optional[Dict[str, UnitSystemUnit]] = Field( + units: dict[str, UnitSystemUnit] | None = Field( None, description=f""" Contains a mapping from each dimension to a unit. If a unit is not @@ -177,7 +178,7 @@ class UnitSystem(ConfigBaseModel): class UnitSystems(OptionsSingle): """Controls the available unit systems.""" - options: Optional[Dict[str, UnitSystem]] = Field( + options: dict[str, UnitSystem] | None = Field( None, description='Contains the available unit systems.' ) @@ -213,7 +214,7 @@ class Card(ConfigBaseModel): class Cards(Options): """Contains the overview page card definitions and controls their visibility.""" - options: Optional[Dict[str, Card]] = Field( + options: dict[str, Card] | None = Field( None, description='Contains the available card options.' ) @@ -267,7 +268,7 @@ class Column(ConfigBaseModel): - Show instance that matches a criterion: `repeating_section[?label=='target'].quantity` """ - search_quantity: Optional[str] = Field( + search_quantity: str | None = Field( None, description=""" Path of the targeted quantity. Note that you can most of the features @@ -276,26 +277,26 @@ class Column(ConfigBaseModel): statistical values. """, ) - quantity: Optional[str] = Field( + quantity: str | None = Field( None, deprecated='The "quantity" field is deprecated, use "search_quantity" instead.', ) selected: bool = Field( False, description="""Is this column initially selected to be shown.""" ) - title: Optional[str] = Field( + title: str | None = Field( None, description='Label shown in the header. Defaults to the quantity name.' ) - label: Optional[str] = Field(None, description='Alias for title.') + label: str | None = Field(None, description='Alias for title.') align: AlignEnum = Field(AlignEnum.LEFT, description='Alignment in the table.') - unit: Optional[str] = Field( + unit: str | None = Field( None, description=""" Unit to convert to when displaying. If not given will be displayed in using the default unit in the active unit system. """, ) - format: Optional[Format] = Field( + format: Format | None = Field( None, description='Controls the formatting of the values.' ) @@ -325,7 +326,7 @@ class Columns(OptionsMulti): selection. """ - options: Optional[Dict[str, Column]] = Field( + options: dict[str, Column] | None = Field( None, description=""" All available column options. Note here that the key must correspond to a @@ -337,7 +338,7 @@ class Columns(OptionsMulti): class RowAction(ConfigBaseModel): """Common configuration for all row actions.""" - description: Optional[str] = Field( + description: str | None = Field( None, description="""Description of the action shown to the user.""" ) type: str = Field(description='Used to identify the action type.') @@ -377,10 +378,10 @@ class RowActions(Options): """Controls the visualization of row actions that are shown at the end of each row.""" enabled: bool = Field(True, description='Whether to enable row actions.') - options: Optional[Dict[str, RowActionURL]] = Field( + options: dict[str, RowActionURL] | None = Field( None, deprecated="""Deprecated, use 'items' instead.""" ) - items: Optional[List[RowActionURL]] = Field( + items: list[RowActionURL] | None = Field( None, description='List of actions to show for each row.' ) @@ -457,7 +458,7 @@ class FilterMenuActionCheckbox(FilterMenuAction): class FilterMenuActions(Options): """Contains filter menu action definitions and controls their availability.""" - options: Optional[Dict[str, FilterMenuActionCheckbox]] = Field( + options: dict[str, FilterMenuActionCheckbox] | None = Field( None, description='Contains options for filter menu actions.' ) @@ -476,19 +477,19 @@ class FilterMenuSizeEnum(str, Enum): class FilterMenu(ConfigBaseModel): """Defines the layout and functionality for a filter menu.""" - label: Optional[str] = Field(None, description='Menu label to show in the UI.') - level: Optional[int] = Field(0, description='Indentation level of the menu.') - size: Optional[FilterMenuSizeEnum] = Field( + label: str | None = Field(None, description='Menu label to show in the UI.') + level: int | None = Field(0, description='Indentation level of the menu.') + size: FilterMenuSizeEnum | None = Field( FilterMenuSizeEnum.S, description='Width of the menu.' ) - actions: Optional[FilterMenuActions] = None + actions: FilterMenuActions | None = None # Deprecated class FilterMenus(Options): """Contains filter menu definitions and controls their availability.""" - options: Optional[Dict[str, FilterMenu]] = Field( + options: dict[str, FilterMenu] | None = Field( None, description='Contains the available filter menu options.' ) @@ -500,7 +501,7 @@ class FilterMenus(Options): class AxisScale(ConfigBaseModel): """Basic configuration for a plot axis.""" - scale: Optional[ScaleEnum] = Field( + scale: ScaleEnum | None = Field( ScaleEnum.LINEAR, description="""Defines the axis scaling. Defaults to linear scaling.""", ) @@ -509,13 +510,13 @@ class AxisScale(ConfigBaseModel): class AxisQuantity(ConfigBaseModel): """Configuration for a plot axis.""" - title: Optional[str] = Field( + title: str | None = Field( None, description="""Custom title to show for the axis.""" ) - unit: Optional[str] = Field( + unit: str | None = Field( None, description="""Custom unit used for displaying the values.""" ) - quantity: Optional[str] = Field( + quantity: str | None = Field( None, deprecated='The "quantity" field is deprecated, use "search_quantity" instead.', ) @@ -550,7 +551,7 @@ class Axis(AxisScale, AxisQuantity): class TermsBase(ConfigBaseModel): """Base model for configuring terms components.""" - quantity: Optional[str] = Field( + quantity: str | None = Field( None, deprecated='The "quantity" field is deprecated, use "search_quantity" instead.', ) @@ -560,7 +561,7 @@ class TermsBase(ConfigBaseModel): ) scale: ScaleEnum = Field(ScaleEnum.LINEAR, description='Statistics scaling.') show_input: bool = Field(True, description='Whether to show text input field.') - showinput: Optional[bool] = Field( + showinput: bool | None = Field( None, deprecated='The "showinput" field is deprecated, use "show_input" instead.', ) @@ -593,37 +594,37 @@ class HistogramBase(ConfigBaseModel): type: Literal['histogram'] = Field( description='Set as `histogram` to get this widget type.' ) - quantity: Optional[str] = Field( + quantity: str | None = Field( None, deprecated='The "quantity" field is deprecated, use "x.search_quantity" instead.', ) - scale: Optional[ScaleEnum] = Field( + scale: ScaleEnum | None = Field( None, deprecated='The "scale" field is deprecated, use "y.scale" instead.' ) show_input: bool = Field(True, description='Whether to show text input field.') - showinput: Optional[bool] = Field( + showinput: bool | None = Field( None, deprecated='The "showinput" field is deprecated, use "show_input" instead.', ) - x: Union[Axis, str] = Field( + x: Axis | str = Field( description='Configures the information source and display options for the x-axis.' ) - y: Union[AxisScale, str] = Field( + y: AxisScale | str = Field( description='Configures the information source and display options for the y-axis.' ) autorange: bool = Field( False, description='Whether to automatically set the range according to the data limits.', ) - n_bins: Optional[int] = Field( + n_bins: int | None = Field( None, description=""" Maximum number of histogram bins. Notice that the actual number of bins may be smaller if there are fewer data items available. """, ) - nbins: Optional[int] = Field( + nbins: int | None = Field( None, deprecated='The "nbins" field is deprecated, use "n_bins" instead.' ) @@ -673,14 +674,12 @@ class PeriodicTableBase(ConfigBaseModel): type: Literal['periodic_table'] = Field( description='Set as `periodic_table` to get this widget type.' ) - quantity: Optional[str] = Field( + quantity: str | None = Field( None, deprecated='The "quantity" field is deprecated, use "search_quantity" instead.', ) search_quantity: str = Field(description='The targeted search quantity.') - scale: Optional[ScaleEnum] = Field( - ScaleEnum.LINEAR, description='Statistics scaling.' - ) + scale: ScaleEnum | None = Field(ScaleEnum.LINEAR, description='Statistics scaling.') @model_validator(mode='before') @classmethod @@ -714,14 +713,14 @@ class MenuItem(ConfigBaseModel): description='Width of the item, 12 means maximum width. Note that the menu size can be changed.', ) show_header: bool = Field(True, description='Whether to show the header.') - title: Optional[str] = Field(None, description='Custom item title.') + title: str | None = Field(None, description='Custom item title.') class MenuItemOption(ConfigBaseModel): """Represents an option shown for a filter.""" - label: Optional[str] = Field(None, description='The label to show for this option.') - description: Optional[str] = Field( + label: str | None = Field(None, description='The label to show for this option.') + description: str | None = Field( None, description='Detailed description for this option.' ) @@ -731,7 +730,7 @@ class MenuItemTerms(MenuItem, TermsBase): quantities. """ - options: Optional[Union[int | bool, Dict[str, MenuItemOption]]] = Field( + options: int | bool | dict[str, MenuItemOption] | None = Field( None, description=""" Used to control the displayed options: @@ -882,7 +881,7 @@ class MenuItemNestedObject(MenuItem): path: str = Field( description='Path of the nested object. Typically a section name.' ) - items: Optional[List[MenuItemTypeNested]] = Field( + items: list[MenuItemTypeNested] | None = Field( None, description='Items that are grouped by this nested object.' ) @@ -923,15 +922,15 @@ class Menu(MenuItem): type: Literal['menu'] = Field( description='Set as `nested_object` to get this menu item type.', ) - size: Optional[Union[MenuSizeEnum, str]] = Field( + size: MenuSizeEnum | str | None = Field( MenuSizeEnum.SM, description=""" Size of the menu. Either use presets as defined by MenuSizeEnum, or then provide valid CSS widths. """, ) - indentation: Optional[int] = Field(0, description='Indentation level for the menu.') - items: Optional[List[MenuItemType]] = Field( + indentation: int | None = Field(0, description='Indentation level for the menu.') + items: list[MenuItemType] | None = Field( None, description='List of items in the menu.' ) @@ -955,13 +954,13 @@ class SearchQuantities(OptionsGlob): `*.#myschema.schema.MySchema`. """ - include: Optional[List[str]] = Field( + include: list[str] | None = Field( None, description=""" List of included options. Supports glob/wildcard syntax. """, ) - exclude: Optional[List[str]] = Field( + exclude: list[str] | None = Field( None, description=""" List of excluded options. Supports glob/wildcard syntax. Has higher precedence than include. @@ -988,7 +987,7 @@ class SearchSyntaxes(ConfigBaseModel): - `free_text`: For inexact, free-text queries. Requires that a set of keywords has been filled in the entry. """ - exclude: Optional[List[str]] = Field( + exclude: list[str] | None = Field( None, description=""" List of excluded options. @@ -1003,8 +1002,8 @@ class Layout(ConfigBaseModel): w: int = Field(description='Width in grid units.') x: int = Field(description='Horizontal start location in the grid.') y: int = Field(description='Vertical start location in the grid.') - minH: Optional[int] = Field(3, description='Minimum height in grid units.') - minW: Optional[int] = Field(3, description='Minimum width in grid units.') + minH: int | None = Field(3, description='Minimum height in grid units.') + minW: int | None = Field(3, description='Minimum width in grid units.') class BreakpointEnum(str, Enum): @@ -1018,7 +1017,7 @@ class BreakpointEnum(str, Enum): class AxisLimitedScale(AxisQuantity): """Configuration for a plot axis with limited scaling options.""" - scale: Optional[ScaleEnumPlot] = Field( + scale: ScaleEnumPlot | None = Field( ScaleEnumPlot.LINEAR, description="""Defines the axis scaling. Defaults to linear scaling.""", ) @@ -1027,7 +1026,7 @@ class AxisLimitedScale(AxisQuantity): class Markers(ConfigBaseModel): """Configuration for plot markers.""" - color: Optional[Axis] = Field( + color: Axis | None = Field( None, description='Configures the information source and display options for the marker colors.', ) @@ -1036,12 +1035,12 @@ class Markers(ConfigBaseModel): class Widget(ConfigBaseModel): """Common configuration for all widgets.""" - title: Optional[str] = Field( + title: str | None = Field( None, description='Custom widget title. If not specified, a widget-specific default title is used.', ) type: str = Field(description='Used to identify the widget type.') - layout: Dict[BreakpointEnum, Layout] = Field( + layout: dict[BreakpointEnum, Layout] = Field( description=""" Defines widget size and grid positioning for different breakpoints. The following breakpoints are supported: `sm`, `md`, `lg`, `xl` and `xxl`. @@ -1109,17 +1108,17 @@ class WidgetScatterPlot(Widget): type: Literal['scatter_plot'] = Field( description='Set as `scatter_plot` to get this widget type.' ) - x: Union[AxisLimitedScale, str] = Field( + x: AxisLimitedScale | str = Field( description='Configures the information source and display options for the x-axis.' ) - y: Union[AxisLimitedScale, str] = Field( + y: AxisLimitedScale | str = Field( description='Configures the information source and display options for the y-axis.' ) - markers: Optional[Markers] = Field( + markers: Markers | None = Field( None, description='Configures the information source and display options for the markers.', ) - color: Optional[str] = Field( + color: str | None = Field( None, description=""" Quantity used for coloring points. Note that this field is deprecated @@ -1215,7 +1214,7 @@ WidgetAnnotated = Annotated[ class Dashboard(ConfigBaseModel): """Dashboard configuration.""" - widgets: List[WidgetAnnotated] = Field( + widgets: list[WidgetAnnotated] = Field( description='List of widgets contained in the dashboard.' ) @@ -1231,47 +1230,43 @@ class App(ConfigBaseModel): label: str = Field(description='Name of the App.') path: str = Field(description='Path used in the browser address bar.') resource: ResourceEnum = Field('entries', description='Targeted resource.') # type: ignore - breadcrumb: Optional[str] = Field( + breadcrumb: str | None = Field( None, description='Name displayed in the breadcrumb, by default the label will be used.', ) category: str = Field( description='Category used to organize Apps in the explore menu.' ) - description: Optional[str] = Field( - None, description='Short description of the App.' - ) - readme: Optional[str] = Field( + description: str | None = Field(None, description='Short description of the App.') + readme: str | None = Field( None, description='Longer description of the App that can also use markdown.' ) pagination: Pagination = Field( Pagination(), description='Default result pagination.' ) - columns: Optional[List[Column]] = Field( + columns: list[Column] | None = Field( None, description='List of columns for the results table.' ) - rows: Optional[Rows] = Field( + rows: Rows | None = Field( Rows(), description='Controls the display of entry rows in the results table.', ) - menu: Optional[Menu] = Field( + menu: Menu | None = Field( None, description='Filter menu displayed on the left side of the screen.' ) - filter_menus: Optional[FilterMenus] = Field( + filter_menus: FilterMenus | None = Field( None, deprecated='The "filter_menus" field is deprecated, use "menu" instead.' ) - filters: Optional[Filters] = Field( + filters: Filters | None = Field( None, deprecated='The "filters" field is deprecated, use "search_quantities" instead.', ) - search_quantities: Optional[SearchQuantities] = Field( + search_quantities: SearchQuantities | None = Field( SearchQuantities(exclude=['mainfile', 'entry_name', 'combine']), description='Controls the quantities that are available for search in this app.', ) - dashboard: Optional[Dashboard] = Field( - None, description='Default dashboard layout.' - ) - filters_locked: Optional[dict] = Field( + dashboard: Dashboard | None = Field(None, description='Default dashboard layout.') + filters_locked: dict | None = Field( None, description=""" Fixed query object that is applied for this search context. This filter @@ -1279,7 +1274,7 @@ class App(ConfigBaseModel): user by default. """, ) - search_syntaxes: Optional[SearchSyntaxes] = Field( + search_syntaxes: SearchSyntaxes | None = Field( None, description='Controls which types of search syntax are available.' ) @@ -1997,7 +1992,7 @@ class App(ConfigBaseModel): class Apps(Options): """Contains App definitions and controls their availability.""" - options: Optional[Dict[str, App]] = Field( + options: dict[str, App] | None = Field( None, description='Contains the available app options.' ) diff --git a/nomad/datamodel/context.py b/nomad/datamodel/context.py index de2138dcdf001d59acb0e2fe6df12c44e1e2e420..dca4c701c1bd5bd8615b059889505c7021f4682f 100644 --- a/nomad/datamodel/context.py +++ b/nomad/datamodel/context.py @@ -58,8 +58,8 @@ class Context(MetainfoContext): '/v1' if self.installation_url.endswith('/api') else '/api/v1' ) - self.archives: Dict[str, MSection] = {} - self.urls: Dict[MSection, str] = {} + self.archives: dict[str, MSection] = {} + self.urls: dict[MSection, str] = {} @property def upload_id(self): @@ -495,7 +495,7 @@ class ClientContext(Context): if os.path.exists(file_path): from nomad.parsing.parser import ArchiveParser - with open(file_path, 'rt') as f: + with open(file_path) as f: archive = EntryArchive(m_context=self) ArchiveParser().parse_file(file_path, f, archive) return archive diff --git a/nomad/datamodel/data.py b/nomad/datamodel/data.py index fb35e9d3f687fbb94fb84180c5842a6dfba158de..36b55b2d65bd0802c0921e2c1292e4a238e15e02 100644 --- a/nomad/datamodel/data.py +++ b/nomad/datamodel/data.py @@ -94,7 +94,7 @@ class EntryData(ArchiveSection): """ def normalize(self, archive, logger): - super(EntryData, self).normalize(archive, logger) + super().normalize(archive, logger) from nomad.datamodel.results import Results from nomad.datamodel import EntryArchive @@ -119,7 +119,7 @@ class Author(MSection): name = Quantity( type=str, - derived=lambda user: ('%s %s' % (user.first_name, user.last_name)).strip(), + derived=lambda user: (f'{user.first_name} {user.last_name}').strip(), a_elasticsearch=[ Elasticsearch(material_entry_type, _es_field='keyword'), Elasticsearch( @@ -265,7 +265,7 @@ class Query(JSON): from nomad.app.v1.models import MetadataResponse class QueryResult(MetadataResponse): - filters: Optional[Dict[str, Any]] = Field(None) + filters: dict[str, Any] | None = Field(None) return QueryResult().parse_obj(value).dict() diff --git a/nomad/datamodel/datamodel.py b/nomad/datamodel/datamodel.py index 875fc3d91489ea15c19af245c37bf853b3b3d39a..a90eecef2abe75a849fcfa55749e220d917e3568 100644 --- a/nomad/datamodel/datamodel.py +++ b/nomad/datamodel/datamodel.py @@ -272,14 +272,14 @@ def derive_origin(entry: 'EntryMetadata') -> str: return None -def derive_authors(entry: 'EntryMetadata') -> List[User]: +def derive_authors(entry: 'EntryMetadata') -> list[User]: if entry.external_db == 'EELS Data Base': return list(entry.entry_coauthors) if entry.entry_type == 'AIToolkitNotebook': return list(entry.entry_coauthors) - authors: List[User] = [entry.main_author] + authors: list[User] = [entry.main_author] if entry.coauthors: authors.extend(entry.coauthors) if entry.entry_coauthors: @@ -1221,7 +1221,7 @@ class EntryArchive(ArchiveSection): definitions = SubSection(sub_section=Package) def normalize(self, archive, logger): - super(EntryArchive, self).normalize(archive, logger) + super().normalize(archive, logger) if not archive.metadata.entry_type: if archive.definitions is not None: diff --git a/nomad/datamodel/metainfo/annotations.py b/nomad/datamodel/metainfo/annotations.py index d8f2c19e31db1c76eab2ab94c46a387af0abd3ec..80105db0a52640816c64113ed2a5f69d8737255e 100644 --- a/nomad/datamodel/metainfo/annotations.py +++ b/nomad/datamodel/metainfo/annotations.py @@ -105,7 +105,7 @@ valid_eln_components = { class Filter(BaseModel): """A filter defined by an include list or and exclude list of the quantities or subsections.""" - include: Optional[List[str]] = Field( + include: list[str] | None = Field( None, description=strip( """ @@ -113,7 +113,7 @@ class Filter(BaseModel): """ ), ) - exclude: Optional[List[str]] = Field( + exclude: list[str] | None = Field( None, description=strip( """ @@ -126,7 +126,7 @@ class Filter(BaseModel): class DisplayAnnotation(BaseModel): """The display settings defined by an include list or an exclude list of the quantities and subsections.""" - visible: Optional[Filter] = Field( # type: ignore + visible: Filter | None = Field( # type: ignore 1, description=strip( """ @@ -134,7 +134,7 @@ class DisplayAnnotation(BaseModel): """ ), ) - editable: Optional[Filter] = Field( + editable: Filter | None = Field( None, description=strip( """ @@ -167,7 +167,7 @@ class QuantityDisplayAnnotation(DisplayAnnotation): ``` """ - unit: Optional[str] = Field( + unit: str | None = Field( None, description=strip( """ @@ -200,7 +200,7 @@ class SectionDisplayAnnotation(DisplayAnnotation): ``` """ - order: Optional[List[str]] = Field( + order: list[str] | None = Field( None, description=strip( """ @@ -213,7 +213,7 @@ class SectionDisplayAnnotation(DisplayAnnotation): class SectionProperties(BaseModel): """The display settings for quantities and subsections. (Deprecated)""" - visible: Optional[Filter] = Field( # type: ignore + visible: Filter | None = Field( # type: ignore 1, description=strip( """ @@ -221,7 +221,7 @@ class SectionProperties(BaseModel): """ ), ) - editable: Optional[Filter] = Field( + editable: Filter | None = Field( None, description=strip( """ @@ -229,7 +229,7 @@ class SectionProperties(BaseModel): """ ), ) - order: Optional[List[str]] = Field( + order: list[str] | None = Field( None, description=strip( """ @@ -303,7 +303,7 @@ class ELNAnnotation(AnnotationModel): """, ) - props: Dict[str, Any] = Field( + props: dict[str, Any] = Field( None, description=""" A dictionary with additional props that are passed to the edit component. @@ -331,7 +331,7 @@ class ELNAnnotation(AnnotationModel): deprecated=True, ) - minValue: Union[int, float] = Field( + minValue: int | float = Field( None, description=""" Allows to specify a minimum value for quantity annotations with number type. @@ -340,7 +340,7 @@ class ELNAnnotation(AnnotationModel): """, ) - maxValue: Union[int, float] = Field( + maxValue: int | float = Field( None, description=""" Allows to specify a maximum value for quantity annotations with number type. @@ -357,7 +357,7 @@ class ELNAnnotation(AnnotationModel): """, ) - hide: List[str] = Field( + hide: list[str] = Field( None, description=""" This attribute is deprecated. Use `visible` attribute of `display` annotation instead. @@ -375,7 +375,7 @@ class ELNAnnotation(AnnotationModel): section annotations.""", ) - lane_width: Union[str, int] = Field( + lane_width: str | int = Field( None, description=""" Value to overwrite the css width of the lane used to render the annotation @@ -561,7 +561,7 @@ class TabularMode(str, Enum): class TabularParsingOptions(BaseModel): - skiprows: Union[List[int], int] = Field(None, description='Number of rows to skip') + skiprows: list[int] | int = Field(None, description='Number of rows to skip') sep: str = Field(None, description='Character identifier of a separator') comment: str = Field(None, description='Character identifier of a commented line') separator: str = Field(None, description='Alias for `sep`') @@ -612,7 +612,7 @@ class TabularMappingOptions(BaseModel): `multiple_new_entries`: Creating many new entries and processing the data into these new NOMAD entries.<br/> """, ) - sections: List[str] = Field( + sections: list[str] = Field( None, description=""" A `list` of paths to the (sub)sections where the tabular quantities are to be filled from the data @@ -642,7 +642,7 @@ class TabularParserAnnotation(AnnotationModel): `separator`: An alias for `sep`.<br/> """, ) - mapping_options: List[TabularMappingOptions] = Field( + mapping_options: list[TabularMappingOptions] = Field( [], description=""" A list of directives on how to map the extracted data from the csv/xlsx file to NOMAD. Each directive @@ -667,12 +667,12 @@ class PlotlyExpressTraceAnnotation(BaseModel): """ method: str = Field(None, description='Plotly express plot method') - layout: Dict = Field(None, description='Plotly layout') + layout: dict = Field(None, description='Plotly layout') - x: Union[List[float], List[str], str] = Field(None, description='Plotly express x') - y: Union[List[float], List[str], str] = Field(None, description='Plotly express y') - z: Union[List[float], List[str], str] = Field(None, description='Plotly express z') - color: Union[List[float], List[str], str] = Field( + x: list[float] | list[str] | str = Field(None, description='Plotly express x') + y: list[float] | list[str] | str = Field(None, description='Plotly express y') + z: list[float] | list[str] | str = Field(None, description='Plotly express z') + color: list[float] | list[str] | str = Field( None, description='Plotly express color' ) symbol: str = Field(None, description='Plotly express symbol') @@ -713,7 +713,7 @@ class PlotlyExpressAnnotation(PlotlyExpressTraceAnnotation): """ label: str = Field(None, description='Figure label') - traces: List[PlotlyExpressTraceAnnotation] = Field( + traces: list[PlotlyExpressTraceAnnotation] = Field( [], description=""" List of traces added to the main trace defined by plotly_express method @@ -744,12 +744,12 @@ class PlotlyGraphObjectAnnotation(BaseModel): """ label: str = Field(None, description='Figure label') - data: Dict = Field(None, description='Plotly data') - layout: Dict = Field(None, description='Plotly layout') - config: Dict = Field(None, description='Plotly config') + data: dict = Field(None, description='Plotly data') + layout: dict = Field(None, description='Plotly layout') + config: dict = Field(None, description='Plotly config') def __init__(self, *args, **kwargs): - super(PlotlyGraphObjectAnnotation, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) if not self.data or not isinstance(self.data, dict): raise PlotlyError('data should be a dictionary containing plotly data.') @@ -802,15 +802,15 @@ class PlotlySubplotsAnnotation(BaseModel): """ label: str = Field(None, description='Figure label') - layout: Dict = Field(None, description='Plotly layout') - parameters: Dict = Field( + layout: dict = Field(None, description='Plotly layout') + parameters: dict = Field( None, description=""" plotly.subplots.make_subplots parameters i.e. rows, cols, shared_xaxes, shared_xaxes, horizontal_spacing , ... See [plotly make_subplots documentation](https://plotly.com/python-api-reference/generated/plotly.subplots.make_subplots.html) for more information. """, ) - plotly_express: List[PlotlyExpressAnnotation] = Field( + plotly_express: list[PlotlyExpressAnnotation] = Field( [], description=""" List of subplots defined by plotly_express method @@ -903,7 +903,7 @@ class PlotAnnotation(AnnotationModel): def __init__(self, *args, **kwargs): # pydantic does not seem to support multiple aliases per field - super(PlotAnnotation, self).__init__( + super().__init__( *args, x=kwargs.pop('x', None) or kwargs.pop('xAxis', None) @@ -917,7 +917,7 @@ class PlotAnnotation(AnnotationModel): label: str = Field( None, description='Is passed to plotly to define the label of the plot.' ) - x: Union[List[str], str] = Field( + x: list[str] | str = Field( ..., description=""" A path or list of paths to the x-axes values. Each path is a `/` separated @@ -926,7 +926,7 @@ class PlotAnnotation(AnnotationModel): integer or a slice `start:stop`. """, ) - y: Union[List[str], str] = Field( + y: list[str] | str = Field( ..., description=""" A path or list of paths to the y-axes values. list of sub-section and quantity @@ -934,7 +934,7 @@ class PlotAnnotation(AnnotationModel): sections are indexed between two `/`s with an integer or a slice `start:stop`. """, ) - lines: List[dict] = Field( + lines: list[dict] = Field( None, description=""" A list of dicts passed as `traces` to plotly to configure the lines of the plot. @@ -986,7 +986,7 @@ class PlotAnnotation(AnnotationModel): class RegexCondition(BaseModel): - regex_path: Optional[str] = Field( + regex_path: str | None = Field( None, description=""" The JMESPath to the target key in the dictionary. If not set, the path is assumed to be the same as @@ -1003,19 +1003,19 @@ class Condition(BaseModel): class Rule(BaseModel): - source: Optional[str] = Field( + source: str | None = Field( None, description='JMESPath to the source value in the source dictionary.' ) target: str = Field( ..., description='JMESPath to the target value in the target dictionary.' ) - conditions: Optional[list[Condition]] = Field( + conditions: list[Condition] | None = Field( None, description='Conditions to check prior to applying the transformation.' ) - default_value: Optional[Any] = Field( + default_value: Any | None = Field( None, description='Default value to set if source is not found.' ) - use_rule: Optional[str] = Field( + use_rule: str | None = Field( None, description='Reference to another rule using #mapping_name.rule_name.' ) @@ -1041,10 +1041,10 @@ Rule.update_forward_refs() class Rules(BaseModel): - name: Optional[str] = Field( + name: str | None = Field( None, description='Name for the rule set, for identification.' ) - other_metadata: Optional[str] = Field( + other_metadata: str | None = Field( None, description='Placeholder for other metadata.' ) rules: dict[str, Rule] = Field(..., description='Dictionary of named rules.') @@ -1060,7 +1060,7 @@ class H5WebAnnotation(AnnotationModel): of the annotation fields. """ - axes: Union[str, List[str]] = Field( + axes: str | list[str] = Field( None, description=""" Names of the HDF5Dataset quantities to plot on the independent axes. @@ -1078,7 +1078,7 @@ class H5WebAnnotation(AnnotationModel): Label for the hdf5 dataset. Note: this attribute will overwrite also the unit. """, ) - auxiliary_signals: List[str] = Field( + auxiliary_signals: list[str] = Field( None, description=""" Additional datasets to include in plot as signal. @@ -1090,7 +1090,7 @@ class H5WebAnnotation(AnnotationModel): Title of the plot """, ) - paths: List[str] = Field([], description="""List of section paths to visualize.""") + paths: list[str] = Field([], description="""List of section paths to visualize.""") class SchemaAnnotation(AnnotationModel): diff --git a/nomad/datamodel/metainfo/basesections/v1.py b/nomad/datamodel/metainfo/basesections/v1.py index de77df19945ad1ab3aefbc8a83f56154762817a9..5604e971d4830f0ca9911bf7a3d27c01c8e3a6ec 100644 --- a/nomad/datamodel/metainfo/basesections/v1.py +++ b/nomad/datamodel/metainfo/basesections/v1.py @@ -20,7 +20,8 @@ import os import random import re import time -from typing import TYPE_CHECKING, Dict, Iterable, List +from typing import TYPE_CHECKING, Dict, List +from collections.abc import Iterable import h5py import numpy as np @@ -239,7 +240,7 @@ class BaseSection(ArchiveSection): normalized. logger ('BoundLogger'): A structlog logger. """ - super(BaseSection, self).normalize(archive, logger) + super().normalize(archive, logger) if isinstance(self, EntryData): if archive.data == self and self.name: @@ -391,7 +392,7 @@ class Activity(BaseSection): normalized. logger ('BoundLogger'): A structlog logger. """ - super(Activity, self).normalize(archive, logger) + super().normalize(archive, logger) if archive.results.eln.methods is None: archive.results.eln.methods = [] @@ -457,7 +458,7 @@ class EntityReference(SectionReference): normalized. logger ('BoundLogger'): A structlog logger. """ - super(EntityReference, self).normalize(archive, logger) + super().normalize(archive, logger) if self.reference is None and self.lab_id is not None: from nomad.search import MetadataPagination, search @@ -522,7 +523,7 @@ class ExperimentStep(ActivityStep): normalized. logger ('BoundLogger'): A structlog logger. """ - super(ExperimentStep, self).normalize(archive, logger) + super().normalize(archive, logger) if self.activity is None and self.lab_id is not None: from nomad.search import MetadataPagination, search @@ -629,7 +630,7 @@ class ElementalComposition(ArchiveSection): normalized. logger ('BoundLogger'): A structlog logger. """ - super(ElementalComposition, self).normalize(archive, logger) + super().normalize(archive, logger) if self.element: if not archive.results: @@ -747,7 +748,7 @@ class System(Entity): normalized. logger ('BoundLogger'): A structlog logger. """ - super(System, self).normalize(archive, logger) + super().normalize(archive, logger) if len(self.elemental_composition) > 0: self._fill_fractions(archive, logger) @@ -767,7 +768,7 @@ class Instrument(Entity): normalized. logger ('BoundLogger'): A structlog logger. """ - super(Instrument, self).normalize(archive, logger) + super().normalize(archive, logger) if self.name: if archive.results.eln.instruments is None: @@ -835,7 +836,7 @@ class SystemComponent(Component): normalized. logger ('BoundLogger'): A structlog logger. """ - super(SystemComponent, self).normalize(archive, logger) + super().normalize(archive, logger) if self.name is None and self.system is not None: self.name = self.system.name @@ -958,14 +959,14 @@ class PureSubstanceComponent(Component): normalized. logger ('BoundLogger'): A structlog logger. """ - super(PureSubstanceComponent, self).normalize(archive, logger) + super().normalize(archive, logger) if self.substance_name and self.pure_substance is None: self.pure_substance = PureSubstanceSection(name=self.substance_name) if self.name is None and self.pure_substance is not None: self.name = self.pure_substance.molecular_formula -def elemental_composition_from_formula(formula: Formula) -> List[ElementalComposition]: +def elemental_composition_from_formula(formula: Formula) -> list[ElementalComposition]: """ Help function for generating list of `ElementalComposition` instances from `nomad.atomutils.Formula` item @@ -1004,8 +1005,8 @@ class CompositeSystem(System): @staticmethod def _atomic_to_mass( - composition: List[ElementalComposition], mass: float - ) -> Dict[str, float]: + composition: list[ElementalComposition], mass: float + ) -> dict[str, float]: """ Private static method for converting list of ElementalComposition objects to dictionary of element masses with the element symbol as key and mass as value. @@ -1033,7 +1034,7 @@ class CompositeSystem(System): atom_masses.append(atomic_masses[atomic_numbers[comp.element]]) masses = np.array(atomic_fractions) * np.array(atom_masses) mass_array = mass * masses / masses.sum() - mass_dict: Dict[str, float] = {} + mass_dict: dict[str, float] = {} for c, m in zip(composition, mass_array): if c.element in mass_dict: mass_dict[c.element] += m @@ -1042,7 +1043,7 @@ class CompositeSystem(System): return mass_dict @staticmethod - def _mass_to_atomic(mass_dict: Dict[str, float]) -> List[ElementalComposition]: + def _mass_to_atomic(mass_dict: dict[str, float]) -> list[ElementalComposition]: """ Private static method for converting dictionary of elements with their masses to a list of ElementalComposition objects containing atomic fractions. @@ -1096,7 +1097,7 @@ class CompositeSystem(System): mass_fractions.pop(empty_index) self.components[empty_index].mass_fraction = 1 - sum(mass_fractions) if not self.elemental_composition: - mass_dict: Dict[str, float] = {} + mass_dict: dict[str, float] = {} if any(component.mass is None for component in self.components): if all(component.mass is None for component in self.components): masses = [component.mass_fraction for component in self.components] @@ -1148,7 +1149,7 @@ class CompositeSystem(System): for comp in self.elemental_composition: comp.normalize(archive, logger) - super(CompositeSystem, self).normalize(archive, logger) + super().normalize(archive, logger) class CompositeSystemReference(EntityReference): @@ -1235,7 +1236,7 @@ class Process(Activity): normalized. logger ('BoundLogger'): A structlog logger. """ - super(Process, self).normalize(archive, logger) + super().normalize(archive, logger) if ( self.datetime is not None and all(step.duration is not None for step in self.steps) @@ -1303,7 +1304,7 @@ class Analysis(Activity): normalized. logger ('BoundLogger'): A structlog logger. """ - super(Analysis, self).normalize(archive, logger) + super().normalize(archive, logger) archive.workflow2.inputs = [ Link(name=input.name, section=input.reference) for input in self.inputs ] @@ -1370,7 +1371,7 @@ class Measurement(Activity): normalized. logger ('BoundLogger'): A structlog logger. """ - super(Measurement, self).normalize(archive, logger) + super().normalize(archive, logger) archive.workflow2.inputs = [ Link(name=sample.name, section=sample.reference) for sample in self.samples ] @@ -1427,7 +1428,7 @@ class PureSubstance(System): archive (EntryArchive): The archive that is being normalized. logger ('BoundLogger'): A structlog logger. """ - super(PureSubstance, self).normalize(archive, logger) + super().normalize(archive, logger) if logger is None: logger = utils.get_logger(__name__) if self.pure_substance and self.pure_substance.molecular_formula: @@ -1650,7 +1651,7 @@ class PubChemPureSubstanceSection(PureSubstanceSection): else: self._find_cid(logger) - super(PubChemPureSubstanceSection, self).normalize(archive, logger) + super().normalize(archive, logger) class CASExperimentalProperty(ArchiveSection): @@ -1907,7 +1908,7 @@ class CASPureSubstanceSection(PureSubstanceSection): else: self._find_cas(archive, logger) - super(CASPureSubstanceSection, self).normalize(archive, logger) + super().normalize(archive, logger) class ReadableIdentifiers(ArchiveSection): @@ -1985,7 +1986,7 @@ class ReadableIdentifiers(ArchiveSection): normalized. logger ('BoundLogger'): A structlog logger. """ - super(ReadableIdentifiers, self).normalize(archive, logger) + super().normalize(archive, logger) if self.owner is None or self.institute is None: author = archive.metadata.main_author @@ -2112,7 +2113,7 @@ class PublicationReference(ArchiveSection): normalized. logger ('BoundLogger'): A structlog logger. """ - super(PublicationReference, self).normalize(archive, logger) + super().normalize(archive, logger) import dateutil.parser import requests from nomad.datamodel.datamodel import EntryMetadata @@ -2157,7 +2158,7 @@ class PublicationReference(ArchiveSection): class HDF5Normalizer(ArchiveSection): def normalize(self, archive, logger): - super(HDF5Normalizer, self).normalize(archive, logger) + super().normalize(archive, logger) h5_re = re.compile(r'.*\.h5$') for quantity_name, quantity_def in self.m_def.all_quantities.items(): diff --git a/nomad/datamodel/metainfo/basesections/v2.py b/nomad/datamodel/metainfo/basesections/v2.py index a3464eac40340c646e6e296a5840d5ba9cd18195..b2b442bbb7fcc6fe4d48dd55e61ad9a30eaf7222 100644 --- a/nomad/datamodel/metainfo/basesections/v2.py +++ b/nomad/datamodel/metainfo/basesections/v2.py @@ -16,7 +16,8 @@ # limitations under the License. # import os -from typing import TYPE_CHECKING, Iterable +from typing import TYPE_CHECKING +from collections.abc import Iterable import random import time import datetime @@ -349,7 +350,7 @@ class Activity(BaseSection): normalized. logger ('BoundLogger'): A structlog logger. """ - super(Activity, self).normalize(archive, logger) + super().normalize(archive, logger) if archive.results.eln.methods is None: archive.results.eln.methods = [] @@ -415,7 +416,7 @@ class EntityReference(SectionReference): normalized. logger ('BoundLogger'): A structlog logger. """ - super(EntityReference, self).normalize(archive, logger) + super().normalize(archive, logger) if self.reference is None and self.lab_id is not None: from nomad.search import search, MetadataPagination @@ -480,7 +481,7 @@ class ExperimentStep(ActivityStep): normalized. logger ('BoundLogger'): A structlog logger. """ - super(ExperimentStep, self).normalize(archive, logger) + super().normalize(archive, logger) if self.activity is None and self.lab_id is not None: from nomad.search import search, MetadataPagination @@ -587,7 +588,7 @@ class ElementalComposition(ArchiveSection): normalized. logger ('BoundLogger'): A structlog logger. """ - super(ElementalComposition, self).normalize(archive, logger) + super().normalize(archive, logger) if self.element: if not archive.results: @@ -998,7 +999,7 @@ class PureSubstance(System): archive (EntryArchive): The archive that is being normalized. logger ('BoundLogger'): A structlog logger. """ - super(PureSubstance, self).normalize(archive, logger) + super().normalize(archive, logger) # if logger is None: # logger = utils.get_logger(__name__) # if self.molecular_formula: @@ -1168,7 +1169,7 @@ class Instrument(Entity): normalized. logger ('BoundLogger'): A structlog logger. """ - super(Instrument, self).normalize(archive, logger) + super().normalize(archive, logger) if self.name: if archive.results.eln.instruments is None: @@ -1260,7 +1261,7 @@ class Process(Activity): normalized. logger ('BoundLogger'): A structlog logger. """ - super(Process, self).normalize(archive, logger) + super().normalize(archive, logger) if ( self.datetime is not None and all(step.duration is not None for step in self.steps) @@ -1328,7 +1329,7 @@ class Analysis(Activity): normalized. logger ('BoundLogger'): A structlog logger. """ - super(Analysis, self).normalize(archive, logger) + super().normalize(archive, logger) archive.workflow2.inputs = [ Link(name=input.name, section=input.reference) for input in self.inputs ] @@ -1395,7 +1396,7 @@ class Measurement(Activity): normalized. logger ('BoundLogger'): A structlog logger. """ - super(Measurement, self).normalize(archive, logger) + super().normalize(archive, logger) archive.workflow2.inputs = [ Link(name=sample.name, section=sample.reference) for sample in self.samples ] @@ -1479,7 +1480,7 @@ class ReadableIdentifiers(ArchiveSection): normalized. logger ('BoundLogger'): A structlog logger. """ - super(ReadableIdentifiers, self).normalize(archive, logger) + super().normalize(archive, logger) if self.owner is None or self.institute is None: author = archive.metadata.main_author @@ -1606,7 +1607,7 @@ class PublicationReference(ArchiveSection): normalized. logger ('BoundLogger'): A structlog logger. """ - super(PublicationReference, self).normalize(archive, logger) + super().normalize(archive, logger) from nomad.datamodel.datamodel import EntryMetadata import dateutil.parser import requests @@ -1651,7 +1652,7 @@ class PublicationReference(ArchiveSection): class HDF5Normalizer(ArchiveSection): def normalize(self, archive, logger): - super(HDF5Normalizer, self).normalize(archive, logger) + super().normalize(archive, logger) h5_re = re.compile(r'.*\.h5$') for quantity_name, quantity_def in self.m_def.all_quantities.items(): diff --git a/nomad/datamodel/metainfo/downloads.py b/nomad/datamodel/metainfo/downloads.py index 067a73234d7020100202ad27e042a3bd9b1caae0..9ff64389d79966909ddf158a38a2cc249d61f781 100644 --- a/nomad/datamodel/metainfo/downloads.py +++ b/nomad/datamodel/metainfo/downloads.py @@ -95,7 +95,7 @@ class Downloads(ArchiveSection): ) def normalize(self, archive, logger): - super(Downloads, self).normalize(archive, logger) + super().normalize(archive, logger) from nomad.datamodel import EntryArchive, ServerContext diff --git a/nomad/datamodel/metainfo/eln/__init__.py b/nomad/datamodel/metainfo/eln/__init__.py index 69417aceee076e5723cede1194ef32b1f1b22deb..319780330ac46b02638b985932d49f33e2b26f12 100644 --- a/nomad/datamodel/metainfo/eln/__init__.py +++ b/nomad/datamodel/metainfo/eln/__init__.py @@ -167,7 +167,7 @@ class ElnBaseSection(ArchiveSection): ) def normalize(self, archive, logger): - super(ElnBaseSection, self).normalize(archive, logger) + super().normalize(archive, logger) if isinstance(self, EntryData): if archive.data == self and self.name: @@ -562,7 +562,7 @@ class SampleID(ArchiveSection): normalized. logger ('BoundLogger'): A structlog logger. """ - super(SampleID, self).normalize(archive, logger) + super().normalize(archive, logger) if self.sample_owner is None or self.institute is None: author = archive.metadata.main_author @@ -917,7 +917,7 @@ class Substance(System): archive (EntryArchive): The archive that is being normalized. logger (Any): A structlog logger. """ - super(Substance, self).normalize(archive, logger) + super().normalize(archive, logger) if logger is None: logger = utils.get_logger(__name__) logger.warn( @@ -956,7 +956,7 @@ class Substance(System): except Exception as e: logger.warn('Could not analyse chemical formula.', exc_info=e) - super(Substance, self).normalize(archive, logger) + super().normalize(archive, logger) class ElnWithFormulaBaseSection(ElnBaseSection): @@ -976,7 +976,7 @@ class ElnWithFormulaBaseSection(ElnBaseSection): ) def normalize(self, archive, logger): - super(ElnWithFormulaBaseSection, self).normalize(archive, logger) + super().normalize(archive, logger) if logger is None: logger = utils.get_logger(__name__) @@ -1026,7 +1026,7 @@ class ElnWithStructureFile(ArchiveSection): ) def normalize(self, archive, logger): - super(ElnWithStructureFile, self).normalize(archive, logger) + super().normalize(archive, logger) if self.structure_file: from ase.io import read @@ -1145,7 +1145,7 @@ class SolarCellDefinition(ArchiveSection): ) def normalize(self, archive, logger): - super(SolarCellDefinition, self).normalize(archive, logger) + super().normalize(archive, logger) add_solar_cell(archive) if self.stack_sequence: if '/' in self.stack_sequence: @@ -1219,7 +1219,7 @@ class SolarCellLayer(ArchiveSection): ) def normalize(self, archive, logger): - super(SolarCellLayer, self).normalize(archive, logger) + super().normalize(archive, logger) add_solar_cell(archive) if self.layer_name: if self.solar_cell_layer_type == 'Absorber': @@ -1256,9 +1256,7 @@ class SolarCellBaseSectionWithOptoelectronicProperties(ArchiveSection): ) def normalize(self, archive, logger): - super(SolarCellBaseSectionWithOptoelectronicProperties, self).normalize( - archive, logger - ) + super().normalize(archive, logger) add_solar_cell(archive) add_band_gap(archive, self.bandgap) @@ -1446,7 +1444,7 @@ class SolarCellJV(PlotSection): archive.results.properties.optoelectronic.solar_cell.illumination_intensity = self.light_intensity def normalize(self, archive, logger): - super(SolarCellJV, self).normalize(archive, logger) + super().normalize(archive, logger) add_solar_cell(archive) self.update_results(archive) @@ -1503,7 +1501,7 @@ class SolarCellJVCurve(SolarCellJV): ) def normalize(self, archive, logger): - super(SolarCellJVCurve, self).normalize(archive, logger) + super().normalize(archive, logger) if self.current_density is not None: if self.voltage is not None: ( @@ -1677,7 +1675,7 @@ class SolarCellEQE(PlotSection): ) def normalize(self, archive, logger): - super(SolarCellEQE, self).normalize(archive, logger) + super().normalize(archive, logger) if self.eqe_data_file: with archive.m_context.raw_file(self.eqe_data_file) as f: diff --git a/nomad/datamodel/metainfo/plot.py b/nomad/datamodel/metainfo/plot.py index 3d68afaec0aa52d2cb40d1fb37c24abb1470f089..3618263f8e78b5f19f434f6af17791de836de6a4 100644 --- a/nomad/datamodel/metainfo/plot.py +++ b/nomad/datamodel/metainfo/plot.py @@ -227,7 +227,7 @@ class PlotSection(ArchiveSection): ) def normalize(self, archive, logger): - super(PlotSection, self).normalize(archive, logger) + super().normalize(archive, logger) all_figures = [] plotly_express_annotations = deepcopy( diff --git a/nomad/datamodel/metainfo/simulation/method.py b/nomad/datamodel/metainfo/simulation/method.py index 7dc86ff9a64a45bffffd967d23323ed6883088f6..3cb23695ca1cfda7c8d102f097550f9e697ffe5f 100644 --- a/nomad/datamodel/metainfo/simulation/method.py +++ b/nomad/datamodel/metainfo/simulation/method.py @@ -582,7 +582,7 @@ class SingleElectronState(MSection): # inherit from AtomicOrbitalState? ) super().__setattr__(name, value) - def normalize(self, archive, logger: typing.Optional[Logger]): + def normalize(self, archive, logger: Logger | None): # self.set_degeneracy() pass @@ -735,7 +735,7 @@ class CoreHole(SingleElectronState): self.degeneracy = 1 super().__setattr__(name, value) - def normalize(self, archive, logger: typing.Optional[Logger]): + def normalize(self, archive, logger: Logger | None): super().normalize(archive, logger) self.set_occupation() diff --git a/nomad/datamodel/metainfo/simulation/run.py b/nomad/datamodel/metainfo/simulation/run.py index a7ec930952431411802f520412bde39013b2d70b..99dab5d713c0d0f2740b665e81c4d2fab70c14f6 100644 --- a/nomad/datamodel/metainfo/simulation/run.py +++ b/nomad/datamodel/metainfo/simulation/run.py @@ -281,7 +281,7 @@ class Run(ArchiveSection): calculation = SubSection(sub_section=Calculation.m_def, repeats=True) def normalize(self, archive, logger): - super(Run, self).normalize(archive, logger) + super().normalize(archive, logger) m_package.__init_metainfo__() diff --git a/nomad/datamodel/metainfo/simulation/workflow.py b/nomad/datamodel/metainfo/simulation/workflow.py index a53db33ecc43112b05ad8fe2fbcce57ee243b732..36374441b67d139e489cc44509c4381fa312271f 100644 --- a/nomad/datamodel/metainfo/simulation/workflow.py +++ b/nomad/datamodel/metainfo/simulation/workflow.py @@ -160,9 +160,9 @@ class SimulationWorkflow(Workflow): def normalize(self, archive, logger): super().normalize(archive, logger) - self._calculations: List[Calculation] = [] - self._systems: List[System] = [] - self._methods: List[Method] = [] + self._calculations: list[Calculation] = [] + self._systems: list[System] = [] + self._methods: list[Method] = [] try: self._calculations = archive.run[-1].calculation self._systems = archive.run[-1].system diff --git a/nomad/datamodel/metainfo/workflow.py b/nomad/datamodel/metainfo/workflow.py index 03a223820ebcfea17454f8a0b2f5f91b59c4e1a6..a77f0af398ff271e2fae073b16230995a311bfd1 100644 --- a/nomad/datamodel/metainfo/workflow.py +++ b/nomad/datamodel/metainfo/workflow.py @@ -45,7 +45,7 @@ class Link(ArchiveSection): ) def normalize(self, archive, logger): - super(Link, self).normalize(archive, logger) + super().normalize(archive, logger) if not self.name and self.section: self.name = getattr(self.section, 'name', None) @@ -97,7 +97,7 @@ class TaskReference(Task): ) def normalize(self, archive, logger): - super(TaskReference, self).normalize(archive, logger) + super().normalize(archive, logger) if not self.name and self.task: self.name = self.task.name @@ -123,7 +123,7 @@ class Workflow(Task, EntryData): ) def normalize(self, archive, logger): - super(Workflow, self).normalize(archive, logger) + super().normalize(archive, logger) from nomad.datamodel import EntryArchive diff --git a/nomad/datamodel/results.py b/nomad/datamodel/results.py index e5a4135e61b474dbdb75e0225c7020c9688fe8a3..75c1e962798ff5f3357a217bb18d405d4b995f3e 100644 --- a/nomad/datamodel/results.py +++ b/nomad/datamodel/results.py @@ -183,7 +183,7 @@ def get_formula_iupac(formula: str) -> str: return None if formula is None else Formula(formula).format('iupac') -def available_properties(root: MSection) -> List[str]: +def available_properties(root: MSection) -> list[str]: """Returns a list of property names that are available in results.properties. Args: @@ -220,7 +220,7 @@ def available_properties(root: MSection) -> List[str]: 'optoelectronic.solar_cell': 'solar_cell', 'electronic.density_charge': 'density_charge', } - available_properties: List[str] = [] + available_properties: list[str] = [] for path, shortcut in available_property_names.items(): for _ in traverse_reversed(root, path.split('.')): available_properties.append(shortcut) @@ -982,7 +982,7 @@ class CoreHole(CoreHoleRun): 'ms_quantum_symbol', ] - def normalize(self, archive, logger: Optional[Logger]): + def normalize(self, archive, logger: Logger | None): super().normalize(archive, logger) # TODO: replace this for a more dynamic mapping self.set_l_quantum_symbol() diff --git a/nomad/datamodel/util.py b/nomad/datamodel/util.py index 96a1e7a12ca9774867a425e005881a4da124e1ae..ce8b7580cbe13ae45c052fb82297489796edca69 100644 --- a/nomad/datamodel/util.py +++ b/nomad/datamodel/util.py @@ -17,7 +17,8 @@ # import math import re -from typing import Callable, Any +from typing import Any +from collections.abc import Callable import numpy as np diff --git a/nomad/doi.py b/nomad/doi.py index c5b111678e735e223aa6d6847387c9975c4648f9..994a840896793da98985af5dd003af1aea74963a 100644 --- a/nomad/doi.py +++ b/nomad/doi.py @@ -54,7 +54,7 @@ def edit_doi_url(doi: str, url: str = None): if url is None: url = _create_dataset_url(doi) - doi_url = '%s/doi/%s' % (config.datacite.mds_host, doi) + doi_url = f'{config.datacite.mds_host}/doi/{doi}' headers = {'Content-Type': 'text/plain;charset=UTF-8'} data = f'doi={doi}\nurl={url}' response = requests.put(doi_url, headers=headers, data=data, **_requests_args()) @@ -140,8 +140,8 @@ class DOI(Document): except NotUniqueError: counter += 1 - doi.metadata_url = '%s/metadata/%s' % (config.datacite.mds_host, doi_str) - doi.doi_url = '%s/doi/%s' % (config.datacite.mds_host, doi_str) + doi.metadata_url = f'{config.datacite.mds_host}/metadata/{doi_str}' + doi.doi_url = f'{config.datacite.mds_host}/doi/{doi_str}' doi.state = 'created' doi.create_time = create_time doi.url = _create_dataset_url(doi_str) @@ -233,7 +233,7 @@ class DOI(Document): def make_findable(self): if config.datacite.enabled: assert self.state == 'draft', 'can only make drafts findable' - body = ('doi=%s\nurl=%s' % (self.doi, self.url)).encode('utf-8') + body = (f'doi={self.doi}\nurl={self.url}').encode() response = None try: diff --git a/nomad/files.py b/nomad/files.py index 1ef3f62257bca8d52692c009c433591ffec3d9c2..f9c98ecb86b644bf68119ba866f8b07806222848 100644 --- a/nomad/files.py +++ b/nomad/files.py @@ -50,14 +50,13 @@ from typing import ( IO, Set, Dict, - Iterable, - Iterator, List, Tuple, Any, NamedTuple, - Callable, ) +from collections.abc import Callable +from collections.abc import Iterable, Iterator from pydantic import BaseModel from datetime import datetime import os.path @@ -172,7 +171,7 @@ class DirectoryObject(PathObject): if create and not os.path.isdir(self.os_path): os.makedirs(self.os_path) - def join_dir(self, path, create: bool = False) -> 'DirectoryObject': + def join_dir(self, path, create: bool = False) -> DirectoryObject: return DirectoryObject(os.path.join(self.os_path, path), create) def join_file(self, path, create_dir: bool = False) -> PathObject: @@ -284,14 +283,14 @@ class BrowsableFileSource(FileSource, metaclass=ABCMeta): """Opens a file by the specified path.""" raise NotImplementedError() - def directory_list(self, path: str) -> List[str]: + def directory_list(self, path: str) -> list[str]: """ Returns a list of directory contents, located in the directory denoted by `path` in this file source. """ raise NotImplementedError() - def sub_source(self, path: str) -> 'BrowsableFileSource': + def sub_source(self, path: str) -> BrowsableFileSource: """ Creates a new instance of :class:`BrowsableFileSource` which just contains the files located under the specified path. @@ -378,12 +377,12 @@ class DiskFileSource(BrowsableFileSource): assert is_safe_relative_path(path) return open(os.path.join(self.base_path, path), mode) - def directory_list(self, path: str) -> List[str]: + def directory_list(self, path: str) -> list[str]: assert is_safe_relative_path(path) sub_path = os.path.join(self.base_path, path) return os.listdir(sub_path) - def sub_source(self, path: str) -> 'DiskFileSource': + def sub_source(self, path: str) -> DiskFileSource: assert is_safe_relative_path(path) return DiskFileSource(self.base_path, path) @@ -398,7 +397,7 @@ class ZipFileSource(BrowsableFileSource): assert is_safe_relative_path(sub_path) self.zip_file = zip_file self.sub_path = sub_path - self._namelist: List[str] = zip_file.namelist() + self._namelist: list[str] = zip_file.namelist() def to_streamed_files(self) -> Iterable[StreamedFile]: path_prefix = '' if not self.sub_path else self.sub_path + os.path.sep @@ -421,7 +420,7 @@ class ZipFileSource(BrowsableFileSource): return io.TextIOWrapper(f) return f - def directory_list(self, path: str) -> List[str]: + def directory_list(self, path: str) -> list[str]: path_prefix = '' if not path else path + os.path.sep found = set() for path2 in self._namelist: @@ -429,7 +428,7 @@ class ZipFileSource(BrowsableFileSource): found.add(path2.split(os.path.sep)[0]) return sorted(found) - def sub_source(self, path: str) -> 'ZipFileSource': + def sub_source(self, path: str) -> ZipFileSource: assert is_safe_relative_path(path), 'Unsafe path provided' if self.sub_path: assert path.startswith(self.sub_path + os.path.sep), ( @@ -452,8 +451,7 @@ class CombinedFileSource(FileSource): def to_streamed_files(self) -> Iterable[StreamedFile]: for file_source in self.file_sources: - for streamed_file in file_source.to_streamed_files(): - yield streamed_file + yield from file_source.to_streamed_files() def to_disk( self, destination_dir: str, move_files: bool = False, overwrite: bool = False @@ -484,13 +482,13 @@ class StandardJSONDecoder(json.JSONDecoder): return d -def json_to_streamed_file(json_dict: Dict[str, Any], path: str) -> StreamedFile: +def json_to_streamed_file(json_dict: dict[str, Any], path: str) -> StreamedFile: """Converts a json dictionary structure to a :class:`StreamedFile`.""" json_bytes = json.dumps(json_dict, indent=2, cls=StandardJSONEncoder).encode() return StreamedFile(path=path, f=io.BytesIO(json_bytes), size=len(json_bytes)) -def create_zipstream_content(streamed_files: Iterable[StreamedFile]) -> Iterable[Dict]: +def create_zipstream_content(streamed_files: Iterable[StreamedFile]) -> Iterable[dict]: """ Generator which "casts" a sequence of StreamedFiles to a sequence of dictionaries, of the form which is required by the `zipstream` library, i.e. dictionaries with keys @@ -621,7 +619,7 @@ class UploadFiles(DirectoryObject, metaclass=ABCMeta): def to_staging_upload_files( self, create: bool = False, include_archive: bool = False - ) -> 'StagingUploadFiles': + ) -> StagingUploadFiles: """Casts to or creates corresponding staging upload files or returns None.""" raise NotImplementedError() @@ -787,7 +785,7 @@ class StagingUploadFiles(UploadFiles): def to_staging_upload_files( self, create: bool = False, include_archive: bool = False - ) -> 'StagingUploadFiles': + ) -> StagingUploadFiles: return self @property @@ -929,7 +927,7 @@ class StagingUploadFiles(UploadFiles): path: str, target_dir: str = '', cleanup_source_file_and_dir: bool = False, - updated_files: Set[str] = None, + updated_files: set[str] = None, ) -> None: """ Adds the file or folder specified by `path` to this upload, in the raw directory @@ -974,7 +972,7 @@ class StagingUploadFiles(UploadFiles): extract_file(path, tmp_dir, compression_format, remove_archive=False) # Determine what to merge - elements_to_merge: Iterable[Tuple[str, List[str], List[str]]] = [] + elements_to_merge: Iterable[tuple[str, list[str], list[str]]] = [] if is_dir: # Directory source_dir = path @@ -1054,7 +1052,7 @@ class StagingUploadFiles(UploadFiles): if os.path.exists(parent_dir) and not os.listdir(parent_dir): shutil.rmtree(parent_dir) - def delete_rawfiles(self, path, updated_files: Set[str] = None): + def delete_rawfiles(self, path, updated_files: set[str] = None): assert is_safe_relative_path(path) raw_os_path = os.path.join(self.os_path, 'raw') os_path = os.path.join(raw_os_path, path) @@ -1083,7 +1081,7 @@ class StagingUploadFiles(UploadFiles): path_to_existing_file, path_to_target_file, copy_or_move, - updated_files: Set[str] = None, + updated_files: set[str] = None, ): assert is_safe_relative_path(path_to_existing_file) assert is_safe_relative_path(path_to_target_file) @@ -1145,7 +1143,7 @@ class StagingUploadFiles(UploadFiles): def pack( self, - entries: List[datamodel.EntryMetadata], + entries: list[datamodel.EntryMetadata], with_embargo: bool, create: bool = True, include_raw: bool = True, @@ -1173,7 +1171,7 @@ class StagingUploadFiles(UploadFiles): # freeze the upload assert not self.is_frozen, 'Cannot pack an upload that is packed, or packing.' - with open(self._frozen_file.os_path, 'wt') as f: + with open(self._frozen_file.os_path, 'w') as f: f.write('frozen') # Check embargo flag consistency @@ -1211,7 +1209,7 @@ class StagingUploadFiles(UploadFiles): def _pack_archive_files( self, target_dir: DirectoryObject, - entries: List[datamodel.EntryMetadata], + entries: list[datamodel.EntryMetadata], access: str, other_access: str, ): @@ -1307,7 +1305,7 @@ class StagingUploadFiles(UploadFiles): entry_relative_dir = entry_dir[len(self._raw_dir.os_path) + 1 :] file_count = 0 - aux_files: List[str] = [] + aux_files: list[str] = [] dir_elements = os.listdir(entry_dir) dir_elements.sort() for dir_element in dir_elements: @@ -1378,7 +1376,7 @@ class StagingUploadFiles(UploadFiles): class PublicUploadFiles(UploadFiles): def __init__(self, upload_id: str, create: bool = False): super().__init__(upload_id, create) - self._directories: Dict[str, Dict[str, RawPathInfo]] = None + self._directories: dict[str, dict[str, RawPathInfo]] = None self._raw_zip_file_object: PathObject = None self._raw_zip_file: zipfile.ZipFile = None self._archive_msg_file_object: PathObject = None @@ -1536,7 +1534,7 @@ class PublicUploadFiles(UploadFiles): def to_staging_upload_files( self, create: bool = False, include_archive: bool = False - ) -> 'StagingUploadFiles': + ) -> StagingUploadFiles: exists = StagingUploadFiles.exists_for(self.upload_id) if exists: if create: @@ -1581,7 +1579,7 @@ class PublicUploadFiles(UploadFiles): if self._directories is None: self._directories = dict() self._directories[''] = {} # Root folder - directory_sizes: Dict[str, int] = {} + directory_sizes: dict[str, int] = {} # Add file RawPathInfo objects and calculate directory sizes try: zf = self._open_raw_zip_file() diff --git a/nomad/graph/graph_reader.py b/nomad/graph/graph_reader.py index bf4cfc52692797197eda342d2e2a806325d7fb49..9370c4bf2dd9df4416c508c57aadb59e73baf8a9 100644 --- a/nomad/graph/graph_reader.py +++ b/nomad/graph/graph_reader.py @@ -27,7 +27,8 @@ import re from collections.abc import AsyncIterator, Iterator from contextlib import contextmanager from threading import Lock -from typing import Any, Callable, Type, Union +from typing import Any, Type, Union +from collections.abc import Callable import orjson from cachetools import TTLCache @@ -597,7 +598,7 @@ def _normalise_required( config: RequestConfig, *, key: str = None, - reader_type: Type[GeneralReader] = None, + reader_type: type[GeneralReader] = None, ): """ Normalise the required dictionary. @@ -715,7 +716,7 @@ def _normalise_required( def _parse_required( - required_query: dict | str, reader_type: Type[GeneralReader] + required_query: dict | str, reader_type: type[GeneralReader] ) -> tuple[dict | RequestConfig, RequestConfig]: # extract global config if present # do not modify the original dict as the same dict may be used elsewhere @@ -1624,7 +1625,7 @@ class MongoReader(GeneralReader): continue async def offload_read( - reader_cls: Type[GeneralReader], *args, read_list=False + reader_cls: type[GeneralReader], *args, read_list=False ): try: with ( @@ -2164,9 +2165,7 @@ class UserReader(MongoReader): @functools.cached_property def datasets(self): return Dataset.m_def.a_mongo.objects( - dataset_id__in=set( - v for e in self.entries if e.datasets for v in e.datasets - ) + dataset_id__in={v for e in self.entries if e.datasets for v in e.datasets} ) # noinspection PyMethodOverriding diff --git a/nomad/graph/model.py b/nomad/graph/model.py index bc780263c9d104e56ec67a7dc44287d26e63fc2b..0358df9756165728654fbdacb0a01ef764718a28 100644 --- a/nomad/graph/model.py +++ b/nomad/graph/model.py @@ -137,7 +137,7 @@ class ResolveType(Enum): return self.value -def check_pattern(data: Optional[frozenset[str]]) -> Optional[frozenset[str]]: +def check_pattern(data: frozenset[str] | None) -> frozenset[str] | None: if data is not None: for value in data: assert re.match(r'^[*?+a-zA-z_\d./]*$', value) is not None @@ -155,7 +155,7 @@ class RequestConfig(BaseModel): Each field can be handled differently. """ - property_name: Optional[str] = Field( + property_name: str | None = Field( None, description=""" The name of the current field, either a quantity or a subsection. @@ -172,7 +172,7 @@ class RequestConfig(BaseModel): The `*` is a shortcut of `plain`. """, ) - include: Annotated[Optional[frozenset[str]], AfterValidator(check_pattern)] = Field( + include: Annotated[frozenset[str] | None, AfterValidator(check_pattern)] = Field( None, description=""" A list of patterns to match the quantities and subsections of the current section. @@ -181,7 +181,7 @@ class RequestConfig(BaseModel): """, ) - exclude: Annotated[Optional[frozenset[str]], AfterValidator(check_pattern)] = Field( + exclude: Annotated[frozenset[str] | None, AfterValidator(check_pattern)] = Field( None, description=""" A list of patterns to match the quantities and subsections of the current section. @@ -263,7 +263,7 @@ class RequestConfig(BaseModel): If `none`, no definition will be included. """, ) - index: Optional[Union[tuple[int], tuple[Optional[int], Optional[int]]]] = Field( + index: tuple[int] | tuple[int | None, int | None] | None = Field( None, description=""" The start and end index of the current field if it is a list. @@ -279,18 +279,16 @@ class RequestConfig(BaseModel): This field only applies to the target section only, i.e., it does not propagate to its children. """, ) - pagination: Optional[ - Union[ - dict, - DatasetPagination, - EntryProcDataPagination, - MetadataPagination, - MetainfoPagination, - RawDirPagination, - UploadProcDataPagination, - UserGroupPagination, - ] - ] = Field( + pagination: None | ( + dict + | DatasetPagination + | EntryProcDataPagination + | MetadataPagination + | MetainfoPagination + | RawDirPagination + | UploadProcDataPagination + | UserGroupPagination + ) = Field( None, description=""" The pagination configuration used for MongoDB search. @@ -299,17 +297,15 @@ class RequestConfig(BaseModel): Please refer to `DatasetPagination`, `UploadProcDataPagination`, `MetadataPagination` for details. """, ) - query: Optional[ - Union[ - dict, - DatasetQuery, - EntryQuery, - Metadata, - MetainfoQuery, - UploadProcDataQuery, - UserGroupQuery, - ] - ] = Field( + query: None | ( + dict + | DatasetQuery + | EntryQuery + | Metadata + | MetainfoQuery + | UploadProcDataQuery + | UserGroupQuery + ) = Field( None, description=""" The query configuration used for either mongo or elastic search. diff --git a/nomad/groups.py b/nomad/groups.py index 2805273b434b9d254737a8dd4231c40b2d98efb9..b8d89b612f25a148f425a48f39c7925763dcb867 100644 --- a/nomad/groups.py +++ b/nomad/groups.py @@ -18,7 +18,8 @@ from __future__ import annotations -from typing import Iterable, Optional, Union +from typing import Optional, Union +from collections.abc import Iterable from mongoengine import Document, ListField, Q, QuerySet, StringField @@ -41,7 +42,7 @@ class MongoUserGroup(Document): meta = {'indexes': ['group_name', 'owner', 'members']} @classmethod - def q_by_ids(cls, group_ids: Union[str, Iterable[str]]) -> Q: + def q_by_ids(cls, group_ids: str | Iterable[str]) -> Q: """ Returns UserGroup Q for group_ids. """ @@ -51,7 +52,7 @@ class MongoUserGroup(Document): return Q(group_id__in=group_ids) @classmethod - def q_by_user_id(cls, user_id: Optional[str]) -> Q: + def q_by_user_id(cls, user_id: str | None) -> Q: """ Returns UserGroup Q where user_id is owner or member, or None. @@ -88,7 +89,7 @@ class MongoUserGroup(Document): return cls.objects(q) @classmethod - def get_ids_by_user_id(cls, user_id: Optional[str], include_all=True) -> list[str]: + def get_ids_by_user_id(cls, user_id: str | None, include_all=True) -> list[str]: """ Returns ids of all user groups where user_id is owner or member. @@ -105,10 +106,10 @@ class MongoUserGroup(Document): def create_user_group( *, - group_id: Optional[str] = None, - group_name: Optional[str] = None, - owner: Optional[str] = None, - members: Optional[Iterable[str]] = None, + group_id: str | None = None, + group_name: str | None = None, + owner: str | None = None, + members: Iterable[str] | None = None, ) -> MongoUserGroup: user_group = MongoUserGroup( group_id=group_id, group_name=group_name, owner=owner, members=members @@ -134,7 +135,7 @@ def get_user_ids_by_group_ids(group_ids: list[str]) -> set[str]: return user_ids -def get_user_group(group_id: str) -> Optional[MongoUserGroup]: +def get_user_group(group_id: str) -> MongoUserGroup | None: q = MongoUserGroup.q_by_ids(group_id) return MongoUserGroup.objects(q).first() diff --git a/nomad/infrastructure.py b/nomad/infrastructure.py index 861407157991d3eaa273418072f20d2d49392923..ae53636d9657050d72d3396e91cb6dd49c942f58 100644 --- a/nomad/infrastructure.py +++ b/nomad/infrastructure.py @@ -336,7 +336,7 @@ class KeycloakUserManagement(UserManagement): def __create_username(self, user): if user.first_name is not None and user.last_name is not None: - user.username = '%s%s' % (user.first_name[:1], user.last_name) + user.username = f'{user.first_name[:1]}{user.last_name}' elif user.last_name is not None: user.username = user.last_name elif '@' in user.username: diff --git a/nomad/metainfo/elasticsearch_extension.py b/nomad/metainfo/elasticsearch_extension.py index 9efaccd506fff925a1fe515219f71c4c1038c452..f3a353820051400745b2749c4f9d795391cec211 100644 --- a/nomad/metainfo/elasticsearch_extension.py +++ b/nomad/metainfo/elasticsearch_extension.py @@ -162,7 +162,6 @@ from collections import defaultdict from typing import ( TYPE_CHECKING, Any, - Callable, DefaultDict, Dict, List, @@ -172,6 +171,7 @@ from typing import ( Union, cast, ) +from collections.abc import Callable from elasticsearch_dsl import Q from nomad import utils @@ -226,13 +226,13 @@ class DocumentType: self.name = name self.id_field = id_field self.root_section_def = None - self.mapping: Dict[str, Any] = None - self.indexed_properties: Set[Definition] = set() - self.nested_object_keys: List[str] = [] - self.nested_sections: List[SearchQuantity] = [] - self.quantities: Dict[str, SearchQuantity] = {} - self.suggestions: Dict[str, Elasticsearch] = {} - self.metrics: Dict[str, Tuple[str, SearchQuantity]] = {} + self.mapping: dict[str, Any] = None + self.indexed_properties: set[Definition] = set() + self.nested_object_keys: list[str] = [] + self.nested_sections: list[SearchQuantity] = [] + self.quantities: dict[str, SearchQuantity] = {} + self.suggestions: dict[str, Elasticsearch] = {} + self.metrics: dict[str, tuple[str, SearchQuantity]] = {} def _reset(self): self.indexed_properties.clear() @@ -302,7 +302,7 @@ class DocumentType: return False - kwargs: Dict[str, Any] = dict( + kwargs: dict[str, Any] = dict( with_meta=False, include_defaults=True, include_derived=True, @@ -380,7 +380,7 @@ class DocumentType: auto_include_subsections: bool = False, repeats: bool = False, ): - mappings: Dict[str, Any] = {} + mappings: dict[str, Any] = {} if self == material_type and prefix is None: mappings['n_entries'] = {'type': 'integer'} @@ -541,14 +541,13 @@ class DocumentType: name = sub_section_def.name repeats = sub_section_def.repeats full_name = f'{prefix}.{name}' if prefix else name - for item in get_all_quantities( + yield from get_all_quantities( sub_section_def.sub_section, full_name, new_branch, repeats, max_level - 1, - ): - yield item + ) quantities_dynamic = {} for package in packages_from_plugins.values(): @@ -850,19 +849,19 @@ class Elasticsearch(DefinitionAnnotation): def __init__( self, doc_type: DocumentType = entry_type, - mapping: Union[str, Dict[str, Any]] = None, + mapping: str | dict[str, Any] = None, field: str = None, es_field: str = None, value: Callable[[MSectionBound], Any] = None, index: bool = True, - values: List[str] = None, + values: list[str] = None, default_aggregation_size: int = None, - metrics: Dict[str, str] = None, + metrics: dict[str, str] = None, many_all: bool = False, auto_include_subsections: bool = False, nested: bool = False, - suggestion: Union[str, Callable[[MSectionBound], Any]] = None, - variants: Optional[Callable[[str], List[str]]] = None, + suggestion: str | Callable[[MSectionBound], Any] = None, + variants: Callable[[str], list[str]] | None = None, normalizer: Callable[[Any], Any] = None, es_query: str = 'match', _es_field: str = None, @@ -910,7 +909,7 @@ class Elasticsearch(DefinitionAnnotation): self.doc_type = doc_type self.value = value self.index = index - self._mapping: Dict[str, Any] = None + self._mapping: dict[str, Any] = None self.default_aggregation_size = default_aggregation_size self.values = values self.metrics = metrics @@ -933,11 +932,11 @@ class Elasticsearch(DefinitionAnnotation): self.default_aggregation_size = len(self._values) @property - def mapping(self) -> Dict[str, Any]: + def mapping(self) -> dict[str, Any]: if self._mapping is not None: return self._mapping - def compute_mapping(quantity: Quantity) -> Dict[str, Any]: + def compute_mapping(quantity: Quantity) -> dict[str, Any]: """Used to generate an ES mapping based on the quantity definition if no custom mapping is provided. """ @@ -991,7 +990,7 @@ class Elasticsearch(DefinitionAnnotation): return self._mapping @property - def fields(self) -> Dict[str, Any]: + def fields(self) -> dict[str, Any]: if self._es_field == '' or self._es_field is None: return {} @@ -1216,12 +1215,12 @@ def index_entry(entry: MSection, **kwargs): index_entries([entry], **kwargs) -def index_entries_with_materials(entries: List, refresh: bool = False): +def index_entries_with_materials(entries: list, refresh: bool = False): index_entries(entries, refresh=refresh) update_materials(entries, refresh=refresh) -def index_entries(entries: List, refresh: bool = False) -> Dict[str, str]: +def index_entries(entries: list, refresh: bool = False) -> dict[str, str]: """ Upserts the given entries in the entry index. Optionally updates the materials index as well. Returns a dictionary of the format {entry_id: error_message} for all entries @@ -1258,7 +1257,7 @@ def index_entries(entries: List, refresh: bool = False) -> Dict[str, str]: exc_info=e, ) - timer_kwargs: Dict[str, Any] = {} + timer_kwargs: dict[str, Any] = {} try: import json @@ -1287,7 +1286,7 @@ def index_entries(entries: List, refresh: bool = False) -> Dict[str, str]: return rv -def update_materials(entries: List, refresh: bool = False): +def update_materials(entries: list, refresh: bool = False): # split into reasonably sized problems if len(entries) > config.elastic.bulk_size: for entries_part in [ @@ -1381,7 +1380,7 @@ def update_materials(entries: List, refresh: bool = False): # have the ammount of entries in all these materials roughly match the desired bulk size. # Using materials as a measure might not be good enough, if a single material has # lots of nested entries. - _actions_and_docs_bulks: List[List[Any]] = [] + _actions_and_docs_bulks: list[list[Any]] = [] _n_entries_in_bulk = [0] def add_action_or_doc(action_or_doc): @@ -1509,7 +1508,7 @@ def update_materials(entries: List, refresh: bool = False): all_n_entries += material_doc['n_entries'] # Execute the created actions in bulk. - timer_kwargs: Dict[str, Any] = {} + timer_kwargs: dict[str, Any] = {} try: import json @@ -1570,7 +1569,7 @@ def get_searchable_quantity_value_field( def create_dynamic_quantity_annotation( quantity_def: Quantity, doc_type: DocumentType = None -) -> Optional[Elasticsearch]: +) -> Elasticsearch | None: """Given a quantity definition, this function will return the corresponding ES annotation if one can be built. """ @@ -1672,7 +1671,7 @@ def create_searchable_quantity( return searchable_quantity -def parse_quantity_name(name: str) -> Tuple[str, Optional[str], Optional[str]]: +def parse_quantity_name(name: str) -> tuple[str, str | None, str | None]: """Used to parse a quantity name into three parts: - path: Path in the schema - schema (optional): Schema identifider diff --git a/nomad/metainfo/metainfo.py b/nomad/metainfo/metainfo.py index e957cd65af8c153b75886eee20478228f3f59ec4..3c155c8d6af0ec6e6a2ed4c060468a72dbf68031 100644 --- a/nomad/metainfo/metainfo.py +++ b/nomad/metainfo/metainfo.py @@ -29,11 +29,11 @@ from copy import deepcopy from functools import wraps from typing import ( Any, - Callable as TypingCallable, TypeVar, cast, Literal, ) +from collections.abc import Callable as TypingCallable from urllib.parse import urlsplit, urlunsplit import docstring_parser diff --git a/nomad/metainfo/mongoengine_extension.py b/nomad/metainfo/mongoengine_extension.py index d34e158aaef013fde73d70f3d41deb09ac7f3a66..3d98ea4a8b63bb61f36cfbdf3832982c4bb74d51 100644 --- a/nomad/metainfo/mongoengine_extension.py +++ b/nomad/metainfo/mongoengine_extension.py @@ -108,8 +108,8 @@ class MongoDocument(SectionAnnotation): raise NotImplementedError def create_model_recursive(section, level): - indexes: List[str] = [] - dct: Dict[str, Any] = {} + indexes: list[str] = [] + dct: dict[str, Any] = {} # Add quantities to model for quantity in section.all_quantities.values(): diff --git a/nomad/metainfo/pydantic_extension.py b/nomad/metainfo/pydantic_extension.py index 0013b809d2f550f5f224f96deb9c1972f8ea30a8..8aee87b63121022518bb76e467e030017e8f8dee 100644 --- a/nomad/metainfo/pydantic_extension.py +++ b/nomad/metainfo/pydantic_extension.py @@ -55,7 +55,7 @@ class PydanticModel(DefinitionAnnotation): """ def __init__(self): - self.model: Type[BaseModel] = None + self.model: type[BaseModel] = None def to_pydantic(self, section): """Returns the pydantic model instance for the given section.""" diff --git a/nomad/metainfo/util.py b/nomad/metainfo/util.py index 435fb123ac5c8c597350b0e929c3d38505744db4..f542325e24cc13fc0f1a51a9fda358c7936cf501 100644 --- a/nomad/metainfo/util.py +++ b/nomad/metainfo/util.py @@ -312,7 +312,7 @@ def get_namefit(name: str, concept_name: str, name_any: bool = False) -> int: return len(concept_name) + match_count - uppercase_count -def resolve_variadic_name(definitions: dict, name: str, hint: Optional[str] = None): +def resolve_variadic_name(definitions: dict, name: str, hint: str | None = None): """ Resolves a property name with variadic patterns to its corresponding definition in the schema. diff --git a/nomad/mkdocs.py b/nomad/mkdocs.py index bc3d8eaefb642f339a8f67a86954930a4230c0fb..25efa82810ed9dcc27304ed57fde48178ae40dbb 100644 --- a/nomad/mkdocs.py +++ b/nomad/mkdocs.py @@ -26,7 +26,8 @@ from enum import Enum from pydantic import BaseModel import os.path from typing import List, Set, Tuple, Any, Optional, Dict -from typing_extensions import Literal, _AnnotatedAlias # type: ignore +from typing_extensions import _AnnotatedAlias # type: ignore +from typing import Literal from inspect import isclass from markdown.extensions.toc import slugify @@ -48,7 +49,7 @@ doc_snippets = { } -def get_field_type_info(field) -> Tuple[str, Set[Any]]: +def get_field_type_info(field) -> tuple[str, set[Any]]: """Used to recursively walk through a type definition, building up a cleaned up type name and returning all of the classes that were used. @@ -62,7 +63,7 @@ def get_field_type_info(field) -> Tuple[str, Set[Any]]: # Notice that pydantic does not store the full type in field.type_, but instead in # field.outer_type_ type_ = field.annotation - type_name: List[str] = [] + type_name: list[str] = [] models = set() def fetch_models(type_, type_name): @@ -129,7 +130,7 @@ def get_field_type_info(field) -> Tuple[str, Set[Any]]: return ''.join(type_name), models -def get_field_description(field) -> Optional[str]: +def get_field_description(field) -> str | None: """Retrieves the description for a pydantic field as a markdown string. Args: @@ -146,7 +147,7 @@ def get_field_description(field) -> Optional[str]: return value -def get_field_default(field) -> Optional[str]: +def get_field_default(field) -> str | None: """Retrieves the default value from a pydantic field as a markdown string. Args: @@ -166,7 +167,7 @@ def get_field_default(field) -> Optional[str]: return default_value -def get_field_options(field) -> Dict[str, Optional[str]]: +def get_field_options(field) -> dict[str, str | None]: """Retrieves a dictionary of value-description pairs from a pydantic field. Args: @@ -176,7 +177,7 @@ def get_field_options(field) -> Dict[str, Optional[str]]: Dictionary containing the possible options and their description for this field. The description may be None indicating that it does not exist. """ - options: Dict[str, Optional[str]] = {} + options: dict[str, str | None] = {} if isclass(field.annotation) and issubclass(field.annotation, Enum): for x in field.annotation: options[str(x.value)] = None @@ -202,7 +203,7 @@ class MyYamlDumper(yaml.Dumper): """ def represent_mapping(self, *args, **kwargs): - node = super(MyYamlDumper, self).represent_mapping(*args, **kwargs) + node = super().represent_mapping(*args, **kwargs) node.flow_style = False return node @@ -227,7 +228,7 @@ def define_env(env): @env.macro def file_contents(path): # pylint: disable=unused-variable - with open(path, 'r') as f: + with open(path) as f: return f.read() @env.macro @@ -249,7 +250,7 @@ def define_env(env): file_path, json_path = path.split(':') file_path = os.path.join(os.path.dirname(__file__), '..', file_path) - with open(file_path, 'rt') as f: + with open(file_path) as f: if file_path.endswith('.yaml'): data = yaml.load(f, Loader=yaml.SafeLoader) elif file_path.endswith('.json'): @@ -267,7 +268,7 @@ def define_env(env): data = data[segment] if filter is not None: - filter = set([item.strip() for item in filter.split(',')]) + filter = {item.strip() for item in filter.split(',')} to_remove = [] for key in data.keys(): if key in filter: @@ -418,13 +419,13 @@ def define_env(env): return metadata - categories: Dict[str, List[ParserEntryPoint]] = {} + categories: dict[str, list[ParserEntryPoint]] = {} for parser in parsers: category_name = getattr(parser, 'code_category', None) category = categories.setdefault(category_name, []) category.append(parser) - def render_category(name: str, category: List[ParserEntryPoint]) -> str: + def render_category(name: str, category: list[ParserEntryPoint]) -> str: return f'## {name}s\n\n' + '\n\n'.join( [render_parser(parser) for parser in category] ) diff --git a/nomad/normalizing/__init__.py b/nomad/normalizing/__init__.py index bcef4f00147eb7a13bb5c05472ec874e0a7693b0..5153fb06a92bc208eff1fe6f06713b6fe4263670 100644 --- a/nomad/normalizing/__init__.py +++ b/nomad/normalizing/__init__.py @@ -37,7 +37,8 @@ There is one ABC for all normalizer: """ import importlib -from typing import Any, Iterator +from typing import Any +from collections.abc import Iterator from collections import UserList from nomad.config import config diff --git a/nomad/normalizing/common.py b/nomad/normalizing/common.py index fc29ab05e55e3d161ae35d739afbc5243282f6e8..75d5c2bdf7935570c29893928482bfa0252baa7c 100644 --- a/nomad/normalizing/common.py +++ b/nomad/normalizing/common.py @@ -38,7 +38,7 @@ from nomad.datamodel.results import ( ) -def wyckoff_sets_from_matid(wyckoff_sets: List[WyckoffSetMatID]) -> List[WyckoffSet]: +def wyckoff_sets_from_matid(wyckoff_sets: list[WyckoffSetMatID]) -> list[WyckoffSet]: """Given a dictionary of wyckoff sets, returns the metainfo equivalent. Args: @@ -65,8 +65,8 @@ def wyckoff_sets_from_matid(wyckoff_sets: List[WyckoffSetMatID]) -> List[Wyckoff def species( - labels: List[str], atomic_numbers: List[int], logger=None -) -> Optional[List[Species]]: + labels: list[str], atomic_numbers: list[int], logger=None +) -> list[Species] | None: """Given a list of atomic labels and atomic numbers, returns the corresponding list of Species objects. @@ -81,7 +81,7 @@ def species( """ if labels is None or atomic_numbers is None: return None - species: Set[str] = set() + species: set[str] = set() species_list = [] for label, atomic_number in zip(labels, atomic_numbers): if label not in species: @@ -130,8 +130,8 @@ def lattice_parameters_from_array(lattice_vectors: NDArray[Any]) -> LatticeParam def cell_from_ase_atoms( atoms: Atoms, - masses: Union[List[float], Dict[Any, Any]] = None, - atom_labels: List[str] = None, + masses: list[float] | dict[Any, Any] = None, + atom_labels: list[str] = None, ) -> Cell: """Extracts Cell metainfo from the given ASE Atoms. Undefined angle values are not stored. @@ -169,7 +169,7 @@ def cell_from_ase_atoms( def structure_from_ase_atoms( - system: Atoms, wyckoff_sets: List[WyckoffSetMatID] = None, logger=None + system: Atoms, wyckoff_sets: list[WyckoffSetMatID] = None, logger=None ) -> Structure: """Returns a populated NOMAD Structure instance from an ase.Atoms-object. diff --git a/nomad/normalizing/material.py b/nomad/normalizing/material.py index fb4e22d7f4a1f432a23817d0f66e617246ba5474..078e4f0452e60c6d3dbe1bcb3ee98352f4c9b8b6 100644 --- a/nomad/normalizing/material.py +++ b/nomad/normalizing/material.py @@ -126,13 +126,13 @@ class MaterialNormalizer: return material - def material_classification(self) -> Dict[str, List[str]]: + def material_classification(self) -> dict[str, list[str]]: try: sec_springer = self.repr_system['springer_material'][0] except Exception: return None - classes: Dict[str, List[str]] = {} + classes: dict[str, list[str]] = {} try: classifications = sec_springer['classification'] except KeyError: @@ -147,9 +147,7 @@ class MaterialNormalizer: classes['compound_class_springer'] = compound_classes return classes - def material_name( - self, symbols: Union[List, NDArray], counts: Union[List, NDArray] - ) -> str: + def material_name(self, symbols: list | NDArray, counts: list | NDArray) -> str: if symbols is None or counts is None: return None name = None diff --git a/nomad/normalizing/metainfo.py b/nomad/normalizing/metainfo.py index 4ab8107ccf11d964e135f81cf3f5e2ee347d4dd7..b76df84ca7d766f35c6556a831a857fec93980bc 100644 --- a/nomad/normalizing/metainfo.py +++ b/nomad/normalizing/metainfo.py @@ -24,7 +24,7 @@ from . import Normalizer class MetainfoNormalizer(Normalizer): - domain: Optional[str] = None + domain: str | None = None def normalize_section(self, archive: EntryArchive, section, logger): normalize = None diff --git a/nomad/normalizing/method.py b/nomad/normalizing/method.py index f8d846d13f2b3bced778bc3f579b89df507511ef..c4423ec72d3f398c46ccea59451f3b8eab808f71 100644 --- a/nomad/normalizing/method.py +++ b/nomad/normalizing/method.py @@ -481,8 +481,8 @@ class MethodNormalizer: # TODO: add normalizer for atom_parameters.label return eos_dict.hash() def calc_k_line_density( - self, k_lattices: List[List[float]], nks: List[int] - ) -> Optional[float]: + self, k_lattices: list[list[float]], nks: list[int] + ) -> float | None: """ Compute the lowest k_line_density value: k_line_density (for a uniformly spaced grid) is the number of k-points per reciprocal length unit @@ -514,7 +514,7 @@ class ElectronicMethod(ABC): self, logger, entry_archive: EntryArchive = None, - methods: List[ArchiveSection] = [None], + methods: list[ArchiveSection] = [None], repr_method: ArchiveSection = None, repr_system: MSection = None, method: Method = None, @@ -591,7 +591,7 @@ class DFTMethod(ElectronicMethod): ) return simulation - def basis_set_type(self, repr_method: ArchiveSection) -> Optional[str]: + def basis_set_type(self, repr_method: ArchiveSection) -> str | None: name = None for em in repr_method.electrons_representation or []: if em.scope: @@ -640,7 +640,7 @@ class DFTMethod(ElectronicMethod): name = '(L)APW+lo' return name - def basis_set_name(self) -> Optional[str]: + def basis_set_name(self) -> str | None: try: name = self._repr_method.basis_set[0].name except Exception: @@ -648,8 +648,8 @@ class DFTMethod(ElectronicMethod): return name def hubbard_kanamori_model( - self, methods: List[ArchiveSection] - ) -> List[HubbardKanamoriModel]: + self, methods: list[ArchiveSection] + ) -> list[HubbardKanamoriModel]: """Generate a list of normalized HubbardKanamoriModel for `results.method`""" hubbard_kanamori_models = [] for sec_method in methods: @@ -845,7 +845,7 @@ class DFTMethod(ElectronicMethod): ) return treatment - def xc_functional_names(self, method_xc_functional: Section) -> Optional[List[str]]: + def xc_functional_names(self, method_xc_functional: Section) -> list[str] | None: if self._repr_method: functionals = set() try: @@ -866,7 +866,7 @@ class DFTMethod(ElectronicMethod): def xc_functional_type( self, - xc_functionals: Optional[list[str]], + xc_functionals: list[str] | None, abbrev_mapping: dict[str, str] = xc_treatments, ) -> str: """Assign the rung on Jacob\'s Ladder based on a set of libxc names. @@ -893,7 +893,7 @@ class DFTMethod(ElectronicMethod): return config.services.unavailable_value return abbrev_mapping[highest_rung_abbrev] - def exact_exchange_mixing_factor(self, xc_functional_names: List[str]): + def exact_exchange_mixing_factor(self, xc_functional_names: list[str]): """Assign the exact exchange mixing factor to `results` section when explicitly stated. Else, fall back on XC functional default.""" @@ -931,7 +931,7 @@ class ExcitedStateMethod(ElectronicMethod): """ExcitedState (GW, BSE, or DFT+GW, DFT+BSE) Method normalized into results.simulation""" def simulation(self) -> Simulation: - xs: Union[None, GW, BSE] = None + xs: None | GW | BSE = None simulation = Simulation() if 'GW' in self._method_name: self._method.method_name = 'GW' @@ -1053,7 +1053,7 @@ class MethodNormalizerBasisSet(ABC): pass @abstractmethod - def setup(self) -> Tuple: + def setup(self) -> tuple: """Used to define a list of mandatory and optional settings for a subclass. @@ -1061,15 +1061,15 @@ class MethodNormalizerBasisSet(ABC): Should return a tuple of two lists: the first one defining mandatory keys and the second one defining optional keys. """ - mandatory: List = [] - optional: List = [] + mandatory: list = [] + optional: list = [] return mandatory, optional class BasisSetFHIAims(MethodNormalizerBasisSet): """Basis set settings for FHI-Aims (code-dependent).""" - def setup(self) -> Tuple: + def setup(self) -> tuple: # Get previously defined values from superclass mandatory, optional = super().setup() @@ -1141,7 +1141,7 @@ class BasisSetFHIAims(MethodNormalizerBasisSet): class BasisSetExciting(MethodNormalizerBasisSet): """Basis set settings for Exciting (code-dependent).""" - def setup(self) -> Tuple: + def setup(self) -> tuple: # Get previously defined values from superclass mandatory, optional = super().setup() @@ -1170,22 +1170,17 @@ class BasisSetExciting(MethodNormalizerBasisSet): for group in groups: label = group.x_exciting_geometry_atom_labels try: - muffin_tin_settings['{}_muffin_tin_radius'.format(label)] = ( - '%.6f' - % ( - group.x_exciting_muffin_tin_radius.to( - ureg.angstrom - ).magnitude - ) + muffin_tin_settings[f'{label}_muffin_tin_radius'] = '%.6f' % ( + group.x_exciting_muffin_tin_radius.to(ureg.angstrom).magnitude ) except Exception: - muffin_tin_settings['{}_muffin_tin_radius'.format(label)] = None + muffin_tin_settings[f'{label}_muffin_tin_radius'] = None try: - muffin_tin_settings['{}_muffin_tin_points'.format(label)] = ( + muffin_tin_settings[f'{label}_muffin_tin_points'] = ( '%d' % group.x_exciting_muffin_tin_points ) except Exception: - muffin_tin_settings['{}_muffin_tin_points'.format(label)] = None + muffin_tin_settings[f'{label}_muffin_tin_points'] = None self.settings['muffin_tin_settings'] = muffin_tin_settings except Exception: pass diff --git a/nomad/normalizing/normalizer.py b/nomad/normalizing/normalizer.py index 53967ed92c41eb86109ef0b964d8de2310d85ec3..82eb23aa250800f302c03f9d2a228438cced1cea 100644 --- a/nomad/normalizing/normalizer.py +++ b/nomad/normalizing/normalizer.py @@ -32,7 +32,7 @@ class Normalizer(metaclass=ABCMeta): not mutate the state of the shared normalizer instance. """ - domain: Optional[str] = 'dft' + domain: str | None = 'dft' """Deprecated: The domain this normalizer should be used in. Default for all normalizer is 'DFT'.""" normalizer_level = 0 """Deprecated: Specifies the order of normalization with respect to other normalizers. Lower level @@ -64,7 +64,7 @@ class SystemBasedNormalizer(Normalizer, metaclass=ABCMeta): self.only_representatives = only_representatives @property - def quantities(self) -> List[str]: + def quantities(self) -> list[str]: return [ 'atom_labels', 'atom_positions', diff --git a/nomad/normalizing/optimade.py b/nomad/normalizing/optimade.py index deb90e98accb52b2c89c7439e46278217119b9c4..d19b0e09853d7e0cfaa1cadd486058de90b53ac8 100644 --- a/nomad/normalizing/optimade.py +++ b/nomad/normalizing/optimade.py @@ -142,7 +142,7 @@ class OptimadeNormalizer(SystemBasedNormalizer): # elements atoms = normalized_atom_labels(nomad_species) atom_count = len(atoms) - atom_counts: Dict[str, int] = {} + atom_counts: dict[str, int] = {} for atom in atoms: current = atom_counts.setdefault(atom, 0) current += 1 diff --git a/nomad/normalizing/results.py b/nomad/normalizing/results.py index 7c24b174ea75407b095a16cdcad6d32c7be3bf00..3e734e6f26049715ad910ec2ed7fba2041e3be61 100644 --- a/nomad/normalizing/results.py +++ b/nomad/normalizing/results.py @@ -535,7 +535,7 @@ class ResultsNormalizer(Normalizer): mapped_data.insert(0, results_data) return mapped_data - def resolve_spectra(self, path: list[str]) -> Optional[list[Spectra]]: + def resolve_spectra(self, path: list[str]) -> list[Spectra] | None: """Returns a section containing the references for a Spectra. This section is then stored under `archive.results.properties.spectroscopic`. @@ -572,7 +572,7 @@ class ResultsNormalizer(Normalizer): def resolve_magnetic_shielding( self, path: list[str] - ) -> Optional[list[MagneticShielding]]: + ) -> list[MagneticShielding] | None: """Returns a section containing the references for the (atomic) Magnetic Shielding. This section is then stored under `archive.results.properties.magnetic`. @@ -599,7 +599,7 @@ class ResultsNormalizer(Normalizer): def resolve_spin_spin_coupling( self, path: list[str] - ) -> Optional[list[SpinSpinCoupling]]: + ) -> list[SpinSpinCoupling] | None: """Returns a section containing the references for the Spin Spin Coupling. This section is then stored under `archive.results.properties.magnetic`. @@ -633,7 +633,7 @@ class ResultsNormalizer(Normalizer): def resolve_magnetic_susceptibility( self, path: list[str] - ) -> Optional[list[MagneticSusceptibility]]: + ) -> list[MagneticSusceptibility] | None: """Returns a section containing the references for the Magnetic Susceptibility. This section is then stored under `archive.results.properties.magnetic`. @@ -769,7 +769,7 @@ class ResultsNormalizer(Normalizer): spct_electronic = spectra return spct_electronic - def band_structure_phonon(self) -> Union[BandStructurePhonon, None]: + def band_structure_phonon(self) -> BandStructurePhonon | None: """Returns a new section containing a phonon band structure. In the case of multiple valid band structures, only the latest one is considered. @@ -797,7 +797,7 @@ class ResultsNormalizer(Normalizer): return None - def dos_phonon(self) -> Union[DOSPhonon, None]: + def dos_phonon(self) -> DOSPhonon | None: """Returns a section containing phonon dos data. In the case of multiple valid data sources, only the latest one is reported. @@ -817,7 +817,7 @@ class ResultsNormalizer(Normalizer): return None - def energy_free_helmholtz(self) -> Union[EnergyFreeHelmholtz, None]: + def energy_free_helmholtz(self) -> EnergyFreeHelmholtz | None: """Returns a section Helmholtz free energy data. In the case of multiple valid data sources, only the latest one is reported. @@ -844,7 +844,7 @@ class ResultsNormalizer(Normalizer): return None - def heat_capacity_constant_volume(self) -> Union[HeatCapacityConstantVolume, None]: + def heat_capacity_constant_volume(self) -> HeatCapacityConstantVolume | None: """Returns a section containing heat capacity data. In the case of multiple valid data sources, only the latest one is reported. @@ -870,7 +870,7 @@ class ResultsNormalizer(Normalizer): return None - def geometry_optimization(self) -> Union[GeometryOptimization, None]: + def geometry_optimization(self) -> GeometryOptimization | None: """Populates both geometry optimization methodology and calculated properties based on the first found geometry optimization workflow. """ @@ -905,7 +905,7 @@ class ResultsNormalizer(Normalizer): return None - def get_md_provenance(self, workflow: Workflow) -> Optional[MolecularDynamics]: + def get_md_provenance(self, workflow: Workflow) -> MolecularDynamics | None: """Retrieves the MD provenance from the given workflow.""" md = None if workflow.m_def.name == 'MolecularDynamics': diff --git a/nomad/normalizing/topology.py b/nomad/normalizing/topology.py index 4f223487d5875bee561d115f4b9b3657d0155546..da0cf2ee5ae501af6ad9a80bc5106bd8ea900b9e 100644 --- a/nomad/normalizing/topology.py +++ b/nomad/normalizing/topology.py @@ -63,7 +63,7 @@ from nomad.normalizing.common import ( conventional_description = 'The conventional cell of the material from which the subsystem is constructed from.' subsystem_description = 'Automatically detected subsystem.' chemical_symbols = np.array(chemical_symbols) -with open(pathlib.Path(__file__).parent / 'data/top_50k_material_ids.json', 'r') as fin: +with open(pathlib.Path(__file__).parent / 'data/top_50k_material_ids.json') as fin: top_50k_material_ids = json.load(fin) @@ -115,8 +115,8 @@ def get_topology_original(atoms=None, archive: EntryArchive = None) -> System: def add_system_info( system: System, - topologies: Dict[str, System], - masses: Union[List[float], Dict[str, float]] = None, + topologies: dict[str, System], + masses: list[float] | dict[str, float] = None, ) -> None: """Given a system with minimal information, attempts to add all values than can be derived. @@ -175,7 +175,7 @@ def add_system_info( def add_system( - system: System, topologies: Dict[str, System], parent: Optional[System] = None + system: System, topologies: dict[str, System], parent: System | None = None ) -> None: """Adds the given system to the topology.""" index = len(topologies) @@ -203,7 +203,7 @@ class TopologyNormalizer: entry_archive, repr_system=repr_system ) - def topology(self, material) -> Optional[List[System]]: + def topology(self, material) -> list[System] | None: """Returns a dictionary that contains all of the topologies mapped by id.""" # If topology already exists (e.g. written by another normalizer), do # not overwrite it. @@ -222,7 +222,7 @@ class TopologyNormalizer: return None - def topology_calculation(self) -> Optional[List[System]]: + def topology_calculation(self) -> list[System] | None: """Extracts the system topology as defined in the original calculation. This topology typically comes from e.g. classical force fields that define a topology for the system. @@ -246,11 +246,11 @@ class TopologyNormalizer: ): return None - topology: Dict[str, System] = {} + topology: dict[str, System] = {} original = get_topology_original(atoms, self.entry_archive) original.atoms_ref = atoms add_system(original, topology) - label_to_indices: Dict[str, list] = defaultdict(list) + label_to_indices: dict[str, list] = defaultdict(list) def add_group(groups, parent=None): if not groups: @@ -302,10 +302,8 @@ class TopologyNormalizer: old_labels.append(instance_indices) else: self.logger.warn( - ( - 'the topology contains entries with the same label but with ' - 'different number of atoms' - ) + 'the topology contains entries with the same label but with ' + 'different number of atoms' ) add_group(groups, original) @@ -328,7 +326,7 @@ class TopologyNormalizer: return list(topology.values()) - def topology_matid(self, material: Material) -> Optional[List[System]]: + def topology_matid(self, material: Material) -> list[System] | None: """ Returns a list of systems that have been identified with MatID. """ @@ -342,7 +340,7 @@ class TopologyNormalizer: return None # Create topology for the original system - topology: Dict[str, System] = {} + topology: dict[str, System] = {} original = get_topology_original(nomad_atoms, self.entry_archive) original.atoms_ref = nomad_atoms add_system(original, topology) @@ -413,7 +411,7 @@ class TopologyNormalizer: conventional_cell, topology, masses=self.masses - if isinstance(self.masses, Dict) + if isinstance(self.masses, dict) else None, ) else: @@ -510,7 +508,7 @@ class TopologyNormalizer: add_system(conv_system, topology, subsystem) add_system_info(conv_system, topology, masses=self.masses) - def _create_subsystem(self, cluster: Cluster) -> Optional[System]: + def _create_subsystem(self, cluster: Cluster) -> System | None: """ Creates a new subsystem as detected by MatID. """ diff --git a/nomad/parsing/artificial.py b/nomad/parsing/artificial.py index 4af5409e77466d443c05f45118b6a5bb69cd16b7..484f97ae15268f3feae2cd449a3ac47b0b0b650e 100644 --- a/nomad/parsing/artificial.py +++ b/nomad/parsing/artificial.py @@ -78,7 +78,7 @@ class TemplateParser(Parser): if logger is not None: logger.debug('received logger') - template_json = json.load(open(mainfile, 'r')) + template_json = json.load(open(mainfile)) loaded_archive = EntryArchive.m_from_dict(template_json) archive.m_add_sub_section(EntryArchive.run, loaded_archive.run[0]) archive.m_add_sub_section(EntryArchive.workflow2, loaded_archive.workflow2) @@ -116,7 +116,7 @@ class ChaosParser(Parser): def parse( self, mainfile: str, archive: EntryArchive, logger=None, child_archives=None ) -> None: - chaos_json = json.load(open(mainfile, 'r')) + chaos_json = json.load(open(mainfile)) if isinstance(chaos_json, str): chaos = chaos_json elif isinstance(chaos_json, dict): diff --git a/nomad/parsing/file_parser/file_parser.py b/nomad/parsing/file_parser/file_parser.py index a9dd56b443c5e55b1be6e99b55fd45c0d3bcaf40..488fb026565f2a62887969b29fd7994ce8a6f6e8 100644 --- a/nomad/parsing/file_parser/file_parser.py +++ b/nomad/parsing/file_parser/file_parser.py @@ -15,7 +15,8 @@ from abc import ABC, abstractmethod import os import pint -from typing import Any, Dict, Callable, IO, Union, List +from typing import Any, Dict, IO, Union, List +from collections.abc import Callable import gzip import bz2 import lzma @@ -39,9 +40,7 @@ class FileParser(ABC): open: function to open file """ - def __init__( - self, mainfile: Union[str, IO] = None, logger=None, open: Callable = None - ): + def __init__(self, mainfile: str | IO = None, logger=None, open: Callable = None): self._mainfile: str = None self._mainfile_obj: IO = None if isinstance(mainfile, str): @@ -54,8 +53,8 @@ class FileParser(ABC): self.logger = logger if logger is not None else get_logger(__name__) # a key is necessary for xml parsers, where parsing is done dynamically self._key: str = None - self._kwargs: Dict[str, Any] = {} - self._results: Dict[str, Any] = None + self._kwargs: dict[str, Any] = {} + self._results: dict[str, Any] = None self._file_handler: Any = None def reset(self): @@ -164,7 +163,7 @@ class FileParser(ABC): self, key: str, default: Any = None, - unit: Union[pint.Unit, pint.Quantity] = None, + unit: pint.Unit | pint.Quantity = None, **kwargs, ): """ @@ -265,9 +264,7 @@ class Parser(ABC): logger = None child_archives = None - def get_mainfile_keys( - self, filename: str, decoded_buffer: str - ) -> Union[bool, List[str]]: + def get_mainfile_keys(self, filename: str, decoded_buffer: str) -> bool | list[str]: """ If child archives are necessary for the entry, a list of keys for the archives are returned. @@ -275,7 +272,7 @@ class Parser(ABC): return True # TODO replace with MSection.m_update_from_dict once it takes in type Quantity? - def parse_section(self, data: Dict[str, Any], root: MSection) -> None: + def parse_section(self, data: dict[str, Any], root: MSection) -> None: """ Write the quantities in data into an archive section. """ @@ -293,7 +290,7 @@ class Parser(ABC): root.m_set(root.m_get_quantity_definition(key), val) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """ Converts the parsed metadata into a dictionary following the nomad archive schema. """ diff --git a/nomad/parsing/file_parser/text_parser.py b/nomad/parsing/file_parser/text_parser.py index 9124a3661262d31e3403eae97654e4fce679612b..6b8804c507a0198e3bb4391b020af63e27232afb 100644 --- a/nomad/parsing/file_parser/text_parser.py +++ b/nomad/parsing/file_parser/text_parser.py @@ -18,7 +18,8 @@ import io import re import numpy as np import pint -from typing import List, Union, Callable, Type, Any +from typing import List, Union, Type, Any +from collections.abc import Callable from nomad.parsing.file_parser import FileParser from nomad.metainfo import Quantity as mQuantity @@ -50,7 +51,7 @@ class ParsePattern: if self._re_pattern is None: head = r'%s[\s\S]*?' % self._head if self._head else '' key = r'%s\s*\:*\=*\s*' % self._key if self._key else '' - self._re_pattern = r'%s%s\s*\:*\=*\s*(%s)%s' % ( + self._re_pattern = r'{}{}\s*\:*\=*\s*({}){}'.format( head, key, self._value, @@ -113,14 +114,14 @@ class Quantity: def __init__( self, - quantity: Union[str, mQuantity], - re_pattern: Union[str, ParsePattern], + quantity: str | mQuantity, + re_pattern: str | ParsePattern, **kwargs, ): self.name: str self.dtype: str self.unit: str - self.shape: List[int] + self.shape: list[int] if isinstance(quantity, str): self.name = quantity self.dtype = None @@ -268,14 +269,14 @@ class TextParser(FileParser): def __init__( self, mainfile: str = None, - quantities: List[Quantity] = None, + quantities: list[Quantity] = None, logger=None, **kwargs, ): if logger is None: logger = get_logger(__name__) super().__init__(mainfile, logger=logger, open=kwargs.get('open', None)) - self._quantities: List[Quantity] = quantities + self._quantities: list[Quantity] = quantities self.findall: bool = kwargs.get('findall', True) self._kwargs = kwargs self._file_length: int = kwargs.get('file_length', 0) @@ -320,7 +321,7 @@ class TextParser(FileParser): return self._quantities @quantities.setter - def quantities(self, val: List[Quantity]): + def quantities(self, val: list[Quantity]): """ Sets the quantities list. """ @@ -393,7 +394,7 @@ class TextParser(FileParser): for key in self.keys(): yield key, self.get(key) - def _add_value(self, quantity: Quantity, value: List[str], units): + def _add_value(self, quantity: Quantity, value: list[str], units): """ Converts the list of parsed blocks into data and apply the corresponding units. """ @@ -417,7 +418,7 @@ class TextParser(FileParser): 'Error setting value', data=dict(quantity=quantity.name) ) - def _parse_quantities(self, quantities: List[Quantity]): + def _parse_quantities(self, quantities: list[Quantity]): """ Parse a list of quantities. """ @@ -592,7 +593,7 @@ class DataTextParser(TextParser): """ def __init__(self, **kwargs): - self._dtype: Type = kwargs.get('dtype', float) + self._dtype: type = kwargs.get('dtype', float) self._mainfile_contents: str = kwargs.get('mainfile_contents', '') super().__init__(**kwargs) diff --git a/nomad/parsing/parser.py b/nomad/parsing/parser.py index f6c234fa97ec3be8cec0ed26d37495d4f8431110..73ec9a4e2ef346dd50dfd43d5b3c05c551e05351 100644 --- a/nomad/parsing/parser.py +++ b/nomad/parsing/parser.py @@ -16,7 +16,8 @@ # limitations under the License. # -from typing import List, Iterable, Dict, Union, Any, IO +from typing import List, Dict, Union, Any, IO +from collections.abc import Iterable from abc import ABCMeta, abstractmethod import re import os @@ -44,7 +45,7 @@ class Parser(metaclass=ABCMeta): name = 'parsers/parser' level = 0 creates_children = False - aliases: List[str] = [] + aliases: list[str] = [] """ Level 0 parsers are run first, then level 1, and so on. Normally the value should be 0, use higher values only when a parser depends on other parsers. @@ -62,7 +63,7 @@ class Parser(metaclass=ABCMeta): buffer: bytes, decoded_buffer: str, compression: str = None, - ) -> Union[bool, Iterable[str]]: + ) -> bool | Iterable[str]: """ Checks if a file is a mainfile for the parser. Should return True or a set of *keys* (non-empty strings) if it is a mainfile, otherwise a falsey value. @@ -95,7 +96,7 @@ class Parser(metaclass=ABCMeta): mainfile: str, archive: EntryArchive, logger=None, - child_archives: Dict[str, EntryArchive] = None, + child_archives: dict[str, EntryArchive] = None, ) -> None: """ Runs the parser on the given mainfile and populates the result in the given @@ -124,7 +125,7 @@ class Parser(metaclass=ABCMeta): pass @classmethod - def main(cls, mainfile, mainfile_keys: List[str] = None): + def main(cls, mainfile, mainfile_keys: list[str] = None): archive = EntryArchive() archive.m_create(EntryMetadata) if mainfile_keys: @@ -216,7 +217,7 @@ class MatchingParser(Parser): level: int = 0, domain='dft', metadata: dict = None, - supported_compressions: List[str] = [], + supported_compressions: list[str] = [], **kwargs, ) -> None: super().__init__() @@ -248,13 +249,13 @@ class MatchingParser(Parser): self._ls = lru_cache(maxsize=16)(lambda directory: os.listdir(directory)) - def read_metadata_file(self, metadata_file: str) -> Dict[str, Any]: + def read_metadata_file(self, metadata_file: str) -> dict[str, Any]: """ Read parser metadata from a yaml file. """ logger = utils.get_logger(__name__) try: - with open(metadata_file, 'r', encoding='UTF-8') as f: + with open(metadata_file, encoding='UTF-8') as f: parser_metadata = yaml.load(f, Loader=yaml.SafeLoader) except Exception as e: logger.warning('failed to read parser metadata', exc_info=e) @@ -269,7 +270,7 @@ class MatchingParser(Parser): buffer: bytes, decoded_buffer: str, compression: str = None, - ) -> Union[bool, Iterable[str]]: + ) -> bool | Iterable[str]: if self._mainfile_binary_header is not None: if self._mainfile_binary_header not in buffer: return False @@ -392,7 +393,7 @@ class MatchingParser(Parser): # TODO remove this after merging hdf5 reference, only for parser compatibility -def to_hdf5(value: Any, f: Union[str, IO], path: str): +def to_hdf5(value: Any, f: str | IO, path: str): with h5py.File(f, 'a') as root: segments = path.rsplit('/', 1) group = root.require_group(segments[0]) if len(segments) == 2 else root @@ -469,7 +470,7 @@ class MatchingParserInterface(MatchingParser): buffer: bytes, decoded_buffer: str, compression: str = None, - ) -> Union[bool, Iterable[str]]: + ) -> bool | Iterable[str]: is_mainfile = super().is_mainfile( filename=filename, mime=mime, @@ -556,7 +557,7 @@ class ArchiveParser(MatchingParser): def parse( self, mainfile: str, archive: EntryArchive, logger=None, child_archives=None ): - with open(mainfile, 'rt') as f: + with open(mainfile) as f: self.parse_file(mainfile, f, archive, logger) self.validate_defintions(archive, logger) diff --git a/nomad/parsing/parsers.py b/nomad/parsing/parsers.py index 60fb18d989beb2a6730038c87ddd0de3c36d1857..3cd806d4b20645dd63308ac0af61709b167cfcc6 100644 --- a/nomad/parsing/parsers.py +++ b/nomad/parsing/parsers.py @@ -58,8 +58,8 @@ except ImportError: def match_parser( - mainfile_path: str, strict=True, parser_name: Optional[str] = None -) -> Tuple[Parser, List[str]]: + mainfile_path: str, strict=True, parser_name: str | None = None +) -> tuple[Parser, list[str]]: """ Performs parser matching. This means it take the given mainfile and potentially opens it with the given callback and tries to identify a parser that can parse @@ -141,7 +141,7 @@ def match_parser( except Exception: pass else: - with open(mainfile_path, 'wt') as text_file: + with open(mainfile_path, 'w') as text_file: text_file.write(content) # TODO: deal with multiple possible parser specs @@ -167,12 +167,12 @@ class ParserContext(Context): def run_parser( mainfile_path: str, parser: Parser, - mainfile_keys: List[str] = None, + mainfile_keys: list[str] = None, logger=None, server_context: bool = False, username: str = None, password: str = None, -) -> List[EntryArchive]: +) -> list[EntryArchive]: """ Parses a file, given the path, the parser, and mainfile_keys, as returned by :func:`match_parser`, and returns the resulting EntryArchive objects. Parsers that have @@ -288,7 +288,7 @@ if config.process.use_empty_parsers: parsers.append(BrokenParser()) """ A dict to access parsers by name. Usually 'parsers/<...>', e.g. 'parsers/vasp'. """ -parser_dict: Dict[str, Parser] = { +parser_dict: dict[str, Parser] = { parser.name: parser for parser in parsers + empty_parsers } # Register also aliases diff --git a/nomad/parsing/tabular.py b/nomad/parsing/tabular.py index 7847032a47a42eb290827e3b5de279c916685a6b..55c6364902853ff362b335701ae3db1bbba7ea01 100644 --- a/nomad/parsing/tabular.py +++ b/nomad/parsing/tabular.py @@ -16,7 +16,9 @@ # limitations under the License. # import os -from typing import List, Dict, Callable, Set, Any, Tuple, Iterator, Union, Iterable +from typing import List, Dict, Set, Any, Tuple, Union +from collections.abc import Callable +from collections.abc import Iterator, Iterable import pandas as pd import re @@ -76,7 +78,7 @@ def create_archive(entry_dict, context, file_name, file_type, logger): ) -def traverse_to_target_data_file(section, path_list: List[str]): +def traverse_to_target_data_file(section, path_list: list[str]): if len(path_list) == 0 and (isinstance(section, str) or section is None): return section else: @@ -106,7 +108,7 @@ class TableData(ArchiveSection): ) def normalize(self, archive, logger): - super(TableData, self).normalize(archive, logger) + super().normalize(archive, logger) if self.fill_archive_from_datafile: for quantity_def in self.m_def.all_quantities.values(): @@ -158,7 +160,7 @@ class TableData(ArchiveSection): mapping_options = annotation.mapping_options if mapping_options: - row_sections_counter: Dict[str, int] = {} + row_sections_counter: dict[str, int] = {} for mapping_option in mapping_options: try: file_mode = mapping_option.file_mode @@ -554,10 +556,10 @@ def append_section_to_subsection( def _parse_row_mode(main_section, row_sections, data, logger): # Getting list of all repeating sections where new instances are going to be read from excel/csv file # and appended. - section_names: List[str] = row_sections + section_names: list[str] = row_sections # A list to track if the top-most level section has ever been visited - list_of_visited_sections: List[str] = [] + list_of_visited_sections: list[str] = [] for section_name in section_names: section_name_list = section_name.split('/') @@ -607,10 +609,10 @@ def _get_relative_path(section_def) -> Iterator[str]: @cached(LRUCache(maxsize=10)) def _create_column_to_quantity_mapping(section_def: Section): - mapping: Dict[str, Callable[[MSection, Any], MSection]] = {} + mapping: dict[str, Callable[[MSection, Any], MSection]] = {} - def add_section_def(section_def: Section, path: List[Tuple[SubSection, Section]]): - properties: Set[Property] = set() + def add_section_def(section_def: Section, path: list[tuple[SubSection, Section]]): + properties: set[Property] = set() for quantity in section_def.all_quantities.values(): if quantity in properties: @@ -674,7 +676,7 @@ def _create_column_to_quantity_mapping(section_def: Section): ) section.m_set(quantity, value) - _section_path_list: List[str] = list(_get_relative_path(section)) + _section_path_list: list[str] = list(_get_relative_path(section)) _section_path_str: str = '/'.join(_section_path_list) section_path_to_top_subsection.append(_section_path_str) @@ -743,7 +745,7 @@ def parse_table(pd_dataframe, section_def: Section, logger): data: pd.DataFrame = pd_dataframe data_dict = data.to_dict() - sections: List[MSection] = [] + sections: list[MSection] = [] sheet_name = list(data_dict[0])[0] mapping = _create_column_to_quantity_mapping(section_def) # type: ignore @@ -782,7 +784,7 @@ def parse_table(pd_dataframe, section_def: Section, logger): except Exception: continue - path_quantities_to_top_subsection: Set[str] = set() + path_quantities_to_top_subsection: set[str] = set() for row_index, row in df.iterrows(): for col_index in range(0, max_no_of_repeated_columns + 1): section = section_def.section_cls() @@ -793,7 +795,7 @@ def parse_table(pd_dataframe, section_def: Section, logger): if col_name in df: try: - temp_quantity_path_container: List[str] = [] + temp_quantity_path_container: list[str] = [] mapping[column]( section, row[col_name], @@ -832,7 +834,7 @@ def parse_table(pd_dataframe, section_def: Section, logger): else: try: for item in path_quantities_to_top_subsection: - section_name: List[str] = item.split('/')[1:] + section_name: list[str] = item.split('/')[1:] _append_subsections_from_section( section_name, sections[row_index], section ) @@ -846,7 +848,7 @@ def parse_table(pd_dataframe, section_def: Section, logger): def _strip_whitespaces_from_df_columns(df): - transformed_column_names: Dict[str, str] = {} + transformed_column_names: dict[str, str] = {} for col_name in list(df.columns): cleaned_col_name = col_name.strip().split('.')[0] count = 0 @@ -861,7 +863,7 @@ def _strip_whitespaces_from_df_columns(df): def _append_subsections_from_section( - section_name: List[str], target_section: MSection, source_section: MSection + section_name: list[str], target_section: MSection, source_section: MSection ): if len(section_name) == 1: for ( @@ -887,13 +889,13 @@ def read_table_data( file_or_path=None, comment: str = None, sep: str = None, - skiprows: Union[list[int], int] = None, + skiprows: list[int] | int = None, separator: str = None, filters: dict = None, ): import pandas as pd - def filter_columns(df: pd.DataFrame, filters: Union[None, dict]) -> pd.DataFrame: + def filter_columns(df: pd.DataFrame, filters: None | dict) -> pd.DataFrame: if not filters: return df @@ -987,7 +989,7 @@ class TabularDataParser(MatchingParser): buffer: bytes, decoded_buffer: str, compression: str = None, - ) -> Union[bool, Iterable[str]]: + ) -> bool | Iterable[str]: # We use the main file regex capabilities of the superclass to check if this is a # .csv file import pandas as pd diff --git a/nomad/processing/base.py b/nomad/processing/base.py index 6f90f096e0cfc21a8a4022ef8663d00899210d44..8207337e762194ec80884338bdc4038729dabc5c 100644 --- a/nomad/processing/base.py +++ b/nomad/processing/base.py @@ -351,7 +351,7 @@ class Proc(Document): force: bool = False, worker_hostname: str = None, process_status: str = ProcessStatus.READY, - errors: List[str] = [], + errors: list[str] = [], clear_queue: bool = True, ): """ @@ -374,7 +374,7 @@ class Proc(Document): cls, worker_hostname: str = None, process_status=ProcessStatus.READY, - errors: List[str] = [], + errors: list[str] = [], clear_queue: bool = True, ): """ @@ -403,7 +403,7 @@ class Proc(Document): raise e if obj is None: - raise KeyError('%s with id %s does not exist' % (cls.__name__, id)) + raise KeyError(f'{cls.__name__} with id {id} does not exist') return obj @@ -447,7 +447,7 @@ class Proc(Document): for error in errors: if isinstance(error, Exception): failed_with_exception = True - self.errors.append('%s: %s' % (error.__class__.__name__, str(error))) + self.errors.append(f'{error.__class__.__name__}: {str(error)}') Proc.log( logger, log_level, @@ -542,7 +542,7 @@ class Proc(Document): ): queue = worker_direct(self.worker_hostname).name - priority = config.celery.priorities.get('%s.%s' % (cls_name, func_name), 1) + priority = config.celery.priorities.get(f'{cls_name}.{func_name}', 1) logger = utils.get_logger(__name__, cls=cls_name, id=self_id, func=func_name) logger.info( @@ -559,7 +559,7 @@ class Proc(Document): ) def __str__(self): - return 'proc celery_task_id=%s worker_hostname=%s' % ( + return 'proc celery_task_id={} worker_hostname={}'.format( self.celery_task_id, self.worker_hostname, ) @@ -766,7 +766,7 @@ class Proc(Document): def _sync_complete_process( self, force_clear_queue_on_failure=False - ) -> Tuple[str, List[Any], Dict[str, Any]]: + ) -> tuple[str, list[Any], dict[str, Any]]: """ Used to complete a process (when done, successful or not). Returns a triple containing information about the next process to run (if any), of the @@ -851,7 +851,7 @@ def all_subclasses(cls): all_proc_cls = {cls.__name__: cls for cls in all_subclasses(Proc)} """ Name dictionary for all Proc classes. """ -process_flags: Dict[str, Dict[str, ProcessFlags]] = defaultdict(dict) +process_flags: dict[str, dict[str, ProcessFlags]] = defaultdict(dict) """ { <Proc class name>: { <process func name>: ProcessFlags } } """ diff --git a/nomad/processing/data.py b/nomad/processing/data.py index 656dfc010036506bab35f22d204746bf07fc9c25..b1298479f752560b900d314f2de1a3066bc38d88 100644 --- a/nomad/processing/data.py +++ b/nomad/processing/data.py @@ -36,12 +36,10 @@ from typing import ( List, Tuple, Set, - Iterator, Dict, - Iterable, - Sequence, Union, ) +from collections.abc import Iterator, Iterable, Sequence from pydantic import ValidationError from pydantic_core import InitErrorDetails, PydanticCustomError import rfc3161ng @@ -149,7 +147,7 @@ mongo_entry_metadata_except_system_fields = tuple( for quantity_name in mongo_entry_metadata if quantity_name not in mongo_system_metadata ) -editable_metadata: Dict[str, metainfo.Definition] = { +editable_metadata: dict[str, metainfo.Definition] = { quantity.name: quantity for quantity in EditableUserMetadata.m_def.definitions if isinstance(quantity, metainfo.Quantity) @@ -196,12 +194,12 @@ _log_processors = [ def get_rfc3161_token( hash_string: str, - server: Optional[str] = None, - cert: Optional[str] = None, - username: Optional[str] = None, - password: Optional[str] = None, - hash_algorithm: Optional[str] = None, -) -> Optional[bytes]: + server: str | None = None, + cert: str | None = None, + username: str | None = None, + password: str | None = None, + hash_algorithm: str | None = None, +) -> bytes | None: """ Get RFC3161 compliant time stamp as a list of int. """ @@ -252,8 +250,8 @@ class MetadataEditRequestHandler: @classmethod def edit_metadata( - cls, edit_request_json: Dict[str, Any], upload_id: str, user: datamodel.User - ) -> Dict[str, Any]: + cls, edit_request_json: dict[str, Any], upload_id: str, user: datamodel.User + ) -> dict[str, Any]: """ Method to verify and execute a generic request to edit metadata from a certain user. The request is specified as a json dictionary (requests defined by metadata files @@ -306,7 +304,7 @@ class MetadataEditRequestHandler: self, logger, user: datamodel.User, - edit_request: Union[StagingUploadFiles, Dict[str, Any]], + edit_request: StagingUploadFiles | dict[str, Any], upload_id: str = None, ): # Initialization @@ -319,40 +317,40 @@ class MetadataEditRequestHandler: self.edit_request = edit_request self.upload_id = upload_id - self.errors: List[ + self.errors: list[ CustomErrorWrapper ] = [] # A list of all encountered errors, if any - self.edit_attempt_locs: List[ - Tuple[str, ...] + self.edit_attempt_locs: list[ + tuple[str, ...] ] = [] # locs where user has attempted to edit something self.required_auth_level = ( AuthLevel.none ) # Maximum required auth level for the edit - self.required_auth_level_locs: List[ - Tuple[str, ...] + self.required_auth_level_locs: list[ + tuple[str, ...] ] = [] # locs where maximal auth level is needed - self.encountered_users: Dict[ + self.encountered_users: dict[ str, str ] = {} # { ref: user_id | None }, ref = user_id | username | email - self.encountered_datasets: Dict[ + self.encountered_datasets: dict[ str, datamodel.Dataset ] = {} # { ref : dataset | None }, ref = dataset_id | dataset_name # Used when edit_request = json dict self.edit_request_obj: MetadataEditRequest = None - self.verified_metadata: Dict[ + self.verified_metadata: dict[ str, Any ] = {} # The metadata specified at the top/root level - self.verified_entries: Dict[ - str, Dict[str, Any] + self.verified_entries: dict[ + str, dict[str, Any] ] = {} # Metadata specified for individual entries - self.affected_uploads: List['Upload'] = ( + self.affected_uploads: list['Upload'] = ( None # A MetadataEditRequest may involve multiple uploads ) # Used when edit_request = files - self.verified_file_metadata_cache: Dict[str, Dict[str, Any]] = {} - self.root_file_entries: Dict[str, Dict[str, Any]] = ( + self.verified_file_metadata_cache: dict[str, dict[str, Any]] = {} + self.root_file_entries: dict[str, dict[str, Any]] = ( None # `entries` defined in the root metadata file ) @@ -448,7 +446,7 @@ class MetadataEditRequestHandler: if self.errors: raise RequestValidationError(errors=self.errors) - def get_upload_mongo_metadata(self, upload: 'Upload') -> Dict[str, Any]: + def get_upload_mongo_metadata(self, upload: 'Upload') -> dict[str, Any]: """ Returns a dictionary with metadata to set on the mongo Upload object. If the provided `edit_request` is a json dictionary the :func: `validate_json_request`) is assumed @@ -467,13 +465,13 @@ class MetadataEditRequestHandler: def get_entry_mongo_metadata( self, upload: 'Upload', entry: 'Entry' - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Returns a dictionary with metadata to set on the mongo entry object. If the provided `edit_request` is a json dictionary the :func: `validate_json_request`) is assumed to have been run first. """ - verified_metadata: Dict[str, Any] = {} + verified_metadata: dict[str, Any] = {} if isinstance(self.edit_request, dict): # edit_request = json dict if self.verified_metadata: @@ -517,18 +515,18 @@ class MetadataEditRequestHandler: verified_metadata.update(verified_entry_metadata) return self._mongo_metadata(entry, verified_metadata) - def _error(self, msg: str, loc: Union[str, Tuple[str, ...]]): + def _error(self, msg: str, loc: str | tuple[str, ...]): """Registers an error associated with a particular location.""" self.errors.append(CustomErrorWrapper(Exception(msg), loc=loc)) self.logger.error(msg, loc=loc) def _verify_metadata( self, - raw_metadata: Dict[str, Any], - loc: Tuple[str, ...], + raw_metadata: dict[str, Any], + loc: tuple[str, ...], can_edit_upload_quantities: bool, auth_level: AuthLevel = None, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Performs basic validation of a dictionary with *raw* metadata (i.e. metadata with key-value pairs as defined in the request json dictionary or metadata files), and @@ -561,10 +559,10 @@ class MetadataEditRequestHandler: self, quantity_name: str, raw_value: Any, - loc: Tuple[str, ...], + loc: tuple[str, ...], can_edit_upload_quantities: bool, auth_level: AuthLevel, - ) -> Tuple[bool, Any]: + ) -> tuple[bool, Any]: """ Performs validation of a single value. Returns (success, verified_value). """ @@ -716,14 +714,14 @@ class MetadataEditRequestHandler: assert False, 'Unhandled value type' # Should not happen def _mongo_metadata( - self, mongo_doc: Union['Upload', 'Entry'], verified_metadata: Dict[str, Any] - ) -> Dict[str, Any]: + self, mongo_doc: Union['Upload', 'Entry'], verified_metadata: dict[str, Any] + ) -> dict[str, Any]: """ Calculates the upload or entry level *mongo* metadata, given a `mongo_doc` and a dictionary with *verified* metadata. The mongo metadata are the key-value pairs to set on `mongo_doc` in order to carry out the edit request. """ - rv: Dict[str, Any] = {} + rv: dict[str, Any] = {} for quantity_name, verified_value in verified_metadata.items(): if ( isinstance(mongo_doc, Entry) @@ -807,7 +805,7 @@ class MetadataEditRequestHandler: return restrict_query_to_upload(query, upload_id) return query - def _find_request_uploads(self) -> List['Upload']: + def _find_request_uploads(self) -> list['Upload']: """Returns a list of :class:`Upload`s matching the edit request.""" query = self._restricted_request_query(self.upload_id) if query: @@ -846,10 +844,9 @@ class MetadataEditRequestHandler: yield Entry.get(result['entry_id']) else: # We have no query. Return all entries for the upload - for entry in Entry.objects(upload_id=upload.upload_id): - yield entry + yield from Entry.objects(upload_id=upload.upload_id) - def _verified_file_metadata(self, path_dir: str) -> Dict[str, Any]: + def _verified_file_metadata(self, path_dir: str) -> dict[str, Any]: """ Gets the verified metadata defined in a metadata file in the provided directory. The `path_dir` should be relative to the `raw` folder. Empty string gives the "root" @@ -865,7 +862,7 @@ class MetadataEditRequestHandler: ).metadata_file_cached(path_dir) if path_dir == '': can_edit_upload_quantities = True - loc: Tuple[str, ...] = ('/',) + loc: tuple[str, ...] = ('/',) if 'entries' in file_metadata: self.root_file_entries = file_metadata.pop('entries') if not isinstance(self.root_file_entries, dict): @@ -971,8 +968,8 @@ class Entry(Proc): self._is_initial_processing: bool = False self._upload: Upload = None self._upload_files: StagingUploadFiles = None - self._proc_logs: List[Any] = [] - self._child_entries: List['Entry'] = [] + self._proc_logs: list[Any] = [] + self._child_entries: list['Entry'] = [] self._entry_metadata: EntryMetadata = None self._perform_index = True @@ -1384,8 +1381,7 @@ class Entry(Proc): def _main_and_child_entries(self) -> Iterable['Entry']: yield self - for child_entry in self._child_entries: - yield child_entry + yield from self._child_entries def on_success(self): # Mark any child entries as successfully completed (necessary because the child entries @@ -1665,7 +1661,7 @@ class Entry(Proc): return self._proc_logs def __str__(self): - return 'entry %s entry_id=%s upload_id%s' % ( + return 'entry {} entry_id={} upload_id{}'.format( super().__str__(), self.entry_id, self.upload_id, @@ -1785,7 +1781,7 @@ class Upload(Proc): def get_logger(self, **kwargs): logger = super().get_logger() main_author_user = self.main_author_user - main_author_name = '%s %s' % ( + main_author_name = '{} {}'.format( main_author_user.first_name, main_author_user.last_name, ) @@ -1938,7 +1934,7 @@ class Upload(Proc): upload_auth = client.Auth( user=config.keycloak.username, password=config.keycloak.password ) - upload_parameters: Dict[str, Any] = {} + upload_parameters: dict[str, Any] = {} if embargo_length is not None: upload_parameters.update(embargo_length=embargo_length) upload_url = ( @@ -1966,7 +1962,7 @@ class Upload(Proc): @process() def process_example_upload( - self, entry_point_id: str, file_operations: List[Dict[str, Any]] = None + self, entry_point_id: str, file_operations: list[dict[str, Any]] = None ): """Used to initiate the processing of an example upload entry point. This process is only triggered once per example upload, and any further @@ -2012,8 +2008,8 @@ class Upload(Proc): @process() def process_upload( self, - file_operations: List[Dict[str, Any]] = None, - reprocess_settings: Dict[str, Any] = None, + file_operations: list[dict[str, Any]] = None, + reprocess_settings: dict[str, Any] = None, path_filter: str = None, only_updated_files: bool = False, ): @@ -2049,8 +2045,8 @@ class Upload(Proc): def _process_upload_local( self, - file_operations: List[Dict[str, Any]] = None, - reprocess_settings: Dict[str, Any] = None, + file_operations: list[dict[str, Any]] = None, + reprocess_settings: dict[str, Any] = None, path_filter: str = None, only_updated_files: bool = False, ): @@ -2123,7 +2119,7 @@ class Upload(Proc): ) # Process entries, if matched; remove existing entries if unmatched. - old_entries_dict: Dict[str, Entry] = { + old_entries_dict: dict[str, Entry] = { e.entry_id: e for e in Entry.objects(upload_id=self.upload_id, mainfile=target_path) } @@ -2137,7 +2133,7 @@ class Upload(Proc): self.upload_id, ) - mainfile_keys_including_main_entry: List[str] = [None] + ( + mainfile_keys_including_main_entry: list[str] = [None] + ( mainfile_keys or [] ) # type: ignore for mainfile_key in mainfile_keys_including_main_entry: @@ -2206,7 +2202,7 @@ class Upload(Proc): @classmethod def _passes_process_filter( - cls, mainfile: str, path_filter: str, updated_files: Set[str] + cls, mainfile: str, path_filter: str, updated_files: set[str] ) -> bool: if path_filter: # Filter by path_filter @@ -2222,8 +2218,8 @@ class Upload(Proc): return True def update_files( - self, file_operations: List[Dict[str, Any]], only_updated_files: bool - ) -> Set[str]: + self, file_operations: list[dict[str, Any]], only_updated_files: bool + ) -> set[str]: """ Performed before the actual parsing/normalizing. It first ensures that there is a folder for the upload in the staging area (if the upload is published, the files @@ -2247,7 +2243,7 @@ class Upload(Proc): StagingUploadFiles(self.upload_id, create=True) staging_upload_files = self.staging_upload_files - updated_files: Set[str] = set() if only_updated_files else None + updated_files: set[str] = set() if only_updated_files else None # Execute the requested file_operations, if any if file_operations: @@ -2307,7 +2303,7 @@ class Upload(Proc): # created stripped POTCAR stripped_path = path + '.stripped' with open( - self.staging_upload_files.raw_file_object(stripped_path).os_path, 'wt' + self.staging_upload_files.raw_file_object(stripped_path).os_path, 'w' ) as stripped_f: stripped_f.write( 'Stripped POTCAR file. Checksum of original file (sha224): %s\n' @@ -2328,8 +2324,8 @@ class Upload(Proc): ) def match_mainfiles( - self, path_filter: str, updated_files: Set[str] - ) -> Iterator[Tuple[str, str, Parser]]: + self, path_filter: str, updated_files: set[str] + ) -> Iterator[tuple[str, str, Parser]]: """ Generator function that matches all files in the upload to all parsers to determine the upload's mainfiles. @@ -2345,7 +2341,7 @@ class Upload(Proc): if path_filter: # path_filter provided, just scan this path - scan: List[Tuple[str, bool]] = [(path_filter, True)] + scan: list[tuple[str, bool]] = [(path_filter, True)] elif updated_files is not None: # Set with updated_files provided, only scan these scan = [(path, False) for path in updated_files] @@ -2373,7 +2369,7 @@ class Upload(Proc): staging_upload_files.raw_file_object(path_info.path).os_path ) if parser is not None: - mainfile_keys_including_main_entry: List[str] = [None] + ( + mainfile_keys_including_main_entry: list[str] = [None] + ( mainfile_keys or [] ) # type: ignore for mainfile_key in mainfile_keys_including_main_entry: @@ -2389,7 +2385,7 @@ class Upload(Proc): self, reprocess_settings: Reprocess, path_filter: str = None, - updated_files: Set[str] = None, + updated_files: set[str] = None, ): """ The process step used to identify mainfile/parser combinations among the upload's files, @@ -2468,7 +2464,7 @@ class Upload(Proc): not self.published or reprocess_settings.delete_unmatched_published_entries ): - entries_to_delete: List[str] = list(old_entries) + entries_to_delete: list[str] = list(old_entries) delete_partial_archives_from_mongo(entries_to_delete) for entry_id in entries_to_delete: search.delete_entry( @@ -2510,7 +2506,7 @@ class Upload(Proc): can_create: bool, metadata_handler: MetadataEditRequestHandler, logger, - ) -> Tuple[Entry, bool, MetadataEditRequestHandler]: + ) -> tuple[Entry, bool, MetadataEditRequestHandler]: entry_id = utils.generate_entry_id(self.upload_id, mainfile, mainfile_key) entry = None was_created = False @@ -2550,7 +2546,7 @@ class Upload(Proc): return entry, was_created, metadata_handler def parse_next_level( - self, min_level: int, path_filter: str = None, updated_files: Set[str] = None + self, min_level: int, path_filter: str = None, updated_files: set[str] = None ) -> bool: """ Triggers processing on the next level of parsers (parsers with level >= min_level). @@ -2559,7 +2555,7 @@ class Upload(Proc): try: logger = self.get_logger() next_level: int = None - next_entries: List[Entry] = None + next_entries: list[Entry] = None with utils.timer(logger, 'entries processing called'): # Determine what the next level is and which entries belongs to this level for entry in Entry.objects(upload_id=self.upload_id, mainfile_key=None): @@ -2859,7 +2855,7 @@ class Upload(Proc): ) @contextmanager - def entries_metadata(self) -> Iterator[List[EntryMetadata]]: + def entries_metadata(self) -> Iterator[list[EntryMetadata]]: """ This is the :py:mod:`nomad.datamodel` transformation method to transform processing upload's entries into list of :class:`EntryMetadata` objects. @@ -2874,7 +2870,7 @@ class Upload(Proc): finally: self.upload_files.close() # Because full_entry_metadata reads the archive files. - def entries_mongo_metadata(self) -> List[EntryMetadata]: + def entries_mongo_metadata(self) -> list[EntryMetadata]: """ Returns a list of :class:`EntryMetadata` containing the mongo metadata only, for all entries of this upload. @@ -2885,7 +2881,7 @@ class Upload(Proc): ] @process() - def edit_upload_metadata(self, edit_request_json: Dict[str, Any], user_id: str): + def edit_upload_metadata(self, edit_request_json: dict[str, Any], user_id: str): """ A @process that executes a metadata edit request, restricted to a specific upload, on behalf of the provided user. The `edit_request_json` should be a json dict of the @@ -2918,7 +2914,7 @@ class Upload(Proc): # Entry level metadata last_edit_time = datetime.utcnow() entry_mongo_writes = [] - updated_metadata: List[datamodel.EntryMetadata] = [] + updated_metadata: list[datamodel.EntryMetadata] = [] for entry in handler.find_request_entries(self): entry_updates = handler.get_entry_mongo_metadata(self, entry) entry_updates['last_edit_time'] = last_edit_time @@ -2953,14 +2949,14 @@ class Upload(Proc): f'Failed to update ES, there were {failed_es} fails' ) - def entry_ids(self) -> List[str]: + def entry_ids(self) -> list[str]: return [entry.entry_id for entry in Entry.objects(upload_id=self.upload_id)] @process(is_blocking=True) def import_bundle( self, bundle_path: str, - import_settings: Dict[str, Any], + import_settings: dict[str, Any], embargo_length: int = None, ): """ @@ -2994,4 +2990,4 @@ class Upload(Proc): bundle_importer.close() def __str__(self): - return 'upload %s upload_id%s' % (super().__str__(), self.upload_id) + return f'upload {super().__str__()} upload_id{self.upload_id}' diff --git a/nomad/search.py b/nomad/search.py index 686ecf6a7007e909b0df4438c56040aeb7428bbb..329bfe1a09ec23a08c5c3853b4291d041b8b78af 100644 --- a/nomad/search.py +++ b/nomad/search.py @@ -37,17 +37,15 @@ import math from enum import Enum from typing import ( Any, - Callable, Dict, - Generator, - Iterable, - Iterator, List, Optional, Tuple, Union, cast, ) +from collections.abc import Callable +from collections.abc import Generator, Iterable, Iterator import elasticsearch.helpers from elasticsearch.exceptions import RequestError, TransportError @@ -239,10 +237,10 @@ _refresh = refresh def index( - entries: Union[EntryArchive, List[EntryArchive]], + entries: EntryArchive | list[EntryArchive], update_materials: bool = False, refresh: bool = False, -) -> Dict[str, str]: +) -> dict[str, str]: """ Index the given entries based on their archive. Either creates or updates the underlying elasticsearch documents. If an underlying elasticsearch document already exists it @@ -258,7 +256,7 @@ def index( return errors -def index_materials(entries: Union[EntryArchive, List[EntryArchive]], **kwargs): +def index_materials(entries: EntryArchive | list[EntryArchive], **kwargs): """ Index the materials within the given entries based on their archive. The entries have to be indexed first. @@ -370,7 +368,7 @@ _all_author_quantities = [ def _api_to_es_required( required: MetadataRequired, pagination: MetadataPagination, doc_type: DocumentType -) -> Tuple[Optional[List[str]], Optional[List[str]], bool]: +) -> tuple[list[str] | None, list[str] | None, bool]: """ Translates an API include/exclude argument into the appropriate ES arguments. Note that certain fields cannot be excluded from the underlying @@ -454,7 +452,7 @@ def _es_to_api_pagination( # itself: internally ES can perform the sorting on a different # value which is reported under meta.sort. after_value = last.meta.sort[0] - next_page_after_value = '%s:%s' % (after_value, last[doc_type.id_field]) + next_page_after_value = f'{after_value}:{last[doc_type.id_field]}' # For dynamic YAML quantities the field name is normalized to not include # the data type @@ -474,7 +472,7 @@ def _es_to_entry_dict( required: MetadataRequired = None, requires_filtering: bool = False, doc_type=None, -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Translates an ES hit response into a response data object that is expected by the API. @@ -584,7 +582,7 @@ def _owner_es_query( query_dict = {(prefix + field): value for field, value in kwargs.items()} return Q(query_type, **query_dict) - def viewers_query(user_id: Optional[str], *, force_groups: bool = False) -> Q: + def viewers_query(user_id: str | None, *, force_groups: bool = False) -> Q: """Filter for user viewers and group viewers. force_groups: If true, add group filter even if user_id is None.""" @@ -674,7 +672,7 @@ def get_definition(path): def validate_quantity( - quantity_name: str, doc_type: DocumentType = None, loc: List[str] = None + quantity_name: str, doc_type: DocumentType = None, loc: list[str] = None ) -> SearchQuantity: """ Validates the given quantity name against the given document type. @@ -1014,7 +1012,7 @@ def _api_to_es_query( def validate_pagination( - pagination: Pagination, doc_type: DocumentType, loc: List[str] = None + pagination: Pagination, doc_type: DocumentType, loc: list[str] = None ): order_quantity = None if pagination.order_by is not None: @@ -1051,8 +1049,8 @@ def validate_pagination( def _api_to_es_sort( - pagination: Pagination, doc_type: DocumentType, loc: List[str] = None -) -> Tuple[Dict[str, Any], SearchQuantity, str]: + pagination: Pagination, doc_type: DocumentType, loc: list[str] = None +) -> tuple[dict[str, Any], SearchQuantity, str]: """ Creates an ES sort based on the API's pagination model. @@ -1063,7 +1061,7 @@ def _api_to_es_sort( """ order_quantity, page_after_value = validate_pagination(pagination, doc_type, loc) - sort: Dict[str, Any] = {} + sort: dict[str, Any] = {} if order_quantity.dynamic: path = order_quantity.get_dynamic_path() postfix = '.keyword' @@ -1256,7 +1254,7 @@ def _api_to_es_aggregation( else: agg.size = 10 - terms_kwargs: Dict[str, Any] = {} + terms_kwargs: dict[str, Any] = {} if agg.include is not None: if isinstance(agg.include, str): terms_kwargs['include'] = f'.*{agg.include}.*' @@ -1272,7 +1270,7 @@ def _api_to_es_aggregation( es_agg = es_aggs.bucket(agg_name, terms) if agg.entries is not None and agg.entries.size > 0: - kwargs: Dict[str, Any] = {} + kwargs: dict[str, Any] = {} if agg.entries.required is not None: if agg.entries.required.include is not None: kwargs.update(_source=dict(includes=agg.entries.required.include)) @@ -1330,7 +1328,7 @@ def _api_to_es_aggregation( f'The quantity {quantity} cannot be used in a histogram aggregation', loc=['aggregations', name, AggType.HISTOGRAM, 'quantity'], ) - params: Dict[str, Any] = {} + params: dict[str, Any] = {} if agg.offset is not None: params['offset'] = agg.offset if agg.extended_bounds is not None: @@ -1385,8 +1383,8 @@ def _es_to_api_aggregation( es_response, name: str, agg: AggregationBase, - histogram_responses: Dict[str, HistogramAggregation], - bucket_values: Dict[str, float], + histogram_responses: dict[str, HistogramAggregation], + bucket_values: dict[str, float], doc_type: DocumentType, ): """ @@ -1574,14 +1572,14 @@ def _es_to_api_aggregation( def _specific_agg( agg: Aggregation, -) -> Union[ - TermsAggregation, - AutoDateHistogramAggregation, - DateHistogramAggregation, - HistogramAggregation, - MinMaxAggregation, - StatisticsAggregation, -]: +) -> ( + TermsAggregation + | AutoDateHistogramAggregation + | DateHistogramAggregation + | HistogramAggregation + | MinMaxAggregation + | StatisticsAggregation +): if agg.terms is not None: return agg.terms @@ -1606,21 +1604,20 @@ def _specific_agg( def _and_clauses(query: Query) -> Generator[Query, None, None]: if isinstance(query, models.And): for clause in query.op: - for query_clause in _and_clauses(clause): - yield query_clause + yield from _and_clauses(clause) yield query def _buckets_to_interval( owner: str = 'public', - query: Union[Query, EsQuery] = None, + query: Query | EsQuery = None, pagination: MetadataPagination = None, required: MetadataRequired = None, - aggregations: Dict[str, Aggregation] = {}, + aggregations: dict[str, Aggregation] = {}, user_id: str = None, index: Index = entry_index, -) -> Tuple[Dict[str, Aggregation], Dict[str, HistogramAggregation], Dict[str, float]]: +) -> tuple[dict[str, Aggregation], dict[str, HistogramAggregation], dict[str, float]]: """Converts any histogram aggregations with the number of buckets into a query with an interval. This is required because elasticsearch does not yet support providing only the number of buckets. @@ -1629,9 +1626,9 @@ def _buckets_to_interval( interval cannot be defined in such cases, so we use a dummy value of 1. """ # Get the histograms which are determined by the number of buckets - histogram_requests: Dict[str, HistogramAggregation] = {} - histogram_responses: Dict[str, HistogramAggregation] = {} - bucket_values: Dict[str, float] = {} + histogram_requests: dict[str, HistogramAggregation] = {} + histogram_responses: dict[str, HistogramAggregation] = {} + bucket_values: dict[str, float] = {} aggs = {name: _specific_agg(agg) for name, agg in aggregations.items()} for agg_name, agg in aggs.items(): if isinstance(agg, HistogramAggregation): @@ -1719,10 +1716,10 @@ def _buckets_to_interval( def search( owner: str = 'public', - query: Union[Query, EsQuery] = None, + query: Query | EsQuery = None, pagination: MetadataPagination = None, required: MetadataRequired = None, - aggregations: Dict[str, Aggregation] = {}, + aggregations: dict[str, Aggregation] = {}, user_id: str = None, index: Index = entry_index, ) -> MetadataResponse: @@ -1864,7 +1861,7 @@ def search( # aggregations if len(aggregations) > 0: more_response_data['aggregations'] = cast( - Dict[str, Any], + dict[str, Any], { name: _es_to_api_aggregation( es_response, @@ -1900,13 +1897,13 @@ def search( def search_iterator( owner: str = 'public', - query: Union[Query, EsQuery] = None, + query: Query | EsQuery = None, order_by: str = 'entry_id', required: MetadataRequired = None, - aggregations: Dict[str, Aggregation] = {}, + aggregations: dict[str, Aggregation] = {}, user_id: str = None, index: Index = entry_index, -) -> Iterator[Dict[str, Any]]: +) -> Iterator[dict[str, Any]]: """ Works like :func:`search`, but returns an iterator for iterating over the results. Consequently, you cannot specify `pagination`, only `order_buy`. @@ -1927,8 +1924,7 @@ def search_iterator( page_after_value = response.pagination.next_page_after_value - for result in response.data: - yield result + yield from response.data if page_after_value is None or len(response.data) == 0: break diff --git a/nomad/utils/__init__.py b/nomad/utils/__init__.py index e5887a47209e2f5222afa0a0fe9982bbf6506f1c..a580d088f98d59cb7e34c246cb29de1f56f13131 100644 --- a/nomad/utils/__init__.py +++ b/nomad/utils/__init__.py @@ -38,7 +38,8 @@ Depending on the configuration all logs will also be send to a central logstash. .. autofunc::nomad.utils.strip """ -from typing import List, Iterable, Union, Any, Dict, Optional +from typing import List, Union, Any, Dict, Optional +from collections.abc import Iterable from collections import OrderedDict from functools import reduce from itertools import takewhile @@ -123,10 +124,10 @@ class ClassicLogger: all_kwargs = dict(self.kwargs) all_kwargs.update(**kwargs) - message = '%s (%s)' % ( + message = '{} ({})'.format( event, ', '.join( - ['%s=%s' % (str(key), str(value)) for key, value in all_kwargs.items()] + [f'{str(key)}={str(value)}' for key, value in all_kwargs.items()] ), ) method(message) @@ -322,10 +323,10 @@ def timer( class archive: @staticmethod def create(upload_id: str, entry_id: str) -> str: - return '%s/%s' % (upload_id, entry_id) + return f'{upload_id}/{entry_id}' @staticmethod - def items(archive_id: str) -> List[str]: + def items(archive_id: str) -> list[str]: return archive_id.split('/') @staticmethod @@ -486,7 +487,7 @@ class RestrictedDict(OrderedDict): if not self._lazy: # Check that only the defined keys are used if key not in self._mandatory_keys and key not in self._optional_keys: - raise KeyError("The key '{}' is not allowed.".format(key)) + raise KeyError(f"The key '{key}' is not allowed.") # Check that forbidden values are not used. try: @@ -495,7 +496,7 @@ class RestrictedDict(OrderedDict): pass # Unhashable value will not match else: if match: - raise ValueError("The value '{}' is not allowed.".format(key)) + raise ValueError(f"The value '{key}' is not allowed.") super().__setitem__(key, value) @@ -503,12 +504,12 @@ class RestrictedDict(OrderedDict): # Check that only the defined keys are used for key in self.keys(): if key not in self._mandatory_keys and key not in self._optional_keys: - raise KeyError("The key '{}' is not allowed.".format(key)) + raise KeyError(f"The key '{key}' is not allowed.") # Check that all mandatory values are all defined for key in self._mandatory_keys: if key not in self: - raise KeyError("The mandatory key '{}' is not present.".format(key)) + raise KeyError(f"The mandatory key '{key}' is not present.") # Check that forbidden values are not used. for key, value in self.items(): @@ -695,7 +696,7 @@ def rebuild_dict(src: dict, separator: str = '.'): result[index], ) - ret: Dict[str, Any] = {} + ret: dict[str, Any] = {} for key, value in src.items(): helper_dict(key, value, ret) @@ -835,7 +836,7 @@ def slugify(value): return re.sub(r'[-\s]+', '-', value).strip('-_') -def query_list_to_dict(path_list: List[Union[str, int]], value: Any) -> Dict[str, Any]: +def query_list_to_dict(path_list: list[str | int], value: Any) -> dict[str, Any]: """Transforms a list of path fragments into a dictionary query. E.g. the list ['run', 0, 'system', 2, 'atoms'] @@ -858,7 +859,7 @@ def query_list_to_dict(path_list: List[Union[str, int]], value: Any) -> Dict[str A nested dictionary representing the query path. """ - returned: Dict[str, Any] = {} + returned: dict[str, Any] = {} current = returned n_items = len(path_list) i = 0 @@ -876,7 +877,7 @@ def query_list_to_dict(path_list: List[Union[str, int]], value: Any) -> Dict[str return returned -def traverse_reversed(archive: Any, path: List[str]) -> Any: +def traverse_reversed(archive: Any, path: list[str]) -> Any: """Traverses the given metainfo path in reverse order. Useful in finding the latest reported section or value. @@ -898,21 +899,19 @@ def traverse_reversed(archive: Any, path: List[str]) -> Any: if i == len(path) - 1: yield section else: - for s in traverse(section, path, i + 1): - yield s + yield from traverse(section, path, i + 1) else: if i == len(path) - 1: yield sections else: - for s in traverse(sections, path, i + 1): - yield s + yield from traverse(sections, path, i + 1) for t in traverse(archive, path, 0): if t is not None: yield t -def extract_section(root: Any, path: List[str], full_list: bool = False): +def extract_section(root: Any, path: list[str], full_list: bool = False): """Extracts a section from source following the path and the last elements of the section lists. If full_list is True, the resolved section gives the full list instead of the last element. diff --git a/nomad/utils/exampledata.py b/nomad/utils/exampledata.py index 91112184a659732495729790443f1d9805bfb716..31ac1586733439f0093ebabc1505cf4b08898cb9 100644 --- a/nomad/utils/exampledata.py +++ b/nomad/utils/exampledata.py @@ -42,10 +42,10 @@ class ExampleData: """ def __init__(self, **kwargs): - self.upload_entries: Dict[str, List[str]] = dict() - self.uploads: Dict[str, Dict[str, Any]] = dict() - self.entries: Dict[str, EntryMetadata] = dict() - self.archives: Dict[str, EntryArchive] = dict() + self.upload_entries: dict[str, list[str]] = dict() + self.uploads: dict[str, dict[str, Any]] = dict() + self.entries: dict[str, EntryMetadata] = dict() + self.archives: dict[str, EntryArchive] = dict() self.entry_defaults = kwargs self._entry_id_counter = 1 @@ -139,10 +139,10 @@ class ExampleData: def create_entry_from_file( self, mainfile: str, - entry_archive: Optional[EntryArchive] = None, - entry_id: Optional[str] = None, - upload_id: Optional[str] = None, - parser_name: Optional[str] = None, + entry_archive: EntryArchive | None = None, + entry_id: str | None = None, + upload_id: str | None = None, + parser_name: str | None = None, ): """Creates an entry from a mainfile which then gets parsed and normalized.""" from nomad.parsing import parsers @@ -232,7 +232,7 @@ class ExampleData: upload_id: str = None, material_id: str = None, mainfile: str = None, - results: Union[Results, dict] = None, + results: Results | dict = None, archive: dict = None, **kwargs, ) -> EntryArchive: @@ -351,7 +351,7 @@ class ExampleData: id: int, h: int, o: int, - extra: List[str], + extra: list[str], periodicity: int, optimade: bool = True, metadata: dict = None, diff --git a/nomad/utils/json_transformer.py b/nomad/utils/json_transformer.py index 3d3aa094931473ea7cf244bb7148d11b82705778..9310a8894fae2a9a46ad97781187d84c8aea839c 100644 --- a/nomad/utils/json_transformer.py +++ b/nomad/utils/json_transformer.py @@ -28,7 +28,7 @@ class Transformer: self.mapping_dict = mapping_dict @staticmethod - def parse_path(path: str) -> list[Union[str, int]]: + def parse_path(path: str) -> list[str | int]: """ Parses a JMESPath-like path into a list of keys and indices. @@ -234,7 +234,7 @@ class Transformer: return target def dict_to_dict( - self, source: dict[str, Any], rules: 'Rules', target: Optional[Any] = None + self, source: dict[str, Any], rules: 'Rules', target: Any | None = None ) -> Any: """ Applies all rules in a Rules object to transform the source dictionary into the target. @@ -257,7 +257,7 @@ class Transformer: self, source_data: dict[str, Any], mapping_name: str = None, - target_data: Optional[Any] = None, + target_data: Any | None = None, ) -> Any: """ Transforms the source data into the target data based on the specified mapping. diff --git a/nomad/utils/structlogging.py b/nomad/utils/structlogging.py index c31d9e4d552426affdca591e58615908a2039ab9..879f4a208a39f6b07a4abe59fc719085efb0ed8e 100644 --- a/nomad/utils/structlogging.py +++ b/nomad/utils/structlogging.py @@ -147,7 +147,7 @@ class LogstashFormatter(logstash.formatter.LogstashFormatterBase): ]: key = 'nomad.%s' % key else: - key = '%s.%s' % (record.name, key) + key = f'{record.name}.{key}' message[key] = value else: @@ -241,9 +241,7 @@ class ConsoleFormatter(LogstashFormatter): else: print_key = key if not cls.short_format or print_key not in ['deployment', 'service']: - out.write( - '\n - %s: %s' % (print_key, str(message_dict.get(key, None))) - ) + out.write(f'\n - {print_key}: {str(message_dict.get(key, None))}') return out.getvalue() diff --git a/pyproject.toml b/pyproject.toml index 407b2c516a1c8d9ebf7108d9a01681856bfe405c..2ec3f67eb796cfa5da9700100d4927529f543bc2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ authors = [ ] dynamic = ["version"] license = { text = "Apache-2.0" } -requires-python = ">=3.10" +requires-python = ">=3.10" # remember to update scripts/pyupgrade.sh dependencies = [ 'aniso8601>=7.0.0', diff --git a/scripts/pyupgrade.sh b/scripts/pyupgrade.sh new file mode 100644 index 0000000000000000000000000000000000000000..f60515a860e24595820277ebcf0386e392612453 --- /dev/null +++ b/scripts/pyupgrade.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# Upgrade syntax of all python files under nomad/ folder using pyupgrade +# pyupgrade is not installed by default +# install it using `pip install pyupgrade` or `uv pip install pyupgrade` +# using modern syntax to maximise maintainability and readability +# it is also possible to use pyupgrade as a commit hook + +if ! command -v pyupgrade &> /dev/null; then + echo "Error: pyupgrade is not installed. Please install it using 'pip install pyupgrade'." + exit 1 +fi + +# Navigate to the parent folder based on script location +cd "$(dirname "$0")/.." || exit 1 + +# Find all Python files in the "nomad" folder and apply pyupgrade +find nomad -type f -name "*.py" | while read -r file; do + pyupgrade --py310-plus "$file" +done