summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/14866.bugfix1
-rw-r--r--changelog.d/14879.misc1
-rw-r--r--changelog.d/14880.bugfix1
-rw-r--r--changelog.d/14886.misc1
-rw-r--r--changelog.d/14887.misc1
-rw-r--r--changelog.d/14904.misc1
-rw-r--r--changelog.d/14916.misc1
-rw-r--r--changelog.d/14920.misc1
-rw-r--r--changelog.d/14922.misc1
-rw-r--r--changelog.d/14927.misc1
-rw-r--r--docs/development/dependencies.md14
-rw-r--r--mypy.ini13
-rwxr-xr-xscripts-dev/release.py2
-rw-r--r--synapse/api/constants.py7
-rw-r--r--synapse/events/utils.py3
-rw-r--r--synapse/handlers/account_data.py6
-rw-r--r--synapse/handlers/admin.py4
-rw-r--r--synapse/handlers/device.py6
-rw-r--r--synapse/handlers/event_auth.py8
-rw-r--r--synapse/handlers/federation.py26
-rw-r--r--synapse/handlers/federation_event.py5
-rw-r--r--synapse/handlers/initial_sync.py16
-rw-r--r--synapse/handlers/pagination.py12
-rw-r--r--synapse/handlers/presence.py18
-rw-r--r--synapse/handlers/relations.py8
-rw-r--r--synapse/handlers/room.py14
-rw-r--r--synapse/handlers/room_summary.py4
-rw-r--r--synapse/handlers/search.py8
-rw-r--r--synapse/handlers/sso.py9
-rw-r--r--synapse/handlers/sync.py4
-rw-r--r--synapse/module_api/__init__.py2
-rw-r--r--synapse/notifier.py3
-rw-r--r--synapse/rest/client/push_rule.py4
-rw-r--r--synapse/rest/client/transactions.py3
-rw-r--r--synapse/storage/databases/main/relations.py46
-rw-r--r--synapse/storage/databases/main/stream.py197
-rw-r--r--synapse/streams/__init__.py6
-rw-r--r--synapse/streams/config.py11
-rw-r--r--tests/events/test_presence_router.py58
-rw-r--r--tests/events/test_snapshot.py17
-rw-r--r--tests/events/test_utils.py71
-rw-r--r--tests/federation/transport/test_knocking.py6
-rw-r--r--tests/handlers/test_typing.py99
-rw-r--r--tests/logging/__init__.py6
-rw-r--r--tests/logging/test_opentracing.py4
-rw-r--r--tests/logging/test_remote_handler.py25
-rw-r--r--tests/logging/test_terse_json.py30
-rw-r--r--tests/rest/client/test_sync.py4
-rw-r--r--tests/rest/client/test_transactions.py42
-rw-r--r--tests/storage/test_user_directory.py5
-rw-r--r--tests/unittest.py3
51 files changed, 507 insertions, 332 deletions
diff --git a/changelog.d/14866.bugfix b/changelog.d/14866.bugfix
new file mode 100644
index 0000000000..540f918cbd
--- /dev/null
+++ b/changelog.d/14866.bugfix
@@ -0,0 +1 @@
+Fix a bug introduced in Synapse 1.53.0 where `next_batch` tokens from `/sync` could not be used with the `/relations` endpoint.
diff --git a/changelog.d/14879.misc b/changelog.d/14879.misc
new file mode 100644
index 0000000000..d44571b731
--- /dev/null
+++ b/changelog.d/14879.misc
@@ -0,0 +1 @@
+Add missing type hints.
diff --git a/changelog.d/14880.bugfix b/changelog.d/14880.bugfix
new file mode 100644
index 0000000000..e56c567082
--- /dev/null
+++ b/changelog.d/14880.bugfix
@@ -0,0 +1 @@
+Fix a bug when using the `send_local_online_presence_to` module API.
diff --git a/changelog.d/14886.misc b/changelog.d/14886.misc
new file mode 100644
index 0000000000..9f5384e60e
--- /dev/null
+++ b/changelog.d/14886.misc
@@ -0,0 +1 @@
+Add missing type hints.
\ No newline at end of file
diff --git a/changelog.d/14887.misc b/changelog.d/14887.misc
new file mode 100644
index 0000000000..9f5384e60e
--- /dev/null
+++ b/changelog.d/14887.misc
@@ -0,0 +1 @@
+Add missing type hints.
\ No newline at end of file
diff --git a/changelog.d/14904.misc b/changelog.d/14904.misc
new file mode 100644
index 0000000000..d44571b731
--- /dev/null
+++ b/changelog.d/14904.misc
@@ -0,0 +1 @@
+Add missing type hints.
diff --git a/changelog.d/14916.misc b/changelog.d/14916.misc
new file mode 100644
index 0000000000..59914d4b8a
--- /dev/null
+++ b/changelog.d/14916.misc
@@ -0,0 +1 @@
+Document how to handle Dependabot pull requests.
diff --git a/changelog.d/14920.misc b/changelog.d/14920.misc
new file mode 100644
index 0000000000..7988897d39
--- /dev/null
+++ b/changelog.d/14920.misc
@@ -0,0 +1 @@
+Fix typo in release script.
diff --git a/changelog.d/14922.misc b/changelog.d/14922.misc
new file mode 100644
index 0000000000..2cc3614dfd
--- /dev/null
+++ b/changelog.d/14922.misc
@@ -0,0 +1 @@
+Use `StrCollection` to avoid potential bugs with `Collection[str]`.
diff --git a/changelog.d/14927.misc b/changelog.d/14927.misc
new file mode 100644
index 0000000000..9f5384e60e
--- /dev/null
+++ b/changelog.d/14927.misc
@@ -0,0 +1 @@
+Add missing type hints.
\ No newline at end of file
diff --git a/docs/development/dependencies.md b/docs/development/dependencies.md
index b734cc5826..c4449c51f7 100644
--- a/docs/development/dependencies.md
+++ b/docs/development/dependencies.md
@@ -258,6 +258,20 @@ because [`build`](https://github.com/pypa/build) is a standardish tool which
 doesn't require poetry. (It's what we use in CI too). However, you could try
 `poetry build` too.
 
+## ...handle a Dependabot pull request?
+
+Synapse uses Dependabot to keep the `poetry.lock` file up-to-date. When it
+creates a pull request a GitHub Action will run to automatically create a changelog
+file. Ensure that:
+
+* the lockfile changes look reasonable;
+* the upstream changelog file (linked in the description) doesn't include any
+  breaking changes;
+* continuous integration passes (due to permissions, the GitHub Actions run on
+  the changelog commit will fail, look at the initial commit of the pull request);
+
+In particular, any updates to the type hints (usually packages which start with `types-`)
+should be safe to merge if linting passes.
 
 # Troubleshooting
 
diff --git a/mypy.ini b/mypy.ini
index 63366dad99..978d92940b 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -35,19 +35,12 @@ exclude = (?x)
    |tests/api/test_auth.py
    |tests/app/test_openid_listener.py
    |tests/appservice/test_scheduler.py
-   |tests/events/test_presence_router.py
-   |tests/events/test_utils.py
    |tests/federation/test_federation_catch_up.py
    |tests/federation/test_federation_sender.py
-   |tests/federation/transport/test_knocking.py
-   |tests/handlers/test_typing.py
    |tests/http/federation/test_matrix_federation_agent.py
    |tests/http/federation/test_srv_resolver.py
    |tests/http/test_proxyagent.py
-   |tests/logging/__init__.py
-   |tests/logging/test_terse_json.py
    |tests/module_api/test_api.py
-   |tests/rest/client/test_transactions.py
    |tests/rest/media/v1/test_media_storage.py
    |tests/server.py
    |tests/test_state.py
@@ -87,12 +80,18 @@ disallow_untyped_defs = True
 [mypy-tests.crypto.*]
 disallow_untyped_defs = True
 
+[mypy-tests.events.*]
+disallow_untyped_defs = True
+
 [mypy-tests.federation.transport.test_client]
 disallow_untyped_defs = True
 
 [mypy-tests.handlers.*]
 disallow_untyped_defs = True
 
+[mypy-tests.logging.*]
+disallow_untyped_defs = True
+
 [mypy-tests.metrics.*]
 disallow_untyped_defs = True
 
diff --git a/scripts-dev/release.py b/scripts-dev/release.py
index 6974fd7895..008a5bd965 100755
--- a/scripts-dev/release.py
+++ b/scripts-dev/release.py
@@ -438,7 +438,7 @@ def _upload(gh_token: Optional[str]) -> None:
     repo = get_repo_and_check_clean_checkout()
     tag = repo.tag(f"refs/tags/{tag_name}")
     if repo.head.commit != tag.commit:
-        click.echo("Tag {tag_name} (tag.commit) is not currently checked out!")
+        click.echo(f"Tag {tag_name} ({tag.commit}) is not currently checked out!")
         click.get_current_context().abort()
 
     # Query all the assets corresponding to this release.
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 6432d32d83..6f9239d21c 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -17,6 +17,8 @@
 
 """Contains constants from the specification."""
 
+import enum
+
 from typing_extensions import Final
 
 # the max size of a (canonical-json-encoded) event
@@ -290,3 +292,8 @@ class ApprovalNoticeMedium:
 
     NONE = "org.matrix.msc3866.none"
     EMAIL = "org.matrix.msc3866.email"
+
+
+class Direction(enum.Enum):
+    BACKWARDS = "b"
+    FORWARDS = "f"
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index ae57a4df5e..52e4b467e8 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -605,10 +605,11 @@ class EventClientSerializer:
 
 
 _PowerLevel = Union[str, int]
+PowerLevelsContent = Mapping[str, Union[_PowerLevel, Mapping[str, _PowerLevel]]]
 
 
 def copy_and_fixup_power_levels_contents(
-    old_power_levels: Mapping[str, Union[_PowerLevel, Mapping[str, _PowerLevel]]]
+    old_power_levels: PowerLevelsContent,
 ) -> Dict[str, Union[int, Dict[str, int]]]:
     """Copy the content of a power_levels event, unfreezing frozendicts along the way.
 
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 834006356a..d500b21809 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import logging
 import random
-from typing import TYPE_CHECKING, Awaitable, Callable, Collection, List, Optional, Tuple
+from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple
 
 from synapse.api.constants import AccountDataTypes
 from synapse.replication.http.account_data import (
@@ -26,7 +26,7 @@ from synapse.replication.http.account_data import (
     ReplicationRemoveUserAccountDataRestServlet,
 )
 from synapse.streams import EventSource
-from synapse.types import JsonDict, StreamKeyType, UserID
+from synapse.types import JsonDict, StrCollection, StreamKeyType, UserID
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -322,7 +322,7 @@ class AccountDataEventSource(EventSource[int, JsonDict]):
         user: UserID,
         from_key: int,
         limit: int,
-        room_ids: Collection[str],
+        room_ids: StrCollection,
         is_guest: bool,
         explicit_room_id: Optional[str] = None,
     ) -> Tuple[List[JsonDict], int]:
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 5bf8e86387..c81ea34758 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -16,7 +16,7 @@ import abc
 import logging
 from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
 
-from synapse.api.constants import Membership
+from synapse.api.constants import Direction, Membership
 from synapse.events import EventBase
 from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID
 from synapse.visibility import filter_events_for_client
@@ -197,7 +197,7 @@ class AdminHandler:
             # efficient method perhaps but it does guarantee we get everything.
             while True:
                 events, _ = await self.store.paginate_room_events(
-                    room_id, from_key, to_key, limit=100, direction="f"
+                    room_id, from_key, to_key, limit=100, direction=Direction.FORWARDS
                 )
                 if not events:
                     break
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 58180ae2fa..5c06073901 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -18,7 +18,6 @@ from http import HTTPStatus
 from typing import (
     TYPE_CHECKING,
     Any,
-    Collection,
     Dict,
     Iterable,
     List,
@@ -45,6 +44,7 @@ from synapse.metrics.background_process_metrics import (
 )
 from synapse.types import (
     JsonDict,
+    StrCollection,
     StreamKeyType,
     StreamToken,
     UserID,
@@ -146,7 +146,7 @@ class DeviceWorkerHandler:
 
     @cancellable
     async def get_device_changes_in_shared_rooms(
-        self, user_id: str, room_ids: Collection[str], from_token: StreamToken
+        self, user_id: str, room_ids: StrCollection, from_token: StreamToken
     ) -> Set[str]:
         """Get the set of users whose devices have changed who share a room with
         the given user.
@@ -551,7 +551,7 @@ class DeviceHandler(DeviceWorkerHandler):
     @trace
     @measure_func("notify_device_update")
     async def notify_device_update(
-        self, user_id: str, device_ids: Collection[str]
+        self, user_id: str, device_ids: StrCollection
     ) -> None:
         """Notify that a user's device(s) has changed. Pokes the notifier, and
         remote servers if the user is local.
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index f91dbbecb7..a23a8ce2a1 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Collection, List, Mapping, Optional, Union
+from typing import TYPE_CHECKING, List, Mapping, Optional, Union
 
 from synapse import event_auth
 from synapse.api.constants import (
@@ -29,7 +29,7 @@ from synapse.event_auth import (
 )
 from synapse.events import EventBase
 from synapse.events.builder import EventBuilder
-from synapse.types import StateMap, get_domain_from_id
+from synapse.types import StateMap, StrCollection, get_domain_from_id
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -290,7 +290,7 @@ class EventAuthHandler:
 
     async def get_rooms_that_allow_join(
         self, state_ids: StateMap[str]
-    ) -> Collection[str]:
+    ) -> StrCollection:
         """
         Generate a list of rooms in which membership allows access to a room.
 
@@ -331,7 +331,7 @@ class EventAuthHandler:
 
         return result
 
-    async def is_user_in_rooms(self, room_ids: Collection[str], user_id: str) -> bool:
+    async def is_user_in_rooms(self, room_ids: StrCollection, user_id: str) -> bool:
         """
         Check whether a user is a member of any of the provided rooms.
 
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 233f8c113d..dc1cbf5c3d 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -20,17 +20,7 @@ import itertools
 import logging
 from enum import Enum
 from http import HTTPStatus
-from typing import (
-    TYPE_CHECKING,
-    Collection,
-    Dict,
-    Iterable,
-    List,
-    Optional,
-    Set,
-    Tuple,
-    Union,
-)
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple, Union
 
 import attr
 from prometheus_client import Histogram
@@ -70,7 +60,7 @@ from synapse.replication.http.federation import (
 )
 from synapse.storage.databases.main.events import PartialStateConflictError
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.types import JsonDict, get_domain_from_id
+from synapse.types import JsonDict, StrCollection, get_domain_from_id
 from synapse.types.state import StateFilter
 from synapse.util.async_helpers import Linearizer
 from synapse.util.retryutils import NotRetryingDestination
@@ -179,7 +169,7 @@ class FederationHandler:
         # A dictionary mapping room IDs to (initial destination, other destinations)
         # tuples.
         self._partial_state_syncs_maybe_needing_restart: Dict[
-            str, Tuple[Optional[str], Collection[str]]
+            str, Tuple[Optional[str], StrCollection]
         ] = {}
         # A lock guarding the partial state flag for rooms.
         # When the lock is held for a given room, no other concurrent code may
@@ -437,7 +427,7 @@ class FederationHandler:
             )
         )
 
-        async def try_backfill(domains: Collection[str]) -> bool:
+        async def try_backfill(domains: StrCollection) -> bool:
             # TODO: Should we try multiple of these at a time?
 
             # Number of contacted remote homeservers that have denied our backfill
@@ -1730,7 +1720,7 @@ class FederationHandler:
     def _start_partial_state_room_sync(
         self,
         initial_destination: Optional[str],
-        other_destinations: Collection[str],
+        other_destinations: StrCollection,
         room_id: str,
     ) -> None:
         """Starts the background process to resync the state of a partial state room,
@@ -1812,7 +1802,7 @@ class FederationHandler:
     async def _sync_partial_state_room(
         self,
         initial_destination: Optional[str],
-        other_destinations: Collection[str],
+        other_destinations: StrCollection,
         room_id: str,
     ) -> None:
         """Background process to resync the state of a partial-state room
@@ -1949,9 +1939,9 @@ class FederationHandler:
 
 def _prioritise_destinations_for_partial_state_resync(
     initial_destination: Optional[str],
-    other_destinations: Collection[str],
+    other_destinations: StrCollection,
     room_id: str,
-) -> Collection[str]:
+) -> StrCollection:
     """Work out the order in which we should ask servers to resync events.
 
     If an `initial_destination` is given, it takes top priority. Otherwise
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 904a721483..e037acbca2 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -80,6 +80,7 @@ from synapse.types import (
     PersistedEventPosition,
     RoomStreamToken,
     StateMap,
+    StrCollection,
     UserID,
     get_domain_from_id,
 )
@@ -615,7 +616,7 @@ class FederationEventHandler:
 
     @trace
     async def backfill(
-        self, dest: str, room_id: str, limit: int, extremities: Collection[str]
+        self, dest: str, room_id: str, limit: int, extremities: StrCollection
     ) -> None:
         """Trigger a backfill request to `dest` for the given `room_id`
 
@@ -1565,7 +1566,7 @@ class FederationEventHandler:
     @trace
     @tag_args
     async def _get_events_and_persist(
-        self, destination: str, room_id: str, event_ids: Collection[str]
+        self, destination: str, room_id: str, event_ids: StrCollection
     ) -> None:
         """Fetch the given events from a server, and persist them as outliers.
 
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 8c2260ad7d..191529bd8e 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -15,7 +15,13 @@
 import logging
 from typing import TYPE_CHECKING, List, Optional, Tuple, cast
 
-from synapse.api.constants import AccountDataTypes, EduTypes, EventTypes, Membership
+from synapse.api.constants import (
+    AccountDataTypes,
+    Direction,
+    EduTypes,
+    EventTypes,
+    Membership,
+)
 from synapse.api.errors import SynapseError
 from synapse.events import EventBase
 from synapse.events.utils import SerializeEventConfig
@@ -57,7 +63,13 @@ class InitialSyncHandler:
         self.validator = EventValidator()
         self.snapshot_cache: ResponseCache[
             Tuple[
-                str, Optional[StreamToken], Optional[StreamToken], str, int, bool, bool
+                str,
+                Optional[StreamToken],
+                Optional[StreamToken],
+                Direction,
+                int,
+                bool,
+                bool,
             ]
         ] = ResponseCache(hs.get_clock(), "initial_sync_cache")
         self._event_serializer = hs.get_event_client_serializer()
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 8c8ff18a1a..ceefa16b49 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -13,13 +13,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set
+from typing import TYPE_CHECKING, Dict, List, Optional, Set
 
 import attr
 
 from twisted.python.failure import Failure
 
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import Direction, EventTypes, Membership
 from synapse.api.errors import SynapseError
 from synapse.api.filtering import Filter
 from synapse.events.utils import SerializeEventConfig
@@ -28,7 +28,7 @@ from synapse.logging.opentracing import trace
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.rest.admin._base import assert_user_is_admin
 from synapse.streams.config import PaginationConfig
-from synapse.types import JsonDict, Requester, StreamKeyType
+from synapse.types import JsonDict, Requester, StrCollection, StreamKeyType
 from synapse.types.state import StateFilter
 from synapse.util.async_helpers import ReadWriteLock
 from synapse.util.stringutils import random_string
@@ -391,7 +391,7 @@ class PaginationHandler:
         """
         return self._delete_by_id.get(delete_id)
 
-    def get_delete_ids_by_room(self, room_id: str) -> Optional[Collection[str]]:
+    def get_delete_ids_by_room(self, room_id: str) -> Optional[StrCollection]:
         """Get all active delete ids by room
 
         Args:
@@ -448,7 +448,7 @@ class PaginationHandler:
 
         if pagin_config.from_token:
             from_token = pagin_config.from_token
-        elif pagin_config.direction == "f":
+        elif pagin_config.direction == Direction.FORWARDS:
             from_token = (
                 await self.hs.get_event_sources().get_start_token_for_pagination(
                     room_id
@@ -476,7 +476,7 @@ class PaginationHandler:
                     room_id, requester, allow_departed_users=True
                 )
 
-            if pagin_config.direction == "b":
+            if pagin_config.direction == Direction.BACKWARDS:
                 # if we're going backwards, we might need to backfill. This
                 # requires that we have a topo token.
                 if room_token.topological:
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 43e4e7b1b4..87af31aa27 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -64,7 +64,13 @@ from synapse.replication.tcp.commands import ClearUserSyncsCommand
 from synapse.replication.tcp.streams import PresenceFederationStream, PresenceStream
 from synapse.storage.databases.main import DataStore
 from synapse.streams import EventSource
-from synapse.types import JsonDict, StreamKeyType, UserID, get_domain_from_id
+from synapse.types import (
+    JsonDict,
+    StrCollection,
+    StreamKeyType,
+    UserID,
+    get_domain_from_id,
+)
 from synapse.util.async_helpers import Linearizer
 from synapse.util.metrics import Measure
 from synapse.util.wheel_timer import WheelTimer
@@ -320,7 +326,7 @@ class BasePresenceHandler(abc.ABC):
         for destination, host_states in hosts_to_states.items():
             self._federation.send_presence_to_destinations(host_states, [destination])
 
-    async def send_full_presence_to_users(self, user_ids: Collection[str]) -> None:
+    async def send_full_presence_to_users(self, user_ids: StrCollection) -> None:
         """
         Adds to the list of users who should receive a full snapshot of presence
         upon their next sync. Note that this only works for local users.
@@ -1601,7 +1607,7 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
         # Having a default limit doesn't match the EventSource API, but some
         # callers do not provide it. It is unused in this class.
         limit: int = 0,
-        room_ids: Optional[Collection[str]] = None,
+        room_ids: Optional[StrCollection] = None,
         is_guest: bool = False,
         explicit_room_id: Optional[str] = None,
         include_offline: bool = True,
@@ -1688,7 +1694,7 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
 
             # The set of users that we're interested in and that have had a presence update.
             # We'll actually pull the presence updates for these users at the end.
-            interested_and_updated_users: Collection[str]
+            interested_and_updated_users: StrCollection
 
             if from_key is not None:
                 # First get all users that have had a presence update
@@ -2120,7 +2126,7 @@ class PresenceFederationQueue:
         # stream_id, destinations, user_ids)`. We don't store the full states
         # for efficiency, and remote workers will already have the full states
         # cached.
-        self._queue: List[Tuple[int, int, Collection[str], Set[str]]] = []
+        self._queue: List[Tuple[int, int, StrCollection, Set[str]]] = []
 
         self._next_id = 1
 
@@ -2142,7 +2148,7 @@ class PresenceFederationQueue:
         self._queue = self._queue[index:]
 
     def send_presence_to_destinations(
-        self, states: Collection[UserPresenceState], destinations: Collection[str]
+        self, states: Collection[UserPresenceState], destinations: StrCollection
     ) -> None:
         """Send the presence states to the given destinations.
 
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index e96f9999a8..0fb15391e0 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Collection, Dict, FrozenSet, Iterable, List, O
 
 import attr
 
-from synapse.api.constants import EventTypes, RelationTypes
+from synapse.api.constants import Direction, EventTypes, RelationTypes
 from synapse.api.errors import SynapseError
 from synapse.events import EventBase, relation_from_event
 from synapse.logging.context import make_deferred_yieldable, run_in_background
@@ -413,7 +413,11 @@ class RelationsHandler:
 
                 # Attempt to find another event to use as the latest event.
                 potential_events, _ = await self._main_store.get_relations_for_event(
-                    event_id, event, room_id, RelationTypes.THREAD, direction="f"
+                    event_id,
+                    event,
+                    room_id,
+                    RelationTypes.THREAD,
+                    direction=Direction.FORWARDS,
                 )
 
                 # Filter out ignored users.
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 572c7b4db3..60a6d9cf3c 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -20,16 +20,7 @@ import random
 import string
 from collections import OrderedDict
 from http import HTTPStatus
-from typing import (
-    TYPE_CHECKING,
-    Any,
-    Awaitable,
-    Collection,
-    Dict,
-    List,
-    Optional,
-    Tuple,
-)
+from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple
 
 import attr
 from typing_extensions import TypedDict
@@ -72,6 +63,7 @@ from synapse.types import (
     RoomID,
     RoomStreamToken,
     StateMap,
+    StrCollection,
     StreamKeyType,
     StreamToken,
     UserID,
@@ -1644,7 +1636,7 @@ class RoomEventSource(EventSource[RoomStreamToken, EventBase]):
         user: UserID,
         from_key: RoomStreamToken,
         limit: int,
-        room_ids: Collection[str],
+        room_ids: StrCollection,
         is_guest: bool,
         explicit_room_id: Optional[str] = None,
     ) -> Tuple[List[EventBase], RoomStreamToken]:
diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index c6b869c6f4..4472019fbc 100644
--- a/synapse/handlers/room_summary.py
+++ b/synapse/handlers/room_summary.py
@@ -36,7 +36,7 @@ from synapse.api.errors import (
 )
 from synapse.api.ratelimiting import Ratelimiter
 from synapse.events import EventBase
-from synapse.types import JsonDict, Requester
+from synapse.types import JsonDict, Requester, StrCollection
 from synapse.util.caches.response_cache import ResponseCache
 
 if TYPE_CHECKING:
@@ -870,7 +870,7 @@ class _RoomQueueEntry:
     # The room ID of this entry.
     room_id: str
     # The server to query if the room is not known locally.
-    via: Sequence[str]
+    via: StrCollection
     # The minimum number of hops necessary to get to this room (compared to the
     # originally requested room).
     depth: int = 0
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 40f4635c4e..9bbf83047d 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -14,7 +14,7 @@
 
 import itertools
 import logging
-from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
 
 import attr
 from unpaddedbase64 import decode_base64, encode_base64
@@ -23,7 +23,7 @@ from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import NotFoundError, SynapseError
 from synapse.api.filtering import Filter
 from synapse.events import EventBase
-from synapse.types import JsonDict, StreamKeyType, UserID
+from synapse.types import JsonDict, StrCollection, StreamKeyType, UserID
 from synapse.types.state import StateFilter
 from synapse.visibility import filter_events_for_client
 
@@ -418,7 +418,7 @@ class SearchHandler:
     async def _search_by_rank(
         self,
         user: UserID,
-        room_ids: Collection[str],
+        room_ids: StrCollection,
         search_term: str,
         keys: Iterable[str],
         search_filter: Filter,
@@ -491,7 +491,7 @@ class SearchHandler:
     async def _search_by_recent(
         self,
         user: UserID,
-        room_ids: Collection[str],
+        room_ids: StrCollection,
         search_term: str,
         keys: Iterable[str],
         search_filter: Filter,
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 44e70fc4b8..4a27c0f051 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -20,7 +20,6 @@ from typing import (
     Any,
     Awaitable,
     Callable,
-    Collection,
     Dict,
     Iterable,
     List,
@@ -47,6 +46,7 @@ from synapse.http.server import respond_with_html, respond_with_redirect
 from synapse.http.site import SynapseRequest
 from synapse.types import (
     JsonDict,
+    StrCollection,
     UserID,
     contains_invalid_mxid_characters,
     create_requester,
@@ -141,7 +141,8 @@ class UserAttributes:
     confirm_localpart: bool = False
     display_name: Optional[str] = None
     picture: Optional[str] = None
-    emails: Collection[str] = attr.Factory(list)
+    # mypy thinks these are incompatible for some reason.
+    emails: StrCollection = attr.Factory(list)  # type: ignore[assignment]
 
 
 @attr.s(slots=True, auto_attribs=True)
@@ -159,7 +160,7 @@ class UsernameMappingSession:
 
     # attributes returned by the ID mapper
     display_name: Optional[str]
-    emails: Collection[str]
+    emails: StrCollection
 
     # An optional dictionary of extra attributes to be provided to the client in the
     # login response.
@@ -174,7 +175,7 @@ class UsernameMappingSession:
     # choices made by the user
     chosen_localpart: Optional[str] = None
     use_display_name: bool = True
-    emails_to_use: Collection[str] = ()
+    emails_to_use: StrCollection = ()
     terms_accepted_version: Optional[str] = None
 
 
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 5ebd3ea855..5235e29460 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -17,7 +17,6 @@ from typing import (
     TYPE_CHECKING,
     AbstractSet,
     Any,
-    Collection,
     Dict,
     FrozenSet,
     List,
@@ -62,6 +61,7 @@ from synapse.types import (
     Requester,
     RoomStreamToken,
     StateMap,
+    StrCollection,
     StreamKeyType,
     StreamToken,
     UserID,
@@ -1179,7 +1179,7 @@ class SyncHandler:
     async def _find_missing_partial_state_memberships(
         self,
         room_id: str,
-        members_to_fetch: Collection[str],
+        members_to_fetch: StrCollection,
         events_with_membership_auth: Mapping[str, EventBase],
         found_state_ids: StateMap[str],
     ) -> StateMap[str]:
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 6153a48257..d22dd19d38 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -1158,7 +1158,7 @@ class ModuleApi:
             # Send to remote destinations.
             destination = UserID.from_string(user).domain
             presence_handler.get_federation_queue().send_presence_to_destinations(
-                presence_events, destination
+                presence_events, [destination]
             )
 
     def looping_background_call(
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 2b0e52f23c..a8832a3f8e 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -46,6 +46,7 @@ from synapse.types import (
     JsonDict,
     PersistedEventPosition,
     RoomStreamToken,
+    StrCollection,
     StreamKeyType,
     StreamToken,
     UserID,
@@ -716,7 +717,7 @@ class Notifier:
 
     async def _get_room_ids(
         self, user: UserID, explicit_room_id: Optional[str]
-    ) -> Tuple[Collection[str], bool]:
+    ) -> Tuple[StrCollection, bool]:
         joined_room_ids = await self.store.get_rooms_for_user(user.to_string())
         if explicit_room_id:
             if explicit_room_id in joined_room_ids:
diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py
index 8191b4e32c..ad5c10c99d 100644
--- a/synapse/rest/client/push_rule.py
+++ b/synapse/rest/client/push_rule.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import TYPE_CHECKING, List, Sequence, Tuple, Union
+from typing import TYPE_CHECKING, List, Tuple, Union
 
 from synapse.api.errors import (
     NotFoundError,
@@ -169,7 +169,7 @@ class PushRuleRestServlet(RestServlet):
             raise UnrecognizedRequestError()
 
 
-def _rule_spec_from_path(path: Sequence[str]) -> RuleSpec:
+def _rule_spec_from_path(path: List[str]) -> RuleSpec:
     """Turn a sequence of path components into a rule spec
 
     Args:
diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py
index 61375651bc..3f40f1874a 100644
--- a/synapse/rest/client/transactions.py
+++ b/synapse/rest/client/transactions.py
@@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Tuple
 
 from typing_extensions import ParamSpec
 
+from twisted.internet.defer import Deferred
 from twisted.python.failure import Failure
 from twisted.web.server import Request
 
@@ -90,7 +91,7 @@ class HttpTransactionCache:
         fn: Callable[P, Awaitable[Tuple[int, JsonDict]]],
         *args: P.args,
         **kwargs: P.kwargs,
-    ) -> Awaitable[Tuple[int, JsonDict]]:
+    ) -> "Deferred[Tuple[int, JsonDict]]":
         """Fetches the response for this transaction, or executes the given function
         to produce a response for this transaction.
 
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 84f844b79e..0018d6f7ab 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -30,7 +30,7 @@ from typing import (
 
 import attr
 
-from synapse.api.constants import MAIN_TIMELINE, RelationTypes
+from synapse.api.constants import MAIN_TIMELINE, Direction, RelationTypes
 from synapse.api.errors import SynapseError
 from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore
@@ -40,9 +40,13 @@ from synapse.storage.database import (
     LoggingTransaction,
     make_in_list_sql_clause,
 )
-from synapse.storage.databases.main.stream import generate_pagination_where_clause
+from synapse.storage.databases.main.stream import (
+    generate_next_token,
+    generate_pagination_bounds,
+    generate_pagination_where_clause,
+)
 from synapse.storage.engines import PostgresEngine
-from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken
+from synapse.types import JsonDict, StreamKeyType, StreamToken
 from synapse.util.caches.descriptors import cached, cachedList
 
 if TYPE_CHECKING:
@@ -164,7 +168,7 @@ class RelationsWorkerStore(SQLBaseStore):
         relation_type: Optional[str] = None,
         event_type: Optional[str] = None,
         limit: int = 5,
-        direction: str = "b",
+        direction: Direction = Direction.BACKWARDS,
         from_token: Optional[StreamToken] = None,
         to_token: Optional[StreamToken] = None,
     ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
@@ -177,8 +181,8 @@ class RelationsWorkerStore(SQLBaseStore):
             relation_type: Only fetch events with this relation type, if given.
             event_type: Only fetch events with this event type, if given.
             limit: Only fetch the most recent `limit` events.
-            direction: Whether to fetch the most recent first (`"b"`) or the
-                oldest first (`"f"`).
+            direction: Whether to fetch the most recent first (backwards) or the
+                oldest first (forwards).
             from_token: Fetch rows from the given token, or from the start if None.
             to_token: Fetch rows up to the given token, or up to the end if None.
 
@@ -207,24 +211,23 @@ class RelationsWorkerStore(SQLBaseStore):
             where_clause.append("type = ?")
             where_args.append(event_type)
 
+        order, from_bound, to_bound = generate_pagination_bounds(
+            direction,
+            from_token.room_key if from_token else None,
+            to_token.room_key if to_token else None,
+        )
+
         pagination_clause = generate_pagination_where_clause(
             direction=direction,
             column_names=("topological_ordering", "stream_ordering"),
-            from_token=from_token.room_key.as_historical_tuple()
-            if from_token
-            else None,
-            to_token=to_token.room_key.as_historical_tuple() if to_token else None,
+            from_token=from_bound,
+            to_token=to_bound,
             engine=self.database_engine,
         )
 
         if pagination_clause:
             where_clause.append(pagination_clause)
 
-        if direction == "b":
-            order = "DESC"
-        else:
-            order = "ASC"
-
         sql = """
             SELECT event_id, relation_type, sender, topological_ordering, stream_ordering
             FROM event_relations
@@ -266,16 +269,9 @@ class RelationsWorkerStore(SQLBaseStore):
                 topo_orderings = topo_orderings[:limit]
                 stream_orderings = stream_orderings[:limit]
 
-                topo = topo_orderings[-1]
-                token = stream_orderings[-1]
-                if direction == "b":
-                    # Tokens are positions between events.
-                    # This token points *after* the last event in the chunk.
-                    # We need it to point to the event before it in the chunk
-                    # when we are going backwards so we subtract one from the
-                    # stream part.
-                    token -= 1
-                next_key = RoomStreamToken(topo, token)
+                next_key = generate_next_token(
+                    direction, topo_orderings[-1], stream_orderings[-1]
+                )
 
                 if from_token:
                     next_token = from_token.copy_and_replace(
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index d28fc65df9..818c46182e 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -55,6 +55,7 @@ from typing_extensions import Literal
 
 from twisted.internet import defer
 
+from synapse.api.constants import Direction
 from synapse.api.filtering import Filter
 from synapse.events import EventBase
 from synapse.logging.context import make_deferred_yieldable, run_in_background
@@ -86,7 +87,6 @@ MAX_STREAM_SIZE = 1000
 _STREAM_TOKEN = "stream"
 _TOPOLOGICAL_TOKEN = "topological"
 
-
 # Used as return values for pagination APIs
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class _EventDictReturn:
@@ -104,7 +104,7 @@ class _EventsAround:
 
 
 def generate_pagination_where_clause(
-    direction: str,
+    direction: Direction,
     column_names: Tuple[str, str],
     from_token: Optional[Tuple[Optional[int], int]],
     to_token: Optional[Tuple[Optional[int], int]],
@@ -130,27 +130,26 @@ def generate_pagination_where_clause(
           token, but include those that match the to token.
 
     Args:
-        direction: Whether we're paginating backwards("b") or forwards ("f").
+        direction: Whether we're paginating backwards or forwards.
         column_names: The column names to bound. Must *not* be user defined as
             these get inserted directly into the SQL statement without escapes.
         from_token: The start point for the pagination. This is an exclusive
-            minimum bound if direction is "f", and an inclusive maximum bound if
-            direction is "b".
+            minimum bound if direction is forwards, and an inclusive maximum bound if
+            direction is backwards.
         to_token: The endpoint point for the pagination. This is an inclusive
-            maximum bound if direction is "f", and an exclusive minimum bound if
-            direction is "b".
+            maximum bound if direction is forwards, and an exclusive minimum bound if
+            direction is backwards.
         engine: The database engine to generate the clauses for
 
     Returns:
         The sql expression
     """
-    assert direction in ("b", "f")
 
     where_clause = []
     if from_token:
         where_clause.append(
             _make_generic_sql_bound(
-                bound=">=" if direction == "b" else "<",
+                bound=">=" if direction == Direction.BACKWARDS else "<",
                 column_names=column_names,
                 values=from_token,
                 engine=engine,
@@ -160,7 +159,7 @@ def generate_pagination_where_clause(
     if to_token:
         where_clause.append(
             _make_generic_sql_bound(
-                bound="<" if direction == "b" else ">=",
+                bound="<" if direction == Direction.BACKWARDS else ">=",
                 column_names=column_names,
                 values=to_token,
                 engine=engine,
@@ -170,6 +169,104 @@ def generate_pagination_where_clause(
     return " AND ".join(where_clause)
 
 
+def generate_pagination_bounds(
+    direction: Direction,
+    from_token: Optional[RoomStreamToken],
+    to_token: Optional[RoomStreamToken],
+) -> Tuple[
+    str, Optional[Tuple[Optional[int], int]], Optional[Tuple[Optional[int], int]]
+]:
+    """
+    Generate a start and end point for this page of events.
+
+    Args:
+        direction: Whether pagination is going forwards or backwards.
+        from_token: The token to start pagination at, or None to start at the first value.
+        to_token: The token to end pagination at, or None to not limit the end point.
+
+    Returns:
+        A three tuple of:
+
+            ASC or DESC for sorting of the query.
+
+            The starting position as a tuple of ints representing
+            (topological position, stream position) or None if no from_token was
+            provided. The topological position may be None for live tokens.
+
+            The end position in the same format as the starting position, or None
+            if no to_token was provided.
+    """
+
+    # Tokens really represent positions between elements, but we use
+    # the convention of pointing to the event before the gap. Hence
+    # we have a bit of asymmetry when it comes to equalities.
+    if direction == Direction.BACKWARDS:
+        order = "DESC"
+    else:
+        order = "ASC"
+
+    # The bounds for the stream tokens are complicated by the fact
+    # that we need to handle the instance_map part of the tokens. We do this
+    # by fetching all events between the min stream token and the maximum
+    # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and
+    # then filtering the results.
+    from_bound: Optional[Tuple[Optional[int], int]] = None
+    if from_token:
+        if from_token.topological is not None:
+            from_bound = from_token.as_historical_tuple()
+        elif direction == Direction.BACKWARDS:
+            from_bound = (
+                None,
+                from_token.get_max_stream_pos(),
+            )
+        else:
+            from_bound = (
+                None,
+                from_token.stream,
+            )
+
+    to_bound: Optional[Tuple[Optional[int], int]] = None
+    if to_token:
+        if to_token.topological is not None:
+            to_bound = to_token.as_historical_tuple()
+        elif direction == Direction.BACKWARDS:
+            to_bound = (
+                None,
+                to_token.stream,
+            )
+        else:
+            to_bound = (
+                None,
+                to_token.get_max_stream_pos(),
+            )
+
+    return order, from_bound, to_bound
+
+
+def generate_next_token(
+    direction: Direction, last_topo_ordering: int, last_stream_ordering: int
+) -> RoomStreamToken:
+    """
+    Generate the next room stream token based on the currently returned data.
+
+    Args:
+        direction: Whether pagination is going forwards or backwards.
+        last_topo_ordering: The last topological ordering being returned.
+        last_stream_ordering: The last stream ordering being returned.
+
+    Returns:
+        A new RoomStreamToken to return to the client.
+    """
+    if direction == Direction.BACKWARDS:
+        # Tokens are positions between events.
+        # This token points *after* the last event in the chunk.
+        # We need it to point to the event before it in the chunk
+        # when we are going backwards so we subtract one from the
+        # stream part.
+        last_stream_ordering -= 1
+    return RoomStreamToken(last_topo_ordering, last_stream_ordering)
+
+
 def _make_generic_sql_bound(
     bound: str,
     column_names: Tuple[str, str],
@@ -1103,7 +1200,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             txn,
             room_id,
             before_token,
-            direction="b",
+            direction=Direction.BACKWARDS,
             limit=before_limit,
             event_filter=event_filter,
         )
@@ -1113,7 +1210,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             txn,
             room_id,
             after_token,
-            direction="f",
+            direction=Direction.FORWARDS,
             limit=after_limit,
             event_filter=event_filter,
         )
@@ -1276,7 +1373,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         room_id: str,
         from_token: RoomStreamToken,
         to_token: Optional[RoomStreamToken] = None,
-        direction: str = "b",
+        direction: Direction = Direction.BACKWARDS,
         limit: int = -1,
         event_filter: Optional[Filter] = None,
     ) -> Tuple[List[_EventDictReturn], RoomStreamToken]:
@@ -1287,8 +1384,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             room_id
             from_token: The token used to stream from
             to_token: A token which if given limits the results to only those before
-            direction: Either 'b' or 'f' to indicate whether we are paginating
-                forwards or backwards from `from_key`.
+            direction: Indicates whether we are paginating forwards or backwards
+                from `from_key`.
             limit: The maximum number of events to return.
             event_filter: If provided filters the events to
                 those that match the filter.
@@ -1300,47 +1397,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             `to_token`), or `limit` is zero.
         """
 
-        # Tokens really represent positions between elements, but we use
-        # the convention of pointing to the event before the gap. Hence
-        # we have a bit of asymmetry when it comes to equalities.
         args = [False, room_id]
-        if direction == "b":
-            order = "DESC"
-        else:
-            order = "ASC"
-
-        # The bounds for the stream tokens are complicated by the fact
-        # that we need to handle the instance_map part of the tokens. We do this
-        # by fetching all events between the min stream token and the maximum
-        # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and
-        # then filtering the results.
-        if from_token.topological is not None:
-            from_bound: Tuple[Optional[int], int] = from_token.as_historical_tuple()
-        elif direction == "b":
-            from_bound = (
-                None,
-                from_token.get_max_stream_pos(),
-            )
-        else:
-            from_bound = (
-                None,
-                from_token.stream,
-            )
 
-        to_bound: Optional[Tuple[Optional[int], int]] = None
-        if to_token:
-            if to_token.topological is not None:
-                to_bound = to_token.as_historical_tuple()
-            elif direction == "b":
-                to_bound = (
-                    None,
-                    to_token.stream,
-                )
-            else:
-                to_bound = (
-                    None,
-                    to_token.get_max_stream_pos(),
-                )
+        order, from_bound, to_bound = generate_pagination_bounds(
+            direction, from_token, to_token
+        )
 
         bounds = generate_pagination_where_clause(
             direction=direction,
@@ -1427,8 +1488,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             _EventDictReturn(event_id, topological_ordering, stream_ordering)
             for event_id, instance_name, topological_ordering, stream_ordering in txn
             if _filter_results(
-                lower_token=to_token if direction == "b" else from_token,
-                upper_token=from_token if direction == "b" else to_token,
+                lower_token=to_token
+                if direction == Direction.BACKWARDS
+                else from_token,
+                upper_token=from_token
+                if direction == Direction.BACKWARDS
+                else to_token,
                 instance_name=instance_name,
                 topological_ordering=topological_ordering,
                 stream_ordering=stream_ordering,
@@ -1436,16 +1501,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         ][:limit]
 
         if rows:
-            topo = rows[-1].topological_ordering
-            token = rows[-1].stream_ordering
-            if direction == "b":
-                # Tokens are positions between events.
-                # This token points *after* the last event in the chunk.
-                # We need it to point to the event before it in the chunk
-                # when we are going backwards so we subtract one from the
-                # stream part.
-                token -= 1
-            next_token = RoomStreamToken(topo, token)
+            assert rows[-1].topological_ordering is not None
+            next_token = generate_next_token(
+                direction, rows[-1].topological_ordering, rows[-1].stream_ordering
+            )
         else:
             # TODO (erikj): We should work out what to do here instead.
             next_token = to_token if to_token else from_token
@@ -1458,7 +1517,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         room_id: str,
         from_key: RoomStreamToken,
         to_key: Optional[RoomStreamToken] = None,
-        direction: str = "b",
+        direction: Direction = Direction.BACKWARDS,
         limit: int = -1,
         event_filter: Optional[Filter] = None,
     ) -> Tuple[List[EventBase], RoomStreamToken]:
@@ -1468,8 +1527,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             room_id
             from_key: The token used to stream from
             to_key: A token which if given limits the results to only those before
-            direction: Either 'b' or 'f' to indicate whether we are paginating
-                forwards or backwards from `from_key`.
+            direction: Indicates whether we are paginating forwards or backwards
+                from `from_key`.
             limit: The maximum number of events to return.
             event_filter: If provided filters the events to those that match the filter.
 
diff --git a/synapse/streams/__init__.py b/synapse/streams/__init__.py
index 2dcd43d0a2..c6c8a0315c 100644
--- a/synapse/streams/__init__.py
+++ b/synapse/streams/__init__.py
@@ -12,9 +12,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Collection, Generic, List, Optional, Tuple, TypeVar
+from typing import Generic, List, Optional, Tuple, TypeVar
 
-from synapse.types import UserID
+from synapse.types import StrCollection, UserID
 
 # The key, this is either a stream token or int.
 K = TypeVar("K")
@@ -28,7 +28,7 @@ class EventSource(Generic[K, R]):
         user: UserID,
         from_key: K,
         limit: int,
-        room_ids: Collection[str],
+        room_ids: StrCollection,
         is_guest: bool,
         explicit_room_id: Optional[str] = None,
     ) -> Tuple[List[R], K]:
diff --git a/synapse/streams/config.py b/synapse/streams/config.py
index 6df2de919c..5cb7875181 100644
--- a/synapse/streams/config.py
+++ b/synapse/streams/config.py
@@ -16,6 +16,7 @@ from typing import Optional
 
 import attr
 
+from synapse.api.constants import Direction
 from synapse.api.errors import SynapseError
 from synapse.http.servlet import parse_integer, parse_string
 from synapse.http.site import SynapseRequest
@@ -34,7 +35,7 @@ class PaginationConfig:
 
     from_token: Optional[StreamToken]
     to_token: Optional[StreamToken]
-    direction: str
+    direction: Direction
     limit: int
 
     @classmethod
@@ -45,9 +46,13 @@ class PaginationConfig:
         default_limit: int,
         default_dir: str = "f",
     ) -> "PaginationConfig":
-        direction = parse_string(
-            request, "dir", default=default_dir, allowed_values=["f", "b"]
+        direction_str = parse_string(
+            request,
+            "dir",
+            default=default_dir,
+            allowed_values=[Direction.FORWARDS.value, Direction.BACKWARDS.value],
         )
+        direction = Direction(direction_str)
 
         from_tok_str = parse_string(request, "from")
         to_tok_str = parse_string(request, "to")
diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py
index b703e4472e..a9893def74 100644
--- a/tests/events/test_presence_router.py
+++ b/tests/events/test_presence_router.py
@@ -16,6 +16,8 @@ from unittest.mock import Mock
 
 import attr
 
+from twisted.test.proto_helpers import MemoryReactor
+
 from synapse.api.constants import EduTypes
 from synapse.events.presence_router import PresenceRouter, load_legacy_presence_router
 from synapse.federation.units import Transaction
@@ -23,11 +25,13 @@ from synapse.handlers.presence import UserPresenceState
 from synapse.module_api import ModuleApi
 from synapse.rest import admin
 from synapse.rest.client import login, presence, room
+from synapse.server import HomeServer
 from synapse.types import JsonDict, StreamToken, create_requester
+from synapse.util import Clock
 
 from tests.handlers.test_sync import generate_sync_config
 from tests.test_utils import simple_async_mock
-from tests.unittest import FederatingHomeserverTestCase, TestCase, override_config
+from tests.unittest import FederatingHomeserverTestCase, override_config
 
 
 @attr.s
@@ -49,9 +53,7 @@ class LegacyPresenceRouterTestModule:
         }
         return users_to_state
 
-    async def get_interested_users(
-        self, user_id: str
-    ) -> Union[Set[str], PresenceRouter.ALL_USERS]:
+    async def get_interested_users(self, user_id: str) -> Union[Set[str], str]:
         if user_id in self._config.users_who_should_receive_all_presence:
             return PresenceRouter.ALL_USERS
 
@@ -71,9 +73,14 @@ class LegacyPresenceRouterTestModule:
         # Initialise a typed config object
         config = PresenceRouterTestConfig()
 
-        config.users_who_should_receive_all_presence = config_dict.get(
+        users_who_should_receive_all_presence = config_dict.get(
             "users_who_should_receive_all_presence"
         )
+        assert isinstance(users_who_should_receive_all_presence, list)
+
+        config.users_who_should_receive_all_presence = (
+            users_who_should_receive_all_presence
+        )
 
         return config
 
@@ -96,9 +103,7 @@ class PresenceRouterTestModule:
         }
         return users_to_state
 
-    async def get_interested_users(
-        self, user_id: str
-    ) -> Union[Set[str], PresenceRouter.ALL_USERS]:
+    async def get_interested_users(self, user_id: str) -> Union[Set[str], str]:
         if user_id in self._config.users_who_should_receive_all_presence:
             return PresenceRouter.ALL_USERS
 
@@ -118,9 +123,14 @@ class PresenceRouterTestModule:
         # Initialise a typed config object
         config = PresenceRouterTestConfig()
 
-        config.users_who_should_receive_all_presence = config_dict.get(
+        users_who_should_receive_all_presence = config_dict.get(
             "users_who_should_receive_all_presence"
         )
+        assert isinstance(users_who_should_receive_all_presence, list)
+
+        config.users_who_should_receive_all_presence = (
+            users_who_should_receive_all_presence
+        )
 
         return config
 
@@ -140,7 +150,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
         presence.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         # Mock out the calls over federation.
         fed_transport_client = Mock(spec=["send_transaction"])
         fed_transport_client.send_transaction = simple_async_mock({})
@@ -153,7 +163,9 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
 
         return hs
 
-    def prepare(self, reactor, clock, homeserver):
+    def prepare(
+        self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+    ) -> None:
         self.sync_handler = self.hs.get_sync_handler()
         self.module_api = homeserver.get_module_api()
 
@@ -176,7 +188,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
             },
         }
     )
-    def test_receiving_all_presence_legacy(self):
+    def test_receiving_all_presence_legacy(self) -> None:
         self.receiving_all_presence_test_body()
 
     @override_config(
@@ -193,10 +205,10 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
             ],
         }
     )
-    def test_receiving_all_presence(self):
+    def test_receiving_all_presence(self) -> None:
         self.receiving_all_presence_test_body()
 
-    def receiving_all_presence_test_body(self):
+    def receiving_all_presence_test_body(self) -> None:
         """Test that a user that does not share a room with another other can receive
         presence for them, due to presence routing.
         """
@@ -302,7 +314,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
             },
         }
     )
-    def test_send_local_online_presence_to_with_module_legacy(self):
+    def test_send_local_online_presence_to_with_module_legacy(self) -> None:
         self.send_local_online_presence_to_with_module_test_body()
 
     @override_config(
@@ -321,10 +333,10 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
             ],
         }
     )
-    def test_send_local_online_presence_to_with_module(self):
+    def test_send_local_online_presence_to_with_module(self) -> None:
         self.send_local_online_presence_to_with_module_test_body()
 
-    def send_local_online_presence_to_with_module_test_body(self):
+    def send_local_online_presence_to_with_module_test_body(self) -> None:
         """Tests that send_local_presence_to_users sends local online presence to a set
         of specified local and remote users, with a custom PresenceRouter module enabled.
         """
@@ -447,18 +459,18 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
                     continue
 
                 # EDUs can contain multiple presence updates
-                for presence_update in edu["content"]["push"]:
+                for presence_edu in edu["content"]["push"]:
                     # Check for presence updates that contain the user IDs we're after
-                    found_users.add(presence_update["user_id"])
+                    found_users.add(presence_edu["user_id"])
 
                     # Ensure that no offline states are being sent out
-                    self.assertNotEqual(presence_update["presence"], "offline")
+                    self.assertNotEqual(presence_edu["presence"], "offline")
 
         self.assertEqual(found_users, expected_users)
 
 
 def send_presence_update(
-    testcase: TestCase,
+    testcase: FederatingHomeserverTestCase,
     user_id: str,
     access_token: str,
     presence_state: str,
@@ -479,7 +491,7 @@ def send_presence_update(
 
 
 def sync_presence(
-    testcase: TestCase,
+    testcase: FederatingHomeserverTestCase,
     user_id: str,
     since_token: Optional[StreamToken] = None,
 ) -> Tuple[List[UserPresenceState], StreamToken]:
@@ -500,7 +512,7 @@ def sync_presence(
     requester = create_requester(user_id)
     sync_config = generate_sync_config(requester.user.to_string())
     sync_result = testcase.get_success(
-        testcase.sync_handler.wait_for_sync_for_user(
+        testcase.hs.get_sync_handler().wait_for_sync_for_user(
             requester, sync_config, since_token
         )
     )
diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py
index 8ddce83b83..6687c28e8f 100644
--- a/tests/events/test_snapshot.py
+++ b/tests/events/test_snapshot.py
@@ -12,9 +12,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.rest import admin
 from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests import unittest
 from tests.test_utils.event_injection import create_event
@@ -27,7 +32,7 @@ class TestEventContext(unittest.HomeserverTestCase):
         room.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastores().main
         self._storage_controllers = hs.get_storage_controllers()
 
@@ -35,7 +40,7 @@ class TestEventContext(unittest.HomeserverTestCase):
         self.user_tok = self.login("u1", "pass")
         self.room_id = self.helper.create_room_as(tok=self.user_tok)
 
-    def test_serialize_deserialize_msg(self):
+    def test_serialize_deserialize_msg(self) -> None:
         """Test that an EventContext for a message event is the same after
         serialize/deserialize.
         """
@@ -51,7 +56,7 @@ class TestEventContext(unittest.HomeserverTestCase):
 
         self._check_serialize_deserialize(event, context)
 
-    def test_serialize_deserialize_state_no_prev(self):
+    def test_serialize_deserialize_state_no_prev(self) -> None:
         """Test that an EventContext for a state event (with not previous entry)
         is the same after serialize/deserialize.
         """
@@ -67,7 +72,7 @@ class TestEventContext(unittest.HomeserverTestCase):
 
         self._check_serialize_deserialize(event, context)
 
-    def test_serialize_deserialize_state_prev(self):
+    def test_serialize_deserialize_state_prev(self) -> None:
         """Test that an EventContext for a state event (which replaces a
         previous entry) is the same after serialize/deserialize.
         """
@@ -84,7 +89,9 @@ class TestEventContext(unittest.HomeserverTestCase):
 
         self._check_serialize_deserialize(event, context)
 
-    def _check_serialize_deserialize(self, event, context):
+    def _check_serialize_deserialize(
+        self, event: EventBase, context: EventContext
+    ) -> None:
         serialized = self.get_success(context.serialize(event, self.store))
 
         d_context = EventContext.deserialize(self._storage_controllers, serialized)
diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py
index a79256846f..ff7b349d75 100644
--- a/tests/events/test_utils.py
+++ b/tests/events/test_utils.py
@@ -13,21 +13,24 @@
 # limitations under the License.
 
 import unittest as stdlib_unittest
+from typing import Any, List, Mapping, Optional
 
 from synapse.api.constants import EventContentFields
 from synapse.api.room_versions import RoomVersions
-from synapse.events import make_event_from_dict
+from synapse.events import EventBase, make_event_from_dict
 from synapse.events.utils import (
+    PowerLevelsContent,
     SerializeEventConfig,
     copy_and_fixup_power_levels_contents,
     maybe_upsert_event_field,
     prune_event,
     serialize_event,
 )
+from synapse.types import JsonDict
 from synapse.util.frozenutils import freeze
 
 
-def MockEvent(**kwargs):
+def MockEvent(**kwargs: Any) -> EventBase:
     if "event_id" not in kwargs:
         kwargs["event_id"] = "fake_event_id"
     if "type" not in kwargs:
@@ -60,7 +63,7 @@ class TestMaybeUpsertEventField(stdlib_unittest.TestCase):
 
 
 class PruneEventTestCase(stdlib_unittest.TestCase):
-    def run_test(self, evdict, matchdict, **kwargs):
+    def run_test(self, evdict: JsonDict, matchdict: JsonDict, **kwargs: Any) -> None:
         """
         Asserts that a new event constructed with `evdict` will look like
         `matchdict` when it is redacted.
@@ -74,7 +77,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase):
             prune_event(make_event_from_dict(evdict, **kwargs)).get_dict(), matchdict
         )
 
-    def test_minimal(self):
+    def test_minimal(self) -> None:
         self.run_test(
             {"type": "A", "event_id": "$test:domain"},
             {
@@ -86,7 +89,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase):
             },
         )
 
-    def test_basic_keys(self):
+    def test_basic_keys(self) -> None:
         """Ensure that the keys that should be untouched are kept."""
         # Note that some of the values below don't really make sense, but the
         # pruning of events doesn't worry about the values of any fields (with
@@ -138,7 +141,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase):
             room_version=RoomVersions.MSC2176,
         )
 
-    def test_unsigned(self):
+    def test_unsigned(self) -> None:
         """Ensure that unsigned properties get stripped (except age_ts and replaces_state)."""
         self.run_test(
             {
@@ -159,7 +162,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase):
             },
         )
 
-    def test_content(self):
+    def test_content(self) -> None:
         """The content dictionary should be stripped in most cases."""
         self.run_test(
             {"type": "C", "event_id": "$test:domain", "content": {"things": "here"}},
@@ -194,7 +197,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase):
                 },
             )
 
-    def test_create(self):
+    def test_create(self) -> None:
         """Create events are partially redacted until MSC2176."""
         self.run_test(
             {
@@ -223,7 +226,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase):
             room_version=RoomVersions.MSC2176,
         )
 
-    def test_power_levels(self):
+    def test_power_levels(self) -> None:
         """Power level events keep a variety of content keys."""
         self.run_test(
             {
@@ -273,7 +276,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase):
             room_version=RoomVersions.MSC2176,
         )
 
-    def test_alias_event(self):
+    def test_alias_event(self) -> None:
         """Alias events have special behavior up through room version 6."""
         self.run_test(
             {
@@ -302,7 +305,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase):
             room_version=RoomVersions.V6,
         )
 
-    def test_redacts(self):
+    def test_redacts(self) -> None:
         """Redaction events have no special behaviour until MSC2174/MSC2176."""
 
         self.run_test(
@@ -328,7 +331,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase):
             room_version=RoomVersions.MSC2176,
         )
 
-    def test_join_rules(self):
+    def test_join_rules(self) -> None:
         """Join rules events have changed behavior starting with MSC3083."""
         self.run_test(
             {
@@ -371,7 +374,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase):
             room_version=RoomVersions.V8,
         )
 
-    def test_member(self):
+    def test_member(self) -> None:
         """Member events have changed behavior starting with MSC3375."""
         self.run_test(
             {
@@ -417,12 +420,12 @@ class PruneEventTestCase(stdlib_unittest.TestCase):
 
 
 class SerializeEventTestCase(stdlib_unittest.TestCase):
-    def serialize(self, ev, fields):
+    def serialize(self, ev: EventBase, fields: Optional[List[str]]) -> JsonDict:
         return serialize_event(
             ev, 1479807801915, config=SerializeEventConfig(only_event_fields=fields)
         )
 
-    def test_event_fields_works_with_keys(self):
+    def test_event_fields_works_with_keys(self) -> None:
         self.assertEqual(
             self.serialize(
                 MockEvent(sender="@alice:localhost", room_id="!foo:bar"), ["room_id"]
@@ -430,7 +433,7 @@ class SerializeEventTestCase(stdlib_unittest.TestCase):
             {"room_id": "!foo:bar"},
         )
 
-    def test_event_fields_works_with_nested_keys(self):
+    def test_event_fields_works_with_nested_keys(self) -> None:
         self.assertEqual(
             self.serialize(
                 MockEvent(
@@ -443,7 +446,7 @@ class SerializeEventTestCase(stdlib_unittest.TestCase):
             {"content": {"body": "A message"}},
         )
 
-    def test_event_fields_works_with_dot_keys(self):
+    def test_event_fields_works_with_dot_keys(self) -> None:
         self.assertEqual(
             self.serialize(
                 MockEvent(
@@ -456,7 +459,7 @@ class SerializeEventTestCase(stdlib_unittest.TestCase):
             {"content": {"key.with.dots": {}}},
         )
 
-    def test_event_fields_works_with_nested_dot_keys(self):
+    def test_event_fields_works_with_nested_dot_keys(self) -> None:
         self.assertEqual(
             self.serialize(
                 MockEvent(
@@ -472,7 +475,7 @@ class SerializeEventTestCase(stdlib_unittest.TestCase):
             {"content": {"nested.dot.key": {"leaf.key": 42}}},
         )
 
-    def test_event_fields_nops_with_unknown_keys(self):
+    def test_event_fields_nops_with_unknown_keys(self) -> None:
         self.assertEqual(
             self.serialize(
                 MockEvent(
@@ -485,7 +488,7 @@ class SerializeEventTestCase(stdlib_unittest.TestCase):
             {"content": {"foo": "bar"}},
         )
 
-    def test_event_fields_nops_with_non_dict_keys(self):
+    def test_event_fields_nops_with_non_dict_keys(self) -> None:
         self.assertEqual(
             self.serialize(
                 MockEvent(
@@ -498,7 +501,7 @@ class SerializeEventTestCase(stdlib_unittest.TestCase):
             {},
         )
 
-    def test_event_fields_nops_with_array_keys(self):
+    def test_event_fields_nops_with_array_keys(self) -> None:
         self.assertEqual(
             self.serialize(
                 MockEvent(
@@ -511,7 +514,7 @@ class SerializeEventTestCase(stdlib_unittest.TestCase):
             {},
         )
 
-    def test_event_fields_all_fields_if_empty(self):
+    def test_event_fields_all_fields_if_empty(self) -> None:
         self.assertEqual(
             self.serialize(
                 MockEvent(
@@ -531,16 +534,16 @@ class SerializeEventTestCase(stdlib_unittest.TestCase):
             },
         )
 
-    def test_event_fields_fail_if_fields_not_str(self):
+    def test_event_fields_fail_if_fields_not_str(self) -> None:
         with self.assertRaises(TypeError):
             self.serialize(
-                MockEvent(room_id="!foo:bar", content={"foo": "bar"}), ["room_id", 4]
+                MockEvent(room_id="!foo:bar", content={"foo": "bar"}), ["room_id", 4]  # type: ignore[list-item]
             )
 
 
 class CopyPowerLevelsContentTestCase(stdlib_unittest.TestCase):
     def setUp(self) -> None:
-        self.test_content = {
+        self.test_content: PowerLevelsContent = {
             "ban": 50,
             "events": {"m.room.name": 100, "m.room.power_levels": 100},
             "events_default": 0,
@@ -553,10 +556,11 @@ class CopyPowerLevelsContentTestCase(stdlib_unittest.TestCase):
             "users_default": 0,
         }
 
-    def _test(self, input):
+    def _test(self, input: PowerLevelsContent) -> None:
         a = copy_and_fixup_power_levels_contents(input)
 
         self.assertEqual(a["ban"], 50)
+        assert isinstance(a["events"], Mapping)
         self.assertEqual(a["events"]["m.room.name"], 100)
 
         # make sure that changing the copy changes the copy and not the orig
@@ -564,18 +568,19 @@ class CopyPowerLevelsContentTestCase(stdlib_unittest.TestCase):
         a["events"]["m.room.power_levels"] = 20
 
         self.assertEqual(input["ban"], 50)
+        assert isinstance(input["events"], Mapping)
         self.assertEqual(input["events"]["m.room.power_levels"], 100)
 
-    def test_unfrozen(self):
+    def test_unfrozen(self) -> None:
         self._test(self.test_content)
 
-    def test_frozen(self):
+    def test_frozen(self) -> None:
         input = freeze(self.test_content)
         self._test(input)
 
-    def test_stringy_integers(self):
+    def test_stringy_integers(self) -> None:
         """String representations of decimal integers are converted to integers."""
-        input = {
+        input: PowerLevelsContent = {
             "a": "100",
             "b": {
                 "foo": 99,
@@ -603,9 +608,9 @@ class CopyPowerLevelsContentTestCase(stdlib_unittest.TestCase):
 
     def test_invalid_types_raise_type_error(self) -> None:
         with self.assertRaises(TypeError):
-            copy_and_fixup_power_levels_contents({"a": ["hello", "grandma"]})  # type: ignore[arg-type]
-            copy_and_fixup_power_levels_contents({"a": None})  # type: ignore[arg-type]
+            copy_and_fixup_power_levels_contents({"a": ["hello", "grandma"]})  # type: ignore[dict-item]
+            copy_and_fixup_power_levels_contents({"a": None})  # type: ignore[dict-item]
 
     def test_invalid_nesting_raises_type_error(self) -> None:
         with self.assertRaises(TypeError):
-            copy_and_fixup_power_levels_contents({"a": {"b": {"c": 1}}})
+            copy_and_fixup_power_levels_contents({"a": {"b": {"c": 1}}})  # type: ignore[dict-item]
diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py
index d21c11b716..ff589c0b6c 100644
--- a/tests/federation/transport/test_knocking.py
+++ b/tests/federation/transport/test_knocking.py
@@ -23,10 +23,10 @@ from synapse.server import HomeServer
 from synapse.types import RoomAlias
 
 from tests.test_utils import event_injection
-from tests.unittest import FederatingHomeserverTestCase, TestCase
+from tests.unittest import FederatingHomeserverTestCase, HomeserverTestCase
 
 
-class KnockingStrippedStateEventHelperMixin(TestCase):
+class KnockingStrippedStateEventHelperMixin(HomeserverTestCase):
     def send_example_state_events_to_room(
         self,
         hs: "HomeServer",
@@ -49,7 +49,7 @@ class KnockingStrippedStateEventHelperMixin(TestCase):
         # To set a canonical alias, we'll need to point an alias at the room first.
         canonical_alias = "#fancy_alias:test"
         self.get_success(
-            self.store.create_room_alias_association(
+            self.hs.get_datastores().main.create_room_alias_association(
                 RoomAlias.from_string(canonical_alias), room_id, ["test"]
             )
         )
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index efbb5a8dbb..1fe9563c98 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -14,21 +14,22 @@
 
 
 import json
-from typing import Dict
+from typing import Dict, List, Set
 from unittest.mock import ANY, Mock, call
 
-from twisted.internet import defer
 from twisted.test.proto_helpers import MemoryReactor
 from twisted.web.resource import Resource
 
 from synapse.api.constants import EduTypes
 from synapse.api.errors import AuthError
 from synapse.federation.transport.server import TransportLayerServer
+from synapse.handlers.typing import TypingWriterHandler
 from synapse.server import HomeServer
 from synapse.types import JsonDict, Requester, UserID, create_requester
 from synapse.util import Clock
 
 from tests import unittest
+from tests.server import ThreadedMemoryReactorClock
 from tests.test_utils import make_awaitable
 from tests.unittest import override_config
 
@@ -62,7 +63,11 @@ def _make_edu_transaction_json(edu_type: str, content: JsonDict) -> bytes:
 
 
 class TypingNotificationsTestCase(unittest.HomeserverTestCase):
-    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+    def make_homeserver(
+        self,
+        reactor: ThreadedMemoryReactorClock,
+        clock: Clock,
+    ) -> HomeServer:
         # we mock out the keyring so as to skip the authentication check on the
         # federation API call.
         mock_keyring = Mock(spec=["verify_json_for_server"])
@@ -75,8 +80,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         # the tests assume that we are starting at unix time 1000
         reactor.pump((1000,))
 
+        self.mock_hs_notifier = Mock()
         hs = self.setup_test_homeserver(
-            notifier=Mock(),
+            notifier=self.mock_hs_notifier,
             federation_http_client=mock_federation_client,
             keyring=mock_keyring,
             replication_streams={},
@@ -90,32 +96,38 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         return d
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        mock_notifier = hs.get_notifier()
-        self.on_new_event = mock_notifier.on_new_event
+        self.on_new_event = self.mock_hs_notifier.on_new_event
 
-        self.handler = hs.get_typing_handler()
+        # hs.get_typing_handler will return a TypingWriterHandler when calling it
+        # from the main process, and a FollowerTypingHandler on workers.
+        # We rely on methods only available on the former, so assert we have the
+        # correct type here. We have to assign self.handler after the assert,
+        # otherwise mypy will treat it as a FollowerTypingHandler
+        handler = hs.get_typing_handler()
+        assert isinstance(handler, TypingWriterHandler)
+        self.handler = handler
 
         self.event_source = hs.get_event_sources().sources.typing
 
         self.datastore = hs.get_datastores().main
+
         self.datastore.get_destination_retry_timings = Mock(
             return_value=make_awaitable(None)
         )
 
-        self.datastore.get_device_updates_by_remote = Mock(
+        self.datastore.get_device_updates_by_remote = Mock(  # type: ignore[assignment]
             return_value=make_awaitable((0, []))
         )
 
-        self.datastore.get_destination_last_successful_stream_ordering = Mock(
+        self.datastore.get_destination_last_successful_stream_ordering = Mock(  # type: ignore[assignment]
             return_value=make_awaitable(None)
         )
 
-        def get_received_txn_response(*args):
-            return defer.succeed(None)
-
-        self.datastore.get_received_txn_response = get_received_txn_response
+        self.datastore.get_received_txn_response = Mock(  # type: ignore[assignment]
+            return_value=make_awaitable(None)
+        )
 
-        self.room_members = []
+        self.room_members: List[UserID] = []
 
         async def check_user_in_room(room_id: str, requester: Requester) -> None:
             if requester.user.to_string() not in [
@@ -124,47 +136,54 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
                 raise AuthError(401, "User is not in the room")
             return None
 
-        hs.get_auth().check_user_in_room = check_user_in_room
+        hs.get_auth().check_user_in_room = Mock(  # type: ignore[assignment]
+            side_effect=check_user_in_room
+        )
 
         async def check_host_in_room(room_id: str, server_name: str) -> bool:
             return room_id == ROOM_ID
 
-        hs.get_event_auth_handler().is_host_in_room = check_host_in_room
+        hs.get_event_auth_handler().is_host_in_room = Mock(  # type: ignore[assignment]
+            side_effect=check_host_in_room
+        )
 
-        async def get_current_hosts_in_room(room_id: str):
+        async def get_current_hosts_in_room(room_id: str) -> Set[str]:
             return {member.domain for member in self.room_members}
 
-        hs.get_storage_controllers().state.get_current_hosts_in_room = (
-            get_current_hosts_in_room
+        hs.get_storage_controllers().state.get_current_hosts_in_room = Mock(  # type: ignore[assignment]
+            side_effect=get_current_hosts_in_room
         )
 
-        hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = (
-            get_current_hosts_in_room
+        hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = Mock(  # type: ignore[assignment]
+            side_effect=get_current_hosts_in_room
         )
 
-        async def get_users_in_room(room_id: str):
+        async def get_users_in_room(room_id: str) -> Set[str]:
             return {str(u) for u in self.room_members}
 
-        self.datastore.get_users_in_room = get_users_in_room
+        self.datastore.get_users_in_room = Mock(side_effect=get_users_in_room)
 
-        self.datastore.get_user_directory_stream_pos = Mock(
+        self.datastore.get_user_directory_stream_pos = Mock(  # type: ignore[assignment]
             side_effect=(
-                # we deliberately return a non-None stream pos to avoid doing an initial_spam
+                # we deliberately return a non-None stream pos to avoid
+                # doing an initial_sync
                 lambda: make_awaitable(1)
             )
         )
 
-        self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, None))
+        self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, None))  # type: ignore[assignment]
 
-        self.datastore.get_to_device_stream_token = lambda: 0
-        self.datastore.get_new_device_msgs_for_remote = (
-            lambda *args, **kargs: make_awaitable(([], 0))
+        self.datastore.get_to_device_stream_token = Mock(  # type: ignore[assignment]
+            side_effect=lambda: 0
+        )
+        self.datastore.get_new_device_msgs_for_remote = Mock(  # type: ignore[assignment]
+            side_effect=lambda *args, **kargs: make_awaitable(([], 0))
         )
-        self.datastore.delete_device_msgs_for_remote = (
-            lambda *args, **kargs: make_awaitable(None)
+        self.datastore.delete_device_msgs_for_remote = Mock(  # type: ignore[assignment]
+            side_effect=lambda *args, **kargs: make_awaitable(None)
         )
-        self.datastore.set_received_txn_response = (
-            lambda *args, **kwargs: make_awaitable(None)
+        self.datastore.set_received_txn_response = Mock(  # type: ignore[assignment]
+            side_effect=lambda *args, **kwargs: make_awaitable(None)
         )
 
     def test_started_typing_local(self) -> None:
@@ -186,7 +205,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         self.assertEqual(self.event_source.get_current_key(), 1)
         events = self.get_success(
             self.event_source.get_new_events(
-                user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False
+                user=U_APPLE, from_key=0, limit=0, room_ids=[ROOM_ID], is_guest=False
             )
         )
         self.assertEqual(
@@ -257,7 +276,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         self.assertEqual(self.event_source.get_current_key(), 1)
         events = self.get_success(
             self.event_source.get_new_events(
-                user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False
+                user=U_APPLE, from_key=0, limit=0, room_ids=[ROOM_ID], is_guest=False
             )
         )
         self.assertEqual(
@@ -298,7 +317,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             self.event_source.get_new_events(
                 user=U_APPLE,
                 from_key=0,
-                limit=None,
+                limit=0,
                 room_ids=[OTHER_ROOM_ID],
                 is_guest=False,
             )
@@ -351,7 +370,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
         self.assertEqual(self.event_source.get_current_key(), 1)
         events = self.get_success(
             self.event_source.get_new_events(
-                user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False
+                user=U_APPLE, from_key=0, limit=0, room_ids=[ROOM_ID], is_guest=False
             )
         )
         self.assertEqual(
@@ -387,7 +406,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             self.event_source.get_new_events(
                 user=U_APPLE,
                 from_key=0,
-                limit=None,
+                limit=0,
                 room_ids=[ROOM_ID],
                 is_guest=False,
             )
@@ -412,7 +431,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             self.event_source.get_new_events(
                 user=U_APPLE,
                 from_key=1,
-                limit=None,
+                limit=0,
                 room_ids=[ROOM_ID],
                 is_guest=False,
             )
@@ -447,7 +466,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             self.event_source.get_new_events(
                 user=U_APPLE,
                 from_key=0,
-                limit=None,
+                limit=0,
                 room_ids=[ROOM_ID],
                 is_guest=False,
             )
diff --git a/tests/logging/__init__.py b/tests/logging/__init__.py
index 1acf5666a8..1c5de95a80 100644
--- a/tests/logging/__init__.py
+++ b/tests/logging/__init__.py
@@ -13,9 +13,11 @@
 # limitations under the License.
 import logging
 
+from tests.unittest import TestCase
 
-class LoggerCleanupMixin:
-    def get_logger(self, handler):
+
+class LoggerCleanupMixin(TestCase):
+    def get_logger(self, handler: logging.Handler) -> logging.Logger:
         """
         Attach a handler to a logger and add clean-ups to remove revert this.
         """
diff --git a/tests/logging/test_opentracing.py b/tests/logging/test_opentracing.py
index 0917e478a5..e28ba84cc2 100644
--- a/tests/logging/test_opentracing.py
+++ b/tests/logging/test_opentracing.py
@@ -153,7 +153,7 @@ class LogContextScopeManagerTestCase(TestCase):
 
         scopes = []
 
-        async def task(i: int):
+        async def task(i: int) -> None:
             scope = start_active_span(
                 f"task{i}",
                 tracer=self._tracer,
@@ -165,7 +165,7 @@ class LogContextScopeManagerTestCase(TestCase):
             self.assertEqual(self._tracer.active_span, scope.span)
             scope.close()
 
-        async def root():
+        async def root() -> None:
             with start_active_span("root span", tracer=self._tracer) as root_scope:
                 self.assertEqual(self._tracer.active_span, root_scope.span)
                 scopes.append(root_scope)
diff --git a/tests/logging/test_remote_handler.py b/tests/logging/test_remote_handler.py
index b0d046fe00..c08954d887 100644
--- a/tests/logging/test_remote_handler.py
+++ b/tests/logging/test_remote_handler.py
@@ -11,7 +11,10 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from twisted.test.proto_helpers import AccumulatingProtocol
+from typing import Tuple
+
+from twisted.internet.protocol import Protocol
+from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
 
 from synapse.logging import RemoteHandler
 
@@ -20,7 +23,9 @@ from tests.server import FakeTransport, get_clock
 from tests.unittest import TestCase
 
 
-def connect_logging_client(reactor, client_id):
+def connect_logging_client(
+    reactor: MemoryReactorClock, client_id: int
+) -> Tuple[Protocol, AccumulatingProtocol]:
     # This is essentially tests.server.connect_client, but disabling autoflush on
     # the client transport. This is necessary to avoid an infinite loop due to
     # sending of data via the logging transport causing additional logs to be
@@ -35,10 +40,10 @@ def connect_logging_client(reactor, client_id):
 
 
 class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
-    def setUp(self):
+    def setUp(self) -> None:
         self.reactor, _ = get_clock()
 
-    def test_log_output(self):
+    def test_log_output(self) -> None:
         """
         The remote handler delivers logs over TCP.
         """
@@ -51,6 +56,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
         client, server = connect_logging_client(self.reactor, 0)
 
         # Trigger data being sent
+        assert isinstance(client.transport, FakeTransport)
         client.transport.flush()
 
         # One log message, with a single trailing newline
@@ -61,7 +67,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
         # Ensure the data passed through properly.
         self.assertEqual(logs[0], "Hello there, wally!")
 
-    def test_log_backpressure_debug(self):
+    def test_log_backpressure_debug(self) -> None:
         """
         When backpressure is hit, DEBUG logs will be shed.
         """
@@ -83,6 +89,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
 
         # Allow the reconnection
         client, server = connect_logging_client(self.reactor, 0)
+        assert isinstance(client.transport, FakeTransport)
         client.transport.flush()
 
         # Only the 7 infos made it through, the debugs were elided
@@ -90,7 +97,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
         self.assertEqual(len(logs), 7)
         self.assertNotIn(b"debug", server.data)
 
-    def test_log_backpressure_info(self):
+    def test_log_backpressure_info(self) -> None:
         """
         When backpressure is hit, DEBUG and INFO logs will be shed.
         """
@@ -116,6 +123,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
 
         # Allow the reconnection
         client, server = connect_logging_client(self.reactor, 0)
+        assert isinstance(client.transport, FakeTransport)
         client.transport.flush()
 
         # The 10 warnings made it through, the debugs and infos were elided
@@ -124,7 +132,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
         self.assertNotIn(b"debug", server.data)
         self.assertNotIn(b"info", server.data)
 
-    def test_log_backpressure_cut_middle(self):
+    def test_log_backpressure_cut_middle(self) -> None:
         """
         When backpressure is hit, and no more DEBUG and INFOs cannot be culled,
         it will cut the middle messages out.
@@ -140,6 +148,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
 
         # Allow the reconnection
         client, server = connect_logging_client(self.reactor, 0)
+        assert isinstance(client.transport, FakeTransport)
         client.transport.flush()
 
         # The first five and last five warnings made it through, the debugs and
@@ -151,7 +160,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
             logs,
         )
 
-    def test_cancel_connection(self):
+    def test_cancel_connection(self) -> None:
         """
         Gracefully handle the connection being cancelled.
         """
diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py
index 0b0d8737c1..fa27f1279a 100644
--- a/tests/logging/test_terse_json.py
+++ b/tests/logging/test_terse_json.py
@@ -14,24 +14,28 @@
 import json
 import logging
 from io import BytesIO, StringIO
+from typing import cast
 from unittest.mock import Mock, patch
 
+from twisted.web.http import HTTPChannel
 from twisted.web.server import Request
 
 from synapse.http.site import SynapseRequest
 from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter
 from synapse.logging.context import LoggingContext, LoggingContextFilter
+from synapse.types import JsonDict
 
 from tests.logging import LoggerCleanupMixin
-from tests.server import FakeChannel
+from tests.server import FakeChannel, get_clock
 from tests.unittest import TestCase
 
 
 class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
-    def setUp(self):
+    def setUp(self) -> None:
         self.output = StringIO()
+        self.reactor, _ = get_clock()
 
-    def get_log_line(self):
+    def get_log_line(self) -> JsonDict:
         # One log message, with a single trailing newline.
         data = self.output.getvalue()
         logs = data.splitlines()
@@ -39,7 +43,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
         self.assertEqual(data.count("\n"), 1)
         return json.loads(logs[0])
 
-    def test_terse_json_output(self):
+    def test_terse_json_output(self) -> None:
         """
         The Terse JSON formatter converts log messages to JSON.
         """
@@ -61,7 +65,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
         self.assertCountEqual(log.keys(), expected_log_keys)
         self.assertEqual(log["log"], "Hello there, wally!")
 
-    def test_extra_data(self):
+    def test_extra_data(self) -> None:
         """
         Additional information can be included in the structured logging.
         """
@@ -93,7 +97,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
         self.assertEqual(log["int"], 3)
         self.assertIs(log["bool"], True)
 
-    def test_json_output(self):
+    def test_json_output(self) -> None:
         """
         The Terse JSON formatter converts log messages to JSON.
         """
@@ -114,7 +118,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
         self.assertCountEqual(log.keys(), expected_log_keys)
         self.assertEqual(log["log"], "Hello there, wally!")
 
-    def test_with_context(self):
+    def test_with_context(self) -> None:
         """
         The logging context should be added to the JSON response.
         """
@@ -139,7 +143,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
         self.assertEqual(log["log"], "Hello there, wally!")
         self.assertEqual(log["request"], "name")
 
-    def test_with_request_context(self):
+    def test_with_request_context(self) -> None:
         """
         Information from the logging context request should be added to the JSON response.
         """
@@ -154,11 +158,13 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
         site.server_version_string = "Server v1"
         site.reactor = Mock()
         site.experimental_cors_msc3886 = False
-        request = SynapseRequest(FakeChannel(site, None), site)
+        request = SynapseRequest(
+            cast(HTTPChannel, FakeChannel(site, self.reactor)), site
+        )
         # Call requestReceived to finish instantiating the object.
         request.content = BytesIO()
-        # Partially skip some of the internal processing of SynapseRequest.
-        request._started_processing = Mock()
+        # Partially skip some internal processing of SynapseRequest.
+        request._started_processing = Mock()  # type: ignore[assignment]
         request.request_metrics = Mock(spec=["name"])
         with patch.object(Request, "render"):
             request.requestReceived(b"POST", b"/_matrix/client/versions", b"1.1")
@@ -200,7 +206,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
         self.assertEqual(log["protocol"], "1.1")
         self.assertEqual(log["user_agent"], "")
 
-    def test_with_exception(self):
+    def test_with_exception(self) -> None:
         """
         The logging exception type & value should be added to the JSON response.
         """
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index c9afa0f3dd..b9047194dd 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -294,9 +294,7 @@ class SyncTypingTests(unittest.HomeserverTestCase):
             self.make_request("GET", sync_url % (access_token, next_batch))
 
 
-class SyncKnockTestCase(
-    unittest.HomeserverTestCase, KnockingStrippedStateEventHelperMixin
-):
+class SyncKnockTestCase(KnockingStrippedStateEventHelperMixin):
     servlets = [
         synapse.rest.admin.register_servlets,
         login.register_servlets,
diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py
index 21a1ca2a68..3086e1b565 100644
--- a/tests/rest/client/test_transactions.py
+++ b/tests/rest/client/test_transactions.py
@@ -13,18 +13,22 @@
 # limitations under the License.
 
 from http import HTTPStatus
+from typing import Any, Generator, Tuple, cast
 from unittest.mock import Mock, call
 
-from twisted.internet import defer, reactor
+from twisted.internet import defer, reactor as _reactor
 
 from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context
 from synapse.rest.client.transactions import CLEANUP_PERIOD_MS, HttpTransactionCache
+from synapse.types import ISynapseReactor, JsonDict
 from synapse.util import Clock
 
 from tests import unittest
 from tests.test_utils import make_awaitable
 from tests.utils import MockClock
 
+reactor = cast(ISynapseReactor, _reactor)
+
 
 class HttpTransactionCacheTestCase(unittest.TestCase):
     def setUp(self) -> None:
@@ -34,11 +38,13 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
         self.hs.get_auth = Mock()
         self.cache = HttpTransactionCache(self.hs)
 
-        self.mock_http_response = (HTTPStatus.OK, "GOOD JOB!")
+        self.mock_http_response = (HTTPStatus.OK, {"result": "GOOD JOB!"})
         self.mock_key = "foo"
 
     @defer.inlineCallbacks
-    def test_executes_given_function(self):
+    def test_executes_given_function(
+        self,
+    ) -> Generator["defer.Deferred[Any]", object, None]:
         cb = Mock(return_value=make_awaitable(self.mock_http_response))
         res = yield self.cache.fetch_or_execute(
             self.mock_key, cb, "some_arg", keyword="arg"
@@ -47,7 +53,9 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
         self.assertEqual(res, self.mock_http_response)
 
     @defer.inlineCallbacks
-    def test_deduplicates_based_on_key(self):
+    def test_deduplicates_based_on_key(
+        self,
+    ) -> Generator["defer.Deferred[Any]", object, None]:
         cb = Mock(return_value=make_awaitable(self.mock_http_response))
         for i in range(3):  # invoke multiple times
             res = yield self.cache.fetch_or_execute(
@@ -58,18 +66,20 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
         cb.assert_called_once_with("some_arg", keyword="arg", changing_args=0)
 
     @defer.inlineCallbacks
-    def test_logcontexts_with_async_result(self):
+    def test_logcontexts_with_async_result(
+        self,
+    ) -> Generator["defer.Deferred[Any]", object, None]:
         @defer.inlineCallbacks
-        def cb():
+        def cb() -> Generator["defer.Deferred[object]", object, Tuple[int, JsonDict]]:
             yield Clock(reactor).sleep(0)
-            return "yay"
+            return 1, {}
 
         @defer.inlineCallbacks
-        def test():
+        def test() -> Generator["defer.Deferred[Any]", object, None]:
             with LoggingContext("c") as c1:
                 res = yield self.cache.fetch_or_execute(self.mock_key, cb)
                 self.assertIs(current_context(), c1)
-                self.assertEqual(res, "yay")
+                self.assertEqual(res, (1, {}))
 
         # run the test twice in parallel
         d = defer.gatherResults([test(), test()])
@@ -78,13 +88,15 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
         self.assertIs(current_context(), SENTINEL_CONTEXT)
 
     @defer.inlineCallbacks
-    def test_does_not_cache_exceptions(self):
+    def test_does_not_cache_exceptions(
+        self,
+    ) -> Generator["defer.Deferred[Any]", object, None]:
         """Checks that, if the callback throws an exception, it is called again
         for the next request.
         """
         called = [False]
 
-        def cb():
+        def cb() -> "defer.Deferred[Tuple[int, JsonDict]]":
             if called[0]:
                 # return a valid result the second time
                 return defer.succeed(self.mock_http_response)
@@ -104,13 +116,15 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
             self.assertIs(current_context(), test_context)
 
     @defer.inlineCallbacks
-    def test_does_not_cache_failures(self):
+    def test_does_not_cache_failures(
+        self,
+    ) -> Generator["defer.Deferred[Any]", object, None]:
         """Checks that, if the callback returns a failure, it is called again
         for the next request.
         """
         called = [False]
 
-        def cb():
+        def cb() -> "defer.Deferred[Tuple[int, JsonDict]]":
             if called[0]:
                 # return a valid result the second time
                 return defer.succeed(self.mock_http_response)
@@ -130,7 +144,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
             self.assertIs(current_context(), test_context)
 
     @defer.inlineCallbacks
-    def test_cleans_up(self):
+    def test_cleans_up(self) -> Generator["defer.Deferred[Any]", object, None]:
         cb = Mock(return_value=make_awaitable(self.mock_http_response))
         yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg")
         # should NOT have cleaned up yet
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 3ba896ecf3..f1ca523d23 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -28,6 +28,7 @@ from synapse.storage.background_updates import _BackgroundUpdateHandler
 from synapse.storage.roommember import ProfileInfo
 from synapse.util import Clock
 
+from tests.server import ThreadedMemoryReactorClock
 from tests.test_utils.event_injection import inject_member_event
 from tests.unittest import HomeserverTestCase, override_config
 
@@ -138,7 +139,9 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
         register.register_servlets,
     ]
 
-    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+    def make_homeserver(
+        self, reactor: ThreadedMemoryReactorClock, clock: Clock
+    ) -> HomeServer:
         self.appservice = ApplicationService(
             token="i_am_an_app_service",
             id="1234",
diff --git a/tests/unittest.py b/tests/unittest.py
index a120c2976c..fa92dd94eb 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -75,6 +75,7 @@ from synapse.util.httpresourcetree import create_resource_tree
 from tests.server import (
     CustomHeaderType,
     FakeChannel,
+    ThreadedMemoryReactorClock,
     get_clock,
     make_request,
     setup_test_homeserver,
@@ -360,7 +361,7 @@ class HomeserverTestCase(TestCase):
                 store.db_pool.updates.do_next_background_update(False), by=0.1
             )
 
-    def make_homeserver(self, reactor: MemoryReactor, clock: Clock):
+    def make_homeserver(self, reactor: ThreadedMemoryReactorClock, clock: Clock):
         """
         Make and return a homeserver.