summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/app/_base.py4
-rw-r--r--synapse/config/logger.py7
-rw-r--r--synapse/config/registration.py9
-rw-r--r--synapse/config/server.py6
-rw-r--r--synapse/events/utils.py7
-rw-r--r--synapse/federation/federation_server.py4
-rw-r--r--synapse/federation/transport/client.py10
-rw-r--r--synapse/handlers/appservice.py4
-rw-r--r--synapse/handlers/auth.py2
-rw-r--r--synapse/handlers/oidc.py2
-rw-r--r--synapse/handlers/push_rules.py138
-rw-r--r--synapse/handlers/receipts.py4
-rw-r--r--synapse/handlers/relations.py4
-rw-r--r--synapse/handlers/search.py6
-rw-r--r--synapse/handlers/ui_auth/checkers.py4
-rw-r--r--synapse/http/server.py16
-rw-r--r--synapse/logging/context.py26
-rw-r--r--synapse/module_api/__init__.py68
-rw-r--r--synapse/module_api/errors.py4
-rw-r--r--synapse/rest/client/push_rule.py112
-rw-r--r--synapse/rest/client/register.py7
-rw-r--r--synapse/server.py5
-rw-r--r--synapse/storage/databases/main/__init__.py21
-rw-r--r--synapse/storage/databases/main/appservice.py4
-rw-r--r--synapse/storage/databases/main/deviceinbox.py79
-rw-r--r--synapse/storage/databases/main/devices.py51
-rw-r--r--synapse/storage/databases/main/events.py15
-rw-r--r--synapse/storage/databases/main/group_server.py4
-rw-r--r--synapse/storage/databases/main/keys.py15
-rw-r--r--synapse/storage/databases/main/media_repository.py13
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py6
-rw-r--r--synapse/storage/databases/main/presence.py19
-rw-r--r--synapse/storage/databases/main/purge_events.py8
-rw-r--r--synapse/storage/databases/main/push_rule.py15
-rw-r--r--synapse/storage/databases/main/pusher.py49
-rw-r--r--synapse/storage/databases/main/receipts.py56
-rw-r--r--synapse/storage/databases/main/state.py9
-rw-r--r--synapse/storage/databases/main/ui_auth.py12
-rw-r--r--synapse/storage/prepare_database.py4
-rw-r--r--synapse/util/caches/ttlcache.py2
-rw-r--r--synapse/util/frozenutils.py3
41 files changed, 579 insertions, 255 deletions
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 37321f9133..d28b87a3f4 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -48,7 +48,6 @@ from twisted.logger import LoggingFile, LogLevel
 from twisted.protocols.tls import TLSMemoryBIOFactory
 from twisted.python.threadpool import ThreadPool
 
-import synapse
 from synapse.api.constants import MAX_PDU_SIZE
 from synapse.app import check_bind_error
 from synapse.app.phone_stats_home import start_phone_stats_home
@@ -60,6 +59,7 @@ from synapse.events.spamcheck import load_legacy_spam_checkers
 from synapse.events.third_party_rules import load_legacy_third_party_event_rules
 from synapse.handlers.auth import load_legacy_password_auth_providers
 from synapse.logging.context import PreserveLoggingContext
+from synapse.logging.opentracing import init_tracer
 from synapse.metrics import install_gc_manager, register_threadpool
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.metrics.jemalloc import setup_jemalloc_stats
@@ -431,7 +431,7 @@ async def start(hs: "HomeServer") -> None:
     refresh_certificate(hs)
 
     # Start the tracer
-    synapse.logging.opentracing.init_tracer(hs)  # type: ignore[attr-defined] # noqa
+    init_tracer(hs)  # noqa
 
     # Instantiate the modules so they can register their web resources to the module API
     # before we start the listeners.
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index 99db9e1e39..470b8b4492 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -110,13 +110,6 @@ loggers:
         # information such as access tokens.
         level: INFO
 
-    twisted:
-        # We send the twisted logging directly to the file handler,
-        # to work around https://github.com/matrix-org/synapse/issues/3471
-        # when using "buffer" logger. Use "console" to log to stderr instead.
-        handlers: [file]
-        propagate: false
-
 root:
     level: INFO
 
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 39e9acb62a..70eb7e6a97 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -43,6 +43,9 @@ class RegistrationConfig(Config):
         self.registration_requires_token = config.get(
             "registration_requires_token", False
         )
+        self.enable_registration_token_3pid_bypasss = config.get(
+            "enable_registration_token_3pid_bypasss", False
+        )
         self.registration_shared_secret = config.get("registration_shared_secret")
 
         self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
@@ -309,6 +312,12 @@ class RegistrationConfig(Config):
         #
         #registration_requires_token: true
 
+        # Allow users to submit a token during registration to bypass any required 3pid
+        # steps configured in `registrations_require_3pid`.
+        # Defaults to false, requiring that registration tokens (if enabled) complete a 3pid flow.
+        #
+        #enable_registration_token_3pid_bypass: false
+
         # If set, allows registration of standard or admin accounts by anyone who
         # has the shared secret, even if registration is otherwise disabled.
         #
diff --git a/synapse/config/server.py b/synapse/config/server.py
index d771045b52..b6cd326416 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -186,7 +186,7 @@ KNOWN_RESOURCES = {
 class HttpResourceConfig:
     names: List[str] = attr.ib(
         factory=list,
-        validator=attr.validators.deep_iterable(attr.validators.in_(KNOWN_RESOURCES)),  # type: ignore
+        validator=attr.validators.deep_iterable(attr.validators.in_(KNOWN_RESOURCES)),
     )
     compress: bool = attr.ib(
         default=False,
@@ -231,9 +231,7 @@ class ManholeConfig:
 class LimitRemoteRoomsConfig:
     enabled: bool = attr.ib(validator=attr.validators.instance_of(bool), default=False)
     complexity: Union[float, int] = attr.ib(
-        validator=attr.validators.instance_of(
-            (float, int)  # type: ignore[arg-type] # noqa
-        ),
+        validator=attr.validators.instance_of((float, int)),  # noqa
         default=1.0,
     )
     complexity_error: str = attr.ib(
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index f8d3ba5456..a6c48308b3 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -27,7 +27,6 @@ from typing import (
 )
 
 import attr
-from frozendict import frozendict
 
 from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
 from synapse.api.errors import Codes, SynapseError
@@ -204,7 +203,9 @@ def _copy_field(src: JsonDict, dst: JsonDict, field: List[str]) -> None:
     key_to_move = field.pop(-1)
     sub_dict = src
     for sub_field in field:  # e.g. sub_field => "content"
-        if sub_field in sub_dict and type(sub_dict[sub_field]) in [dict, frozendict]:
+        if sub_field in sub_dict and isinstance(
+            sub_dict[sub_field], collections.abc.Mapping
+        ):
             sub_dict = sub_dict[sub_field]
         else:
             return
@@ -622,7 +623,7 @@ def validate_canonicaljson(value: Any) -> None:
         # Note that Infinity, -Infinity, and NaN are also considered floats.
         raise SynapseError(400, "Bad JSON value: float", Codes.BAD_JSON)
 
-    elif isinstance(value, (dict, frozendict)):
+    elif isinstance(value, collections.abc.Mapping):
         for v in value.values():
             validate_canonicaljson(v)
 
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index beab1227b8..884b5d60b4 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -268,8 +268,8 @@ class FederationServer(FederationBase):
             transaction_id=transaction_id,
             destination=destination,
             origin=origin,
-            origin_server_ts=transaction_data.get("origin_server_ts"),  # type: ignore
-            pdus=transaction_data.get("pdus"),  # type: ignore
+            origin_server_ts=transaction_data.get("origin_server_ts"),  # type: ignore[arg-type]
+            pdus=transaction_data.get("pdus"),
             edus=transaction_data.get("edus"),
         )
 
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 1421050b9a..9ce06dfa28 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -229,21 +229,21 @@ class TransportLayerClient:
         """
         logger.debug(
             "send_data dest=%s, txid=%s",
-            transaction.destination,  # type: ignore
-            transaction.transaction_id,  # type: ignore
+            transaction.destination,
+            transaction.transaction_id,
         )
 
-        if transaction.destination == self.server_name:  # type: ignore
+        if transaction.destination == self.server_name:
             raise RuntimeError("Transport layer cannot send to itself!")
 
         # FIXME: This is only used by the tests. The actual json sent is
         # generated by the json_data_callback.
         json_data = transaction.get_dict()
 
-        path = _create_v1_path("/send/%s", transaction.transaction_id)  # type: ignore
+        path = _create_v1_path("/send/%s", transaction.transaction_id)
 
         return await self.client.put_json(
-            transaction.destination,  # type: ignore
+            transaction.destination,
             path=path,
             data=json_data,
             json_data_callback=json_data_callback,
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 1b57840506..b3894666cc 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -416,7 +416,7 @@ class ApplicationServicesHandler:
         return typing
 
     async def _handle_receipts(
-        self, service: ApplicationService, new_token: Optional[int]
+        self, service: ApplicationService, new_token: int
     ) -> List[JsonDict]:
         """
         Return the latest read receipts that the given application service should receive.
@@ -447,7 +447,7 @@ class ApplicationServicesHandler:
 
         receipts_source = self.event_sources.sources.receipt
         receipts, _ = await receipts_source.get_new_events_as(
-            service=service, from_key=from_key
+            service=service, from_key=from_key, to_key=new_token
         )
         return receipts
 
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 86991d26ce..22678d486d 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -481,7 +481,7 @@ class AuthHandler:
             sid = authdict["session"]
 
         # Convert the URI and method to strings.
-        uri = request.uri.decode("utf-8")  # type: ignore
+        uri = request.uri.decode("utf-8")
         method = request.method.decode("utf-8")
 
         # If there's no session ID, create a new session.
diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index 724b9cfcb4..f6ffb7d18d 100644
--- a/synapse/handlers/oidc.py
+++ b/synapse/handlers/oidc.py
@@ -966,7 +966,7 @@ class OidcProvider:
                         "Mapping provider does not support de-duplicating Matrix IDs"
                     )
 
-                attributes = await self._user_mapping_provider.map_user_attributes(  # type: ignore
+                attributes = await self._user_mapping_provider.map_user_attributes(
                     userinfo, token
                 )
 
diff --git a/synapse/handlers/push_rules.py b/synapse/handlers/push_rules.py
new file mode 100644
index 0000000000..2599160bcc
--- /dev/null
+++ b/synapse/handlers/push_rules.py
@@ -0,0 +1,138 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING, List, Optional, Union
+
+import attr
+
+from synapse.api.errors import SynapseError, UnrecognizedRequestError
+from synapse.push.baserules import BASE_RULE_IDS
+from synapse.storage.push_rule import RuleNotFoundException
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class RuleSpec:
+    scope: str
+    template: str
+    rule_id: str
+    attr: Optional[str]
+
+
+class PushRulesHandler:
+    """A class to handle changes in push rules for users."""
+
+    def __init__(self, hs: "HomeServer"):
+        self._notifier = hs.get_notifier()
+        self._main_store = hs.get_datastores().main
+
+    async def set_rule_attr(
+        self, user_id: str, spec: RuleSpec, val: Union[bool, JsonDict]
+    ) -> None:
+        """Set an attribute (enabled or actions) on an existing push rule.
+
+        Notifies listeners (e.g. sync handler) of the change.
+
+        Args:
+            user_id: the user for which to modify the push rule.
+            spec: the spec of the push rule to modify.
+            val: the value to change the attribute to.
+
+        Raises:
+            RuleNotFoundException if the rule being modified doesn't exist.
+            SynapseError(400) if the value is malformed.
+            UnrecognizedRequestError if the attribute to change is unknown.
+            InvalidRuleException if we're trying to change the actions on a rule but
+                the provided actions aren't compliant with the spec.
+        """
+        if spec.attr not in ("enabled", "actions"):
+            # for the sake of potential future expansion, shouldn't report
+            # 404 in the case of an unknown request so check it corresponds to
+            # a known attribute first.
+            raise UnrecognizedRequestError()
+
+        namespaced_rule_id = f"global/{spec.template}/{spec.rule_id}"
+        rule_id = spec.rule_id
+        is_default_rule = rule_id.startswith(".")
+        if is_default_rule:
+            if namespaced_rule_id not in BASE_RULE_IDS:
+                raise RuleNotFoundException("Unknown rule %r" % (namespaced_rule_id,))
+        if spec.attr == "enabled":
+            if isinstance(val, dict) and "enabled" in val:
+                val = val["enabled"]
+            if not isinstance(val, bool):
+                # Legacy fallback
+                # This should *actually* take a dict, but many clients pass
+                # bools directly, so let's not break them.
+                raise SynapseError(400, "Value for 'enabled' must be boolean")
+            await self._main_store.set_push_rule_enabled(
+                user_id, namespaced_rule_id, val, is_default_rule
+            )
+        elif spec.attr == "actions":
+            if not isinstance(val, dict):
+                raise SynapseError(400, "Value must be a dict")
+            actions = val.get("actions")
+            if not isinstance(actions, list):
+                raise SynapseError(400, "Value for 'actions' must be dict")
+            check_actions(actions)
+            rule_id = spec.rule_id
+            is_default_rule = rule_id.startswith(".")
+            if is_default_rule:
+                if namespaced_rule_id not in BASE_RULE_IDS:
+                    raise RuleNotFoundException(
+                        "Unknown rule %r" % (namespaced_rule_id,)
+                    )
+            await self._main_store.set_push_rule_actions(
+                user_id, namespaced_rule_id, actions, is_default_rule
+            )
+        else:
+            raise UnrecognizedRequestError()
+
+        self.notify_user(user_id)
+
+    def notify_user(self, user_id: str) -> None:
+        """Notify listeners about a push rule change.
+
+        Args:
+            user_id: the user ID the change is for.
+        """
+        stream_id = self._main_store.get_max_push_rules_stream_id()
+        self._notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
+
+
+def check_actions(actions: List[Union[str, JsonDict]]) -> None:
+    """Check if the given actions are spec compliant.
+
+    Args:
+        actions: the actions to check.
+
+    Raises:
+        InvalidRuleException if the rules aren't compliant with the spec.
+    """
+    if not isinstance(actions, list):
+        raise InvalidRuleException("No actions found")
+
+    for a in actions:
+        if a in ["notify", "dont_notify", "coalesce"]:
+            pass
+        elif isinstance(a, dict) and "set_tweak" in a:
+            pass
+        else:
+            raise InvalidRuleException("Unrecognised action %s" % a)
+
+
+class InvalidRuleException(Exception):
+    pass
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 6250bb3bdf..cfe860decc 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -239,13 +239,14 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
         return events, to_key
 
     async def get_new_events_as(
-        self, from_key: int, service: ApplicationService
+        self, from_key: int, to_key: int, service: ApplicationService
     ) -> Tuple[List[JsonDict], int]:
         """Returns a set of new read receipt events that an appservice
         may be interested in.
 
         Args:
             from_key: the stream position at which events should be fetched from
+            to_key: the stream position up to which events should be fetched to
             service: The appservice which may be interested
 
         Returns:
@@ -255,7 +256,6 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
                 * The current read receipt stream token.
         """
         from_key = int(from_key)
-        to_key = self.get_current_key()
 
         if from_key == to_key:
             return [], to_key
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index 5efb561273..b5dc9f74b3 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -11,6 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import collections.abc
 import logging
 from typing import (
     TYPE_CHECKING,
@@ -24,7 +25,6 @@ from typing import (
 )
 
 import attr
-from frozendict import frozendict
 
 from synapse.api.constants import RelationTypes
 from synapse.api.errors import SynapseError
@@ -380,7 +380,7 @@ class RelationsHandler:
             # Do not bundle aggregations for an event which represents an edit or an
             # annotation. It does not make sense for them to have related events.
             relates_to = event.content.get("m.relates_to")
-            if isinstance(relates_to, (dict, frozendict)):
+            if isinstance(relates_to, collections.abc.Mapping):
                 relation_type = relates_to.get("rel_type")
                 if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
                     continue
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 102dd4b57d..5619f8f50e 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -357,7 +357,7 @@ class SearchHandler:
             itertools.chain(
                 # The events_before and events_after for each context.
                 itertools.chain.from_iterable(
-                    itertools.chain(context["events_before"], context["events_after"])  # type: ignore[arg-type]
+                    itertools.chain(context["events_before"], context["events_after"])
                     for context in contexts.values()
                 ),
                 # The returned events.
@@ -373,10 +373,10 @@ class SearchHandler:
 
         for context in contexts.values():
             context["events_before"] = self._event_serializer.serialize_events(
-                context["events_before"], time_now, bundle_aggregations=aggregations  # type: ignore[arg-type]
+                context["events_before"], time_now, bundle_aggregations=aggregations
             )
             context["events_after"] = self._event_serializer.serialize_events(
-                context["events_after"], time_now, bundle_aggregations=aggregations  # type: ignore[arg-type]
+                context["events_after"], time_now, bundle_aggregations=aggregations
             )
 
         results = [
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 472b029af3..e2a441066d 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -256,7 +256,9 @@ class RegistrationTokenAuthChecker(UserInteractiveAuthChecker):
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         self.hs = hs
-        self._enabled = bool(hs.config.registration.registration_requires_token)
+        self._enabled = bool(
+            hs.config.registration.registration_requires_token
+        ) or bool(hs.config.registration.enable_registration_token_3pid_bypasss)
         self.store = hs.get_datastores().main
 
     def is_enabled(self) -> bool:
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 31ca841889..1cf49830e8 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -295,7 +295,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
             if isawaitable(raw_callback_return):
                 callback_return = await raw_callback_return
             else:
-                callback_return = raw_callback_return  # type: ignore
+                callback_return = raw_callback_return
 
             return callback_return
 
@@ -469,7 +469,7 @@ class JsonResource(DirectServeJsonResource):
         if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
             callback_return = await raw_callback_return
         else:
-            callback_return = raw_callback_return  # type: ignore
+            callback_return = raw_callback_return
 
         return callback_return
 
@@ -683,6 +683,9 @@ def respond_with_json(
     Returns:
         twisted.web.server.NOT_DONE_YET if the request is still active.
     """
+    # The response code must always be set, for logging purposes.
+    request.setResponseCode(code)
+
     # could alternatively use request.notifyFinish() and flip a flag when
     # the Deferred fires, but since the flag is RIGHT THERE it seems like
     # a waste.
@@ -697,7 +700,6 @@ def respond_with_json(
     else:
         encoder = _encode_json_bytes
 
-    request.setResponseCode(code)
     request.setHeader(b"Content-Type", b"application/json")
     request.setHeader(b"Cache-Control", b"no-cache, no-store, must-revalidate")
 
@@ -728,13 +730,15 @@ def respond_with_json_bytes(
     Returns:
         twisted.web.server.NOT_DONE_YET if the request is still active.
     """
+    # The response code must always be set, for logging purposes.
+    request.setResponseCode(code)
+
     if request._disconnected:
         logger.warning(
             "Not sending response to request %s, already disconnected.", request
         )
         return None
 
-    request.setResponseCode(code)
     request.setHeader(b"Content-Type", b"application/json")
     request.setHeader(b"Content-Length", b"%d" % (len(json_bytes),))
     request.setHeader(b"Cache-Control", b"no-cache, no-store, must-revalidate")
@@ -840,6 +844,9 @@ def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes) -> N
         code: The HTTP response code.
         html_bytes: The HTML bytes to use as the response body.
     """
+    # The response code must always be set, for logging purposes.
+    request.setResponseCode(code)
+
     # could alternatively use request.notifyFinish() and flip a flag when
     # the Deferred fires, but since the flag is RIGHT THERE it seems like
     # a waste.
@@ -849,7 +856,6 @@ def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes) -> N
         )
         return None
 
-    request.setResponseCode(code)
     request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
     request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
 
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 88cd8a9e1c..fd9cb97920 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -722,6 +722,11 @@ P = ParamSpec("P")
 R = TypeVar("R")
 
 
+async def _unwrap_awaitable(awaitable: Awaitable[R]) -> R:
+    """Unwraps an arbitrary awaitable by awaiting it."""
+    return await awaitable
+
+
 @overload
 def preserve_fn(  # type: ignore[misc]
     f: Callable[P, Awaitable[R]],
@@ -802,17 +807,20 @@ def run_in_background(  # type: ignore[misc]
         # by synchronous exceptions, so let's turn them into Failures.
         return defer.fail()
 
+    # `res` may be a coroutine, `Deferred`, some other kind of awaitable, or a plain
+    # value. Convert it to a `Deferred`.
     if isinstance(res, typing.Coroutine):
+        # Wrap the coroutine in a `Deferred`.
         res = defer.ensureDeferred(res)
-
-    # At this point we should have a Deferred, if not then f was a synchronous
-    # function, wrap it in a Deferred for consistency.
-    if not isinstance(res, defer.Deferred):
-        # `res` is not a `Deferred` and not a `Coroutine`.
-        # There are no other types of `Awaitable`s we expect to encounter in Synapse.
-        assert not isinstance(res, Awaitable)
-
-        return defer.succeed(res)
+    elif isinstance(res, defer.Deferred):
+        pass
+    elif isinstance(res, Awaitable):
+        # `res` is probably some kind of completed awaitable, such as a `DoneAwaitable`
+        # or `Future` from `make_awaitable`.
+        res = defer.ensureDeferred(_unwrap_awaitable(res))
+    else:
+        # `res` is a plain value. Wrap it in a `Deferred`.
+        res = defer.succeed(res)
 
     if res.called and not res.paused:
         # The function should have maintained the logcontext, so we can
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 8f9e629274..834fe1b62c 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -82,6 +82,7 @@ from synapse.handlers.auth import (
     ON_LOGGED_OUT_CALLBACK,
     AuthHandler,
 )
+from synapse.handlers.push_rules import RuleSpec, check_actions
 from synapse.http.client import SimpleHttpClient
 from synapse.http.server import (
     DirectServeHtmlResource,
@@ -109,6 +110,7 @@ from synapse.storage.state import StateFilter
 from synapse.types import (
     DomainSpecificString,
     JsonDict,
+    JsonMapping,
     Requester,
     StateMap,
     UserID,
@@ -151,6 +153,7 @@ __all__ = [
     "PRESENCE_ALL_USERS",
     "LoginResponse",
     "JsonDict",
+    "JsonMapping",
     "EventBase",
     "StateMap",
     "ProfileInfo",
@@ -193,6 +196,7 @@ class ModuleApi:
         self._clock: Clock = hs.get_clock()
         self._registration_handler = hs.get_registration_handler()
         self._send_email_handler = hs.get_send_email_handler()
+        self._push_rules_handler = hs.get_push_rules_handler()
         self.custom_template_dir = hs.config.server.custom_template_directory
 
         try:
@@ -1350,6 +1354,68 @@ class ModuleApi:
         """
         await self._store.add_user_bound_threepid(user_id, medium, address, id_server)
 
+    def check_push_rule_actions(
+        self, actions: List[Union[str, Dict[str, str]]]
+    ) -> None:
+        """Checks if the given push rule actions are valid according to the Matrix
+        specification.
+
+        See https://spec.matrix.org/v1.2/client-server-api/#actions for the list of valid
+        actions.
+
+        Added in Synapse v1.58.0.
+
+        Args:
+            actions: the actions to check.
+
+        Raises:
+            synapse.module_api.errors.InvalidRuleException if the actions are invalid.
+        """
+        check_actions(actions)
+
+    async def set_push_rule_action(
+        self,
+        user_id: str,
+        scope: str,
+        kind: str,
+        rule_id: str,
+        actions: List[Union[str, Dict[str, str]]],
+    ) -> None:
+        """Changes the actions of an existing push rule for the given user.
+
+        See https://spec.matrix.org/v1.2/client-server-api/#push-rules for more
+        information about push rules and their syntax.
+
+        Can only be called on the main process.
+
+        Added in Synapse v1.58.0.
+
+        Args:
+            user_id: the user for which to change the push rule's actions.
+            scope: the push rule's scope, currently only "global" is allowed.
+            kind: the push rule's kind.
+            rule_id: the push rule's identifier.
+            actions: the actions to run when the rule's conditions match.
+
+        Raises:
+            RuntimeError if this method is called on a worker or `scope` is invalid.
+            synapse.module_api.errors.RuleNotFoundException if the rule being modified
+                can't be found.
+            synapse.module_api.errors.InvalidRuleException if the actions are invalid.
+        """
+        if self.worker_app is not None:
+            raise RuntimeError("module tried to change push rule actions on a worker")
+
+        if scope != "global":
+            raise RuntimeError(
+                "invalid scope %s, only 'global' is currently allowed" % scope
+            )
+
+        spec = RuleSpec(scope, kind, rule_id, "actions")
+        await self._push_rules_handler.set_rule_attr(
+            user_id, spec, {"actions": actions}
+        )
+
 
 class PublicRoomListManager:
     """Contains methods for adding to, removing from and querying whether a room
@@ -1419,7 +1485,7 @@ class AccountDataManager:
                 f"{user_id} is not local to this homeserver; can't access account data for remote users."
             )
 
-    async def get_global(self, user_id: str, data_type: str) -> Optional[JsonDict]:
+    async def get_global(self, user_id: str, data_type: str) -> Optional[JsonMapping]:
         """
         Gets some global account data, of a specified type, for the specified user.
 
diff --git a/synapse/module_api/errors.py b/synapse/module_api/errors.py
index 1db900e41f..e58e0e60fe 100644
--- a/synapse/module_api/errors.py
+++ b/synapse/module_api/errors.py
@@ -20,10 +20,14 @@ from synapse.api.errors import (
     SynapseError,
 )
 from synapse.config._base import ConfigError
+from synapse.handlers.push_rules import InvalidRuleException
+from synapse.storage.push_rule import RuleNotFoundException
 
 __all__ = [
     "InvalidClientCredentialsError",
     "RedirectException",
     "SynapseError",
     "ConfigError",
+    "InvalidRuleException",
+    "RuleNotFoundException",
 ]
diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py
index a93f6fd5e0..b98640b14a 100644
--- a/synapse/rest/client/push_rule.py
+++ b/synapse/rest/client/push_rule.py
@@ -12,9 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, Union
-
-import attr
+from typing import TYPE_CHECKING, List, Sequence, Tuple, Union
 
 from synapse.api.errors import (
     NotFoundError,
@@ -22,6 +20,7 @@ from synapse.api.errors import (
     SynapseError,
     UnrecognizedRequestError,
 )
+from synapse.handlers.push_rules import InvalidRuleException, RuleSpec, check_actions
 from synapse.http.server import HttpServer
 from synapse.http.servlet import (
     RestServlet,
@@ -29,7 +28,6 @@ from synapse.http.servlet import (
     parse_string,
 )
 from synapse.http.site import SynapseRequest
-from synapse.push.baserules import BASE_RULE_IDS
 from synapse.push.clientformat import format_push_rules_for_user
 from synapse.push.rulekinds import PRIORITY_CLASS_MAP
 from synapse.rest.client._base import client_patterns
@@ -40,14 +38,6 @@ if TYPE_CHECKING:
     from synapse.server import HomeServer
 
 
-@attr.s(slots=True, frozen=True, auto_attribs=True)
-class RuleSpec:
-    scope: str
-    template: str
-    rule_id: str
-    attr: Optional[str]
-
-
 class PushRuleRestServlet(RestServlet):
     PATTERNS = client_patterns("/(?P<path>pushrules/.*)$", v1=True)
     SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = (
@@ -60,6 +50,7 @@ class PushRuleRestServlet(RestServlet):
         self.store = hs.get_datastores().main
         self.notifier = hs.get_notifier()
         self._is_worker = hs.config.worker.worker_app is not None
+        self._push_rules_handler = hs.get_push_rules_handler()
 
     async def on_PUT(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]:
         if self._is_worker:
@@ -81,8 +72,13 @@ class PushRuleRestServlet(RestServlet):
         user_id = requester.user.to_string()
 
         if spec.attr:
-            await self.set_rule_attr(user_id, spec, content)
-            self.notify_user(user_id)
+            try:
+                await self._push_rules_handler.set_rule_attr(user_id, spec, content)
+            except InvalidRuleException as e:
+                raise SynapseError(400, "Invalid actions: %s" % e)
+            except RuleNotFoundException:
+                raise NotFoundError("Unknown rule")
+
             return 200, {}
 
         if spec.rule_id.startswith("."):
@@ -98,23 +94,23 @@ class PushRuleRestServlet(RestServlet):
 
         before = parse_string(request, "before")
         if before:
-            before = _namespaced_rule_id(spec, before)
+            before = f"global/{spec.template}/{before}"
 
         after = parse_string(request, "after")
         if after:
-            after = _namespaced_rule_id(spec, after)
+            after = f"global/{spec.template}/{after}"
 
         try:
             await self.store.add_push_rule(
                 user_id=user_id,
-                rule_id=_namespaced_rule_id_from_spec(spec),
+                rule_id=f"global/{spec.template}/{spec.rule_id}",
                 priority_class=priority_class,
                 conditions=conditions,
                 actions=actions,
                 before=before,
                 after=after,
             )
-            self.notify_user(user_id)
+            self._push_rules_handler.notify_user(user_id)
         except InconsistentRuleException as e:
             raise SynapseError(400, str(e))
         except RuleNotFoundException as e:
@@ -133,11 +129,11 @@ class PushRuleRestServlet(RestServlet):
         requester = await self.auth.get_user_by_req(request)
         user_id = requester.user.to_string()
 
-        namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
+        namespaced_rule_id = f"global/{spec.template}/{spec.rule_id}"
 
         try:
             await self.store.delete_push_rule(user_id, namespaced_rule_id)
-            self.notify_user(user_id)
+            self._push_rules_handler.notify_user(user_id)
             return 200, {}
         except StoreError as e:
             if e.code == 404:
@@ -172,55 +168,6 @@ class PushRuleRestServlet(RestServlet):
         else:
             raise UnrecognizedRequestError()
 
-    def notify_user(self, user_id: str) -> None:
-        stream_id = self.store.get_max_push_rules_stream_id()
-        self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
-
-    async def set_rule_attr(
-        self, user_id: str, spec: RuleSpec, val: Union[bool, JsonDict]
-    ) -> None:
-        if spec.attr not in ("enabled", "actions"):
-            # for the sake of potential future expansion, shouldn't report
-            # 404 in the case of an unknown request so check it corresponds to
-            # a known attribute first.
-            raise UnrecognizedRequestError()
-
-        namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
-        rule_id = spec.rule_id
-        is_default_rule = rule_id.startswith(".")
-        if is_default_rule:
-            if namespaced_rule_id not in BASE_RULE_IDS:
-                raise NotFoundError("Unknown rule %s" % (namespaced_rule_id,))
-        if spec.attr == "enabled":
-            if isinstance(val, dict) and "enabled" in val:
-                val = val["enabled"]
-            if not isinstance(val, bool):
-                # Legacy fallback
-                # This should *actually* take a dict, but many clients pass
-                # bools directly, so let's not break them.
-                raise SynapseError(400, "Value for 'enabled' must be boolean")
-            await self.store.set_push_rule_enabled(
-                user_id, namespaced_rule_id, val, is_default_rule
-            )
-        elif spec.attr == "actions":
-            if not isinstance(val, dict):
-                raise SynapseError(400, "Value must be a dict")
-            actions = val.get("actions")
-            if not isinstance(actions, list):
-                raise SynapseError(400, "Value for 'actions' must be dict")
-            _check_actions(actions)
-            namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
-            rule_id = spec.rule_id
-            is_default_rule = rule_id.startswith(".")
-            if is_default_rule:
-                if namespaced_rule_id not in BASE_RULE_IDS:
-                    raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,))
-            await self.store.set_push_rule_actions(
-                user_id, namespaced_rule_id, actions, is_default_rule
-            )
-        else:
-            raise UnrecognizedRequestError()
-
 
 def _rule_spec_from_path(path: Sequence[str]) -> RuleSpec:
     """Turn a sequence of path components into a rule spec
@@ -291,24 +238,11 @@ def _rule_tuple_from_request_object(
         raise InvalidRuleException("No actions found")
     actions = req_obj["actions"]
 
-    _check_actions(actions)
+    check_actions(actions)
 
     return conditions, actions
 
 
-def _check_actions(actions: List[Union[str, JsonDict]]) -> None:
-    if not isinstance(actions, list):
-        raise InvalidRuleException("No actions found")
-
-    for a in actions:
-        if a in ["notify", "dont_notify", "coalesce"]:
-            pass
-        elif isinstance(a, dict) and "set_tweak" in a:
-            pass
-        else:
-            raise InvalidRuleException("Unrecognised action")
-
-
 def _filter_ruleset_with_path(ruleset: JsonDict, path: List[str]) -> JsonDict:
     if path == []:
         raise UnrecognizedRequestError(
@@ -357,17 +291,5 @@ def _priority_class_from_spec(spec: RuleSpec) -> int:
     return pc
 
 
-def _namespaced_rule_id_from_spec(spec: RuleSpec) -> str:
-    return _namespaced_rule_id(spec, spec.rule_id)
-
-
-def _namespaced_rule_id(spec: RuleSpec, rule_id: str) -> str:
-    return "global/%s/%s" % (spec.template, rule_id)
-
-
-class InvalidRuleException(Exception):
-    pass
-
-
 def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     PushRuleRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index 70baf50fa4..13ef6b35a0 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -929,6 +929,10 @@ def _calculate_registration_flows(
         # always let users provide both MSISDN & email
         flows.append([LoginType.MSISDN, LoginType.EMAIL_IDENTITY])
 
+    # Add a flow that doesn't require any 3pids, if the config requests it.
+    if config.registration.enable_registration_token_3pid_bypasss:
+        flows.append([LoginType.REGISTRATION_TOKEN])
+
     # Prepend m.login.terms to all flows if we're requiring consent
     if config.consent.user_consent_at_registration:
         for flow in flows:
@@ -942,7 +946,8 @@ def _calculate_registration_flows(
     # Prepend registration token to all flows if we're requiring a token
     if config.registration.registration_requires_token:
         for flow in flows:
-            flow.insert(0, LoginType.REGISTRATION_TOKEN)
+            if LoginType.REGISTRATION_TOKEN not in flow:
+                flow.insert(0, LoginType.REGISTRATION_TOKEN)
 
     return flows
 
diff --git a/synapse/server.py b/synapse/server.py
index 37c72bd83a..d49c76518a 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -91,6 +91,7 @@ from synapse.handlers.presence import (
     WorkerPresenceHandler,
 )
 from synapse.handlers.profile import ProfileHandler
+from synapse.handlers.push_rules import PushRulesHandler
 from synapse.handlers.read_marker import ReadMarkerHandler
 from synapse.handlers.receipts import ReceiptsHandler
 from synapse.handlers.register import RegistrationHandler
@@ -811,6 +812,10 @@ class HomeServer(metaclass=abc.ABCMeta):
         return AccountHandler(self)
 
     @cache_in_self
+    def get_push_rules_handler(self) -> PushRulesHandler:
+        return PushRulesHandler(self)
+
+    @cache_in_self
     def get_outbound_redis_connection(self) -> "ConnectionHandler":
         """
         The Redis connection used for replication.
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 951031af50..5895b89202 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -15,12 +15,17 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, List, Optional, Tuple
+from typing import TYPE_CHECKING, List, Optional, Tuple, cast
 
 from synapse.config.homeserver import HomeServerConfig
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.databases.main.stats import UserSortOrder
-from synapse.storage.engines import PostgresEngine
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.types import Cursor
 from synapse.storage.util.id_generators import (
     IdGenerator,
     MultiWriterIdGenerator,
@@ -266,7 +271,9 @@ class DataStore(
             A tuple of a list of mappings from user to information and a count of total users.
         """
 
-        def get_users_paginate_txn(txn):
+        def get_users_paginate_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[JsonDict], int]:
             filters = []
             args = [self.hs.config.server.server_name]
 
@@ -301,7 +308,7 @@ class DataStore(
                 """
             sql = "SELECT COUNT(*) as total_users " + sql_base
             txn.execute(sql, args)
-            count = txn.fetchone()[0]
+            count = cast(Tuple[int], txn.fetchone())[0]
 
             sql = f"""
                 SELECT name, user_type, is_guest, admin, deactivated, shadow_banned,
@@ -338,7 +345,9 @@ class DataStore(
         )
 
 
-def check_database_before_upgrade(cur, database_engine, config: HomeServerConfig):
+def check_database_before_upgrade(
+    cur: Cursor, database_engine: BaseDatabaseEngine, config: HomeServerConfig
+) -> None:
     """Called before upgrading an existing database to check that it is broadly sane
     compared with the configuration.
     """
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index fa732edcca..945707b0ec 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import logging
 import re
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Tuple, cast
 
 from synapse.appservice import (
     ApplicationService,
@@ -83,7 +83,7 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore):
             txn.execute(
                 "SELECT COALESCE(max(txn_id), 0) FROM application_services_txns"
             )
-            return txn.fetchone()[0]  # type: ignore
+            return cast(Tuple[int], txn.fetchone())[0]
 
         self._as_txn_seq_gen = build_sequence_generator(
             db_conn,
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index b4a1b041b1..599b418383 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -14,7 +14,17 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple, cast
+from typing import (
+    TYPE_CHECKING,
+    Collection,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Set,
+    Tuple,
+    cast,
+)
 
 from synapse.logging import issue9533_logger
 from synapse.logging.opentracing import log_kv, set_tag, trace
@@ -118,7 +128,13 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             prefilled_cache=device_outbox_prefill,
         )
 
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
+    def process_replication_rows(
+        self,
+        stream_name: str,
+        instance_name: str,
+        token: int,
+        rows: Iterable[ToDeviceStream.ToDeviceStreamRow],
+    ) -> None:
         if stream_name == ToDeviceStream.NAME:
             # If replication is happening than postgres must be being used.
             assert isinstance(self._device_inbox_id_gen, MultiWriterIdGenerator)
@@ -134,7 +150,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                     )
         return super().process_replication_rows(stream_name, instance_name, token, rows)
 
-    def get_to_device_stream_token(self):
+    def get_to_device_stream_token(self) -> int:
         return self._device_inbox_id_gen.get_current_token()
 
     async def get_messages_for_user_devices(
@@ -301,7 +317,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
         if not user_ids_to_query:
             return {}, to_stream_id
 
-        def get_device_messages_txn(txn: LoggingTransaction):
+        def get_device_messages_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]:
             # Build a query to select messages from any of the given devices that
             # are between the given stream id bounds.
 
@@ -428,7 +446,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 log_kv({"message": "No changes in cache since last check"})
                 return 0
 
-        def delete_messages_for_device_txn(txn):
+        def delete_messages_for_device_txn(txn: LoggingTransaction) -> int:
             sql = (
                 "DELETE FROM device_inbox"
                 " WHERE user_id = ? AND device_id = ?"
@@ -455,15 +473,14 @@ class DeviceInboxWorkerStore(SQLBaseStore):
 
     @trace
     async def get_new_device_msgs_for_remote(
-        self, destination, last_stream_id, current_stream_id, limit
-    ) -> Tuple[List[dict], int]:
+        self, destination: str, last_stream_id: int, current_stream_id: int, limit: int
+    ) -> Tuple[List[JsonDict], int]:
         """
         Args:
-            destination(str): The name of the remote server.
-            last_stream_id(int|long): The last position of the device message stream
+            destination: The name of the remote server.
+            last_stream_id: The last position of the device message stream
                 that the server sent up to.
-            current_stream_id(int|long): The current position of the device
-                message stream.
+            current_stream_id: The current position of the device message stream.
         Returns:
             A list of messages for the device and where in the stream the messages got to.
         """
@@ -485,7 +502,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             return [], last_stream_id
 
         @trace
-        def get_new_messages_for_remote_destination_txn(txn):
+        def get_new_messages_for_remote_destination_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[JsonDict], int]:
             sql = (
                 "SELECT stream_id, messages_json FROM device_federation_outbox"
                 " WHERE destination = ?"
@@ -527,7 +546,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             up_to_stream_id: Where to delete messages up to.
         """
 
-        def delete_messages_for_remote_destination_txn(txn):
+        def delete_messages_for_remote_destination_txn(txn: LoggingTransaction) -> None:
             sql = (
                 "DELETE FROM device_federation_outbox"
                 " WHERE destination = ?"
@@ -566,7 +585,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
         if last_id == current_id:
             return [], current_id, False
 
-        def get_all_new_device_messages_txn(txn):
+        def get_all_new_device_messages_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
             # We limit like this as we might have multiple rows per stream_id, and
             # we want to make sure we always get all entries for any stream_id
             # we return.
@@ -607,8 +628,8 @@ class DeviceInboxWorkerStore(SQLBaseStore):
     @trace
     async def add_messages_to_device_inbox(
         self,
-        local_messages_by_user_then_device: dict,
-        remote_messages_by_destination: dict,
+        local_messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
+        remote_messages_by_destination: Dict[str, JsonDict],
     ) -> int:
         """Used to send messages from this server.
 
@@ -624,7 +645,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
 
         assert self._can_write_to_device
 
-        def add_messages_txn(txn, now_ms, stream_id):
+        def add_messages_txn(
+            txn: LoggingTransaction, now_ms: int, stream_id: int
+        ) -> None:
             # Add the local messages directly to the local inbox.
             self._add_messages_to_local_device_inbox_txn(
                 txn, stream_id, local_messages_by_user_then_device
@@ -677,11 +700,16 @@ class DeviceInboxWorkerStore(SQLBaseStore):
         return self._device_inbox_id_gen.get_current_token()
 
     async def add_messages_from_remote_to_device_inbox(
-        self, origin: str, message_id: str, local_messages_by_user_then_device: dict
+        self,
+        origin: str,
+        message_id: str,
+        local_messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
     ) -> int:
         assert self._can_write_to_device
 
-        def add_messages_txn(txn, now_ms, stream_id):
+        def add_messages_txn(
+            txn: LoggingTransaction, now_ms: int, stream_id: int
+        ) -> None:
             # Check if we've already inserted a matching message_id for that
             # origin. This can happen if the origin doesn't receive our
             # acknowledgement from the first time we received the message.
@@ -727,8 +755,11 @@ class DeviceInboxWorkerStore(SQLBaseStore):
         return stream_id
 
     def _add_messages_to_local_device_inbox_txn(
-        self, txn, stream_id, messages_by_user_then_device
-    ):
+        self,
+        txn: LoggingTransaction,
+        stream_id: int,
+        messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
+    ) -> None:
         assert self._can_write_to_device
 
         local_by_user_then_device = {}
@@ -840,8 +871,10 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
             self._remove_dead_devices_from_device_inbox,
         )
 
-    async def _background_drop_index_device_inbox(self, progress, batch_size):
-        def reindex_txn(conn):
+    async def _background_drop_index_device_inbox(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
+        def reindex_txn(conn: LoggingDatabaseConnection) -> None:
             txn = conn.cursor()
             txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
             txn.close()
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 483dd80406..2df4dd4ed4 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -25,6 +25,7 @@ from typing import (
     Optional,
     Set,
     Tuple,
+    cast,
 )
 
 from synapse.api.errors import Codes, StoreError
@@ -136,7 +137,9 @@ class DeviceWorkerStore(SQLBaseStore):
             Number of devices of this users.
         """
 
-        def count_devices_by_users_txn(txn, user_ids):
+        def count_devices_by_users_txn(
+            txn: LoggingTransaction, user_ids: List[str]
+        ) -> int:
             sql = """
                 SELECT count(*)
                 FROM devices
@@ -149,7 +152,7 @@ class DeviceWorkerStore(SQLBaseStore):
             )
 
             txn.execute(sql + clause, args)
-            return txn.fetchone()[0]
+            return cast(Tuple[int], txn.fetchone())[0]
 
         if not user_ids:
             return 0
@@ -468,7 +471,7 @@ class DeviceWorkerStore(SQLBaseStore):
         """
         txn.execute(sql, (destination, from_stream_id, now_stream_id, limit))
 
-        return list(txn)
+        return cast(List[Tuple[str, str, int, Optional[str]]], txn.fetchall())
 
     async def _get_device_update_edus_by_remote(
         self,
@@ -549,7 +552,7 @@ class DeviceWorkerStore(SQLBaseStore):
     async def _get_last_device_update_for_remote_user(
         self, destination: str, user_id: str, from_stream_id: int
     ) -> int:
-        def f(txn):
+        def f(txn: LoggingTransaction) -> int:
             prev_sent_id_sql = """
                 SELECT coalesce(max(stream_id), 0) as stream_id
                 FROM device_lists_outbound_last_success
@@ -767,7 +770,7 @@ class DeviceWorkerStore(SQLBaseStore):
         if not user_ids_to_check:
             return set()
 
-        def _get_users_whose_devices_changed_txn(txn):
+        def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
             changes = set()
 
             stream_id_where_clause = "stream_id > ?"
@@ -966,7 +969,9 @@ class DeviceWorkerStore(SQLBaseStore):
     async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None:
         """Mark that we no longer track device lists for remote user."""
 
-        def _mark_remote_user_device_list_as_unsubscribed_txn(txn):
+        def _mark_remote_user_device_list_as_unsubscribed_txn(
+            txn: LoggingTransaction,
+        ) -> None:
             self.db_pool.simple_delete_txn(
                 txn,
                 table="device_lists_remote_extremeties",
@@ -1004,7 +1009,7 @@ class DeviceWorkerStore(SQLBaseStore):
         )
 
     def _store_dehydrated_device_txn(
-        self, txn, user_id: str, device_id: str, device_data: str
+        self, txn: LoggingTransaction, user_id: str, device_id: str, device_data: str
     ) -> Optional[str]:
         old_device_id = self.db_pool.simple_select_one_onecol_txn(
             txn,
@@ -1081,7 +1086,7 @@ class DeviceWorkerStore(SQLBaseStore):
         """
         yesterday = self._clock.time_msec() - prune_age
 
-        def _prune_txn(txn):
+        def _prune_txn(txn: LoggingTransaction) -> None:
             # look for (user, destination) pairs which have an update older than
             # the cutoff.
             #
@@ -1204,8 +1209,10 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
             "drop_device_lists_outbound_last_success_non_unique_idx",
         )
 
-    async def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
-        def f(conn):
+    async def _drop_device_list_streams_non_unique_indexes(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
+        def f(conn: LoggingDatabaseConnection) -> None:
             txn = conn.cursor()
             txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
             txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
@@ -1217,7 +1224,9 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
         )
         return 1
 
-    async def _remove_duplicate_outbound_pokes(self, progress, batch_size):
+    async def _remove_duplicate_outbound_pokes(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         # for some reason, we have accumulated duplicate entries in
         # device_lists_outbound_pokes, which makes prune_outbound_device_list_pokes less
         # efficient.
@@ -1230,7 +1239,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
             {"stream_id": 0, "destination": "", "user_id": "", "device_id": ""},
         )
 
-        def _txn(txn):
+        def _txn(txn: LoggingTransaction) -> int:
             clause, args = make_tuple_comparison_clause(
                 [(x, last_row[x]) for x in KEY_COLS]
             )
@@ -1602,7 +1611,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
 
         context = get_active_span_text_map()
 
-        def add_device_changes_txn(txn, stream_ids):
+        def add_device_changes_txn(
+            txn: LoggingTransaction, stream_ids: List[int]
+        ) -> None:
             self._add_device_change_to_stream_txn(
                 txn,
                 user_id,
@@ -1635,8 +1646,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         txn: LoggingTransaction,
         user_id: str,
         device_ids: Collection[str],
-        stream_ids: List[str],
-    ):
+        stream_ids: List[int],
+    ) -> None:
         txn.call_after(
             self._device_list_stream_cache.entity_has_changed,
             user_id,
@@ -1720,7 +1731,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         user_id: str,
         device_ids: Iterable[str],
         room_ids: Collection[str],
-        stream_ids: List[str],
+        stream_ids: List[int],
         context: Dict[str, str],
     ) -> None:
         """Record the user in the room has updated their device."""
@@ -1775,7 +1786,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             LIMIT ?
         """
 
-        def get_uncoverted_outbound_room_pokes_txn(txn):
+        def get_uncoverted_outbound_room_pokes_txn(
+            txn: LoggingTransaction,
+        ) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:
             txn.execute(sql, (limit,))
 
             return [
@@ -1808,7 +1821,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         Marks the associated row in `device_lists_changes_in_room` as handled.
         """
 
-        def add_device_list_outbound_pokes_txn(txn, stream_ids: List[int]):
+        def add_device_list_outbound_pokes_txn(
+            txn: LoggingTransaction, stream_ids: List[int]
+        ) -> None:
             if hosts:
                 self._add_device_outbound_poke_to_stream_txn(
                     txn,
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 2a1e567ce0..9a6c2fd47a 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -47,6 +47,7 @@ from synapse.storage.database import (
 )
 from synapse.storage.databases.main.events_worker import EventCacheEntry
 from synapse.storage.databases.main.search import SearchEntry
+from synapse.storage.engines.postgres import PostgresEngine
 from synapse.storage.util.id_generators import AbstractStreamIdGenerator
 from synapse.storage.util.sequence import SequenceGenerator
 from synapse.types import StateMap, get_domain_from_id
@@ -364,6 +365,20 @@ class PersistEventsStore:
         min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering
         max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
 
+        # We check that the room still exists for events we're trying to
+        # persist. This is to protect against races with deleting a room.
+        #
+        # Annoyingly SQLite doesn't support row level locking.
+        if isinstance(self.database_engine, PostgresEngine):
+            for room_id in {e.room_id for e, _ in events_and_contexts}:
+                txn.execute(
+                    "SELECT room_version FROM rooms WHERE room_id = ? FOR SHARE",
+                    (room_id,),
+                )
+                row = txn.fetchone()
+                if row is None:
+                    raise Exception(f"Room does not exist {room_id}")
+
         # stream orderings should have been assigned by now
         assert min_stream_order
         assert max_stream_order
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index 0aef121d83..04efad9e9a 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -522,7 +522,9 @@ class GroupServerWorkerStore(SQLBaseStore):
             desc="get_joined_groups",
         )
 
-    async def get_all_groups_for_user(self, user_id, now_token) -> List[JsonDict]:
+    async def get_all_groups_for_user(
+        self, user_id: str, now_token: int
+    ) -> List[JsonDict]:
         def _get_all_groups_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]:
             sql = """
                 SELECT group_id, type, membership, u.content
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index 6990f3ed1d..0a19f607bd 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -15,11 +15,12 @@
 
 import itertools
 import logging
-from typing import Dict, Iterable, List, Optional, Tuple
+from typing import Any, Dict, Iterable, List, Optional, Tuple
 
 from signedjson.key import decode_verify_key_bytes
 
 from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import LoggingTransaction
 from synapse.storage.keys import FetchKeyResult
 from synapse.storage.types import Cursor
 from synapse.util.caches.descriptors import cached, cachedList
@@ -35,7 +36,9 @@ class KeyStore(SQLBaseStore):
     """Persistence for signature verification keys"""
 
     @cached()
-    def _get_server_verify_key(self, server_name_and_key_id):
+    def _get_server_verify_key(
+        self, server_name_and_key_id: Tuple[str, str]
+    ) -> FetchKeyResult:
         raise NotImplementedError()
 
     @cachedList(
@@ -179,19 +182,21 @@ class KeyStore(SQLBaseStore):
 
     async def get_server_keys_json(
         self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]]
-    ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[dict]]:
+    ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
         """Retrieve the key json for a list of server_keys and key ids.
         If no keys are found for a given server, key_id and source then
         that server, key_id, and source triplet entry will be an empty list.
         The JSON is returned as a byte array so that it can be efficiently
         used in an HTTP response.
         Args:
-            server_keys (list): List of (server_name, key_id, source) triplets.
+            server_keys: List of (server_name, key_id, source) triplets.
         Returns:
             A mapping from (server_name, key_id, source) triplets to a list of dicts
         """
 
-        def _get_server_keys_json_txn(txn):
+        def _get_server_keys_json_txn(
+            txn: LoggingTransaction,
+        ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
             results = {}
             for server_name, key_id, from_server in server_keys:
                 keyvalues = {"server_name": server_name}
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 322ed05390..40ac377ca9 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -388,7 +388,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
 
     async def store_url_cache(
-        self, url, response_code, etag, expires_ts, og, media_id, download_ts
+        self,
+        url: str,
+        response_code: int,
+        etag: Optional[str],
+        expires_ts: int,
+        og: Optional[str],
+        media_id: str,
+        download_ts: int,
     ) -> None:
         await self.db_pool.simple_insert(
             "local_media_repository_url_cache",
@@ -441,7 +448,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         )
 
     async def get_cached_remote_media(
-        self, origin, media_id: str
+        self, origin: str, media_id: str
     ) -> Optional[Dict[str, Any]]:
         return await self.db_pool.simple_select_one(
             "remote_media_cache",
@@ -608,7 +615,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         )
 
     async def delete_remote_media(self, media_origin: str, media_id: str) -> None:
-        def delete_remote_media_txn(txn):
+        def delete_remote_media_txn(txn: LoggingTransaction) -> None:
             self.db_pool.simple_delete_txn(
                 txn,
                 "remote_media_cache",
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index 4f1c22c71b..5beb8f1d4b 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -232,10 +232,10 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore):
             # is racy.
             # Have resolved to invalidate the whole cache for now and do
             # something about it if and when the perf becomes significant
-            self._invalidate_all_cache_and_stream(  # type: ignore[attr-defined]
+            self._invalidate_all_cache_and_stream(
                 txn, self.user_last_seen_monthly_active
             )
-            self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ())  # type: ignore[attr-defined]
+            self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ())
 
         reserved_users = await self.get_registered_reserved_users()
         await self.db_pool.runInteraction(
@@ -363,7 +363,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore):
 
         if self._limit_usage_by_mau or self._mau_stats_only:
             # Trial users and guests should not be included as part of MAU group
-            is_guest = await self.is_guest(user_id)  # type: ignore[attr-defined]
+            is_guest = await self.is_guest(user_id)
             if is_guest:
                 return
             is_trial = await self.is_trial_user(user_id)
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index d3c4611686..b47c511450 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple, cast
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Tuple, cast
 
 from synapse.api.presence import PresenceState, UserPresenceState
 from synapse.replication.tcp.streams import PresenceStream
@@ -103,7 +103,9 @@ class PresenceStore(PresenceBackgroundUpdateStore):
             prefilled_cache=presence_cache_prefill,
         )
 
-    async def update_presence(self, presence_states) -> Tuple[int, int]:
+    async def update_presence(
+        self, presence_states: List[UserPresenceState]
+    ) -> Tuple[int, int]:
         assert self._can_persist_presence
 
         stream_ordering_manager = self._presence_id_gen.get_next_mult(
@@ -121,7 +123,10 @@ class PresenceStore(PresenceBackgroundUpdateStore):
         return stream_orderings[-1], self._presence_id_gen.get_current_token()
 
     def _update_presence_txn(
-        self, txn: LoggingTransaction, stream_orderings, presence_states
+        self,
+        txn: LoggingTransaction,
+        stream_orderings: List[int],
+        presence_states: List[UserPresenceState],
     ) -> None:
         for stream_id, state in zip(stream_orderings, presence_states):
             txn.call_after(
@@ -405,7 +410,13 @@ class PresenceStore(PresenceBackgroundUpdateStore):
         self._presence_on_startup = []
         return active_on_startup
 
-    def process_replication_rows(self, stream_name, instance_name, token, rows) -> None:
+    def process_replication_rows(
+        self,
+        stream_name: str,
+        instance_name: str,
+        token: int,
+        rows: Iterable[Any],
+    ) -> None:
         if stream_name == PresenceStream.NAME:
             self._presence_id_gen.advance(instance_name, token)
             for row in rows:
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 2e3818e432..bfc85b3add 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -324,7 +324,12 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
         )
 
     def _purge_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[int]:
-        # First we fetch all the state groups that should be deleted, before
+        # We *immediately* delete the room from the rooms table. This ensures
+        # that we don't race when persisting events (as that transaction checks
+        # that the room exists).
+        txn.execute("DELETE FROM rooms WHERE room_id = ?", (room_id,))
+
+        # Next, we fetch all the state groups that should be deleted, before
         # we delete that information.
         txn.execute(
             """
@@ -403,7 +408,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
             "room_stats_state",
             "room_stats_current",
             "room_stats_earliest_token",
-            "rooms",
             "stream_ordering_to_exterm",
             "users_in_public_rooms",
             "users_who_share_private_rooms",
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 92539f5d41..eb85bbd392 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -16,7 +16,7 @@ import abc
 import logging
 from typing import TYPE_CHECKING, Dict, List, Tuple, Union
 
-from synapse.api.errors import NotFoundError, StoreError
+from synapse.api.errors import StoreError
 from synapse.push.baserules import list_with_base_rules
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.storage._base import SQLBaseStore, db_to_json
@@ -618,7 +618,7 @@ class PushRuleStore(PushRulesWorkerStore):
                 are always stored in the database `push_rules` table).
 
         Raises:
-            NotFoundError if the rule does not exist.
+            RuleNotFoundException if the rule does not exist.
         """
         async with self._push_rules_stream_id_gen.get_next() as stream_id:
             event_stream_ordering = self._stream_id_gen.get_current_token()
@@ -668,8 +668,7 @@ class PushRuleStore(PushRulesWorkerStore):
             )
             txn.execute(sql, (user_id, rule_id))
             if txn.fetchone() is None:
-                # needed to set NOT_FOUND code.
-                raise NotFoundError("Push rule does not exist.")
+                raise RuleNotFoundException("Push rule does not exist.")
 
         self.db_pool.simple_upsert_txn(
             txn,
@@ -698,9 +697,6 @@ class PushRuleStore(PushRulesWorkerStore):
         """
         Sets the `actions` state of a push rule.
 
-        Will throw NotFoundError if the rule does not exist; the Code for this
-        is NOT_FOUND.
-
         Args:
             user_id: the user ID of the user who wishes to enable/disable the rule
                 e.g. '@tina:example.org'
@@ -712,6 +708,9 @@ class PushRuleStore(PushRulesWorkerStore):
             is_default_rule: True if and only if this is a server-default rule.
                 This skips the check for existence (as only user-created rules
                 are always stored in the database `push_rules` table).
+
+        Raises:
+            RuleNotFoundException if the rule does not exist.
         """
         actions_json = json_encoder.encode(actions)
 
@@ -744,7 +743,7 @@ class PushRuleStore(PushRulesWorkerStore):
                 except StoreError as serr:
                     if serr.code == 404:
                         # this sets the NOT_FOUND error Code
-                        raise NotFoundError("Push rule does not exist")
+                        raise RuleNotFoundException("Push rule does not exist")
                     else:
                         raise
 
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index cf64cd63a4..91286c9b65 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -14,11 +14,25 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Dict,
+    Iterable,
+    Iterator,
+    List,
+    Optional,
+    Tuple,
+    cast,
+)
 
 from synapse.push import PusherConfig, ThrottleParams
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.types import JsonDict
 from synapse.util import json_encoder
@@ -117,7 +131,7 @@ class PusherWorkerStore(SQLBaseStore):
         return self._decode_pushers_rows(ret)
 
     async def get_all_pushers(self) -> Iterator[PusherConfig]:
-        def get_pushers(txn):
+        def get_pushers(txn: LoggingTransaction) -> Iterator[PusherConfig]:
             txn.execute("SELECT * FROM pushers")
             rows = self.db_pool.cursor_to_dict(txn)
 
@@ -152,7 +166,9 @@ class PusherWorkerStore(SQLBaseStore):
         if last_id == current_id:
             return [], current_id, False
 
-        def get_all_updated_pushers_rows_txn(txn):
+        def get_all_updated_pushers_rows_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
             sql = """
                 SELECT id, user_name, app_id, pushkey
                 FROM pushers
@@ -160,10 +176,13 @@ class PusherWorkerStore(SQLBaseStore):
                 ORDER BY id ASC LIMIT ?
             """
             txn.execute(sql, (last_id, current_id, limit))
-            updates = [
-                (stream_id, (user_name, app_id, pushkey, False))
-                for stream_id, user_name, app_id, pushkey in txn
-            ]
+            updates = cast(
+                List[Tuple[int, tuple]],
+                [
+                    (stream_id, (user_name, app_id, pushkey, False))
+                    for stream_id, user_name, app_id, pushkey in txn
+                ],
+            )
 
             sql = """
                 SELECT stream_id, user_id, app_id, pushkey
@@ -192,12 +211,12 @@ class PusherWorkerStore(SQLBaseStore):
         )
 
     @cached(num_args=1, max_entries=15000)
-    async def get_if_user_has_pusher(self, user_id: str):
+    async def get_if_user_has_pusher(self, user_id: str) -> None:
         # This only exists for the cachedList decorator
         raise NotImplementedError()
 
     async def update_pusher_last_stream_ordering(
-        self, app_id, pushkey, user_id, last_stream_ordering
+        self, app_id: str, pushkey: str, user_id: str, last_stream_ordering: int
     ) -> None:
         await self.db_pool.simple_update_one(
             "pushers",
@@ -291,7 +310,7 @@ class PusherWorkerStore(SQLBaseStore):
 
         last_user = progress.get("last_user", "")
 
-        def _delete_pushers(txn) -> int:
+        def _delete_pushers(txn: LoggingTransaction) -> int:
 
             sql = """
                 SELECT name FROM users
@@ -339,7 +358,7 @@ class PusherWorkerStore(SQLBaseStore):
 
         last_pusher = progress.get("last_pusher", 0)
 
-        def _delete_pushers(txn) -> int:
+        def _delete_pushers(txn: LoggingTransaction) -> int:
 
             sql = """
                 SELECT p.id, access_token FROM pushers AS p
@@ -396,7 +415,7 @@ class PusherWorkerStore(SQLBaseStore):
 
         last_pusher = progress.get("last_pusher", 0)
 
-        def _delete_pushers(txn) -> int:
+        def _delete_pushers(txn: LoggingTransaction) -> int:
 
             sql = """
                 SELECT p.id, p.user_name, p.app_id, p.pushkey
@@ -502,7 +521,7 @@ class PusherStore(PusherWorkerStore):
     async def delete_pusher_by_app_id_pushkey_user_id(
         self, app_id: str, pushkey: str, user_id: str
     ) -> None:
-        def delete_pusher_txn(txn, stream_id):
+        def delete_pusher_txn(txn: LoggingTransaction, stream_id: int) -> None:
             self._invalidate_cache_and_stream(  # type: ignore[attr-defined]
                 txn, self.get_if_user_has_pusher, (user_id,)
             )
@@ -547,7 +566,7 @@ class PusherStore(PusherWorkerStore):
         # account.
         pushers = list(await self.get_pushers_by_user_id(user_id))
 
-        def delete_pushers_txn(txn, stream_ids):
+        def delete_pushers_txn(txn: LoggingTransaction, stream_ids: List[int]) -> None:
             self._invalidate_cache_and_stream(  # type: ignore[attr-defined]
                 txn, self.get_if_user_has_pusher, (user_id,)
             )
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 332e901dda..7d96f4feda 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -122,10 +122,21 @@ class ReceiptsWorkerStore(SQLBaseStore):
         receipts = await self.get_receipts_for_room(room_id, ReceiptTypes.READ)
         return {r["user_id"] for r in receipts}
 
-    @cached(num_args=2)
+    @cached()
     async def get_receipts_for_room(
         self, room_id: str, receipt_type: str
     ) -> List[Dict[str, Any]]:
+        """
+        Fetch the event IDs for the latest receipt for all users in a room with the given receipt type.
+
+        Args:
+            room_id: The room ID to fetch the receipt for.
+            receipt_type: The receipt type to fetch.
+
+        Returns:
+            A list of dictionaries, one for each user ID. Each dictionary
+            contains a user ID and the event ID of that user's latest receipt.
+        """
         return await self.db_pool.simple_select_list(
             table="receipts_linearized",
             keyvalues={"room_id": room_id, "receipt_type": receipt_type},
@@ -133,10 +144,21 @@ class ReceiptsWorkerStore(SQLBaseStore):
             desc="get_receipts_for_room",
         )
 
-    @cached(num_args=3)
+    @cached()
     async def get_last_receipt_event_id_for_user(
         self, user_id: str, room_id: str, receipt_type: str
     ) -> Optional[str]:
+        """
+        Fetch the event ID for the latest receipt in a room with the given receipt type.
+
+        Args:
+            user_id: The user to fetch receipts for.
+            room_id: The room ID to fetch the receipt for.
+            receipt_type: The receipt type to fetch.
+
+        Returns:
+            The event ID of the latest receipt, if one exists; otherwise `None`.
+        """
         return await self.db_pool.simple_select_one_onecol(
             table="receipts_linearized",
             keyvalues={
@@ -149,10 +171,23 @@ class ReceiptsWorkerStore(SQLBaseStore):
             allow_none=True,
         )
 
-    @cached(num_args=2)
+    @cached()
     async def get_receipts_for_user(
         self, user_id: str, receipt_type: str
     ) -> Dict[str, str]:
+        """
+        Fetch the event IDs for the latest receipts sent by the given user.
+
+        Args:
+            user_id: The user to fetch receipts for.
+            receipt_type: The receipt type to fetch.
+
+        Returns:
+            A map of room ID to the event ID of the latest receipt for that room.
+
+            If the user has not sent a receipt to a room then it will not appear
+            in the returned dictionary.
+        """
         rows = await self.db_pool.simple_select_list(
             table="receipts_linearized",
             keyvalues={"user_id": user_id, "receipt_type": receipt_type},
@@ -165,6 +200,17 @@ class ReceiptsWorkerStore(SQLBaseStore):
     async def get_receipts_for_user_with_orderings(
         self, user_id: str, receipt_type: str
     ) -> JsonDict:
+        """
+        Fetch receipts for all rooms that the given user is joined to.
+
+        Args:
+            user_id: The user to fetch receipts for.
+            receipt_type: The receipt type to fetch.
+
+        Returns:
+            A map of room ID to the latest receipt information.
+        """
+
         def f(txn: LoggingTransaction) -> List[Tuple[str, str, int, int]]:
             sql = (
                 "SELECT rl.room_id, rl.event_id,"
@@ -241,7 +287,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
         return await self._get_linearized_receipts_for_room(room_id, to_key, from_key)
 
-    @cached(num_args=3, tree=True)
+    @cached(tree=True)
     async def _get_linearized_receipts_for_room(
         self, room_id: str, to_key: int, from_key: Optional[int] = None
     ) -> List[JsonDict]:
@@ -541,7 +587,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         data: JsonDict,
         stream_id: int,
     ) -> Optional[int]:
-        """Inserts a read-receipt into the database if it's newer than the current RR
+        """Inserts a receipt into the database if it's newer than the current one.
 
         Returns:
             None if the RR is older than the current RR
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index e653841fe5..18ae8aee29 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -12,11 +12,10 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import collections.abc
 import logging
 from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple
 
-from frozendict import frozendict
-
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
@@ -160,7 +159,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         predecessor = create_event.content.get("predecessor", None)
 
         # Ensure the key is a dictionary
-        if not isinstance(predecessor, (dict, frozendict)):
+        if not isinstance(predecessor, collections.abc.Mapping):
             return None
 
         # The keys must be strings since the data is JSON.
@@ -370,10 +369,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
 
     def _update_state_for_partial_state_event_txn(
         self,
-        txn,
+        txn: LoggingTransaction,
         event: EventBase,
         context: EventContext,
-    ):
+    ) -> None:
         # we shouldn't have any outliers here
         assert not event.internal_metadata.is_outlier()
 
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 2d339b6008..f38bedbbcd 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -131,7 +131,7 @@ class UIAuthWorkerStore(SQLBaseStore):
         session_id: str,
         stage_type: str,
         result: Union[str, bool, JsonDict],
-    ):
+    ) -> None:
         """
         Mark a session stage as completed.
 
@@ -200,7 +200,9 @@ class UIAuthWorkerStore(SQLBaseStore):
             desc="set_ui_auth_client_dict",
         )
 
-    async def set_ui_auth_session_data(self, session_id: str, key: str, value: Any):
+    async def set_ui_auth_session_data(
+        self, session_id: str, key: str, value: Any
+    ) -> None:
         """
         Store a key-value pair into the sessions data associated with this
         request. This data is stored server-side and cannot be modified by
@@ -223,7 +225,7 @@ class UIAuthWorkerStore(SQLBaseStore):
 
     def _set_ui_auth_session_data_txn(
         self, txn: LoggingTransaction, session_id: str, key: str, value: Any
-    ):
+    ) -> None:
         # Get the current value.
         result = cast(
             Dict[str, Any],
@@ -275,7 +277,7 @@ class UIAuthWorkerStore(SQLBaseStore):
         session_id: str,
         user_agent: str,
         ip: str,
-    ):
+    ) -> None:
         """Add the given user agent / IP to the tracking table"""
         await self.db_pool.simple_upsert(
             table="ui_auth_sessions_ips",
@@ -318,7 +320,7 @@ class UIAuthWorkerStore(SQLBaseStore):
 
     def _delete_old_ui_auth_sessions_txn(
         self, txn: LoggingTransaction, expiration_time: int
-    ):
+    ) -> None:
         # Get the expired sessions.
         sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?"
         txn.execute(sql, [expiration_time])
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index e3153d1a4a..546d6bae6e 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -501,11 +501,11 @@ def _upgrade_existing_database(
 
                 if hasattr(module, "run_create"):
                     logger.info("Running %s:run_create", relative_path)
-                    module.run_create(cur, database_engine)  # type: ignore
+                    module.run_create(cur, database_engine)
 
                 if not is_empty and hasattr(module, "run_upgrade"):
                     logger.info("Running %s:run_upgrade", relative_path)
-                    module.run_upgrade(cur, database_engine, config=config)  # type: ignore
+                    module.run_upgrade(cur, database_engine, config=config)
             elif ext == ".pyc" or file_name == "__pycache__":
                 # Sometimes .pyc files turn up anyway even though we've
                 # disabled their generation; e.g. from distribution package
diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py
index 0b9ac26b69..f6b3ee31e4 100644
--- a/synapse/util/caches/ttlcache.py
+++ b/synapse/util/caches/ttlcache.py
@@ -107,7 +107,7 @@ class TTLCache(Generic[KT, VT]):
         self._metrics.inc_hits()
         return e.value, e.expiry_time, e.ttl
 
-    def pop(self, key: KT, default: T = SENTINEL) -> Union[VT, T]:  # type: ignore
+    def pop(self, key: KT, default: T = SENTINEL) -> Union[VT, T]:
         """Remove a value from the cache
 
         If key is in the cache, remove it and return its value, else return default.
diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py
index 9c405eb4d7..7223af1a36 100644
--- a/synapse/util/frozenutils.py
+++ b/synapse/util/frozenutils.py
@@ -11,6 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import collections.abc
 from typing import Any
 
 from frozendict import frozendict
@@ -35,7 +36,7 @@ def freeze(o: Any) -> Any:
 
 
 def unfreeze(o: Any) -> Any:
-    if isinstance(o, (dict, frozendict)):
+    if isinstance(o, collections.abc.Mapping):
         return {k: unfreeze(v) for k, v in o.items()}
 
     if isinstance(o, (bytes, str)):