summary refs log tree commit diff
path: root/tests/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/utils.py')
-rw-r--r--tests/utils.py90
1 files changed, 64 insertions, 26 deletions
diff --git a/tests/utils.py b/tests/utils.py
index 358b5b72b7..52405502e9 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -19,6 +19,9 @@ from synapse.api.constants import EventTypes
 from synapse.storage.prepare_database import prepare_database
 from synapse.storage.engines import create_engine
 from synapse.server import HomeServer
+from synapse.federation.transport import server
+from synapse.types import Requester
+from synapse.util.ratelimitutils import FederationRateLimiter
 
 from synapse.util.logcontext import LoggingContext
 
@@ -44,9 +47,13 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
         config = Mock()
         config.signing_key = [MockKey()]
         config.event_cache_size = 1
-        config.disable_registration = False
+        config.enable_registration = True
         config.macaroon_secret_key = "not even a little secret"
         config.server_name = "server.under.test"
+        config.trusted_third_party_id_servers = []
+        config.room_invite_state_types = []
+
+    config.database_config = {"name": "sqlite3"}
 
     if "clock" not in kargs:
         kargs["clock"] = MockClock()
@@ -57,14 +64,16 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
         hs = HomeServer(
             name, db_pool=db_pool, config=config,
             version_string="Synapse/tests",
-            database_engine=create_engine("sqlite3"),
+            database_engine=create_engine(config),
+            get_db_conn=db_pool.get_db_conn,
             **kargs
         )
+        hs.setup()
     else:
         hs = HomeServer(
             name, db_pool=None, datastore=datastore, config=config,
             version_string="Synapse/tests",
-            database_engine=create_engine("sqlite3"),
+            database_engine=create_engine(config),
             **kargs
         )
 
@@ -80,6 +89,22 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
 
     hs.build_handlers = swap_out_hash_for_testing(hs.build_handlers)
 
+    fed = kargs.get("resource_for_federation", None)
+    if fed:
+        server.register_servlets(
+            hs,
+            resource=fed,
+            authenticator=server.Authenticator(hs),
+            ratelimiter=FederationRateLimiter(
+                hs.get_clock(),
+                window_size=hs.config.federation_rc_window_size,
+                sleep_limit=hs.config.federation_rc_sleep_limit,
+                sleep_msec=hs.config.federation_rc_sleep_delay,
+                reject_limit=hs.config.federation_rc_reject_limit,
+                concurrent_requests=hs.config.federation_rc_concurrent
+            ),
+        )
+
     defer.returnValue(hs)
 
 
@@ -131,7 +156,7 @@ class MockHttpResource(HttpServer):
 
         mock_request.getClientIP.return_value = "-"
 
-        mock_request.requestHeaders.getRawHeaders.return_value=[
+        mock_request.requestHeaders.getRawHeaders.return_value = [
             "X-Matrix origin=test,key=,sig="
         ]
 
@@ -203,12 +228,12 @@ class MockClock(object):
     def time_msec(self):
         return self.time() * 1000
 
-    def call_later(self, delay, callback):
+    def call_later(self, delay, callback, *args, **kwargs):
         current_context = LoggingContext.current_context()
 
         def wrapped_callback():
             LoggingContext.thread_local.current_context = current_context
-            callback()
+            callback(*args, **kwargs)
 
         t = [self.now + delay, wrapped_callback, False]
         self.timers.append(t)
@@ -218,9 +243,10 @@ class MockClock(object):
     def looping_call(self, function, interval):
         pass
 
-    def cancel_call_later(self, timer):
+    def cancel_call_later(self, timer, ignore_errs=False):
         if timer[2]:
-            raise Exception("Cannot cancel an expired timer")
+            if not ignore_errs:
+                raise Exception("Cannot cancel an expired timer")
 
         timer[2] = True
         self.timers = [t for t in self.timers if t != timer]
@@ -256,12 +282,24 @@ class SQLiteMemoryDbPool(ConnectionPool, object):
             cp_max=1,
         )
 
+        self.config = Mock()
+        self.config.database_config = {"name": "sqlite3"}
+
     def prepare(self):
-        engine = create_engine("sqlite3")
+        engine = self.create_engine()
         return self.runWithConnection(
-            lambda conn: prepare_database(conn, engine)
+            lambda conn: prepare_database(conn, engine, self.config)
         )
 
+    def get_db_conn(self):
+        conn = self.connect()
+        engine = self.create_engine()
+        prepare_database(conn, engine, self.config)
+        return conn
+
+    def create_engine(self):
+        return create_engine(self.config)
+
 
 class MemoryDataStore(object):
 
@@ -333,13 +371,12 @@ class MemoryDataStore(object):
 
     def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
         return [
-            self.members[r].get(user_id) for r in self.members
-            if user_id in self.members[r] and
-                self.members[r][user_id].membership in membership_list
+            m[user_id] for m in self.members.values()
+            if user_id in m and m[user_id].membership in membership_list
         ]
 
     def get_room_events_stream(self, user_id=None, from_key=None, to_key=None,
-                            limit=0, with_feedback=False):
+                               limit=0, with_feedback=False):
         return ([], from_key)  # TODO
 
     def get_joined_hosts_for_room(self, room_id):
@@ -349,7 +386,6 @@ class MemoryDataStore(object):
         if event.type == EventTypes.Member:
             room_id = event.room_id
             user = event.state_key
-            membership = event.membership
             self.members.setdefault(room_id, {})[user] = event
 
         if hasattr(event, "state_key"):
@@ -429,9 +465,9 @@ class DeferredMockCallable(object):
                 d.callback(None)
                 return result
 
-        failure = AssertionError("Was not expecting call(%s)" %
+        failure = AssertionError("Was not expecting call(%s)" % (
             _format_call(args, kwargs)
-        )
+        ))
 
         for _, _, d in self.expectations:
             try:
@@ -452,14 +488,12 @@ class DeferredMockCallable(object):
         )
 
         timer = reactor.callLater(
-            timeout/1000,
+            timeout / 1000,
             deferred.errback,
-            AssertionError(
-                "%d pending calls left: %s"% (
-                    len([e for e in self.expectations if not e[2].called]),
-                    [e for e in self.expectations if not e[2].called]
-                )
-            )
+            AssertionError("%d pending calls left: %s" % (
+                len([e for e in self.expectations if not e[2].called]),
+                [e for e in self.expectations if not e[2].called]
+            ))
         )
 
         yield deferred
@@ -473,8 +507,12 @@ class DeferredMockCallable(object):
             calls = self.calls
             self.calls = []
 
-            raise AssertionError("Expected not to received any calls, got:\n" +
-                "\n".join([
+            raise AssertionError(
+                "Expected not to received any calls, got:\n" + "\n".join([
                     "call(%s)" % _format_call(c[0], c[1]) for c in calls
                 ])
             )
+
+
+def requester_for_user(user):
+    return Requester(user, None, False)