summary refs log tree commit diff
path: root/tests/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/utils.py')
-rw-r--r--tests/utils.py212
1 files changed, 153 insertions, 59 deletions
diff --git a/tests/utils.py b/tests/utils.py
index 5d49692c58..52326d4f67 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -13,7 +13,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import atexit
 import hashlib
+import os
+import uuid
 from inspect import getcallargs
 
 from mock import Mock, patch
@@ -27,22 +30,80 @@ from synapse.http.server import HttpServer
 from synapse.server import HomeServer
 from synapse.storage import PostgresEngine
 from synapse.storage.engines import create_engine
-from synapse.storage.prepare_database import prepare_database
+from synapse.storage.prepare_database import (
+    _get_or_create_schema_state,
+    _setup_new_database,
+    prepare_database,
+)
 from synapse.util.logcontext import LoggingContext
 from synapse.util.ratelimitutils import FederationRateLimiter
 
 # set this to True to run the tests against postgres instead of sqlite.
-# It requires you to have a local postgres database called synapse_test, within
-# which ALL TABLES WILL BE DROPPED
-USE_POSTGRES_FOR_TESTS = False
+USE_POSTGRES_FOR_TESTS = os.environ.get("SYNAPSE_POSTGRES", False)
+POSTGRES_USER = os.environ.get("SYNAPSE_POSTGRES_USER", "postgres")
+POSTGRES_BASE_DB = "_synapse_unit_tests_base_%s" % (os.getpid(),)
+
+
+def setupdb():
+
+    # If we're using PostgreSQL, set up the db once
+    if USE_POSTGRES_FOR_TESTS:
+        pgconfig = {
+            "name": "psycopg2",
+            "args": {
+                "database": POSTGRES_BASE_DB,
+                "user": POSTGRES_USER,
+                "cp_min": 1,
+                "cp_max": 5,
+            },
+        }
+        config = Mock()
+        config.password_providers = []
+        config.database_config = pgconfig
+        db_engine = create_engine(pgconfig)
+        db_conn = db_engine.module.connect(user=POSTGRES_USER)
+        db_conn.autocommit = True
+        cur = db_conn.cursor()
+        cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
+        cur.execute("CREATE DATABASE %s;" % (POSTGRES_BASE_DB,))
+        cur.close()
+        db_conn.close()
+
+        # Set up in the db
+        db_conn = db_engine.module.connect(
+            database=POSTGRES_BASE_DB, user=POSTGRES_USER
+        )
+        cur = db_conn.cursor()
+        _get_or_create_schema_state(cur, db_engine)
+        _setup_new_database(cur, db_engine)
+        db_conn.commit()
+        cur.close()
+        db_conn.close()
+
+        def _cleanup():
+            db_conn = db_engine.module.connect(user=POSTGRES_USER)
+            db_conn.autocommit = True
+            cur = db_conn.cursor()
+            cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
+            cur.close()
+            db_conn.close()
+
+        atexit.register(_cleanup)
 
 
 @defer.inlineCallbacks
-def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None,
-                          **kargs):
-    """Setup a homeserver suitable for running tests against. Keyword arguments
-    are passed to the Homeserver constructor. If no datastore is supplied a
-    datastore backed by an in-memory sqlite db will be given to the HS.
+def setup_test_homeserver(
+    cleanup_func, name="test", datastore=None, config=None, reactor=None, **kargs
+):
+    """
+    Setup a homeserver suitable for running tests against.  Keyword arguments
+    are passed to the Homeserver constructor.
+
+    If no datastore is supplied, one is created and given to the homeserver.
+
+    Args:
+        cleanup_func : The function used to register a cleanup routine for
+                       after the test.
     """
     if reactor is None:
         from twisted.internet import reactor
@@ -74,8 +135,11 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None
         config.media_storage_providers = []
         config.auto_join_rooms = []
         config.limit_usage_by_mau = False
+        config.hs_disabled = False
+        config.hs_disabled_message = ""
         config.max_mau_value = 50
         config.mau_limits_reserved_threepids = []
+        config.admin_uri = None
 
         # we need a sane default_room_version, otherwise attempts to create rooms will
         # fail.
@@ -92,26 +156,35 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None
         kargs["clock"] = MockClock()
 
     if USE_POSTGRES_FOR_TESTS:
+        test_db = "synapse_test_%s" % uuid.uuid4().hex
+
         config.database_config = {
             "name": "psycopg2",
-            "args": {
-                "database": "synapse_test",
-                "cp_min": 1,
-                "cp_max": 5,
-            },
+            "args": {"database": test_db, "cp_min": 1, "cp_max": 5},
         }
     else:
         config.database_config = {
             "name": "sqlite3",
-            "args": {
-                "database": ":memory:",
-                "cp_min": 1,
-                "cp_max": 1,
-            },
+            "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
         }
 
     db_engine = create_engine(config.database_config)
 
+    # Create the database before we actually try and connect to it, based off
+    # the template database we generate in setupdb()
+    if datastore is None and isinstance(db_engine, PostgresEngine):
+        db_conn = db_engine.module.connect(
+            database=POSTGRES_BASE_DB, user=POSTGRES_USER
+        )
+        db_conn.autocommit = True
+        cur = db_conn.cursor()
+        cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
+        cur.execute(
+            "CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB)
+        )
+        cur.close()
+        db_conn.close()
+
     # we need to configure the connection pool to run the on_new_connection
     # function, so that we can test code that uses custom sqlite functions
     # (like rank).
@@ -119,32 +192,58 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None
 
     if datastore is None:
         hs = HomeServer(
-            name, config=config,
+            name,
+            config=config,
             db_config=config.database_config,
             version_string="Synapse/tests",
             database_engine=db_engine,
             room_list_handler=object(),
             tls_server_context_factory=Mock(),
+            tls_client_options_factory=Mock(),
             reactor=reactor,
             **kargs
         )
-        db_conn = hs.get_db_conn()
-        # make sure that the database is empty
-        if isinstance(db_engine, PostgresEngine):
-            cur = db_conn.cursor()
-            cur.execute("SELECT tablename FROM pg_tables where schemaname='public'")
-            rows = cur.fetchall()
-            for r in rows:
-                cur.execute("DROP TABLE %s CASCADE" % r[0])
-        yield prepare_database(db_conn, db_engine, config)
+
+        # Prepare the DB on SQLite -- PostgreSQL is a copy of an already up to
+        # date db
+        if not isinstance(db_engine, PostgresEngine):
+            db_conn = hs.get_db_conn()
+            yield prepare_database(db_conn, db_engine, config)
+            db_conn.commit()
+            db_conn.close()
+
+        else:
+            # We need to do cleanup on PostgreSQL
+            def cleanup():
+                # Close all the db pools
+                hs.get_db_pool().close()
+
+                # Drop the test database
+                db_conn = db_engine.module.connect(
+                    database=POSTGRES_BASE_DB, user=POSTGRES_USER
+                )
+                db_conn.autocommit = True
+                cur = db_conn.cursor()
+                cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
+                db_conn.commit()
+                cur.close()
+                db_conn.close()
+
+            # Register the cleanup hook
+            cleanup_func(cleanup)
+
         hs.setup()
     else:
         hs = HomeServer(
-            name, db_pool=None, datastore=datastore, config=config,
+            name,
+            db_pool=None,
+            datastore=datastore,
+            config=config,
             version_string="Synapse/tests",
             database_engine=db_engine,
             room_list_handler=object(),
             tls_server_context_factory=Mock(),
+            tls_client_options_factory=Mock(),
             reactor=reactor,
             **kargs
         )
@@ -154,8 +253,9 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None
     # because AuthHandler's constructor requires the HS, so we can't make one
     # beforehand and pass it in to the HS's constructor (chicken / egg)
     hs.get_auth_handler().hash = lambda p: hashlib.md5(p.encode('utf8')).hexdigest()
-    hs.get_auth_handler().validate_hash = lambda p, h: hashlib.md5(
-        p.encode('utf8')).hexdigest() == h
+    hs.get_auth_handler().validate_hash = (
+        lambda p, h: hashlib.md5(p.encode('utf8')).hexdigest() == h
+    )
 
     fed = kargs.get("resource_for_federation", None)
     if fed:
@@ -169,7 +269,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None
                 sleep_limit=hs.config.federation_rc_sleep_limit,
                 sleep_msec=hs.config.federation_rc_sleep_delay,
                 reject_limit=hs.config.federation_rc_reject_limit,
-                concurrent_requests=hs.config.federation_rc_concurrent
+                concurrent_requests=hs.config.federation_rc_concurrent,
             ),
         )
 
@@ -195,7 +295,6 @@ def mock_getRawHeaders(headers=None):
 
 # This is a mock /resource/ not an entire server
 class MockHttpResource(HttpServer):
-
     def __init__(self, prefix=""):
         self.callbacks = []  # 3-tuple of method/pattern/function
         self.prefix = prefix
@@ -259,15 +358,9 @@ class MockHttpResource(HttpServer):
             matcher = pattern.match(path)
             if matcher:
                 try:
-                    args = [
-                        urlparse.unquote(u)
-                        for u in matcher.groups()
-                    ]
-
-                    (code, response) = yield func(
-                        mock_request,
-                        *args
-                    )
+                    args = [urlparse.unquote(u) for u in matcher.groups()]
+
+                    (code, response) = yield func(mock_request, *args)
                     defer.returnValue((code, response))
                 except CodeMessageException as e:
                     defer.returnValue((e.code, cs_error(e.msg, code=e.errcode)))
@@ -368,8 +461,7 @@ class MockClock(object):
 
 def _format_call(args, kwargs):
     return ", ".join(
-        ["%r" % (a) for a in args] +
-        ["%s=%r" % (k, v) for k, v in kwargs.items()]
+        ["%r" % (a) for a in args] + ["%s=%r" % (k, v) for k, v in kwargs.items()]
     )
 
 
@@ -387,8 +479,9 @@ class DeferredMockCallable(object):
         self.calls.append((args, kwargs))
 
         if not self.expectations:
-            raise ValueError("%r has no pending calls to handle call(%s)" % (
-                self, _format_call(args, kwargs))
+            raise ValueError(
+                "%r has no pending calls to handle call(%s)"
+                % (self, _format_call(args, kwargs))
             )
 
         for (call, result, d) in self.expectations:
@@ -396,9 +489,9 @@ class DeferredMockCallable(object):
                 d.callback(None)
                 return result
 
-        failure = AssertionError("Was not expecting call(%s)" % (
-            _format_call(args, kwargs)
-        ))
+        failure = AssertionError(
+            "Was not expecting call(%s)" % (_format_call(args, kwargs))
+        )
 
         for _, _, d in self.expectations:
             try:
@@ -414,17 +507,19 @@ class DeferredMockCallable(object):
     @defer.inlineCallbacks
     def await_calls(self, timeout=1000):
         deferred = defer.DeferredList(
-            [d for _, _, d in self.expectations],
-            fireOnOneErrback=True
+            [d for _, _, d in self.expectations], fireOnOneErrback=True
         )
 
         timer = reactor.callLater(
             timeout / 1000,
             deferred.errback,
-            AssertionError("%d pending calls left: %s" % (
-                len([e for e in self.expectations if not e[2].called]),
-                [e for e in self.expectations if not e[2].called]
-            ))
+            AssertionError(
+                "%d pending calls left: %s"
+                % (
+                    len([e for e in self.expectations if not e[2].called]),
+                    [e for e in self.expectations if not e[2].called],
+                )
+            ),
         )
 
         yield deferred
@@ -439,7 +534,6 @@ class DeferredMockCallable(object):
             self.calls = []
 
             raise AssertionError(
-                "Expected not to received any calls, got:\n" + "\n".join([
-                    "call(%s)" % _format_call(c[0], c[1]) for c in calls
-                ])
+                "Expected not to received any calls, got:\n"
+                + "\n".join(["call(%s)" % _format_call(c[0], c[1]) for c in calls])
             )