summary refs log tree commit diff
path: root/tests/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/utils.py')
-rw-r--r--tests/utils.py39
1 files changed, 18 insertions, 21 deletions
diff --git a/tests/utils.py b/tests/utils.py
index f3935648a0..c67fa1ca35 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -154,7 +154,7 @@ class MockHttpResource(HttpServer):
 
         mock_request.getClientIP.return_value = "-"
 
-        mock_request.requestHeaders.getRawHeaders.return_value=[
+        mock_request.requestHeaders.getRawHeaders.return_value = [
             "X-Matrix origin=test,key=,sig="
         ]
 
@@ -226,12 +226,12 @@ class MockClock(object):
     def time_msec(self):
         return self.time() * 1000
 
-    def call_later(self, delay, callback):
+    def call_later(self, delay, callback, *args, **kwargs):
         current_context = LoggingContext.current_context()
 
         def wrapped_callback():
             LoggingContext.thread_local.current_context = current_context
-            callback()
+            callback(*args, **kwargs)
 
         t = [self.now + delay, wrapped_callback, False]
         self.timers.append(t)
@@ -241,9 +241,10 @@ class MockClock(object):
     def looping_call(self, function, interval):
         pass
 
-    def cancel_call_later(self, timer):
+    def cancel_call_later(self, timer, ignore_errs=False):
         if timer[2]:
-            raise Exception("Cannot cancel an expired timer")
+            if not ignore_errs:
+                raise Exception("Cannot cancel an expired timer")
 
         timer[2] = True
         self.timers = [t for t in self.timers if t != timer]
@@ -368,13 +369,12 @@ class MemoryDataStore(object):
 
     def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
         return [
-            self.members[r].get(user_id) for r in self.members
-            if user_id in self.members[r] and
-                self.members[r][user_id].membership in membership_list
+            m[user_id] for m in self.members.values()
+            if user_id in m and m[user_id].membership in membership_list
         ]
 
     def get_room_events_stream(self, user_id=None, from_key=None, to_key=None,
-                            limit=0, with_feedback=False):
+                               limit=0, with_feedback=False):
         return ([], from_key)  # TODO
 
     def get_joined_hosts_for_room(self, room_id):
@@ -384,7 +384,6 @@ class MemoryDataStore(object):
         if event.type == EventTypes.Member:
             room_id = event.room_id
             user = event.state_key
-            membership = event.membership
             self.members.setdefault(room_id, {})[user] = event
 
         if hasattr(event, "state_key"):
@@ -464,9 +463,9 @@ class DeferredMockCallable(object):
                 d.callback(None)
                 return result
 
-        failure = AssertionError("Was not expecting call(%s)" %
+        failure = AssertionError("Was not expecting call(%s)" % (
             _format_call(args, kwargs)
-        )
+        ))
 
         for _, _, d in self.expectations:
             try:
@@ -487,14 +486,12 @@ class DeferredMockCallable(object):
         )
 
         timer = reactor.callLater(
-            timeout/1000,
+            timeout / 1000,
             deferred.errback,
-            AssertionError(
-                "%d pending calls left: %s"% (
-                    len([e for e in self.expectations if not e[2].called]),
-                    [e for e in self.expectations if not e[2].called]
-                )
-            )
+            AssertionError("%d pending calls left: %s" % (
+                len([e for e in self.expectations if not e[2].called]),
+                [e for e in self.expectations if not e[2].called]
+            ))
         )
 
         yield deferred
@@ -508,8 +505,8 @@ class DeferredMockCallable(object):
             calls = self.calls
             self.calls = []
 
-            raise AssertionError("Expected not to received any calls, got:\n" +
-                "\n".join([
+            raise AssertionError(
+                "Expected not to received any calls, got:\n" + "\n".join([
                     "call(%s)" % _format_call(c[0], c[1]) for c in calls
                 ])
             )