From 65a5b0e98f8a30d9861126cbcd665b9ca8a57246 Mon Sep 17 00:00:00 2001
From: David Sikter <david.sikter@physik.hu-berlin.de>
Date: Thu, 8 Jul 2021 10:25:52 +0200
Subject: [PATCH] Unified zip utility methods

---
 nomad/app/v1/routers/entries.py | 15 +++++----
 nomad/app/v1/utils.py           | 59 +++++---------------------------
 nomad/files.py                  | 60 +++++++++++++++++++++++++--------
 3 files changed, 62 insertions(+), 72 deletions(-)

diff --git a/nomad/app/v1/routers/entries.py b/nomad/app/v1/routers/entries.py
index cd372d5be1..c8f95e5bbb 100644
--- a/nomad/app/v1/routers/entries.py
+++ b/nomad/app/v1/routers/entries.py
@@ -29,6 +29,7 @@ import json
 import orjson
 
 from nomad import files, config, utils
+from nomad.files import StreamedFile, create_zipstream
 from nomad.utils import strip
 from nomad.archive import RequiredReader, RequiredValidationError, ArchiveQueryError
 from nomad.archive import ArchiveQueryError
@@ -38,8 +39,8 @@ from nomad.metainfo.elasticsearch_extension import entry_type
 
 from .auth import create_user_dependency
 from ..utils import (
-    create_streamed_zipfile, create_download_stream_zipped, create_download_stream_raw_file,
-    DownloadItem, File, create_responses)
+    create_download_stream_zipped, create_download_stream_raw_file,
+    DownloadItem, create_responses)
 from ..models import (
     PaginationResponse, MetadataPagination, WithQuery, WithQueryAndPagination, MetadataRequired,
     MetadataResponse, Metadata, Files, Query, User, Owner,
@@ -747,8 +748,8 @@ def _answer_entries_archive_download_request(
 
     required_reader = RequiredReader('*')
 
-    # a generator of File objects to create the streamed zip from
-    def file_generator():
+    # a generator of StreamedFile objects to create the zipstream from
+    def streamed_files():
         # go through all entries that match the query
         for entry_metadata in _do_exaustive_search(owner, query, include=search_includes, user=user):
             path = os.path.join(entry_metadata['upload_id'], '%s.json' % entry_metadata['entry_id'])
@@ -758,7 +759,7 @@ def _answer_entries_archive_download_request(
                 f = io.BytesIO(orjson.dumps(
                     archive_data, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS))
 
-                yield File(path=path, f=f, size=f.getbuffer().nbytes)
+                yield StreamedFile(path=path, f=f, size=f.getbuffer().nbytes)
             except KeyError as e:
                 logger.error('missing archive', entry_id=entry_metadata['entry_id'], exc_info=e)
 
@@ -767,11 +768,11 @@ def _answer_entries_archive_download_request(
 
         # add the manifest at the end
         manifest_content = json.dumps(manifest).encode()
-        yield File(path='manifest.json', f=io.BytesIO(manifest_content), size=len(manifest_content))
+        yield StreamedFile(path='manifest.json', f=io.BytesIO(manifest_content), size=len(manifest_content))
 
     try:
         # create the streaming response with zip file contents
-        content = create_streamed_zipfile(file_generator(), compress=files_params.compress)
+        content = create_zipstream(streamed_files(), compress=files_params.compress)
         return StreamingResponse(content, media_type='application/zip')
     finally:
         uploads.close()
diff --git a/nomad/app/v1/utils.py b/nomad/app/v1/utils.py
index a187604a61..17b78b7b4e 100644
--- a/nomad/app/v1/utils.py
+++ b/nomad/app/v1/utils.py
@@ -22,19 +22,12 @@ import urllib
 import io
 import json
 import os
-import sys
 import inspect
 from fastapi import Request, Query, HTTPException, status  # pylint: disable=unused-import
 from pydantic import ValidationError, BaseModel  # pylint: disable=unused-import
-import zipstream
 import gzip
 import lzma
-from nomad.files import UploadFiles
-
-if sys.version_info >= (3, 7):
-    import zipfile
-else:
-    import zipfile37 as zipfile  # pragma: no cover
+from nomad.files import UploadFiles, StreamedFile, create_zipstream
 
 
 def parameter_dependency_from_model(name: str, model_cls):
@@ -84,42 +77,6 @@ def parameter_dependency_from_model(name: str, model_cls):
     return func
 
 
-class File(BaseModel):
-    path: str
-    f: Any
-    size: int
-
-
-def create_streamed_zipfile(
-        files: Iterator[File],
-        compress: bool = False) -> Iterator[bytes]:
-
-    '''
-    Creates a streaming zipfile object that can be used in fastapi's ``StreamingResponse``.
-    '''
-
-    def path_to_write_generator():
-        for file_obj in files:
-            def content_generator():
-                while True:
-                    data = file_obj.f.read(1024 * 64)
-                    if not data:
-                        break
-                    yield data
-
-            yield dict(
-                arcname=file_obj.path,
-                iterable=content_generator(),
-                buffer_size=file_obj.size)
-
-    compression = zipfile.ZIP_DEFLATED if compress else zipfile.ZIP_STORED
-    zip_stream = zipstream.ZipFile(mode='w', compression=compression, allowZip64=True)
-    zip_stream.paths_to_write = path_to_write_generator()
-
-    for chunk in zip_stream:
-        yield chunk
-
-
 class DownloadItem(BaseModel):
     ''' Defines an object (file or folder) for download. '''
     upload_id: str
@@ -147,7 +104,7 @@ def create_download_stream_zipped(
         create_manifest_file: if set, a manifest file is created in the root folder.
         compress: if the zip file should be compressed or not
     '''
-    def file_generator(upload_files) -> Iterator[File]:
+    def streamed_files(upload_files) -> Iterator[StreamedFile]:
         manifest = []
         try:
             items: Iterator[DownloadItem] = (
@@ -174,7 +131,7 @@ def create_download_stream_zipped(
                     # File
                     if download_item.zip_path not in streamed_paths:
                         streamed_paths.add(download_item.zip_path)
-                        yield File(
+                        yield StreamedFile(
                             path=download_item.zip_path,
                             f=upload_files.raw_file(download_item.raw_path, 'rb'),
                             size=upload_files.raw_file_size(download_item.raw_path))
@@ -189,7 +146,7 @@ def create_download_stream_zipped(
                             zip_path = os.path.join(download_item.zip_path, relative_path)
                             if zip_path not in streamed_paths:
                                 streamed_paths.add(zip_path)
-                                yield File(
+                                yield StreamedFile(
                                     path=zip_path,
                                     f=upload_files.raw_file(path_info.path, 'rb'),
                                     size=path_info.size)
@@ -200,13 +157,13 @@ def create_download_stream_zipped(
 
             if create_manifest_file:
                 manifest_content = json.dumps(manifest).encode()
-                yield File(path='manifest.json', f=io.BytesIO(manifest_content), size=len(manifest_content))
+                yield StreamedFile(path='manifest.json', f=io.BytesIO(manifest_content), size=len(manifest_content))
 
         finally:
             if upload_files:
                 upload_files.close()
 
-    return create_streamed_zipfile(file_generator(upload_files), compress=compress)
+    return create_zipstream(streamed_files(upload_files), compress=compress)
 
 
 def create_download_stream_raw_file(
@@ -258,9 +215,9 @@ def create_download_stream_raw_file(
     upload_files.close()
 
 
-def create_stream_from_string(content: str) -> Iterator[bytes]:
+def create_stream_from_string(content: str) -> io.BytesIO:
     ''' For returning strings as content using '''
-    yield content.encode()
+    return io.BytesIO(content.encode())
 
 
 def create_responses(*args):
diff --git a/nomad/files.py b/nomad/files.py
index 34b59a350d..ea413ab616 100644
--- a/nomad/files.py
+++ b/nomad/files.py
@@ -230,7 +230,7 @@ class StreamedFile(BaseModel):
     Convenience class for representing a streamed file, together with information about
     file size and an associated path.
     '''
-    f: Iterable[bytes]
+    f: Any
     path: str
     size: int
 
@@ -294,6 +294,40 @@ class FileSource:
                         size=os.stat(os_path).st_size)
 
 
+def create_zipstream_content(streamed_files: Iterable[StreamedFile]) -> Iterable[Dict]:
+    '''
+    Generator which "casts" a sequence of StreamedFiles to a sequence of dictionaries, of
+    the form which is required by the `zipstream` library, i.e. dictionaries with keys
+    `arcname`, `iterable` and `buffer_size`. Useful for generating zipstreams.
+    '''
+    for streamed_file in streamed_files:
+
+        def content_generator():
+            while True:
+                data = streamed_file.f.read(1024 * 64)
+                if not data:
+                    break
+                yield data
+
+        yield dict(
+            arcname=streamed_file.path,
+            iterable=content_generator(),
+            buffer_size=streamed_file.size)
+
+
+def create_zipstream(
+        streamed_files: Iterable[StreamedFile],
+        compress: bool = False) -> Iterator[bytes]:
+    '''
+    Creates a zip stream, i.e. a streamed zip file.
+    '''
+    compression = zipfile.ZIP_DEFLATED if compress else zipfile.ZIP_STORED
+    zip_stream = zipstream.ZipFile(mode='w', compression=compression, allowZip64=True)
+    zip_stream.paths_to_write = create_zipstream_content(streamed_files)
+
+    return iter(zip_stream)
+
+
 class UploadFiles(DirectoryObject, metaclass=ABCMeta):
     ''' Abstract base class for upload files. '''
     def __init__(
@@ -1292,28 +1326,26 @@ class UploadBundle:
                 json.dump(bundle_info, bundle_info_file, indent=2)
         else:
             # Exporting zipped
-            def path_to_write_generator():
+            def streamed_files():
                 # Generator for generating zip file content
                 # 1. Yield all the selected regular files
                 for file_source in upload_files.files_for_bundle(
                         include_raw_files, include_protected_raw_files, include_archive_files):
                     for streamed_file in file_source.to_streamed_files():
-                        yield dict(
-                            arcname=os.path.join(upload_id, streamed_file.path),
-                            iterable=streamed_file.f,
-                            buffer_size=streamed_file.size)
+                        # Add upload_id at the beginning of the path
+                        streamed_file.path = os.path.join(upload_id, streamed_file.path)
+                        yield streamed_file
                 # 2. Finally, also yield a stream for the bundle_info.json
-                bundle_info_bytes = json.dumps(bundle_info, indent=2).encode('utf8')
-                yield dict(
-                    arcname=os.path.join(upload_id, 'bundle_info.json'),
-                    iterable=io.BytesIO(bundle_info_bytes),
-                    buffer_size=len(bundle_info_bytes))
+                bundle_info_bytes = json.dumps(bundle_info, indent=2).encode()
+                yield StreamedFile(
+                    path=os.path.join(upload_id, 'bundle_info.json'),
+                    f=io.BytesIO(bundle_info_bytes),
+                    size=len(bundle_info_bytes))
 
-            zip_stream = zipstream.ZipFile(mode='w', compression=zipfile.ZIP_DEFLATED, allowZip64=True)
-            zip_stream.paths_to_write = path_to_write_generator()
+            zip_stream = create_zipstream(streamed_files())
 
             if export_as_stream:
-                return iter(zip_stream)
+                return zip_stream
             else:
                 # Write to zip file
                 with open(export_path, 'wb') as f:
-- 
GitLab