From d6ff8bdf96770edd6b65775e70d85110a1e8d67e Mon Sep 17 00:00:00 2001 From: David Robertson Date: Sat, 11 Jun 2022 18:35:02 +0100 Subject: Introduce a `Timer` dataclass --- tests/utils.py | 54 +++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 37 insertions(+), 17 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index cae851a0eb..a4df72921d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -15,8 +15,9 @@ import atexit import os -from typing import TYPE_CHECKING, Dict, Union, cast, overload +from typing import TYPE_CHECKING, Callable, Dict, List, ParamSpec, Union, cast, overload +import attr from typing_extensions import Literal from synapse.api.constants import EventTypes @@ -218,13 +219,22 @@ def mock_getRawHeaders(headers=None): return getRawHeaders +P = ParamSpec("P") + + +@attr.s(slots=True, auto_attribs=True) +class Timer: + absolute_time: float + callback: Callable[[], None] + expired: bool + + class MockClock: now = 1000.0 def __init__(self) -> None: - # list of lists of [absolute_time, callback, expired] in no particular - # order - self.timers = [] + # Timers in no particular order + self.timers: List[Timer] = [] self.loopers = [] def time(self) -> float: @@ -233,27 +243,39 @@ class MockClock: def time_msec(self) -> int: return int(self.time() * 1000) - def call_later(self, delay, callback, *args, **kwargs): + def call_later( + self, + delay: float, + callback: Callable[P, object], + *args: P.args, + **kwargs: P.kwargs, + ) -> Timer: ctx = current_context() - def wrapped_callback(): + def wrapped_callback() -> None: set_current_context(ctx) callback(*args, **kwargs) - t = [self.now + delay, wrapped_callback, False] + t = Timer(self.now + delay, wrapped_callback, False) self.timers.append(t) return t - def looping_call(self, function, interval, *args, **kwargs): + def looping_call( + self, + function: Callable[P, object], + interval: float, + *args: P.args, + **kwargs: P.kwargs, + ) -> None: self.loopers.append([function, interval / 1000.0, self.now, args, kwargs]) - def cancel_call_later(self, timer, ignore_errs=False): - if timer[2]: + def cancel_call_later(self, timer: Timer, ignore_errs: bool = False) -> None: + if timer.expired: if not ignore_errs: raise Exception("Cannot cancel an expired timer") - timer[2] = True + timer.expired = True self.timers = [t for t in self.timers if t != timer] # For unit testing @@ -264,14 +286,12 @@ class MockClock: self.timers = [] for t in timers: - time, callback, expired = t - - if expired: + if t.expired: raise Exception("Timer already expired") - if self.now >= time: - t[2] = True - callback() + if self.now >= t.absolute_time: + t.expired = True + t.callback() else: self.timers.append(t) -- cgit 1.4.1