From d0dfdd0d6dd671f1986772d5d5fcfcfe802b64b1 Mon Sep 17 00:00:00 2001
From: Theodore Chang <tlcfem@gmail.com>
Date: Sat, 29 Apr 2023 09:51:38 +0200
Subject: [PATCH] Use modern stmp impl

---
 pyproject.toml       |  1 +
 requirements-dev.txt |  4 ++-
 tests/conftest.py    | 86 +++++++++++++++++---------------------------
 3 files changed, 37 insertions(+), 54 deletions(-)

diff --git a/pyproject.toml b/pyproject.toml
index c71931d35d..823ecf65e8 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -129,6 +129,7 @@ dev = [
     'mkdocs-material==8.1.1',
     'mkdocs-material-extensions==1.0.3',
     'mkdocs-macros-plugin==0.6.3',
+    'aiosmtpd'
 ]
 
 [project.scripts]
diff --git a/requirements-dev.txt b/requirements-dev.txt
index 2c3e208140..cfd5b606a4 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -4,6 +4,7 @@
 #
 #    pip-compile --annotation-style=line --extra=dev --extra=infrastructure --extra=parsing --output-file=requirements-dev.txt dependencies/matid/pyproject.toml dependencies/nomad-dos-fingerprints/pyproject.toml dependencies/parsers/atomistic/pyproject.toml dependencies/parsers/database/pyproject.toml dependencies/parsers/eelsdb/pyproject.toml dependencies/parsers/electronic/pyproject.toml dependencies/parsers/nexus/pyproject.toml dependencies/parsers/workflow/pyproject.toml pyproject.toml requirements.txt
 #
+aiosmtpd==1.4.4.post2     # via -r requirements.txt, nomad-lab (pyproject.toml)
 alabaster==0.7.12         # via -r requirements.txt, sphinx
 alembic==1.9.1            # via -r requirements.txt, jupyterhub
 amqp==5.1.1               # via -r requirements.txt, kombu
@@ -23,7 +24,8 @@ astunparse==1.6.3         # via -r requirements.txt, mdtraj
 async-generator==1.10     # via -r requirements.txt, jupyterhub
 async-timeout==4.0.2      # via -r requirements.txt, redis
 atomicwrites==1.4.1       # via pytest
-attrs==22.2.0             # via -r requirements.txt, cattrs, jsonschema, pytest, requests-cache
+atpublic==3.1.1           # via -r requirements.txt, aiosmtpd
+attrs==22.2.0             # via -r requirements.txt, aiosmtpd, cattrs, jsonschema, pytest, requests-cache
 babel==2.11.0             # via -r requirements.txt, jupyterlab-server, sphinx
 backcall==0.2.0           # via -r requirements.txt, ipython
 bagit==1.8.1              # via -r requirements.txt, nomad-lab (pyproject.toml)
diff --git a/tests/conftest.py b/tests/conftest.py
index cffde080ba..aef69765d8 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -15,15 +15,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-
-from typing import Tuple, List
+from typing import Tuple
 import math
 import pytest
 import logging
 from collections import namedtuple
-from smtpd import SMTPServer
-from threading import Lock, Thread
-import asyncore
 import time
 from datetime import datetime
 import shutil
@@ -33,6 +29,8 @@ from typing import List
 import json
 import logging
 import warnings
+
+from aiosmtpd.controller import Controller
 from fastapi.testclient import TestClient
 
 from nomad import config, infrastructure, processing, utils, datamodel, bundles
@@ -467,94 +465,76 @@ def with_warn(caplog):
     assert count > 0
 
 
-'''
-Fixture for mocked SMTP server for testing.
-Based on https://gist.github.com/akheron/cf3863cdc424f08929e4cb7dc365ef23.
-'''
-
 RecordedMessage = namedtuple(
     'RecordedMessage',
     'peer envelope_from envelope_recipients data',
 )
 
 
-class ThreadSafeList:
-    def __init__(self, *args, **kwds):
-        self._items = []
-        self._lock = Lock()
-
-    def clear(self):
-        with self._lock:
-            self._items = []
+class Handler:
+    def __init__(self):
+        self.messages = []
 
-    def add(self, item):
-        with self._lock:
-            self._items.append(item)
+    async def handle_exception(self, exc):
+        return '250 Dummy'
 
-    def copy(self):
-        with self._lock:
-            return self._items[:]
+    async def handle_DATA(self, server, session, envelope):
+        peer = session.peer
+        mailfrom = envelope.mail_from
+        rcpttos = envelope.rcpt_tos
+        data = envelope.content
+        msg = RecordedMessage(peer, mailfrom, rcpttos, data)
+        self.messages.append(msg)
 
 
-class SMTPServerThread(Thread):
-    def __init__(self, messages):
-        super().__init__()
-        self.messages = messages
+class SMTPServer:
+    def __init__(self):
         self.host_port = None
         self.smtp = None
+        self.handler = None
 
     def run(self):
-        _messages = self.messages
-
-        class _SMTPServer(SMTPServer):
-            def process_message(self, peer, mailfrom, rcpttos, data, **kwargs):
-                msg = RecordedMessage(peer, mailfrom, rcpttos, data)
-                _messages.add(msg)
-
-        self.smtp = _SMTPServer(('127.0.0.1', config.mail.port), None)
-        self.host_port = self.smtp.socket.getsockname()
-        try:
-            asyncore.loop(1)
-        except Exception:
-            pass
+        self.handler = Handler()
+        self.smtp = Controller(self.handler, hostname='127.0.0.1', port=config.mail.port)
+        self.smtp.start()
+        self.host_port = self.smtp.hostname, self.smtp.port
 
     def close(self):
         if self.smtp is not None:
-            self.smtp.close()
+            self.smtp.stop()
 
 
 class SMTPServerFixture:
     def __init__(self):
-        self._messages = ThreadSafeList()
-        self._thread = SMTPServerThread(self._messages)
-        self._thread.start()
+        self.server = SMTPServer()
+        self.server.run()
+        _ = self.host_port
 
     @property
     def host_port(self):
         '''SMTP server's listening address as a (host, port) tuple'''
-        while self._thread.host_port is None:
+        while self.server.host_port is None:
             time.sleep(0.1)
-        return self._thread.host_port
+        return self.server.host_port
 
     @property
     def host(self):
-        return self.host_port[0]
+        return self.server.host_port[0]
 
     @property
     def port(self):
-        return self.host_port[1]
+        return self.server.host_port[1]
 
     @property
     def messages(self):
         '''A list of RecordedMessage objects'''
-        return self._messages.copy()
+        return self.server.handler.messages[:]
 
     def clear(self):
-        self._messages.clear()
+        self.server.handler.messages = []
 
     def close(self):
-        self._thread.close()
-        self._thread.join(1)
+        self.server.close()
 
 
 @pytest.fixture(scope='session')
-- 
GitLab