summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/11503.misc1
-rw-r--r--tests/server.py199
-rw-r--r--tests/storage/test_base.py3
-rw-r--r--tests/storage/test_roommember.py2
-rw-r--r--tests/utils.py175
5 files changed, 195 insertions, 185 deletions
diff --git a/changelog.d/11503.misc b/changelog.d/11503.misc
new file mode 100644
index 0000000000..03a24a9224
--- /dev/null
+++ b/changelog.d/11503.misc
@@ -0,0 +1 @@
+Refactor `tests.util.setup_test_homeserver` and `tests.server.setup_test_homeserver`.
\ No newline at end of file
diff --git a/tests/server.py b/tests/server.py
index 40cf5b12c3..ca2b7a5b97 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -11,9 +11,12 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+import hashlib
 import json
 import logging
+import time
+import uuid
+import warnings
 from collections import deque
 from io import SEEK_END, BytesIO
 from typing import (
@@ -27,6 +30,7 @@ from typing import (
     Type,
     Union,
 )
+from unittest.mock import Mock
 
 import attr
 from typing_extensions import Deque
@@ -53,11 +57,24 @@ from twisted.web.http_headers import Headers
 from twisted.web.resource import IResource
 from twisted.web.server import Request, Site
 
+from synapse.config.database import DatabaseConnectionConfig
 from synapse.http.site import SynapseRequest
+from synapse.server import HomeServer
+from synapse.storage import DataStore
+from synapse.storage.engines import PostgresEngine, create_engine
 from synapse.types import JsonDict
 from synapse.util import Clock
 
-from tests.utils import setup_test_homeserver as _sth
+from tests.utils import (
+    LEAVE_DB,
+    POSTGRES_BASE_DB,
+    POSTGRES_HOST,
+    POSTGRES_PASSWORD,
+    POSTGRES_USER,
+    USE_POSTGRES_FOR_TESTS,
+    MockClock,
+    default_config,
+)
 
 logger = logging.getLogger(__name__)
 
@@ -450,14 +467,11 @@ class ThreadPool:
         return d
 
 
-def setup_test_homeserver(cleanup_func, *args, **kwargs):
+def _make_test_homeserver_synchronous(server: HomeServer) -> None:
     """
-    Set up a synchronous test server, driven by the reactor used by
-    the homeserver.
+    Make the given test homeserver's database interactions synchronous.
     """
-    server = _sth(cleanup_func, *args, **kwargs)
 
-    # Make the thread pool synchronous.
     clock = server.get_clock()
 
     for database in server.get_datastores().databases:
@@ -485,6 +499,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
 
         pool.runWithConnection = runWithConnection
         pool.runInteraction = runInteraction
+        # Replace the thread pool with a threadless 'thread' pool
         pool.threadpool = ThreadPool(clock._reactor)
         pool.running = True
 
@@ -492,8 +507,6 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
     # thread, so we need to disable the dedicated thread behaviour.
     server.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False
 
-    return server
-
 
 def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]:
     clock = ThreadedMemoryReactorClock()
@@ -673,3 +686,171 @@ def connect_client(
     client.makeConnection(FakeTransport(server, reactor))
 
     return client, server
+
+
+class TestHomeServer(HomeServer):
+    DATASTORE_CLASS = DataStore
+
+
+def setup_test_homeserver(
+    cleanup_func,
+    name="test",
+    config=None,
+    reactor=None,
+    homeserver_to_use: Type[HomeServer] = TestHomeServer,
+    **kwargs,
+):
+    """
+    Setup a homeserver suitable for running tests against.  Keyword arguments
+    are passed to the Homeserver constructor.
+
+    If no datastore is supplied, one is created and given to the homeserver.
+
+    Args:
+        cleanup_func : The function used to register a cleanup routine for
+                       after the test.
+
+    Calling this method directly is deprecated: you should instead derive from
+    HomeserverTestCase.
+    """
+    if reactor is None:
+        from twisted.internet import reactor
+
+    if config is None:
+        config = default_config(name, parse=True)
+
+    config.ldap_enabled = False
+
+    if "clock" not in kwargs:
+        kwargs["clock"] = MockClock()
+
+    if USE_POSTGRES_FOR_TESTS:
+        test_db = "synapse_test_%s" % uuid.uuid4().hex
+
+        database_config = {
+            "name": "psycopg2",
+            "args": {
+                "database": test_db,
+                "host": POSTGRES_HOST,
+                "password": POSTGRES_PASSWORD,
+                "user": POSTGRES_USER,
+                "cp_min": 1,
+                "cp_max": 5,
+            },
+        }
+    else:
+        database_config = {
+            "name": "sqlite3",
+            "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
+        }
+
+    if "db_txn_limit" in kwargs:
+        database_config["txn_limit"] = kwargs["db_txn_limit"]
+
+    database = DatabaseConnectionConfig("master", database_config)
+    config.database.databases = [database]
+
+    db_engine = create_engine(database.config)
+
+    # Create the database before we actually try and connect to it, based off
+    # the template database we generate in setupdb()
+    if isinstance(db_engine, PostgresEngine):
+        db_conn = db_engine.module.connect(
+            database=POSTGRES_BASE_DB,
+            user=POSTGRES_USER,
+            host=POSTGRES_HOST,
+            password=POSTGRES_PASSWORD,
+        )
+        db_conn.autocommit = True
+        cur = db_conn.cursor()
+        cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
+        cur.execute(
+            "CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB)
+        )
+        cur.close()
+        db_conn.close()
+
+    hs = homeserver_to_use(
+        name,
+        config=config,
+        version_string="Synapse/tests",
+        reactor=reactor,
+    )
+
+    # Install @cache_in_self attributes
+    for key, val in kwargs.items():
+        setattr(hs, "_" + key, val)
+
+    # Mock TLS
+    hs.tls_server_context_factory = Mock()
+    hs.tls_client_options_factory = Mock()
+
+    hs.setup()
+    if homeserver_to_use == TestHomeServer:
+        hs.setup_background_tasks()
+
+    if isinstance(db_engine, PostgresEngine):
+        database = hs.get_datastores().databases[0]
+
+        # We need to do cleanup on PostgreSQL
+        def cleanup():
+            import psycopg2
+
+            # Close all the db pools
+            database._db_pool.close()
+
+            dropped = False
+
+            # Drop the test database
+            db_conn = db_engine.module.connect(
+                database=POSTGRES_BASE_DB,
+                user=POSTGRES_USER,
+                host=POSTGRES_HOST,
+                password=POSTGRES_PASSWORD,
+            )
+            db_conn.autocommit = True
+            cur = db_conn.cursor()
+
+            # Try a few times to drop the DB. Some things may hold on to the
+            # database for a few more seconds due to flakiness, preventing
+            # us from dropping it when the test is over. If we can't drop
+            # it, warn and move on.
+            for _ in range(5):
+                try:
+                    cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
+                    db_conn.commit()
+                    dropped = True
+                except psycopg2.OperationalError as e:
+                    warnings.warn(
+                        "Couldn't drop old db: " + str(e), category=UserWarning
+                    )
+                    time.sleep(0.5)
+
+            cur.close()
+            db_conn.close()
+
+            if not dropped:
+                warnings.warn("Failed to drop old DB.", category=UserWarning)
+
+        if not LEAVE_DB:
+            # Register the cleanup hook
+            cleanup_func(cleanup)
+
+    # bcrypt is far too slow to be doing in unit tests
+    # Need to let the HS build an auth handler and then mess with it
+    # because AuthHandler's constructor requires the HS, so we can't make one
+    # beforehand and pass it in to the HS's constructor (chicken / egg)
+    async def hash(p):
+        return hashlib.md5(p.encode("utf8")).hexdigest()
+
+    hs.get_auth_handler().hash = hash
+
+    async def validate_hash(p, h):
+        return hashlib.md5(p.encode("utf8")).hexdigest() == h
+
+    hs.get_auth_handler().validate_hash = validate_hash
+
+    # Make the threadpool and database transactions synchronous for testing.
+    _make_test_homeserver_synchronous(hs)
+
+    return hs
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index ddad44bd6c..3e4f0579c9 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -23,7 +23,8 @@ from synapse.storage.database import DatabasePool
 from synapse.storage.engines import create_engine
 
 from tests import unittest
-from tests.utils import TestHomeServer, default_config
+from tests.server import TestHomeServer
+from tests.utils import default_config
 
 
 class SQLBaseStoreTestCase(unittest.TestCase):
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index fccab733c0..5cfdfe9b85 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -19,8 +19,8 @@ from synapse.rest.client import login, room
 from synapse.types import UserID, create_requester
 
 from tests import unittest
+from tests.server import TestHomeServer
 from tests.test_utils import event_injection
-from tests.utils import TestHomeServer
 
 
 class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
diff --git a/tests/utils.py b/tests/utils.py
index 983859120f..6d013e8518 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -14,12 +14,7 @@
 # limitations under the License.
 
 import atexit
-import hashlib
 import os
-import time
-import uuid
-import warnings
-from typing import Type
 from unittest.mock import Mock, patch
 from urllib import parse as urlparse
 
@@ -28,14 +23,11 @@ from twisted.internet import defer
 from synapse.api.constants import EventTypes
 from synapse.api.errors import CodeMessageException, cs_error
 from synapse.api.room_versions import RoomVersions
-from synapse.config.database import DatabaseConnectionConfig
 from synapse.config.homeserver import HomeServerConfig
 from synapse.config.server import DEFAULT_ROOM_VERSION
 from synapse.logging.context import current_context, set_current_context
-from synapse.server import HomeServer
-from synapse.storage import DataStore
 from synapse.storage.database import LoggingDatabaseConnection
-from synapse.storage.engines import PostgresEngine, create_engine
+from synapse.storage.engines import create_engine
 from synapse.storage.prepare_database import prepare_database
 
 # set this to True to run the tests against postgres instead of sqlite.
@@ -182,171 +174,6 @@ def default_config(name, parse=False):
     return config_dict
 
 
-class TestHomeServer(HomeServer):
-    DATASTORE_CLASS = DataStore
-
-
-def setup_test_homeserver(
-    cleanup_func,
-    name="test",
-    config=None,
-    reactor=None,
-    homeserver_to_use: Type[HomeServer] = TestHomeServer,
-    **kwargs,
-):
-    """
-    Setup a homeserver suitable for running tests against.  Keyword arguments
-    are passed to the Homeserver constructor.
-
-    If no datastore is supplied, one is created and given to the homeserver.
-
-    Args:
-        cleanup_func : The function used to register a cleanup routine for
-                       after the test.
-
-    Calling this method directly is deprecated: you should instead derive from
-    HomeserverTestCase.
-    """
-    if reactor is None:
-        from twisted.internet import reactor
-
-    if config is None:
-        config = default_config(name, parse=True)
-
-    config.ldap_enabled = False
-
-    if "clock" not in kwargs:
-        kwargs["clock"] = MockClock()
-
-    if USE_POSTGRES_FOR_TESTS:
-        test_db = "synapse_test_%s" % uuid.uuid4().hex
-
-        database_config = {
-            "name": "psycopg2",
-            "args": {
-                "database": test_db,
-                "host": POSTGRES_HOST,
-                "password": POSTGRES_PASSWORD,
-                "user": POSTGRES_USER,
-                "cp_min": 1,
-                "cp_max": 5,
-            },
-        }
-    else:
-        database_config = {
-            "name": "sqlite3",
-            "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
-        }
-
-    if "db_txn_limit" in kwargs:
-        database_config["txn_limit"] = kwargs["db_txn_limit"]
-
-    database = DatabaseConnectionConfig("master", database_config)
-    config.database.databases = [database]
-
-    db_engine = create_engine(database.config)
-
-    # Create the database before we actually try and connect to it, based off
-    # the template database we generate in setupdb()
-    if isinstance(db_engine, PostgresEngine):
-        db_conn = db_engine.module.connect(
-            database=POSTGRES_BASE_DB,
-            user=POSTGRES_USER,
-            host=POSTGRES_HOST,
-            password=POSTGRES_PASSWORD,
-        )
-        db_conn.autocommit = True
-        cur = db_conn.cursor()
-        cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
-        cur.execute(
-            "CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB)
-        )
-        cur.close()
-        db_conn.close()
-
-    hs = homeserver_to_use(
-        name,
-        config=config,
-        version_string="Synapse/tests",
-        reactor=reactor,
-    )
-
-    # Install @cache_in_self attributes
-    for key, val in kwargs.items():
-        setattr(hs, "_" + key, val)
-
-    # Mock TLS
-    hs.tls_server_context_factory = Mock()
-    hs.tls_client_options_factory = Mock()
-
-    hs.setup()
-    if homeserver_to_use == TestHomeServer:
-        hs.setup_background_tasks()
-
-    if isinstance(db_engine, PostgresEngine):
-        database = hs.get_datastores().databases[0]
-
-        # We need to do cleanup on PostgreSQL
-        def cleanup():
-            import psycopg2
-
-            # Close all the db pools
-            database._db_pool.close()
-
-            dropped = False
-
-            # Drop the test database
-            db_conn = db_engine.module.connect(
-                database=POSTGRES_BASE_DB,
-                user=POSTGRES_USER,
-                host=POSTGRES_HOST,
-                password=POSTGRES_PASSWORD,
-            )
-            db_conn.autocommit = True
-            cur = db_conn.cursor()
-
-            # Try a few times to drop the DB. Some things may hold on to the
-            # database for a few more seconds due to flakiness, preventing
-            # us from dropping it when the test is over. If we can't drop
-            # it, warn and move on.
-            for _ in range(5):
-                try:
-                    cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
-                    db_conn.commit()
-                    dropped = True
-                except psycopg2.OperationalError as e:
-                    warnings.warn(
-                        "Couldn't drop old db: " + str(e), category=UserWarning
-                    )
-                    time.sleep(0.5)
-
-            cur.close()
-            db_conn.close()
-
-            if not dropped:
-                warnings.warn("Failed to drop old DB.", category=UserWarning)
-
-        if not LEAVE_DB:
-            # Register the cleanup hook
-            cleanup_func(cleanup)
-
-    # bcrypt is far too slow to be doing in unit tests
-    # Need to let the HS build an auth handler and then mess with it
-    # because AuthHandler's constructor requires the HS, so we can't make one
-    # beforehand and pass it in to the HS's constructor (chicken / egg)
-    async def hash(p):
-        return hashlib.md5(p.encode("utf8")).hexdigest()
-
-    hs.get_auth_handler().hash = hash
-
-    async def validate_hash(p, h):
-        return hashlib.md5(p.encode("utf8")).hexdigest() == h
-
-    hs.get_auth_handler().validate_hash = validate_hash
-
-    return hs
-
-
 def mock_getRawHeaders(headers=None):
     headers = headers if headers is not None else {}