summary refs log tree commit diff
diff options
context:
space:
mode:
authorDavid Robertson <davidr@element.io>2022-06-11 18:35:02 +0100
committerDavid Robertson <davidr@element.io>2022-06-11 18:35:02 +0100
commitd6ff8bdf96770edd6b65775e70d85110a1e8d67e (patch)
treec5b789ac0ef0ced9d029551299b2968bcce1f6ec
parentTrack `now` as a float (diff)
downloadsynapse-d6ff8bdf96770edd6b65775e70d85110a1e8d67e.tar.xz
Introduce a `Timer` dataclass
-rw-r--r--tests/utils.py54
1 files changed, 37 insertions, 17 deletions
diff --git a/tests/utils.py b/tests/utils.py
index cae851a0eb..a4df72921d 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -15,8 +15,9 @@
 
 import atexit
 import os
-from typing import TYPE_CHECKING, Dict, Union, cast, overload
+from typing import TYPE_CHECKING, Callable, Dict, List, ParamSpec, Union, cast, overload
 
+import attr
 from typing_extensions import Literal
 
 from synapse.api.constants import EventTypes
@@ -218,13 +219,22 @@ def mock_getRawHeaders(headers=None):
     return getRawHeaders
 
 
+P = ParamSpec("P")
+
+
+@attr.s(slots=True, auto_attribs=True)
+class Timer:
+    absolute_time: float
+    callback: Callable[[], None]
+    expired: bool
+
+
 class MockClock:
     now = 1000.0
 
     def __init__(self) -> None:
-        # list of lists of [absolute_time, callback, expired] in no particular
-        # order
-        self.timers = []
+        # Timers in no particular order
+        self.timers: List[Timer] = []
         self.loopers = []
 
     def time(self) -> float:
@@ -233,27 +243,39 @@ class MockClock:
     def time_msec(self) -> int:
         return int(self.time() * 1000)
 
-    def call_later(self, delay, callback, *args, **kwargs):
+    def call_later(
+        self,
+        delay: float,
+        callback: Callable[P, object],
+        *args: P.args,
+        **kwargs: P.kwargs,
+    ) -> Timer:
         ctx = current_context()
 
-        def wrapped_callback():
+        def wrapped_callback() -> None:
             set_current_context(ctx)
             callback(*args, **kwargs)
 
-        t = [self.now + delay, wrapped_callback, False]
+        t = Timer(self.now + delay, wrapped_callback, False)
         self.timers.append(t)
 
         return t
 
-    def looping_call(self, function, interval, *args, **kwargs):
+    def looping_call(
+        self,
+        function: Callable[P, object],
+        interval: float,
+        *args: P.args,
+        **kwargs: P.kwargs,
+    ) -> None:
         self.loopers.append([function, interval / 1000.0, self.now, args, kwargs])
 
-    def cancel_call_later(self, timer, ignore_errs=False):
-        if timer[2]:
+    def cancel_call_later(self, timer: Timer, ignore_errs: bool = False) -> None:
+        if timer.expired:
             if not ignore_errs:
                 raise Exception("Cannot cancel an expired timer")
 
-        timer[2] = True
+        timer.expired = True
         self.timers = [t for t in self.timers if t != timer]
 
     # For unit testing
@@ -264,14 +286,12 @@ class MockClock:
         self.timers = []
 
         for t in timers:
-            time, callback, expired = t
-
-            if expired:
+            if t.expired:
                 raise Exception("Timer already expired")
 
-            if self.now >= time:
-                t[2] = True
-                callback()
+            if self.now >= t.absolute_time:
+                t.expired = True
+                t.callback()
             else:
                 self.timers.append(t)