Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
nomad-lab
nomad-FAIR
Commits
6ad17658
Commit
6ad17658
authored
Nov 03, 2021
by
Markus Scheidgen
Browse files
Added exclude from search option to aggregations.
#573
,
#575
parent
e029faf8
Pipeline
#114491
passed with stages
in 29 minutes and 21 seconds
Changes
5
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
nomad/app/v1/models.py
View file @
6ad17658
...
...
@@ -788,6 +788,19 @@ class QuantityAggregation(AggregationBase):
The manatory name of the quantity for the aggregation. Aggregations
can only be computed for those search metadata that have discrete values;
an aggregation buckets entries that have the same value for this quantity.'''
))
exclude_from_search
:
bool
=
Field
(
False
,
description
=
strip
(
'''
If set to true, top-level search criteria involving the aggregation quantity, will not
be applied for this aggregation. Therefore, the aggregation will return all
values for the quantity, even if the possible values where filtered by the query.
There are two limitations. This is only supported with queries that start with a
dictionary. It will not work for queries with a boolean operator. It can only
exclude top-level criteria at the root of the query dictionary. Nested criteria,
e.g. within complex and/or constructs, cannot be considered. Using this might also
prohibit pagination with page_after_value on aggregations in the same request.
'''
)
)
class
BucketAggregation
(
QuantityAggregation
):
...
...
nomad/search.py
View file @
6ad17658
...
...
@@ -437,9 +437,20 @@ def validate_quantity(
return
quantity
def
_create_es_must
(
queries
:
Dict
[
str
,
EsQuery
]):
# dictionary is like an "and" of all items in the dict
if
len
(
queries
)
==
0
:
return
Q
()
if
len
(
queries
)
==
1
:
return
list
(
queries
.
values
())[
0
]
return
Q
(
'bool'
,
must
=
list
(
queries
.
values
()))
def
validate_api_query
(
query
:
Query
,
doc_type
:
DocumentType
,
owner_query
:
EsQuery
,
prefix
:
str
=
None
)
->
EsQuery
:
prefix
:
str
=
None
,
results_dict
:
Dict
[
str
,
EsQuery
]
=
None
)
->
EsQuery
:
'''
Creates an ES query based on the API's query model. This needs to be a normalized
query expression with explicit objects for logical, set, and comparison operators.
...
...
@@ -460,6 +471,11 @@ def validate_api_query(
materials queries.
prefix:
An optional prefix that is added to all quantity names. Used for recursion.
results_dict:
If an empty dictionary is given and the query is a mapping, the top-level
criteria from this mapping will be added as individual es queries. The
keys will be the mapping keys and values the respective es queries. A logical
and (or es "must") would result in the overall resulting es query.
Returns:
A elasticsearch dsl query object.
...
...
@@ -564,11 +580,20 @@ def validate_api_query(
return
Q
()
if
len
(
query
)
==
1
:
key
=
next
(
iter
(
query
))
return
validate_criteria
(
key
,
query
[
key
])
name
=
next
(
iter
(
query
))
es_criteria_query
=
validate_criteria
(
name
,
query
[
name
])
if
results_dict
is
not
None
:
results_dict
[
name
]
=
es_criteria_query
return
es_criteria_query
return
Q
(
'bool'
,
must
=
[
validate_criteria
(
name
,
value
)
for
name
,
value
in
query
.
items
()])
es_criteria_queries
=
[]
for
name
,
value
in
query
.
items
():
es_criteria_query
=
validate_criteria
(
name
,
value
)
es_criteria_queries
.
append
(
es_criteria_query
)
if
results_dict
is
not
None
:
results_dict
[
name
]
=
es_criteria_query
return
Q
(
'bool'
,
must
=
es_criteria_queries
)
raise
NotImplementedError
()
...
...
@@ -595,7 +620,8 @@ def validate_pagination(pagination: Pagination, doc_type: DocumentType, loc: Lis
def
_api_to_es_aggregation
(
es_search
:
Search
,
name
:
str
,
agg
:
AggregationBase
,
doc_type
:
DocumentType
)
->
A
:
es_search
:
Search
,
name
:
str
,
agg
:
AggregationBase
,
doc_type
:
DocumentType
,
post_agg_queries
:
Dict
[
str
,
EsQuery
])
->
A
:
'''
Creates an ES aggregation based on the API's aggregation model.
'''
...
...
@@ -603,6 +629,12 @@ def _api_to_es_aggregation(
agg_name
=
f
'agg:
{
name
}
'
es_aggs
=
es_search
.
aggs
if
post_agg_queries
:
filter
=
post_agg_queries
if
isinstance
(
agg
,
QuantityAggregation
)
and
agg
.
exclude_from_search
:
filter
=
{
name
:
query
for
name
,
query
in
post_agg_queries
.
items
()
if
name
!=
agg
.
quantity
}
es_aggs
=
es_aggs
.
bucket
(
f
'
{
agg_name
}
:filtered'
,
A
(
'filter'
,
filter
=
_create_es_must
(
filter
)))
if
isinstance
(
agg
,
StatisticsAggregation
):
for
metric_name
in
agg
.
metrics
:
metrics
=
doc_type
.
metrics
...
...
@@ -620,13 +652,11 @@ def _api_to_es_aggregation(
return
agg
=
cast
(
QuantityAggregation
,
agg
)
longest_nested_key
=
None
quantity
=
validate_quantity
(
agg
.
quantity
,
doc_type
=
doc_type
,
loc
=
[
'aggregation'
,
'quantity'
])
for
nested_key
in
doc_type
.
nested_object_keys
:
if
agg
.
quantity
.
startswith
(
nested_key
):
es_aggs
=
es_
search
.
aggs
.
bucket
(
'nested_agg:%s'
%
name
,
'nested'
,
path
=
nested_key
)
es_aggs
=
es_aggs
.
bucket
(
'nested_agg:%s'
%
name
,
'nested'
,
path
=
nested_key
)
longest_nested_key
=
nested_key
es_agg
=
None
...
...
@@ -674,6 +704,11 @@ def _api_to_es_aggregation(
}
if
page_after_value
is
not
None
:
if
post_agg_queries
:
raise
QueryValidationError
(
f
'aggregation page_after_value cannot be used with exclude_from_search in the same request'
,
loc
=
[
'aggregations'
,
name
,
'terms'
,
'pagination'
,
'page_after_value'
])
if
order_quantity
is
None
:
composite
[
'after'
]
=
{
name
:
page_after_value
}
else
:
...
...
@@ -770,6 +805,11 @@ def _es_to_api_aggregation(
the given aggregation.
'''
es_aggs
=
es_response
.
aggs
filtered_agg_name
=
f
'agg:
{
name
}
:filtered'
if
filtered_agg_name
in
es_response
.
aggs
:
es_aggs
=
es_aggs
[
f
'agg:
{
name
}
:filtered'
]
aggregation_dict
=
agg
.
dict
(
by_alias
=
True
)
if
isinstance
(
agg
,
StatisticsAggregation
):
...
...
@@ -785,7 +825,7 @@ def _es_to_api_aggregation(
longest_nested_key
=
None
for
nested_key
in
doc_type
.
nested_object_keys
:
if
agg
.
quantity
.
startswith
(
nested_key
):
es_aggs
=
es_
response
.
aggs
[
f
'nested_agg:
{
name
}
'
]
es_aggs
=
es_aggs
[
f
'nested_agg:
{
name
}
'
]
longest_nested_key
=
nested_key
has_no_pagination
=
getattr
(
agg
,
'pagination'
,
None
)
is
None
...
...
@@ -907,22 +947,24 @@ def search(
doc_type
=
index
.
doc_type
# owner
and query
# owner
owner_query
=
_owner_es_query
(
owner
=
owner
,
user_id
=
user_id
,
doc_type
=
doc_type
)
# query
if
query
is
None
:
query
=
{}
es_query_dict
:
Dict
[
str
,
EsQuery
]
=
{}
if
isinstance
(
query
,
EsQuery
):
es_query
=
cast
(
EsQuery
,
query
)
else
:
es_query
=
validate_api_query
(
cast
(
Query
,
query
),
doc_type
=
doc_type
,
owner_query
=
owner_query
)
cast
(
Query
,
query
),
doc_type
=
doc_type
,
owner_query
=
owner_query
,
results_dict
=
es_query_dict
)
if
doc_type
!=
entry_type
:
es_query
&=
Q
(
'nested'
,
path
=
'entries'
,
query
=
owner_query
)
else
:
es_query
&=
owner_query
owner_query
=
Q
(
'nested'
,
path
=
'entries'
,
query
=
owner_query
)
es_query
&=
owner_query
# pagination
if
pagination
is
None
:
...
...
@@ -933,7 +975,6 @@ def search(
search
=
Search
(
index
=
index
.
index_name
)
search
=
search
.
query
(
es_query
)
# TODO this depends on doc_type
if
pagination
.
order_by
is
None
:
pagination
.
order_by
=
doc_type
.
id_field
...
...
@@ -974,15 +1015,44 @@ def search(
search
=
search
.
source
(
includes
=
required
.
include
,
excludes
=
required
.
exclude
)
# aggregations
for
name
,
agg
in
aggregations
.
items
():
_api_to_es_aggregation
(
search
,
name
,
_specific_agg
(
agg
),
doc_type
=
doc_type
)
aggs
=
[(
name
,
_specific_agg
(
agg
))
for
name
,
agg
in
aggregations
.
items
()]
post_agg_queries
:
Dict
[
str
,
EsQuery
]
=
{}
excluded_agg_quantities
=
{
agg
.
quantity
for
_
,
agg
in
aggs
if
isinstance
(
agg
,
QuantityAggregation
)
and
agg
.
exclude_from_search
}
if
len
(
excluded_agg_quantities
)
>
0
:
if
not
isinstance
(
query
,
dict
):
# "exclude_from_search" only work for toplevel mapping queries
raise
QueryValidationError
(
f
'the query has to be a dictionary if there is an aggregation with exclude_from_search'
,
loc
=
[
'query'
])
pre_agg_queries
=
{
quantity
:
es_query
for
quantity
,
es_query
in
es_query_dict
.
items
()
if
quantity
not
in
excluded_agg_quantities
}
post_agg_queries
=
{
quantity
:
es_query
for
quantity
,
es_query
in
es_query_dict
.
items
()
if
quantity
in
excluded_agg_quantities
}
search
=
search
.
post_filter
(
_create_es_must
(
post_agg_queries
))
search
=
search
.
query
(
_create_es_must
(
pre_agg_queries
)
&
owner_query
)
else
:
search
=
search
.
query
(
es_query
)
# pylint: disable=no-member
for
name
,
agg
in
aggs
:
_api_to_es_aggregation
(
search
,
name
,
agg
,
doc_type
=
doc_type
,
post_agg_queries
=
post_agg_queries
)
# execute
try
:
es_response
=
search
.
execute
()
except
RequestError
as
e
:
raise
SearchError
(
e
)
more_response_data
=
{}
# pagination
...
...
tests/app/v1/routers/common.py
View file @
6ad17658
...
...
@@ -333,6 +333,134 @@ def aggregation_test_parameters(entity_id: str, material_prefix: str, entry_pref
]
def
aggregation_exclude_from_search_test_parameters
(
entry_prefix
:
str
,
total_per_entity
:
int
,
total
:
int
):
entry_id
=
f
'
{
entry_prefix
}
entry_id'
upload_id
=
f
'
{
entry_prefix
}
upload_id'
program_name
=
f
'
{
entry_prefix
}
results.method.simulation.program_name'
return
[
pytest
.
param
(
{
f
'
{
entry_id
}
:any'
:
[
'id_01'
]
},
[
{
'exclude_from_search'
:
True
,
'quantity'
:
entry_id
}
],
[
10
],
1
,
200
,
id
=
'exclude'
),
pytest
.
param
(
{
f
'
{
entry_id
}
:any'
:
[
'id_01'
]
},
[
{
'exclude_from_search'
:
False
,
'quantity'
:
entry_id
}
],
[
total_per_entity
],
1
,
200
,
id
=
'dont-exclude'
),
pytest
.
param
(
{
f
'
{
entry_id
}
:any'
:
[
'id_01'
],
upload_id
:
'id_published'
,
program_name
:
'VASP'
},
[
{
'exclude_from_search'
:
True
,
'quantity'
:
entry_id
},
{
'exclude_from_search'
:
True
,
'quantity'
:
upload_id
}
],
[
10
,
1
],
1
,
200
,
id
=
'two-aggs'
),
pytest
.
param
(
{
f
'
{
entry_id
}
:any'
:
[
'id_01'
]
},
[
{
'exclude_from_search'
:
True
,
'quantity'
:
entry_id
},
{
'exclude_from_search'
:
False
,
'quantity'
:
entry_id
}
],
[
10
,
total_per_entity
],
1
,
200
,
id
=
'two-aggs-same-quantity'
),
pytest
.
param
(
{},
[
{
'exclude_from_search'
:
True
,
'quantity'
:
entry_id
}
],
[
10
],
total
,
200
,
id
=
'not-in-query'
),
pytest
.
param
(
{},
[
{
'exclude_from_search'
:
True
,
'quantity'
:
entry_id
,
'pagination'
:
{
'page_size'
:
20
}
}
],
[
20
],
total
,
200
,
id
=
'with-pagination'
),
pytest
.
param
(
{
'or'
:
[{
entry_id
:
'id_01'
}]
},
[
{
'exclude_from_search'
:
True
,
'quantity'
:
entry_id
}
],
[
0
],
0
,
422
,
id
=
'non-dict-query'
),
pytest
.
param
(
{
f
'
{
entry_id
}
:any'
:
[
'id_01'
]
},
[
{
'exclude_from_search'
:
True
,
'quantity'
:
entry_id
},
{
'quantity'
:
entry_id
,
'pagination'
:
{
'page_after_value'
:
'id_published'
}
}
],
[
0
],
0
,
422
,
id
=
'with-page-after-value'
)
]
def
assert_response
(
response
,
status_code
=
None
):
''' General assertions for status_code and error messages '''
if
status_code
and
response
.
status_code
!=
status_code
:
...
...
tests/app/v1/routers/test_entries.py
View file @
6ad17658
...
...
@@ -29,8 +29,8 @@ from tests.utils import ExampleData
from
tests.test_files
import
example_mainfile_contents
,
append_raw_files
# pylint: disable=unused-import
from
.common
import
(
a
ssert_response
,
assert_base_metadata
_response
,
assert_metadata_response
,
assert_required
,
assert_aggregations
,
assert_pagination
,
a
ggregation_exclude_from_search_test_parameters
,
assert
_response
,
assert_
base_
metadata_response
,
assert_metadata_response
,
assert_required
,
assert_aggregations
,
assert_pagination
,
perform_metadata_test
,
post_query_test_parameters
,
get_query_test_parameters
,
perform_owner_test
,
owner_test_parameters
,
pagination_test_parameters
,
aggregation_test_parameters
)
...
...
@@ -368,6 +368,27 @@ def test_entries_aggregations(client, data, test_user_auth, aggregation, total,
default_key
=
'entry_id'
)
@
pytest
.
mark
.
parametrize
(
'query,aggs,agg_lengths,total,status_code'
,
aggregation_exclude_from_search_test_parameters
(
entry_prefix
=
''
,
total_per_entity
=
1
,
total
=
23
))
def
test_entries_aggregations_exclude_from_search
(
client
,
data
,
query
,
aggs
,
agg_lengths
,
total
,
status_code
):
aggs
=
{
f
'agg_
{
i
}
'
:
{
'terms'
:
agg
}
for
i
,
agg
in
enumerate
(
aggs
)}
response_json
=
perform_entries_metadata_test
(
client
,
owner
=
'visible'
,
query
=
query
,
aggregations
=
aggs
,
pagination
=
dict
(
page_size
=
0
),
status_code
=
status_code
,
http_method
=
'post'
)
if
response_json
is
None
:
return
assert
response_json
[
'pagination'
][
'total'
]
==
total
for
i
,
length
in
enumerate
(
agg_lengths
):
response_agg
=
response_json
[
'aggregations'
][
f
'agg_
{
i
}
'
][
'terms'
]
assert
len
(
response_agg
[
'data'
])
==
length
@
pytest
.
mark
.
parametrize
(
'required, status_code'
,
[
pytest
.
param
({
'include'
:
[
'entry_id'
,
'upload_id'
]},
200
,
id
=
'include'
),
pytest
.
param
({
'include'
:
[
'results.*'
,
'upload_id'
]},
200
,
id
=
'include-section'
),
...
...
tests/app/v1/routers/test_materials.py
View file @
6ad17658
...
...
@@ -24,7 +24,7 @@ from nomad.metainfo.elasticsearch_extension import material_entry_type
from
tests.test_files
import
example_mainfile_contents
# pylint: disable=unused-import
from
.common
import
(
assert_pagination
,
assert_metadata_response
,
assert_required
,
assert_aggregations
,
aggregation_exclude_from_search_test_parameters
,
assert_pagination
,
assert_metadata_response
,
assert_required
,
assert_aggregations
,
perform_metadata_test
,
perform_owner_test
,
owner_test_parameters
,
post_query_test_parameters
,
get_query_test_parameters
,
pagination_test_parameters
,
aggregation_test_parameters
)
...
...
@@ -74,6 +74,27 @@ def test_materials_aggregations(client, data, test_user_auth, aggregation, total
default_key
=
'material_id'
)
@
pytest
.
mark
.
parametrize
(
'query,aggs,agg_lengths,total,status_code'
,
aggregation_exclude_from_search_test_parameters
(
entry_prefix
=
'entries.'
,
total_per_entity
=
3
,
total
=
6
))
def
test_materials_aggregations_exclude_from_search
(
client
,
data
,
query
,
aggs
,
agg_lengths
,
total
,
status_code
):
aggs
=
{
f
'agg_
{
i
}
'
:
{
'terms'
:
agg
}
for
i
,
agg
in
enumerate
(
aggs
)}
response_json
=
perform_materials_metadata_test
(
client
,
owner
=
'visible'
,
query
=
query
,
aggregations
=
aggs
,
pagination
=
dict
(
page_size
=
0
),
status_code
=
status_code
,
http_method
=
'post'
)
if
response_json
is
None
:
return
assert
response_json
[
'pagination'
][
'total'
]
==
total
for
i
,
length
in
enumerate
(
agg_lengths
):
response_agg
=
response_json
[
'aggregations'
][
f
'agg_
{
i
}
'
][
'terms'
]
assert
len
(
response_agg
[
'data'
])
==
length
@
pytest
.
mark
.
parametrize
(
'required, status_code'
,
[
pytest
.
param
({
'include'
:
[
'material_id'
,
program_name
]},
200
,
id
=
'include'
),
pytest
.
param
({
'include'
:
[
'entries.*'
,
program_name
]},
200
,
id
=
'include-section'
),
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment