summary refs log tree commit diff
path: root/tests/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/utils.py')
-rw-r--r--tests/utils.py394
1 files changed, 295 insertions, 99 deletions
diff --git a/tests/utils.py b/tests/utils.py
index 9bff3ff3b9..52ab762010 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -13,7 +13,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import atexit
 import hashlib
+import os
+import time
+import uuid
+import warnings
 from inspect import getcallargs
 
 from mock import Mock, patch
@@ -21,123 +26,284 @@ from six.moves.urllib import parse as urlparse
 
 from twisted.internet import defer, reactor
 
+from synapse.api.constants import EventTypes
 from synapse.api.errors import CodeMessageException, cs_error
+from synapse.config.server import ServerConfig
 from synapse.federation.transport import server
 from synapse.http.server import HttpServer
 from synapse.server import HomeServer
-from synapse.storage import PostgresEngine
-from synapse.storage.engines import create_engine
-from synapse.storage.prepare_database import prepare_database
+from synapse.storage import DataStore
+from synapse.storage.engines import PostgresEngine, create_engine
+from synapse.storage.prepare_database import (
+    _get_or_create_schema_state,
+    _setup_new_database,
+    prepare_database,
+)
 from synapse.util.logcontext import LoggingContext
 from synapse.util.ratelimitutils import FederationRateLimiter
 
 # set this to True to run the tests against postgres instead of sqlite.
-# It requires you to have a local postgres database called synapse_test, within
-# which ALL TABLES WILL BE DROPPED
-USE_POSTGRES_FOR_TESTS = False
+USE_POSTGRES_FOR_TESTS = os.environ.get("SYNAPSE_POSTGRES", False)
+LEAVE_DB = os.environ.get("SYNAPSE_LEAVE_DB", False)
+POSTGRES_USER = os.environ.get("SYNAPSE_POSTGRES_USER", "postgres")
+POSTGRES_BASE_DB = "_synapse_unit_tests_base_%s" % (os.getpid(),)
+
+
+def setupdb():
+
+    # If we're using PostgreSQL, set up the db once
+    if USE_POSTGRES_FOR_TESTS:
+        pgconfig = {
+            "name": "psycopg2",
+            "args": {
+                "database": POSTGRES_BASE_DB,
+                "user": POSTGRES_USER,
+                "cp_min": 1,
+                "cp_max": 5,
+            },
+        }
+        config = Mock()
+        config.password_providers = []
+        config.database_config = pgconfig
+        db_engine = create_engine(pgconfig)
+        db_conn = db_engine.module.connect(user=POSTGRES_USER)
+        db_conn.autocommit = True
+        cur = db_conn.cursor()
+        cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
+        cur.execute("CREATE DATABASE %s;" % (POSTGRES_BASE_DB,))
+        cur.close()
+        db_conn.close()
+
+        # Set up in the db
+        db_conn = db_engine.module.connect(
+            database=POSTGRES_BASE_DB, user=POSTGRES_USER
+        )
+        cur = db_conn.cursor()
+        _get_or_create_schema_state(cur, db_engine)
+        _setup_new_database(cur, db_engine)
+        db_conn.commit()
+        cur.close()
+        db_conn.close()
+
+        def _cleanup():
+            db_conn = db_engine.module.connect(user=POSTGRES_USER)
+            db_conn.autocommit = True
+            cur = db_conn.cursor()
+            cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
+            cur.close()
+            db_conn.close()
+
+        atexit.register(_cleanup)
+
+
+def default_config(name):
+    """
+    Create a reasonable test config.
+    """
+    config = Mock()
+    config.signing_key = [MockKey()]
+    config.event_cache_size = 1
+    config.enable_registration = True
+    config.macaroon_secret_key = "not even a little secret"
+    config.expire_access_token = False
+    config.server_name = name
+    config.trusted_third_party_id_servers = []
+    config.room_invite_state_types = []
+    config.password_providers = []
+    config.worker_replication_url = ""
+    config.worker_app = None
+    config.email_enable_notifs = False
+    config.block_non_admin_invites = False
+    config.federation_domain_whitelist = None
+    config.federation_rc_reject_limit = 10
+    config.federation_rc_sleep_limit = 10
+    config.federation_rc_sleep_delay = 100
+    config.federation_rc_concurrent = 10
+    config.filter_timeline_limit = 5000
+    config.user_directory_search_all_users = False
+    config.user_consent_server_notice_content = None
+    config.block_events_without_consent_error = None
+    config.user_consent_at_registration = False
+    config.user_consent_policy_name = "Privacy Policy"
+    config.media_storage_providers = []
+    config.autocreate_auto_join_rooms = True
+    config.auto_join_rooms = []
+    config.limit_usage_by_mau = False
+    config.hs_disabled = False
+    config.hs_disabled_message = ""
+    config.hs_disabled_limit_type = ""
+    config.max_mau_value = 50
+    config.mau_trial_days = 0
+    config.mau_stats_only = False
+    config.mau_limits_reserved_threepids = []
+    config.admin_contact = None
+    config.rc_messages_per_second = 10000
+    config.rc_message_burst_count = 10000
+
+    config.use_frozen_dicts = False
+
+    # we need a sane default_room_version, otherwise attempts to create rooms will
+    # fail.
+    config.default_room_version = "1"
+
+    # disable user directory updates, because they get done in the
+    # background, which upsets the test runner.
+    config.update_user_directory = False
+
+    def is_threepid_reserved(threepid):
+        return ServerConfig.is_threepid_reserved(config, threepid)
+
+    config.is_threepid_reserved.side_effect = is_threepid_reserved
+
+    return config
+
+
+class TestHomeServer(HomeServer):
+    DATASTORE_CLASS = DataStore
 
 
 @defer.inlineCallbacks
-def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None,
-                          **kargs):
-    """Setup a homeserver suitable for running tests against. Keyword arguments
-    are passed to the Homeserver constructor. If no datastore is supplied a
-    datastore backed by an in-memory sqlite db will be given to the HS.
+def setup_test_homeserver(
+    cleanup_func,
+    name="test",
+    datastore=None,
+    config=None,
+    reactor=None,
+    homeserverToUse=TestHomeServer,
+    **kargs
+):
+    """
+    Setup a homeserver suitable for running tests against.  Keyword arguments
+    are passed to the Homeserver constructor.
+
+    If no datastore is supplied, one is created and given to the homeserver.
+
+    Args:
+        cleanup_func : The function used to register a cleanup routine for
+                       after the test.
     """
     if reactor is None:
         from twisted.internet import reactor
 
     if config is None:
-        config = Mock()
-        config.signing_key = [MockKey()]
-        config.event_cache_size = 1
-        config.enable_registration = True
-        config.macaroon_secret_key = "not even a little secret"
-        config.expire_access_token = False
-        config.server_name = name
-        config.trusted_third_party_id_servers = []
-        config.room_invite_state_types = []
-        config.password_providers = []
-        config.worker_replication_url = ""
-        config.worker_app = None
-        config.email_enable_notifs = False
-        config.block_non_admin_invites = False
-        config.federation_domain_whitelist = None
-        config.federation_rc_reject_limit = 10
-        config.federation_rc_sleep_limit = 10
-        config.federation_rc_sleep_delay = 100
-        config.federation_rc_concurrent = 10
-        config.filter_timeline_limit = 5000
-        config.user_directory_search_all_users = False
-        config.user_consent_server_notice_content = None
-        config.block_events_without_consent_error = None
-        config.media_storage_providers = []
-        config.auto_join_rooms = []
-
-        # disable user directory updates, because they get done in the
-        # background, which upsets the test runner.
-        config.update_user_directory = False
-
-    config.use_frozen_dicts = True
+        config = default_config(name)
+
     config.ldap_enabled = False
 
     if "clock" not in kargs:
         kargs["clock"] = MockClock()
 
     if USE_POSTGRES_FOR_TESTS:
+        test_db = "synapse_test_%s" % uuid.uuid4().hex
+
         config.database_config = {
             "name": "psycopg2",
-            "args": {
-                "database": "synapse_test",
-                "cp_min": 1,
-                "cp_max": 5,
-            },
+            "args": {"database": test_db, "cp_min": 1, "cp_max": 5},
         }
     else:
         config.database_config = {
             "name": "sqlite3",
-            "args": {
-                "database": ":memory:",
-                "cp_min": 1,
-                "cp_max": 1,
-            },
+            "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
         }
 
     db_engine = create_engine(config.database_config)
 
+    # Create the database before we actually try and connect to it, based off
+    # the template database we generate in setupdb()
+    if datastore is None and isinstance(db_engine, PostgresEngine):
+        db_conn = db_engine.module.connect(
+            database=POSTGRES_BASE_DB, user=POSTGRES_USER
+        )
+        db_conn.autocommit = True
+        cur = db_conn.cursor()
+        cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
+        cur.execute(
+            "CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB)
+        )
+        cur.close()
+        db_conn.close()
+
     # we need to configure the connection pool to run the on_new_connection
     # function, so that we can test code that uses custom sqlite functions
     # (like rank).
     config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection
 
     if datastore is None:
-        hs = HomeServer(
-            name, config=config,
+        hs = homeserverToUse(
+            name,
+            config=config,
             db_config=config.database_config,
             version_string="Synapse/tests",
             database_engine=db_engine,
             room_list_handler=object(),
             tls_server_context_factory=Mock(),
+            tls_client_options_factory=Mock(),
             reactor=reactor,
             **kargs
         )
-        db_conn = hs.get_db_conn()
-        # make sure that the database is empty
-        if isinstance(db_engine, PostgresEngine):
-            cur = db_conn.cursor()
-            cur.execute("SELECT tablename FROM pg_tables where schemaname='public'")
-            rows = cur.fetchall()
-            for r in rows:
-                cur.execute("DROP TABLE %s CASCADE" % r[0])
-        yield prepare_database(db_conn, db_engine, config)
+
+        # Prepare the DB on SQLite -- PostgreSQL is a copy of an already up to
+        # date db
+        if not isinstance(db_engine, PostgresEngine):
+            db_conn = hs.get_db_conn()
+            yield prepare_database(db_conn, db_engine, config)
+            db_conn.commit()
+            db_conn.close()
+
+        else:
+            # We need to do cleanup on PostgreSQL
+            def cleanup():
+                import psycopg2
+
+                # Close all the db pools
+                hs.get_db_pool().close()
+
+                dropped = False
+
+                # Drop the test database
+                db_conn = db_engine.module.connect(
+                    database=POSTGRES_BASE_DB, user=POSTGRES_USER
+                )
+                db_conn.autocommit = True
+                cur = db_conn.cursor()
+
+                # Try a few times to drop the DB. Some things may hold on to the
+                # database for a few more seconds due to flakiness, preventing
+                # us from dropping it when the test is over. If we can't drop
+                # it, warn and move on.
+                for x in range(5):
+                    try:
+                        cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
+                        db_conn.commit()
+                        dropped = True
+                    except psycopg2.OperationalError as e:
+                        warnings.warn(
+                            "Couldn't drop old db: " + str(e), category=UserWarning
+                        )
+                        time.sleep(0.5)
+
+                cur.close()
+                db_conn.close()
+
+                if not dropped:
+                    warnings.warn("Failed to drop old DB.", category=UserWarning)
+
+            if not LEAVE_DB:
+                # Register the cleanup hook
+                cleanup_func(cleanup)
+
         hs.setup()
     else:
-        hs = HomeServer(
-            name, db_pool=None, datastore=datastore, config=config,
+        hs = homeserverToUse(
+            name,
+            db_pool=None,
+            datastore=datastore,
+            config=config,
             version_string="Synapse/tests",
             database_engine=db_engine,
             room_list_handler=object(),
             tls_server_context_factory=Mock(),
+            tls_client_options_factory=Mock(),
             reactor=reactor,
             **kargs
         )
@@ -146,8 +312,10 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None
     # Need to let the HS build an auth handler and then mess with it
     # because AuthHandler's constructor requires the HS, so we can't make one
     # beforehand and pass it in to the HS's constructor (chicken / egg)
-    hs.get_auth_handler().hash = lambda p: hashlib.md5(p).hexdigest()
-    hs.get_auth_handler().validate_hash = lambda p, h: hashlib.md5(p).hexdigest() == h
+    hs.get_auth_handler().hash = lambda p: hashlib.md5(p.encode('utf8')).hexdigest()
+    hs.get_auth_handler().validate_hash = (
+        lambda p, h: hashlib.md5(p.encode('utf8')).hexdigest() == h
+    )
 
     fed = kargs.get("resource_for_federation", None)
     if fed:
@@ -161,7 +329,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None
                 sleep_limit=hs.config.federation_rc_sleep_limit,
                 sleep_msec=hs.config.federation_rc_sleep_delay,
                 reject_limit=hs.config.federation_rc_reject_limit,
-                concurrent_requests=hs.config.federation_rc_concurrent
+                concurrent_requests=hs.config.federation_rc_concurrent,
             ),
         )
 
@@ -187,7 +355,6 @@ def mock_getRawHeaders(headers=None):
 
 # This is a mock /resource/ not an entire server
 class MockHttpResource(HttpServer):
-
     def __init__(self, prefix=""):
         self.callbacks = []  # 3-tuple of method/pattern/function
         self.prefix = prefix
@@ -197,7 +364,9 @@ class MockHttpResource(HttpServer):
 
     @patch('twisted.web.http.Request')
     @defer.inlineCallbacks
-    def trigger(self, http_method, path, content, mock_request, federation_auth=False):
+    def trigger(
+        self, http_method, path, content, mock_request, federation_auth_origin=None
+    ):
         """ Fire an HTTP event.
 
         Args:
@@ -206,6 +375,7 @@ class MockHttpResource(HttpServer):
             content : The HTTP body
             mock_request : Mocked request to pass to the event so it can get
                            content.
+            federation_auth_origin (bytes|None): domain to authenticate as, for federation
         Returns:
             A tuple of (code, response)
         Raises:
@@ -220,14 +390,16 @@ class MockHttpResource(HttpServer):
         mock_content.configure_mock(**config)
         mock_request.content = mock_content
 
-        mock_request.method = http_method
-        mock_request.uri = path
+        mock_request.method = http_method.encode('ascii')
+        mock_request.uri = path.encode('ascii')
 
         mock_request.getClientIP.return_value = "-"
 
         headers = {}
-        if federation_auth:
-            headers[b"Authorization"] = [b"X-Matrix origin=test,key=,sig="]
+        if federation_auth_origin is not None:
+            headers[b"Authorization"] = [
+                b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,)
+            ]
         mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
 
         # return the right path if the event requires it
@@ -251,15 +423,9 @@ class MockHttpResource(HttpServer):
             matcher = pattern.match(path)
             if matcher:
                 try:
-                    args = [
-                        urlparse.unquote(u)
-                        for u in matcher.groups()
-                    ]
-
-                    (code, response) = yield func(
-                        mock_request,
-                        *args
-                    )
+                    args = [urlparse.unquote(u) for u in matcher.groups()]
+
+                    (code, response) = yield func(mock_request, *args)
                     defer.returnValue((code, response))
                 except CodeMessageException as e:
                     defer.returnValue((e.code, cs_error(e.msg, code=e.errcode)))
@@ -360,8 +526,7 @@ class MockClock(object):
 
 def _format_call(args, kwargs):
     return ", ".join(
-        ["%r" % (a) for a in args] +
-        ["%s=%r" % (k, v) for k, v in kwargs.items()]
+        ["%r" % (a) for a in args] + ["%s=%r" % (k, v) for k, v in kwargs.items()]
     )
 
 
@@ -379,8 +544,9 @@ class DeferredMockCallable(object):
         self.calls.append((args, kwargs))
 
         if not self.expectations:
-            raise ValueError("%r has no pending calls to handle call(%s)" % (
-                self, _format_call(args, kwargs))
+            raise ValueError(
+                "%r has no pending calls to handle call(%s)"
+                % (self, _format_call(args, kwargs))
             )
 
         for (call, result, d) in self.expectations:
@@ -388,9 +554,9 @@ class DeferredMockCallable(object):
                 d.callback(None)
                 return result
 
-        failure = AssertionError("Was not expecting call(%s)" % (
-            _format_call(args, kwargs)
-        ))
+        failure = AssertionError(
+            "Was not expecting call(%s)" % (_format_call(args, kwargs))
+        )
 
         for _, _, d in self.expectations:
             try:
@@ -406,17 +572,19 @@ class DeferredMockCallable(object):
     @defer.inlineCallbacks
     def await_calls(self, timeout=1000):
         deferred = defer.DeferredList(
-            [d for _, _, d in self.expectations],
-            fireOnOneErrback=True
+            [d for _, _, d in self.expectations], fireOnOneErrback=True
         )
 
         timer = reactor.callLater(
             timeout / 1000,
             deferred.errback,
-            AssertionError("%d pending calls left: %s" % (
-                len([e for e in self.expectations if not e[2].called]),
-                [e for e in self.expectations if not e[2].called]
-            ))
+            AssertionError(
+                "%d pending calls left: %s"
+                % (
+                    len([e for e in self.expectations if not e[2].called]),
+                    [e for e in self.expectations if not e[2].called],
+                )
+            ),
         )
 
         yield deferred
@@ -431,7 +599,35 @@ class DeferredMockCallable(object):
             self.calls = []
 
             raise AssertionError(
-                "Expected not to received any calls, got:\n" + "\n".join([
-                    "call(%s)" % _format_call(c[0], c[1]) for c in calls
-                ])
+                "Expected not to received any calls, got:\n"
+                + "\n".join(["call(%s)" % _format_call(c[0], c[1]) for c in calls])
             )
+
+
+@defer.inlineCallbacks
+def create_room(hs, room_id, creator_id):
+    """Creates and persist a creation event for the given room
+
+    Args:
+        hs
+        room_id (str)
+        creator_id (str)
+    """
+
+    store = hs.get_datastore()
+    event_builder_factory = hs.get_event_builder_factory()
+    event_creation_handler = hs.get_event_creation_handler()
+
+    builder = event_builder_factory.new(
+        {
+            "type": EventTypes.Create,
+            "state_key": "",
+            "sender": creator_id,
+            "room_id": room_id,
+            "content": {},
+        }
+    )
+
+    event, context = yield event_creation_handler.create_new_client_event(builder)
+
+    yield store.persist_event(event, context)