summary refs log tree commit diff
path: root/tests/utils.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/utils.py100
1 files changed, 50 insertions, 50 deletions
diff --git a/tests/utils.py b/tests/utils.py
index 46ef2959f2..59c020a051 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -30,19 +30,16 @@ from twisted.internet import defer, reactor
 from synapse.api.constants import EventTypes
 from synapse.api.errors import CodeMessageException, cs_error
 from synapse.api.room_versions import RoomVersions
+from synapse.config.database import DatabaseConnectionConfig
 from synapse.config.homeserver import HomeServerConfig
 from synapse.config.server import DEFAULT_ROOM_VERSION
 from synapse.federation.transport import server as federation_server
 from synapse.http.server import HttpServer
-from synapse.logging.context import LoggingContext
+from synapse.logging.context import current_context, set_current_context
 from synapse.server import HomeServer
 from synapse.storage import DataStore
 from synapse.storage.engines import PostgresEngine, create_engine
-from synapse.storage.prepare_database import (
-    _get_or_create_schema_state,
-    _setup_new_database,
-    prepare_database,
-)
+from synapse.storage.prepare_database import prepare_database
 from synapse.util.ratelimitutils import FederationRateLimiter
 
 # set this to True to run the tests against postgres instead of sqlite.
@@ -77,7 +74,10 @@ def setupdb():
         db_conn.autocommit = True
         cur = db_conn.cursor()
         cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
-        cur.execute("CREATE DATABASE %s;" % (POSTGRES_BASE_DB,))
+        cur.execute(
+            "CREATE DATABASE %s ENCODING 'UTF8' LC_COLLATE='C' LC_CTYPE='C' "
+            "template=template0;" % (POSTGRES_BASE_DB,)
+        )
         cur.close()
         db_conn.close()
 
@@ -88,11 +88,7 @@ def setupdb():
             host=POSTGRES_HOST,
             password=POSTGRES_PASSWORD,
         )
-        cur = db_conn.cursor()
-        _get_or_create_schema_state(cur, db_engine)
-        _setup_new_database(cur, db_engine)
-        db_conn.commit()
-        cur.close()
+        prepare_database(db_conn, db_engine, None)
         db_conn.close()
 
         def _cleanup():
@@ -117,6 +113,7 @@ def default_config(name, parse=False):
     """
     config_dict = {
         "server_name": name,
+        "send_federation": False,
         "media_store_path": "media",
         "uploads_path": "uploads",
         # the test signing key is just an arbitrary ed25519 key to keep the config
@@ -145,7 +142,6 @@ def default_config(name, parse=False):
         "limit_usage_by_mau": False,
         "hs_disabled": False,
         "hs_disabled_message": "",
-        "hs_disabled_limit_type": "",
         "max_mau_value": 50,
         "mau_trial_days": 0,
         "mau_stats_only": False,
@@ -171,6 +167,7 @@ def default_config(name, parse=False):
         # disable user directory updates, because they get done in the
         # background, which upsets the test runner.
         "update_user_directory": False,
+        "caches": {"global_factor": 1},
     }
 
     if parse:
@@ -185,7 +182,6 @@ class TestHomeServer(HomeServer):
     DATASTORE_CLASS = DataStore
 
 
-@defer.inlineCallbacks
 def setup_test_homeserver(
     cleanup_func,
     name="test",
@@ -222,7 +218,7 @@ def setup_test_homeserver(
     if USE_POSTGRES_FOR_TESTS:
         test_db = "synapse_test_%s" % uuid.uuid4().hex
 
-        config.database_config = {
+        database_config = {
             "name": "psycopg2",
             "args": {
                 "database": test_db,
@@ -234,12 +230,15 @@ def setup_test_homeserver(
             },
         }
     else:
-        config.database_config = {
+        database_config = {
             "name": "sqlite3",
             "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
         }
 
-    db_engine = create_engine(config.database_config)
+    database = DatabaseConnectionConfig("master", database_config)
+    config.database.databases = [database]
+
+    db_engine = create_engine(database.config)
 
     # Create the database before we actually try and connect to it, based off
     # the template database we generate in setupdb()
@@ -259,39 +258,30 @@ def setup_test_homeserver(
         cur.close()
         db_conn.close()
 
-    # we need to configure the connection pool to run the on_new_connection
-    # function, so that we can test code that uses custom sqlite functions
-    # (like rank).
-    config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection
-
     if datastore is None:
         hs = homeserverToUse(
             name,
             config=config,
-            db_config=config.database_config,
             version_string="Synapse/tests",
-            database_engine=db_engine,
             tls_server_context_factory=Mock(),
             tls_client_options_factory=Mock(),
             reactor=reactor,
             **kargs
         )
 
-        # Prepare the DB on SQLite -- PostgreSQL is a copy of an already up to
-        # date db
-        if not isinstance(db_engine, PostgresEngine):
-            db_conn = hs.get_db_conn()
-            yield prepare_database(db_conn, db_engine, config)
-            db_conn.commit()
-            db_conn.close()
+        hs.setup()
+        if homeserverToUse.__name__ == "TestHomeServer":
+            hs.setup_master()
+
+        if isinstance(db_engine, PostgresEngine):
+            database = hs.get_datastores().databases[0]
 
-        else:
             # We need to do cleanup on PostgreSQL
             def cleanup():
                 import psycopg2
 
                 # Close all the db pools
-                hs.get_db_pool().close()
+                database._db_pool.close()
 
                 dropped = False
 
@@ -330,17 +320,12 @@ def setup_test_homeserver(
                 # Register the cleanup hook
                 cleanup_func(cleanup)
 
-        hs.setup()
-        if homeserverToUse.__name__ == "TestHomeServer":
-            hs.setup_master()
     else:
         hs = homeserverToUse(
             name,
-            db_pool=None,
             datastore=datastore,
             config=config,
             version_string="Synapse/tests",
-            database_engine=db_engine,
             tls_server_context_factory=Mock(),
             tls_client_options_factory=Mock(),
             reactor=reactor,
@@ -351,10 +336,15 @@ def setup_test_homeserver(
     # Need to let the HS build an auth handler and then mess with it
     # because AuthHandler's constructor requires the HS, so we can't make one
     # beforehand and pass it in to the HS's constructor (chicken / egg)
-    hs.get_auth_handler().hash = lambda p: hashlib.md5(p.encode("utf8")).hexdigest()
-    hs.get_auth_handler().validate_hash = (
-        lambda p, h: hashlib.md5(p.encode("utf8")).hexdigest() == h
-    )
+    async def hash(p):
+        return hashlib.md5(p.encode("utf8")).hexdigest()
+
+    hs.get_auth_handler().hash = hash
+
+    async def validate_hash(p, h):
+        return hashlib.md5(p.encode("utf8")).hexdigest() == h
+
+    hs.get_auth_handler().validate_hash = validate_hash
 
     fed = kargs.get("resource_for_federation", None)
     if fed:
@@ -463,7 +453,9 @@ class MockHttpResource(HttpServer):
                 try:
                     args = [urlparse.unquote(u) for u in matcher.groups()]
 
-                    (code, response) = yield func(mock_request, *args)
+                    (code, response) = yield defer.ensureDeferred(
+                        func(mock_request, *args)
+                    )
                     return code, response
                 except CodeMessageException as e:
                     return (e.code, cs_error(e.msg, code=e.errcode))
@@ -510,10 +502,10 @@ class MockClock(object):
         return self.time() * 1000
 
     def call_later(self, delay, callback, *args, **kwargs):
-        current_context = LoggingContext.current_context()
+        ctx = current_context()
 
         def wrapped_callback():
-            LoggingContext.thread_local.current_context = current_context
+            set_current_context(ctx)
             callback(*args, **kwargs)
 
         t = [self.now + delay, wrapped_callback, False]
@@ -521,8 +513,8 @@ class MockClock(object):
 
         return t
 
-    def looping_call(self, function, interval):
-        self.loopers.append([function, interval / 1000.0, self.now])
+    def looping_call(self, function, interval, *args, **kwargs):
+        self.loopers.append([function, interval / 1000.0, self.now, args, kwargs])
 
     def cancel_call_later(self, timer, ignore_errs=False):
         if timer[2]:
@@ -552,9 +544,9 @@ class MockClock(object):
                 self.timers.append(t)
 
         for looped in self.loopers:
-            func, interval, last = looped
+            func, interval, last, args, kwargs = looped
             if last + interval < self.now:
-                func()
+                func(*args, **kwargs)
                 looped[2] = self.now
 
     def advance_time_msec(self, ms):
@@ -655,10 +647,18 @@ def create_room(hs, room_id, creator_id):
         creator_id (str)
     """
 
+    persistence_store = hs.get_storage().persistence
     store = hs.get_datastore()
     event_builder_factory = hs.get_event_builder_factory()
     event_creation_handler = hs.get_event_creation_handler()
 
+    yield store.store_room(
+        room_id=room_id,
+        room_creator_user_id=creator_id,
+        is_public=False,
+        room_version=RoomVersions.V1,
+    )
+
     builder = event_builder_factory.for_room_version(
         RoomVersions.V1,
         {
@@ -672,4 +672,4 @@ def create_room(hs, room_id, creator_id):
 
     event, context = yield event_creation_handler.create_new_client_event(builder)
 
-    yield store.persist_event(event, context)
+    yield persistence_store.persist_event(event, context)