summary refs log tree commit diff
path: root/tests/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/utils.py')
-rw-r--r--tests/utils.py150
1 files changed, 99 insertions, 51 deletions
diff --git a/tests/utils.py b/tests/utils.py
index 2dfcb70a93..615b9f8cca 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018-2019 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -28,8 +29,8 @@ from twisted.internet import defer, reactor
 
 from synapse.api.constants import EventTypes, RoomVersions
 from synapse.api.errors import CodeMessageException, cs_error
-from synapse.config.server import ServerConfig
-from synapse.federation.transport import server
+from synapse.config.homeserver import HomeServerConfig
+from synapse.federation.transport import server as federation_server
 from synapse.http.server import HttpServer
 from synapse.server import HomeServer
 from synapse.storage import DataStore
@@ -43,30 +44,32 @@ from synapse.util.logcontext import LoggingContext
 from synapse.util.ratelimitutils import FederationRateLimiter
 
 # set this to True to run the tests against postgres instead of sqlite.
+#
+# When running under postgres, we first create a base database with the name
+# POSTGRES_BASE_DB and update it to the current schema. Then, for each test case, we
+# create another unique database, using the base database as a template.
 USE_POSTGRES_FOR_TESTS = os.environ.get("SYNAPSE_POSTGRES", False)
 LEAVE_DB = os.environ.get("SYNAPSE_LEAVE_DB", False)
-POSTGRES_USER = os.environ.get("SYNAPSE_POSTGRES_USER", "postgres")
+POSTGRES_USER = os.environ.get("SYNAPSE_POSTGRES_USER", None)
+POSTGRES_HOST = os.environ.get("SYNAPSE_POSTGRES_HOST", None)
+POSTGRES_PASSWORD = os.environ.get("SYNAPSE_POSTGRES_PASSWORD", None)
 POSTGRES_BASE_DB = "_synapse_unit_tests_base_%s" % (os.getpid(),)
 
+# the dbname we will connect to in order to create the base database.
+POSTGRES_DBNAME_FOR_INITIAL_CREATE = "postgres"
 
-def setupdb():
 
+def setupdb():
     # If we're using PostgreSQL, set up the db once
     if USE_POSTGRES_FOR_TESTS:
-        pgconfig = {
-            "name": "psycopg2",
-            "args": {
-                "database": POSTGRES_BASE_DB,
-                "user": POSTGRES_USER,
-                "cp_min": 1,
-                "cp_max": 5,
-            },
-        }
-        config = Mock()
-        config.password_providers = []
-        config.database_config = pgconfig
-        db_engine = create_engine(pgconfig)
-        db_conn = db_engine.module.connect(user=POSTGRES_USER)
+        # create a PostgresEngine
+        db_engine = create_engine({"name": "psycopg2", "args": {}})
+
+        # connect to postgres to create the base database.
+        db_conn = db_engine.module.connect(
+            user=POSTGRES_USER, host=POSTGRES_HOST, password=POSTGRES_PASSWORD,
+            dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE,
+        )
         db_conn.autocommit = True
         cur = db_conn.cursor()
         cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
@@ -76,7 +79,10 @@ def setupdb():
 
         # Set up in the db
         db_conn = db_engine.module.connect(
-            database=POSTGRES_BASE_DB, user=POSTGRES_USER
+            database=POSTGRES_BASE_DB,
+            user=POSTGRES_USER,
+            host=POSTGRES_HOST,
+            password=POSTGRES_PASSWORD,
         )
         cur = db_conn.cursor()
         _get_or_create_schema_state(cur, db_engine)
@@ -86,7 +92,10 @@ def setupdb():
         db_conn.close()
 
         def _cleanup():
-            db_conn = db_engine.module.connect(user=POSTGRES_USER)
+            db_conn = db_engine.module.connect(
+                user=POSTGRES_USER, host=POSTGRES_HOST, password=POSTGRES_PASSWORD,
+                dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE,
+            )
             db_conn.autocommit = True
             cur = db_conn.cursor()
             cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
@@ -100,13 +109,25 @@ def default_config(name):
     """
     Create a reasonable test config.
     """
-    config = Mock()
-    config.signing_key = [MockKey()]
+    config_dict = {
+        "server_name": name,
+        "media_store_path": "media",
+        "uploads_path": "uploads",
+
+        # the test signing key is just an arbitrary ed25519 key to keep the config
+        # parser happy
+        "signing_key": "ed25519 a_lPym qvioDNmfExFBRPgdTU+wtFYKq4JfwFRv7sYVgWvmgJg",
+    }
+
+    config = HomeServerConfig()
+    config.parse_config_dict(config_dict)
+
+    # TODO: move this stuff into config_dict or get rid of it
     config.event_cache_size = 1
     config.enable_registration = True
+    config.enable_registration_captcha = False
     config.macaroon_secret_key = "not even a little secret"
     config.expire_access_token = False
-    config.server_name = name
     config.trusted_third_party_id_servers = []
     config.room_invite_state_types = []
     config.password_providers = []
@@ -139,9 +160,20 @@ def default_config(name):
     config.admin_contact = None
     config.rc_messages_per_second = 10000
     config.rc_message_burst_count = 10000
+    config.rc_registration.per_second = 10000
+    config.rc_registration.burst_count = 10000
+    config.rc_login_address.per_second = 10000
+    config.rc_login_address.burst_count = 10000
+    config.rc_login_account.per_second = 10000
+    config.rc_login_account.burst_count = 10000
+    config.rc_login_failed_attempts.per_second = 10000
+    config.rc_login_failed_attempts.burst_count = 10000
     config.saml2_enabled = False
     config.public_baseurl = None
     config.default_identity_server = None
+    config.key_refresh_interval = 24 * 60 * 60 * 1000
+    config.old_signing_keys = {}
+    config.tls_fingerprints = []
 
     config.use_frozen_dicts = False
 
@@ -153,13 +185,6 @@ def default_config(name):
     # background, which upsets the test runner.
     config.update_user_directory = False
 
-    def is_threepid_reserved(threepid):
-        return ServerConfig.is_threepid_reserved(
-            config.mau_limits_reserved_threepids, threepid
-        )
-
-    config.is_threepid_reserved.side_effect = is_threepid_reserved
-
     return config
 
 
@@ -186,6 +211,9 @@ def setup_test_homeserver(
     Args:
         cleanup_func : The function used to register a cleanup routine for
                        after the test.
+
+    Calling this method directly is deprecated: you should instead derive from
+    HomeserverTestCase.
     """
     if reactor is None:
         from twisted.internet import reactor
@@ -203,7 +231,14 @@ def setup_test_homeserver(
 
         config.database_config = {
             "name": "psycopg2",
-            "args": {"database": test_db, "cp_min": 1, "cp_max": 5},
+            "args": {
+                "database": test_db,
+                "host": POSTGRES_HOST,
+                "password": POSTGRES_PASSWORD,
+                "user": POSTGRES_USER,
+                "cp_min": 1,
+                "cp_max": 5,
+            },
         }
     else:
         config.database_config = {
@@ -217,7 +252,10 @@ def setup_test_homeserver(
     # the template database we generate in setupdb()
     if datastore is None and isinstance(db_engine, PostgresEngine):
         db_conn = db_engine.module.connect(
-            database=POSTGRES_BASE_DB, user=POSTGRES_USER
+            database=POSTGRES_BASE_DB,
+            user=POSTGRES_USER,
+            host=POSTGRES_HOST,
+            password=POSTGRES_PASSWORD,
         )
         db_conn.autocommit = True
         cur = db_conn.cursor()
@@ -240,7 +278,6 @@ def setup_test_homeserver(
             db_config=config.database_config,
             version_string="Synapse/tests",
             database_engine=db_engine,
-            room_list_handler=object(),
             tls_server_context_factory=Mock(),
             tls_client_options_factory=Mock(),
             reactor=reactor,
@@ -267,7 +304,10 @@ def setup_test_homeserver(
 
                 # Drop the test database
                 db_conn = db_engine.module.connect(
-                    database=POSTGRES_BASE_DB, user=POSTGRES_USER
+                    database=POSTGRES_BASE_DB,
+                    user=POSTGRES_USER,
+                    host=POSTGRES_HOST,
+                    password=POSTGRES_PASSWORD,
                 )
                 db_conn.autocommit = True
                 cur = db_conn.cursor()
@@ -298,6 +338,8 @@ def setup_test_homeserver(
                 cleanup_func(cleanup)
 
         hs.setup()
+        if homeserverToUse.__name__ == "TestHomeServer":
+            hs.setup_master()
     else:
         hs = homeserverToUse(
             name,
@@ -306,7 +348,6 @@ def setup_test_homeserver(
             config=config,
             version_string="Synapse/tests",
             database_engine=db_engine,
-            room_list_handler=object(),
             tls_server_context_factory=Mock(),
             tls_client_options_factory=Mock(),
             reactor=reactor,
@@ -324,23 +365,27 @@ def setup_test_homeserver(
 
     fed = kargs.get("resource_for_federation", None)
     if fed:
-        server.register_servlets(
-            hs,
-            resource=fed,
-            authenticator=server.Authenticator(hs),
-            ratelimiter=FederationRateLimiter(
-                hs.get_clock(),
-                window_size=hs.config.federation_rc_window_size,
-                sleep_limit=hs.config.federation_rc_sleep_limit,
-                sleep_msec=hs.config.federation_rc_sleep_delay,
-                reject_limit=hs.config.federation_rc_reject_limit,
-                concurrent_requests=hs.config.federation_rc_concurrent,
-            ),
-        )
+        register_federation_servlets(hs, fed)
 
     defer.returnValue(hs)
 
 
+def register_federation_servlets(hs, resource):
+    federation_server.register_servlets(
+        hs,
+        resource=resource,
+        authenticator=federation_server.Authenticator(hs),
+        ratelimiter=FederationRateLimiter(
+            hs.get_clock(),
+            window_size=hs.config.federation_rc_window_size,
+            sleep_limit=hs.config.federation_rc_sleep_limit,
+            sleep_msec=hs.config.federation_rc_sleep_delay,
+            reject_limit=hs.config.federation_rc_reject_limit,
+            concurrent_requests=hs.config.federation_rc_concurrent,
+        ),
+    )
+
+
 def get_mock_call_args(pattern_func, mock_func):
     """ Return the arguments the mock function was called with interpreted
     by the pattern functions argument list.
@@ -457,6 +502,9 @@ class MockKey(object):
     def verify(self, message, sig):
         assert sig == b"\x9a\x87$"
 
+    def encode(self):
+        return b"<fake_encoded_key>"
+
 
 class MockClock(object):
     now = 1000
@@ -486,7 +534,7 @@ class MockClock(object):
         return t
 
     def looping_call(self, function, interval):
-        self.loopers.append([function, interval / 1000., self.now])
+        self.loopers.append([function, interval / 1000.0, self.now])
 
     def cancel_call_later(self, timer, ignore_errs=False):
         if timer[2]:
@@ -522,7 +570,7 @@ class MockClock(object):
                 looped[2] = self.now
 
     def advance_time_msec(self, ms):
-        self.advance_time(ms / 1000.)
+        self.advance_time(ms / 1000.0)
 
     def time_bound_deferred(self, d, *args, **kwargs):
         # We don't bother timing things out for now.
@@ -631,7 +679,7 @@ def create_room(hs, room_id, creator_id):
             "sender": creator_id,
             "room_id": room_id,
             "content": {},
-        }
+        },
     )
 
     event, context = yield event_creation_handler.create_new_client_event(builder)