diff --git a/changelog.d/11503.misc b/changelog.d/11503.misc
new file mode 100644
index 0000000000..03a24a9224
--- /dev/null
+++ b/changelog.d/11503.misc
@@ -0,0 +1 @@
+Refactor `tests.util.setup_test_homeserver` and `tests.server.setup_test_homeserver`.
\ No newline at end of file
diff --git a/tests/server.py b/tests/server.py
index 40cf5b12c3..ca2b7a5b97 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -11,9 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import hashlib
import json
import logging
+import time
+import uuid
+import warnings
from collections import deque
from io import SEEK_END, BytesIO
from typing import (
@@ -27,6 +30,7 @@ from typing import (
Type,
Union,
)
+from unittest.mock import Mock
import attr
from typing_extensions import Deque
@@ -53,11 +57,24 @@ from twisted.web.http_headers import Headers
from twisted.web.resource import IResource
from twisted.web.server import Request, Site
+from synapse.config.database import DatabaseConnectionConfig
from synapse.http.site import SynapseRequest
+from synapse.server import HomeServer
+from synapse.storage import DataStore
+from synapse.storage.engines import PostgresEngine, create_engine
from synapse.types import JsonDict
from synapse.util import Clock
-from tests.utils import setup_test_homeserver as _sth
+from tests.utils import (
+ LEAVE_DB,
+ POSTGRES_BASE_DB,
+ POSTGRES_HOST,
+ POSTGRES_PASSWORD,
+ POSTGRES_USER,
+ USE_POSTGRES_FOR_TESTS,
+ MockClock,
+ default_config,
+)
logger = logging.getLogger(__name__)
@@ -450,14 +467,11 @@ class ThreadPool:
return d
-def setup_test_homeserver(cleanup_func, *args, **kwargs):
+def _make_test_homeserver_synchronous(server: HomeServer) -> None:
"""
- Set up a synchronous test server, driven by the reactor used by
- the homeserver.
+ Make the given test homeserver's database interactions synchronous.
"""
- server = _sth(cleanup_func, *args, **kwargs)
- # Make the thread pool synchronous.
clock = server.get_clock()
for database in server.get_datastores().databases:
@@ -485,6 +499,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
pool.runWithConnection = runWithConnection
pool.runInteraction = runInteraction
+ # Replace the thread pool with a threadless 'thread' pool
pool.threadpool = ThreadPool(clock._reactor)
pool.running = True
@@ -492,8 +507,6 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
# thread, so we need to disable the dedicated thread behaviour.
server.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False
- return server
-
def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]:
clock = ThreadedMemoryReactorClock()
@@ -673,3 +686,171 @@ def connect_client(
client.makeConnection(FakeTransport(server, reactor))
return client, server
+
+
+class TestHomeServer(HomeServer):
+ DATASTORE_CLASS = DataStore
+
+
+def setup_test_homeserver(
+ cleanup_func,
+ name="test",
+ config=None,
+ reactor=None,
+ homeserver_to_use: Type[HomeServer] = TestHomeServer,
+ **kwargs,
+):
+ """
+ Setup a homeserver suitable for running tests against. Keyword arguments
+ are passed to the Homeserver constructor.
+
+ If no datastore is supplied, one is created and given to the homeserver.
+
+ Args:
+ cleanup_func : The function used to register a cleanup routine for
+ after the test.
+
+ Calling this method directly is deprecated: you should instead derive from
+ HomeserverTestCase.
+ """
+ if reactor is None:
+ from twisted.internet import reactor
+
+ if config is None:
+ config = default_config(name, parse=True)
+
+ config.ldap_enabled = False
+
+ if "clock" not in kwargs:
+ kwargs["clock"] = MockClock()
+
+ if USE_POSTGRES_FOR_TESTS:
+ test_db = "synapse_test_%s" % uuid.uuid4().hex
+
+ database_config = {
+ "name": "psycopg2",
+ "args": {
+ "database": test_db,
+ "host": POSTGRES_HOST,
+ "password": POSTGRES_PASSWORD,
+ "user": POSTGRES_USER,
+ "cp_min": 1,
+ "cp_max": 5,
+ },
+ }
+ else:
+ database_config = {
+ "name": "sqlite3",
+ "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
+ }
+
+ if "db_txn_limit" in kwargs:
+ database_config["txn_limit"] = kwargs["db_txn_limit"]
+
+ database = DatabaseConnectionConfig("master", database_config)
+ config.database.databases = [database]
+
+ db_engine = create_engine(database.config)
+
+ # Create the database before we actually try and connect to it, based off
+ # the template database we generate in setupdb()
+ if isinstance(db_engine, PostgresEngine):
+ db_conn = db_engine.module.connect(
+ database=POSTGRES_BASE_DB,
+ user=POSTGRES_USER,
+ host=POSTGRES_HOST,
+ password=POSTGRES_PASSWORD,
+ )
+ db_conn.autocommit = True
+ cur = db_conn.cursor()
+ cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
+ cur.execute(
+ "CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB)
+ )
+ cur.close()
+ db_conn.close()
+
+ hs = homeserver_to_use(
+ name,
+ config=config,
+ version_string="Synapse/tests",
+ reactor=reactor,
+ )
+
+ # Install @cache_in_self attributes
+ for key, val in kwargs.items():
+ setattr(hs, "_" + key, val)
+
+ # Mock TLS
+ hs.tls_server_context_factory = Mock()
+ hs.tls_client_options_factory = Mock()
+
+ hs.setup()
+ if homeserver_to_use == TestHomeServer:
+ hs.setup_background_tasks()
+
+ if isinstance(db_engine, PostgresEngine):
+ database = hs.get_datastores().databases[0]
+
+ # We need to do cleanup on PostgreSQL
+ def cleanup():
+ import psycopg2
+
+ # Close all the db pools
+ database._db_pool.close()
+
+ dropped = False
+
+ # Drop the test database
+ db_conn = db_engine.module.connect(
+ database=POSTGRES_BASE_DB,
+ user=POSTGRES_USER,
+ host=POSTGRES_HOST,
+ password=POSTGRES_PASSWORD,
+ )
+ db_conn.autocommit = True
+ cur = db_conn.cursor()
+
+ # Try a few times to drop the DB. Some things may hold on to the
+ # database for a few more seconds due to flakiness, preventing
+ # us from dropping it when the test is over. If we can't drop
+ # it, warn and move on.
+ for _ in range(5):
+ try:
+ cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
+ db_conn.commit()
+ dropped = True
+ except psycopg2.OperationalError as e:
+ warnings.warn(
+ "Couldn't drop old db: " + str(e), category=UserWarning
+ )
+ time.sleep(0.5)
+
+ cur.close()
+ db_conn.close()
+
+ if not dropped:
+ warnings.warn("Failed to drop old DB.", category=UserWarning)
+
+ if not LEAVE_DB:
+ # Register the cleanup hook
+ cleanup_func(cleanup)
+
+ # bcrypt is far too slow to be doing in unit tests
+ # Need to let the HS build an auth handler and then mess with it
+ # because AuthHandler's constructor requires the HS, so we can't make one
+ # beforehand and pass it in to the HS's constructor (chicken / egg)
+ async def hash(p):
+ return hashlib.md5(p.encode("utf8")).hexdigest()
+
+ hs.get_auth_handler().hash = hash
+
+ async def validate_hash(p, h):
+ return hashlib.md5(p.encode("utf8")).hexdigest() == h
+
+ hs.get_auth_handler().validate_hash = validate_hash
+
+ # Make the threadpool and database transactions synchronous for testing.
+ _make_test_homeserver_synchronous(hs)
+
+ return hs
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index ddad44bd6c..3e4f0579c9 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -23,7 +23,8 @@ from synapse.storage.database import DatabasePool
from synapse.storage.engines import create_engine
from tests import unittest
-from tests.utils import TestHomeServer, default_config
+from tests.server import TestHomeServer
+from tests.utils import default_config
class SQLBaseStoreTestCase(unittest.TestCase):
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index fccab733c0..5cfdfe9b85 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -19,8 +19,8 @@ from synapse.rest.client import login, room
from synapse.types import UserID, create_requester
from tests import unittest
+from tests.server import TestHomeServer
from tests.test_utils import event_injection
-from tests.utils import TestHomeServer
class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
diff --git a/tests/utils.py b/tests/utils.py
index 983859120f..6d013e8518 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -14,12 +14,7 @@
# limitations under the License.
import atexit
-import hashlib
import os
-import time
-import uuid
-import warnings
-from typing import Type
from unittest.mock import Mock, patch
from urllib import parse as urlparse
@@ -28,14 +23,11 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.api.errors import CodeMessageException, cs_error
from synapse.api.room_versions import RoomVersions
-from synapse.config.database import DatabaseConnectionConfig
from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import DEFAULT_ROOM_VERSION
from synapse.logging.context import current_context, set_current_context
-from synapse.server import HomeServer
-from synapse.storage import DataStore
from synapse.storage.database import LoggingDatabaseConnection
-from synapse.storage.engines import PostgresEngine, create_engine
+from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database
# set this to True to run the tests against postgres instead of sqlite.
@@ -182,171 +174,6 @@ def default_config(name, parse=False):
return config_dict
-class TestHomeServer(HomeServer):
- DATASTORE_CLASS = DataStore
-
-
-def setup_test_homeserver(
- cleanup_func,
- name="test",
- config=None,
- reactor=None,
- homeserver_to_use: Type[HomeServer] = TestHomeServer,
- **kwargs,
-):
- """
- Setup a homeserver suitable for running tests against. Keyword arguments
- are passed to the Homeserver constructor.
-
- If no datastore is supplied, one is created and given to the homeserver.
-
- Args:
- cleanup_func : The function used to register a cleanup routine for
- after the test.
-
- Calling this method directly is deprecated: you should instead derive from
- HomeserverTestCase.
- """
- if reactor is None:
- from twisted.internet import reactor
-
- if config is None:
- config = default_config(name, parse=True)
-
- config.ldap_enabled = False
-
- if "clock" not in kwargs:
- kwargs["clock"] = MockClock()
-
- if USE_POSTGRES_FOR_TESTS:
- test_db = "synapse_test_%s" % uuid.uuid4().hex
-
- database_config = {
- "name": "psycopg2",
- "args": {
- "database": test_db,
- "host": POSTGRES_HOST,
- "password": POSTGRES_PASSWORD,
- "user": POSTGRES_USER,
- "cp_min": 1,
- "cp_max": 5,
- },
- }
- else:
- database_config = {
- "name": "sqlite3",
- "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
- }
-
- if "db_txn_limit" in kwargs:
- database_config["txn_limit"] = kwargs["db_txn_limit"]
-
- database = DatabaseConnectionConfig("master", database_config)
- config.database.databases = [database]
-
- db_engine = create_engine(database.config)
-
- # Create the database before we actually try and connect to it, based off
- # the template database we generate in setupdb()
- if isinstance(db_engine, PostgresEngine):
- db_conn = db_engine.module.connect(
- database=POSTGRES_BASE_DB,
- user=POSTGRES_USER,
- host=POSTGRES_HOST,
- password=POSTGRES_PASSWORD,
- )
- db_conn.autocommit = True
- cur = db_conn.cursor()
- cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
- cur.execute(
- "CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB)
- )
- cur.close()
- db_conn.close()
-
- hs = homeserver_to_use(
- name,
- config=config,
- version_string="Synapse/tests",
- reactor=reactor,
- )
-
- # Install @cache_in_self attributes
- for key, val in kwargs.items():
- setattr(hs, "_" + key, val)
-
- # Mock TLS
- hs.tls_server_context_factory = Mock()
- hs.tls_client_options_factory = Mock()
-
- hs.setup()
- if homeserver_to_use == TestHomeServer:
- hs.setup_background_tasks()
-
- if isinstance(db_engine, PostgresEngine):
- database = hs.get_datastores().databases[0]
-
- # We need to do cleanup on PostgreSQL
- def cleanup():
- import psycopg2
-
- # Close all the db pools
- database._db_pool.close()
-
- dropped = False
-
- # Drop the test database
- db_conn = db_engine.module.connect(
- database=POSTGRES_BASE_DB,
- user=POSTGRES_USER,
- host=POSTGRES_HOST,
- password=POSTGRES_PASSWORD,
- )
- db_conn.autocommit = True
- cur = db_conn.cursor()
-
- # Try a few times to drop the DB. Some things may hold on to the
- # database for a few more seconds due to flakiness, preventing
- # us from dropping it when the test is over. If we can't drop
- # it, warn and move on.
- for _ in range(5):
- try:
- cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
- db_conn.commit()
- dropped = True
- except psycopg2.OperationalError as e:
- warnings.warn(
- "Couldn't drop old db: " + str(e), category=UserWarning
- )
- time.sleep(0.5)
-
- cur.close()
- db_conn.close()
-
- if not dropped:
- warnings.warn("Failed to drop old DB.", category=UserWarning)
-
- if not LEAVE_DB:
- # Register the cleanup hook
- cleanup_func(cleanup)
-
- # bcrypt is far too slow to be doing in unit tests
- # Need to let the HS build an auth handler and then mess with it
- # because AuthHandler's constructor requires the HS, so we can't make one
- # beforehand and pass it in to the HS's constructor (chicken / egg)
- async def hash(p):
- return hashlib.md5(p.encode("utf8")).hexdigest()
-
- hs.get_auth_handler().hash = hash
-
- async def validate_hash(p, h):
- return hashlib.md5(p.encode("utf8")).hexdigest() == h
-
- hs.get_auth_handler().validate_hash = validate_hash
-
- return hs
-
-
def mock_getRawHeaders(headers=None):
headers = headers if headers is not None else {}
|