summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rwxr-xr-xsynapse/_scripts/synapse_port_db.py16
-rw-r--r--synapse/api/ratelimiting.py94
-rw-r--r--synapse/api/room_versions.py33
-rw-r--r--synapse/app/homeserver.py3
-rw-r--r--synapse/appservice/api.py14
-rw-r--r--synapse/config/emailconfig.py45
-rw-r--r--synapse/config/ratelimiting.py7
-rw-r--r--synapse/config/registration.py11
-rw-r--r--synapse/config/repository.py35
-rw-r--r--synapse/event_auth.py26
-rw-r--r--synapse/events/builder.py10
-rw-r--r--synapse/federation/federation_client.py2
-rw-r--r--synapse/federation/federation_server.py16
-rw-r--r--synapse/federation/sender/__init__.py10
-rw-r--r--synapse/federation/transport/client.py2
-rw-r--r--synapse/handlers/appservice.py11
-rw-r--r--synapse/handlers/e2e_keys.py16
-rw-r--r--synapse/handlers/federation_event.py284
-rw-r--r--synapse/handlers/identity.py90
-rw-r--r--synapse/handlers/message.py26
-rw-r--r--synapse/handlers/room.py45
-rw-r--r--synapse/handlers/room_member.py115
-rw-r--r--synapse/handlers/ui_auth/checkers.py21
-rw-r--r--synapse/logging/opentracing.py50
-rw-r--r--synapse/notifier.py18
-rw-r--r--synapse/push/pusherpool.py27
-rw-r--r--synapse/replication/http/_base.py4
-rw-r--r--synapse/replication/tcp/client.py17
-rw-r--r--synapse/replication/tcp/streams/events.py1
-rw-r--r--synapse/rest/admin/users.py1
-rw-r--r--synapse/rest/client/account.py113
-rw-r--r--synapse/rest/client/keys.py4
-rw-r--r--synapse/rest/client/login.py10
-rw-r--r--synapse/rest/client/read_marker.py56
-rw-r--r--synapse/rest/client/receipts.py20
-rw-r--r--synapse/rest/client/register.py59
-rw-r--r--synapse/rest/client/room_keys.py13
-rw-r--r--synapse/rest/client/sendtodevice.py4
-rw-r--r--synapse/rest/client/sync.py12
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py62
-rw-r--r--synapse/rest/media/v1/thumbnail_resource.py40
-rw-r--r--synapse/rest/synapse/client/password_reset.py8
-rw-r--r--synapse/state/__init__.py194
-rw-r--r--synapse/storage/_base.py9
-rw-r--r--synapse/storage/controllers/__init__.py4
-rw-r--r--synapse/storage/controllers/persist_events.py15
-rw-r--r--synapse/storage/controllers/state.py2
-rw-r--r--synapse/storage/database.py50
-rw-r--r--synapse/storage/databases/main/appservice.py58
-rw-r--r--synapse/storage/databases/main/cache.py7
-rw-r--r--synapse/storage/databases/main/censor_events.py2
-rw-r--r--synapse/storage/databases/main/devices.py2
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py47
-rw-r--r--synapse/storage/databases/main/events.py10
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py87
-rw-r--r--synapse/storage/databases/main/events_worker.py107
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py1
-rw-r--r--synapse/storage/databases/main/purge_events.py35
-rw-r--r--synapse/storage/databases/main/push_rule.py1
-rw-r--r--synapse/storage/databases/main/room.py6
-rw-r--r--synapse/storage/databases/main/roommember.py43
-rw-r--r--synapse/storage/databases/main/stream.py32
-rw-r--r--synapse/storage/databases/state/bg_updates.py3
-rw-r--r--synapse/storage/databases/state/store.py156
-rw-r--r--synapse/storage/schema/__init__.py7
-rw-r--r--synapse/storage/schema/main/delta/72/03bg_populate_events_columns.py47
-rw-r--r--synapse/storage/schema/main/delta/72/03drop_event_reference_hashes.sql17
-rw-r--r--synapse/storage/schema/main/delta/72/03remove_groups.sql31
-rw-r--r--synapse/storage/util/partial_state_events_tracker.py1
-rw-r--r--synapse/util/caches/lrucache.py38
70 files changed, 1514 insertions, 949 deletions
diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py
index 26834a437e..543bba27c2 100755
--- a/synapse/_scripts/synapse_port_db.py
+++ b/synapse/_scripts/synapse_port_db.py
@@ -166,22 +166,6 @@ IGNORED_TABLES = {
     "ui_auth_sessions",
     "ui_auth_sessions_credentials",
     "ui_auth_sessions_ips",
-    # Groups/communities is no longer supported.
-    "group_attestations_remote",
-    "group_attestations_renewals",
-    "group_invites",
-    "group_roles",
-    "group_room_categories",
-    "group_rooms",
-    "group_summary_roles",
-    "group_summary_room_categories",
-    "group_summary_rooms",
-    "group_summary_users",
-    "group_users",
-    "groups",
-    "local_group_membership",
-    "local_group_updates",
-    "remote_profile_cache",
 }
 
 
diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index 54d13026c9..f43965c1c8 100644
--- a/synapse/api/ratelimiting.py
+++ b/synapse/api/ratelimiting.py
@@ -27,6 +27,33 @@ class Ratelimiter:
     """
     Ratelimit actions marked by arbitrary keys.
 
+    (Note that the source code speaks of "actions" and "burst_count" rather than
+    "tokens" and a "bucket_size".)
+
+    This is a "leaky bucket as a meter". For each key to be tracked there is a bucket
+    containing some number 0 <= T <= `burst_count` of tokens corresponding to previously
+    permitted requests for that key. Each bucket starts empty, and gradually leaks
+    tokens at a rate of `rate_hz`.
+
+    Upon an incoming request, we must determine:
+    - the key that this request falls under (which bucket to inspect), and
+    - the cost C of this request in tokens.
+    Then, if there is room in the bucket for C tokens (T + C <= `burst_count`),
+    the request is permitted and `cost` tokens are added to the bucket.
+    Otherwise the request is denied, and the bucket continues to hold T tokens.
+
+    This means that the limiter enforces an average request frequency of `rate_hz`,
+    while accumulating a buffer of up to `burst_count` requests which can be consumed
+    instantaneously.
+
+    The tricky bit is the leaking. We do not want to have a periodic process which
+    leaks every bucket! Instead, we track
+    - the time point when the bucket was last completely empty, and
+    - how many tokens have added to the bucket permitted since then.
+    Then for each incoming request, we can calculate how many tokens have leaked
+    since this time point, and use that to decide if we should accept or reject the
+    request.
+
     Args:
         clock: A homeserver clock, for retrieving the current time
         rate_hz: The long term number of actions that can be performed in a second.
@@ -41,14 +68,30 @@ class Ratelimiter:
         self.burst_count = burst_count
         self.store = store
 
-        # A ordered dictionary keeping track of actions, when they were last
-        # performed and how often. Each entry is a mapping from a key of arbitrary type
-        # to a tuple representing:
-        #   * How many times an action has occurred since a point in time
-        #   * The point in time
-        #   * The rate_hz of this particular entry. This can vary per request
+        # An ordered dictionary representing the token buckets tracked by this rate
+        # limiter. Each entry maps a key of arbitrary type to a tuple representing:
+        #   * The number of tokens currently in the bucket,
+        #   * The time point when the bucket was last completely empty, and
+        #   * The rate_hz (leak rate) of this particular bucket.
         self.actions: OrderedDict[Hashable, Tuple[float, float, float]] = OrderedDict()
 
+    def _get_key(
+        self, requester: Optional[Requester], key: Optional[Hashable]
+    ) -> Hashable:
+        """Use the requester's MXID as a fallback key if no key is provided."""
+        if key is None:
+            if not requester:
+                raise ValueError("Must supply at least one of `requester` or `key`")
+
+            key = requester.user.to_string()
+        return key
+
+    def _get_action_counts(
+        self, key: Hashable, time_now_s: float
+    ) -> Tuple[float, float, float]:
+        """Retrieve the action counts, with a fallback representing an empty bucket."""
+        return self.actions.get(key, (0.0, time_now_s, 0.0))
+
     async def can_do_action(
         self,
         requester: Optional[Requester],
@@ -88,11 +131,7 @@ class Ratelimiter:
                 * The reactor timestamp for when the action can be performed next.
                   -1 if rate_hz is less than or equal to zero
         """
-        if key is None:
-            if not requester:
-                raise ValueError("Must supply at least one of `requester` or `key`")
-
-            key = requester.user.to_string()
+        key = self._get_key(requester, key)
 
         if requester:
             # Disable rate limiting of users belonging to any AS that is configured
@@ -121,7 +160,7 @@ class Ratelimiter:
         self._prune_message_counts(time_now_s)
 
         # Check if there is an existing count entry for this key
-        action_count, time_start, _ = self.actions.get(key, (0.0, time_now_s, 0.0))
+        action_count, time_start, _ = self._get_action_counts(key, time_now_s)
 
         # Check whether performing another action is allowed
         time_delta = time_now_s - time_start
@@ -164,6 +203,37 @@ class Ratelimiter:
 
         return allowed, time_allowed
 
+    def record_action(
+        self,
+        requester: Optional[Requester],
+        key: Optional[Hashable] = None,
+        n_actions: int = 1,
+        _time_now_s: Optional[float] = None,
+    ) -> None:
+        """Record that an action(s) took place, even if they violate the rate limit.
+
+        This is useful for tracking the frequency of events that happen across
+        federation which we still want to impose local rate limits on. For instance, if
+        we are alice.com monitoring a particular room, we cannot prevent bob.com
+        from joining users to that room. However, we can track the number of recent
+        joins in the room and refuse to serve new joins ourselves if there have been too
+        many in the room across both homeservers.
+
+        Args:
+            requester: The requester that is doing the action, if any.
+            key: An arbitrary key used to classify an action. Defaults to the
+                requester's user ID.
+            n_actions: The number of times the user wants to do this action. If the user
+                cannot do all of the actions, the user's action count is not incremented
+                at all.
+            _time_now_s: The current time. Optional, defaults to the current time according
+                to self.clock. Only used by tests.
+        """
+        key = self._get_key(requester, key)
+        time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()
+        action_count, time_start, rate_hz = self._get_action_counts(key, time_now_s)
+        self.actions[key] = (action_count + n_actions, time_start, rate_hz)
+
     def _prune_message_counts(self, time_now_s: float) -> None:
         """Remove message count entries that have not exceeded their defined
         rate_hz limit
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index 3f85d61b46..00e81b3afc 100644
--- a/synapse/api/room_versions.py
+++ b/synapse/api/room_versions.py
@@ -84,6 +84,8 @@ class RoomVersion:
     # MSC3787: Adds support for a `knock_restricted` join rule, mixing concepts of
     # knocks and restricted join rules into the same join condition.
     msc3787_knock_restricted_join_rule: bool
+    # MSC3667: Enforce integer power levels
+    msc3667_int_only_power_levels: bool
 
 
 class RoomVersions:
@@ -103,6 +105,7 @@ class RoomVersions:
         msc2716_historical=False,
         msc2716_redactions=False,
         msc3787_knock_restricted_join_rule=False,
+        msc3667_int_only_power_levels=False,
     )
     V2 = RoomVersion(
         "2",
@@ -120,6 +123,7 @@ class RoomVersions:
         msc2716_historical=False,
         msc2716_redactions=False,
         msc3787_knock_restricted_join_rule=False,
+        msc3667_int_only_power_levels=False,
     )
     V3 = RoomVersion(
         "3",
@@ -137,6 +141,7 @@ class RoomVersions:
         msc2716_historical=False,
         msc2716_redactions=False,
         msc3787_knock_restricted_join_rule=False,
+        msc3667_int_only_power_levels=False,
     )
     V4 = RoomVersion(
         "4",
@@ -154,6 +159,7 @@ class RoomVersions:
         msc2716_historical=False,
         msc2716_redactions=False,
         msc3787_knock_restricted_join_rule=False,
+        msc3667_int_only_power_levels=False,
     )
     V5 = RoomVersion(
         "5",
@@ -171,6 +177,7 @@ class RoomVersions:
         msc2716_historical=False,
         msc2716_redactions=False,
         msc3787_knock_restricted_join_rule=False,
+        msc3667_int_only_power_levels=False,
     )
     V6 = RoomVersion(
         "6",
@@ -188,6 +195,7 @@ class RoomVersions:
         msc2716_historical=False,
         msc2716_redactions=False,
         msc3787_knock_restricted_join_rule=False,
+        msc3667_int_only_power_levels=False,
     )
     MSC2176 = RoomVersion(
         "org.matrix.msc2176",
@@ -205,6 +213,7 @@ class RoomVersions:
         msc2716_historical=False,
         msc2716_redactions=False,
         msc3787_knock_restricted_join_rule=False,
+        msc3667_int_only_power_levels=False,
     )
     V7 = RoomVersion(
         "7",
@@ -222,6 +231,7 @@ class RoomVersions:
         msc2716_historical=False,
         msc2716_redactions=False,
         msc3787_knock_restricted_join_rule=False,
+        msc3667_int_only_power_levels=False,
     )
     V8 = RoomVersion(
         "8",
@@ -239,6 +249,7 @@ class RoomVersions:
         msc2716_historical=False,
         msc2716_redactions=False,
         msc3787_knock_restricted_join_rule=False,
+        msc3667_int_only_power_levels=False,
     )
     V9 = RoomVersion(
         "9",
@@ -256,6 +267,7 @@ class RoomVersions:
         msc2716_historical=False,
         msc2716_redactions=False,
         msc3787_knock_restricted_join_rule=False,
+        msc3667_int_only_power_levels=False,
     )
     MSC2716v3 = RoomVersion(
         "org.matrix.msc2716v3",
@@ -273,6 +285,7 @@ class RoomVersions:
         msc2716_historical=True,
         msc2716_redactions=True,
         msc3787_knock_restricted_join_rule=False,
+        msc3667_int_only_power_levels=False,
     )
     MSC3787 = RoomVersion(
         "org.matrix.msc3787",
@@ -290,6 +303,25 @@ class RoomVersions:
         msc2716_historical=False,
         msc2716_redactions=False,
         msc3787_knock_restricted_join_rule=True,
+        msc3667_int_only_power_levels=False,
+    )
+    V10 = RoomVersion(
+        "10",
+        RoomDisposition.STABLE,
+        EventFormatVersions.V3,
+        StateResolutionVersions.V2,
+        enforce_key_validity=True,
+        special_case_aliases_auth=False,
+        strict_canonicaljson=True,
+        limit_notifications_power_levels=True,
+        msc2176_redaction_rules=False,
+        msc3083_join_rules=True,
+        msc3375_redaction_rules=True,
+        msc2403_knocking=True,
+        msc2716_historical=False,
+        msc2716_redactions=False,
+        msc3787_knock_restricted_join_rule=True,
+        msc3667_int_only_power_levels=True,
     )
 
 
@@ -308,6 +340,7 @@ KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = {
         RoomVersions.V9,
         RoomVersions.MSC2716v3,
         RoomVersions.MSC3787,
+        RoomVersions.V10,
     )
 }
 
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 745e704141..6bafa7d3f3 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -44,7 +44,6 @@ from synapse.app._base import (
     register_start,
 )
 from synapse.config._base import ConfigError, format_config_error
-from synapse.config.emailconfig import ThreepidBehaviour
 from synapse.config.homeserver import HomeServerConfig
 from synapse.config.server import ListenerConfig
 from synapse.federation.transport.server import TransportLayerServer
@@ -202,7 +201,7 @@ class SynapseHomeServer(HomeServer):
                 }
             )
 
-            if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+            if self.config.email.can_verify_email:
                 from synapse.rest.synapse.client.password_reset import (
                     PasswordResetSubmitTokenResource,
                 )
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index df1c214462..0963fb3bb4 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -53,6 +53,18 @@ sent_events_counter = Counter(
     "synapse_appservice_api_sent_events", "Number of events sent to the AS", ["service"]
 )
 
+sent_ephemeral_counter = Counter(
+    "synapse_appservice_api_sent_ephemeral",
+    "Number of ephemeral events sent to the AS",
+    ["service"],
+)
+
+sent_todevice_counter = Counter(
+    "synapse_appservice_api_sent_todevice",
+    "Number of todevice messages sent to the AS",
+    ["service"],
+)
+
 HOUR_IN_MS = 60 * 60 * 1000
 
 
@@ -310,6 +322,8 @@ class ApplicationServiceApi(SimpleHttpClient):
                 )
             sent_transactions_counter.labels(service.id).inc()
             sent_events_counter.labels(service.id).inc(len(serialized_events))
+            sent_ephemeral_counter.labels(service.id).inc(len(ephemeral))
+            sent_todevice_counter.labels(service.id).inc(len(to_device_messages))
             return True
         except CodeMessageException as e:
             logger.warning(
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index 6e11fbdb9a..3ead80d985 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -18,7 +18,6 @@
 import email.utils
 import logging
 import os
-from enum import Enum
 from typing import Any
 
 import attr
@@ -131,41 +130,22 @@ class EmailConfig(Config):
 
         self.email_enable_notifs = email_config.get("enable_notifs", False)
 
-        self.threepid_behaviour_email = (
-            # Have Synapse handle the email sending if account_threepid_delegates.email
-            # is not defined
-            # msisdn is currently always remote while Synapse does not support any method of
-            # sending SMS messages
-            ThreepidBehaviour.REMOTE
-            if self.root.registration.account_threepid_delegate_email
-            else ThreepidBehaviour.LOCAL
-        )
-
         if config.get("trust_identity_server_for_password_resets"):
             raise ConfigError(
                 'The config option "trust_identity_server_for_password_resets" '
-                'has been replaced by "account_threepid_delegate". '
-                "Please consult the configuration manual at docs/usage/configuration/config_documentation.md for "
-                "details and update your config file."
+                "is no longer supported. Please remove it from the config file."
             )
 
-        self.local_threepid_handling_disabled_due_to_email_config = False
-        if (
-            self.threepid_behaviour_email == ThreepidBehaviour.LOCAL
-            and email_config == {}
-        ):
-            # We cannot warn the user this has happened here
-            # Instead do so when a user attempts to reset their password
-            self.local_threepid_handling_disabled_due_to_email_config = True
-
-            self.threepid_behaviour_email = ThreepidBehaviour.OFF
+        # If we have email config settings, assume that we can verify ownership of
+        # email addresses.
+        self.can_verify_email = email_config != {}
 
         # Get lifetime of a validation token in milliseconds
         self.email_validation_token_lifetime = self.parse_duration(
             email_config.get("validation_token_lifetime", "1h")
         )
 
-        if self.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+        if self.can_verify_email:
             missing = []
             if not self.email_notif_from:
                 missing.append("email.notif_from")
@@ -356,18 +336,3 @@ class EmailConfig(Config):
                     "Config option email.invite_client_location must be a http or https URL",
                     path=("email", "invite_client_location"),
                 )
-
-
-class ThreepidBehaviour(Enum):
-    """
-    Enum to define the behaviour of Synapse with regards to when it contacts an identity
-    server for 3pid registration and password resets
-
-    REMOTE = use an external server to send tokens
-    LOCAL = send tokens ourselves
-    OFF = disable registration via 3pid and password resets
-    """
-
-    REMOTE = "remote"
-    LOCAL = "local"
-    OFF = "off"
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 4fc1784efe..5a91917b4a 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -112,6 +112,13 @@ class RatelimitConfig(Config):
             defaults={"per_second": 0.01, "burst_count": 10},
         )
 
+        # Track the rate of joins to a given room. If there are too many, temporarily
+        # prevent local joins and remote joins via this server.
+        self.rc_joins_per_room = RateLimitConfig(
+            config.get("rc_joins_per_room", {}),
+            defaults={"per_second": 1, "burst_count": 10},
+        )
+
         # Ratelimit cross-user key requests:
         # * For local requests this is keyed by the sending device.
         # * For requests received over federation this is keyed by the origin.
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index fcf99be092..685a0423c5 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -20,6 +20,13 @@ from synapse.config._base import Config, ConfigError
 from synapse.types import JsonDict, RoomAlias, UserID
 from synapse.util.stringutils import random_string_with_symbols, strtobool
 
+NO_EMAIL_DELEGATE_ERROR = """\
+Delegation of email verification to an identity server is no longer supported. To
+continue to allow users to add email addresses to their accounts, and use them for
+password resets, configure Synapse with an SMTP server via the `email` setting, and
+remove `account_threepid_delegates.email`.
+"""
+
 
 class RegistrationConfig(Config):
     section = "registration"
@@ -51,7 +58,9 @@ class RegistrationConfig(Config):
         self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
 
         account_threepid_delegates = config.get("account_threepid_delegates") or {}
-        self.account_threepid_delegate_email = account_threepid_delegates.get("email")
+        if "email" in account_threepid_delegates:
+            raise ConfigError(NO_EMAIL_DELEGATE_ERROR)
+        # self.account_threepid_delegate_email = account_threepid_delegates.get("email")
         self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn")
         self.default_identity_server = config.get("default_identity_server")
         self.allow_guest_access = config.get("allow_guest_access", False)
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index 3c69dd325f..1033496bb4 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -42,6 +42,18 @@ THUMBNAIL_SIZE_YAML = """\
         #    method: %(method)s
 """
 
+# A map from the given media type to the type of thumbnail we should generate
+# for it.
+THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP = {
+    "image/jpeg": "jpeg",
+    "image/jpg": "jpeg",
+    "image/webp": "jpeg",
+    # Thumbnails can only be jpeg or png. We choose png thumbnails for gif
+    # because it can have transparency.
+    "image/gif": "png",
+    "image/png": "png",
+}
+
 HTTP_PROXY_SET_WARNING = """\
 The Synapse config url_preview_ip_range_blacklist will be ignored as an HTTP(s) proxy is configured."""
 
@@ -79,13 +91,22 @@ def parse_thumbnail_requirements(
         width = size["width"]
         height = size["height"]
         method = size["method"]
-        jpeg_thumbnail = ThumbnailRequirement(width, height, method, "image/jpeg")
-        png_thumbnail = ThumbnailRequirement(width, height, method, "image/png")
-        requirements.setdefault("image/jpeg", []).append(jpeg_thumbnail)
-        requirements.setdefault("image/jpg", []).append(jpeg_thumbnail)
-        requirements.setdefault("image/webp", []).append(jpeg_thumbnail)
-        requirements.setdefault("image/gif", []).append(png_thumbnail)
-        requirements.setdefault("image/png", []).append(png_thumbnail)
+
+        for format, thumbnail_format in THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP.items():
+            requirement = requirements.setdefault(format, [])
+            if thumbnail_format == "jpeg":
+                requirement.append(
+                    ThumbnailRequirement(width, height, method, "image/jpeg")
+                )
+            elif thumbnail_format == "png":
+                requirement.append(
+                    ThumbnailRequirement(width, height, method, "image/png")
+                )
+            else:
+                raise Exception(
+                    "Unknown thumbnail mapping from %s to %s. This is a Synapse problem, please report!"
+                    % (format, thumbnail_format)
+                )
     return {
         media_type: tuple(thumbnails) for media_type, thumbnails in requirements.items()
     }
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 0fc2c4b27e..965cb265da 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -740,6 +740,32 @@ def _check_power_levels(
         except Exception:
             raise SynapseError(400, "Not a valid power level: %s" % (v,))
 
+    # Reject events with stringy power levels if required by room version
+    if (
+        event.type == EventTypes.PowerLevels
+        and room_version_obj.msc3667_int_only_power_levels
+    ):
+        for k, v in event.content.items():
+            if k in {
+                "users_default",
+                "events_default",
+                "state_default",
+                "ban",
+                "redact",
+                "kick",
+                "invite",
+            }:
+                if not isinstance(v, int):
+                    raise SynapseError(400, f"{v!r} must be an integer.")
+            if k in {"events", "notifications", "users"}:
+                if not isinstance(v, dict) or not all(
+                    isinstance(v, int) for v in v.values()
+                ):
+                    raise SynapseError(
+                        400,
+                        f"{v!r} must be a dict wherein all the values are integers.",
+                    )
+
     key = (event.type, event.state_key)
     current_state = auth_events.get(key)
 
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 98c203ada0..17f624b68f 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -24,9 +24,11 @@ from synapse.api.room_versions import (
     RoomVersion,
 )
 from synapse.crypto.event_signing import add_hashes_and_signatures
+from synapse.event_auth import auth_types_for_event
 from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict
 from synapse.state import StateHandler
 from synapse.storage.databases.main import DataStore
+from synapse.storage.state import StateFilter
 from synapse.types import EventID, JsonDict
 from synapse.util import Clock
 from synapse.util.stringutils import random_string
@@ -120,8 +122,12 @@ class EventBuilder:
             The signed and hashed event.
         """
         if auth_event_ids is None:
-            state_ids = await self._state.get_current_state_ids(
-                self.room_id, prev_event_ids
+            state_ids = await self._state.compute_state_after_events(
+                self.room_id,
+                prev_event_ids,
+                state_filter=StateFilter.from_types(
+                    auth_types_for_event(self.room_version, self)
+                ),
             )
             auth_event_ids = self._event_auth_handler.compute_auth_events(
                 self, state_ids
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 66e6305562..7c450ecad0 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -217,7 +217,7 @@ class FederationClient(FederationBase):
         )
 
     async def claim_client_keys(
-        self, destination: str, content: JsonDict, timeout: int
+        self, destination: str, content: JsonDict, timeout: Optional[int]
     ) -> JsonDict:
         """Claims one-time keys for a device hosted on a remote server.
 
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 5dfdc86740..ae550d3f4d 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -118,6 +118,7 @@ class FederationServer(FederationBase):
         self._federation_event_handler = hs.get_federation_event_handler()
         self.state = hs.get_state_handler()
         self._event_auth_handler = hs.get_event_auth_handler()
+        self._room_member_handler = hs.get_room_member_handler()
 
         self._state_storage_controller = hs.get_storage_controllers().state
 
@@ -621,6 +622,15 @@ class FederationServer(FederationBase):
             )
             raise IncompatibleRoomVersionError(room_version=room_version)
 
+        # Refuse the request if that room has seen too many joins recently.
+        # This is in addition to the HS-level rate limiting applied by
+        # BaseFederationServlet.
+        # type-ignore: mypy doesn't seem able to deduce the type of the limiter(!?)
+        await self._room_member_handler._join_rate_per_room_limiter.ratelimit(  # type: ignore[has-type]
+            requester=None,
+            key=room_id,
+            update=False,
+        )
         pdu = await self.handler.on_make_join_request(origin, room_id, user_id)
         return {"event": pdu.get_templated_pdu_json(), "room_version": room_version}
 
@@ -655,6 +665,12 @@ class FederationServer(FederationBase):
         room_id: str,
         caller_supports_partial_state: bool = False,
     ) -> Dict[str, Any]:
+        await self._room_member_handler._join_rate_per_room_limiter.ratelimit(  # type: ignore[has-type]
+            requester=None,
+            key=room_id,
+            update=False,
+        )
+
         event, context = await self._on_send_membership_event(
             origin, content, Membership.JOIN, room_id
         )
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 99a794c042..94a65ac65f 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -351,7 +351,11 @@ class FederationSender(AbstractFederationSender):
             self._is_processing = True
             while True:
                 last_token = await self.store.get_federation_out_pos("events")
-                next_token, events = await self.store.get_all_new_events_stream(
+                (
+                    next_token,
+                    events,
+                    event_to_received_ts,
+                ) = await self.store.get_all_new_events_stream(
                     last_token, self._last_poked_id, limit=100
                 )
 
@@ -476,7 +480,7 @@ class FederationSender(AbstractFederationSender):
                         await self._send_pdu(event, sharded_destinations)
 
                         now = self.clock.time_msec()
-                        ts = await self.store.get_received_ts(event.event_id)
+                        ts = event_to_received_ts[event.event_id]
                         assert ts is not None
                         synapse.metrics.event_processing_lag_by_event.labels(
                             "federation_sender"
@@ -509,7 +513,7 @@ class FederationSender(AbstractFederationSender):
 
                 if events:
                     now = self.clock.time_msec()
-                    ts = await self.store.get_received_ts(events[-1].event_id)
+                    ts = event_to_received_ts[events[-1].event_id]
                     assert ts is not None
 
                     synapse.metrics.event_processing_lag.labels(
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 9e84bd677e..32074b8ca6 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -619,7 +619,7 @@ class TransportLayerClient:
         )
 
     async def claim_client_keys(
-        self, destination: str, query_content: JsonDict, timeout: int
+        self, destination: str, query_content: JsonDict, timeout: Optional[int]
     ) -> JsonDict:
         """Claim one-time keys for a list of devices hosted on a remote server.
 
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 814553e098..203b62e015 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -104,14 +104,15 @@ class ApplicationServicesHandler:
         with Measure(self.clock, "notify_interested_services"):
             self.is_processing = True
             try:
-                limit = 100
                 upper_bound = -1
                 while upper_bound < self.current_max:
+                    last_token = await self.store.get_appservice_last_pos()
                     (
                         upper_bound,
                         events,
-                    ) = await self.store.get_new_events_for_appservice(
-                        self.current_max, limit
+                        event_to_received_ts,
+                    ) = await self.store.get_all_new_events_stream(
+                        last_token, self.current_max, limit=100, get_prev_content=True
                     )
 
                     events_by_room: Dict[str, List[EventBase]] = {}
@@ -150,7 +151,7 @@ class ApplicationServicesHandler:
                             )
 
                         now = self.clock.time_msec()
-                        ts = await self.store.get_received_ts(event.event_id)
+                        ts = event_to_received_ts[event.event_id]
                         assert ts is not None
 
                         synapse.metrics.event_processing_lag_by_event.labels(
@@ -187,7 +188,7 @@ class ApplicationServicesHandler:
 
                     if events:
                         now = self.clock.time_msec()
-                        ts = await self.store.get_received_ts(events[-1].event_id)
+                        ts = event_to_received_ts[events[-1].event_id]
                         assert ts is not None
 
                         synapse.metrics.event_processing_lag.labels(
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 52bb5c9c55..84c28c480e 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple
 
 import attr
 from canonicaljson import encode_canonical_json
@@ -92,7 +92,11 @@ class E2eKeysHandler:
 
     @trace
     async def query_devices(
-        self, query_body: JsonDict, timeout: int, from_user_id: str, from_device_id: str
+        self,
+        query_body: JsonDict,
+        timeout: int,
+        from_user_id: str,
+        from_device_id: Optional[str],
     ) -> JsonDict:
         """Handle a device key query from a client
 
@@ -120,9 +124,7 @@ class E2eKeysHandler:
                 the number of in-flight queries at a time.
         """
         async with self._query_devices_linearizer.queue((from_user_id, from_device_id)):
-            device_keys_query: Dict[str, Iterable[str]] = query_body.get(
-                "device_keys", {}
-            )
+            device_keys_query: Dict[str, List[str]] = query_body.get("device_keys", {})
 
             # separate users by domain.
             # make a map from domain to user_id to device_ids
@@ -392,7 +394,7 @@ class E2eKeysHandler:
 
     @trace
     async def query_local_devices(
-        self, query: Dict[str, Optional[List[str]]]
+        self, query: Mapping[str, Optional[List[str]]]
     ) -> Dict[str, Dict[str, dict]]:
         """Get E2E device keys for local users
 
@@ -461,7 +463,7 @@ class E2eKeysHandler:
 
     @trace
     async def claim_one_time_keys(
-        self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int
+        self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int]
     ) -> JsonDict:
         local_query: List[Tuple[str, str, str]] = []
         remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index c74117c19a..e4a5b64d10 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import collections
 import itertools
 import logging
 from http import HTTPStatus
@@ -347,7 +348,7 @@ class FederationEventHandler:
         event.internal_metadata.send_on_behalf_of = origin
 
         context = await self._state_handler.compute_event_context(event)
-        context = await self._check_event_auth(origin, event, context)
+        await self._check_event_auth(origin, event, context)
         if context.rejected:
             raise SynapseError(
                 403, f"{event.membership} event was rejected", Codes.FORBIDDEN
@@ -485,7 +486,7 @@ class FederationEventHandler:
                 partial_state=partial_state,
             )
 
-            context = await self._check_event_auth(origin, event, context)
+            await self._check_event_auth(origin, event, context)
             if context.rejected:
                 raise SynapseError(400, "Join event was rejected")
 
@@ -1036,6 +1037,9 @@ class FederationEventHandler:
         # XXX: this doesn't sound right? it means that we'll end up with incomplete
         #   state.
         failed_to_fetch = desired_events - event_metadata.keys()
+        # `event_id` could be missing from `event_metadata` because it's not necessarily
+        # a state event. We've already checked that we've fetched it above.
+        failed_to_fetch.discard(event_id)
         if failed_to_fetch:
             logger.warning(
                 "Failed to fetch missing state events for %s %s",
@@ -1116,11 +1120,7 @@ class FederationEventHandler:
             state_ids_before_event=state_ids,
         )
         try:
-            context = await self._check_event_auth(
-                origin,
-                event,
-                context,
-            )
+            await self._check_event_auth(origin, event, context)
         except AuthError as e:
             # This happens only if we couldn't find the auth events. We'll already have
             # logged a warning, so now we just convert to a FederationError.
@@ -1495,11 +1495,8 @@ class FederationEventHandler:
         )
 
     async def _check_event_auth(
-        self,
-        origin: str,
-        event: EventBase,
-        context: EventContext,
-    ) -> EventContext:
+        self, origin: str, event: EventBase, context: EventContext
+    ) -> None:
         """
         Checks whether an event should be rejected (for failing auth checks).
 
@@ -1509,9 +1506,6 @@ class FederationEventHandler:
             context:
                 The event context.
 
-        Returns:
-            The updated context object.
-
         Raises:
             AuthError if we were unable to find copies of the event's auth events.
                (Most other failures just cause us to set `context.rejected`.)
@@ -1526,7 +1520,7 @@ class FederationEventHandler:
             logger.warning("While validating received event %r: %s", event, e)
             # TODO: use a different rejected reason here?
             context.rejected = RejectedReason.AUTH_ERROR
-            return context
+            return
 
         # next, check that we have all of the event's auth events.
         #
@@ -1538,6 +1532,9 @@ class FederationEventHandler:
         )
 
         # ... and check that the event passes auth at those auth events.
+        # https://spec.matrix.org/v1.3/server-server-api/#checks-performed-on-receipt-of-a-pdu:
+        #   4. Passes authorization rules based on the event’s auth events,
+        #      otherwise it is rejected.
         try:
             await check_state_independent_auth_rules(self._store, event)
             check_state_dependent_auth_rules(event, claimed_auth_events)
@@ -1546,55 +1543,90 @@ class FederationEventHandler:
                 "While checking auth of %r against auth_events: %s", event, e
             )
             context.rejected = RejectedReason.AUTH_ERROR
-            return context
+            return
+
+        # now check the auth rules pass against the room state before the event
+        # https://spec.matrix.org/v1.3/server-server-api/#checks-performed-on-receipt-of-a-pdu:
+        #   5. Passes authorization rules based on the state before the event,
+        #      otherwise it is rejected.
+        #
+        # ... however, if we only have partial state for the room, then there is a good
+        # chance that we'll be missing some of the state needed to auth the new event.
+        # So, we state-resolve the auth events that we are given against the state that
+        # we know about, which ensures things like bans are applied. (Note that we'll
+        # already have checked we have all the auth events, in
+        # _load_or_fetch_auth_events_for_event above)
+        if context.partial_state:
+            room_version = await self._store.get_room_version_id(event.room_id)
+
+            local_state_id_map = await context.get_prev_state_ids()
+            claimed_auth_events_id_map = {
+                (ev.type, ev.state_key): ev.event_id for ev in claimed_auth_events
+            }
+
+            state_for_auth_id_map = (
+                await self._state_resolution_handler.resolve_events_with_store(
+                    event.room_id,
+                    room_version,
+                    [local_state_id_map, claimed_auth_events_id_map],
+                    event_map=None,
+                    state_res_store=StateResolutionStore(self._store),
+                )
+            )
+        else:
+            event_types = event_auth.auth_types_for_event(event.room_version, event)
+            state_for_auth_id_map = await context.get_prev_state_ids(
+                StateFilter.from_types(event_types)
+            )
 
-        # now check auth against what we think the auth events *should* be.
-        event_types = event_auth.auth_types_for_event(event.room_version, event)
-        prev_state_ids = await context.get_prev_state_ids(
-            StateFilter.from_types(event_types)
+        calculated_auth_event_ids = self._event_auth_handler.compute_auth_events(
+            event, state_for_auth_id_map, for_verification=True
         )
 
-        auth_events_ids = self._event_auth_handler.compute_auth_events(
-            event, prev_state_ids, for_verification=True
+        # if those are the same, we're done here.
+        if collections.Counter(event.auth_event_ids()) == collections.Counter(
+            calculated_auth_event_ids
+        ):
+            return
+
+        # otherwise, re-run the auth checks based on what we calculated.
+        calculated_auth_events = await self._store.get_events_as_list(
+            calculated_auth_event_ids
         )
-        auth_events_x = await self._store.get_events(auth_events_ids)
+
+        # log the differences
+
+        claimed_auth_event_map = {(e.type, e.state_key): e for e in claimed_auth_events}
         calculated_auth_event_map = {
-            (e.type, e.state_key): e for e in auth_events_x.values()
+            (e.type, e.state_key): e for e in calculated_auth_events
         }
+        logger.info(
+            "event's auth_events are different to our calculated auth_events. "
+            "Claimed but not calculated: %s. Calculated but not claimed: %s",
+            [
+                ev
+                for k, ev in claimed_auth_event_map.items()
+                if k not in calculated_auth_event_map
+                or calculated_auth_event_map[k].event_id != ev.event_id
+            ],
+            [
+                ev
+                for k, ev in calculated_auth_event_map.items()
+                if k not in claimed_auth_event_map
+                or claimed_auth_event_map[k].event_id != ev.event_id
+            ],
+        )
 
         try:
-            updated_auth_events = await self._update_auth_events_for_auth(
+            check_state_dependent_auth_rules(event, calculated_auth_events)
+        except AuthError as e:
+            logger.warning(
+                "While checking auth of %r against room state before the event: %s",
                 event,
-                calculated_auth_event_map=calculated_auth_event_map,
-            )
-        except Exception:
-            # We don't really mind if the above fails, so lets not fail
-            # processing if it does. However, it really shouldn't fail so
-            # let's still log as an exception since we'll still want to fix
-            # any bugs.
-            logger.exception(
-                "Failed to double check auth events for %s with remote. "
-                "Ignoring failure and continuing processing of event.",
-                event.event_id,
-            )
-            updated_auth_events = None
-
-        if updated_auth_events:
-            context = await self._update_context_for_auth_events(
-                event, context, updated_auth_events
+                e,
             )
-            auth_events_for_auth = updated_auth_events
-        else:
-            auth_events_for_auth = calculated_auth_event_map
-
-        try:
-            check_state_dependent_auth_rules(event, auth_events_for_auth.values())
-        except AuthError as e:
-            logger.warning("Failed auth resolution for %r because %s", event, e)
             context.rejected = RejectedReason.AUTH_ERROR
 
-        return context
-
     async def _maybe_kick_guest_users(self, event: EventBase) -> None:
         if event.type != EventTypes.GuestAccess:
             return
@@ -1704,93 +1736,6 @@ class FederationEventHandler:
             soft_failed_event_counter.inc()
             event.internal_metadata.soft_failed = True
 
-    async def _update_auth_events_for_auth(
-        self,
-        event: EventBase,
-        calculated_auth_event_map: StateMap[EventBase],
-    ) -> Optional[StateMap[EventBase]]:
-        """Helper for _check_event_auth. See there for docs.
-
-        Checks whether a given event has the expected auth events. If it
-        doesn't then we talk to the remote server to compare state to see if
-        we can come to a consensus (e.g. if one server missed some valid
-        state).
-
-        This attempts to resolve any potential divergence of state between
-        servers, but is not essential and so failures should not block further
-        processing of the event.
-
-        Args:
-            event:
-
-            calculated_auth_event_map:
-                Our calculated auth_events based on the state of the room
-                at the event's position in the DAG.
-
-        Returns:
-            updated auth event map, or None if no changes are needed.
-
-        """
-        assert not event.internal_metadata.outlier
-
-        # check for events which are in the event's claimed auth_events, but not
-        # in our calculated event map.
-        event_auth_events = set(event.auth_event_ids())
-        different_auth = event_auth_events.difference(
-            e.event_id for e in calculated_auth_event_map.values()
-        )
-
-        if not different_auth:
-            return None
-
-        logger.info(
-            "auth_events refers to events which are not in our calculated auth "
-            "chain: %s",
-            different_auth,
-        )
-
-        # XXX: currently this checks for redactions but I'm not convinced that is
-        # necessary?
-        different_events = await self._store.get_events_as_list(different_auth)
-
-        # double-check they're all in the same room - we should already have checked
-        # this but it doesn't hurt to check again.
-        for d in different_events:
-            assert (
-                d.room_id == event.room_id
-            ), f"Event {event.event_id} refers to auth_event {d.event_id} which is in a different room"
-
-        # now we state-resolve between our own idea of the auth events, and the remote's
-        # idea of them.
-
-        local_state = calculated_auth_event_map.values()
-        remote_auth_events = dict(calculated_auth_event_map)
-        remote_auth_events.update({(d.type, d.state_key): d for d in different_events})
-        remote_state = remote_auth_events.values()
-
-        room_version = await self._store.get_room_version_id(event.room_id)
-        new_state = await self._state_handler.resolve_events(
-            room_version, (local_state, remote_state), event
-        )
-        different_state = {
-            (d.type, d.state_key): d
-            for d in new_state.values()
-            if calculated_auth_event_map.get((d.type, d.state_key)) != d
-        }
-        if not different_state:
-            logger.info("State res returned no new state")
-            return None
-
-        logger.info(
-            "After state res: updating auth_events with new state %s",
-            different_state.values(),
-        )
-
-        # take a copy of calculated_auth_event_map before we modify it.
-        auth_events = dict(calculated_auth_event_map)
-        auth_events.update(different_state)
-        return auth_events
-
     async def _load_or_fetch_auth_events_for_event(
         self, destination: str, event: EventBase
     ) -> Collection[EventBase]:
@@ -1888,61 +1833,6 @@ class FederationEventHandler:
 
         await self._auth_and_persist_outliers(room_id, remote_auth_events)
 
-    async def _update_context_for_auth_events(
-        self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase]
-    ) -> EventContext:
-        """Update the state_ids in an event context after auth event resolution,
-        storing the changes as a new state group.
-
-        Args:
-            event: The event we're handling the context for
-
-            context: initial event context
-
-            auth_events: Events to update in the event context.
-
-        Returns:
-            new event context
-        """
-        # exclude the state key of the new event from the current_state in the context.
-        if event.is_state():
-            event_key: Optional[Tuple[str, str]] = (event.type, event.state_key)
-        else:
-            event_key = None
-        state_updates = {
-            k: a.event_id for k, a in auth_events.items() if k != event_key
-        }
-
-        current_state_ids = await context.get_current_state_ids()
-        current_state_ids = dict(current_state_ids)  # type: ignore
-
-        current_state_ids.update(state_updates)
-
-        prev_state_ids = await context.get_prev_state_ids()
-        prev_state_ids = dict(prev_state_ids)
-
-        prev_state_ids.update({k: a.event_id for k, a in auth_events.items()})
-
-        # create a new state group as a delta from the existing one.
-        prev_group = context.state_group
-        state_group = await self._state_storage_controller.store_state_group(
-            event.event_id,
-            event.room_id,
-            prev_group=prev_group,
-            delta_ids=state_updates,
-            current_state_ids=current_state_ids,
-        )
-
-        return EventContext.with_state(
-            storage=self._storage_controllers,
-            state_group=state_group,
-            state_group_before_event=context.state_group_before_event,
-            state_delta_due_to_event=state_updates,
-            prev_group=prev_group,
-            delta_ids=state_updates,
-            partial_state=context.partial_state,
-        )
-
     async def _run_push_actions_and_persist_event(
         self, event: EventBase, context: EventContext, backfilled: bool = False
     ) -> None:
@@ -2093,6 +1983,10 @@ class FederationEventHandler:
             event, event_pos, max_stream_token, extra_users=extra_users
         )
 
+        if event.type == EventTypes.Member and event.membership == Membership.JOIN:
+            # TODO retrieve the previous state, and exclude join -> join transitions
+            self._notifier.notify_user_joined_room(event.event_id, event.room_id)
+
     def _sanity_check_event(self, ev: EventBase) -> None:
         """
         Do some early sanity checks of a received event
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 9bca2bc4b2..9571d461c8 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -26,7 +26,6 @@ from synapse.api.errors import (
     SynapseError,
 )
 from synapse.api.ratelimiting import Ratelimiter
-from synapse.config.emailconfig import ThreepidBehaviour
 from synapse.http import RequestTimedOutError
 from synapse.http.client import SimpleHttpClient
 from synapse.http.site import SynapseRequest
@@ -163,8 +162,7 @@ class IdentityHandler:
         sid: str,
         mxid: str,
         id_server: str,
-        id_access_token: Optional[str] = None,
-        use_v2: bool = True,
+        id_access_token: str,
     ) -> JsonDict:
         """Bind a 3PID to an identity server
 
@@ -174,8 +172,7 @@ class IdentityHandler:
             mxid: The MXID to bind the 3PID to
             id_server: The domain of the identity server to query
             id_access_token: The access token to authenticate to the identity
-                server with, if necessary. Required if use_v2 is true
-            use_v2: Whether to use v2 Identity Service API endpoints. Defaults to True
+                server with
 
         Raises:
             SynapseError: On any of the following conditions
@@ -187,24 +184,15 @@ class IdentityHandler:
         """
         logger.debug("Proxying threepid bind request for %s to %s", mxid, id_server)
 
-        # If an id_access_token is not supplied, force usage of v1
-        if id_access_token is None:
-            use_v2 = False
-
         if not valid_id_server_location(id_server):
             raise SynapseError(
                 400,
                 "id_server must be a valid hostname with optional port and path components",
             )
 
-        # Decide which API endpoint URLs to use
-        headers = {}
         bind_data = {"sid": sid, "client_secret": client_secret, "mxid": mxid}
-        if use_v2:
-            bind_url = "https://%s/_matrix/identity/v2/3pid/bind" % (id_server,)
-            headers["Authorization"] = create_id_access_token_header(id_access_token)  # type: ignore
-        else:
-            bind_url = "https://%s/_matrix/identity/api/v1/3pid/bind" % (id_server,)
+        bind_url = "https://%s/_matrix/identity/v2/3pid/bind" % (id_server,)
+        headers = {"Authorization": create_id_access_token_header(id_access_token)}
 
         try:
             # Use the blacklisting http client as this call is only to identity servers
@@ -223,21 +211,14 @@ class IdentityHandler:
 
             return data
         except HttpResponseException as e:
-            if e.code != 404 or not use_v2:
-                logger.error("3PID bind failed with Matrix error: %r", e)
-                raise e.to_synapse_error()
+            logger.error("3PID bind failed with Matrix error: %r", e)
+            raise e.to_synapse_error()
         except RequestTimedOutError:
             raise SynapseError(500, "Timed out contacting identity server")
         except CodeMessageException as e:
             data = json_decoder.decode(e.msg)  # XXX WAT?
             return data
 
-        logger.info("Got 404 when POSTing JSON %s, falling back to v1 URL", bind_url)
-        res = await self.bind_threepid(
-            client_secret, sid, mxid, id_server, id_access_token, use_v2=False
-        )
-        return res
-
     async def try_unbind_threepid(self, mxid: str, threepid: dict) -> bool:
         """Attempt to remove a 3PID from an identity server, or if one is not provided, all
         identity servers we're aware the binding is present on
@@ -300,8 +281,8 @@ class IdentityHandler:
                 "id_server must be a valid hostname with optional port and path components",
             )
 
-        url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,)
-        url_bytes = b"/_matrix/identity/api/v1/3pid/unbind"
+        url = "https://%s/_matrix/identity/v2/3pid/unbind" % (id_server,)
+        url_bytes = b"/_matrix/identity/v2/3pid/unbind"
 
         content = {
             "mxid": mxid,
@@ -434,48 +415,6 @@ class IdentityHandler:
 
         return session_id
 
-    async def requestEmailToken(
-        self,
-        id_server: str,
-        email: str,
-        client_secret: str,
-        send_attempt: int,
-        next_link: Optional[str] = None,
-    ) -> JsonDict:
-        """
-        Request an external server send an email on our behalf for the purposes of threepid
-        validation.
-
-        Args:
-            id_server: The identity server to proxy to
-            email: The email to send the message to
-            client_secret: The unique client_secret sends by the user
-            send_attempt: Which attempt this is
-            next_link: A link to redirect the user to once they submit the token
-
-        Returns:
-            The json response body from the server
-        """
-        params = {
-            "email": email,
-            "client_secret": client_secret,
-            "send_attempt": send_attempt,
-        }
-        if next_link:
-            params["next_link"] = next_link
-
-        try:
-            data = await self.http_client.post_json_get_json(
-                id_server + "/_matrix/identity/api/v1/validate/email/requestToken",
-                params,
-            )
-            return data
-        except HttpResponseException as e:
-            logger.info("Proxied requestToken failed: %r", e)
-            raise e.to_synapse_error()
-        except RequestTimedOutError:
-            raise SynapseError(500, "Timed out contacting identity server")
-
     async def requestMsisdnToken(
         self,
         id_server: str,
@@ -549,18 +488,7 @@ class IdentityHandler:
         validation_session = None
 
         # Try to validate as email
-        if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
-            # Remote emails will only be used if a valid identity server is provided.
-            assert (
-                self.hs.config.registration.account_threepid_delegate_email is not None
-            )
-
-            # Ask our delegated email identity server
-            validation_session = await self.threepid_from_creds(
-                self.hs.config.registration.account_threepid_delegate_email,
-                threepid_creds,
-            )
-        elif self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+        if self.hs.config.email.can_verify_email:
             # Get a validated session matching these details
             validation_session = await self.store.get_threepid_validation_session(
                 "email", client_secret, sid=sid, validated=True
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 1980e37dae..bd7baef051 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -463,6 +463,7 @@ class EventCreationHandler:
         )
         self._events_shard_config = self.config.worker.events_shard_config
         self._instance_name = hs.get_instance_name()
+        self._notifier = hs.get_notifier()
 
         self.room_prejoin_state_types = self.hs.config.api.room_prejoin_state
 
@@ -1444,7 +1445,12 @@ class EventCreationHandler:
             if state_entry.state_group in self._external_cache_joined_hosts_updates:
                 return
 
-            joined_hosts = await self.store.get_joined_hosts(event.room_id, state_entry)
+            state = await state_entry.get_state(
+                self._storage_controllers.state, StateFilter.all()
+            )
+            joined_hosts = await self.store.get_joined_hosts(
+                event.room_id, state, state_entry
+            )
 
             # Note that the expiry times must be larger than the expiry time in
             # _external_cache_joined_hosts_updates.
@@ -1545,6 +1551,16 @@ class EventCreationHandler:
                 requester, is_admin_redaction=is_admin_redaction
             )
 
+        if event.type == EventTypes.Member and event.membership == Membership.JOIN:
+            (
+                current_membership,
+                _,
+            ) = await self.store.get_local_current_membership_for_user_in_room(
+                event.state_key, event.room_id
+            )
+            if current_membership != Membership.JOIN:
+                self._notifier.notify_user_joined_room(event.event_id, event.room_id)
+
         await self._maybe_kick_guest_users(event, context)
 
         if event.type == EventTypes.CanonicalAlias:
@@ -1844,13 +1860,8 @@ class EventCreationHandler:
 
         # For each room we need to find a joined member we can use to send
         # the dummy event with.
-        latest_event_ids = await self.store.get_prev_events_for_room(room_id)
-        members = await self.state.get_current_users_in_room(
-            room_id, latest_event_ids=latest_event_ids
-        )
+        members = await self.store.get_local_users_in_room(room_id)
         for user_id in members:
-            if not self.hs.is_mine_id(user_id):
-                continue
             requester = create_requester(user_id, authenticated_entity=self.server_name)
             try:
                 event, context = await self.create_event(
@@ -1861,7 +1872,6 @@ class EventCreationHandler:
                         "room_id": room_id,
                         "sender": user_id,
                     },
-                    prev_event_ids=latest_event_ids,
                 )
 
                 event.internal_metadata.proactively_send = False
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index a54f163c0a..978d3ee39f 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -889,7 +889,11 @@ class RoomCreationHandler:
         # override any attempt to set room versions via the creation_content
         creation_content["room_version"] = room_version.identifier
 
-        last_stream_id = await self._send_events_for_new_room(
+        (
+            last_stream_id,
+            last_sent_event_id,
+            depth,
+        ) = await self._send_events_for_new_room(
             requester,
             room_id,
             preset_config=preset_config,
@@ -905,7 +909,7 @@ class RoomCreationHandler:
         if "name" in config:
             name = config["name"]
             (
-                _,
+                name_event,
                 last_stream_id,
             ) = await self.event_creation_handler.create_and_send_nonmember_event(
                 requester,
@@ -917,12 +921,16 @@ class RoomCreationHandler:
                     "content": {"name": name},
                 },
                 ratelimit=False,
+                prev_event_ids=[last_sent_event_id],
+                depth=depth,
             )
+            last_sent_event_id = name_event.event_id
+            depth += 1
 
         if "topic" in config:
             topic = config["topic"]
             (
-                _,
+                topic_event,
                 last_stream_id,
             ) = await self.event_creation_handler.create_and_send_nonmember_event(
                 requester,
@@ -934,7 +942,11 @@ class RoomCreationHandler:
                     "content": {"topic": topic},
                 },
                 ratelimit=False,
+                prev_event_ids=[last_sent_event_id],
+                depth=depth,
             )
+            last_sent_event_id = topic_event.event_id
+            depth += 1
 
         # we avoid dropping the lock between invites, as otherwise joins can
         # start coming in and making the createRoom slow.
@@ -949,7 +961,7 @@ class RoomCreationHandler:
 
             for invitee in invite_list:
                 (
-                    _,
+                    member_event_id,
                     last_stream_id,
                 ) = await self.room_member_handler.update_membership_locked(
                     requester,
@@ -959,7 +971,11 @@ class RoomCreationHandler:
                     ratelimit=False,
                     content=content,
                     new_room=True,
+                    prev_event_ids=[last_sent_event_id],
+                    depth=depth,
                 )
+                last_sent_event_id = member_event_id
+                depth += 1
 
         for invite_3pid in invite_3pid_list:
             id_server = invite_3pid["id_server"]
@@ -968,7 +984,10 @@ class RoomCreationHandler:
             medium = invite_3pid["medium"]
             # Note that do_3pid_invite can raise a  ShadowBanError, but this was
             # handled above by emptying invite_3pid_list.
-            last_stream_id = await self.hs.get_room_member_handler().do_3pid_invite(
+            (
+                member_event_id,
+                last_stream_id,
+            ) = await self.hs.get_room_member_handler().do_3pid_invite(
                 room_id,
                 requester.user,
                 medium,
@@ -977,7 +996,11 @@ class RoomCreationHandler:
                 requester,
                 txn_id=None,
                 id_access_token=id_access_token,
+                prev_event_ids=[last_sent_event_id],
+                depth=depth,
             )
+            last_sent_event_id = member_event_id
+            depth += 1
 
         result = {"room_id": room_id}
 
@@ -1005,20 +1028,22 @@ class RoomCreationHandler:
         power_level_content_override: Optional[JsonDict] = None,
         creator_join_profile: Optional[JsonDict] = None,
         ratelimit: bool = True,
-    ) -> int:
+    ) -> Tuple[int, str, int]:
         """Sends the initial events into a new room.
 
         `power_level_content_override` doesn't apply when initial state has
         power level state event content.
 
         Returns:
-            The stream_id of the last event persisted.
+            A tuple containing the stream ID, event ID and depth of the last
+            event sent to the room.
         """
 
         creator_id = creator.user.to_string()
 
         event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
 
+        depth = 1
         last_sent_event_id: Optional[str] = None
 
         def create(etype: str, content: JsonDict, **kwargs: Any) -> JsonDict:
@@ -1031,6 +1056,7 @@ class RoomCreationHandler:
 
         async def send(etype: str, content: JsonDict, **kwargs: Any) -> int:
             nonlocal last_sent_event_id
+            nonlocal depth
 
             event = create(etype, content, **kwargs)
             logger.debug("Sending %s in new room", etype)
@@ -1047,9 +1073,11 @@ class RoomCreationHandler:
                 # Note: we don't pass state_event_ids here because this triggers
                 # an additional query per event to look them up from the events table.
                 prev_event_ids=[last_sent_event_id] if last_sent_event_id else [],
+                depth=depth,
             )
 
             last_sent_event_id = sent_event.event_id
+            depth += 1
 
             return last_stream_id
 
@@ -1075,6 +1103,7 @@ class RoomCreationHandler:
             content=creator_join_profile,
             new_room=True,
             prev_event_ids=[last_sent_event_id],
+            depth=depth,
         )
         last_sent_event_id = member_event_id
 
@@ -1168,7 +1197,7 @@ class RoomCreationHandler:
                 content={"algorithm": RoomEncryptionAlgorithms.DEFAULT},
             )
 
-        return last_sent_stream_id
+        return last_sent_stream_id, last_sent_event_id, depth
 
     def _generate_room_id(self) -> str:
         """Generates a random room ID.
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 04c44b2ccb..30b4cb23df 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -94,12 +94,29 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             rate_hz=hs.config.ratelimiting.rc_joins_local.per_second,
             burst_count=hs.config.ratelimiting.rc_joins_local.burst_count,
         )
+        # Tracks joins from local users to rooms this server isn't a member of.
+        # I.e. joins this server makes by requesting /make_join /send_join from
+        # another server.
         self._join_rate_limiter_remote = Ratelimiter(
             store=self.store,
             clock=self.clock,
             rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second,
             burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count,
         )
+        # TODO: find a better place to keep this Ratelimiter.
+        #   It needs to be
+        #    - written to by event persistence code
+        #    - written to by something which can snoop on replication streams
+        #    - read by the RoomMemberHandler to rate limit joins from local users
+        #    - read by the FederationServer to rate limit make_joins and send_joins from
+        #      other homeservers
+        #   I wonder if a homeserver-wide collection of rate limiters might be cleaner?
+        self._join_rate_per_room_limiter = Ratelimiter(
+            store=self.store,
+            clock=self.clock,
+            rate_hz=hs.config.ratelimiting.rc_joins_per_room.per_second,
+            burst_count=hs.config.ratelimiting.rc_joins_per_room.burst_count,
+        )
 
         # Ratelimiter for invites, keyed by room (across all issuers, all
         # recipients).
@@ -136,6 +153,18 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         )
 
         self.request_ratelimiter = hs.get_request_ratelimiter()
+        hs.get_notifier().add_new_join_in_room_callback(self._on_user_joined_room)
+
+    def _on_user_joined_room(self, event_id: str, room_id: str) -> None:
+        """Notify the rate limiter that a room join has occurred.
+
+        Use this to inform the RoomMemberHandler about joins that have either
+        - taken place on another homeserver, or
+        - on another worker in this homeserver.
+        Joins actioned by this worker should use the usual `ratelimit` method, which
+        checks the limit and increments the counter in one go.
+        """
+        self._join_rate_per_room_limiter.record_action(requester=None, key=room_id)
 
     @abc.abstractmethod
     async def _remote_join(
@@ -285,6 +314,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         allow_no_prev_events: bool = False,
         prev_event_ids: Optional[List[str]] = None,
         state_event_ids: Optional[List[str]] = None,
+        depth: Optional[int] = None,
         txn_id: Optional[str] = None,
         ratelimit: bool = True,
         content: Optional[dict] = None,
@@ -315,6 +345,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 prev_events are set so we need to set them ourself via this argument.
                 This should normally be left as None, which will cause the auth_event_ids
                 to be calculated based on the room state at the prev_events.
+            depth: Override the depth used to order the event in the DAG.
+                Should normally be set to None, which will cause the depth to be calculated
+                based on the prev_events.
 
             txn_id:
             ratelimit:
@@ -370,6 +403,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             allow_no_prev_events=allow_no_prev_events,
             prev_event_ids=prev_event_ids,
             state_event_ids=state_event_ids,
+            depth=depth,
             require_consent=require_consent,
             outlier=outlier,
             historical=historical,
@@ -391,6 +425,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             # up blocking profile updates.
             if newly_joined and ratelimit:
                 await self._join_rate_limiter_local.ratelimit(requester)
+                await self._join_rate_per_room_limiter.ratelimit(
+                    requester, key=room_id, update=False
+                )
 
         result_event = await self.event_creation_handler.handle_new_client_event(
             requester,
@@ -466,6 +503,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         allow_no_prev_events: bool = False,
         prev_event_ids: Optional[List[str]] = None,
         state_event_ids: Optional[List[str]] = None,
+        depth: Optional[int] = None,
     ) -> Tuple[str, int]:
         """Update a user's membership in a room.
 
@@ -501,6 +539,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 prev_events are set so we need to set them ourself via this argument.
                 This should normally be left as None, which will cause the auth_event_ids
                 to be calculated based on the room state at the prev_events.
+            depth: Override the depth used to order the event in the DAG.
+                Should normally be set to None, which will cause the depth to be calculated
+                based on the prev_events.
 
         Returns:
             A tuple of the new event ID and stream ID.
@@ -540,6 +581,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                     allow_no_prev_events=allow_no_prev_events,
                     prev_event_ids=prev_event_ids,
                     state_event_ids=state_event_ids,
+                    depth=depth,
                 )
 
         return result
@@ -562,6 +604,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         allow_no_prev_events: bool = False,
         prev_event_ids: Optional[List[str]] = None,
         state_event_ids: Optional[List[str]] = None,
+        depth: Optional[int] = None,
     ) -> Tuple[str, int]:
         """Helper for update_membership.
 
@@ -599,6 +642,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 prev_events are set so we need to set them ourself via this argument.
                 This should normally be left as None, which will cause the auth_event_ids
                 to be calculated based on the room state at the prev_events.
+            depth: Override the depth used to order the event in the DAG.
+                Should normally be set to None, which will cause the depth to be calculated
+                based on the prev_events.
 
         Returns:
             A tuple of the new event ID and stream ID.
@@ -732,6 +778,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 allow_no_prev_events=allow_no_prev_events,
                 prev_event_ids=prev_event_ids,
                 state_event_ids=state_event_ids,
+                depth=depth,
                 content=content,
                 require_consent=require_consent,
                 outlier=outlier,
@@ -740,14 +787,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
         latest_event_ids = await self.store.get_prev_events_for_room(room_id)
 
-        current_state_ids = await self.state_handler.get_current_state_ids(
-            room_id, latest_event_ids=latest_event_ids
+        state_before_join = await self.state_handler.compute_state_after_events(
+            room_id, latest_event_ids
         )
 
         # TODO: Refactor into dictionary of explicitly allowed transitions
         # between old and new state, with specific error messages for some
         # transitions and generic otherwise
-        old_state_id = current_state_ids.get((EventTypes.Member, target.to_string()))
+        old_state_id = state_before_join.get((EventTypes.Member, target.to_string()))
         if old_state_id:
             old_state = await self.store.get_event(old_state_id, allow_none=True)
             old_membership = old_state.content.get("membership") if old_state else None
@@ -798,11 +845,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             if action == "kick":
                 raise AuthError(403, "The target user is not in the room")
 
-        is_host_in_room = await self._is_host_in_room(current_state_ids)
+        is_host_in_room = await self._is_host_in_room(state_before_join)
 
         if effective_membership_state == Membership.JOIN:
             if requester.is_guest:
-                guest_can_join = await self._can_guest_join(current_state_ids)
+                guest_can_join = await self._can_guest_join(state_before_join)
                 if not guest_can_join:
                     # This should be an auth check, but guests are a local concept,
                     # so don't really fit into the general auth process.
@@ -840,13 +887,23 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
             # Check if a remote join should be performed.
             remote_join, remote_room_hosts = await self._should_perform_remote_join(
-                target.to_string(), room_id, remote_room_hosts, content, is_host_in_room
+                target.to_string(),
+                room_id,
+                remote_room_hosts,
+                content,
+                is_host_in_room,
+                state_before_join,
             )
             if remote_join:
                 if ratelimit:
                     await self._join_rate_limiter_remote.ratelimit(
                         requester,
                     )
+                    await self._join_rate_per_room_limiter.ratelimit(
+                        requester,
+                        key=room_id,
+                        update=False,
+                    )
 
                 inviter = await self._get_inviter(target.to_string(), room_id)
                 if inviter and not self.hs.is_mine(inviter):
@@ -967,6 +1024,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             ratelimit=ratelimit,
             prev_event_ids=latest_event_ids,
             state_event_ids=state_event_ids,
+            depth=depth,
             content=content,
             require_consent=require_consent,
             outlier=outlier,
@@ -979,6 +1037,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         remote_room_hosts: List[str],
         content: JsonDict,
         is_host_in_room: bool,
+        state_before_join: StateMap[str],
     ) -> Tuple[bool, List[str]]:
         """
         Check whether the server should do a remote join (as opposed to a local
@@ -998,6 +1057,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             content: The content to use as the event body of the join. This may
                 be modified.
             is_host_in_room: True if the host is in the room.
+            state_before_join: The state before the join event (i.e. the resolution of
+                the states after its parent events).
 
         Returns:
             A tuple of:
@@ -1014,20 +1075,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         # If the host is in the room, but not one of the authorised hosts
         # for restricted join rules, a remote join must be used.
         room_version = await self.store.get_room_version(room_id)
-        current_state_ids = await self._storage_controllers.state.get_current_state_ids(
-            room_id
-        )
 
         # If restricted join rules are not being used, a local join can always
         # be used.
         if not await self.event_auth_handler.has_restricted_join_rules(
-            current_state_ids, room_version
+            state_before_join, room_version
         ):
             return False, []
 
         # If the user is invited to the room or already joined, the join
         # event can always be issued locally.
-        prev_member_event_id = current_state_ids.get((EventTypes.Member, user_id), None)
+        prev_member_event_id = state_before_join.get((EventTypes.Member, user_id), None)
         prev_member_event = None
         if prev_member_event_id:
             prev_member_event = await self.store.get_event(prev_member_event_id)
@@ -1042,10 +1100,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         #
         # If not, generate a new list of remote hosts based on which
         # can issue invites.
-        event_map = await self.store.get_events(current_state_ids.values())
+        event_map = await self.store.get_events(state_before_join.values())
         current_state = {
             state_key: event_map[event_id]
-            for state_key, event_id in current_state_ids.items()
+            for state_key, event_id in state_before_join.items()
         }
         allowed_servers = get_servers_from_users(
             get_users_which_can_issue_invite(current_state)
@@ -1059,7 +1117,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
         # Ensure the member should be allowed access via membership in a room.
         await self.event_auth_handler.check_restricted_join_rules(
-            current_state_ids, room_version, user_id, prev_member_event
+            state_before_join, room_version, user_id, prev_member_event
         )
 
         # If this is going to be a local join, additional information must
@@ -1069,7 +1127,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             EventContentFields.AUTHORISING_USER
         ] = await self.event_auth_handler.get_user_which_could_invite(
             room_id,
-            current_state_ids,
+            state_before_join,
         )
 
         return False, []
@@ -1322,7 +1380,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         requester: Requester,
         txn_id: Optional[str],
         id_access_token: Optional[str] = None,
-    ) -> int:
+        prev_event_ids: Optional[List[str]] = None,
+        depth: Optional[int] = None,
+    ) -> Tuple[str, int]:
         """Invite a 3PID to a room.
 
         Args:
@@ -1335,9 +1395,13 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             txn_id: The transaction ID this is part of, or None if this is not
                 part of a transaction.
             id_access_token: The optional identity server access token.
+            depth: Override the depth used to order the event in the DAG.
+            prev_event_ids: The event IDs to use as the prev events
+                Should normally be set to None, which will cause the depth to be calculated
+                based on the prev_events.
 
         Returns:
-             The new stream ID.
+            Tuple of event ID and stream ordering position
 
         Raises:
             ShadowBanError if the requester has been shadow-banned.
@@ -1383,7 +1447,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             # We don't check the invite against the spamchecker(s) here (through
             # user_may_invite) because we'll do it further down the line anyway (in
             # update_membership_locked).
-            _, stream_id = await self.update_membership(
+            event_id, stream_id = await self.update_membership(
                 requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id
             )
         else:
@@ -1402,7 +1466,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                     additional_fields=spam_check[1],
                 )
 
-            stream_id = await self._make_and_store_3pid_invite(
+            event, stream_id = await self._make_and_store_3pid_invite(
                 requester,
                 id_server,
                 medium,
@@ -1411,9 +1475,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 inviter,
                 txn_id=txn_id,
                 id_access_token=id_access_token,
+                prev_event_ids=prev_event_ids,
+                depth=depth,
             )
+            event_id = event.event_id
 
-        return stream_id
+        return event_id, stream_id
 
     async def _make_and_store_3pid_invite(
         self,
@@ -1425,7 +1492,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         user: UserID,
         txn_id: Optional[str],
         id_access_token: Optional[str] = None,
-    ) -> int:
+        prev_event_ids: Optional[List[str]] = None,
+        depth: Optional[int] = None,
+    ) -> Tuple[EventBase, int]:
         room_state = await self._storage_controllers.state.get_current_state(
             room_id,
             StateFilter.from_types(
@@ -1518,8 +1587,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             },
             ratelimit=False,
             txn_id=txn_id,
+            prev_event_ids=prev_event_ids,
+            depth=depth,
         )
-        return stream_id
+        return event, stream_id
 
     async def _is_host_in_room(self, current_state_ids: StateMap[str]) -> bool:
         # Have we just created the room, and is this about to be the very
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 05cebb5d4d..a744d68c64 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -19,7 +19,6 @@ from twisted.web.client import PartialDownloadError
 
 from synapse.api.constants import LoginType
 from synapse.api.errors import Codes, LoginError, SynapseError
-from synapse.config.emailconfig import ThreepidBehaviour
 from synapse.util import json_decoder
 
 if TYPE_CHECKING:
@@ -153,7 +152,7 @@ class _BaseThreepidAuthChecker:
 
         logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,))
 
-        # msisdns are currently always ThreepidBehaviour.REMOTE
+        # msisdns are currently always verified via the IS
         if medium == "msisdn":
             if not self.hs.config.registration.account_threepid_delegate_msisdn:
                 raise SynapseError(
@@ -164,18 +163,7 @@ class _BaseThreepidAuthChecker:
                 threepid_creds,
             )
         elif medium == "email":
-            if (
-                self.hs.config.email.threepid_behaviour_email
-                == ThreepidBehaviour.REMOTE
-            ):
-                assert self.hs.config.registration.account_threepid_delegate_email
-                threepid = await identity_handler.threepid_from_creds(
-                    self.hs.config.registration.account_threepid_delegate_email,
-                    threepid_creds,
-                )
-            elif (
-                self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL
-            ):
+            if self.hs.config.email.can_verify_email:
                 threepid = None
                 row = await self.store.get_threepid_validation_session(
                     medium,
@@ -227,10 +215,7 @@ class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChec
         _BaseThreepidAuthChecker.__init__(self, hs)
 
     def is_enabled(self) -> bool:
-        return self.hs.config.email.threepid_behaviour_email in (
-            ThreepidBehaviour.REMOTE,
-            ThreepidBehaviour.LOCAL,
-        )
+        return self.hs.config.email.can_verify_email
 
     async def check_auth(self, authdict: dict, clientip: str) -> Any:
         return await self._check_threepid("email", authdict)
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 50c57940f9..17e729f0c7 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -84,14 +84,13 @@ the function becomes the operation name for the span.
        return something_usual_and_useful
 
 
-Operation names can be explicitly set for a function by passing the
-operation name to ``trace``
+Operation names can be explicitly set for a function by using ``trace_with_opname``:
 
 .. code-block:: python
 
-   from synapse.logging.opentracing import trace
+   from synapse.logging.opentracing import trace_with_opname
 
-   @trace(opname="a_better_operation_name")
+   @trace_with_opname("a_better_operation_name")
    def interesting_badly_named_function(*args, **kwargs):
        # Does all kinds of cool and expected things
        return something_usual_and_useful
@@ -798,33 +797,31 @@ def extract_text_map(carrier: Dict[str, str]) -> Optional["opentracing.SpanConte
 # Tracing decorators
 
 
-def trace(func=None, opname: Optional[str] = None):
+def trace_with_opname(opname: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
     """
-    Decorator to trace a function.
-    Sets the operation name to that of the function's or that given
-    as operation_name. See the module's doc string for usage
-    examples.
+    Decorator to trace a function with a custom opname.
+
+    See the module's doc string for usage examples.
+
     """
 
-    def decorator(func):
+    def decorator(func: Callable[P, R]) -> Callable[P, R]:
         if opentracing is None:
             return func  # type: ignore[unreachable]
 
-        _opname = opname if opname else func.__name__
-
         if inspect.iscoroutinefunction(func):
 
             @wraps(func)
-            async def _trace_inner(*args, **kwargs):
-                with start_active_span(_opname):
-                    return await func(*args, **kwargs)
+            async def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R:
+                with start_active_span(opname):
+                    return await func(*args, **kwargs)  # type: ignore[misc]
 
         else:
             # The other case here handles both sync functions and those
             # decorated with inlineDeferred.
             @wraps(func)
-            def _trace_inner(*args, **kwargs):
-                scope = start_active_span(_opname)
+            def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R:
+                scope = start_active_span(opname)
                 scope.__enter__()
 
                 try:
@@ -858,12 +855,21 @@ def trace(func=None, opname: Optional[str] = None):
                     scope.__exit__(type(e), None, e.__traceback__)
                     raise
 
-        return _trace_inner
+        return _trace_inner  # type: ignore[return-value]
 
-    if func:
-        return decorator(func)
-    else:
-        return decorator
+    return decorator
+
+
+def trace(func: Callable[P, R]) -> Callable[P, R]:
+    """
+    Decorator to trace a function.
+
+    Sets the operation name to that of the function's name.
+
+    See the module's doc string for usage examples.
+    """
+
+    return trace_with_opname(func.__name__)(func)
 
 
 def tag_args(func: Callable[P, R]) -> Callable[P, R]:
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 54b0ec4b97..c42bb8266a 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -228,6 +228,7 @@ class Notifier:
 
         # Called when there are new things to stream over replication
         self.replication_callbacks: List[Callable[[], None]] = []
+        self._new_join_in_room_callbacks: List[Callable[[str, str], None]] = []
 
         self._federation_client = hs.get_federation_http_client()
 
@@ -280,6 +281,19 @@ class Notifier:
         """
         self.replication_callbacks.append(cb)
 
+    def add_new_join_in_room_callback(self, cb: Callable[[str, str], None]) -> None:
+        """Add a callback that will be called when a user joins a room.
+
+        This only fires on genuine membership changes, e.g. "invite" -> "join".
+        Membership transitions like "join" -> "join" (for e.g. displayname changes) do
+        not trigger the callback.
+
+        When called, the callback receives two arguments: the event ID and the room ID.
+        It should *not* return a Deferred - if it needs to do any asynchronous work, a
+        background thread should be started and wrapped with run_as_background_process.
+        """
+        self._new_join_in_room_callbacks.append(cb)
+
     async def on_new_room_event(
         self,
         event: EventBase,
@@ -723,6 +737,10 @@ class Notifier:
         for cb in self.replication_callbacks:
             cb()
 
+    def notify_user_joined_room(self, event_id: str, room_id: str) -> None:
+        for cb in self._new_join_in_room_callbacks:
+            cb(event_id, room_id)
+
     def notify_remote_server_up(self, server: str) -> None:
         """Notify any replication that a remote server has come back up"""
         # We call federation_sender directly rather than registering as a
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index d0cc657b44..1e0ef44fc7 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -328,7 +328,7 @@ class PusherPool:
             return None
 
         try:
-            p = self.pusher_factory.create_pusher(pusher_config)
+            pusher = self.pusher_factory.create_pusher(pusher_config)
         except PusherConfigException as e:
             logger.warning(
                 "Pusher incorrectly configured id=%i, user=%s, appid=%s, pushkey=%s: %s",
@@ -346,23 +346,28 @@ class PusherPool:
             )
             return None
 
-        if not p:
+        if not pusher:
             return None
 
-        appid_pushkey = "%s:%s" % (pusher_config.app_id, pusher_config.pushkey)
+        appid_pushkey = "%s:%s" % (pusher.app_id, pusher.pushkey)
 
-        byuser = self.pushers.setdefault(pusher_config.user_name, {})
+        byuser = self.pushers.setdefault(pusher.user_id, {})
         if appid_pushkey in byuser:
-            byuser[appid_pushkey].on_stop()
-        byuser[appid_pushkey] = p
+            previous_pusher = byuser[appid_pushkey]
+            previous_pusher.on_stop()
 
-        synapse_pushers.labels(type(p).__name__, p.app_id).inc()
+            synapse_pushers.labels(
+                type(previous_pusher).__name__, previous_pusher.app_id
+            ).dec()
+        byuser[appid_pushkey] = pusher
+
+        synapse_pushers.labels(type(pusher).__name__, pusher.app_id).inc()
 
         # Check if there *may* be push to process. We do this as this check is a
         # lot cheaper to do than actually fetching the exact rows we need to
         # push.
-        user_id = pusher_config.user_name
-        last_stream_ordering = pusher_config.last_stream_ordering
+        user_id = pusher.user_id
+        last_stream_ordering = pusher.last_stream_ordering
         if last_stream_ordering:
             have_notifs = await self.store.get_if_maybe_push_in_range_for_user(
                 user_id, last_stream_ordering
@@ -372,9 +377,9 @@ class PusherPool:
             # risk missing push.
             have_notifs = True
 
-        p.on_started(have_notifs)
+        pusher.on_started(have_notifs)
 
-        return p
+        return pusher
 
     async def remove_pusher(self, app_id: str, pushkey: str, user_id: str) -> None:
         appid_pushkey = "%s:%s" % (app_id, pushkey)
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index a4ae4040c3..561ad5bf04 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -29,7 +29,7 @@ from synapse.http import RequestTimedOutError
 from synapse.http.server import HttpServer, is_method_cancellable
 from synapse.http.site import SynapseRequest
 from synapse.logging import opentracing
-from synapse.logging.opentracing import trace
+from synapse.logging.opentracing import trace_with_opname
 from synapse.types import JsonDict
 from synapse.util.caches.response_cache import ResponseCache
 from synapse.util.stringutils import random_string
@@ -196,7 +196,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
                 "ascii"
             )
 
-        @trace(opname="outgoing_replication_request")
+        @trace_with_opname("outgoing_replication_request")
         async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any:
             with outgoing_gauge.track_inprogress():
                 if instance_name == local_instance_name:
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 2f59245058..e4f2201c92 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -21,7 +21,7 @@ from twisted.internet.interfaces import IAddress, IConnector
 from twisted.internet.protocol import ReconnectingClientFactory
 from twisted.python.failure import Failure
 
-from synapse.api.constants import EventTypes, ReceiptTypes
+from synapse.api.constants import EventTypes, Membership, ReceiptTypes
 from synapse.federation import send_queue
 from synapse.federation.sender import FederationSender
 from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
@@ -219,6 +219,21 @@ class ReplicationDataHandler:
                     membership=row.data.membership,
                 )
 
+                # If this event is a join, make a note of it so we have an accurate
+                # cross-worker room rate limit.
+                # TODO: Erik said we should exclude rows that came from ex_outliers
+                #  here, but I don't see how we can determine that. I guess we could
+                #  add a flag to row.data?
+                if (
+                    row.data.type == EventTypes.Member
+                    and row.data.membership == Membership.JOIN
+                    and not row.data.outlier
+                ):
+                    # TODO retrieve the previous state, and exclude join -> join transitions
+                    self.notifier.notify_user_joined_room(
+                        row.data.event_id, row.data.room_id
+                    )
+
         await self._presence_handler.process_replication_rows(
             stream_name, instance_name, token, rows
         )
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 26f4fa7cfd..14b6705862 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -98,6 +98,7 @@ class EventsStreamEventRow(BaseEventsStreamRow):
     relates_to: Optional[str]
     membership: Optional[str]
     rejected: bool
+    outlier: bool
 
 
 @attr.s(slots=True, frozen=True, auto_attribs=True)
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index f0614a2897..ba2f7fa6d8 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -373,6 +373,7 @@ class UserRestServletV2(RestServlet):
                     if (
                         self.hs.config.email.email_enable_notifs
                         and self.hs.config.email.email_notif_for_new_users
+                        and medium == "email"
                     ):
                         await self.pusher_pool.add_pusher(
                             user_id=user_id,
diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py
index bdc4a9c068..0cc87a4001 100644
--- a/synapse/rest/client/account.py
+++ b/synapse/rest/client/account.py
@@ -28,7 +28,6 @@ from synapse.api.errors import (
     SynapseError,
     ThreepidValidationError,
 )
-from synapse.config.emailconfig import ThreepidBehaviour
 from synapse.handlers.ui_auth import UIAuthSessionDataConstants
 from synapse.http.server import HttpServer, finish_request, respond_with_html
 from synapse.http.servlet import (
@@ -64,7 +63,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
         self.config = hs.config
         self.identity_handler = hs.get_identity_handler()
 
-        if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+        if self.config.email.can_verify_email:
             self.mailer = Mailer(
                 hs=self.hs,
                 app_name=self.config.email.email_app_name,
@@ -73,11 +72,10 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
             )
 
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
-        if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF:
-            if self.config.email.local_threepid_handling_disabled_due_to_email_config:
-                logger.warning(
-                    "User password resets have been disabled due to lack of email config"
-                )
+        if not self.config.email.can_verify_email:
+            logger.warning(
+                "User password resets have been disabled due to lack of email config"
+            )
             raise SynapseError(
                 400, "Email-based password resets have been disabled on this server"
             )
@@ -129,35 +127,21 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
 
             raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
 
-        if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
-            assert self.hs.config.registration.account_threepid_delegate_email
-
-            # Have the configured identity server handle the request
-            ret = await self.identity_handler.requestEmailToken(
-                self.hs.config.registration.account_threepid_delegate_email,
-                email,
-                client_secret,
-                send_attempt,
-                next_link,
-            )
-        else:
-            # Send password reset emails from Synapse
-            sid = await self.identity_handler.send_threepid_validation(
-                email,
-                client_secret,
-                send_attempt,
-                self.mailer.send_password_reset_mail,
-                next_link,
-            )
-
-            # Wrap the session id in a JSON object
-            ret = {"sid": sid}
+        # Send password reset emails from Synapse
+        sid = await self.identity_handler.send_threepid_validation(
+            email,
+            client_secret,
+            send_attempt,
+            self.mailer.send_password_reset_mail,
+            next_link,
+        )
 
         threepid_send_requests.labels(type="email", reason="password_reset").observe(
             send_attempt
         )
 
-        return 200, ret
+        # Wrap the session id in a JSON object
+        return 200, {"sid": sid}
 
 
 class PasswordRestServlet(RestServlet):
@@ -349,7 +333,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
         self.identity_handler = hs.get_identity_handler()
         self.store = self.hs.get_datastores().main
 
-        if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+        if self.config.email.can_verify_email:
             self.mailer = Mailer(
                 hs=self.hs,
                 app_name=self.config.email.email_app_name,
@@ -358,11 +342,10 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
             )
 
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
-        if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF:
-            if self.config.email.local_threepid_handling_disabled_due_to_email_config:
-                logger.warning(
-                    "Adding emails have been disabled due to lack of an email config"
-                )
+        if not self.config.email.can_verify_email:
+            logger.warning(
+                "Adding emails have been disabled due to lack of an email config"
+            )
             raise SynapseError(
                 400, "Adding an email to your account is disabled on this server"
             )
@@ -413,35 +396,20 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
 
             raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
 
-        if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
-            assert self.hs.config.registration.account_threepid_delegate_email
-
-            # Have the configured identity server handle the request
-            ret = await self.identity_handler.requestEmailToken(
-                self.hs.config.registration.account_threepid_delegate_email,
-                email,
-                client_secret,
-                send_attempt,
-                next_link,
-            )
-        else:
-            # Send threepid validation emails from Synapse
-            sid = await self.identity_handler.send_threepid_validation(
-                email,
-                client_secret,
-                send_attempt,
-                self.mailer.send_add_threepid_mail,
-                next_link,
-            )
-
-            # Wrap the session id in a JSON object
-            ret = {"sid": sid}
+        sid = await self.identity_handler.send_threepid_validation(
+            email,
+            client_secret,
+            send_attempt,
+            self.mailer.send_add_threepid_mail,
+            next_link,
+        )
 
         threepid_send_requests.labels(type="email", reason="add_threepid").observe(
             send_attempt
         )
 
-        return 200, ret
+        # Wrap the session id in a JSON object
+        return 200, {"sid": sid}
 
 
 class MsisdnThreepidRequestTokenRestServlet(RestServlet):
@@ -534,25 +502,18 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
         self.config = hs.config
         self.clock = hs.get_clock()
         self.store = hs.get_datastores().main
-        if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+        if self.config.email.can_verify_email:
             self._failure_email_template = (
                 self.config.email.email_add_threepid_template_failure_html
             )
 
     async def on_GET(self, request: Request) -> None:
-        if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF:
-            if self.config.email.local_threepid_handling_disabled_due_to_email_config:
-                logger.warning(
-                    "Adding emails have been disabled due to lack of an email config"
-                )
-            raise SynapseError(
-                400, "Adding an email to your account is disabled on this server"
+        if not self.config.email.can_verify_email:
+            logger.warning(
+                "Adding emails have been disabled due to lack of an email config"
             )
-        elif self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
             raise SynapseError(
-                400,
-                "This homeserver is not validating threepids. Use an identity server "
-                "instead.",
+                400, "Adding an email to your account is disabled on this server"
             )
 
         sid = parse_string(request, "sid", required=True)
@@ -743,10 +704,12 @@ class ThreepidBindRestServlet(RestServlet):
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         body = parse_json_object_from_request(request)
 
-        assert_params_in_dict(body, ["id_server", "sid", "client_secret"])
+        assert_params_in_dict(
+            body, ["id_server", "sid", "id_access_token", "client_secret"]
+        )
         id_server = body["id_server"]
         sid = body["sid"]
-        id_access_token = body.get("id_access_token")  # optional
+        id_access_token = body["id_access_token"]
         client_secret = body["client_secret"]
         assert_valid_client_secret(client_secret)
 
diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py
index ce806e3c11..eb1b85721f 100644
--- a/synapse/rest/client/keys.py
+++ b/synapse/rest/client/keys.py
@@ -26,7 +26,7 @@ from synapse.http.servlet import (
     parse_string,
 )
 from synapse.http.site import SynapseRequest
-from synapse.logging.opentracing import log_kv, set_tag, trace
+from synapse.logging.opentracing import log_kv, set_tag, trace_with_opname
 from synapse.types import JsonDict, StreamToken
 
 from ._base import client_patterns, interactive_auth_handler
@@ -71,7 +71,7 @@ class KeyUploadServlet(RestServlet):
         self.e2e_keys_handler = hs.get_e2e_keys_handler()
         self.device_handler = hs.get_device_handler()
 
-    @trace(opname="upload_keys")
+    @trace_with_opname("upload_keys")
     async def on_POST(
         self, request: SynapseRequest, device_id: Optional[str]
     ) -> Tuple[int, JsonDict]:
diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py
index dd75e40f34..0437c87d8d 100644
--- a/synapse/rest/client/login.py
+++ b/synapse/rest/client/login.py
@@ -28,7 +28,7 @@ from typing import (
 
 from typing_extensions import TypedDict
 
-from synapse.api.errors import Codes, LoginError, SynapseError
+from synapse.api.errors import Codes, InvalidClientTokenError, LoginError, SynapseError
 from synapse.api.ratelimiting import Ratelimiter
 from synapse.api.urls import CLIENT_API_PREFIX
 from synapse.appservice import ApplicationService
@@ -172,7 +172,13 @@ class LoginRestServlet(RestServlet):
 
         try:
             if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
-                appservice = self.auth.get_appservice_by_req(request)
+                requester = await self.auth.get_user_by_req(request)
+                appservice = requester.app_service
+
+                if appservice is None:
+                    raise InvalidClientTokenError(
+                        "This login method is only valid for application services"
+                    )
 
                 if appservice.is_rate_limited():
                     await self._address_ratelimiter.ratelimit(
diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py
index 3644705e6a..8896f2df50 100644
--- a/synapse/rest/client/read_marker.py
+++ b/synapse/rest/client/read_marker.py
@@ -40,6 +40,10 @@ class ReadMarkerRestServlet(RestServlet):
         self.read_marker_handler = hs.get_read_marker_handler()
         self.presence_handler = hs.get_presence_handler()
 
+        self._known_receipt_types = {ReceiptTypes.READ, ReceiptTypes.FULLY_READ}
+        if hs.config.experimental.msc2285_enabled:
+            self._known_receipt_types.add(ReceiptTypes.READ_PRIVATE)
+
     async def on_POST(
         self, request: SynapseRequest, room_id: str
     ) -> Tuple[int, JsonDict]:
@@ -49,13 +53,7 @@ class ReadMarkerRestServlet(RestServlet):
 
         body = parse_json_object_from_request(request)
 
-        valid_receipt_types = {
-            ReceiptTypes.READ,
-            ReceiptTypes.FULLY_READ,
-            ReceiptTypes.READ_PRIVATE,
-        }
-
-        unrecognized_types = set(body.keys()) - valid_receipt_types
+        unrecognized_types = set(body.keys()) - self._known_receipt_types
         if unrecognized_types:
             # It's fine if there are unrecognized receipt types, but let's log
             # it to help debug clients that have typoed the receipt type.
@@ -65,31 +63,25 @@ class ReadMarkerRestServlet(RestServlet):
             # types.
             logger.info("Ignoring unrecognized receipt types: %s", unrecognized_types)
 
-        read_event_id = body.get(ReceiptTypes.READ, None)
-        if read_event_id:
-            await self.receipts_handler.received_client_receipt(
-                room_id,
-                ReceiptTypes.READ,
-                user_id=requester.user.to_string(),
-                event_id=read_event_id,
-            )
-
-        read_private_event_id = body.get(ReceiptTypes.READ_PRIVATE, None)
-        if read_private_event_id and self.config.experimental.msc2285_enabled:
-            await self.receipts_handler.received_client_receipt(
-                room_id,
-                ReceiptTypes.READ_PRIVATE,
-                user_id=requester.user.to_string(),
-                event_id=read_private_event_id,
-            )
-
-        read_marker_event_id = body.get(ReceiptTypes.FULLY_READ, None)
-        if read_marker_event_id:
-            await self.read_marker_handler.received_client_read_marker(
-                room_id,
-                user_id=requester.user.to_string(),
-                event_id=read_marker_event_id,
-            )
+        for receipt_type in self._known_receipt_types:
+            event_id = body.get(receipt_type, None)
+            # TODO Add validation to reject non-string event IDs.
+            if not event_id:
+                continue
+
+            if receipt_type == ReceiptTypes.FULLY_READ:
+                await self.read_marker_handler.received_client_read_marker(
+                    room_id,
+                    user_id=requester.user.to_string(),
+                    event_id=event_id,
+                )
+            else:
+                await self.receipts_handler.received_client_receipt(
+                    room_id,
+                    receipt_type,
+                    user_id=requester.user.to_string(),
+                    event_id=event_id,
+                )
 
         return 200, {}
 
diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py
index 4b03eb876b..409bfd43c1 100644
--- a/synapse/rest/client/receipts.py
+++ b/synapse/rest/client/receipts.py
@@ -39,31 +39,27 @@ class ReceiptRestServlet(RestServlet):
 
     def __init__(self, hs: "HomeServer"):
         super().__init__()
-        self.hs = hs
         self.auth = hs.get_auth()
         self.receipts_handler = hs.get_receipts_handler()
         self.read_marker_handler = hs.get_read_marker_handler()
         self.presence_handler = hs.get_presence_handler()
 
+        self._known_receipt_types = {ReceiptTypes.READ}
+        if hs.config.experimental.msc2285_enabled:
+            self._known_receipt_types.update(
+                (ReceiptTypes.READ_PRIVATE, ReceiptTypes.FULLY_READ)
+            )
+
     async def on_POST(
         self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
 
-        if self.hs.config.experimental.msc2285_enabled and receipt_type not in [
-            ReceiptTypes.READ,
-            ReceiptTypes.READ_PRIVATE,
-            ReceiptTypes.FULLY_READ,
-        ]:
+        if receipt_type not in self._known_receipt_types:
             raise SynapseError(
                 400,
-                "Receipt type must be 'm.read', 'org.matrix.msc2285.read.private' or 'm.fully_read'",
+                f"Receipt type must be {', '.join(self._known_receipt_types)}",
             )
-        elif (
-            not self.hs.config.experimental.msc2285_enabled
-            and receipt_type != ReceiptTypes.READ
-        ):
-            raise SynapseError(400, "Receipt type must be 'm.read'")
 
         parse_json_object_from_request(request, allow_empty_body=False)
 
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index e8e51a9c66..a8402cdb3a 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -31,7 +31,6 @@ from synapse.api.errors import (
 )
 from synapse.api.ratelimiting import Ratelimiter
 from synapse.config import ConfigError
-from synapse.config.emailconfig import ThreepidBehaviour
 from synapse.config.homeserver import HomeServerConfig
 from synapse.config.ratelimiting import FederationRateLimitConfig
 from synapse.config.server import is_threepid_reserved
@@ -74,7 +73,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
         self.identity_handler = hs.get_identity_handler()
         self.config = hs.config
 
-        if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+        if self.hs.config.email.can_verify_email:
             self.mailer = Mailer(
                 hs=self.hs,
                 app_name=self.config.email.email_app_name,
@@ -83,13 +82,10 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
             )
 
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
-        if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF:
-            if (
-                self.hs.config.email.local_threepid_handling_disabled_due_to_email_config
-            ):
-                logger.warning(
-                    "Email registration has been disabled due to lack of email config"
-                )
+        if not self.hs.config.email.can_verify_email:
+            logger.warning(
+                "Email registration has been disabled due to lack of email config"
+            )
             raise SynapseError(
                 400, "Email-based registration has been disabled on this server"
             )
@@ -138,35 +134,21 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
 
             raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
 
-        if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
-            assert self.hs.config.registration.account_threepid_delegate_email
-
-            # Have the configured identity server handle the request
-            ret = await self.identity_handler.requestEmailToken(
-                self.hs.config.registration.account_threepid_delegate_email,
-                email,
-                client_secret,
-                send_attempt,
-                next_link,
-            )
-        else:
-            # Send registration emails from Synapse
-            sid = await self.identity_handler.send_threepid_validation(
-                email,
-                client_secret,
-                send_attempt,
-                self.mailer.send_registration_mail,
-                next_link,
-            )
-
-            # Wrap the session id in a JSON object
-            ret = {"sid": sid}
+        # Send registration emails from Synapse
+        sid = await self.identity_handler.send_threepid_validation(
+            email,
+            client_secret,
+            send_attempt,
+            self.mailer.send_registration_mail,
+            next_link,
+        )
 
         threepid_send_requests.labels(type="email", reason="register").observe(
             send_attempt
         )
 
-        return 200, ret
+        # Wrap the session id in a JSON object
+        return 200, {"sid": sid}
 
 
 class MsisdnRegisterRequestTokenRestServlet(RestServlet):
@@ -260,7 +242,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
         self.clock = hs.get_clock()
         self.store = hs.get_datastores().main
 
-        if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+        if self.config.email.can_verify_email:
             self._failure_email_template = (
                 self.config.email.email_registration_template_failure_html
             )
@@ -270,11 +252,10 @@ class RegistrationSubmitTokenServlet(RestServlet):
             raise SynapseError(
                 400, "This medium is currently not supported for registration"
             )
-        if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF:
-            if self.config.email.local_threepid_handling_disabled_due_to_email_config:
-                logger.warning(
-                    "User registration via email has been disabled due to lack of email config"
-                )
+        if not self.config.email.can_verify_email:
+            logger.warning(
+                "User registration via email has been disabled due to lack of email config"
+            )
             raise SynapseError(
                 400, "Email-based registration is disabled on this server"
             )
diff --git a/synapse/rest/client/room_keys.py b/synapse/rest/client/room_keys.py
index 37e39570f6..f7081f638e 100644
--- a/synapse/rest/client/room_keys.py
+++ b/synapse/rest/client/room_keys.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Optional, Tuple
+from typing import TYPE_CHECKING, Optional, Tuple, cast
 
 from synapse.api.errors import Codes, NotFoundError, SynapseError
 from synapse.http.server import HttpServer
@@ -127,7 +127,7 @@ class RoomKeysServlet(RestServlet):
         requester = await self.auth.get_user_by_req(request, allow_guest=False)
         user_id = requester.user.to_string()
         body = parse_json_object_from_request(request)
-        version = parse_string(request, "version")
+        version = parse_string(request, "version", required=True)
 
         if session_id:
             body = {"sessions": {session_id: body}}
@@ -196,8 +196,11 @@ class RoomKeysServlet(RestServlet):
         user_id = requester.user.to_string()
         version = parse_string(request, "version", required=True)
 
-        room_keys = await self.e2e_room_keys_handler.get_room_keys(
-            user_id, version, room_id, session_id
+        room_keys = cast(
+            JsonDict,
+            await self.e2e_room_keys_handler.get_room_keys(
+                user_id, version, room_id, session_id
+            ),
         )
 
         # Convert room_keys to the right format to return.
@@ -240,7 +243,7 @@ class RoomKeysServlet(RestServlet):
 
         requester = await self.auth.get_user_by_req(request, allow_guest=False)
         user_id = requester.user.to_string()
-        version = parse_string(request, "version")
+        version = parse_string(request, "version", required=True)
 
         ret = await self.e2e_room_keys_handler.delete_room_keys(
             user_id, version, room_id, session_id
diff --git a/synapse/rest/client/sendtodevice.py b/synapse/rest/client/sendtodevice.py
index 3322c8ef48..1a8e9a96d4 100644
--- a/synapse/rest/client/sendtodevice.py
+++ b/synapse/rest/client/sendtodevice.py
@@ -19,7 +19,7 @@ from synapse.http import servlet
 from synapse.http.server import HttpServer
 from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request
 from synapse.http.site import SynapseRequest
-from synapse.logging.opentracing import set_tag, trace
+from synapse.logging.opentracing import set_tag, trace_with_opname
 from synapse.rest.client.transactions import HttpTransactionCache
 from synapse.types import JsonDict
 
@@ -43,7 +43,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
         self.txns = HttpTransactionCache(hs)
         self.device_message_handler = hs.get_device_message_handler()
 
-    @trace(opname="sendToDevice")
+    @trace_with_opname("sendToDevice")
     def on_PUT(
         self, request: SynapseRequest, message_type: str, txn_id: str
     ) -> Awaitable[Tuple[int, JsonDict]]:
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index 8bbf35148d..c2989765ce 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -37,7 +37,7 @@ from synapse.handlers.sync import (
 from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
 from synapse.http.site import SynapseRequest
-from synapse.logging.opentracing import trace
+from synapse.logging.opentracing import trace_with_opname
 from synapse.types import JsonDict, StreamToken
 from synapse.util import json_decoder
 
@@ -210,7 +210,7 @@ class SyncRestServlet(RestServlet):
         logger.debug("Event formatting complete")
         return 200, response_content
 
-    @trace(opname="sync.encode_response")
+    @trace_with_opname("sync.encode_response")
     async def encode_response(
         self,
         time_now: int,
@@ -315,7 +315,7 @@ class SyncRestServlet(RestServlet):
             ]
         }
 
-    @trace(opname="sync.encode_joined")
+    @trace_with_opname("sync.encode_joined")
     async def encode_joined(
         self,
         rooms: List[JoinedSyncResult],
@@ -340,7 +340,7 @@ class SyncRestServlet(RestServlet):
 
         return joined
 
-    @trace(opname="sync.encode_invited")
+    @trace_with_opname("sync.encode_invited")
     async def encode_invited(
         self,
         rooms: List[InvitedSyncResult],
@@ -371,7 +371,7 @@ class SyncRestServlet(RestServlet):
 
         return invited
 
-    @trace(opname="sync.encode_knocked")
+    @trace_with_opname("sync.encode_knocked")
     async def encode_knocked(
         self,
         rooms: List[KnockedSyncResult],
@@ -420,7 +420,7 @@ class SyncRestServlet(RestServlet):
 
         return knocked
 
-    @trace(opname="sync.encode_archived")
+    @trace_with_opname("sync.encode_archived")
     async def encode_archived(
         self,
         rooms: List[ArchivedSyncResult],
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 54a849eac9..b36c98a08e 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -109,10 +109,64 @@ class MediaInfo:
 
 class PreviewUrlResource(DirectServeJsonResource):
     """
-    Generating URL previews is a complicated task which many potential pitfalls.
-
-    See docs/development/url_previews.md for discussion of the design and
-    algorithm followed in this module.
+    The `GET /_matrix/media/r0/preview_url` endpoint provides a generic preview API
+    for URLs which outputs Open Graph (https://ogp.me/) responses (with some Matrix
+    specific additions).
+
+    This does have trade-offs compared to other designs:
+
+    * Pros:
+      * Simple and flexible; can be used by any clients at any point
+    * Cons:
+      * If each homeserver provides one of these independently, all the homeservers in a
+        room may needlessly DoS the target URI
+      * The URL metadata must be stored somewhere, rather than just using Matrix
+        itself to store the media.
+      * Matrix cannot be used to distribute the metadata between homeservers.
+
+    When Synapse is asked to preview a URL it does the following:
+
+    1. Checks against a URL blacklist (defined as `url_preview_url_blacklist` in the
+       config).
+    2. Checks the URL against an in-memory cache and returns the result if it exists. (This
+       is also used to de-duplicate processing of multiple in-flight requests at once.)
+    3. Kicks off a background process to generate a preview:
+       1. Checks URL and timestamp against the database cache and returns the result if it
+          has not expired and was successful (a 2xx return code).
+       2. Checks if the URL matches an oEmbed (https://oembed.com/) pattern. If it
+          does, update the URL to download.
+       3. Downloads the URL and stores it into a file via the media storage provider
+          and saves the local media metadata.
+       4. If the media is an image:
+          1. Generates thumbnails.
+          2. Generates an Open Graph response based on image properties.
+       5. If the media is HTML:
+          1. Decodes the HTML via the stored file.
+          2. Generates an Open Graph response from the HTML.
+          3. If a JSON oEmbed URL was found in the HTML via autodiscovery:
+             1. Downloads the URL and stores it into a file via the media storage provider
+                and saves the local media metadata.
+             2. Convert the oEmbed response to an Open Graph response.
+             3. Override any Open Graph data from the HTML with data from oEmbed.
+          4. If an image exists in the Open Graph response:
+             1. Downloads the URL and stores it into a file via the media storage
+                provider and saves the local media metadata.
+             2. Generates thumbnails.
+             3. Updates the Open Graph response based on image properties.
+       6. If the media is JSON and an oEmbed URL was found:
+          1. Convert the oEmbed response to an Open Graph response.
+          2. If a thumbnail or image is in the oEmbed response:
+             1. Downloads the URL and stores it into a file via the media storage
+                provider and saves the local media metadata.
+             2. Generates thumbnails.
+             3. Updates the Open Graph response based on image properties.
+       7. Stores the result in the database cache.
+    4. Returns the result.
+
+    The in-memory cache expires after 1 hour.
+
+    Expired entries in the database cache (and their associated media files) are
+    deleted every 10 seconds. The default expiration time is 1 hour from download.
     """
 
     isLeaf = True
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 2295adfaa7..5f725c7600 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -17,9 +17,11 @@
 import logging
 from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
 
-from synapse.api.errors import SynapseError
+from synapse.api.errors import Codes, SynapseError, cs_error
+from synapse.config.repository import THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP
 from synapse.http.server import (
     DirectServeJsonResource,
+    respond_with_json,
     set_corp_headers,
     set_cors_headers,
 )
@@ -309,6 +311,19 @@ class ThumbnailResource(DirectServeJsonResource):
             url_cache: True if this is from a URL cache.
             server_name: The server name, if this is a remote thumbnail.
         """
+        logger.debug(
+            "_select_and_respond_with_thumbnail: media_id=%s desired=%sx%s (%s) thumbnail_infos=%s",
+            media_id,
+            desired_width,
+            desired_height,
+            desired_method,
+            thumbnail_infos,
+        )
+
+        # If `dynamic_thumbnails` is enabled, we expect Synapse to go down a
+        # different code path to handle it.
+        assert not self.dynamic_thumbnails
+
         if thumbnail_infos:
             file_info = self._select_thumbnail(
                 desired_width,
@@ -384,8 +399,29 @@ class ThumbnailResource(DirectServeJsonResource):
                 file_info.thumbnail.length,
             )
         else:
+            # This might be because:
+            # 1. We can't create thumbnails for the given media (corrupted or
+            #    unsupported file type), or
+            # 2. The thumbnailing process never ran or errored out initially
+            #    when the media was first uploaded (these bugs should be
+            #    reported and fixed).
+            # Note that we don't attempt to generate a thumbnail now because
+            # `dynamic_thumbnails` is disabled.
             logger.info("Failed to find any generated thumbnails")
-            respond_404(request)
+
+            respond_with_json(
+                request,
+                400,
+                cs_error(
+                    "Cannot find any thumbnails for the requested media (%r). This might mean the media is not a supported_media_format=(%s) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)"
+                    % (
+                        request.postpath,
+                        ", ".join(THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP.keys()),
+                    ),
+                    code=Codes.UNKNOWN,
+                ),
+                send_cors=True,
+            )
 
     def _select_thumbnail(
         self,
diff --git a/synapse/rest/synapse/client/password_reset.py b/synapse/rest/synapse/client/password_reset.py
index 6ac9dbc7c9..b9402cfb75 100644
--- a/synapse/rest/synapse/client/password_reset.py
+++ b/synapse/rest/synapse/client/password_reset.py
@@ -17,7 +17,6 @@ from typing import TYPE_CHECKING, Tuple
 from twisted.web.server import Request
 
 from synapse.api.errors import ThreepidValidationError
-from synapse.config.emailconfig import ThreepidBehaviour
 from synapse.http.server import DirectServeHtmlResource
 from synapse.http.servlet import parse_string
 from synapse.util.stringutils import assert_valid_client_secret
@@ -46,9 +45,6 @@ class PasswordResetSubmitTokenResource(DirectServeHtmlResource):
         self.clock = hs.get_clock()
         self.store = hs.get_datastores().main
 
-        self._local_threepid_handling_disabled_due_to_email_config = (
-            hs.config.email.local_threepid_handling_disabled_due_to_email_config
-        )
         self._confirmation_email_template = (
             hs.config.email.email_password_reset_template_confirmation_html
         )
@@ -59,8 +55,8 @@ class PasswordResetSubmitTokenResource(DirectServeHtmlResource):
             hs.config.email.email_password_reset_template_failure_html
         )
 
-        # This resource should not be mounted if threepid behaviour is not LOCAL
-        assert hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL
+        # This resource should only be mounted if email validation is enabled
+        assert hs.config.email.can_verify_email
 
     async def _async_render_GET(self, request: Request) -> Tuple[int, bytes]:
         sid = parse_string(request, "sid", required=True)
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 781d9f06da..e3faa52cd6 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -24,14 +24,12 @@ from typing import (
     DefaultDict,
     Dict,
     FrozenSet,
-    Iterable,
     List,
     Mapping,
     Optional,
     Sequence,
     Set,
     Tuple,
-    Union,
 )
 
 import attr
@@ -47,6 +45,7 @@ from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServ
 from synapse.state import v1, v2
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.roommember import ProfileInfo
+from synapse.storage.state import StateFilter
 from synapse.types import StateMap
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -54,6 +53,7 @@ from synapse.util.metrics import Measure, measure_func
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
+    from synapse.storage.controllers import StateStorageController
     from synapse.storage.databases.main import DataStore
 
 logger = logging.getLogger(__name__)
@@ -83,17 +83,23 @@ def _gen_state_id() -> str:
 
 
 class _StateCacheEntry:
-    __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
+    __slots__ = ["_state", "state_group", "prev_group", "delta_ids"]
 
     def __init__(
         self,
-        state: StateMap[str],
+        state: Optional[StateMap[str]],
         state_group: Optional[int],
         prev_group: Optional[int] = None,
         delta_ids: Optional[StateMap[str]] = None,
     ):
+        if state is None and state_group is None:
+            raise Exception("Either state or state group must be not None")
+
         # A map from (type, state_key) to event_id.
-        self.state = frozendict(state)
+        #
+        # This can be None if we have a `state_group` (as then we can fetch the
+        # state from the DB.)
+        self._state = frozendict(state) if state is not None else None
 
         # the ID of a state group if one and only one is involved.
         # otherwise, None otherwise?
@@ -102,20 +108,30 @@ class _StateCacheEntry:
         self.prev_group = prev_group
         self.delta_ids = frozendict(delta_ids) if delta_ids is not None else None
 
-        # The `state_id` is a unique ID we generate that can be used as ID for
-        # this collection of state. Usually this would be the same as the
-        # state group, but on worker instances we can't generate a new state
-        # group each time we resolve state, so we generate a separate one that
-        # isn't persisted and is used solely for caches.
-        # `state_id` is either a state_group (and so an int) or a string. This
-        # ensures we don't accidentally persist a state_id as a stateg_group
-        if state_group:
-            self.state_id: Union[str, int] = state_group
-        else:
-            self.state_id = _gen_state_id()
+    async def get_state(
+        self,
+        state_storage: "StateStorageController",
+        state_filter: Optional["StateFilter"] = None,
+    ) -> StateMap[str]:
+        """Get the state map for this entry, either from the in-memory state or
+        looking up the state group in the DB.
+        """
+
+        if self._state is not None:
+            return self._state
+
+        assert self.state_group is not None
+
+        return await state_storage.get_state_ids_for_group(
+            self.state_group, state_filter
+        )
 
     def __len__(self) -> int:
-        return len(self.state)
+        # The len should is used to estimate how large this cache entry is, for
+        # cache eviction purposes. This is why if `self.state` is None it's fine
+        # to return 1.
+
+        return len(self._state) if self._state else 1
 
 
 class StateHandler:
@@ -137,23 +153,29 @@ class StateHandler:
             ReplicationUpdateCurrentStateRestServlet.make_client(hs)
         )
 
-    async def get_current_state_ids(
+    async def compute_state_after_events(
         self,
         room_id: str,
-        latest_event_ids: Collection[str],
+        event_ids: Collection[str],
+        state_filter: Optional[StateFilter] = None,
     ) -> StateMap[str]:
-        """Get the current state, or the state at a set of events, for a room
+        """Fetch the state after each of the given event IDs. Resolve them and return.
+
+        This is typically used where `event_ids` is a collection of forward extremities
+        in a room, intended to become the `prev_events` of a new event E. If so, the
+        return value of this function represents the state before E.
 
         Args:
-            room_id:
-            latest_event_ids: The forward extremities to resolve.
+            room_id: the room_id containing the given events.
+            event_ids: the events whose state should be fetched and resolved.
 
         Returns:
-            the state dict, mapping from (event_type, state_key) -> event_id
+            the state dict (a mapping from (event_type, state_key) -> event_id) which
+            holds the resolution of the states after the given event IDs.
         """
-        logger.debug("calling resolve_state_groups from get_current_state_ids")
-        ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
-        return ret.state
+        logger.debug("calling resolve_state_groups from compute_state_after_events")
+        ret = await self.resolve_state_groups_for_events(room_id, event_ids)
+        return await ret.get_state(self._state_storage_controller, state_filter)
 
     async def get_current_users_in_room(
         self, room_id: str, latest_event_ids: List[str]
@@ -177,7 +199,8 @@ class StateHandler:
 
         logger.debug("calling resolve_state_groups from get_current_users_in_room")
         entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
-        return await self.store.get_joined_users_from_state(room_id, entry)
+        state = await entry.get_state(self._state_storage_controller, StateFilter.all())
+        return await self.store.get_joined_users_from_state(room_id, state, entry)
 
     async def get_hosts_in_room_at_events(
         self, room_id: str, event_ids: Collection[str]
@@ -192,7 +215,8 @@ class StateHandler:
             The hosts in the room at the given events
         """
         entry = await self.resolve_state_groups_for_events(room_id, event_ids)
-        return await self.store.get_joined_hosts(room_id, entry)
+        state = await entry.get_state(self._state_storage_controller, StateFilter.all())
+        return await self.store.get_joined_hosts(room_id, state, entry)
 
     async def compute_event_context(
         self,
@@ -227,10 +251,19 @@ class StateHandler:
         #
         if state_ids_before_event:
             # if we're given the state before the event, then we use that
-            state_group_before_event = None
             state_group_before_event_prev_group = None
             deltas_to_state_group_before_event = None
-            entry = None
+
+            # .. though we need to get a state group for it.
+            state_group_before_event = (
+                await self._state_storage_controller.store_state_group(
+                    event.event_id,
+                    event.room_id,
+                    prev_group=None,
+                    delta_ids=None,
+                    current_state_ids=state_ids_before_event,
+                )
+            )
 
         else:
             # otherwise, we'll need to resolve the state across the prev_events.
@@ -264,36 +297,32 @@ class StateHandler:
                 await_full_state=False,
             )
 
-            state_ids_before_event = entry.state
-            state_group_before_event = entry.state_group
             state_group_before_event_prev_group = entry.prev_group
             deltas_to_state_group_before_event = entry.delta_ids
+            state_ids_before_event = None
+
+            # We make sure that we have a state group assigned to the state.
+            if entry.state_group is None:
+                # store_state_group requires us to have either a previous state group
+                # (with deltas) or the complete state map. So, if we don't have a
+                # previous state group, load the complete state map now.
+                if state_group_before_event_prev_group is None:
+                    state_ids_before_event = await entry.get_state(
+                        self._state_storage_controller, StateFilter.all()
+                    )
 
-        #
-        # make sure that we have a state group at that point. If it's not a state event,
-        # that will be the state group for the new event. If it *is* a state event,
-        # it might get rejected (in which case we'll need to persist it with the
-        # previous state group)
-        #
-
-        if not state_group_before_event:
-            state_group_before_event = (
-                await self._state_storage_controller.store_state_group(
-                    event.event_id,
-                    event.room_id,
-                    prev_group=state_group_before_event_prev_group,
-                    delta_ids=deltas_to_state_group_before_event,
-                    current_state_ids=state_ids_before_event,
+                state_group_before_event = (
+                    await self._state_storage_controller.store_state_group(
+                        event.event_id,
+                        event.room_id,
+                        prev_group=state_group_before_event_prev_group,
+                        delta_ids=deltas_to_state_group_before_event,
+                        current_state_ids=state_ids_before_event,
+                    )
                 )
-            )
-
-            # Assign the new state group to the cached state entry.
-            #
-            # Note that this can race in that we could generate multiple state
-            # groups for the same state entry, but that is just inefficient
-            # rather than dangerous.
-            if entry and entry.state_group is None:
                 entry.state_group = state_group_before_event
+            else:
+                state_group_before_event = entry.state_group
 
         #
         # now if it's not a state event, we're done
@@ -315,13 +344,18 @@ class StateHandler:
         #
 
         key = (event.type, event.state_key)
-        if key in state_ids_before_event:
-            replaces = state_ids_before_event[key]
-            if replaces != event.event_id:
-                event.unsigned["replaces_state"] = replaces
 
-        state_ids_after_event = dict(state_ids_before_event)
-        state_ids_after_event[key] = event.event_id
+        if state_ids_before_event is not None:
+            replaces = state_ids_before_event.get(key)
+        else:
+            replaces_state_map = await entry.get_state(
+                self._state_storage_controller, StateFilter.from_types([key])
+            )
+            replaces = replaces_state_map.get(key)
+
+        if replaces and replaces != event.event_id:
+            event.unsigned["replaces_state"] = replaces
+
         delta_ids = {key: event.event_id}
 
         state_group_after_event = (
@@ -330,7 +364,7 @@ class StateHandler:
                 event.room_id,
                 prev_group=state_group_before_event,
                 delta_ids=delta_ids,
-                current_state_ids=state_ids_after_event,
+                current_state_ids=None,
             )
         )
 
@@ -372,9 +406,6 @@ class StateHandler:
         state_group_ids_set = set(state_group_ids)
         if len(state_group_ids_set) == 1:
             (state_group_id,) = state_group_ids_set
-            state = await self._state_storage_controller.get_state_for_groups(
-                state_group_ids_set
-            )
             (
                 prev_group,
                 delta_ids,
@@ -382,7 +413,7 @@ class StateHandler:
                 state_group_id
             )
             return _StateCacheEntry(
-                state=state[state_group_id],
+                state=None,
                 state_group=state_group_id,
                 prev_group=prev_group,
                 delta_ids=delta_ids,
@@ -405,31 +436,6 @@ class StateHandler:
         )
         return result
 
-    async def resolve_events(
-        self,
-        room_version: str,
-        state_sets: Collection[Iterable[EventBase]],
-        event: EventBase,
-    ) -> StateMap[EventBase]:
-        logger.info(
-            "Resolving state for %s with %d groups", event.room_id, len(state_sets)
-        )
-        state_set_ids = [
-            {(ev.type, ev.state_key): ev.event_id for ev in st} for st in state_sets
-        ]
-
-        state_map = {ev.event_id: ev for st in state_sets for ev in st}
-
-        new_state = await self._state_resolution_handler.resolve_events_with_store(
-            event.room_id,
-            room_version,
-            state_set_ids,
-            event_map=state_map,
-            state_res_store=StateResolutionStore(self.store),
-        )
-
-        return {key: state_map[ev_id] for key, ev_id in new_state.items()}
-
     async def update_current_state(self, room_id: str) -> None:
         """Recalculates the current state for a room, and persists it.
 
@@ -752,6 +758,12 @@ def _make_state_cache_entry(
     delta_ids: Optional[StateMap[str]] = None
 
     for old_group, old_state in state_groups_ids.items():
+        if old_state.keys() - new_state.keys():
+            # Currently we don't support deltas that remove keys from the state
+            # map, so we have to ignore this group as a candidate to base the
+            # new group on.
+            continue
+
         n_delta_ids = {k: v for k, v in new_state.items() if old_state.get(k) != v}
         if not delta_ids or len(n_delta_ids) < len(delta_ids):
             prev_group = old_group
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index b8c8dcd76b..a2f8310388 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -96,6 +96,10 @@ class SQLBaseStore(metaclass=ABCMeta):
         cache doesn't exist. Mainly used for invalidating caches on workers,
         where they may not have the cache.
 
+        Note that this function does not invalidate any remote caches, only the
+        local in-memory ones. Any remote invalidation must be performed before
+        calling this.
+
         Args:
             cache_name
             key: Entry to invalidate. If None then invalidates the entire
@@ -112,7 +116,10 @@ class SQLBaseStore(metaclass=ABCMeta):
         if key is None:
             cache.invalidate_all()
         else:
-            cache.invalidate(tuple(key))
+            # Prefer any local-only invalidation method. Invalidating any non-local
+            # cache must be be done before this.
+            invalidate_method = getattr(cache, "invalidate_local", cache.invalidate)
+            invalidate_method(tuple(key))
 
 
 def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any:
diff --git a/synapse/storage/controllers/__init__.py b/synapse/storage/controllers/__init__.py
index 55649719f6..45101cda7a 100644
--- a/synapse/storage/controllers/__init__.py
+++ b/synapse/storage/controllers/__init__.py
@@ -43,4 +43,6 @@ class StorageControllers:
 
         self.persistence = None
         if stores.persist_events:
-            self.persistence = EventsPersistenceStorageController(hs, stores)
+            self.persistence = EventsPersistenceStorageController(
+                hs, stores, self.state
+            )
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index ea499ce0f8..cf98b0ab48 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -48,9 +48,11 @@ from synapse.events.snapshot import EventContext
 from synapse.logging import opentracing
 from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.controllers.state import StateStorageController
 from synapse.storage.databases import Databases
 from synapse.storage.databases.main.events import DeltaState
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
+from synapse.storage.state import StateFilter
 from synapse.types import (
     PersistedEventPosition,
     RoomStreamToken,
@@ -308,7 +310,12 @@ class EventsPersistenceStorageController:
     current state and forward extremity changes.
     """
 
-    def __init__(self, hs: "HomeServer", stores: Databases):
+    def __init__(
+        self,
+        hs: "HomeServer",
+        stores: Databases,
+        state_controller: StateStorageController,
+    ):
         # We ultimately want to split out the state store from the main store,
         # so we use separate variables here even though they point to the same
         # store for now.
@@ -325,6 +332,7 @@ class EventsPersistenceStorageController:
             self._process_event_persist_queue_task
         )
         self._state_resolution_handler = hs.get_state_resolution_handler()
+        self._state_controller = state_controller
 
     async def _process_event_persist_queue_task(
         self,
@@ -504,7 +512,7 @@ class EventsPersistenceStorageController:
             state_res_store=StateResolutionStore(self.main_store),
         )
 
-        return res.state
+        return await res.get_state(self._state_controller, StateFilter.all())
 
     async def _persist_event_batch(
         self, _room_id: str, task: _PersistEventsTask
@@ -940,7 +948,8 @@ class EventsPersistenceStorageController:
                 events_context,
             )
 
-        return res.state, None, new_latest_event_ids
+        full_state = await res.get_state(self._state_controller)
+        return full_state, None, new_latest_event_ids
 
     async def _prune_extremities(
         self,
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index d3a44bc876..e08f956e6e 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -346,7 +346,7 @@ class StateStorageController:
         room_id: str,
         prev_group: Optional[int],
         delta_ids: Optional[StateMap[str]],
-        current_state_ids: StateMap[str],
+        current_state_ids: Optional[StateMap[str]],
     ) -> int:
         """Store a new set of state, returning a newly assigned state group.
 
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index e21ab08515..ea672ff89e 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -23,6 +23,7 @@ from time import monotonic as monotonic_time
 from typing import (
     TYPE_CHECKING,
     Any,
+    Awaitable,
     Callable,
     Collection,
     Dict,
@@ -168,6 +169,7 @@ class LoggingDatabaseConnection:
         *,
         txn_name: Optional[str] = None,
         after_callbacks: Optional[List["_CallbackListEntry"]] = None,
+        async_after_callbacks: Optional[List["_AsyncCallbackListEntry"]] = None,
         exception_callbacks: Optional[List["_CallbackListEntry"]] = None,
     ) -> "LoggingTransaction":
         if not txn_name:
@@ -178,6 +180,7 @@ class LoggingDatabaseConnection:
             name=txn_name,
             database_engine=self.engine,
             after_callbacks=after_callbacks,
+            async_after_callbacks=async_after_callbacks,
             exception_callbacks=exception_callbacks,
         )
 
@@ -209,6 +212,9 @@ class LoggingDatabaseConnection:
 
 # The type of entry which goes on our after_callbacks and exception_callbacks lists.
 _CallbackListEntry = Tuple[Callable[..., object], Tuple[object, ...], Dict[str, object]]
+_AsyncCallbackListEntry = Tuple[
+    Callable[..., Awaitable], Tuple[object, ...], Dict[str, object]
+]
 
 P = ParamSpec("P")
 R = TypeVar("R")
@@ -227,6 +233,10 @@ class LoggingTransaction:
             that have been added by `call_after` which should be run on
             successful completion of the transaction. None indicates that no
             callbacks should be allowed to be scheduled to run.
+        async_after_callbacks: A list that asynchronous callbacks will be appended
+            to by `async_call_after` which should run, before after_callbacks, on
+            successful completion of the transaction. None indicates that no
+            callbacks should be allowed to be scheduled to run.
         exception_callbacks: A list that callbacks will be appended
             to that have been added by `call_on_exception` which should be run
             if transaction ends with an error. None indicates that no callbacks
@@ -238,6 +248,7 @@ class LoggingTransaction:
         "name",
         "database_engine",
         "after_callbacks",
+        "async_after_callbacks",
         "exception_callbacks",
     ]
 
@@ -247,12 +258,14 @@ class LoggingTransaction:
         name: str,
         database_engine: BaseDatabaseEngine,
         after_callbacks: Optional[List[_CallbackListEntry]] = None,
+        async_after_callbacks: Optional[List[_AsyncCallbackListEntry]] = None,
         exception_callbacks: Optional[List[_CallbackListEntry]] = None,
     ):
         self.txn = txn
         self.name = name
         self.database_engine = database_engine
         self.after_callbacks = after_callbacks
+        self.async_after_callbacks = async_after_callbacks
         self.exception_callbacks = exception_callbacks
 
     def call_after(
@@ -277,6 +290,28 @@ class LoggingTransaction:
         # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668
         self.after_callbacks.append((callback, args, kwargs))  # type: ignore[arg-type]
 
+    def async_call_after(
+        self, callback: Callable[P, Awaitable], *args: P.args, **kwargs: P.kwargs
+    ) -> None:
+        """Call the given asynchronous callback on the main twisted thread after
+        the transaction has finished (but before those added in `call_after`).
+
+        Mostly used to invalidate remote caches after transactions.
+
+        Note that transactions may be retried a few times if they encounter database
+        errors such as serialization failures. Callbacks given to `async_call_after`
+        will accumulate across transaction attempts and will _all_ be called once a
+        transaction attempt succeeds, regardless of whether previous transaction
+        attempts failed. Otherwise, if all transaction attempts fail, all
+        `call_on_exception` callbacks will be run instead.
+        """
+        # if self.async_after_callbacks is None, that means that whatever constructed the
+        # LoggingTransaction isn't expecting there to be any callbacks; assert that
+        # is not the case.
+        assert self.async_after_callbacks is not None
+        # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668
+        self.async_after_callbacks.append((callback, args, kwargs))  # type: ignore[arg-type]
+
     def call_on_exception(
         self, callback: Callable[P, object], *args: P.args, **kwargs: P.kwargs
     ) -> None:
@@ -574,6 +609,7 @@ class DatabasePool:
         conn: LoggingDatabaseConnection,
         desc: str,
         after_callbacks: List[_CallbackListEntry],
+        async_after_callbacks: List[_AsyncCallbackListEntry],
         exception_callbacks: List[_CallbackListEntry],
         func: Callable[Concatenate[LoggingTransaction, P], R],
         *args: P.args,
@@ -597,6 +633,7 @@ class DatabasePool:
             conn
             desc
             after_callbacks
+            async_after_callbacks
             exception_callbacks
             func
             *args
@@ -659,6 +696,7 @@ class DatabasePool:
                 cursor = conn.cursor(
                     txn_name=name,
                     after_callbacks=after_callbacks,
+                    async_after_callbacks=async_after_callbacks,
                     exception_callbacks=exception_callbacks,
                 )
                 try:
@@ -798,6 +836,7 @@ class DatabasePool:
 
         async def _runInteraction() -> R:
             after_callbacks: List[_CallbackListEntry] = []
+            async_after_callbacks: List[_AsyncCallbackListEntry] = []
             exception_callbacks: List[_CallbackListEntry] = []
 
             if not current_context():
@@ -809,6 +848,7 @@ class DatabasePool:
                         self.new_transaction,
                         desc,
                         after_callbacks,
+                        async_after_callbacks,
                         exception_callbacks,
                         func,
                         *args,
@@ -817,13 +857,17 @@ class DatabasePool:
                         **kwargs,
                     )
 
+                # We order these assuming that async functions call out to external
+                # systems (e.g. to invalidate a cache) and the sync functions make these
+                # changes on any local in-memory caches/similar, and thus must be second.
+                for async_callback, async_args, async_kwargs in async_after_callbacks:
+                    await async_callback(*async_args, **async_kwargs)
                 for after_callback, after_args, after_kwargs in after_callbacks:
                     after_callback(*after_args, **after_kwargs)
-
                 return cast(R, result)
             except Exception:
-                for after_callback, after_args, after_kwargs in exception_callbacks:
-                    after_callback(*after_args, **after_kwargs)
+                for exception_callback, after_args, after_kwargs in exception_callbacks:
+                    exception_callback(*after_args, **after_kwargs)
                 raise
 
         # To handle cancellation, we ensure that `after_callback`s and
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index e284454b66..64b70a7b28 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -371,52 +371,30 @@ class ApplicationServiceTransactionWorkerStore(
             device_list_summary=DeviceListUpdates(),
         )
 
-    async def set_appservice_last_pos(self, pos: int) -> None:
-        def set_appservice_last_pos_txn(txn: LoggingTransaction) -> None:
-            txn.execute(
-                "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
-            )
+    async def get_appservice_last_pos(self) -> int:
+        """
+        Get the last stream ordering position for the appservice process.
+        """
 
-        await self.db_pool.runInteraction(
-            "set_appservice_last_pos", set_appservice_last_pos_txn
+        return await self.db_pool.simple_select_one_onecol(
+            table="appservice_stream_position",
+            retcol="stream_ordering",
+            keyvalues={},
+            desc="get_appservice_last_pos",
         )
 
-    async def get_new_events_for_appservice(
-        self, current_id: int, limit: int
-    ) -> Tuple[int, List[EventBase]]:
-        """Get all new events for an appservice"""
-
-        def get_new_events_for_appservice_txn(
-            txn: LoggingTransaction,
-        ) -> Tuple[int, List[str]]:
-            sql = (
-                "SELECT e.stream_ordering, e.event_id"
-                " FROM events AS e"
-                " WHERE"
-                " (SELECT stream_ordering FROM appservice_stream_position)"
-                "     < e.stream_ordering"
-                " AND e.stream_ordering <= ?"
-                " ORDER BY e.stream_ordering ASC"
-                " LIMIT ?"
-            )
-
-            txn.execute(sql, (current_id, limit))
-            rows = txn.fetchall()
-
-            upper_bound = current_id
-            if len(rows) == limit:
-                upper_bound = rows[-1][0]
-
-            return upper_bound, [row[1] for row in rows]
+    async def set_appservice_last_pos(self, pos: int) -> None:
+        """
+        Set the last stream ordering position for the appservice process.
+        """
 
-        upper_bound, event_ids = await self.db_pool.runInteraction(
-            "get_new_events_for_appservice", get_new_events_for_appservice_txn
+        await self.db_pool.simple_update_one(
+            table="appservice_stream_position",
+            keyvalues={},
+            updatevalues={"stream_ordering": pos},
+            desc="set_appservice_last_pos",
         )
 
-        events = await self.get_events_as_list(event_ids, get_prev_content=True)
-
-        return upper_bound, events
-
     async def get_type_stream_id_for_appservice(
         self, service: ApplicationService, type: str
     ) -> int:
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 1653a6a9b6..2367ddeea3 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -193,7 +193,10 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
         relates_to: Optional[str],
         backfilled: bool,
     ) -> None:
-        self._invalidate_get_event_cache(event_id)
+        # This invalidates any local in-memory cached event objects, the original
+        # process triggering the invalidation is responsible for clearing any external
+        # cached objects.
+        self._invalidate_local_get_event_cache(event_id)
         self.have_seen_event.invalidate((room_id, event_id))
 
         self.get_latest_event_ids_in_room.invalidate((room_id,))
@@ -208,7 +211,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
             self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
 
         if redacts:
-            self._invalidate_get_event_cache(redacts)
+            self._invalidate_local_get_event_cache(redacts)
             # Caches which might leak edits must be invalidated for the event being
             # redacted.
             self.get_relations_for_event.invalidate((redacts,))
diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index fd3fc298b3..58177ecec1 100644
--- a/synapse/storage/databases/main/censor_events.py
+++ b/synapse/storage/databases/main/censor_events.py
@@ -194,7 +194,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
             # changed its content in the database. We can't call
             # self._invalidate_cache_and_stream because self.get_event_cache isn't of the
             # right type.
-            txn.call_after(self._get_event_cache.invalidate, (event.event_id,))
+            self.invalidate_get_event_cache_after_txn(txn, event.event_id)
             # Send that invalidation to replication so that other workers also invalidate
             # the event cache.
             self._send_invalidation_to_replication(
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index adde5d0978..7a6ed332aa 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -669,7 +669,7 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
 
     @trace
     async def get_user_devices_from_cache(
-        self, query_list: List[Tuple[str, str]]
+        self, query_list: List[Tuple[str, Optional[str]]]
     ) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
         """Get the devices (and keys if any) for remote users from the cache.
 
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 9b293475c8..60f622ad71 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -22,11 +22,14 @@ from typing import (
     List,
     Optional,
     Tuple,
+    Union,
     cast,
+    overload,
 )
 
 import attr
 from canonicaljson import encode_canonical_json
+from typing_extensions import Literal
 
 from synapse.api.constants import DeviceKeyAlgorithms
 from synapse.appservice import (
@@ -113,7 +116,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
             user_devices = devices[user_id]
             results = []
             for device_id, device in user_devices.items():
-                result = {"device_id": device_id}
+                result: JsonDict = {"device_id": device_id}
 
                 keys = device.keys
                 if keys:
@@ -156,6 +159,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
             rv[user_id] = {}
             for device_id, device_info in device_keys.items():
                 r = device_info.keys
+                if r is None:
+                    continue
+
                 r["unsigned"] = {}
                 display_name = device_info.display_name
                 if display_name is not None:
@@ -164,13 +170,42 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
 
         return rv
 
+    @overload
+    async def get_e2e_device_keys_and_signatures(
+        self,
+        query_list: Collection[Tuple[str, Optional[str]]],
+        include_all_devices: Literal[False] = False,
+    ) -> Dict[str, Dict[str, DeviceKeyLookupResult]]:
+        ...
+
+    @overload
+    async def get_e2e_device_keys_and_signatures(
+        self,
+        query_list: Collection[Tuple[str, Optional[str]]],
+        include_all_devices: bool = False,
+        include_deleted_devices: Literal[False] = False,
+    ) -> Dict[str, Dict[str, DeviceKeyLookupResult]]:
+        ...
+
+    @overload
+    async def get_e2e_device_keys_and_signatures(
+        self,
+        query_list: Collection[Tuple[str, Optional[str]]],
+        include_all_devices: Literal[True],
+        include_deleted_devices: Literal[True],
+    ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
+        ...
+
     @trace
     async def get_e2e_device_keys_and_signatures(
         self,
-        query_list: List[Tuple[str, Optional[str]]],
+        query_list: Collection[Tuple[str, Optional[str]]],
         include_all_devices: bool = False,
         include_deleted_devices: bool = False,
-    ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
+    ) -> Union[
+        Dict[str, Dict[str, DeviceKeyLookupResult]],
+        Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]],
+    ]:
         """Fetch a list of device keys
 
         Any cross-signatures made on the keys by the owner of the device are also
@@ -1044,7 +1079,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
                 _claim_e2e_one_time_key = _claim_e2e_one_time_key_simple
                 db_autocommit = False
 
-            row = await self.db_pool.runInteraction(
+            claim_row = await self.db_pool.runInteraction(
                 "claim_e2e_one_time_keys",
                 _claim_e2e_one_time_key,
                 user_id,
@@ -1052,11 +1087,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
                 algorithm,
                 db_autocommit=db_autocommit,
             )
-            if row:
+            if claim_row:
                 device_results = results.setdefault(user_id, {}).setdefault(
                     device_id, {}
                 )
-                device_results[row[0]] = row[1]
+                device_results[claim_row[0]] = claim_row[1]
                 continue
 
             # No one-time key available, so see if there's a fallback
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index eb4efbb93c..156e1bd5ab 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1293,7 +1293,7 @@ class PersistEventsStore:
         depth_updates: Dict[str, int] = {}
         for event, context in events_and_contexts:
             # Remove the any existing cache entries for the event_ids
-            txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
+            self.store.invalidate_get_event_cache_after_txn(txn, event.event_id)
             # Then update the `stream_ordering` position to mark the latest
             # event as the front of the room. This should not be done for
             # backfilled events because backfilled events have negative
@@ -1669,13 +1669,13 @@ class PersistEventsStore:
             if not row["rejects"] and not row["redacts"]:
                 to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
 
-        def prefill() -> None:
+        async def prefill() -> None:
             for cache_entry in to_prefill:
-                self.store._get_event_cache.set(
+                await self.store._get_event_cache.set(
                     (cache_entry.event.event_id,), cache_entry
                 )
 
-        txn.call_after(prefill)
+        txn.async_call_after(prefill)
 
     def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None:
         """Invalidate the caches for the redacted event.
@@ -1684,7 +1684,7 @@ class PersistEventsStore:
         _invalidate_caches_for_event.
         """
         assert event.redacts is not None
-        txn.call_after(self.store._invalidate_get_event_cache, event.redacts)
+        self.store.invalidate_get_event_cache_after_txn(txn, event.redacts)
         txn.call_after(self.store.get_relations_for_event.invalidate, (event.redacts,))
         txn.call_after(self.store.get_applicable_edit.invalidate, (event.redacts,))
 
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index eeca85fc94..6e8aeed7b4 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -67,6 +67,8 @@ class _BackgroundUpdates:
     EVENT_EDGES_DROP_INVALID_ROWS = "event_edges_drop_invalid_rows"
     EVENT_EDGES_REPLACE_INDEX = "event_edges_replace_index"
 
+    EVENTS_POPULATE_STATE_KEY_REJECTIONS = "events_populate_state_key_rejections"
+
 
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class _CalculateChainCover:
@@ -253,6 +255,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
             replaces_index="ev_edges_id",
         )
 
+        self.db_pool.updates.register_background_update_handler(
+            _BackgroundUpdates.EVENTS_POPULATE_STATE_KEY_REJECTIONS,
+            self._background_events_populate_state_key_rejections,
+        )
+
     async def _background_reindex_fields_sender(
         self, progress: JsonDict, batch_size: int
     ) -> int:
@@ -1399,3 +1406,83 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
             )
 
         return batch_size
+
+    async def _background_events_populate_state_key_rejections(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
+        """Back-populate `events.state_key` and `events.rejection_reason"""
+
+        min_stream_ordering_exclusive = progress["min_stream_ordering_exclusive"]
+        max_stream_ordering_inclusive = progress["max_stream_ordering_inclusive"]
+
+        def _populate_txn(txn: LoggingTransaction) -> bool:
+            """Returns True if we're done."""
+
+            # first we need to find an endpoint.
+            # we need to find the final row in the batch of batch_size, which means
+            # we need to skip over (batch_size-1) rows and get the next row.
+            txn.execute(
+                """
+                SELECT stream_ordering FROM events
+                WHERE stream_ordering > ? AND stream_ordering <= ?
+                ORDER BY stream_ordering
+                LIMIT 1 OFFSET ?
+                """,
+                (
+                    min_stream_ordering_exclusive,
+                    max_stream_ordering_inclusive,
+                    batch_size - 1,
+                ),
+            )
+
+            endpoint = None
+            row = txn.fetchone()
+            if row:
+                endpoint = row[0]
+
+            where_clause = "stream_ordering > ?"
+            args = [min_stream_ordering_exclusive]
+            if endpoint:
+                where_clause += " AND stream_ordering <= ?"
+                args.append(endpoint)
+
+            # now do the updates.
+            txn.execute(
+                f"""
+                UPDATE events
+                SET state_key = (SELECT state_key FROM state_events se WHERE se.event_id = events.event_id),
+                    rejection_reason = (SELECT reason FROM rejections rej WHERE rej.event_id = events.event_id)
+                WHERE ({where_clause})
+                """,
+                args,
+            )
+
+            logger.info(
+                "populated new `events` columns up to %s/%i: updated %i rows",
+                endpoint,
+                max_stream_ordering_inclusive,
+                txn.rowcount,
+            )
+
+            if endpoint is None:
+                # we're done
+                return True
+
+            progress["min_stream_ordering_exclusive"] = endpoint
+            self.db_pool.updates._background_update_progress_txn(
+                txn,
+                _BackgroundUpdates.EVENTS_POPULATE_STATE_KEY_REJECTIONS,
+                progress,
+            )
+            return False
+
+        done = await self.db_pool.runInteraction(
+            desc="events_populate_state_key_rejections", func=_populate_txn
+        )
+
+        if done:
+            await self.db_pool.updates._end_background_update(
+                _BackgroundUpdates.EVENTS_POPULATE_STATE_KEY_REJECTIONS
+            )
+
+        return batch_size
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index b99b107784..5914a35420 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -79,7 +79,7 @@ from synapse.types import JsonDict, get_domain_from_id
 from synapse.util import unwrapFirstError
 from synapse.util.async_helpers import ObservableDeferred, delay_cancellation
 from synapse.util.caches.descriptors import cached, cachedList
-from synapse.util.caches.lrucache import LruCache
+from synapse.util.caches.lrucache import AsyncLruCache
 from synapse.util.iterutils import batch_iter
 from synapse.util.metrics import Measure
 
@@ -238,7 +238,9 @@ class EventsWorkerStore(SQLBaseStore):
                 5 * 60 * 1000,
             )
 
-        self._get_event_cache: LruCache[Tuple[str], EventCacheEntry] = LruCache(
+        self._get_event_cache: AsyncLruCache[
+            Tuple[str], EventCacheEntry
+        ] = AsyncLruCache(
             cache_name="*getEvent*",
             max_size=hs.config.caches.event_cache_size,
         )
@@ -292,25 +294,6 @@ class EventsWorkerStore(SQLBaseStore):
 
         super().process_replication_rows(stream_name, instance_name, token, rows)
 
-    async def get_received_ts(self, event_id: str) -> Optional[int]:
-        """Get received_ts (when it was persisted) for the event.
-
-        Raises an exception for unknown events.
-
-        Args:
-            event_id: The event ID to query.
-
-        Returns:
-            Timestamp in milliseconds, or None for events that were persisted
-            before received_ts was implemented.
-        """
-        return await self.db_pool.simple_select_one_onecol(
-            table="events",
-            keyvalues={"event_id": event_id},
-            retcol="received_ts",
-            desc="get_received_ts",
-        )
-
     async def have_censored_event(self, event_id: str) -> bool:
         """Check if an event has been censored, i.e. if the content of the event has been erased
         from the database due to a redaction.
@@ -617,7 +600,7 @@ class EventsWorkerStore(SQLBaseStore):
         Returns:
             map from event id to result
         """
-        event_entry_map = self._get_events_from_cache(
+        event_entry_map = await self._get_events_from_cache(
             event_ids,
         )
 
@@ -729,12 +712,46 @@ class EventsWorkerStore(SQLBaseStore):
 
         return event_entry_map
 
-    def _invalidate_get_event_cache(self, event_id: str) -> None:
-        self._get_event_cache.invalidate((event_id,))
+    def invalidate_get_event_cache_after_txn(
+        self, txn: LoggingTransaction, event_id: str
+    ) -> None:
+        """
+        Prepares a database transaction to invalidate the get event cache for a given
+        event ID when executed successfully. This is achieved by attaching two callbacks
+        to the transaction, one to invalidate the async cache and one for the in memory
+        sync cache (importantly called in that order).
+
+        Arguments:
+            txn: the database transaction to attach the callbacks to
+            event_id: the event ID to be invalidated from caches
+        """
+
+        txn.async_call_after(self._invalidate_async_get_event_cache, event_id)
+        txn.call_after(self._invalidate_local_get_event_cache, event_id)
+
+    async def _invalidate_async_get_event_cache(self, event_id: str) -> None:
+        """
+        Invalidates an event in the asyncronous get event cache, which may be remote.
+
+        Arguments:
+            event_id: the event ID to invalidate
+        """
+
+        await self._get_event_cache.invalidate((event_id,))
+
+    def _invalidate_local_get_event_cache(self, event_id: str) -> None:
+        """
+        Invalidates an event in local in-memory get event caches.
+
+        Arguments:
+            event_id: the event ID to invalidate
+        """
+
+        self._get_event_cache.invalidate_local((event_id,))
         self._event_ref.pop(event_id, None)
         self._current_event_fetches.pop(event_id, None)
 
-    def _get_events_from_cache(
+    async def _get_events_from_cache(
         self, events: Iterable[str], update_metrics: bool = True
     ) -> Dict[str, EventCacheEntry]:
         """Fetch events from the caches.
@@ -749,7 +766,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         for event_id in events:
             # First check if it's in the event cache
-            ret = self._get_event_cache.get(
+            ret = await self._get_event_cache.get(
                 (event_id,), None, update_metrics=update_metrics
             )
             if ret:
@@ -771,7 +788,7 @@ class EventsWorkerStore(SQLBaseStore):
 
                 # We add the entry back into the cache as we want to keep
                 # recently queried events in the cache.
-                self._get_event_cache.set((event_id,), cache_entry)
+                await self._get_event_cache.set((event_id,), cache_entry)
 
         return event_map
 
@@ -965,7 +982,13 @@ class EventsWorkerStore(SQLBaseStore):
                 }
 
                 row_dict = self.db_pool.new_transaction(
-                    conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch
+                    conn,
+                    "do_fetch",
+                    [],
+                    [],
+                    [],
+                    self._fetch_event_rows,
+                    events_to_fetch,
                 )
 
                 # We only want to resolve deferreds from the main thread
@@ -1148,7 +1171,7 @@ class EventsWorkerStore(SQLBaseStore):
                 event=original_ev, redacted_event=redacted_event
             )
 
-            self._get_event_cache.set((event_id,), cache_entry)
+            await self._get_event_cache.set((event_id,), cache_entry)
             result_map[event_id] = cache_entry
 
             if not redacted_event:
@@ -1382,7 +1405,9 @@ class EventsWorkerStore(SQLBaseStore):
         # if the event cache contains the event, obviously we've seen it.
 
         cache_results = {
-            (rid, eid) for (rid, eid) in keys if self._get_event_cache.contains((eid,))
+            (rid, eid)
+            for (rid, eid) in keys
+            if await self._get_event_cache.contains((eid,))
         }
         results = dict.fromkeys(cache_results, True)
         remaining = [k for k in keys if k not in cache_results]
@@ -1465,7 +1490,7 @@ class EventsWorkerStore(SQLBaseStore):
 
     async def get_all_new_forward_event_rows(
         self, instance_name: str, last_id: int, current_id: int, limit: int
-    ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
+    ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]:
         """Returns new events, for the Events replication stream
 
         Args:
@@ -1481,10 +1506,11 @@ class EventsWorkerStore(SQLBaseStore):
 
         def get_all_new_forward_event_rows(
             txn: LoggingTransaction,
-        ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
+        ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]:
             sql = (
                 "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
-                " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
+                " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL,"
+                " e.outlier"
                 " FROM events AS e"
                 " LEFT JOIN redactions USING (event_id)"
                 " LEFT JOIN state_events AS se USING (event_id)"
@@ -1498,7 +1524,8 @@ class EventsWorkerStore(SQLBaseStore):
             )
             txn.execute(sql, (last_id, current_id, instance_name, limit))
             return cast(
-                List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
+                List[Tuple[int, str, str, str, str, str, str, str, bool, bool]],
+                txn.fetchall(),
             )
 
         return await self.db_pool.runInteraction(
@@ -1507,7 +1534,7 @@ class EventsWorkerStore(SQLBaseStore):
 
     async def get_ex_outlier_stream_rows(
         self, instance_name: str, last_id: int, current_id: int
-    ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
+    ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]:
         """Returns de-outliered events, for the Events replication stream
 
         Args:
@@ -1522,11 +1549,14 @@ class EventsWorkerStore(SQLBaseStore):
 
         def get_ex_outlier_stream_rows_txn(
             txn: LoggingTransaction,
-        ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
+        ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]:
             sql = (
                 "SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
-                " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
+                " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL,"
+                " e.outlier"
                 " FROM events AS e"
+                # NB: the next line (inner join) is what makes this query different from
+                # get_all_new_forward_event_rows.
                 " INNER JOIN ex_outlier_stream AS out USING (event_id)"
                 " LEFT JOIN redactions USING (event_id)"
                 " LEFT JOIN state_events AS se USING (event_id)"
@@ -1541,7 +1571,8 @@ class EventsWorkerStore(SQLBaseStore):
 
             txn.execute(sql, (last_id, current_id, instance_name))
             return cast(
-                List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
+                List[Tuple[int, str, str, str, str, str, str, str, bool, bool]],
+                txn.fetchall(),
             )
 
         return await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index 9a63f953fb..efd136a864 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -66,6 +66,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore):
                 "initialise_mau_threepids",
                 [],
                 [],
+                [],
                 self._initialise_reserved_users,
                 hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
             )
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 87b0d09039..f6822707e4 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -19,6 +19,8 @@ from synapse.api.errors import SynapseError
 from synapse.storage.database import LoggingTransaction
 from synapse.storage.databases.main import CacheInvalidationWorkerStore
 from synapse.storage.databases.main.state import StateGroupWorkerStore
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.engines._base import IsolationLevel
 from synapse.types import RoomStreamToken
 
 logger = logging.getLogger(__name__)
@@ -302,7 +304,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
                 self._invalidate_cache_and_stream(
                     txn, self.have_seen_event, (room_id, event_id)
                 )
-                self._invalidate_get_event_cache(event_id)
+                self.invalidate_get_event_cache_after_txn(txn, event_id)
 
         logger.info("[purge] done")
 
@@ -317,11 +319,38 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
         Returns:
             The list of state groups to delete.
         """
-        return await self.db_pool.runInteraction(
-            "purge_room", self._purge_room_txn, room_id
+
+        # This first runs the purge transaction with READ_COMMITTED isolation level,
+        # meaning any new rows in the tables will not trigger a serialization error.
+        # We then run the same purge a second time without this isolation level to
+        # purge any of those rows which were added during the first.
+
+        state_groups_to_delete = await self.db_pool.runInteraction(
+            "purge_room",
+            self._purge_room_txn,
+            room_id=room_id,
+            isolation_level=IsolationLevel.READ_COMMITTED,
+        )
+
+        state_groups_to_delete.extend(
+            await self.db_pool.runInteraction(
+                "purge_room",
+                self._purge_room_txn,
+                room_id=room_id,
+            ),
         )
 
+        return state_groups_to_delete
+
     def _purge_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[int]:
+        # This collides with event persistence so we cannot write new events and metadata into
+        # a room while deleting it or this transaction will fail.
+        if isinstance(self.database_engine, PostgresEngine):
+            txn.execute(
+                "SELECT room_version FROM rooms WHERE room_id = ? FOR UPDATE",
+                (room_id,),
+            )
+
         # First, fetch all the state groups that should be deleted, before
         # we delete that information.
         txn.execute(
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 86649c1e6c..768f95d16c 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -228,6 +228,7 @@ class PushRulesWorkerStore(
             iterable=user_ids,
             retcols=("*",),
             desc="bulk_get_push_rules",
+            batch_size=1000,
         )
 
         rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 13d6a1d5c0..d6d485507b 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -175,7 +175,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
                   rooms.creator, state.encryption, state.is_federatable AS federatable,
                   rooms.is_public AS public, state.join_rules, state.guest_access,
                   state.history_visibility, curr.current_state_events AS state_events,
-                  state.avatar, state.topic
+                  state.avatar, state.topic, state.room_type
                 FROM rooms
                 LEFT JOIN room_stats_state state USING (room_id)
                 LEFT JOIN room_stats_current curr USING (room_id)
@@ -596,7 +596,8 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
             SELECT state.room_id, state.name, state.canonical_alias, curr.joined_members,
               curr.local_users_in_room, rooms.room_version, rooms.creator,
               state.encryption, state.is_federatable, rooms.is_public, state.join_rules,
-              state.guest_access, state.history_visibility, curr.current_state_events
+              state.guest_access, state.history_visibility, curr.current_state_events,
+              state.room_type
             FROM room_stats_state state
             INNER JOIN room_stats_current curr USING (room_id)
             INNER JOIN rooms USING (room_id)
@@ -646,6 +647,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
                         "guest_access": room[11],
                         "history_visibility": room[12],
                         "state_events": room[13],
+                        "room_type": room[14],
                     }
                 )
 
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 0b5e4e4254..df6b82660e 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -31,7 +31,6 @@ import attr
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
 from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import (
     run_as_background_process,
@@ -244,7 +243,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             txn: LoggingTransaction,
         ) -> Dict[str, ProfileInfo]:
             clause, ids = make_in_list_sql_clause(
-                self.database_engine, "m.user_id", user_ids
+                self.database_engine, "c.state_key", user_ids
             )
 
             sql = """
@@ -780,26 +779,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         return shared_room_ids or frozenset()
 
-    async def get_joined_users_from_context(
-        self, event: EventBase, context: EventContext
-    ) -> Dict[str, ProfileInfo]:
-        state_group: Union[object, int] = context.state_group
-        if not state_group:
-            # If state_group is None it means it has yet to be assigned a
-            # state group, i.e. we need to make sure that calls with a state_group
-            # of None don't hit previous cached calls with a None state_group.
-            # To do this we set the state_group to a new object as object() != object()
-            state_group = object()
-
-        current_state_ids = await context.get_current_state_ids()
-        assert current_state_ids is not None
-        assert state_group is not None
-        return await self._get_joined_users_from_context(
-            event.room_id, state_group, current_state_ids, event=event, context=context
-        )
-
     async def get_joined_users_from_state(
-        self, room_id: str, state_entry: "_StateCacheEntry"
+        self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry"
     ) -> Dict[str, ProfileInfo]:
         state_group: Union[object, int] = state_entry.state_group
         if not state_group:
@@ -812,18 +793,17 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         assert state_group is not None
         with Measure(self._clock, "get_joined_users_from_state"):
             return await self._get_joined_users_from_context(
-                room_id, state_group, state_entry.state, context=state_entry
+                room_id, state_group, state, context=state_entry
             )
 
-    @cached(num_args=2, cache_context=True, iterable=True, max_entries=100000)
+    @cached(num_args=2, iterable=True, max_entries=100000)
     async def _get_joined_users_from_context(
         self,
         room_id: str,
         state_group: Union[object, int],
         current_state_ids: StateMap[str],
-        cache_context: _CacheContext,
         event: Optional[EventBase] = None,
-        context: Optional[Union[EventContext, "_StateCacheEntry"]] = None,
+        context: Optional["_StateCacheEntry"] = None,
     ) -> Dict[str, ProfileInfo]:
         # We don't use `state_group`, it's there so that we can cache based
         # on it. However, it's important that it's never None, since two current_states
@@ -863,7 +843,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         # We don't update the event cache hit ratio as it completely throws off
         # the hit ratio counts. After all, we don't populate the cache if we
         # miss it here
-        event_map = self._get_events_from_cache(member_event_ids, update_metrics=False)
+        event_map = await self._get_events_from_cache(
+            member_event_ids, update_metrics=False
+        )
 
         missing_member_event_ids = []
         for event_id in member_event_ids:
@@ -922,7 +904,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             iterable=event_ids,
             retcols=("user_id", "display_name", "avatar_url", "event_id"),
             keyvalues={"membership": Membership.JOIN},
-            batch_size=500,
+            batch_size=1000,
             desc="_get_joined_profiles_from_event_ids",
         )
 
@@ -1017,7 +999,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         )
 
     async def get_joined_hosts(
-        self, room_id: str, state_entry: "_StateCacheEntry"
+        self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry"
     ) -> FrozenSet[str]:
         state_group: Union[object, int] = state_entry.state_group
         if not state_group:
@@ -1030,7 +1012,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         assert state_group is not None
         with Measure(self._clock, "get_joined_hosts"):
             return await self._get_joined_hosts(
-                room_id, state_group, state_entry=state_entry
+                room_id, state_group, state, state_entry=state_entry
             )
 
     @cached(num_args=2, max_entries=10000, iterable=True)
@@ -1038,6 +1020,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         self,
         room_id: str,
         state_group: Union[object, int],
+        state: StateMap[str],
         state_entry: "_StateCacheEntry",
     ) -> FrozenSet[str]:
         # We don't use `state_group`, it's there so that we can cache based on
@@ -1093,7 +1076,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                 # The cache doesn't match the state group or prev state group,
                 # so we calculate the result from first principles.
                 joined_users = await self.get_joined_users_from_state(
-                    room_id, state_entry
+                    room_id, state, state_entry
                 )
 
                 cache.hosts_to_joined_users = {}
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 3a1df7776c..2590b52f73 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -1022,8 +1022,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         }
 
     async def get_all_new_events_stream(
-        self, from_id: int, current_id: int, limit: int
-    ) -> Tuple[int, List[EventBase]]:
+        self, from_id: int, current_id: int, limit: int, get_prev_content: bool = False
+    ) -> Tuple[int, List[EventBase], Dict[str, Optional[int]]]:
         """Get all new events
 
         Returns all events with from_id < stream_ordering <= current_id.
@@ -1032,19 +1032,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             from_id:  the stream_ordering of the last event we processed
             current_id:  the stream_ordering of the most recently processed event
             limit: the maximum number of events to return
+            get_prev_content: whether to fetch previous event content
 
         Returns:
-            A tuple of (next_id, events), where `next_id` is the next value to
-            pass as `from_id` (it will either be the stream_ordering of the
-            last returned event, or, if fewer than `limit` events were found,
-            the `current_id`).
+            A tuple of (next_id, events, event_to_received_ts), where `next_id`
+            is the next value to pass as `from_id` (it will either be the
+            stream_ordering of the last returned event, or, if fewer than `limit`
+            events were found, the `current_id`). The `event_to_received_ts` is
+            a dictionary mapping event ID to the event `received_ts`.
         """
 
         def get_all_new_events_stream_txn(
             txn: LoggingTransaction,
-        ) -> Tuple[int, List[str]]:
+        ) -> Tuple[int, Dict[str, Optional[int]]]:
             sql = (
-                "SELECT e.stream_ordering, e.event_id"
+                "SELECT e.stream_ordering, e.event_id, e.received_ts"
                 " FROM events AS e"
                 " WHERE"
                 " ? < e.stream_ordering AND e.stream_ordering <= ?"
@@ -1059,15 +1061,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             if len(rows) == limit:
                 upper_bound = rows[-1][0]
 
-            return upper_bound, [row[1] for row in rows]
+            event_to_received_ts: Dict[str, Optional[int]] = {
+                row[1]: row[2] for row in rows
+            }
+            return upper_bound, event_to_received_ts
 
-        upper_bound, event_ids = await self.db_pool.runInteraction(
+        upper_bound, event_to_received_ts = await self.db_pool.runInteraction(
             "get_all_new_events_stream", get_all_new_events_stream_txn
         )
 
-        events = await self.get_events_as_list(event_ids)
+        events = await self.get_events_as_list(
+            event_to_received_ts.keys(),
+            get_prev_content=get_prev_content,
+        )
 
-        return upper_bound, events
+        return upper_bound, events, event_to_received_ts
 
     async def get_federation_out_pos(self, typ: str) -> int:
         if self._need_to_reset_federation_stream_positions:
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index fa9eadaca7..a7fcc564a9 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -24,6 +24,7 @@ from synapse.storage.database import (
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.state import StateFilter
 from synapse.types import MutableStateMap, StateMap
+from synapse.util.caches import intern_string
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -136,7 +137,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
                 txn.execute(sql % (where_clause,), args)
                 for row in txn:
                     typ, state_key, event_id = row
-                    key = (typ, state_key)
+                    key = (intern_string(typ), intern_string(state_key))
                     results[group][key] = event_id
         else:
             max_entries_returned = state_filter.max_entries_returned()
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 609a2b88bf..afbc85ad0c 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -400,14 +400,17 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         room_id: str,
         prev_group: Optional[int],
         delta_ids: Optional[StateMap[str]],
-        current_state_ids: StateMap[str],
+        current_state_ids: Optional[StateMap[str]],
     ) -> int:
         """Store a new set of state, returning a newly assigned state group.
 
+        At least one of `current_state_ids` and `prev_group` must be provided. Whenever
+        `prev_group` is not None, `delta_ids` must also not be None.
+
         Args:
             event_id: The event ID for which the state was calculated
             room_id
-            prev_group: A previous state group for the room, optional.
+            prev_group: A previous state group for the room.
             delta_ids: The delta between state at `prev_group` and
                 `current_state_ids`, if `prev_group` was given. Same format as
                 `current_state_ids`.
@@ -418,10 +421,41 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             The state group ID
         """
 
-        def _store_state_group_txn(txn: LoggingTransaction) -> int:
-            if current_state_ids is None:
-                # AFAIK, this can never happen
-                raise Exception("current_state_ids cannot be None")
+        if prev_group is None and current_state_ids is None:
+            raise Exception("current_state_ids and prev_group can't both be None")
+
+        if prev_group is not None and delta_ids is None:
+            raise Exception("delta_ids is None when prev_group is not None")
+
+        def insert_delta_group_txn(
+            txn: LoggingTransaction, prev_group: int, delta_ids: StateMap[str]
+        ) -> Optional[int]:
+            """Try and persist the new group as a delta.
+
+            Requires that we have the state as a delta from a previous state group.
+
+            Returns:
+                The state group if successfully created, or None if the state
+                needs to be persisted as a full state.
+            """
+            is_in_db = self.db_pool.simple_select_one_onecol_txn(
+                txn,
+                table="state_groups",
+                keyvalues={"id": prev_group},
+                retcol="id",
+                allow_none=True,
+            )
+            if not is_in_db:
+                raise Exception(
+                    "Trying to persist state with unpersisted prev_group: %r"
+                    % (prev_group,)
+                )
+
+            # if the chain of state group deltas is going too long, we fall back to
+            # persisting a complete state group.
+            potential_hops = self._count_state_group_hops_txn(txn, prev_group)
+            if potential_hops >= MAX_STATE_DELTA_HOPS:
+                return None
 
             state_group = self._state_group_seq_gen.get_next_id_txn(txn)
 
@@ -431,51 +465,45 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
                 values={"id": state_group, "room_id": room_id, "event_id": event_id},
             )
 
-            # We persist as a delta if we can, while also ensuring the chain
-            # of deltas isn't tooo long, as otherwise read performance degrades.
-            if prev_group:
-                is_in_db = self.db_pool.simple_select_one_onecol_txn(
-                    txn,
-                    table="state_groups",
-                    keyvalues={"id": prev_group},
-                    retcol="id",
-                    allow_none=True,
-                )
-                if not is_in_db:
-                    raise Exception(
-                        "Trying to persist state with unpersisted prev_group: %r"
-                        % (prev_group,)
-                    )
-
-                potential_hops = self._count_state_group_hops_txn(txn, prev_group)
-            if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
-                assert delta_ids is not None
-
-                self.db_pool.simple_insert_txn(
-                    txn,
-                    table="state_group_edges",
-                    values={"state_group": state_group, "prev_state_group": prev_group},
-                )
+            self.db_pool.simple_insert_txn(
+                txn,
+                table="state_group_edges",
+                values={"state_group": state_group, "prev_state_group": prev_group},
+            )
 
-                self.db_pool.simple_insert_many_txn(
-                    txn,
-                    table="state_groups_state",
-                    keys=("state_group", "room_id", "type", "state_key", "event_id"),
-                    values=[
-                        (state_group, room_id, key[0], key[1], state_id)
-                        for key, state_id in delta_ids.items()
-                    ],
-                )
-            else:
-                self.db_pool.simple_insert_many_txn(
-                    txn,
-                    table="state_groups_state",
-                    keys=("state_group", "room_id", "type", "state_key", "event_id"),
-                    values=[
-                        (state_group, room_id, key[0], key[1], state_id)
-                        for key, state_id in current_state_ids.items()
-                    ],
-                )
+            self.db_pool.simple_insert_many_txn(
+                txn,
+                table="state_groups_state",
+                keys=("state_group", "room_id", "type", "state_key", "event_id"),
+                values=[
+                    (state_group, room_id, key[0], key[1], state_id)
+                    for key, state_id in delta_ids.items()
+                ],
+            )
+
+            return state_group
+
+        def insert_full_state_txn(
+            txn: LoggingTransaction, current_state_ids: StateMap[str]
+        ) -> int:
+            """Persist the full state, returning the new state group."""
+            state_group = self._state_group_seq_gen.get_next_id_txn(txn)
+
+            self.db_pool.simple_insert_txn(
+                txn,
+                table="state_groups",
+                values={"id": state_group, "room_id": room_id, "event_id": event_id},
+            )
+
+            self.db_pool.simple_insert_many_txn(
+                txn,
+                table="state_groups_state",
+                keys=("state_group", "room_id", "type", "state_key", "event_id"),
+                values=[
+                    (state_group, room_id, key[0], key[1], state_id)
+                    for key, state_id in current_state_ids.items()
+                ],
+            )
 
             # Prefill the state group caches with this group.
             # It's fine to use the sequence like this as the state group map
@@ -491,7 +519,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
                 self._state_group_members_cache.update,
                 self._state_group_members_cache.sequence,
                 key=state_group,
-                value=dict(current_member_state_ids),
+                value=current_member_state_ids,
             )
 
             current_non_member_state_ids = {
@@ -503,13 +531,35 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
                 self._state_group_cache.update,
                 self._state_group_cache.sequence,
                 key=state_group,
-                value=dict(current_non_member_state_ids),
+                value=current_non_member_state_ids,
             )
 
             return state_group
 
+        if prev_group is not None:
+            state_group = await self.db_pool.runInteraction(
+                "store_state_group.insert_delta_group",
+                insert_delta_group_txn,
+                prev_group,
+                delta_ids,
+            )
+            if state_group is not None:
+                return state_group
+
+        # We're going to persist the state as a complete group rather than
+        # a delta, so first we need to ensure we have loaded the state map
+        # from the database.
+        if current_state_ids is None:
+            assert prev_group is not None
+            assert delta_ids is not None
+            groups = await self._get_state_for_groups([prev_group])
+            current_state_ids = dict(groups[prev_group])
+            current_state_ids.update(delta_ids)
+
         return await self.db_pool.runInteraction(
-            "store_state_group", _store_state_group_txn
+            "store_state_group.insert_full_state",
+            insert_full_state_txn,
+            current_state_ids,
         )
 
     async def purge_unreferenced_state_groups(
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index dc237e3032..a9a88c8bfd 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -74,13 +74,14 @@ Changes in SCHEMA_VERSION = 71:
 
 Changes in SCHEMA_VERSION = 72:
     - event_edges.(room_id, is_state) are no longer written to.
+    - Tables related to groups are dropped.
 """
 
 
 SCHEMA_COMPAT_VERSION = (
-    # We no longer maintain `event_edges.room_id`, so synapses with SCHEMA_VERSION < 71
-    # will break.
-    71
+    # The groups tables are no longer accessible, so synapses with SCHEMA_VERSION < 72
+    # could break.
+    72
 )
 """Limit on how far the synapse codebase can be rolled back without breaking db compat
 
diff --git a/synapse/storage/schema/main/delta/72/03bg_populate_events_columns.py b/synapse/storage/schema/main/delta/72/03bg_populate_events_columns.py
new file mode 100644
index 0000000000..55a5d092cc
--- /dev/null
+++ b/synapse/storage/schema/main/delta/72/03bg_populate_events_columns.py
@@ -0,0 +1,47 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+
+from synapse.storage.types import Cursor
+
+
+def run_create(cur: Cursor, database_engine, *args, **kwargs):
+    """Add a bg update to populate the `state_key` and `rejection_reason` columns of `events`"""
+
+    # we know that any new events will have the columns populated (and that has been
+    # the case since schema_version 68, so there is no chance of rolling back now).
+    #
+    # So, we only need to make sure that existing rows are updated. We read the
+    # current min and max stream orderings, since that is guaranteed to include all
+    # the events that were stored before the new columns were added.
+    cur.execute("SELECT MIN(stream_ordering), MAX(stream_ordering) FROM events")
+    (min_stream_ordering, max_stream_ordering) = cur.fetchone()
+
+    if min_stream_ordering is None:
+        # no rows, nothing to do.
+        return
+
+    cur.execute(
+        "INSERT into background_updates (ordering, update_name, progress_json)"
+        " VALUES (7203, 'events_populate_state_key_rejections', ?)",
+        (
+            json.dumps(
+                {
+                    "min_stream_ordering_exclusive": min_stream_ordering - 1,
+                    "max_stream_ordering_inclusive": max_stream_ordering,
+                }
+            ),
+        ),
+    )
diff --git a/synapse/storage/schema/main/delta/72/03drop_event_reference_hashes.sql b/synapse/storage/schema/main/delta/72/03drop_event_reference_hashes.sql
new file mode 100644
index 0000000000..0da668aa3a
--- /dev/null
+++ b/synapse/storage/schema/main/delta/72/03drop_event_reference_hashes.sql
@@ -0,0 +1,17 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- event_reference_hashes is unused, so we can drop it
+DROP TABLE event_reference_hashes;
diff --git a/synapse/storage/schema/main/delta/72/03remove_groups.sql b/synapse/storage/schema/main/delta/72/03remove_groups.sql
new file mode 100644
index 0000000000..b7c5894de8
--- /dev/null
+++ b/synapse/storage/schema/main/delta/72/03remove_groups.sql
@@ -0,0 +1,31 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Remove the tables which powered the unspecced groups/communities feature.
+DROP TABLE IF EXISTS group_attestations_remote;
+DROP TABLE IF EXISTS group_attestations_renewals;
+DROP TABLE IF EXISTS group_invites;
+DROP TABLE IF EXISTS group_roles;
+DROP TABLE IF EXISTS group_room_categories;
+DROP TABLE IF EXISTS group_rooms;
+DROP TABLE IF EXISTS group_summary_roles;
+DROP TABLE IF EXISTS group_summary_room_categories;
+DROP TABLE IF EXISTS group_summary_rooms;
+DROP TABLE IF EXISTS group_summary_users;
+DROP TABLE IF EXISTS group_users;
+DROP TABLE IF EXISTS groups;
+DROP TABLE IF EXISTS local_group_membership;
+DROP TABLE IF EXISTS local_group_updates;
+DROP TABLE IF EXISTS remote_profile_cache;
diff --git a/synapse/storage/util/partial_state_events_tracker.py b/synapse/storage/util/partial_state_events_tracker.py
index 211437cfaa..466e5137f2 100644
--- a/synapse/storage/util/partial_state_events_tracker.py
+++ b/synapse/storage/util/partial_state_events_tracker.py
@@ -166,6 +166,7 @@ class PartialCurrentStateTracker:
             logger.info(
                 "Awaiting un-partial-stating of room %s",
                 room_id,
+                stack_info=True,
             )
 
             await make_deferred_yieldable(d)
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 8ed5325c5d..31f41fec82 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -730,3 +730,41 @@ class LruCache(Generic[KT, VT]):
         # This happens e.g. in the sync code where we have an expiring cache of
         # lru caches.
         self.clear()
+
+
+class AsyncLruCache(Generic[KT, VT]):
+    """
+    An asynchronous wrapper around a subset of the LruCache API.
+
+    On its own this doesn't change the behaviour but allows subclasses that
+    utilize external cache systems that require await behaviour to be created.
+    """
+
+    def __init__(self, *args, **kwargs):  # type: ignore
+        self._lru_cache: LruCache[KT, VT] = LruCache(*args, **kwargs)
+
+    async def get(
+        self, key: KT, default: Optional[T] = None, update_metrics: bool = True
+    ) -> Optional[VT]:
+        return self._lru_cache.get(key, update_metrics=update_metrics)
+
+    async def set(self, key: KT, value: VT) -> None:
+        self._lru_cache.set(key, value)
+
+    async def invalidate(self, key: KT) -> None:
+        # This method should invalidate any external cache and then invalidate the LruCache.
+        return self._lru_cache.invalidate(key)
+
+    def invalidate_local(self, key: KT) -> None:
+        """Remove an entry from the local cache
+
+        This variant of `invalidate` is useful if we know that the external
+        cache has already been invalidated.
+        """
+        return self._lru_cache.invalidate(key)
+
+    async def contains(self, key: KT) -> bool:
+        return self._lru_cache.contains(key)
+
+    async def clear(self) -> None:
+        self._lru_cache.clear()