summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/app/generic_worker.py3
-rw-r--r--synapse/config/server.py39
-rw-r--r--synapse/events/presence_router.py104
-rw-r--r--synapse/federation/federation_client.py2
-rw-r--r--synapse/federation/federation_server.py26
-rw-r--r--synapse/federation/sender/__init__.py25
-rw-r--r--synapse/federation/transport/server.py4
-rw-r--r--synapse/handlers/device.py4
-rw-r--r--synapse/handlers/e2e_keys.py12
-rw-r--r--synapse/handlers/federation.py161
-rw-r--r--synapse/handlers/presence.py278
-rw-r--r--synapse/handlers/room_member.py75
-rw-r--r--synapse/handlers/sync.py10
-rw-r--r--synapse/http/site.py112
-rw-r--r--synapse/logging/context.py70
-rw-r--r--synapse/metrics/__init__.py16
-rw-r--r--synapse/metrics/background_process_metrics.py18
-rw-r--r--synapse/module_api/__init__.py50
-rw-r--r--synapse/replication/tcp/protocol.py5
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py2
-rw-r--r--synapse/server.py5
-rw-r--r--synapse/state/__init__.py5
-rw-r--r--synapse/storage/databases/main/media_repository.py21
-rw-r--r--synapse/storage/databases/main/schema/delta/59/11drop_thumbnail_constraint.sql.postgres22
-rw-r--r--synapse/util/caches/expiringcache.py83
26 files changed, 904 insertions, 250 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 419299bf01..1d2883acf6 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -48,7 +48,7 @@ try:
 except ImportError:
     pass
 
-__version__ = "1.31.0rc1"
+__version__ = "1.31.0"
 
 if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
     # We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 3df2aa5c2b..d1c2079233 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -281,6 +281,7 @@ class GenericWorkerPresence(BasePresenceHandler):
         self.hs = hs
         self.is_mine_id = hs.is_mine_id
 
+        self.presence_router = hs.get_presence_router()
         self._presence_enabled = hs.config.use_presence
 
         # The number of ongoing syncs on this process, by user id.
@@ -395,7 +396,7 @@ class GenericWorkerPresence(BasePresenceHandler):
         return _user_syncing()
 
     async def notify_from_replication(self, states, stream_id):
-        parties = await get_interested_parties(self.store, states)
+        parties = await get_interested_parties(self.store, self.presence_router, states)
         room_ids_to_states, users_to_states = parties
 
         self.notifier.on_new_event(
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 5f8910b6e1..8decc9d10d 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -27,6 +27,7 @@ import yaml
 from netaddr import AddrFormatError, IPNetwork, IPSet
 
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.util.module_loader import load_module
 from synapse.util.stringutils import parse_and_validate_server_name
 
 from ._base import Config, ConfigError
@@ -238,7 +239,20 @@ class ServerConfig(Config):
         self.public_baseurl = config.get("public_baseurl")
 
         # Whether to enable user presence.
-        self.use_presence = config.get("use_presence", True)
+        presence_config = config.get("presence") or {}
+        self.use_presence = presence_config.get("enabled")
+        if self.use_presence is None:
+            self.use_presence = config.get("use_presence", True)
+
+        # Custom presence router module
+        self.presence_router_module_class = None
+        self.presence_router_config = None
+        presence_router_config = presence_config.get("presence_router")
+        if presence_router_config:
+            (
+                self.presence_router_module_class,
+                self.presence_router_config,
+            ) = load_module(presence_router_config, ("presence", "presence_router"))
 
         # Whether to update the user directory or not. This should be set to
         # false only if we are updating the user directory in a worker
@@ -834,9 +848,28 @@ class ServerConfig(Config):
         #
         #soft_file_limit: 0
 
-        # Set to false to disable presence tracking on this homeserver.
+        # Presence tracking allows users to see the state (e.g online/offline)
+        # of other local and remote users.
         #
-        #use_presence: false
+        presence:
+          # Uncomment to disable presence tracking on this homeserver. This option
+          # replaces the previous top-level 'use_presence' option.
+          #
+          #enabled: false
+
+          # Presence routers are third-party modules that can specify additional logic
+          # to where presence updates from users are routed.
+          #
+          presence_router:
+            # The custom module's class. Uncomment to use a custom presence router module.
+            #
+            #module: "my_custom_router.PresenceRouter"
+
+            # Configuration options of the custom module. Refer to your module's
+            # documentation for available options.
+            #
+            #config:
+            #  example_option: 'something'
 
         # Whether to require authentication to retrieve profile data (avatars,
         # display names) of other users through the client API. Defaults to
diff --git a/synapse/events/presence_router.py b/synapse/events/presence_router.py
new file mode 100644
index 0000000000..24cd389d80
--- /dev/null
+++ b/synapse/events/presence_router.py
@@ -0,0 +1,104 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import TYPE_CHECKING, Dict, Iterable, Set, Union
+
+from synapse.api.presence import UserPresenceState
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+
+class PresenceRouter:
+    """
+    A module that the homeserver will call upon to help route user presence updates to
+    additional destinations. If a custom presence router is configured, calls will be
+    passed to that instead.
+    """
+
+    ALL_USERS = "ALL"
+
+    def __init__(self, hs: "HomeServer"):
+        self.custom_presence_router = None
+
+        # Check whether a custom presence router module has been configured
+        if hs.config.presence_router_module_class:
+            # Initialise the module
+            self.custom_presence_router = hs.config.presence_router_module_class(
+                config=hs.config.presence_router_config, module_api=hs.get_module_api()
+            )
+
+            # Ensure the module has implemented the required methods
+            required_methods = ["get_users_for_states", "get_interested_users"]
+            for method_name in required_methods:
+                if not hasattr(self.custom_presence_router, method_name):
+                    raise Exception(
+                        "PresenceRouter module '%s' must implement all required methods: %s"
+                        % (
+                            hs.config.presence_router_module_class.__name__,
+                            ", ".join(required_methods),
+                        )
+                    )
+
+    async def get_users_for_states(
+        self,
+        state_updates: Iterable[UserPresenceState],
+    ) -> Dict[str, Set[UserPresenceState]]:
+        """
+        Given an iterable of user presence updates, determine where each one
+        needs to go.
+
+        Args:
+            state_updates: An iterable of user presence state updates.
+
+        Returns:
+          A dictionary of user_id -> set of UserPresenceState, indicating which
+          presence updates each user should receive.
+        """
+        if self.custom_presence_router is not None:
+            # Ask the custom module
+            return await self.custom_presence_router.get_users_for_states(
+                state_updates=state_updates
+            )
+
+        # Don't include any extra destinations for presence updates
+        return {}
+
+    async def get_interested_users(self, user_id: str) -> Union[Set[str], ALL_USERS]:
+        """
+        Retrieve a list of users that `user_id` is interested in receiving the
+        presence of. This will be in addition to those they share a room with.
+        Optionally, the object PresenceRouter.ALL_USERS can be returned to indicate
+        that this user should receive all incoming local and remote presence updates.
+
+        Note that this method will only be called for local users, but can return users
+        that are local or remote.
+
+        Args:
+            user_id: A user requesting presence updates.
+
+        Returns:
+            A set of user IDs to return presence updates for, or ALL_USERS to return all
+            known updates.
+        """
+        if self.custom_presence_router is not None:
+            # Ask the custom module for interested users
+            return await self.custom_presence_router.get_interested_users(
+                user_id=user_id
+            )
+
+        # A custom presence router is not defined.
+        # Don't report any additional interested users
+        return set()
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index afdb5bf2fa..55533d7501 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -102,7 +102,7 @@ class FederationClient(FederationBase):
             max_len=1000,
             expiry_ms=120 * 1000,
             reset_expiry_on_get=False,
-        )
+        )  # type: ExpiringCache[str, EventBase]
 
     def _clear_tried_cache(self):
         """Clear pdu_destination_tried cache"""
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 71cb120ef7..b9f8d966a6 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -739,22 +739,20 @@ class FederationServer(FederationBase):
 
         await self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True)
 
-    def __str__(self):
+    def __str__(self) -> str:
         return "<ReplicationLayer(%s)>" % self.server_name
 
     async def exchange_third_party_invite(
         self, sender_user_id: str, target_user_id: str, room_id: str, signed: Dict
-    ):
-        ret = await self.handler.exchange_third_party_invite(
+    ) -> None:
+        await self.handler.exchange_third_party_invite(
             sender_user_id, target_user_id, room_id, signed
         )
-        return ret
 
-    async def on_exchange_third_party_invite_request(self, event_dict: Dict):
-        ret = await self.handler.on_exchange_third_party_invite_request(event_dict)
-        return ret
+    async def on_exchange_third_party_invite_request(self, event_dict: Dict) -> None:
+        await self.handler.on_exchange_third_party_invite_request(event_dict)
 
-    async def check_server_matches_acl(self, server_name: str, room_id: str):
+    async def check_server_matches_acl(self, server_name: str, room_id: str) -> None:
         """Check if the given server is allowed by the server ACLs in the room
 
         Args:
@@ -878,7 +876,7 @@ class FederationHandlerRegistry:
 
     def register_edu_handler(
         self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]]
-    ):
+    ) -> None:
         """Sets the handler callable that will be used to handle an incoming
         federation EDU of the given type.
 
@@ -897,7 +895,7 @@ class FederationHandlerRegistry:
 
     def register_query_handler(
         self, query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
-    ):
+    ) -> None:
         """Sets the handler callable that will be used to handle an incoming
         federation query of the given type.
 
@@ -915,15 +913,17 @@ class FederationHandlerRegistry:
 
         self.query_handlers[query_type] = handler
 
-    def register_instance_for_edu(self, edu_type: str, instance_name: str):
+    def register_instance_for_edu(self, edu_type: str, instance_name: str) -> None:
         """Register that the EDU handler is on a different instance than master."""
         self._edu_type_to_instance[edu_type] = [instance_name]
 
-    def register_instances_for_edu(self, edu_type: str, instance_names: List[str]):
+    def register_instances_for_edu(
+        self, edu_type: str, instance_names: List[str]
+    ) -> None:
         """Register that the EDU handler is on multiple instances."""
         self._edu_type_to_instance[edu_type] = instance_names
 
-    async def on_edu(self, edu_type: str, origin: str, content: dict):
+    async def on_edu(self, edu_type: str, origin: str, content: dict) -> None:
         if not self.config.use_presence and edu_type == EduTypes.Presence:
             return
 
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 8babb1ebbe..d821dcbf6a 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -44,6 +44,7 @@ from synapse.types import JsonDict, ReadReceipt, RoomStreamToken
 from synapse.util.metrics import Measure, measure_func
 
 if TYPE_CHECKING:
+    from synapse.events.presence_router import PresenceRouter
     from synapse.server import HomeServer
 
 logger = logging.getLogger(__name__)
@@ -162,6 +163,7 @@ class FederationSender(AbstractFederationSender):
         self.clock = hs.get_clock()
         self.is_mine_id = hs.is_mine_id
 
+        self._presence_router = None  # type: Optional[PresenceRouter]
         self._transaction_manager = TransactionManager(hs)
 
         self._instance_name = hs.get_instance_name()
@@ -584,7 +586,22 @@ class FederationSender(AbstractFederationSender):
         """Given a list of states populate self.pending_presence_by_dest and
         poke to send a new transaction to each destination
         """
-        hosts_and_states = await get_interested_remotes(self.store, states, self.state)
+        # We pull the presence router here instead of __init__
+        # to prevent a dependency cycle:
+        #
+        # AuthHandler -> Notifier -> FederationSender
+        # -> PresenceRouter -> ModuleApi -> AuthHandler
+        if self._presence_router is None:
+            self._presence_router = self.hs.get_presence_router()
+
+        assert self._presence_router is not None
+
+        hosts_and_states = await get_interested_remotes(
+            self.store,
+            self._presence_router,
+            states,
+            self.state,
+        )
 
         for destinations, states in hosts_and_states:
             for destination in destinations:
@@ -717,16 +734,18 @@ class FederationSender(AbstractFederationSender):
                 self._catchup_after_startup_timer = None
                 break
 
+            last_processed = destinations_to_wake[-1]
+
             destinations_to_wake = [
                 d
                 for d in destinations_to_wake
                 if self._federation_shard_config.should_handle(self._instance_name, d)
             ]
 
-            for last_processed in destinations_to_wake:
+            for destination in destinations_to_wake:
                 logger.info(
                     "Destination %s has outstanding catch-up, waking up.",
                     last_processed,
                 )
-                self.wake_destination(last_processed)
+                self.wake_destination(destination)
                 await self.clock.sleep(CATCH_UP_STARTUP_INTERVAL_SEC)
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 84e39c5a46..5ef0556ef7 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -620,8 +620,8 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
     PATH = "/exchange_third_party_invite/(?P<room_id>[^/]*)"
 
     async def on_PUT(self, origin, content, query, room_id):
-        content = await self.handler.on_exchange_third_party_invite_request(content)
-        return 200, content
+        await self.handler.on_exchange_third_party_invite_request(content)
+        return 200, {}
 
 
 class FederationClientKeysQueryServlet(BaseFederationServlet):
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 54293d0b9c..7e76db3e2a 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -631,7 +631,7 @@ class DeviceListUpdater:
             max_len=10000,
             expiry_ms=30 * 60 * 1000,
             iterable=True,
-        )
+        )  # type: ExpiringCache[str, Set[str]]
 
         # Attempt to resync out of sync device lists every 30s.
         self._resync_retry_in_progress = False
@@ -760,7 +760,7 @@ class DeviceListUpdater:
         """Given a list of updates for a user figure out if we need to do a full
         resync, or whether we have enough data that we can just apply the delta.
         """
-        seen_updates = self._seen_updates.get(user_id, set())
+        seen_updates = self._seen_updates.get(user_id, set())  # type: Set[str]
 
         extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id)
 
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 739653a3fa..92b18378fc 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -38,7 +38,6 @@ from synapse.types import (
 )
 from synapse.util import json_decoder, unwrapFirstError
 from synapse.util.async_helpers import Linearizer
-from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.retryutils import NotRetryingDestination
 
 if TYPE_CHECKING:
@@ -1292,17 +1291,6 @@ class SigningKeyEduUpdater:
         # user_id -> list of updates waiting to be handled.
         self._pending_updates = {}  # type: Dict[str, List[Tuple[JsonDict, JsonDict]]]
 
-        # Recently seen stream ids. We don't bother keeping these in the DB,
-        # but they're useful to have them about to reduce the number of spurious
-        # resyncs.
-        self._seen_updates = ExpiringCache(
-            cache_name="signing_key_update_edu",
-            clock=self.clock,
-            max_len=10000,
-            expiry_ms=30 * 60 * 1000,
-            iterable=True,
-        )
-
     async def incoming_signing_key_update(
         self, origin: str, edu_content: JsonDict
     ) -> None:
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 3ebee38ebe..5ea8a7b603 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -21,7 +21,17 @@ import itertools
 import logging
 from collections.abc import Container
 from http import HTTPStatus
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
+from typing import (
+    TYPE_CHECKING,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Sequence,
+    Set,
+    Tuple,
+    Union,
+)
 
 import attr
 from signedjson.key import decode_verify_key_bytes
@@ -171,15 +181,17 @@ class FederationHandler(BaseHandler):
 
         self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
 
-    async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
+    async def on_receive_pdu(
+        self, origin: str, pdu: EventBase, sent_to_us_directly: bool = False
+    ) -> None:
         """Process a PDU received via a federation /send/ transaction, or
         via backfill of missing prev_events
 
         Args:
-            origin (str): server which initiated the /send/ transaction. Will
+            origin: server which initiated the /send/ transaction. Will
                 be used to fetch missing events or state.
-            pdu (FrozenEvent): received PDU
-            sent_to_us_directly (bool): True if this event was pushed to us; False if
+            pdu: received PDU
+            sent_to_us_directly: True if this event was pushed to us; False if
                 we pulled it as the result of a missing prev_event.
         """
 
@@ -411,13 +423,15 @@ class FederationHandler(BaseHandler):
 
         await self._process_received_pdu(origin, pdu, state=state)
 
-    async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
+    async def _get_missing_events_for_pdu(
+        self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int
+    ) -> None:
         """
         Args:
-            origin (str): Origin of the pdu. Will be called to get the missing events
+            origin: Origin of the pdu. Will be called to get the missing events
             pdu: received pdu
-            prevs (set(str)): List of event ids which we are missing
-            min_depth (int): Minimum depth of events to return.
+            prevs: List of event ids which we are missing
+            min_depth: Minimum depth of events to return.
         """
 
         room_id = pdu.room_id
@@ -778,7 +792,7 @@ class FederationHandler(BaseHandler):
         origin: str,
         event: EventBase,
         state: Optional[Iterable[EventBase]],
-    ):
+    ) -> None:
         """Called when we have a new pdu. We need to do auth checks and put it
         through the StateHandler.
 
@@ -887,7 +901,9 @@ class FederationHandler(BaseHandler):
             logger.exception("Failed to resync device for %s", sender)
 
     @log_function
-    async def backfill(self, dest, room_id, limit, extremities):
+    async def backfill(
+        self, dest: str, room_id: str, limit: int, extremities: List[str]
+    ) -> List[EventBase]:
         """Trigger a backfill request to `dest` for the given `room_id`
 
         This will attempt to get more events from the remote. If the other side
@@ -1142,16 +1158,15 @@ class FederationHandler(BaseHandler):
 
         curr_state = await self.state_handler.get_current_state(room_id)
 
-        def get_domains_from_state(state):
+        def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]:
             """Get joined domains from state
 
             Args:
-                state (dict[tuple, FrozenEvent]): State map from type/state
-                    key to event.
+                state: State map from type/state key to event.
 
             Returns:
-                list[tuple[str, int]]: Returns a list of servers with the
-                lowest depth of their joins. Sorted by lowest depth first.
+                Returns a list of servers with the lowest depth of their joins.
+                 Sorted by lowest depth first.
             """
             joined_users = [
                 (state_key, int(event.depth))
@@ -1179,7 +1194,7 @@ class FederationHandler(BaseHandler):
             domain for domain, depth in curr_domains if domain != self.server_name
         ]
 
-        async def try_backfill(domains):
+        async def try_backfill(domains: List[str]) -> bool:
             # TODO: Should we try multiple of these at a time?
             for dom in domains:
                 try:
@@ -1258,21 +1273,25 @@ class FederationHandler(BaseHandler):
         }
 
         for e_id, _ in sorted_extremeties_tuple:
-            likely_domains = get_domains_from_state(states[e_id])
+            likely_extremeties_domains = get_domains_from_state(states[e_id])
 
             success = await try_backfill(
-                [dom for dom, _ in likely_domains if dom not in tried_domains]
+                [
+                    dom
+                    for dom, _ in likely_extremeties_domains
+                    if dom not in tried_domains
+                ]
             )
             if success:
                 return True
 
-            tried_domains.update(dom for dom, _ in likely_domains)
+            tried_domains.update(dom for dom, _ in likely_extremeties_domains)
 
         return False
 
     async def _get_events_and_persist(
         self, destination: str, room_id: str, events: Iterable[str]
-    ):
+    ) -> None:
         """Fetch the given events from a server, and persist them as outliers.
 
         This function *does not* recursively get missing auth events of the
@@ -1348,7 +1367,7 @@ class FederationHandler(BaseHandler):
             event_infos,
         )
 
-    def _sanity_check_event(self, ev):
+    def _sanity_check_event(self, ev: EventBase) -> None:
         """
         Do some early sanity checks of a received event
 
@@ -1357,9 +1376,7 @@ class FederationHandler(BaseHandler):
         or cascade of event fetches.
 
         Args:
-            ev (synapse.events.EventBase): event to be checked
-
-        Returns: None
+            ev: event to be checked
 
         Raises:
             SynapseError if the event does not pass muster
@@ -1380,7 +1397,7 @@ class FederationHandler(BaseHandler):
             )
             raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events")
 
-    async def send_invite(self, target_host, event):
+    async def send_invite(self, target_host: str, event: EventBase) -> EventBase:
         """Sends the invite to the remote server for signing.
 
         Invites must be signed by the invitee's server before distribution.
@@ -1528,12 +1545,13 @@ class FederationHandler(BaseHandler):
 
             run_in_background(self._handle_queued_pdus, room_queue)
 
-    async def _handle_queued_pdus(self, room_queue):
+    async def _handle_queued_pdus(
+        self, room_queue: List[Tuple[EventBase, str]]
+    ) -> None:
         """Process PDUs which got queued up while we were busy send_joining.
 
         Args:
-            room_queue (list[FrozenEvent, str]): list of PDUs to be processed
-                and the servers that sent them
+            room_queue: list of PDUs to be processed and the servers that sent them
         """
         for p, origin in room_queue:
             try:
@@ -1612,7 +1630,7 @@ class FederationHandler(BaseHandler):
 
         return event
 
-    async def on_send_join_request(self, origin, pdu):
+    async def on_send_join_request(self, origin: str, pdu: EventBase) -> JsonDict:
         """We have received a join event for a room. Fully process it and
         respond with the current state and auth chains.
         """
@@ -1668,7 +1686,7 @@ class FederationHandler(BaseHandler):
 
     async def on_invite_request(
         self, origin: str, event: EventBase, room_version: RoomVersion
-    ):
+    ) -> EventBase:
         """We've got an invite event. Process and persist it. Sign it.
 
         Respond with the now signed event.
@@ -1841,7 +1859,7 @@ class FederationHandler(BaseHandler):
 
         return event
 
-    async def on_send_leave_request(self, origin, pdu):
+    async def on_send_leave_request(self, origin: str, pdu: EventBase) -> None:
         """ We have received a leave event for a room. Fully process it."""
         event = pdu
 
@@ -1969,12 +1987,17 @@ class FederationHandler(BaseHandler):
         else:
             return None
 
-    async def get_min_depth_for_context(self, context):
+    async def get_min_depth_for_context(self, context: str) -> int:
         return await self.store.get_min_depth(context)
 
     async def _handle_new_event(
-        self, origin, event, state=None, auth_events=None, backfilled=False
-    ):
+        self,
+        origin: str,
+        event: EventBase,
+        state: Optional[Iterable[EventBase]] = None,
+        auth_events: Optional[MutableStateMap[EventBase]] = None,
+        backfilled: bool = False,
+    ) -> EventContext:
         context = await self._prep_event(
             origin, event, state=state, auth_events=auth_events, backfilled=backfilled
         )
@@ -2280,40 +2303,14 @@ class FederationHandler(BaseHandler):
             logger.warning("Soft-failing %r because %s", event, e)
             event.internal_metadata.soft_failed = True
 
-    async def on_query_auth(
-        self, origin, event_id, room_id, remote_auth_chain, rejects, missing
-    ):
-        in_room = await self.auth.check_host_in_room(room_id, origin)
-        if not in_room:
-            raise AuthError(403, "Host not in room.")
-
-        event = await self.store.get_event(event_id, check_room_id=room_id)
-
-        # Just go through and process each event in `remote_auth_chain`. We
-        # don't want to fall into the trap of `missing` being wrong.
-        for e in remote_auth_chain:
-            try:
-                await self._handle_new_event(origin, e)
-            except AuthError:
-                pass
-
-        # Now get the current auth_chain for the event.
-        local_auth_chain = await self.store.get_auth_chain(
-            room_id, list(event.auth_event_ids()), include_given=True
-        )
-
-        # TODO: Check if we would now reject event_id. If so we need to tell
-        # everyone.
-
-        ret = await self.construct_auth_difference(local_auth_chain, remote_auth_chain)
-
-        logger.debug("on_query_auth returning: %s", ret)
-
-        return ret
-
     async def on_get_missing_events(
-        self, origin, room_id, earliest_events, latest_events, limit
-    ):
+        self,
+        origin: str,
+        room_id: str,
+        earliest_events: List[str],
+        latest_events: List[str],
+        limit: int,
+    ) -> List[EventBase]:
         in_room = await self.auth.check_host_in_room(room_id, origin)
         if not in_room:
             raise AuthError(403, "Host not in room.")
@@ -2617,8 +2614,8 @@ class FederationHandler(BaseHandler):
         assumes that we have already processed all events in remote_auth
 
         Params:
-            local_auth (list)
-            remote_auth (list)
+            local_auth
+            remote_auth
 
         Returns:
             dict
@@ -2742,8 +2739,8 @@ class FederationHandler(BaseHandler):
 
     @log_function
     async def exchange_third_party_invite(
-        self, sender_user_id, target_user_id, room_id, signed
-    ):
+        self, sender_user_id: str, target_user_id: str, room_id: str, signed: JsonDict
+    ) -> None:
         third_party_invite = {"signed": signed}
 
         event_dict = {
@@ -2835,8 +2832,12 @@ class FederationHandler(BaseHandler):
         await member_handler.send_membership_event(None, event, context)
 
     async def add_display_name_to_third_party_invite(
-        self, room_version, event_dict, event, context
-    ):
+        self,
+        room_version: str,
+        event_dict: JsonDict,
+        event: EventBase,
+        context: EventContext,
+    ) -> Tuple[EventBase, EventContext]:
         key = (
             EventTypes.ThirdPartyInvite,
             event.content["third_party_invite"]["signed"]["token"],
@@ -2872,13 +2873,13 @@ class FederationHandler(BaseHandler):
         EventValidator().validate_new(event, self.config)
         return (event, context)
 
-    async def _check_signature(self, event, context):
+    async def _check_signature(self, event: EventBase, context: EventContext) -> None:
         """
         Checks that the signature in the event is consistent with its invite.
 
         Args:
-            event (Event): The m.room.member event to check
-            context (EventContext):
+            event: The m.room.member event to check
+            context:
 
         Raises:
             AuthError: if signature didn't match any keys, or key has been
@@ -2964,13 +2965,13 @@ class FederationHandler(BaseHandler):
 
         raise last_exception
 
-    async def _check_key_revocation(self, public_key, url):
+    async def _check_key_revocation(self, public_key: str, url: str) -> None:
         """
         Checks whether public_key has been revoked.
 
         Args:
-            public_key (str): base-64 encoded public key.
-            url (str): Key revocation URL.
+            public_key: base-64 encoded public key.
+            url: Key revocation URL.
 
         Raises:
             AuthError: if they key has been revoked.
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index da92feacc9..c817f2952d 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -25,7 +25,17 @@ The methods that define policy are:
 import abc
 import logging
 from contextlib import contextmanager
-from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Dict,
+    FrozenSet,
+    Iterable,
+    List,
+    Optional,
+    Set,
+    Tuple,
+    Union,
+)
 
 from prometheus_client import Counter
 from typing_extensions import ContextManager
@@ -34,6 +44,7 @@ import synapse.metrics
 from synapse.api.constants import EventTypes, Membership, PresenceState
 from synapse.api.errors import SynapseError
 from synapse.api.presence import UserPresenceState
+from synapse.events.presence_router import PresenceRouter
 from synapse.logging.context import run_in_background
 from synapse.logging.utils import log_function
 from synapse.metrics import LaterGauge
@@ -42,7 +53,7 @@ from synapse.state import StateHandler
 from synapse.storage.databases.main import DataStore
 from synapse.types import Collection, JsonDict, UserID, get_domain_from_id
 from synapse.util.async_helpers import Linearizer
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import _CacheContext, cached
 from synapse.util.metrics import Measure
 from synapse.util.wheel_timer import WheelTimer
 
@@ -209,6 +220,7 @@ class PresenceHandler(BasePresenceHandler):
         self.notifier = hs.get_notifier()
         self.federation = hs.get_federation_sender()
         self.state = hs.get_state_handler()
+        self.presence_router = hs.get_presence_router()
         self._presence_enabled = hs.config.use_presence
 
         federation_registry = hs.get_federation_registry()
@@ -653,7 +665,7 @@ class PresenceHandler(BasePresenceHandler):
         """
         stream_id, max_token = await self.store.update_presence(states)
 
-        parties = await get_interested_parties(self.store, states)
+        parties = await get_interested_parties(self.store, self.presence_router, states)
         room_ids_to_states, users_to_states = parties
 
         self.notifier.on_new_event(
@@ -1041,7 +1053,12 @@ class PresenceEventSource:
         #
         #   Presence -> Notifier -> PresenceEventSource -> Presence
         #
+        # Same with get_module_api, get_presence_router
+        #
+        #   AuthHandler -> Notifier -> PresenceEventSource -> ModuleApi -> AuthHandler
         self.get_presence_handler = hs.get_presence_handler
+        self.get_module_api = hs.get_module_api
+        self.get_presence_router = hs.get_presence_router
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
         self.state = hs.get_state_handler()
@@ -1055,7 +1072,7 @@ class PresenceEventSource:
         include_offline=True,
         explicit_room_id=None,
         **kwargs
-    ):
+    ) -> Tuple[List[UserPresenceState], int]:
         # The process for getting presence events are:
         #  1. Get the rooms the user is in.
         #  2. Get the list of user in the rooms.
@@ -1068,7 +1085,17 @@ class PresenceEventSource:
         # We don't try and limit the presence updates by the current token, as
         # sending down the rare duplicate is not a concern.
 
+        user_id = user.to_string()
+        stream_change_cache = self.store.presence_stream_cache
+
         with Measure(self.clock, "presence.get_new_events"):
+            if user_id in self.get_module_api()._send_full_presence_to_local_users:
+                # This user has been specified by a module to receive all current, online
+                # user presence. Removing from_key and setting include_offline to false
+                # will do effectively this.
+                from_key = None
+                include_offline = False
+
             if from_key is not None:
                 from_key = int(from_key)
 
@@ -1091,59 +1118,209 @@ class PresenceEventSource:
                 # doesn't return. C.f. #5503.
                 return [], max_token
 
-            presence = self.get_presence_handler()
-            stream_change_cache = self.store.presence_stream_cache
-
+            # Figure out which other users this user should receive updates for
             users_interested_in = await self._get_interested_in(user, explicit_room_id)
 
-            user_ids_changed = set()  # type: Collection[str]
-            changed = None
-            if from_key:
-                changed = stream_change_cache.get_all_entities_changed(from_key)
+            # We have a set of users that we're interested in the presence of. We want to
+            # cross-reference that with the users that have actually changed their presence.
 
-            if changed is not None and len(changed) < 500:
-                assert isinstance(user_ids_changed, set)
+            # Check whether this user should see all user updates
 
-                # For small deltas, its quicker to get all changes and then
-                # work out if we share a room or they're in our presence list
-                get_updates_counter.labels("stream").inc()
-                for other_user_id in changed:
-                    if other_user_id in users_interested_in:
-                        user_ids_changed.add(other_user_id)
-            else:
-                # Too many possible updates. Find all users we can see and check
-                # if any of them have changed.
-                get_updates_counter.labels("full").inc()
+            if users_interested_in == PresenceRouter.ALL_USERS:
+                # Provide presence state for all users
+                presence_updates = await self._filter_all_presence_updates_for_user(
+                    user_id, include_offline, from_key
+                )
 
-                if from_key:
-                    user_ids_changed = stream_change_cache.get_entities_changed(
-                        users_interested_in, from_key
+                # Remove the user from the list of users to receive all presence
+                if user_id in self.get_module_api()._send_full_presence_to_local_users:
+                    self.get_module_api()._send_full_presence_to_local_users.remove(
+                        user_id
                     )
+
+                return presence_updates, max_token
+
+            # Make mypy happy. users_interested_in should now be a set
+            assert not isinstance(users_interested_in, str)
+
+            # The set of users that we're interested in and that have had a presence update.
+            # We'll actually pull the presence updates for these users at the end.
+            interested_and_updated_users = (
+                set()
+            )  # type: Union[Set[str], FrozenSet[str]]
+
+            if from_key:
+                # First get all users that have had a presence update
+                updated_users = stream_change_cache.get_all_entities_changed(from_key)
+
+                # Cross-reference users we're interested in with those that have had updates.
+                # Use a slightly-optimised method for processing smaller sets of updates.
+                if updated_users is not None and len(updated_users) < 500:
+                    # For small deltas, it's quicker to get all changes and then
+                    # cross-reference with the users we're interested in
+                    get_updates_counter.labels("stream").inc()
+                    for other_user_id in updated_users:
+                        if other_user_id in users_interested_in:
+                            # mypy thinks this variable could be a FrozenSet as it's possibly set
+                            # to one in the `get_entities_changed` call below, and `add()` is not
+                            # method on a FrozenSet. That doesn't affect us here though, as
+                            # `interested_and_updated_users` is clearly a set() above.
+                            interested_and_updated_users.add(other_user_id)  # type: ignore
                 else:
-                    user_ids_changed = users_interested_in
+                    # Too many possible updates. Find all users we can see and check
+                    # if any of them have changed.
+                    get_updates_counter.labels("full").inc()
 
-            updates = await presence.current_state_for_users(user_ids_changed)
+                    interested_and_updated_users = (
+                        stream_change_cache.get_entities_changed(
+                            users_interested_in, from_key
+                        )
+                    )
+            else:
+                # No from_key has been specified. Return the presence for all users
+                # this user is interested in
+                interested_and_updated_users = users_interested_in
+
+            # Retrieve the current presence state for each user
+            users_to_state = await self.get_presence_handler().current_state_for_users(
+                interested_and_updated_users
+            )
+            presence_updates = list(users_to_state.values())
 
-        if include_offline:
-            return (list(updates.values()), max_token)
+        # Remove the user from the list of users to receive all presence
+        if user_id in self.get_module_api()._send_full_presence_to_local_users:
+            self.get_module_api()._send_full_presence_to_local_users.remove(user_id)
+
+        if not include_offline:
+            # Filter out offline presence states
+            presence_updates = self._filter_offline_presence_state(presence_updates)
+
+        return presence_updates, max_token
+
+    async def _filter_all_presence_updates_for_user(
+        self,
+        user_id: str,
+        include_offline: bool,
+        from_key: Optional[int] = None,
+    ) -> List[UserPresenceState]:
+        """
+        Computes the presence updates a user should receive.
+
+        First pulls presence updates from the database. Then consults PresenceRouter
+        for whether any updates should be excluded by user ID.
+
+        Args:
+            user_id: The User ID of the user to compute presence updates for.
+            include_offline: Whether to include offline presence states from the results.
+            from_key: The minimum stream ID of updates to pull from the database
+                before filtering.
+
+        Returns:
+            A list of presence states for the given user to receive.
+        """
+        if from_key:
+            # Only return updates since the last sync
+            updated_users = self.store.presence_stream_cache.get_all_entities_changed(
+                from_key
+            )
+            if not updated_users:
+                updated_users = []
+
+            # Get the actual presence update for each change
+            users_to_state = await self.get_presence_handler().current_state_for_users(
+                updated_users
+            )
+            presence_updates = list(users_to_state.values())
+
+            if not include_offline:
+                # Filter out offline states
+                presence_updates = self._filter_offline_presence_state(presence_updates)
         else:
-            return (
-                [s for s in updates.values() if s.state != PresenceState.OFFLINE],
-                max_token,
+            users_to_state = await self.store.get_presence_for_all_users(
+                include_offline=include_offline
             )
 
+            presence_updates = list(users_to_state.values())
+
+        # TODO: This feels wildly inefficient, and it's unfortunate we need to ask the
+        # module for information on a number of users when we then only take the info
+        # for a single user
+
+        # Filter through the presence router
+        users_to_state_set = await self.get_presence_router().get_users_for_states(
+            presence_updates
+        )
+
+        # We only want the mapping for the syncing user
+        presence_updates = list(users_to_state_set[user_id])
+
+        # Return presence information for all users
+        return presence_updates
+
+    def _filter_offline_presence_state(
+        self, presence_updates: Iterable[UserPresenceState]
+    ) -> List[UserPresenceState]:
+        """Given an iterable containing user presence updates, return a list with any offline
+        presence states removed.
+
+        Args:
+            presence_updates: Presence states to filter
+
+        Returns:
+            A new list with any offline presence states removed.
+        """
+        return [
+            update
+            for update in presence_updates
+            if update.state != PresenceState.OFFLINE
+        ]
+
     def get_current_key(self):
         return self.store.get_current_presence_token()
 
     @cached(num_args=2, cache_context=True)
-    async def _get_interested_in(self, user, explicit_room_id, cache_context):
+    async def _get_interested_in(
+        self,
+        user: UserID,
+        explicit_room_id: Optional[str] = None,
+        cache_context: Optional[_CacheContext] = None,
+    ) -> Union[Set[str], str]:
         """Returns the set of users that the given user should see presence
-        updates for
+        updates for.
+
+        Args:
+            user: The user to retrieve presence updates for.
+            explicit_room_id: The users that are in the room will be returned.
+
+        Returns:
+            A set of user IDs to return presence updates for, or "ALL" to return all
+            known updates.
         """
         user_id = user.to_string()
         users_interested_in = set()
         users_interested_in.add(user_id)  # So that we receive our own presence
 
+        # cache_context isn't likely to ever be None due to the @cached decorator,
+        # but we can't have a non-optional argument after the optional argument
+        # explicit_room_id either. Assert cache_context is not None so we can use it
+        # without mypy complaining.
+        assert cache_context
+
+        # Check with the presence router whether we should poll additional users for
+        # their presence information
+        additional_users = await self.get_presence_router().get_interested_users(
+            user.to_string()
+        )
+        if additional_users == PresenceRouter.ALL_USERS:
+            # If the module requested that this user see the presence updates of *all*
+            # users, then simply return that instead of calculating what rooms this
+            # user shares
+            return PresenceRouter.ALL_USERS
+
+        # Add the additional users from the router
+        users_interested_in.update(additional_users)
+
+        # Find the users who share a room with this user
         users_who_share_room = await self.store.get_users_who_share_room_with_user(
             user_id, on_invalidate=cache_context.invalidate
         )
@@ -1314,14 +1491,15 @@ def handle_update(prev_state, new_state, is_mine, wheel_timer, now):
 
 
 async def get_interested_parties(
-    store: DataStore, states: List[UserPresenceState]
+    store: DataStore, presence_router: PresenceRouter, states: List[UserPresenceState]
 ) -> Tuple[Dict[str, List[UserPresenceState]], Dict[str, List[UserPresenceState]]]:
     """Given a list of states return which entities (rooms, users)
     are interested in the given states.
 
     Args:
-        store
-        states
+        store: The homeserver's data store.
+        presence_router: A module for augmenting the destinations for presence updates.
+        states: A list of incoming user presence updates.
 
     Returns:
         A 2-tuple of `(room_ids_to_states, users_to_states)`,
@@ -1337,11 +1515,22 @@ async def get_interested_parties(
         # Always notify self
         users_to_states.setdefault(state.user_id, []).append(state)
 
+    # Ask a presence routing module for any additional parties if one
+    # is loaded.
+    router_users_to_states = await presence_router.get_users_for_states(states)
+
+    # Update the dictionaries with additional destinations and state to send
+    for user_id, user_states in router_users_to_states.items():
+        users_to_states.setdefault(user_id, []).extend(user_states)
+
     return room_ids_to_states, users_to_states
 
 
 async def get_interested_remotes(
-    store: DataStore, states: List[UserPresenceState], state_handler: StateHandler
+    store: DataStore,
+    presence_router: PresenceRouter,
+    states: List[UserPresenceState],
+    state_handler: StateHandler,
 ) -> List[Tuple[Collection[str], List[UserPresenceState]]]:
     """Given a list of presence states figure out which remote servers
     should be sent which.
@@ -1349,9 +1538,10 @@ async def get_interested_remotes(
     All the presence states should be for local users only.
 
     Args:
-        store
-        states
-        state_handler
+        store: The homeserver's data store.
+        presence_router: A module for augmenting the destinations for presence updates.
+        states: A list of incoming user presence updates.
+        state_handler:
 
     Returns:
         A list of 2-tuples of destinations and states, where for
@@ -1363,7 +1553,9 @@ async def get_interested_remotes(
     # First we look up the rooms each user is in (as well as any explicit
     # subscriptions), then for each distinct room we look up the remote
     # hosts in those rooms.
-    room_ids_to_states, users_to_states = await get_interested_parties(store, states)
+    room_ids_to_states, users_to_states = await get_interested_parties(
+        store, presence_router, states
+    )
 
     for room_id, states in room_ids_to_states.items():
         hosts = await state_handler.get_current_hosts_in_room(room_id)
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 1cf12f3255..894ef859f4 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -20,7 +20,7 @@ from http import HTTPStatus
 from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
 
 from synapse import types
-from synapse.api.constants import AccountDataTypes, EventTypes, Membership
+from synapse.api.constants import AccountDataTypes, EventTypes, JoinRules, Membership
 from synapse.api.errors import (
     AuthError,
     Codes,
@@ -29,6 +29,7 @@ from synapse.api.errors import (
     SynapseError,
 )
 from synapse.api.ratelimiting import Ratelimiter
+from synapse.api.room_versions import RoomVersion
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
@@ -178,6 +179,62 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
         await self._invites_per_user_limiter.ratelimit(requester, invitee_user_id)
 
+    async def _can_join_without_invite(
+        self, state_ids: StateMap[str], room_version: RoomVersion, user_id: str
+    ) -> bool:
+        """
+        Check whether a user can join a room without an invite.
+
+        When joining a room with restricted joined rules (as defined in MSC3083),
+        the membership of spaces must be checked during join.
+
+        Args:
+            state_ids: The state of the room as it currently is.
+            room_version: The room version of the room being joined.
+            user_id: The user joining the room.
+
+        Returns:
+            True if the user can join the room, false otherwise.
+        """
+        # This only applies to room versions which support the new join rule.
+        if not room_version.msc3083_join_rules:
+            return True
+
+        # If there's no join rule, then it defaults to public (so this doesn't apply).
+        join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""), None)
+        if not join_rules_event_id:
+            return True
+
+        # If the join rule is not restricted, this doesn't apply.
+        join_rules_event = await self.store.get_event(join_rules_event_id)
+        if join_rules_event.content.get("join_rule") != JoinRules.MSC3083_RESTRICTED:
+            return True
+
+        # If allowed is of the wrong form, then only allow invited users.
+        allowed_spaces = join_rules_event.content.get("allow", [])
+        if not isinstance(allowed_spaces, list):
+            return False
+
+        # Get the list of joined rooms and see if there's an overlap.
+        joined_rooms = await self.store.get_rooms_for_user(user_id)
+
+        # Pull out the other room IDs, invalid data gets filtered.
+        for space in allowed_spaces:
+            if not isinstance(space, dict):
+                continue
+
+            space_id = space.get("space")
+            if not isinstance(space_id, str):
+                continue
+
+            # The user was joined to one of the spaces specified, they can join
+            # this room!
+            if space_id in joined_rooms:
+                return True
+
+        # The user was not in any of the required spaces.
+        return False
+
     async def _local_membership_update(
         self,
         requester: Requester,
@@ -235,9 +292,25 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
         if event.membership == Membership.JOIN:
             newly_joined = True
+            user_is_invited = False
             if prev_member_event_id:
                 prev_member_event = await self.store.get_event(prev_member_event_id)
                 newly_joined = prev_member_event.membership != Membership.JOIN
+                user_is_invited = prev_member_event.membership == Membership.INVITE
+
+            # If the member is not already in the room and is not accepting an invite,
+            # check if they should be allowed access via membership in a space.
+            if (
+                newly_joined
+                and not user_is_invited
+                and not await self._can_join_without_invite(
+                    prev_state_ids, event.room_version, user_id
+                )
+            ):
+                raise AuthError(
+                    403,
+                    "You do not belong to any of the required spaces to join this room.",
+                )
 
             # Only rate-limit if the user actually joined the room, otherwise we'll end
             # up blocking profile updates.
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 7b356ba7e5..ff11266c67 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -252,13 +252,13 @@ class SyncHandler:
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
 
-        # ExpiringCache((User, Device)) -> LruCache(state_key => event_id)
+        # ExpiringCache((User, Device)) -> LruCache(user_id => event_id)
         self.lazy_loaded_members_cache = ExpiringCache(
             "lazy_loaded_members_cache",
             self.clock,
             max_len=0,
             expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
-        )
+        )  # type: ExpiringCache[Tuple[str, Optional[str]], LruCache[str, str]]
 
     async def wait_for_sync_for_user(
         self,
@@ -733,8 +733,10 @@ class SyncHandler:
 
     def get_lazy_loaded_members_cache(
         self, cache_key: Tuple[str, Optional[str]]
-    ) -> LruCache:
-        cache = self.lazy_loaded_members_cache.get(cache_key)
+    ) -> LruCache[str, str]:
+        cache = self.lazy_loaded_members_cache.get(
+            cache_key
+        )  # type: Optional[LruCache[str, str]]
         if cache is None:
             logger.debug("creating LruCache for %r", cache_key)
             cache = LruCache(LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE)
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 47754aff43..c0c873ce32 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -14,7 +14,7 @@
 import contextlib
 import logging
 import time
-from typing import Optional, Type, Union
+from typing import Optional, Tuple, Type, Union
 
 import attr
 from zope.interface import implementer
@@ -26,7 +26,11 @@ from twisted.web.server import Request, Site
 from synapse.config.server import ListenerConfig
 from synapse.http import get_request_user_agent, redact_uri
 from synapse.http.request_metrics import RequestMetrics, requests_counter
-from synapse.logging.context import LoggingContext, PreserveLoggingContext
+from synapse.logging.context import (
+    ContextRequest,
+    LoggingContext,
+    PreserveLoggingContext,
+)
 from synapse.types import Requester
 
 logger = logging.getLogger(__name__)
@@ -63,7 +67,7 @@ class SynapseRequest(Request):
 
         # The requester, if authenticated. For federation requests this is the
         # server name, for client requests this is the Requester object.
-        self.requester = None  # type: Optional[Union[Requester, str]]
+        self._requester = None  # type: Optional[Union[Requester, str]]
 
         # we can't yet create the logcontext, as we don't know the method.
         self.logcontext = None  # type: Optional[LoggingContext]
@@ -93,6 +97,31 @@ class SynapseRequest(Request):
             self.site.site_tag,
         )
 
+    @property
+    def requester(self) -> Optional[Union[Requester, str]]:
+        return self._requester
+
+    @requester.setter
+    def requester(self, value: Union[Requester, str]) -> None:
+        # Store the requester, and update some properties based on it.
+
+        # This should only be called once.
+        assert self._requester is None
+
+        self._requester = value
+
+        # A logging context should exist by now (and have a ContextRequest).
+        assert self.logcontext is not None
+        assert self.logcontext.request is not None
+
+        (
+            requester,
+            authenticated_entity,
+        ) = self.get_authenticated_entity()
+        self.logcontext.request.requester = requester
+        # If there's no authenticated entity, it was the requester.
+        self.logcontext.request.authenticated_entity = authenticated_entity or requester
+
     def get_request_id(self):
         return "%s-%i" % (self.get_method(), self.request_seq)
 
@@ -126,13 +155,60 @@ class SynapseRequest(Request):
             return self.method.decode("ascii")
         return method
 
+    def get_authenticated_entity(self) -> Tuple[Optional[str], Optional[str]]:
+        """
+        Get the "authenticated" entity of the request, which might be the user
+        performing the action, or a user being puppeted by a server admin.
+
+        Returns:
+            A tuple:
+                The first item is a string representing the user making the request.
+
+                The second item is a string or None representing the user who
+                authenticated when making this request. See
+                Requester.authenticated_entity.
+        """
+        # Convert the requester into a string that we can log
+        if isinstance(self._requester, str):
+            return self._requester, None
+        elif isinstance(self._requester, Requester):
+            requester = self._requester.user.to_string()
+            authenticated_entity = self._requester.authenticated_entity
+
+            # If this is a request where the target user doesn't match the user who
+            # authenticated (e.g. and admin is puppetting a user) then we return both.
+            if self._requester.user.to_string() != authenticated_entity:
+                return requester, authenticated_entity
+
+            return requester, None
+        elif self._requester is not None:
+            # This shouldn't happen, but we log it so we don't lose information
+            # and can see that we're doing something wrong.
+            return repr(self._requester), None  # type: ignore[unreachable]
+
+        return None, None
+
     def render(self, resrc):
         # this is called once a Resource has been found to serve the request; in our
         # case the Resource in question will normally be a JsonResource.
 
         # create a LogContext for this request
         request_id = self.get_request_id()
-        self.logcontext = LoggingContext(request_id, request=request_id)
+        self.logcontext = LoggingContext(
+            request_id,
+            request=ContextRequest(
+                request_id=request_id,
+                ip_address=self.getClientIP(),
+                site_tag=self.site.site_tag,
+                # The requester is going to be unknown at this point.
+                requester=None,
+                authenticated_entity=None,
+                method=self.get_method(),
+                url=self.get_redacted_uri(),
+                protocol=self.clientproto.decode("ascii", errors="replace"),
+                user_agent=get_request_user_agent(self),
+            ),
+        )
 
         # override the Server header which is set by twisted
         self.setHeader("Server", self.site.server_version_string)
@@ -277,25 +353,6 @@ class SynapseRequest(Request):
         # to the client (nb may be negative)
         response_send_time = self.finish_time - self._processing_finished_time
 
-        # Convert the requester into a string that we can log
-        authenticated_entity = None
-        if isinstance(self.requester, str):
-            authenticated_entity = self.requester
-        elif isinstance(self.requester, Requester):
-            authenticated_entity = self.requester.authenticated_entity
-
-            # If this is a request where the target user doesn't match the user who
-            # authenticated (e.g. and admin is puppetting a user) then we log both.
-            if self.requester.user.to_string() != authenticated_entity:
-                authenticated_entity = "{},{}".format(
-                    authenticated_entity,
-                    self.requester.user.to_string(),
-                )
-        elif self.requester is not None:
-            # This shouldn't happen, but we log it so we don't lose information
-            # and can see that we're doing something wrong.
-            authenticated_entity = repr(self.requester)  # type: ignore[unreachable]
-
         user_agent = get_request_user_agent(self, "-")
 
         code = str(self.code)
@@ -305,6 +362,13 @@ class SynapseRequest(Request):
             code += "!"
 
         log_level = logging.INFO if self._should_log_request() else logging.DEBUG
+
+        # If this is a request where the target user doesn't match the user who
+        # authenticated (e.g. and admin is puppetting a user) then we log both.
+        requester, authenticated_entity = self.get_authenticated_entity()
+        if authenticated_entity:
+            requester = "{}.{}".format(authenticated_entity, requester)
+
         self.site.access_logger.log(
             log_level,
             "%s - %s - {%s}"
@@ -312,7 +376,7 @@ class SynapseRequest(Request):
             ' %sB %s "%s %s %s" "%s" [%d dbevts]',
             self.getClientIP(),
             self.site.site_tag,
-            authenticated_entity,
+            requester,
             processing_time,
             response_send_time,
             usage.ru_utime,
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 03cf3c2b8e..e78343f554 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -22,7 +22,6 @@ them.
 
 See doc/log_contexts.rst for details on how this works.
 """
-
 import inspect
 import logging
 import threading
@@ -30,6 +29,7 @@ import types
 import warnings
 from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union
 
+import attr
 from typing_extensions import Literal
 
 from twisted.internet import defer, threads
@@ -181,6 +181,29 @@ class ContextResourceUsage:
         return res
 
 
+@attr.s(slots=True)
+class ContextRequest:
+    """
+    A bundle of attributes from the SynapseRequest object.
+
+    This exists to:
+
+    * Avoid a cycle between LoggingContext and SynapseRequest.
+    * Be a single variable that can be passed from parent LoggingContexts to
+      their children.
+    """
+
+    request_id = attr.ib(type=str)
+    ip_address = attr.ib(type=str)
+    site_tag = attr.ib(type=str)
+    requester = attr.ib(type=Optional[str])
+    authenticated_entity = attr.ib(type=Optional[str])
+    method = attr.ib(type=str)
+    url = attr.ib(type=str)
+    protocol = attr.ib(type=str)
+    user_agent = attr.ib(type=str)
+
+
 LoggingContextOrSentinel = Union["LoggingContext", "_Sentinel"]
 
 
@@ -256,7 +279,7 @@ class LoggingContext:
         self,
         name: Optional[str] = None,
         parent_context: "Optional[LoggingContext]" = None,
-        request: Optional[str] = None,
+        request: Optional[ContextRequest] = None,
     ) -> None:
         self.previous_context = current_context()
         self.name = name
@@ -281,7 +304,11 @@ class LoggingContext:
         self.parent_context = parent_context
 
         if self.parent_context is not None:
-            self.parent_context.copy_to(self)
+            # we track the current request_id
+            self.request = self.parent_context.request
+
+            # we also track the current scope:
+            self.scope = self.parent_context.scope
 
         if request is not None:
             # the request param overrides the request from the parent context
@@ -289,7 +316,7 @@ class LoggingContext:
 
     def __str__(self) -> str:
         if self.request:
-            return str(self.request)
+            return self.request.request_id
         return "%s@%x" % (self.name, id(self))
 
     @classmethod
@@ -556,8 +583,23 @@ class LoggingContextFilter(logging.Filter):
         # we end up in a death spiral of infinite loops, so let's check, for
         # robustness' sake.
         if context is not None:
-            # Logging is interested in the request.
-            record.request = context.request  # type: ignore
+            # Logging is interested in the request ID. Note that for backwards
+            # compatibility this is stored as the "request" on the record.
+            record.request = str(context)  # type: ignore
+
+            # Add some data from the HTTP request.
+            request = context.request
+            if request is None:
+                return True
+
+            record.ip_address = request.ip_address  # type: ignore
+            record.site_tag = request.site_tag  # type: ignore
+            record.requester = request.requester  # type: ignore
+            record.authenticated_entity = request.authenticated_entity  # type: ignore
+            record.method = request.method  # type: ignore
+            record.url = request.url  # type: ignore
+            record.protocol = request.protocol  # type: ignore
+            record.user_agent = request.user_agent  # type: ignore
 
         return True
 
@@ -630,8 +672,8 @@ def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSe
 def nested_logging_context(suffix: str) -> LoggingContext:
     """Creates a new logging context as a child of another.
 
-    The nested logging context will have a 'request' made up of the parent context's
-    request, plus the given suffix.
+    The nested logging context will have a 'name' made up of the parent context's
+    name, plus the given suffix.
 
     CPU/db usage stats will be added to the parent context's on exit.
 
@@ -641,7 +683,7 @@ def nested_logging_context(suffix: str) -> LoggingContext:
             # ... do stuff
 
     Args:
-        suffix: suffix to add to the parent context's 'request'.
+        suffix: suffix to add to the parent context's 'name'.
 
     Returns:
         LoggingContext: new logging context.
@@ -653,11 +695,17 @@ def nested_logging_context(suffix: str) -> LoggingContext:
         )
         parent_context = None
         prefix = ""
+        request = None
     else:
         assert isinstance(curr_context, LoggingContext)
         parent_context = curr_context
-        prefix = str(parent_context.request)
-    return LoggingContext(parent_context=parent_context, request=prefix + "-" + suffix)
+        prefix = str(parent_context.name)
+        request = parent_context.request
+    return LoggingContext(
+        prefix + "-" + suffix,
+        parent_context=parent_context,
+        request=request,
+    )
 
 
 def preserve_fn(f):
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index 3b499efc07..13a5bc4558 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -214,7 +214,12 @@ class GaugeBucketCollector:
     Prometheus, and optimise for that case.
     """
 
-    __slots__ = ("_name", "_documentation", "_bucket_bounds", "_metric")
+    __slots__ = (
+        "_name",
+        "_documentation",
+        "_bucket_bounds",
+        "_metric",
+    )
 
     def __init__(
         self,
@@ -242,11 +247,16 @@ class GaugeBucketCollector:
         if self._bucket_bounds[-1] != float("inf"):
             self._bucket_bounds.append(float("inf"))
 
-        self._metric = self._values_to_metric([])
+        # We initially set this to None. We won't report metrics until
+        # this has been initialised after a successful data update
+        self._metric = None  # type: Optional[GaugeHistogramMetricFamily]
+
         registry.register(self)
 
     def collect(self):
-        yield self._metric
+        # Don't report metrics unless we've already collected some data
+        if self._metric is not None:
+            yield self._metric
 
     def update_data(self, values: Iterable[float]):
         """Update the data to be reported by the metric
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index b56986d8e7..e8a9096c03 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -16,7 +16,7 @@
 import logging
 import threading
 from functools import wraps
-from typing import TYPE_CHECKING, Dict, Optional, Set
+from typing import TYPE_CHECKING, Dict, Optional, Set, Union
 
 from prometheus_client.core import REGISTRY, Counter, Gauge
 
@@ -199,11 +199,11 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
         _background_process_start_count.labels(desc).inc()
         _background_process_in_flight_count.labels(desc).inc()
 
-        with BackgroundProcessLoggingContext(desc, "%s-%i" % (desc, count)) as context:
+        with BackgroundProcessLoggingContext(desc, count) as context:
             try:
                 ctx = noop_context_manager()
                 if bg_start_span:
-                    ctx = start_active_span(desc, tags={"request_id": context.request})
+                    ctx = start_active_span(desc, tags={"request_id": str(context)})
                 with ctx:
                     return await maybe_awaitable(func(*args, **kwargs))
             except Exception:
@@ -242,13 +242,19 @@ class BackgroundProcessLoggingContext(LoggingContext):
     processes.
     """
 
-    __slots__ = ["_proc"]
+    __slots__ = ["_id", "_proc"]
 
-    def __init__(self, name: str, request: Optional[str] = None):
-        super().__init__(name, request=request)
+    def __init__(self, name: str, id: Optional[Union[int, str]] = None):
+        super().__init__(name)
+        self._id = id
 
         self._proc = _BackgroundProcess(name, self)
 
+    def __str__(self) -> str:
+        if self._id is not None:
+            return "%s-%s" % (self.name, self._id)
+        return "%s@%x" % (self.name, id(self))
+
     def start(self, rusage: "Optional[resource._RUsage]"):
         """Log context has started running (again)."""
 
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 781e02fbbb..3ecd46c038 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -50,11 +50,20 @@ class ModuleApi:
         self._auth = hs.get_auth()
         self._auth_handler = auth_handler
         self._server_name = hs.hostname
+        self._presence_stream = hs.get_event_sources().sources["presence"]
 
         # We expose these as properties below in order to attach a helpful docstring.
         self._http_client = hs.get_simple_http_client()  # type: SimpleHttpClient
         self._public_room_list_manager = PublicRoomListManager(hs)
 
+        # The next time these users sync, they will receive the current presence
+        # state of all local users. Users are added by send_local_online_presence_to,
+        # and removed after a successful sync.
+        #
+        # We make this a private variable to deter modules from accessing it directly,
+        # though other classes in Synapse will still do so.
+        self._send_full_presence_to_local_users = set()
+
     @property
     def http_client(self):
         """Allows making outbound HTTP requests to remote resources.
@@ -385,6 +394,47 @@ class ModuleApi:
 
         return event
 
+    async def send_local_online_presence_to(self, users: Iterable[str]) -> None:
+        """
+        Forces the equivalent of a presence initial_sync for a set of local or remote
+        users. The users will receive presence for all currently online users that they
+        are considered interested in.
+
+        Updates to remote users will be sent immediately, whereas local users will receive
+        them on their next sync attempt.
+
+        Note that this method can only be run on the main or federation_sender worker
+        processes.
+        """
+        if not self._hs.should_send_federation():
+            raise Exception(
+                "send_local_online_presence_to can only be run "
+                "on processes that send federation",
+            )
+
+        for user in users:
+            if self._hs.is_mine_id(user):
+                # Modify SyncHandler._generate_sync_entry_for_presence to call
+                # presence_source.get_new_events with an empty `from_key` if
+                # that user's ID were in a list modified by ModuleApi somewhere.
+                # That user would then get all presence state on next incremental sync.
+
+                # Force a presence initial_sync for this user next time
+                self._send_full_presence_to_local_users.add(user)
+            else:
+                # Retrieve presence state for currently online users that this user
+                # is considered interested in
+                presence_events, _ = await self._presence_stream.get_new_events(
+                    UserID.from_string(user), from_key=None, include_offline=False
+                )
+
+                # Send to remote destinations
+                await make_deferred_yieldable(
+                    # We pull the federation sender here as we can only do so on workers
+                    # that support sending presence
+                    self._hs.get_federation_sender().send_presence(presence_events)
+                )
+
 
 class PublicRoomListManager:
     """Contains methods for adding to, removing from and querying whether a room
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index e829add257..d10d574246 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -184,8 +184,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
         # a logcontext which we use for processing incoming commands. We declare it as a
         # background process so that the CPU stats get reported to prometheus.
-        ctx_name = "replication-conn-%s" % self.conn_id
-        self._logging_context = BackgroundProcessLoggingContext(ctx_name, ctx_name)
+        self._logging_context = BackgroundProcessLoggingContext(
+            "replication-conn", self.conn_id
+        )
 
     def connectionMade(self):
         logger.info("[%s] Connection established", self.id())
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index c4ed9dfdb4..814145a04a 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -175,7 +175,7 @@ class PreviewUrlResource(DirectServeJsonResource):
             clock=self.clock,
             # don't spider URLs more often than once an hour
             expiry_ms=ONE_HOUR,
-        )
+        )  # type: ExpiringCache[str, ObservableDeferred]
 
         if self._worker_run_media_background_jobs:
             self._cleaner_loop = self.clock.looping_call(
diff --git a/synapse/server.py b/synapse/server.py
index e42f7b1a18..cfb55c230d 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -51,6 +51,7 @@ from synapse.crypto import context_factory
 from synapse.crypto.context_factory import RegularPolicyForHTTPS
 from synapse.crypto.keyring import Keyring
 from synapse.events.builder import EventBuilderFactory
+from synapse.events.presence_router import PresenceRouter
 from synapse.events.spamcheck import SpamChecker
 from synapse.events.third_party_rules import ThirdPartyEventRules
 from synapse.events.utils import EventClientSerializer
@@ -426,6 +427,10 @@ class HomeServer(metaclass=abc.ABCMeta):
             raise Exception("Workers cannot write typing")
 
     @cache_in_self
+    def get_presence_router(self) -> PresenceRouter:
+        return PresenceRouter(self)
+
+    @cache_in_self
     def get_typing_handler(self) -> FollowerTypingHandler:
         if self.config.worker.writers.typing == self.get_instance_name():
             # Use get_typing_writer_handler to ensure that we use the same
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index c3d6e80c49..c0f79ffdc8 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -22,6 +22,7 @@ from typing import (
     Callable,
     DefaultDict,
     Dict,
+    FrozenSet,
     Iterable,
     List,
     Optional,
@@ -515,7 +516,7 @@ class StateResolutionHandler:
             expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
             iterable=True,
             reset_expiry_on_get=True,
-        )
+        )  # type: ExpiringCache[FrozenSet[int], _StateCacheEntry]
 
         #
         # stuff for tracking time spent on state-res by room
@@ -536,7 +537,7 @@ class StateResolutionHandler:
         state_groups_ids: Dict[int, StateMap[str]],
         event_map: Optional[Dict[str, EventBase]],
         state_res_store: "StateResolutionStore",
-    ):
+    ) -> _StateCacheEntry:
         """Resolves conflicts between a set of state groups
 
         Always generates a new state group (unless we hit the cache), so should
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 4f3d192562..b7820ac7ff 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -22,6 +22,9 @@ from synapse.storage.database import DatabasePool
 BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = (
     "media_repository_drop_index_wo_method"
 )
+BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2 = (
+    "media_repository_drop_index_wo_method_2"
+)
 
 
 class MediaSortOrder(Enum):
@@ -85,23 +88,35 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
             unique=True,
         )
 
+        # the original impl of _drop_media_index_without_method was broken (see
+        # https://github.com/matrix-org/synapse/issues/8649), so we replace the original
+        # impl with a no-op and run the fixed migration as
+        # media_repository_drop_index_wo_method_2.
+        self.db_pool.updates.register_noop_background_update(
+            BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD
+        )
         self.db_pool.updates.register_background_update_handler(
-            BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD,
+            BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2,
             self._drop_media_index_without_method,
         )
 
     async def _drop_media_index_without_method(self, progress, batch_size):
+        """background update handler which removes the old constraints.
+
+        Note that this is only run on postgres.
+        """
+
         def f(txn):
             txn.execute(
                 "ALTER TABLE local_media_repository_thumbnails DROP CONSTRAINT IF EXISTS local_media_repository_thumbn_media_id_thumbnail_width_thum_key"
             )
             txn.execute(
-                "ALTER TABLE remote_media_cache_thumbnails DROP CONSTRAINT IF EXISTS remote_media_repository_thumbn_media_id_thumbnail_width_thum_key"
+                "ALTER TABLE remote_media_cache_thumbnails DROP CONSTRAINT IF EXISTS remote_media_cache_thumbnails_media_origin_media_id_thumbna_key"
             )
 
         await self.db_pool.runInteraction("drop_media_indices_without_method", f)
         await self.db_pool.updates._end_background_update(
-            BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD
+            BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD_2
         )
         return 1
 
diff --git a/synapse/storage/databases/main/schema/delta/59/11drop_thumbnail_constraint.sql.postgres b/synapse/storage/databases/main/schema/delta/59/11drop_thumbnail_constraint.sql.postgres
new file mode 100644
index 0000000000..54c1bca3b1
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/11drop_thumbnail_constraint.sql.postgres
@@ -0,0 +1,22 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- drop old constraints on remote_media_cache_thumbnails
+--
+-- This was originally part of 57.07, but it was done wrong, per
+-- https://github.com/matrix-org/synapse/issues/8649, so we do it again.
+INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
+  (5911, 'media_repository_drop_index_wo_method_2', '{}', 'remote_media_repository_thumbnails_method_idx');
+
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index e15f7ee698..4dc3477e89 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -15,40 +15,50 @@
 
 import logging
 from collections import OrderedDict
+from typing import Any, Generic, Optional, TypeVar, Union, overload
+
+import attr
+from typing_extensions import Literal
 
 from synapse.config import cache as cache_config
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util import Clock
 from synapse.util.caches import register_cache
 
 logger = logging.getLogger(__name__)
 
 
-SENTINEL = object()
+SENTINEL = object()  # type: Any
+
 
+T = TypeVar("T")
+KT = TypeVar("KT")
+VT = TypeVar("VT")
 
-class ExpiringCache:
+
+class ExpiringCache(Generic[KT, VT]):
     def __init__(
         self,
-        cache_name,
-        clock,
-        max_len=0,
-        expiry_ms=0,
-        reset_expiry_on_get=False,
-        iterable=False,
+        cache_name: str,
+        clock: Clock,
+        max_len: int = 0,
+        expiry_ms: int = 0,
+        reset_expiry_on_get: bool = False,
+        iterable: bool = False,
     ):
         """
         Args:
-            cache_name (str): Name of this cache, used for logging.
-            clock (Clock)
-            max_len (int): Max size of dict. If the dict grows larger than this
+            cache_name: Name of this cache, used for logging.
+            clock
+            max_len: Max size of dict. If the dict grows larger than this
                 then the oldest items get automatically evicted. Default is 0,
                 which indicates there is no max limit.
-            expiry_ms (int): How long before an item is evicted from the cache
+            expiry_ms: How long before an item is evicted from the cache
                 in milliseconds. Default is 0, indicating items never get
                 evicted based on time.
-            reset_expiry_on_get (bool): If true, will reset the expiry time for
+            reset_expiry_on_get: If true, will reset the expiry time for
                 an item on access. Defaults to False.
-            iterable (bool): If true, the size is calculated by summing the
+            iterable: If true, the size is calculated by summing the
                 sizes of all entries, rather than the number of entries.
         """
         self._cache_name = cache_name
@@ -62,7 +72,7 @@ class ExpiringCache:
         self._expiry_ms = expiry_ms
         self._reset_expiry_on_get = reset_expiry_on_get
 
-        self._cache = OrderedDict()
+        self._cache = OrderedDict()  # type: OrderedDict[KT, _CacheEntry]
 
         self.iterable = iterable
 
@@ -79,12 +89,12 @@ class ExpiringCache:
 
         self._clock.looping_call(f, self._expiry_ms / 2)
 
-    def __setitem__(self, key, value):
+    def __setitem__(self, key: KT, value: VT) -> None:
         now = self._clock.time_msec()
         self._cache[key] = _CacheEntry(now, value)
         self.evict()
 
-    def evict(self):
+    def evict(self) -> None:
         # Evict if there are now too many items
         while self._max_size and len(self) > self._max_size:
             _key, value = self._cache.popitem(last=False)
@@ -93,7 +103,7 @@ class ExpiringCache:
             else:
                 self.metrics.inc_evictions()
 
-    def __getitem__(self, key):
+    def __getitem__(self, key: KT) -> VT:
         try:
             entry = self._cache[key]
             self.metrics.inc_hits()
@@ -106,7 +116,7 @@ class ExpiringCache:
 
         return entry.value
 
-    def pop(self, key, default=SENTINEL):
+    def pop(self, key: KT, default: T = SENTINEL) -> Union[VT, T]:
         """Removes and returns the value with the given key from the cache.
 
         If the key isn't in the cache then `default` will be returned if
@@ -115,29 +125,40 @@ class ExpiringCache:
         Identical functionality to `dict.pop(..)`.
         """
 
-        value = self._cache.pop(key, default)
+        value = self._cache.pop(key, SENTINEL)
+        # The key was not found.
         if value is SENTINEL:
-            raise KeyError(key)
+            if default is SENTINEL:
+                raise KeyError(key)
+            return default
 
-        return value
+        return value.value
 
-    def __contains__(self, key):
+    def __contains__(self, key: KT) -> bool:
         return key in self._cache
 
-    def get(self, key, default=None):
+    @overload
+    def get(self, key: KT, default: Literal[None] = None) -> Optional[VT]:
+        ...
+
+    @overload
+    def get(self, key: KT, default: T) -> Union[VT, T]:
+        ...
+
+    def get(self, key: KT, default: Optional[T] = None) -> Union[VT, Optional[T]]:
         try:
             return self[key]
         except KeyError:
             return default
 
-    def setdefault(self, key, value):
+    def setdefault(self, key: KT, value: VT) -> VT:
         try:
             return self[key]
         except KeyError:
             self[key] = value
             return value
 
-    def _prune_cache(self):
+    def _prune_cache(self) -> None:
         if not self._expiry_ms:
             # zero expiry time means don't expire. This should never get called
             # since we have this check in start too.
@@ -166,7 +187,7 @@ class ExpiringCache:
             len(self),
         )
 
-    def __len__(self):
+    def __len__(self) -> int:
         if self.iterable:
             return sum(len(entry.value) for entry in self._cache.values())
         else:
@@ -190,9 +211,7 @@ class ExpiringCache:
         return False
 
 
+@attr.s(slots=True)
 class _CacheEntry:
-    __slots__ = ["time", "value"]
-
-    def __init__(self, time, value):
-        self.time = time
-        self.value = value
+    time = attr.ib(type=int)
+    value = attr.ib()