summary refs log tree commit diff
path: root/tests/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/utils.py')
-rw-r--r--tests/utils.py31
1 files changed, 29 insertions, 2 deletions
diff --git a/tests/utils.py b/tests/utils.py
index aee69b1caa..3b1eb50d8d 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -1,5 +1,5 @@
 # -*- coding: utf-8 -*-
-# Copyright 2014 OpenMarket Ltd
+# Copyright 2014-2016 OpenMarket Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -19,6 +19,8 @@ from synapse.api.constants import EventTypes
 from synapse.storage.prepare_database import prepare_database
 from synapse.storage.engines import create_engine
 from synapse.server import HomeServer
+from synapse.federation.transport import server
+from synapse.util.ratelimitutils import FederationRateLimiter
 
 from synapse.util.logcontext import LoggingContext
 
@@ -44,9 +46,10 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
         config = Mock()
         config.signing_key = [MockKey()]
         config.event_cache_size = 1
-        config.disable_registration = False
+        config.enable_registration = True
         config.macaroon_secret_key = "not even a little secret"
         config.server_name = "server.under.test"
+        config.trusted_third_party_id_servers = []
 
     if "clock" not in kargs:
         kargs["clock"] = MockClock()
@@ -58,8 +61,10 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
             name, db_pool=db_pool, config=config,
             version_string="Synapse/tests",
             database_engine=create_engine("sqlite3"),
+            get_db_conn=db_pool.get_db_conn,
             **kargs
         )
+        hs.setup()
     else:
         hs = HomeServer(
             name, db_pool=None, datastore=datastore, config=config,
@@ -80,6 +85,22 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
 
     hs.build_handlers = swap_out_hash_for_testing(hs.build_handlers)
 
+    fed = kargs.get("resource_for_federation", None)
+    if fed:
+        server.register_servlets(
+            hs,
+            resource=fed,
+            authenticator=server.Authenticator(hs),
+            ratelimiter=FederationRateLimiter(
+                hs.get_clock(),
+                window_size=hs.config.federation_rc_window_size,
+                sleep_limit=hs.config.federation_rc_sleep_limit,
+                sleep_msec=hs.config.federation_rc_sleep_delay,
+                reject_limit=hs.config.federation_rc_reject_limit,
+                concurrent_requests=hs.config.federation_rc_concurrent
+            ),
+        )
+
     defer.returnValue(hs)
 
 
@@ -262,6 +283,12 @@ class SQLiteMemoryDbPool(ConnectionPool, object):
             lambda conn: prepare_database(conn, engine)
         )
 
+    def get_db_conn(self):
+        conn = self.connect()
+        engine = create_engine("sqlite3")
+        prepare_database(conn, engine)
+        return conn
+
 
 class MemoryDataStore(object):