Commit 626af9bf authored by Markus Scheidgen's avatar Markus Scheidgen
Browse files

Refactored zipfile streaming. #457

parent c72c9ccc
Pipeline #89619 passed with stages
in 25 minutes and 43 seconds
......@@ -25,7 +25,6 @@ from typing import Dict, Any
from io import BytesIO
from flask import request, g
from flask_restplus import abort, Resource, fields
import json
import orjson
import urllib.parse
......@@ -163,6 +162,15 @@ class ArchiveDownloadResource(Resource):
for entry in calcs:
upload_id = entry['upload_id']
calc_id = entry['calc_id']
manifest = {
calc_id: {
key: entry[key]
for key in ArchiveDownloadResource.manifest_quantities
if entry.get(key) is not None
}
}
if upload_files is None or upload_files.upload_id != upload_id:
if upload_files is not None:
upload_files.close()
......@@ -182,40 +190,21 @@ class ArchiveDownloadResource(Resource):
option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS))
yield (
'%s.%s' % (calc_id, 'json'), calc_id,
'%s.%s' % (calc_id, 'json'), calc_id, manifest,
lambda calc_id: f,
lambda calc_id: f.getbuffer().nbytes)
manifest[calc_id] = {
key: entry[key]
for key in ArchiveDownloadResource.manifest_quantities
if entry.get(key) is not None
}
if upload_files is not None:
upload_files.close()
try:
manifest_contents = json.dumps(manifest).encode('utf-8')
except Exception as e:
manifest_contents = json.dumps(
dict(error='Could not create the manifest: %s' % (e))).encode('utf-8')
common.logger.error(
'could not create raw query manifest', exc_info=e)
yield (
'manifest.json', 'manifest',
lambda *args: BytesIO(manifest_contents),
lambda *args: len(manifest_contents))
except Exception as e:
common.logger.warning(
common.logger.error(
'unexpected error while streaming raw data from query',
exc_info=e,
query=urllib.parse.urlencode(request.args, doseq=True))
return streamed_zipfile(
generator(), zipfile_name='nomad_archive.zip', compress=compress)
generator(), zipfile_name='nomad_archive.zip', compress=compress, manifest=dict())
_archive_query_model = api.inherit('ArchiveSearch', search_model, {
......
......@@ -34,7 +34,7 @@ import os.path
import gzip
from functools import wraps
from nomad import search, config, datamodel
from nomad import search, config, datamodel, utils
from nomad.app.optimade import filterparser
from nomad.app.common import RFC3339DateTime, rfc3339DateTime
from nomad.files import Restricted
......@@ -282,19 +282,22 @@ def upload_route(ns, prefix: str = ''):
def streamed_zipfile(
files: Iterable[Tuple[str, str, Callable[[str], IO], Callable[[str], int]]],
zipfile_name: str, compress: bool = False):
files: Iterable[Tuple[str, str, dict, Callable[[str], IO], Callable[[str], int]]],
zipfile_name: str, compress: bool = False, manifest: dict = None):
'''
Creates a response that streams the given files as a streamed zip file. Ensures that
each given file is only streamed once, based on its filename in the resulting zipfile.
Arguments:
files: An iterable of tuples with the filename to be used in the resulting zipfile,
an file id within the upload, a callable that gives an binary IO object for the
file id, and a callable that gives the file size for the file id.
an file id within the upload, an optional manifest, a callable that gives an
binary IO object for the file id, and a callable that gives the file size for
the file id.
zipfile_name: A name that will be used in the content disposition attachment
used as an HTTP respone.
compress: Uses compression. Default is stored only.
manifest: The dict contents of the manifest file. Will be extended if the files
provide manifest information.
'''
streamed_files: Set[str] = set()
......@@ -306,11 +309,16 @@ def streamed_zipfile(
Replace the directory based iter of zipstream with an iter over all given
files.
'''
collected_manifest = manifest
# the actual contents
for zipped_filename, file_id, open_io, file_size in files:
for zipped_filename, file_id, manifest_part, open_io, file_size in files:
if manifest_part is not None:
if collected_manifest is None:
collected_manifest = {}
collected_manifest.update(manifest_part)
if zipped_filename in streamed_files:
continue
streamed_files.add(zipped_filename)
# Write a file to the zipstream.
try:
......@@ -319,7 +327,7 @@ def streamed_zipfile(
def iter_content():
while True:
data = f.read(1024 * 64)
if not data:
if len(data) == 0:
break
yield data
......@@ -335,6 +343,21 @@ def streamed_zipfile(
# due to the streaming nature, we cannot raise 401 here
# we just leave it out in the download
pass
except Exception as e:
utils.get_logger('__name__').error(
'unexpected exception while streaming zipfile', exc_info=e)
manifest.setdefault('errors', []).append(
'unexpected exception while streaming zipfile: %s' % str(e))
try:
if collected_manifest is not None:
manifest_content = json.dumps(collected_manifest).encode('utf-8')
yield dict(
arcname='manifest.json', iterable=[manifest_content],
buffer_size=len(manifest_content))
except Exception as e:
utils.get_logger('__name__').error(
'could not stream zipfile manifest', exc_info=e)
compression = zipfile.ZIP_DEFLATED if compress else zipfile.ZIP_STORED
zip_stream = zipstream.ZipFile(mode='w', compression=compression, allowZip64=True)
......
......@@ -22,12 +22,10 @@ The raw API of the nomad@FAIRDI APIs. Can be used to retrieve raw calculation fi
from typing import IO, Any, Union, List
import os.path
from io import BytesIO
from flask import request, send_file
from flask_restplus import abort, Resource, fields
import magic
import fnmatch
import json
import gzip
import lzma
import urllib.parse
......@@ -427,119 +425,15 @@ class RawFileQueryResource(Resource):
The zip file will contain a ``manifest.json`` with the repository meta data.
'''
logger = common.logger.bind(query=urllib.parse.urlencode(request.args, doseq=True))
patterns: List[str] = None
try:
args = _raw_file_from_query_parser.parse_args()
compress = args.get('compress', False)
strip = args.get('strip', False)
pattern = args.get('file_pattern', None)
if isinstance(pattern, str):
patterns = [pattern]
elif pattern is None:
patterns = []
else:
patterns = pattern
except Exception:
abort(400, message='bad parameter types')
abort(400, message='could not parse request arguments')
search_request = search.SearchRequest()
apply_search_parameters(search_request, _raw_file_from_query_parser.parse_args())
search_request.include('calc_id', 'upload_id', 'mainfile')
def path(entry):
return '%s/%s' % (entry['upload_id'], entry['mainfile'])
calcs = search_request.execute_scan(
order_by='upload_id',
size=config.services.download_scan_size,
scroll=config.services.download_scan_timeout)
if strip:
if search_request.execute()['total'] > config.raw_file_strip_cutoff:
abort(400, 'The requested download has to many files for using "strip".')
calcs = list(calcs)
paths = [path(entry) for entry in calcs]
common_prefix_len = len(utils.common_prefix(paths))
else:
common_prefix_len = 0
def generator():
try:
manifest = {}
directories = set()
upload_files = None
streamed, skipped = 0, 0
for entry in calcs:
upload_id = entry['upload_id']
mainfile = entry['mainfile']
if upload_files is None or upload_files.upload_id != upload_id:
logger.info('opening next upload for raw file streaming', upload_id=upload_id)
if upload_files is not None:
upload_files.close()
upload_files = UploadFiles.get(upload_id)
if upload_files is None:
logger.error('upload files do not exist', upload_id=upload_id)
continue
def open_file(upload_filename):
return upload_files.raw_file(upload_filename, 'rb')
upload_files._is_authorized = create_authorization_predicate(
upload_id=upload_id, calc_id=entry['calc_id'])
directory = os.path.dirname(mainfile)
directory_w_upload = os.path.join(upload_files.upload_id, directory)
if directory_w_upload not in directories:
streamed += 1
directories.add(directory_w_upload)
for filename, file_size in upload_files.raw_file_list(directory=directory):
filename = os.path.join(directory, filename)
filename_w_upload = os.path.join(upload_files.upload_id, filename)
filename_wo_prefix = filename_w_upload[common_prefix_len:]
if len(patterns) == 0 or any(
fnmatch.fnmatchcase(os.path.basename(filename_wo_prefix), pattern)
for pattern in patterns):
yield (
filename_wo_prefix, filename, open_file,
lambda *args, **kwargs: file_size)
else:
skipped += 1
if (streamed + skipped) % 10000 == 0:
logger.info('streaming raw files', streamed=streamed, skipped=skipped)
manifest[path(entry)] = {
key: entry[key]
for key in RawFileQueryResource.manifest_quantities
if entry.get(key) is not None
}
if upload_files is not None:
upload_files.close()
logger.info('streaming raw file manifest')
try:
manifest_contents = json.dumps(manifest).encode('utf-8')
except Exception as e:
manifest_contents = json.dumps(
dict(error='Could not create the manifest: %s' % (e))).encode('utf-8')
logger.error('could not create raw query manifest', exc_info=e)
yield (
'manifest.json', 'manifest',
lambda *args: BytesIO(manifest_contents),
lambda *args: len(manifest_contents))
except Exception as e:
logger.warning(
'unexpected error while streaming raw data from query', exc_info=e)
logger.info('start streaming raw files')
return streamed_zipfile(
generator(), zipfile_name='nomad_raw_files.zip', compress=compress)
return respond_to_raw_files_query(search_request, args, logger)
@api.doc('post_raw_files_from_query')
@api.expect(_raw_file_from_query_model)
......@@ -558,28 +452,17 @@ class RawFileQueryResource(Resource):
The zip file will contain a ``manifest.json`` with the repository meta data.
'''
patterns: List[str] = None
try:
data_in = request.get_json()
compress = data_in.get('compress', False)
strip = data_in.get('strip', False)
pattern = data_in.get('file_pattern', None)
if isinstance(pattern, str):
patterns = [pattern]
elif pattern is None:
patterns = []
else:
patterns = pattern
query = data_in.get('query', {})
post_data = request.get_json()
query = post_data.get('query', {})
query_expression = {key: val for key, val in query.items() if '$' in key}
except Exception:
abort(400, message='bad parameter types')
abort(400, message='could not parse request body')
logger = common.logger.bind(query=urllib.parse.urlencode(query, doseq=True))
search_request = search.SearchRequest()
apply_search_parameters(search_request, query)
search_request.include('calc_id', 'upload_id', 'mainfile')
if query_expression:
try:
......@@ -587,99 +470,106 @@ class RawFileQueryResource(Resource):
except AssertionError as e:
abort(400, str(e))
def path(entry):
return '%s/%s' % (entry['upload_id'], entry['mainfile'])
return respond_to_raw_files_query(search_request, post_data, logger)
calcs = search_request.execute_scan(
order_by='upload_id',
size=config.services.download_scan_size,
scroll=config.services.download_scan_timeout)
if strip:
if search_request.execute()['total'] > config.raw_file_strip_cutoff:
abort(400, 'The requested download has to many files for using "strip".')
calcs = list(calcs)
paths = [path(entry) for entry in calcs]
common_prefix_len = len(utils.common_prefix(paths))
def respond_to_raw_files_query(search_request, args, logger):
patterns: List[str] = None
try:
compress = args.get('compress', False)
strip = args.get('strip', False)
pattern = args.get('file_pattern', None)
if isinstance(pattern, str):
patterns = [pattern]
elif pattern is None:
patterns = []
else:
common_prefix_len = 0
patterns = pattern
except Exception:
abort(400, message='bad parameter types')
def generator():
try:
manifest = {}
directories = set()
upload_files = None
streamed, skipped = 0, 0
for entry in calcs:
upload_id = entry['upload_id']
mainfile = entry['mainfile']
if upload_files is None or upload_files.upload_id != upload_id:
logger.info('opening next upload for raw file streaming', upload_id=upload_id)
if upload_files is not None:
upload_files.close()
upload_files = UploadFiles.get(upload_id)
if upload_files is None:
logger.error('upload files do not exist', upload_id=upload_id)
continue
def open_file(upload_filename):
return upload_files.raw_file(upload_filename, 'rb')
upload_files._is_authorized = create_authorization_predicate(
upload_id=upload_id, calc_id=entry['calc_id'])
directory = os.path.dirname(mainfile)
directory_w_upload = os.path.join(upload_files.upload_id, directory)
if directory_w_upload not in directories:
streamed += 1
directories.add(directory_w_upload)
for filename, file_size in upload_files.raw_file_list(directory=directory):
filename = os.path.join(directory, filename)
filename_w_upload = os.path.join(upload_files.upload_id, filename)
filename_wo_prefix = filename_w_upload[common_prefix_len:]
if len(patterns) == 0 or any(
fnmatch.fnmatchcase(os.path.basename(filename_wo_prefix), pattern)
for pattern in patterns):
yield (
filename_wo_prefix, filename, open_file,
lambda *args, **kwargs: file_size)
else:
skipped += 1
if (streamed + skipped) % 10000 == 0:
logger.info('streaming raw files', streamed=streamed, skipped=skipped)
manifest[path(entry)] = {
key: entry[key]
for key in RawFileQueryResource.manifest_quantities
if entry.get(key) is not None
}
search_request.include('calc_id', 'upload_id', 'mainfile')
if upload_files is not None:
upload_files.close()
def path(entry):
return '%s/%s' % (entry['upload_id'], entry['mainfile'])
logger.info('streaming raw file manifest')
try:
manifest_contents = json.dumps(manifest).encode('utf-8')
except Exception as e:
manifest_contents = json.dumps(
dict(error='Could not create the manifest: %s' % (e))).encode('utf-8')
logger.error('could not create raw query manifest', exc_info=e)
calcs = search_request.execute_scan(
order_by='upload_id',
size=config.services.download_scan_size,
scroll=config.services.download_scan_timeout)
yield (
'manifest.json', 'manifest',
lambda *args: BytesIO(manifest_contents),
lambda *args: len(manifest_contents))
if strip:
if search_request.execute()['total'] > config.raw_file_strip_cutoff:
abort(400, 'The requested download has to many files for using "strip".')
calcs = list(calcs)
paths = [path(entry) for entry in calcs]
common_prefix_len = len(utils.common_prefix(paths))
else:
common_prefix_len = 0
except Exception as e:
logger.warning(
'unexpected error while streaming raw data from query', exc_info=e)
def generator():
try:
directories = set()
upload_files = None
streamed, skipped = 0, 0
logger.info('start streaming raw files')
return streamed_zipfile(
generator(), zipfile_name='nomad_raw_files.zip', compress=compress)
for entry in calcs:
manifest = {
path(entry): {
key: entry[key]
for key in RawFileQueryResource.manifest_quantities
if entry.get(key) is not None
}
}
upload_id = entry['upload_id']
mainfile = entry['mainfile']
if upload_files is None or upload_files.upload_id != upload_id:
logger.info('opening next upload for raw file streaming', upload_id=upload_id)
if upload_files is not None:
upload_files.close()
upload_files = UploadFiles.get(upload_id)
if upload_files is None:
logger.error('upload files do not exist', upload_id=upload_id)
continue
def open_file(upload_filename):
return upload_files.raw_file(upload_filename, 'rb')
upload_files._is_authorized = create_authorization_predicate(
upload_id=upload_id, calc_id=entry['calc_id'])
directory = os.path.dirname(mainfile)
directory_w_upload = os.path.join(upload_files.upload_id, directory)
if directory_w_upload not in directories:
streamed += 1
directories.add(directory_w_upload)
for filename, file_size in upload_files.raw_file_list(directory=directory):
filename = os.path.join(directory, filename)
filename_w_upload = os.path.join(upload_files.upload_id, filename)
filename_wo_prefix = filename_w_upload[common_prefix_len:]
if len(patterns) == 0 or any(
fnmatch.fnmatchcase(os.path.basename(filename_wo_prefix), pattern)
for pattern in patterns):
yield (
filename_wo_prefix, filename, manifest, open_file,
lambda *args, **kwargs: file_size)
else:
skipped += 1
if (streamed + skipped) % 10000 == 0:
logger.info('streaming raw files', streamed=streamed, skipped=skipped)
if upload_files is not None:
upload_files.close()
except Exception as e:
logger.error('unexpected error while streaming raw data from query', exc_info=e)
logger.info('start streaming raw files')
return streamed_zipfile(
generator(), zipfile_name='nomad_raw_files.zip', compress=compress, manifest=dict())
def respond_to_get_raw_files(upload_id, files, compress=False, strip=False):
......@@ -696,7 +586,7 @@ def respond_to_get_raw_files(upload_id, files, compress=False, strip=False):
try:
return streamed_zipfile(
[(
filename[common_prefix_len:], filename,
filename[common_prefix_len:], filename, None,
lambda upload_filename: upload_files.raw_file(upload_filename, 'rb'),
lambda upload_filename: upload_files.raw_file_size(upload_filename)
) for filename in files],
......
......@@ -1811,6 +1811,22 @@ class TestRaw(UploadFilesBasedTests):
assert rv.status_code == 200
assert_zip_file(rv, files=1)
def test_raw_files_from_query_file_error(self, api, processeds, test_user_auth, monkeypatch):
def raw_file(self, *args, **kwargs):
raise Exception('test error')
monkeypatch.setattr('nomad.files.StagingUploadFiles.raw_file', raw_file)
url = '/raw/query?%s' % urlencode({'atoms': 'Si'})
rv = api.get(url, headers=test_user_auth)
assert rv.status_code == 200
with zipfile.ZipFile(io.BytesIO(rv.data)) as zip_file:
with zip_file.open('manifest.json', 'r') as f:
manifest = json.load(f)
assert 'errors' in manifest
assert len(manifest['errors']) > 0
@pytest.mark.parametrize('files, pattern, strip', [
(1, '*.json', False),
(1, '*.json', True),
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment