diff --git a/tests/server.py b/tests/server.py
index b29df37595..40cf5b12c3 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -11,12 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import hashlib
+
import json
import logging
-import time
-import uuid
-import warnings
from collections import deque
from io import SEEK_END, BytesIO
from typing import (
@@ -30,7 +27,6 @@ from typing import (
Type,
Union,
)
-from unittest.mock import Mock
import attr
from typing_extensions import Deque
@@ -57,24 +53,11 @@ from twisted.web.http_headers import Headers
from twisted.web.resource import IResource
from twisted.web.server import Request, Site
-from synapse.config.database import DatabaseConnectionConfig
from synapse.http.site import SynapseRequest
-from synapse.server import HomeServer
-from synapse.storage import DataStore
-from synapse.storage.engines import PostgresEngine, create_engine
from synapse.types import JsonDict
from synapse.util import Clock
-from tests.utils import (
- LEAVE_DB,
- POSTGRES_BASE_DB,
- POSTGRES_HOST,
- POSTGRES_PASSWORD,
- POSTGRES_USER,
- USE_POSTGRES_FOR_TESTS,
- MockClock,
- default_config,
-)
+from tests.utils import setup_test_homeserver as _sth
logger = logging.getLogger(__name__)
@@ -467,11 +450,14 @@ class ThreadPool:
return d
-def make_test_homeserver_synchronous(server: HomeServer) -> None:
+def setup_test_homeserver(cleanup_func, *args, **kwargs):
"""
- Make the given test homeserver's database interactions synchronous.
+ Set up a synchronous test server, driven by the reactor used by
+ the homeserver.
"""
+ server = _sth(cleanup_func, *args, **kwargs)
+ # Make the thread pool synchronous.
clock = server.get_clock()
for database in server.get_datastores().databases:
@@ -499,7 +485,6 @@ def make_test_homeserver_synchronous(server: HomeServer) -> None:
pool.runWithConnection = runWithConnection
pool.runInteraction = runInteraction
- # Replace the thread pool with a threadless 'thread' pool
pool.threadpool = ThreadPool(clock._reactor)
pool.running = True
@@ -507,6 +492,8 @@ def make_test_homeserver_synchronous(server: HomeServer) -> None:
# thread, so we need to disable the dedicated thread behaviour.
server.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False
+ return server
+
def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]:
clock = ThreadedMemoryReactorClock()
@@ -686,171 +673,3 @@ def connect_client(
client.makeConnection(FakeTransport(server, reactor))
return client, server
-
-
-class TestHomeServer(HomeServer):
- DATASTORE_CLASS = DataStore
-
-
-def setup_test_homeserver(
- cleanup_func,
- name="test",
- config=None,
- reactor=None,
- homeserver_to_use: Type[HomeServer] = TestHomeServer,
- **kwargs,
-):
- """
- Setup a homeserver suitable for running tests against. Keyword arguments
- are passed to the Homeserver constructor.
-
- If no datastore is supplied, one is created and given to the homeserver.
-
- Args:
- cleanup_func : The function used to register a cleanup routine for
- after the test.
-
- Calling this method directly is deprecated: you should instead derive from
- HomeserverTestCase.
- """
- if reactor is None:
- from twisted.internet import reactor
-
- if config is None:
- config = default_config(name, parse=True)
-
- config.ldap_enabled = False
-
- if "clock" not in kwargs:
- kwargs["clock"] = MockClock()
-
- if USE_POSTGRES_FOR_TESTS:
- test_db = "synapse_test_%s" % uuid.uuid4().hex
-
- database_config = {
- "name": "psycopg2",
- "args": {
- "database": test_db,
- "host": POSTGRES_HOST,
- "password": POSTGRES_PASSWORD,
- "user": POSTGRES_USER,
- "cp_min": 1,
- "cp_max": 5,
- },
- }
- else:
- database_config = {
- "name": "sqlite3",
- "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
- }
-
- if "db_txn_limit" in kwargs:
- database_config["txn_limit"] = kwargs["db_txn_limit"]
-
- database = DatabaseConnectionConfig("master", database_config)
- config.database.databases = [database]
-
- db_engine = create_engine(database.config)
-
- # Create the database before we actually try and connect to it, based off
- # the template database we generate in setupdb()
- if isinstance(db_engine, PostgresEngine):
- db_conn = db_engine.module.connect(
- database=POSTGRES_BASE_DB,
- user=POSTGRES_USER,
- host=POSTGRES_HOST,
- password=POSTGRES_PASSWORD,
- )
- db_conn.autocommit = True
- cur = db_conn.cursor()
- cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
- cur.execute(
- "CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB)
- )
- cur.close()
- db_conn.close()
-
- hs = homeserver_to_use(
- name,
- config=config,
- version_string="Synapse/tests",
- reactor=reactor,
- )
-
- # Install @cache_in_self attributes
- for key, val in kwargs.items():
- setattr(hs, "_" + key, val)
-
- # Mock TLS
- hs.tls_server_context_factory = Mock()
- hs.tls_client_options_factory = Mock()
-
- hs.setup()
- if homeserver_to_use == TestHomeServer:
- hs.setup_background_tasks()
-
- if isinstance(db_engine, PostgresEngine):
- database = hs.get_datastores().databases[0]
-
- # We need to do cleanup on PostgreSQL
- def cleanup():
- import psycopg2
-
- # Close all the db pools
- database._db_pool.close()
-
- dropped = False
-
- # Drop the test database
- db_conn = db_engine.module.connect(
- database=POSTGRES_BASE_DB,
- user=POSTGRES_USER,
- host=POSTGRES_HOST,
- password=POSTGRES_PASSWORD,
- )
- db_conn.autocommit = True
- cur = db_conn.cursor()
-
- # Try a few times to drop the DB. Some things may hold on to the
- # database for a few more seconds due to flakiness, preventing
- # us from dropping it when the test is over. If we can't drop
- # it, warn and move on.
- for _ in range(5):
- try:
- cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
- db_conn.commit()
- dropped = True
- except psycopg2.OperationalError as e:
- warnings.warn(
- "Couldn't drop old db: " + str(e), category=UserWarning
- )
- time.sleep(0.5)
-
- cur.close()
- db_conn.close()
-
- if not dropped:
- warnings.warn("Failed to drop old DB.", category=UserWarning)
-
- if not LEAVE_DB:
- # Register the cleanup hook
- cleanup_func(cleanup)
-
- # bcrypt is far too slow to be doing in unit tests
- # Need to let the HS build an auth handler and then mess with it
- # because AuthHandler's constructor requires the HS, so we can't make one
- # beforehand and pass it in to the HS's constructor (chicken / egg)
- async def hash(p):
- return hashlib.md5(p.encode("utf8")).hexdigest()
-
- hs.get_auth_handler().hash = hash
-
- async def validate_hash(p, h):
- return hashlib.md5(p.encode("utf8")).hexdigest() == h
-
- hs.get_auth_handler().validate_hash = validate_hash
-
- # Make the threadpool and database transactions synchronous for testing.
- make_test_homeserver_synchronous(hs)
-
- return hs
|