summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorRichard van der Hoff <github@rvanderhoff.org.uk>2018-01-27 18:20:55 +0100
committerGitHub <noreply@github.com>2018-01-27 18:20:55 +0100
commit2186d7c06ee0c4b200b4567c7848cca434444fc4 (patch)
treed6e87eb8800eafd86502295fb8c3164b0f32a14e /tests
parentMerge pull request #2828 from matrix-org/rav/kill_memory_datastore (diff)
parentMake it possible to run tests against postgres (diff)
downloadsynapse-2186d7c06ee0c4b200b4567c7848cca434444fc4.tar.xz
Merge pull request #2829 from matrix-org/rav/postgres_in_tests
Make it possible to run tests against postgres
Diffstat (limited to 'tests')
-rw-r--r--tests/crypto/test_keyring.py4
-rw-r--r--tests/utils.py81
2 files changed, 47 insertions, 38 deletions
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index c899fecf5d..d4ec02ffc2 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -167,7 +167,7 @@ class KeyringTestCase(unittest.TestCase):
 
             # wait a tick for it to send the request to the perspectives server
             # (it first tries the datastore)
-            yield async.sleep(0.005)
+            yield async.sleep(1)   # XXX find out why this takes so long!
             self.http_client.post_json.assert_called_once()
 
             self.assertIs(LoggingContext.current_context(), context_11)
@@ -183,7 +183,7 @@ class KeyringTestCase(unittest.TestCase):
                 res_deferreds_2 = kr.verify_json_objects_for_server(
                     [("server10", json1)],
                 )
-                yield async.sleep(0.005)
+                yield async.sleep(01)
                 self.http_client.post_json.assert_not_called()
                 res_deferreds_2[0].addBoth(self.check_context, None)
 
diff --git a/tests/utils.py b/tests/utils.py
index de33deb0b2..d1f59551e8 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -19,18 +19,23 @@ import urllib
 import urlparse
 
 from mock import Mock, patch
-from twisted.enterprise.adbapi import ConnectionPool
 from twisted.internet import defer, reactor
 
 from synapse.api.errors import CodeMessageException, cs_error
 from synapse.federation.transport import server
 from synapse.http.server import HttpServer
 from synapse.server import HomeServer
+from synapse.storage import PostgresEngine
 from synapse.storage.engines import create_engine
 from synapse.storage.prepare_database import prepare_database
 from synapse.util.logcontext import LoggingContext
 from synapse.util.ratelimitutils import FederationRateLimiter
 
+# set this to True to run the tests against postgres instead of sqlite.
+# It requires you to have a local postgres database called synapse_test, within
+# which ALL TABLES WILL BE DROPPED
+USE_POSTGRES_FOR_TESTS = False
+
 
 @defer.inlineCallbacks
 def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
@@ -60,30 +65,62 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
         config.update_user_directory = False
 
     config.use_frozen_dicts = True
-    config.database_config = {"name": "sqlite3"}
     config.ldap_enabled = False
 
     if "clock" not in kargs:
         kargs["clock"] = MockClock()
 
+    if USE_POSTGRES_FOR_TESTS:
+        config.database_config = {
+            "name": "psycopg2",
+            "args": {
+                "database": "synapse_test",
+                "cp_min": 1,
+                "cp_max": 5,
+            },
+        }
+    else:
+        config.database_config = {
+            "name": "sqlite3",
+            "args": {
+                "database": ":memory:",
+                "cp_min": 1,
+                "cp_max": 1,
+            },
+        }
+
+    db_engine = create_engine(config.database_config)
+
+    # we need to configure the connection pool to run the on_new_connection
+    # function, so that we can test code that uses custom sqlite functions
+    # (like rank).
+    config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection
+
     if datastore is None:
-        db_pool = SQLiteMemoryDbPool()
-        yield db_pool.prepare()
         hs = HomeServer(
-            name, db_pool=db_pool, config=config,
+            name, config=config,
+            db_config=config.database_config,
             version_string="Synapse/tests",
-            database_engine=create_engine(config.database_config),
-            get_db_conn=db_pool.get_db_conn,
+            database_engine=db_engine,
             room_list_handler=object(),
             tls_server_context_factory=Mock(),
             **kargs
         )
+        db_conn = hs.get_db_conn()
+        # make sure that the database is empty
+        if isinstance(db_engine, PostgresEngine):
+            cur = db_conn.cursor()
+            cur.execute("SELECT tablename FROM pg_tables where schemaname='public'")
+            rows = cur.fetchall()
+            for r in rows:
+                cur.execute("DROP TABLE %s CASCADE" % r[0])
+        yield prepare_database(db_conn, db_engine, config)
         hs.setup()
     else:
         hs = HomeServer(
             name, db_pool=None, datastore=datastore, config=config,
             version_string="Synapse/tests",
-            database_engine=create_engine(config.database_config),
+            database_engine=db_engine,
             room_list_handler=object(),
             tls_server_context_factory=Mock(),
             **kargs
@@ -302,34 +339,6 @@ class MockClock(object):
         return d
 
 
-class SQLiteMemoryDbPool(ConnectionPool, object):
-    def __init__(self):
-        super(SQLiteMemoryDbPool, self).__init__(
-            "sqlite3", ":memory:",
-            cp_min=1,
-            cp_max=1,
-        )
-
-        self.config = Mock()
-        self.config.password_providers = []
-        self.config.database_config = {"name": "sqlite3"}
-
-    def prepare(self):
-        engine = self.create_engine()
-        return self.runWithConnection(
-            lambda conn: prepare_database(conn, engine, self.config)
-        )
-
-    def get_db_conn(self):
-        conn = self.connect()
-        engine = self.create_engine()
-        prepare_database(conn, engine, self.config)
-        return conn
-
-    def create_engine(self):
-        return create_engine(self.config.database_config)
-
-
 def _format_call(args, kwargs):
     return ", ".join(
         ["%r" % (a) for a in args] +