summary refs log tree commit diff
path: root/tests/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/utils.py')
-rw-r--r--tests/utils.py184
1 files changed, 121 insertions, 63 deletions
diff --git a/tests/utils.py b/tests/utils.py
index 9f7ff94575..dd347a0c59 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -16,7 +16,9 @@
 import atexit
 import hashlib
 import os
+import time
 import uuid
+import warnings
 from inspect import getcallargs
 
 from mock import Mock, patch
@@ -26,11 +28,12 @@ from twisted.internet import defer, reactor
 
 from synapse.api.constants import EventTypes
 from synapse.api.errors import CodeMessageException, cs_error
+from synapse.config.server import ServerConfig
 from synapse.federation.transport import server
 from synapse.http.server import HttpServer
 from synapse.server import HomeServer
-from synapse.storage import PostgresEngine
-from synapse.storage.engines import create_engine
+from synapse.storage import DataStore
+from synapse.storage.engines import PostgresEngine, create_engine
 from synapse.storage.prepare_database import (
     _get_or_create_schema_state,
     _setup_new_database,
@@ -41,6 +44,7 @@ from synapse.util.ratelimitutils import FederationRateLimiter
 
 # set this to True to run the tests against postgres instead of sqlite.
 USE_POSTGRES_FOR_TESTS = os.environ.get("SYNAPSE_POSTGRES", False)
+LEAVE_DB = os.environ.get("SYNAPSE_LEAVE_DB", False)
 POSTGRES_USER = os.environ.get("SYNAPSE_POSTGRES_USER", "postgres")
 POSTGRES_BASE_DB = "_synapse_unit_tests_base_%s" % (os.getpid(),)
 
@@ -92,10 +96,77 @@ def setupdb():
         atexit.register(_cleanup)
 
 
+def default_config(name):
+    """
+    Create a reasonable test config.
+    """
+    config = Mock()
+    config.signing_key = [MockKey()]
+    config.event_cache_size = 1
+    config.enable_registration = True
+    config.macaroon_secret_key = "not even a little secret"
+    config.expire_access_token = False
+    config.server_name = name
+    config.trusted_third_party_id_servers = []
+    config.room_invite_state_types = []
+    config.password_providers = []
+    config.worker_replication_url = ""
+    config.worker_app = None
+    config.email_enable_notifs = False
+    config.block_non_admin_invites = False
+    config.federation_domain_whitelist = None
+    config.federation_rc_reject_limit = 10
+    config.federation_rc_sleep_limit = 10
+    config.federation_rc_sleep_delay = 100
+    config.federation_rc_concurrent = 10
+    config.filter_timeline_limit = 5000
+    config.user_directory_search_all_users = False
+    config.user_consent_server_notice_content = None
+    config.block_events_without_consent_error = None
+    config.media_storage_providers = []
+    config.auto_join_rooms = []
+    config.limit_usage_by_mau = False
+    config.hs_disabled = False
+    config.hs_disabled_message = ""
+    config.hs_disabled_limit_type = ""
+    config.max_mau_value = 50
+    config.mau_trial_days = 0
+    config.mau_limits_reserved_threepids = []
+    config.admin_contact = None
+    config.rc_messages_per_second = 10000
+    config.rc_message_burst_count = 10000
+
+    config.use_frozen_dicts = False
+
+    # we need a sane default_room_version, otherwise attempts to create rooms will
+    # fail.
+    config.default_room_version = "1"
+
+    # disable user directory updates, because they get done in the
+    # background, which upsets the test runner.
+    config.update_user_directory = False
+
+    def is_threepid_reserved(threepid):
+        return ServerConfig.is_threepid_reserved(config, threepid)
+
+    config.is_threepid_reserved.side_effect = is_threepid_reserved
+
+    return config
+
+
+class TestHomeServer(HomeServer):
+    DATASTORE_CLASS = DataStore
+
+
 @defer.inlineCallbacks
 def setup_test_homeserver(
-    cleanup_func, name="test", datastore=None, config=None, reactor=None,
-    homeserverToUse=HomeServer, **kargs
+    cleanup_func,
+    name="test",
+    datastore=None,
+    config=None,
+    reactor=None,
+    homeserverToUse=TestHomeServer,
+    **kargs
 ):
     """
     Setup a homeserver suitable for running tests against.  Keyword arguments
@@ -111,48 +182,8 @@ def setup_test_homeserver(
         from twisted.internet import reactor
 
     if config is None:
-        config = Mock()
-        config.signing_key = [MockKey()]
-        config.event_cache_size = 1
-        config.enable_registration = True
-        config.macaroon_secret_key = "not even a little secret"
-        config.expire_access_token = False
-        config.server_name = name
-        config.trusted_third_party_id_servers = []
-        config.room_invite_state_types = []
-        config.password_providers = []
-        config.worker_replication_url = ""
-        config.worker_app = None
-        config.email_enable_notifs = False
-        config.block_non_admin_invites = False
-        config.federation_domain_whitelist = None
-        config.federation_rc_reject_limit = 10
-        config.federation_rc_sleep_limit = 10
-        config.federation_rc_sleep_delay = 100
-        config.federation_rc_concurrent = 10
-        config.filter_timeline_limit = 5000
-        config.user_directory_search_all_users = False
-        config.user_consent_server_notice_content = None
-        config.block_events_without_consent_error = None
-        config.media_storage_providers = []
-        config.auto_join_rooms = []
-        config.limit_usage_by_mau = False
-        config.hs_disabled = False
-        config.hs_disabled_message = ""
-        config.hs_disabled_limit_type = ""
-        config.max_mau_value = 50
-        config.mau_limits_reserved_threepids = []
-        config.admin_uri = None
-
-        # we need a sane default_room_version, otherwise attempts to create rooms will
-        # fail.
-        config.default_room_version = "1"
-
-        # disable user directory updates, because they get done in the
-        # background, which upsets the test runner.
-        config.update_user_directory = False
-
-    config.use_frozen_dicts = True
+        config = default_config(name)
+
     config.ldap_enabled = False
 
     if "clock" not in kargs:
@@ -218,22 +249,44 @@ def setup_test_homeserver(
         else:
             # We need to do cleanup on PostgreSQL
             def cleanup():
+                import psycopg2
+
                 # Close all the db pools
                 hs.get_db_pool().close()
 
+                dropped = False
+
                 # Drop the test database
                 db_conn = db_engine.module.connect(
                     database=POSTGRES_BASE_DB, user=POSTGRES_USER
                 )
                 db_conn.autocommit = True
                 cur = db_conn.cursor()
-                cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
-                db_conn.commit()
+
+                # Try a few times to drop the DB. Some things may hold on to the
+                # database for a few more seconds due to flakiness, preventing
+                # us from dropping it when the test is over. If we can't drop
+                # it, warn and move on.
+                for x in range(5):
+                    try:
+                        cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
+                        db_conn.commit()
+                        dropped = True
+                    except psycopg2.OperationalError as e:
+                        warnings.warn(
+                            "Couldn't drop old db: " + str(e), category=UserWarning
+                        )
+                        time.sleep(0.5)
+
                 cur.close()
                 db_conn.close()
 
-            # Register the cleanup hook
-            cleanup_func(cleanup)
+                if not dropped:
+                    warnings.warn("Failed to drop old DB.", category=UserWarning)
+
+            if not LEAVE_DB:
+                # Register the cleanup hook
+                cleanup_func(cleanup)
 
         hs.setup()
     else:
@@ -307,7 +360,9 @@ class MockHttpResource(HttpServer):
 
     @patch('twisted.web.http.Request')
     @defer.inlineCallbacks
-    def trigger(self, http_method, path, content, mock_request, federation_auth=False):
+    def trigger(
+        self, http_method, path, content, mock_request, federation_auth_origin=None
+    ):
         """ Fire an HTTP event.
 
         Args:
@@ -316,6 +371,7 @@ class MockHttpResource(HttpServer):
             content : The HTTP body
             mock_request : Mocked request to pass to the event so it can get
                            content.
+            federation_auth_origin (bytes|None): domain to authenticate as, for federation
         Returns:
             A tuple of (code, response)
         Raises:
@@ -336,8 +392,10 @@ class MockHttpResource(HttpServer):
         mock_request.getClientIP.return_value = "-"
 
         headers = {}
-        if federation_auth:
-            headers[b"Authorization"] = [b"X-Matrix origin=test,key=,sig="]
+        if federation_auth_origin is not None:
+            headers[b"Authorization"] = [
+                b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,)
+            ]
         mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
 
         # return the right path if the event requires it
@@ -556,16 +614,16 @@ def create_room(hs, room_id, creator_id):
     event_builder_factory = hs.get_event_builder_factory()
     event_creation_handler = hs.get_event_creation_handler()
 
-    builder = event_builder_factory.new({
-        "type": EventTypes.Create,
-        "state_key": "",
-        "sender": creator_id,
-        "room_id": room_id,
-        "content": {},
-    })
-
-    event, context = yield event_creation_handler.create_new_client_event(
-        builder
+    builder = event_builder_factory.new(
+        {
+            "type": EventTypes.Create,
+            "state_key": "",
+            "sender": creator_id,
+            "room_id": room_id,
+            "content": {},
+        }
     )
 
+    event, context = yield event_creation_handler.create_new_client_event(builder)
+
     yield store.persist_event(event, context)