summary refs log tree commit diff
path: root/tests/utils.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/utils.py136
1 files changed, 92 insertions, 44 deletions
diff --git a/tests/utils.py b/tests/utils.py
index aca6a0083b..424cc4c2a0 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -15,12 +15,17 @@
 
 import atexit
 import os
+from typing import Any, Callable, Dict, List, Tuple, Union, overload
+
+import attr
+from typing_extensions import Literal, ParamSpec
 
 from synapse.api.constants import EventTypes
 from synapse.api.room_versions import RoomVersions
 from synapse.config.homeserver import HomeServerConfig
 from synapse.config.server import DEFAULT_ROOM_VERSION
 from synapse.logging.context import current_context, set_current_context
+from synapse.server import HomeServer
 from synapse.storage.database import LoggingDatabaseConnection
 from synapse.storage.engines import create_engine
 from synapse.storage.prepare_database import prepare_database
@@ -50,12 +55,11 @@ SQLITE_PERSIST_DB = os.environ.get("SYNAPSE_TEST_PERSIST_SQLITE_DB") is not None
 POSTGRES_DBNAME_FOR_INITIAL_CREATE = "postgres"
 
 
-def setupdb():
+def setupdb() -> None:
     # If we're using PostgreSQL, set up the db once
     if USE_POSTGRES_FOR_TESTS:
         # create a PostgresEngine
         db_engine = create_engine({"name": "psycopg2", "args": {}})
-
         # connect to postgres to create the base database.
         db_conn = db_engine.module.connect(
             user=POSTGRES_USER,
@@ -82,11 +86,11 @@ def setupdb():
             port=POSTGRES_PORT,
             password=POSTGRES_PASSWORD,
         )
-        db_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests")
-        prepare_database(db_conn, db_engine, None)
-        db_conn.close()
+        logging_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests")
+        prepare_database(logging_conn, db_engine, None)
+        logging_conn.close()
 
-        def _cleanup():
+        def _cleanup() -> None:
             db_conn = db_engine.module.connect(
                 user=POSTGRES_USER,
                 host=POSTGRES_HOST,
@@ -103,7 +107,19 @@ def setupdb():
         atexit.register(_cleanup)
 
 
-def default_config(name, parse=False):
+@overload
+def default_config(name: str, parse: Literal[False] = ...) -> Dict[str, object]:
+    ...
+
+
+@overload
+def default_config(name: str, parse: Literal[True]) -> HomeServerConfig:
+    ...
+
+
+def default_config(
+    name: str, parse: bool = False
+) -> Union[Dict[str, object], HomeServerConfig]:
     """
     Create a reasonable test config.
     """
@@ -181,90 +197,122 @@ def default_config(name, parse=False):
     return config_dict
 
 
-def mock_getRawHeaders(headers=None):
+def mock_getRawHeaders(headers=None):  # type: ignore[no-untyped-def]
     headers = headers if headers is not None else {}
 
-    def getRawHeaders(name, default=None):
+    def getRawHeaders(name, default=None):  # type: ignore[no-untyped-def]
+        # If the requested header is present, the real twisted function returns
+        # List[str] if name is a str and List[bytes] if name is a bytes.
+        # This mock doesn't support that behaviour.
+        # Fortunately, none of the current callers of mock_getRawHeaders() provide a
+        # headers dict, so we don't encounter this discrepancy in practice.
         return headers.get(name, default)
 
     return getRawHeaders
 
 
+P = ParamSpec("P")
+
+
+@attr.s(slots=True, auto_attribs=True)
+class Timer:
+    absolute_time: float
+    callback: Callable[[], None]
+    expired: bool
+
+
+# TODO: Make this generic over a ParamSpec?
+@attr.s(slots=True, auto_attribs=True)
+class Looper:
+    func: Callable[..., Any]
+    interval: float  # seconds
+    last: float
+    args: Tuple[object, ...]
+    kwargs: Dict[str, object]
+
+
 class MockClock:
-    now = 1000
+    now = 1000.0
 
-    def __init__(self):
-        # list of lists of [absolute_time, callback, expired] in no particular
-        # order
-        self.timers = []
-        self.loopers = []
+    def __init__(self) -> None:
+        # Timers in no particular order
+        self.timers: List[Timer] = []
+        self.loopers: List[Looper] = []
 
-    def time(self):
+    def time(self) -> float:
         return self.now
 
-    def time_msec(self):
-        return self.time() * 1000
+    def time_msec(self) -> int:
+        return int(self.time() * 1000)
 
-    def call_later(self, delay, callback, *args, **kwargs):
+    def call_later(
+        self,
+        delay: float,
+        callback: Callable[P, object],
+        *args: P.args,
+        **kwargs: P.kwargs,
+    ) -> Timer:
         ctx = current_context()
 
-        def wrapped_callback():
+        def wrapped_callback() -> None:
             set_current_context(ctx)
             callback(*args, **kwargs)
 
-        t = [self.now + delay, wrapped_callback, False]
+        t = Timer(self.now + delay, wrapped_callback, False)
         self.timers.append(t)
 
         return t
 
-    def looping_call(self, function, interval, *args, **kwargs):
-        self.loopers.append([function, interval / 1000.0, self.now, args, kwargs])
-
-    def cancel_call_later(self, timer, ignore_errs=False):
-        if timer[2]:
+    def looping_call(
+        self,
+        function: Callable[P, object],
+        interval: float,
+        *args: P.args,
+        **kwargs: P.kwargs,
+    ) -> None:
+        # This type-ignore should be redundant once we use a mypy release with
+        # https://github.com/python/mypy/pull/12668.
+        self.loopers.append(Looper(function, interval / 1000.0, self.now, args, kwargs))  # type: ignore[arg-type]
+
+    def cancel_call_later(self, timer: Timer, ignore_errs: bool = False) -> None:
+        if timer.expired:
             if not ignore_errs:
                 raise Exception("Cannot cancel an expired timer")
 
-        timer[2] = True
+        timer.expired = True
         self.timers = [t for t in self.timers if t != timer]
 
     # For unit testing
-    def advance_time(self, secs):
+    def advance_time(self, secs: float) -> None:
         self.now += secs
 
         timers = self.timers
         self.timers = []
 
         for t in timers:
-            time, callback, expired = t
-
-            if expired:
+            if t.expired:
                 raise Exception("Timer already expired")
 
-            if self.now >= time:
-                t[2] = True
-                callback()
+            if self.now >= t.absolute_time:
+                t.expired = True
+                t.callback()
             else:
                 self.timers.append(t)
 
         for looped in self.loopers:
-            func, interval, last, args, kwargs = looped
-            if last + interval < self.now:
-                func(*args, **kwargs)
-                looped[2] = self.now
+            if looped.last + looped.interval < self.now:
+                looped.func(*looped.args, **looped.kwargs)
+                looped.last = self.now
 
-    def advance_time_msec(self, ms):
+    def advance_time_msec(self, ms: float) -> None:
         self.advance_time(ms / 1000.0)
 
-    def time_bound_deferred(self, d, *args, **kwargs):
-        # We don't bother timing things out for now.
-        return d
-
 
-async def create_room(hs, room_id: str, creator_id: str):
+async def create_room(hs: HomeServer, room_id: str, creator_id: str) -> None:
     """Creates and persist a creation event for the given room"""
 
     persistence_store = hs.get_storage_controllers().persistence
+    assert persistence_store is not None
     store = hs.get_datastores().main
     event_builder_factory = hs.get_event_builder_factory()
     event_creation_handler = hs.get_event_creation_handler()